""" Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints Model: allenai/Molmo2-7B-1225 For ProofPath video assessment - video pointing, tracking, and grounded analysis. Unique capability: Returns pixel-level coordinates for objects in videos. """ from typing import Dict, List, Any, Optional, Tuple, Union import torch import numpy as np import base64 import io import tempfile import os import re class EndpointHandler: def __init__(self, path: str = ""): """ Initialize Molmo 2 model for video pointing and tracking. Args: path: Path to the model directory (ignored - we always load from HF hub) """ # IMPORTANT: Always load from HF hub, not the repository path model_id = "allenai/Molmo2-7B-1225" # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load processor and model with trust_remote_code from transformers import AutoProcessor, AutoModelForCausalLM self.processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, ) self.model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) if not torch.cuda.is_available(): self.model = self.model.to(self.device) self.model.eval() # Molmo 2 limits self.max_frames = 128 self.default_fps = 2.0 # Regex patterns for parsing Molmo pointing output # Molmo outputs: self.POINT_REGEX = re.compile(r'') self.POINTS_REGEX = re.compile(r'(.*?)', re.DOTALL) def _parse_points(self, text: str, image_w: int, image_h: int) -> List[Dict]: """ Extract pointing coordinates from Molmo output. Molmo outputs coordinates as percentages (0-100). """ points = [] for match in self.POINT_REGEX.finditer(text): x_pct = float(match.group(1)) y_pct = float(match.group(2)) alt = match.group(3) or "" # Convert percentage to pixels x = (x_pct / 100) * image_w y = (y_pct / 100) * image_h points.append({ "x": x, "y": y, "x_pct": x_pct, "y_pct": y_pct, "label": alt }) return points def _load_image(self, image_data: Any): """Load a single image from various formats.""" from PIL import Image import requests if isinstance(image_data, Image.Image): return image_data elif isinstance(image_data, str): if image_data.startswith(('http://', 'https://')): response = requests.get(image_data, stream=True) return Image.open(response.raw).convert('RGB') elif image_data.startswith('data:'): header, encoded = image_data.split(',', 1) image_bytes = base64.b64decode(encoded) return Image.open(io.BytesIO(image_bytes)).convert('RGB') else: image_bytes = base64.b64decode(image_data) return Image.open(io.BytesIO(image_bytes)).convert('RGB') elif isinstance(image_data, bytes): return Image.open(io.BytesIO(image_data)).convert('RGB') else: raise ValueError(f"Unsupported image input type: {type(image_data)}") def _load_video_frames( self, video_data: Any, max_frames: int = 128, fps: float = 2.0 ) -> tuple: """Load video frames from various input formats.""" import cv2 from PIL import Image # Decode video to temp file if needed if isinstance(video_data, str): if video_data.startswith(('http://', 'https://')): import requests response = requests.get(video_data, stream=True) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) video_path = f.name elif video_data.startswith('data:'): header, encoded = video_data.split(',', 1) video_bytes = base64.b64decode(encoded) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_path = f.name else: video_bytes = base64.b64decode(video_data) with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_bytes) video_path = f.name elif isinstance(video_data, bytes): with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: f.write(video_data) video_path = f.name else: raise ValueError(f"Unsupported video input type: {type(video_data)}") try: cap = cv2.VideoCapture(video_path) video_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / video_fps if video_fps > 0 else 0 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Calculate frame indices target_frames = min(max_frames, int(duration * fps), total_frames) if target_frames <= 0: target_frames = min(max_frames, total_frames) frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int) frames = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) cap.release() return frames, { "duration": duration, "total_frames": total_frames, "sampled_frames": len(frames), "video_fps": video_fps, "width": width, "height": height } finally: if os.path.exists(video_path): os.unlink(video_path) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process video or images with Molmo 2. Expected input formats: 1. Image analysis with pointing: { "inputs": , "parameters": { "prompt": "Point to the Excel cell B2.", "max_new_tokens": 1024 } } 2. Video analysis (processes as multi-frame): { "inputs": , "parameters": { "prompt": "What happens in this video?", "max_frames": 64, "max_new_tokens": 2048 } } 3. Multi-image comparison: { "inputs": [, ], "parameters": { "prompt": "Compare these screenshots." } } Returns: { "generated_text": "...", "points": [{"x": 123, "y": 456, "label": "..."}], # If pointing detected "image_size": {...} } """ inputs = data.get("inputs") if inputs is None: inputs = data.get("video") or data.get("image") or data.get("images") if inputs is None: raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.") params = data.get("parameters", {}) prompt = params.get("prompt", "Describe this image.") max_new_tokens = params.get("max_new_tokens", 1024) try: if isinstance(inputs, list): return self._process_multi_image(inputs, prompt, max_new_tokens) elif self._is_video(inputs, params): return self._process_video(inputs, prompt, params, max_new_tokens) else: return self._process_image(inputs, prompt, max_new_tokens) except Exception as e: import traceback return {"error": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()} def _is_video(self, inputs: Any, params: Dict) -> bool: """Determine if input is video.""" if params.get("input_type") == "video": return True if params.get("input_type") == "image": return False if isinstance(inputs, str): lower = inputs.lower() video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v'] return any(ext in lower for ext in video_exts) return False def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]: """Process a single image.""" image = self._load_image(image_data) # Process with Molmo processor inputs = self.processor.process( images=[image], text=prompt, ) # Move to device inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} # Generate with torch.inference_mode(): output = self.model.generate_from_batch( inputs, generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]}, tokenizer=self.processor.tokenizer, ) # Decode generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) result = { "generated_text": generated_text, "image_size": {"width": image.width, "height": image.height} } # Parse any pointing coordinates points = self._parse_points(generated_text, image.width, image.height) if points: result["points"] = points result["num_points"] = len(points) return result def _process_video( self, video_data: Any, prompt: str, params: Dict, max_new_tokens: int ) -> Dict[str, Any]: """Process video by sampling frames.""" max_frames = min(params.get("max_frames", 32), self.max_frames) fps = params.get("fps", self.default_fps) frames, video_metadata = self._load_video_frames(video_data, max_frames, fps) if not frames: raise ValueError("No frames could be extracted from video") # For video, we process key frames # Molmo can handle multiple images - we'll sample representative frames sample_indices = np.linspace(0, len(frames) - 1, min(8, len(frames)), dtype=int) sample_frames = [frames[i] for i in sample_indices] # Modify prompt to indicate video context video_prompt = f"These are {len(sample_frames)} frames from a video. {prompt}" # Process with Molmo inputs = self.processor.process( images=sample_frames, text=video_prompt, ) inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} with torch.inference_mode(): output = self.model.generate_from_batch( inputs, generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]}, tokenizer=self.processor.tokenizer, ) generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) result = { "generated_text": generated_text, "video_metadata": video_metadata, "frames_analyzed": len(sample_frames) } # Parse points using first frame dimensions points = self._parse_points(generated_text, video_metadata["width"], video_metadata["height"]) if points: result["points"] = points result["num_points"] = len(points) return result def _process_multi_image( self, images_data: List, prompt: str, max_new_tokens: int ) -> Dict[str, Any]: """Process multiple images.""" images = [self._load_image(img) for img in images_data] # Process with Molmo inputs = self.processor.process( images=images, text=prompt, ) inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} with torch.inference_mode(): output = self.model.generate_from_batch( inputs, generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]}, tokenizer=self.processor.tokenizer, ) generated_tokens = output[0, inputs['input_ids'].size(1):] generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) result = { "generated_text": generated_text, "num_images": len(images), "image_sizes": [{"width": img.width, "height": img.height} for img in images] } # Parse points using first image dimensions if images: points = self._parse_points(generated_text, images[0].width, images[0].height) if points: result["points"] = points result["num_points"] = len(points) return result