""" Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints Model: allenai/Molmo2-8B For ProofPath video assessment - video pointing, tracking, and grounded analysis. Unique capability: Returns pixel-level coordinates for objects in videos. """ from typing import Dict, List, Any, Optional, Tuple, Union import torch import numpy as np import base64 import io import tempfile import os import re class EndpointHandler: def __init__(self, path: str = ""): """ Initialize Molmo 2 model for video pointing and tracking. Args: path: Path to the model directory (provided by HF Inference Endpoints) """ from transformers import AutoProcessor, AutoModelForImageTextToText # Use the model path provided by the endpoint, or default to HF hub model_id = path if path else "allenai/Molmo2-8B" # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load processor and model self.processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, dtype="auto", device_map="auto" if torch.cuda.is_available() else None ) self.model = AutoModelForImageTextToText.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) if not torch.cuda.is_available(): self.model = self.model.to(self.device) self.model.eval() # Molmo 2 limits: 128 frames max at 2fps self.max_frames = 128 self.default_fps = 2.0 # Regex patterns for parsing Molmo output self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>") self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)") self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})") def _parse_video_points( self, text: str, image_w: int, image_h: int, extract_ids: bool = False ) -> List[Tuple]: """ Extract video pointing coordinates from Molmo output. Molmo outputs coordinates in XML-like format: Where: - 8.5 = timestamp/frame - 0, 1 = instance IDs - 183 216, 245 198 = x, y coordinates (scaled by 1000) Returns: List of (timestamp, x, y) or (timestamp, id, x, y) tuples """ all_points = [] for coord_match in self.COORD_REGEX.finditer(text): for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)): timestamp = float(frame_match.group(1)) for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)): instance_id = point_match.group(1) # Coordinates are scaled by 1000 x = float(point_match.group(2)) / 1000 * image_w y = float(point_match.group(3)) / 1000 * image_h if 0 <= x <= image_w and 0 <= y <= image_h: if extract_ids: all_points.append((timestamp, int(instance_id), x, y)) else: all_points.append((timestamp, x, y)) return all_points def _parse_multi_image_points( self, text: str, widths: List[int], heights: List[int] ) -> List[Tuple]: """Parse pointing coordinates across multiple images.""" all_points = [] for coord_match in self.COORD_REGEX.finditer(text): for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)): # For multi-image, frame_id is 1-indexed image number image_idx = int(frame_match.group(1)) - 1 if 0 <= image_idx < len(widths): w, h = widths[image_idx], heights[image_idx] for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)): x = float(point_match.group(2)) / 1000 * w y = float(point_match.group(3)) / 1000 * h if 0 <= x <= w and 0 <= y <= h: all_points.append((image_idx + 1, x, y)) return all_points def _load_image(self, image_data: Any): """Load a single image from various formats.""" from PIL import Image import requests if isinstance(image_data, Image.Image): return image_data elif isinstance(image_data, str): if image_data.startswith(('http://', 'https://')): response = requests.get(image_data, stream=True) return Image.open(response.raw).convert('RGB') elif image_data.startswith('data:'): header, encoded = image_data.split(',', 1) image_bytes = base64.b64decode(encoded) return Image.open(io.BytesIO(image_bytes)).convert('RGB') else: image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert('RGB') elif isinstance(image_data, bytes): return Image.open(io.BytesIO(image_data)).convert('RGB') else: raise ValueError(f"Unsupported image input type: {type(image_data)}") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process video or images with Molmo 2. Expected input formats: 1. Video QA: { "inputs": , "parameters": { "prompt": "What happens in this video?", "max_new_tokens": 2048 } } 2. Video Pointing (Molmo's unique capability): { "inputs": , "parameters": { "prompt": "Point to all the people in this video.", "mode": "pointing", "max_new_tokens": 2048 } } 3. Video Tracking: { "inputs": , "parameters": { "prompt": "Track the person in the red shirt.", "mode": "tracking", "max_new_tokens": 2048 } } 4. Image Pointing: { "inputs": , "parameters": { "prompt": "Point to the Excel cell B2.", "mode": "pointing" } } 5. Multi-image comparison: { "inputs": [, ], "parameters": { "prompt": "Compare these images." } } Returns: { "generated_text": "...", "points": [(timestamp, x, y), ...], # If pointing mode "tracks": {"object_id": [(t, x, y), ...]}, # If tracking mode "video_metadata": {...} } """ inputs = data.get("inputs") if inputs is None: inputs = data.get("video") or data.get("image") or data.get("images") if inputs is None: raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.") params = data.get("parameters", {}) mode = params.get("mode", "default") prompt = params.get("prompt", "Describe this content.") max_new_tokens = params.get("max_new_tokens", 2048) try: if isinstance(inputs, list): return self._process_multi_image(inputs, prompt, params, max_new_tokens) elif self._is_video(inputs, params): return self._process_video(inputs, prompt, params, max_new_tokens) else: return self._process_image(inputs, prompt, params, max_new_tokens) except Exception as e: return {"error": str(e), "error_type": type(e).__name__} def _is_video(self, inputs: Any, params: Dict) -> bool: """Determine if input is video.""" if params.get("input_type") == "video": return True if params.get("input_type") == "image": return False if isinstance(inputs, str): lower = inputs.lower() video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v'] return any(ext in lower for ext in video_exts) return False def _process_video( self, video_data: Any, prompt: str, params: Dict, max_new_tokens: int ) -> Dict[str, Any]: """Process video with Molmo 2.""" try: from molmo_utils import process_vision_info except ImportError: # Fallback if molmo_utils not available return self._process_video_fallback(video_data, prompt, params, max_new_tokens) mode = params.get("mode", "default") # Prepare video URL or path if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')): video_source = video_data else: # Write to temp file if isinstance(video_data, str): video_bytes = base64.b64decode(video_data) else: video_bytes = video_data with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_source = f.name try: messages = [ { "role": "user", "content": [ dict(type="text", text=prompt), dict(type="video", video=video_source), ], } ] # Process video with molmo_utils _, videos, video_kwargs = process_vision_info(messages) videos, video_metadatas = zip(*videos) videos, video_metadatas = list(videos), list(video_metadatas) # Get chat template text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process inputs inputs = self.processor( videos=videos, video_metadata=video_metadatas, text=text, padding=True, return_tensors="pt", **video_kwargs, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # Generate with torch.inference_mode(): generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) # Decode generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode( generated_tokens, skip_special_tokens=True ) # Get video dimensions video_w = video_metadatas[0].get("width", 1920) video_h = video_metadatas[0].get("height", 1080) result = { "generated_text": generated_text, "video_metadata": { "width": video_w, "height": video_h, **{k: v for k, v in video_metadatas[0].items() if k not in ["width", "height"]} } } # Parse coordinates based on mode if mode in ["pointing", "tracking"]: points = self._parse_video_points( generated_text, video_w, video_h, extract_ids=(mode == "tracking") ) if mode == "tracking": # Group by object ID for tracking from collections import defaultdict tracks = defaultdict(list) for point in points: obj_id = point[1] tracks[obj_id].append((point[0], point[2], point[3])) result["tracks"] = dict(tracks) result["num_objects_tracked"] = len(tracks) else: result["points"] = points result["num_points"] = len(points) return result finally: # Clean up temp file if created if not isinstance(video_data, str) or not video_data.startswith(('http://', 'https://')): if os.path.exists(video_source): os.unlink(video_source) def _process_video_fallback( self, video_data: Any, prompt: str, params: Dict, max_new_tokens: int ) -> Dict[str, Any]: """Fallback video processing without molmo_utils.""" # Extract frames manually import cv2 from PIL import Image # Write video to temp file if isinstance(video_data, str): if video_data.startswith(('http://', 'https://')): import requests response = requests.get(video_data, stream=True) video_bytes = response.content else: video_bytes = base64.b64decode(video_data) else: video_bytes = video_data with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_path = f.name try: # Extract frames at 2fps, max 128 cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / fps if fps > 0 else 0 # Sample frames target_frames = min(self.max_frames, int(duration * self.default_fps), total_frames) frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int) frames = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() # Process as multi-image content = [dict(type="text", text=prompt)] for frame in frames: content.append(dict(type="image", image=frame)) messages = [{"role": "user", "content": content}] inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.inference_mode(): generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode( generated_tokens, skip_special_tokens=True ) mode = params.get("mode", "default") result = { "generated_text": generated_text, "video_metadata": { "width": video_w, "height": video_h, "duration": duration, "sampled_frames": len(frames) } } if mode in ["pointing", "tracking"]: points = self._parse_video_points( generated_text, video_w, video_h, extract_ids=(mode == "tracking") ) if mode == "tracking": from collections import defaultdict tracks = defaultdict(list) for point in points: tracks[point[1]].append((point[0], point[2], point[3])) result["tracks"] = dict(tracks) else: result["points"] = points return result finally: if os.path.exists(video_path): os.unlink(video_path) def _process_image( self, image_data: Any, prompt: str, params: Dict, max_new_tokens: int ) -> Dict[str, Any]: """Process a single image.""" image = self._load_image(image_data) mode = params.get("mode", "default") messages = [ { "role": "user", "content": [ dict(type="text", text=prompt), dict(type="image", image=image), ], } ] inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.inference_mode(): generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode( generated_tokens, skip_special_tokens=True ) result = { "generated_text": generated_text, "image_size": {"width": image.width, "height": image.height} } if mode == "pointing": points = self._parse_video_points(generated_text, image.width, image.height) result["points"] = points result["num_points"] = len(points) return result def _process_multi_image( self, images_data: List, prompt: str, params: Dict, max_new_tokens: int ) -> Dict[str, Any]: """Process multiple images.""" images = [self._load_image(img) for img in images_data] mode = params.get("mode", "default") content = [dict(type="text", text=prompt)] for image in images: content.append(dict(type="image", image=image)) messages = [{"role": "user", "content": content}] inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.inference_mode(): generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode( generated_tokens, skip_special_tokens=True ) result = { "generated_text": generated_text, "num_images": len(images), "image_sizes": [{"width": img.width, "height": img.height} for img in images] } if mode == "pointing": widths = [img.width for img in images] heights = [img.height for img in images] points = self._parse_multi_image_points(generated_text, widths, heights) result["points"] = points result["num_points"] = len(points) return result