| | """ |
| | Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints |
| | Model: allenai/Molmo2-7B-1225 |
| | |
| | For ProofPath video assessment - video pointing, tracking, and grounded analysis. |
| | Unique capability: Returns pixel-level coordinates for objects in videos. |
| | """ |
| |
|
| | from typing import Dict, List, Any, Optional, Tuple, Union |
| | import torch |
| | import numpy as np |
| | import base64 |
| | import io |
| | import tempfile |
| | import os |
| | import re |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize Molmo 2 model for video pointing and tracking. |
| | |
| | Args: |
| | path: Path to the model directory (ignored - we always load from HF hub) |
| | """ |
| | |
| | model_id = "allenai/Molmo2-7B-1225" |
| | |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | from transformers import AutoProcessor, AutoModelForCausalLM |
| | |
| | self.processor = AutoProcessor.from_pretrained( |
| | model_id, |
| | trust_remote_code=True, |
| | ) |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| | device_map="auto" if torch.cuda.is_available() else None, |
| | ) |
| | |
| | if not torch.cuda.is_available(): |
| | self.model = self.model.to(self.device) |
| | |
| | self.model.eval() |
| | |
| | |
| | self.max_frames = 128 |
| | self.default_fps = 2.0 |
| | |
| | |
| | |
| | self.POINT_REGEX = re.compile(r'<point\s+x="([0-9.]+)"\s+y="([0-9.]+)"(?:\s+alt="([^"]*)")?>') |
| | self.POINTS_REGEX = re.compile(r'<points>(.*?)</points>', re.DOTALL) |
| | |
| | def _parse_points(self, text: str, image_w: int, image_h: int) -> List[Dict]: |
| | """ |
| | Extract pointing coordinates from Molmo output. |
| | |
| | Molmo outputs coordinates as percentages (0-100). |
| | """ |
| | points = [] |
| | |
| | for match in self.POINT_REGEX.finditer(text): |
| | x_pct = float(match.group(1)) |
| | y_pct = float(match.group(2)) |
| | alt = match.group(3) or "" |
| | |
| | |
| | x = (x_pct / 100) * image_w |
| | y = (y_pct / 100) * image_h |
| | |
| | points.append({ |
| | "x": x, |
| | "y": y, |
| | "x_pct": x_pct, |
| | "y_pct": y_pct, |
| | "label": alt |
| | }) |
| | |
| | return points |
| | |
| | def _load_image(self, image_data: Any): |
| | """Load a single image from various formats.""" |
| | from PIL import Image |
| | import requests |
| | |
| | if isinstance(image_data, Image.Image): |
| | return image_data |
| | elif isinstance(image_data, str): |
| | if image_data.startswith(('http://', 'https://')): |
| | response = requests.get(image_data, stream=True) |
| | return Image.open(response.raw).convert('RGB') |
| | elif image_data.startswith('data:'): |
| | header, encoded = image_data.split(',', 1) |
| | image_bytes = base64.b64decode(encoded) |
| | return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| | else: |
| | image_bytes = base64.b64decode(image_data) |
| | return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
| | elif isinstance(image_data, bytes): |
| | return Image.open(io.BytesIO(image_data)).convert('RGB') |
| | else: |
| | raise ValueError(f"Unsupported image input type: {type(image_data)}") |
| | |
| | def _load_video_frames( |
| | self, |
| | video_data: Any, |
| | max_frames: int = 128, |
| | fps: float = 2.0 |
| | ) -> tuple: |
| | """Load video frames from various input formats.""" |
| | import cv2 |
| | from PIL import Image |
| | |
| | |
| | if isinstance(video_data, str): |
| | if video_data.startswith(('http://', 'https://')): |
| | import requests |
| | response = requests.get(video_data, stream=True) |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | for chunk in response.iter_content(chunk_size=8192): |
| | f.write(chunk) |
| | video_path = f.name |
| | elif video_data.startswith('data:'): |
| | header, encoded = video_data.split(',', 1) |
| | video_bytes = base64.b64decode(encoded) |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | f.write(video_bytes) |
| | video_path = f.name |
| | else: |
| | video_bytes = base64.b64decode(video_data) |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | f.write(video_bytes) |
| | video_path = f.name |
| | elif isinstance(video_data, bytes): |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| | f.write(video_data) |
| | video_path = f.name |
| | else: |
| | raise ValueError(f"Unsupported video input type: {type(video_data)}") |
| | |
| | try: |
| | cap = cv2.VideoCapture(video_path) |
| | video_fps = cap.get(cv2.CAP_PROP_FPS) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | duration = total_frames / video_fps if video_fps > 0 else 0 |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | |
| | |
| | target_frames = min(max_frames, int(duration * fps), total_frames) |
| | if target_frames <= 0: |
| | target_frames = min(max_frames, total_frames) |
| | |
| | frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int) |
| | |
| | frames = [] |
| | for idx in frame_indices: |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| | ret, frame = cap.read() |
| | if ret: |
| | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames.append(Image.fromarray(frame_rgb)) |
| | |
| | cap.release() |
| | |
| | return frames, { |
| | "duration": duration, |
| | "total_frames": total_frames, |
| | "sampled_frames": len(frames), |
| | "video_fps": video_fps, |
| | "width": width, |
| | "height": height |
| | } |
| | |
| | finally: |
| | if os.path.exists(video_path): |
| | os.unlink(video_path) |
| | |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process video or images with Molmo 2. |
| | |
| | Expected input formats: |
| | |
| | 1. Image analysis with pointing: |
| | { |
| | "inputs": <image_url_or_base64>, |
| | "parameters": { |
| | "prompt": "Point to the Excel cell B2.", |
| | "max_new_tokens": 1024 |
| | } |
| | } |
| | |
| | 2. Video analysis (processes as multi-frame): |
| | { |
| | "inputs": <video_url>, |
| | "parameters": { |
| | "prompt": "What happens in this video?", |
| | "max_frames": 64, |
| | "max_new_tokens": 2048 |
| | } |
| | } |
| | |
| | 3. Multi-image comparison: |
| | { |
| | "inputs": [<image1>, <image2>], |
| | "parameters": { |
| | "prompt": "Compare these screenshots." |
| | } |
| | } |
| | |
| | Returns: |
| | { |
| | "generated_text": "...", |
| | "points": [{"x": 123, "y": 456, "label": "..."}], # If pointing detected |
| | "image_size": {...} |
| | } |
| | """ |
| | inputs = data.get("inputs") |
| | if inputs is None: |
| | inputs = data.get("video") or data.get("image") or data.get("images") |
| | if inputs is None: |
| | raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.") |
| | |
| | params = data.get("parameters", {}) |
| | prompt = params.get("prompt", "Describe this image.") |
| | max_new_tokens = params.get("max_new_tokens", 1024) |
| | |
| | try: |
| | if isinstance(inputs, list): |
| | return self._process_multi_image(inputs, prompt, max_new_tokens) |
| | elif self._is_video(inputs, params): |
| | return self._process_video(inputs, prompt, params, max_new_tokens) |
| | else: |
| | return self._process_image(inputs, prompt, max_new_tokens) |
| | |
| | except Exception as e: |
| | import traceback |
| | return {"error": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()} |
| | |
| | def _is_video(self, inputs: Any, params: Dict) -> bool: |
| | """Determine if input is video.""" |
| | if params.get("input_type") == "video": |
| | return True |
| | if params.get("input_type") == "image": |
| | return False |
| | |
| | if isinstance(inputs, str): |
| | lower = inputs.lower() |
| | video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v'] |
| | return any(ext in lower for ext in video_exts) |
| | |
| | return False |
| | |
| | def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]: |
| | """Process a single image.""" |
| | image = self._load_image(image_data) |
| | |
| | |
| | inputs = self.processor.process( |
| | images=[image], |
| | text=prompt, |
| | ) |
| | |
| | |
| | inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} |
| | |
| | |
| | with torch.inference_mode(): |
| | output = self.model.generate_from_batch( |
| | inputs, |
| | generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]}, |
| | tokenizer=self.processor.tokenizer, |
| | ) |
| | |
| | |
| | generated_tokens = output[0, inputs['input_ids'].size(1):] |
| | generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| | |
| | result = { |
| | "generated_text": generated_text, |
| | "image_size": {"width": image.width, "height": image.height} |
| | } |
| | |
| | |
| | points = self._parse_points(generated_text, image.width, image.height) |
| | if points: |
| | result["points"] = points |
| | result["num_points"] = len(points) |
| | |
| | return result |
| | |
| | def _process_video( |
| | self, |
| | video_data: Any, |
| | prompt: str, |
| | params: Dict, |
| | max_new_tokens: int |
| | ) -> Dict[str, Any]: |
| | """Process video by sampling frames.""" |
| | max_frames = min(params.get("max_frames", 32), self.max_frames) |
| | fps = params.get("fps", self.default_fps) |
| | |
| | frames, video_metadata = self._load_video_frames(video_data, max_frames, fps) |
| | |
| | if not frames: |
| | raise ValueError("No frames could be extracted from video") |
| | |
| | |
| | |
| | sample_indices = np.linspace(0, len(frames) - 1, min(8, len(frames)), dtype=int) |
| | sample_frames = [frames[i] for i in sample_indices] |
| | |
| | |
| | video_prompt = f"These are {len(sample_frames)} frames from a video. {prompt}" |
| | |
| | |
| | inputs = self.processor.process( |
| | images=sample_frames, |
| | text=video_prompt, |
| | ) |
| | |
| | inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} |
| | |
| | with torch.inference_mode(): |
| | output = self.model.generate_from_batch( |
| | inputs, |
| | generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]}, |
| | tokenizer=self.processor.tokenizer, |
| | ) |
| | |
| | generated_tokens = output[0, inputs['input_ids'].size(1):] |
| | generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| | |
| | result = { |
| | "generated_text": generated_text, |
| | "video_metadata": video_metadata, |
| | "frames_analyzed": len(sample_frames) |
| | } |
| | |
| | |
| | points = self._parse_points(generated_text, video_metadata["width"], video_metadata["height"]) |
| | if points: |
| | result["points"] = points |
| | result["num_points"] = len(points) |
| | |
| | return result |
| | |
| | def _process_multi_image( |
| | self, |
| | images_data: List, |
| | prompt: str, |
| | max_new_tokens: int |
| | ) -> Dict[str, Any]: |
| | """Process multiple images.""" |
| | images = [self._load_image(img) for img in images_data] |
| | |
| | |
| | inputs = self.processor.process( |
| | images=images, |
| | text=prompt, |
| | ) |
| | |
| | inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} |
| | |
| | with torch.inference_mode(): |
| | output = self.model.generate_from_batch( |
| | inputs, |
| | generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]}, |
| | tokenizer=self.processor.tokenizer, |
| | ) |
| | |
| | generated_tokens = output[0, inputs['input_ids'].size(1):] |
| | generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| | |
| | result = { |
| | "generated_text": generated_text, |
| | "num_images": len(images), |
| | "image_sizes": [{"width": img.width, "height": img.height} for img in images] |
| | } |
| | |
| | |
| | if images: |
| | points = self._parse_points(generated_text, images[0].width, images[0].height) |
| | if points: |
| | result["points"] = points |
| | result["num_points"] = len(points) |
| | |
| | return result |
| |
|