"""
Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints
Model: allenai/Molmo2-8B
For ProofPath video assessment - video pointing, tracking, and grounded analysis.
Unique capability: Returns pixel-level coordinates for objects in videos.
"""
from typing import Dict, List, Any, Optional, Tuple, Union
import torch
import numpy as np
import base64
import io
import tempfile
import os
import re
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize Molmo 2 model for video pointing and tracking.
Args:
path: Path to the model directory (provided by HF Inference Endpoints)
"""
from transformers import AutoProcessor, AutoModelForImageTextToText
# Use the model path provided by the endpoint, or default to HF hub
model_id = path if path else "allenai/Molmo2-8B"
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load processor and model
self.processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True,
dtype="auto",
device_map="auto" if torch.cuda.is_available() else None
)
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
if not torch.cuda.is_available():
self.model = self.model.to(self.device)
self.model.eval()
# Molmo 2 limits: 128 frames max at 2fps
self.max_frames = 128
self.default_fps = 2.0
# Regex patterns for parsing Molmo output
self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
def _parse_video_points(
self,
text: str,
image_w: int,
image_h: int,
extract_ids: bool = False
) -> List[Tuple]:
"""
Extract video pointing coordinates from Molmo output.
Molmo outputs coordinates in XML-like format:
Where:
- 8.5 = timestamp/frame
- 0, 1 = instance IDs
- 183 216, 245 198 = x, y coordinates (scaled by 1000)
Returns: List of (timestamp, x, y) or (timestamp, id, x, y) tuples
"""
all_points = []
for coord_match in self.COORD_REGEX.finditer(text):
for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)):
timestamp = float(frame_match.group(1))
for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)):
instance_id = point_match.group(1)
# Coordinates are scaled by 1000
x = float(point_match.group(2)) / 1000 * image_w
y = float(point_match.group(3)) / 1000 * image_h
if 0 <= x <= image_w and 0 <= y <= image_h:
if extract_ids:
all_points.append((timestamp, int(instance_id), x, y))
else:
all_points.append((timestamp, x, y))
return all_points
def _parse_multi_image_points(
self,
text: str,
widths: List[int],
heights: List[int]
) -> List[Tuple]:
"""Parse pointing coordinates across multiple images."""
all_points = []
for coord_match in self.COORD_REGEX.finditer(text):
for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)):
# For multi-image, frame_id is 1-indexed image number
image_idx = int(frame_match.group(1)) - 1
if 0 <= image_idx < len(widths):
w, h = widths[image_idx], heights[image_idx]
for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)):
x = float(point_match.group(2)) / 1000 * w
y = float(point_match.group(3)) / 1000 * h
if 0 <= x <= w and 0 <= y <= h:
all_points.append((image_idx + 1, x, y))
return all_points
def _load_image(self, image_data: Any):
"""Load a single image from various formats."""
from PIL import Image
import requests
if isinstance(image_data, Image.Image):
return image_data
elif isinstance(image_data, str):
if image_data.startswith(('http://', 'https://')):
response = requests.get(image_data, stream=True)
return Image.open(response.raw).convert('RGB')
elif image_data.startswith('data:'):
header, encoded = image_data.split(',', 1)
image_bytes = base64.b64decode(encoded)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
else:
image_bytes = base64.b64decode(image_data)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
elif isinstance(image_data, bytes):
return Image.open(io.BytesIO(image_data)).convert('RGB')
else:
raise ValueError(f"Unsupported image input type: {type(image_data)}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process video or images with Molmo 2.
Expected input formats:
1. Video QA:
{
"inputs": ,
"parameters": {
"prompt": "What happens in this video?",
"max_new_tokens": 2048
}
}
2. Video Pointing (Molmo's unique capability):
{
"inputs": ,
"parameters": {
"prompt": "Point to all the people in this video.",
"mode": "pointing",
"max_new_tokens": 2048
}
}
3. Video Tracking:
{
"inputs": ,
"parameters": {
"prompt": "Track the person in the red shirt.",
"mode": "tracking",
"max_new_tokens": 2048
}
}
4. Image Pointing:
{
"inputs": ,
"parameters": {
"prompt": "Point to the Excel cell B2.",
"mode": "pointing"
}
}
5. Multi-image comparison:
{
"inputs": [, ],
"parameters": {
"prompt": "Compare these images."
}
}
Returns:
{
"generated_text": "...",
"points": [(timestamp, x, y), ...], # If pointing mode
"tracks": {"object_id": [(t, x, y), ...]}, # If tracking mode
"video_metadata": {...}
}
"""
inputs = data.get("inputs")
if inputs is None:
inputs = data.get("video") or data.get("image") or data.get("images")
if inputs is None:
raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.")
params = data.get("parameters", {})
mode = params.get("mode", "default")
prompt = params.get("prompt", "Describe this content.")
max_new_tokens = params.get("max_new_tokens", 2048)
try:
if isinstance(inputs, list):
return self._process_multi_image(inputs, prompt, params, max_new_tokens)
elif self._is_video(inputs, params):
return self._process_video(inputs, prompt, params, max_new_tokens)
else:
return self._process_image(inputs, prompt, params, max_new_tokens)
except Exception as e:
return {"error": str(e), "error_type": type(e).__name__}
def _is_video(self, inputs: Any, params: Dict) -> bool:
"""Determine if input is video."""
if params.get("input_type") == "video":
return True
if params.get("input_type") == "image":
return False
if isinstance(inputs, str):
lower = inputs.lower()
video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v']
return any(ext in lower for ext in video_exts)
return False
def _process_video(
self,
video_data: Any,
prompt: str,
params: Dict,
max_new_tokens: int
) -> Dict[str, Any]:
"""Process video with Molmo 2."""
try:
from molmo_utils import process_vision_info
except ImportError:
# Fallback if molmo_utils not available
return self._process_video_fallback(video_data, prompt, params, max_new_tokens)
mode = params.get("mode", "default")
# Prepare video URL or path
if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')):
video_source = video_data
else:
# Write to temp file
if isinstance(video_data, str):
video_bytes = base64.b64decode(video_data)
else:
video_bytes = video_data
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_source = f.name
try:
messages = [
{
"role": "user",
"content": [
dict(type="text", text=prompt),
dict(type="video", video=video_source),
],
}
]
# Process video with molmo_utils
_, videos, video_kwargs = process_vision_info(messages)
videos, video_metadatas = zip(*videos)
videos, video_metadatas = list(videos), list(video_metadatas)
# Get chat template
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process inputs
inputs = self.processor(
videos=videos,
video_metadata=video_metadatas,
text=text,
padding=True,
return_tensors="pt",
**video_kwargs,
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate
with torch.inference_mode():
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Decode
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(
generated_tokens,
skip_special_tokens=True
)
# Get video dimensions
video_w = video_metadatas[0].get("width", 1920)
video_h = video_metadatas[0].get("height", 1080)
result = {
"generated_text": generated_text,
"video_metadata": {
"width": video_w,
"height": video_h,
**{k: v for k, v in video_metadatas[0].items() if k not in ["width", "height"]}
}
}
# Parse coordinates based on mode
if mode in ["pointing", "tracking"]:
points = self._parse_video_points(
generated_text,
video_w,
video_h,
extract_ids=(mode == "tracking")
)
if mode == "tracking":
# Group by object ID for tracking
from collections import defaultdict
tracks = defaultdict(list)
for point in points:
obj_id = point[1]
tracks[obj_id].append((point[0], point[2], point[3]))
result["tracks"] = dict(tracks)
result["num_objects_tracked"] = len(tracks)
else:
result["points"] = points
result["num_points"] = len(points)
return result
finally:
# Clean up temp file if created
if not isinstance(video_data, str) or not video_data.startswith(('http://', 'https://')):
if os.path.exists(video_source):
os.unlink(video_source)
def _process_video_fallback(
self,
video_data: Any,
prompt: str,
params: Dict,
max_new_tokens: int
) -> Dict[str, Any]:
"""Fallback video processing without molmo_utils."""
# Extract frames manually
import cv2
from PIL import Image
# Write video to temp file
if isinstance(video_data, str):
if video_data.startswith(('http://', 'https://')):
import requests
response = requests.get(video_data, stream=True)
video_bytes = response.content
else:
video_bytes = base64.b64decode(video_data)
else:
video_bytes = video_data
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_path = f.name
try:
# Extract frames at 2fps, max 128
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
# Sample frames
target_frames = min(self.max_frames, int(duration * self.default_fps), total_frames)
frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Process as multi-image
content = [dict(type="text", text=prompt)]
for frame in frames:
content.append(dict(type="image", image=frame))
messages = [{"role": "user", "content": content}]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.inference_mode():
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(
generated_tokens,
skip_special_tokens=True
)
mode = params.get("mode", "default")
result = {
"generated_text": generated_text,
"video_metadata": {
"width": video_w,
"height": video_h,
"duration": duration,
"sampled_frames": len(frames)
}
}
if mode in ["pointing", "tracking"]:
points = self._parse_video_points(
generated_text,
video_w,
video_h,
extract_ids=(mode == "tracking")
)
if mode == "tracking":
from collections import defaultdict
tracks = defaultdict(list)
for point in points:
tracks[point[1]].append((point[0], point[2], point[3]))
result["tracks"] = dict(tracks)
else:
result["points"] = points
return result
finally:
if os.path.exists(video_path):
os.unlink(video_path)
def _process_image(
self,
image_data: Any,
prompt: str,
params: Dict,
max_new_tokens: int
) -> Dict[str, Any]:
"""Process a single image."""
image = self._load_image(image_data)
mode = params.get("mode", "default")
messages = [
{
"role": "user",
"content": [
dict(type="text", text=prompt),
dict(type="image", image=image),
],
}
]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.inference_mode():
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(
generated_tokens,
skip_special_tokens=True
)
result = {
"generated_text": generated_text,
"image_size": {"width": image.width, "height": image.height}
}
if mode == "pointing":
points = self._parse_video_points(generated_text, image.width, image.height)
result["points"] = points
result["num_points"] = len(points)
return result
def _process_multi_image(
self,
images_data: List,
prompt: str,
params: Dict,
max_new_tokens: int
) -> Dict[str, Any]:
"""Process multiple images."""
images = [self._load_image(img) for img in images_data]
mode = params.get("mode", "default")
content = [dict(type="text", text=prompt)]
for image in images:
content.append(dict(type="image", image=image))
messages = [{"role": "user", "content": content}]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.inference_mode():
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(
generated_tokens,
skip_special_tokens=True
)
result = {
"generated_text": generated_text,
"num_images": len(images),
"image_sizes": [{"width": img.width, "height": img.height} for img in images]
}
if mode == "pointing":
widths = [img.width for img in images]
heights = [img.height for img in images]
points = self._parse_multi_image_points(generated_text, widths, heights)
result["points"] = points
result["num_points"] = len(points)
return result