molmo / handler.py
peterproofpath's picture
Update handler.py
be41be8 verified
raw
history blame
14.7 kB
"""
Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints
Model: allenai/Molmo2-7B-1225
For ProofPath video assessment - video pointing, tracking, and grounded analysis.
Unique capability: Returns pixel-level coordinates for objects in videos.
"""
from typing import Dict, List, Any, Optional, Tuple, Union
import torch
import numpy as np
import base64
import io
import tempfile
import os
import re
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize Molmo 2 model for video pointing and tracking.
Args:
path: Path to the model directory (ignored - we always load from HF hub)
"""
# IMPORTANT: Always load from HF hub, not the repository path
model_id = "allenai/Molmo2-7B-1225"
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load processor and model with trust_remote_code
from transformers import AutoProcessor, AutoModelForCausalLM
self.processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
if not torch.cuda.is_available():
self.model = self.model.to(self.device)
self.model.eval()
# Molmo 2 limits
self.max_frames = 128
self.default_fps = 2.0
# Regex patterns for parsing Molmo pointing output
# Molmo outputs: <point x="123" y="456" alt="description">
self.POINT_REGEX = re.compile(r'<point\s+x="([0-9.]+)"\s+y="([0-9.]+)"(?:\s+alt="([^"]*)")?>')
self.POINTS_REGEX = re.compile(r'<points>(.*?)</points>', re.DOTALL)
def _parse_points(self, text: str, image_w: int, image_h: int) -> List[Dict]:
"""
Extract pointing coordinates from Molmo output.
Molmo outputs coordinates as percentages (0-100).
"""
points = []
for match in self.POINT_REGEX.finditer(text):
x_pct = float(match.group(1))
y_pct = float(match.group(2))
alt = match.group(3) or ""
# Convert percentage to pixels
x = (x_pct / 100) * image_w
y = (y_pct / 100) * image_h
points.append({
"x": x,
"y": y,
"x_pct": x_pct,
"y_pct": y_pct,
"label": alt
})
return points
def _load_image(self, image_data: Any):
"""Load a single image from various formats."""
from PIL import Image
import requests
if isinstance(image_data, Image.Image):
return image_data
elif isinstance(image_data, str):
if image_data.startswith(('http://', 'https://')):
response = requests.get(image_data, stream=True)
return Image.open(response.raw).convert('RGB')
elif image_data.startswith('data:'):
header, encoded = image_data.split(',', 1)
image_bytes = base64.b64decode(encoded)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
else:
image_bytes = base64.b64decode(image_data)
return Image.open(io.BytesIO(image_bytes)).convert('RGB')
elif isinstance(image_data, bytes):
return Image.open(io.BytesIO(image_data)).convert('RGB')
else:
raise ValueError(f"Unsupported image input type: {type(image_data)}")
def _load_video_frames(
self,
video_data: Any,
max_frames: int = 128,
fps: float = 2.0
) -> tuple:
"""Load video frames from various input formats."""
import cv2
from PIL import Image
# Decode video to temp file if needed
if isinstance(video_data, str):
if video_data.startswith(('http://', 'https://')):
import requests
response = requests.get(video_data, stream=True)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
video_path = f.name
elif video_data.startswith('data:'):
header, encoded = video_data.split(',', 1)
video_bytes = base64.b64decode(encoded)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_path = f.name
else:
video_bytes = base64.b64decode(video_data)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
video_path = f.name
elif isinstance(video_data, bytes):
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_data)
video_path = f.name
else:
raise ValueError(f"Unsupported video input type: {type(video_data)}")
try:
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / video_fps if video_fps > 0 else 0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Calculate frame indices
target_frames = min(max_frames, int(duration * fps), total_frames)
if target_frames <= 0:
target_frames = min(max_frames, total_frames)
frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
return frames, {
"duration": duration,
"total_frames": total_frames,
"sampled_frames": len(frames),
"video_fps": video_fps,
"width": width,
"height": height
}
finally:
if os.path.exists(video_path):
os.unlink(video_path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process video or images with Molmo 2.
Expected input formats:
1. Image analysis with pointing:
{
"inputs": <image_url_or_base64>,
"parameters": {
"prompt": "Point to the Excel cell B2.",
"max_new_tokens": 1024
}
}
2. Video analysis (processes as multi-frame):
{
"inputs": <video_url>,
"parameters": {
"prompt": "What happens in this video?",
"max_frames": 64,
"max_new_tokens": 2048
}
}
3. Multi-image comparison:
{
"inputs": [<image1>, <image2>],
"parameters": {
"prompt": "Compare these screenshots."
}
}
Returns:
{
"generated_text": "...",
"points": [{"x": 123, "y": 456, "label": "..."}], # If pointing detected
"image_size": {...}
}
"""
inputs = data.get("inputs")
if inputs is None:
inputs = data.get("video") or data.get("image") or data.get("images")
if inputs is None:
raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.")
params = data.get("parameters", {})
prompt = params.get("prompt", "Describe this image.")
max_new_tokens = params.get("max_new_tokens", 1024)
try:
if isinstance(inputs, list):
return self._process_multi_image(inputs, prompt, max_new_tokens)
elif self._is_video(inputs, params):
return self._process_video(inputs, prompt, params, max_new_tokens)
else:
return self._process_image(inputs, prompt, max_new_tokens)
except Exception as e:
import traceback
return {"error": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()}
def _is_video(self, inputs: Any, params: Dict) -> bool:
"""Determine if input is video."""
if params.get("input_type") == "video":
return True
if params.get("input_type") == "image":
return False
if isinstance(inputs, str):
lower = inputs.lower()
video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v']
return any(ext in lower for ext in video_exts)
return False
def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
"""Process a single image."""
image = self._load_image(image_data)
# Process with Molmo processor
inputs = self.processor.process(
images=[image],
text=prompt,
)
# Move to device
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
# Generate
with torch.inference_mode():
output = self.model.generate_from_batch(
inputs,
generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
tokenizer=self.processor.tokenizer,
)
# Decode
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
result = {
"generated_text": generated_text,
"image_size": {"width": image.width, "height": image.height}
}
# Parse any pointing coordinates
points = self._parse_points(generated_text, image.width, image.height)
if points:
result["points"] = points
result["num_points"] = len(points)
return result
def _process_video(
self,
video_data: Any,
prompt: str,
params: Dict,
max_new_tokens: int
) -> Dict[str, Any]:
"""Process video by sampling frames."""
max_frames = min(params.get("max_frames", 32), self.max_frames)
fps = params.get("fps", self.default_fps)
frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
if not frames:
raise ValueError("No frames could be extracted from video")
# For video, we process key frames
# Molmo can handle multiple images - we'll sample representative frames
sample_indices = np.linspace(0, len(frames) - 1, min(8, len(frames)), dtype=int)
sample_frames = [frames[i] for i in sample_indices]
# Modify prompt to indicate video context
video_prompt = f"These are {len(sample_frames)} frames from a video. {prompt}"
# Process with Molmo
inputs = self.processor.process(
images=sample_frames,
text=video_prompt,
)
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
with torch.inference_mode():
output = self.model.generate_from_batch(
inputs,
generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
tokenizer=self.processor.tokenizer,
)
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
result = {
"generated_text": generated_text,
"video_metadata": video_metadata,
"frames_analyzed": len(sample_frames)
}
# Parse points using first frame dimensions
points = self._parse_points(generated_text, video_metadata["width"], video_metadata["height"])
if points:
result["points"] = points
result["num_points"] = len(points)
return result
def _process_multi_image(
self,
images_data: List,
prompt: str,
max_new_tokens: int
) -> Dict[str, Any]:
"""Process multiple images."""
images = [self._load_image(img) for img in images_data]
# Process with Molmo
inputs = self.processor.process(
images=images,
text=prompt,
)
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
with torch.inference_mode():
output = self.model.generate_from_batch(
inputs,
generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
tokenizer=self.processor.tokenizer,
)
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
result = {
"generated_text": generated_text,
"num_images": len(images),
"image_sizes": [{"width": img.width, "height": img.height} for img in images]
}
# Parse points using first image dimensions
if images:
points = self._parse_points(generated_text, images[0].width, images[0].height)
if points:
result["points"] = points
result["num_points"] = len(points)
return result