|
|
""" |
|
|
Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints |
|
|
Model: allenai/Molmo2-8B |
|
|
|
|
|
For ProofPath video assessment - video pointing, tracking, and grounded analysis. |
|
|
Unique capability: Returns pixel-level coordinates for objects in videos. |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any, Optional, Tuple, Union |
|
|
import torch |
|
|
import numpy as np |
|
|
import base64 |
|
|
import io |
|
|
import tempfile |
|
|
import os |
|
|
import re |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize Molmo 2 model for video pointing and tracking. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory (provided by HF Inference Endpoints) |
|
|
""" |
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
|
|
|
|
|
|
|
model_id = path if path else "allenai/Molmo2-8B" |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True, |
|
|
dtype="auto", |
|
|
device_map="auto" if torch.cuda.is_available() else None |
|
|
) |
|
|
|
|
|
self.model = AutoModelForImageTextToText.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.max_frames = 128 |
|
|
self.default_fps = 2.0 |
|
|
|
|
|
|
|
|
self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>") |
|
|
self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)") |
|
|
self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})") |
|
|
|
|
|
def _parse_video_points( |
|
|
self, |
|
|
text: str, |
|
|
image_w: int, |
|
|
image_h: int, |
|
|
extract_ids: bool = False |
|
|
) -> List[Tuple]: |
|
|
""" |
|
|
Extract video pointing coordinates from Molmo output. |
|
|
|
|
|
Molmo outputs coordinates in XML-like format: |
|
|
<points alt="object" coords="8.5 0 183 216; 8.5 1 245 198"/> |
|
|
|
|
|
Where: |
|
|
- 8.5 = timestamp/frame |
|
|
- 0, 1 = instance IDs |
|
|
- 183 216, 245 198 = x, y coordinates (scaled by 1000) |
|
|
|
|
|
Returns: List of (timestamp, x, y) or (timestamp, id, x, y) tuples |
|
|
""" |
|
|
all_points = [] |
|
|
|
|
|
for coord_match in self.COORD_REGEX.finditer(text): |
|
|
for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)): |
|
|
timestamp = float(frame_match.group(1)) |
|
|
|
|
|
for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)): |
|
|
instance_id = point_match.group(1) |
|
|
|
|
|
x = float(point_match.group(2)) / 1000 * image_w |
|
|
y = float(point_match.group(3)) / 1000 * image_h |
|
|
|
|
|
if 0 <= x <= image_w and 0 <= y <= image_h: |
|
|
if extract_ids: |
|
|
all_points.append((timestamp, int(instance_id), x, y)) |
|
|
else: |
|
|
all_points.append((timestamp, x, y)) |
|
|
|
|
|
return all_points |
|
|
|
|
|
def _parse_multi_image_points( |
|
|
self, |
|
|
text: str, |
|
|
widths: List[int], |
|
|
heights: List[int] |
|
|
) -> List[Tuple]: |
|
|
"""Parse pointing coordinates across multiple images.""" |
|
|
all_points = [] |
|
|
|
|
|
for coord_match in self.COORD_REGEX.finditer(text): |
|
|
for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)): |
|
|
|
|
|
image_idx = int(frame_match.group(1)) - 1 |
|
|
|
|
|
if 0 <= image_idx < len(widths): |
|
|
w, h = widths[image_idx], heights[image_idx] |
|
|
|
|
|
for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)): |
|
|
x = float(point_match.group(2)) / 1000 * w |
|
|
y = float(point_match.group(3)) / 1000 * h |
|
|
|
|
|
if 0 <= x <= w and 0 <= y <= h: |
|
|
all_points.append((image_idx + 1, x, y)) |
|
|
|
|
|
return all_points |
|
|
|
|
|
def _load_image(self, image_data: Any): |
|
|
"""Load a single image from various formats.""" |
|
|
from PIL import Image |
|
|
import requests |
|
|
|
|
|
if isinstance(image_data, Image.Image): |
|
|
return image_data |
|
|
elif isinstance(image_data, str): |
|
|
if image_data.startswith(('http://', 'https://')): |
|
|
response = requests.get(image_data, stream=True) |
|
|
return Image.open(response.raw).convert('RGB') |
|
|
elif image_data.startswith('data:'): |
|
|
header, encoded = image_data.split(',', 1) |
|
|
image_bytes = base64.b64decode(encoded) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
else: |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
return Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
elif isinstance(image_data, bytes): |
|
|
return Image.open(io.BytesIO(image_data)).convert('RGB') |
|
|
else: |
|
|
raise ValueError(f"Unsupported image input type: {type(image_data)}") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process video or images with Molmo 2. |
|
|
|
|
|
Expected input formats: |
|
|
|
|
|
1. Video QA: |
|
|
{ |
|
|
"inputs": <video_url_or_base64>, |
|
|
"parameters": { |
|
|
"prompt": "What happens in this video?", |
|
|
"max_new_tokens": 2048 |
|
|
} |
|
|
} |
|
|
|
|
|
2. Video Pointing (Molmo's unique capability): |
|
|
{ |
|
|
"inputs": <video_url>, |
|
|
"parameters": { |
|
|
"prompt": "Point to all the people in this video.", |
|
|
"mode": "pointing", |
|
|
"max_new_tokens": 2048 |
|
|
} |
|
|
} |
|
|
|
|
|
3. Video Tracking: |
|
|
{ |
|
|
"inputs": <video_url>, |
|
|
"parameters": { |
|
|
"prompt": "Track the person in the red shirt.", |
|
|
"mode": "tracking", |
|
|
"max_new_tokens": 2048 |
|
|
} |
|
|
} |
|
|
|
|
|
4. Image Pointing: |
|
|
{ |
|
|
"inputs": <image_url>, |
|
|
"parameters": { |
|
|
"prompt": "Point to the Excel cell B2.", |
|
|
"mode": "pointing" |
|
|
} |
|
|
} |
|
|
|
|
|
5. Multi-image comparison: |
|
|
{ |
|
|
"inputs": [<image1>, <image2>], |
|
|
"parameters": { |
|
|
"prompt": "Compare these images." |
|
|
} |
|
|
} |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
"generated_text": "...", |
|
|
"points": [(timestamp, x, y), ...], # If pointing mode |
|
|
"tracks": {"object_id": [(t, x, y), ...]}, # If tracking mode |
|
|
"video_metadata": {...} |
|
|
} |
|
|
""" |
|
|
inputs = data.get("inputs") |
|
|
if inputs is None: |
|
|
inputs = data.get("video") or data.get("image") or data.get("images") |
|
|
if inputs is None: |
|
|
raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.") |
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
mode = params.get("mode", "default") |
|
|
prompt = params.get("prompt", "Describe this content.") |
|
|
max_new_tokens = params.get("max_new_tokens", 2048) |
|
|
|
|
|
try: |
|
|
if isinstance(inputs, list): |
|
|
return self._process_multi_image(inputs, prompt, params, max_new_tokens) |
|
|
elif self._is_video(inputs, params): |
|
|
return self._process_video(inputs, prompt, params, max_new_tokens) |
|
|
else: |
|
|
return self._process_image(inputs, prompt, params, max_new_tokens) |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "error_type": type(e).__name__} |
|
|
|
|
|
def _is_video(self, inputs: Any, params: Dict) -> bool: |
|
|
"""Determine if input is video.""" |
|
|
if params.get("input_type") == "video": |
|
|
return True |
|
|
if params.get("input_type") == "image": |
|
|
return False |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
lower = inputs.lower() |
|
|
video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v'] |
|
|
return any(ext in lower for ext in video_exts) |
|
|
|
|
|
return False |
|
|
|
|
|
def _process_video( |
|
|
self, |
|
|
video_data: Any, |
|
|
prompt: str, |
|
|
params: Dict, |
|
|
max_new_tokens: int |
|
|
) -> Dict[str, Any]: |
|
|
"""Process video with Molmo 2.""" |
|
|
try: |
|
|
from molmo_utils import process_vision_info |
|
|
except ImportError: |
|
|
|
|
|
return self._process_video_fallback(video_data, prompt, params, max_new_tokens) |
|
|
|
|
|
mode = params.get("mode", "default") |
|
|
|
|
|
|
|
|
if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')): |
|
|
video_source = video_data |
|
|
else: |
|
|
|
|
|
if isinstance(video_data, str): |
|
|
video_bytes = base64.b64decode(video_data) |
|
|
else: |
|
|
video_bytes = video_data |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
video_source = f.name |
|
|
|
|
|
try: |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
dict(type="text", text=prompt), |
|
|
dict(type="video", video=video_source), |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
_, videos, video_kwargs = process_vision_info(messages) |
|
|
videos, video_metadatas = zip(*videos) |
|
|
videos, video_metadatas = list(videos), list(video_metadatas) |
|
|
|
|
|
|
|
|
text = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
videos=videos, |
|
|
video_metadata=video_metadatas, |
|
|
text=text, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
**video_kwargs, |
|
|
) |
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) |
|
|
|
|
|
|
|
|
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] |
|
|
generated_text = self.processor.tokenizer.decode( |
|
|
generated_tokens, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
video_w = video_metadatas[0].get("width", 1920) |
|
|
video_h = video_metadatas[0].get("height", 1080) |
|
|
|
|
|
result = { |
|
|
"generated_text": generated_text, |
|
|
"video_metadata": { |
|
|
"width": video_w, |
|
|
"height": video_h, |
|
|
**{k: v for k, v in video_metadatas[0].items() if k not in ["width", "height"]} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if mode in ["pointing", "tracking"]: |
|
|
points = self._parse_video_points( |
|
|
generated_text, |
|
|
video_w, |
|
|
video_h, |
|
|
extract_ids=(mode == "tracking") |
|
|
) |
|
|
|
|
|
if mode == "tracking": |
|
|
|
|
|
from collections import defaultdict |
|
|
tracks = defaultdict(list) |
|
|
for point in points: |
|
|
obj_id = point[1] |
|
|
tracks[obj_id].append((point[0], point[2], point[3])) |
|
|
result["tracks"] = dict(tracks) |
|
|
result["num_objects_tracked"] = len(tracks) |
|
|
else: |
|
|
result["points"] = points |
|
|
result["num_points"] = len(points) |
|
|
|
|
|
return result |
|
|
|
|
|
finally: |
|
|
|
|
|
if not isinstance(video_data, str) or not video_data.startswith(('http://', 'https://')): |
|
|
if os.path.exists(video_source): |
|
|
os.unlink(video_source) |
|
|
|
|
|
def _process_video_fallback( |
|
|
self, |
|
|
video_data: Any, |
|
|
prompt: str, |
|
|
params: Dict, |
|
|
max_new_tokens: int |
|
|
) -> Dict[str, Any]: |
|
|
"""Fallback video processing without molmo_utils.""" |
|
|
|
|
|
import cv2 |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
if isinstance(video_data, str): |
|
|
if video_data.startswith(('http://', 'https://')): |
|
|
import requests |
|
|
response = requests.get(video_data, stream=True) |
|
|
video_bytes = response.content |
|
|
else: |
|
|
video_bytes = base64.b64decode(video_data) |
|
|
else: |
|
|
video_bytes = video_data |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
|
|
f.write(video_bytes) |
|
|
video_path = f.name |
|
|
|
|
|
try: |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
duration = total_frames / fps if fps > 0 else 0 |
|
|
|
|
|
|
|
|
target_frames = min(self.max_frames, int(duration * self.default_fps), total_frames) |
|
|
frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int) |
|
|
|
|
|
frames = [] |
|
|
for idx in frame_indices: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(Image.fromarray(frame_rgb)) |
|
|
|
|
|
video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
cap.release() |
|
|
|
|
|
|
|
|
content = [dict(type="text", text=prompt)] |
|
|
for frame in frames: |
|
|
content.append(dict(type="image", image=frame)) |
|
|
|
|
|
messages = [{"role": "user", "content": content}] |
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True, |
|
|
) |
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) |
|
|
|
|
|
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] |
|
|
generated_text = self.processor.tokenizer.decode( |
|
|
generated_tokens, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
mode = params.get("mode", "default") |
|
|
result = { |
|
|
"generated_text": generated_text, |
|
|
"video_metadata": { |
|
|
"width": video_w, |
|
|
"height": video_h, |
|
|
"duration": duration, |
|
|
"sampled_frames": len(frames) |
|
|
} |
|
|
} |
|
|
|
|
|
if mode in ["pointing", "tracking"]: |
|
|
points = self._parse_video_points( |
|
|
generated_text, |
|
|
video_w, |
|
|
video_h, |
|
|
extract_ids=(mode == "tracking") |
|
|
) |
|
|
|
|
|
if mode == "tracking": |
|
|
from collections import defaultdict |
|
|
tracks = defaultdict(list) |
|
|
for point in points: |
|
|
tracks[point[1]].append((point[0], point[2], point[3])) |
|
|
result["tracks"] = dict(tracks) |
|
|
else: |
|
|
result["points"] = points |
|
|
|
|
|
return result |
|
|
|
|
|
finally: |
|
|
if os.path.exists(video_path): |
|
|
os.unlink(video_path) |
|
|
|
|
|
def _process_image( |
|
|
self, |
|
|
image_data: Any, |
|
|
prompt: str, |
|
|
params: Dict, |
|
|
max_new_tokens: int |
|
|
) -> Dict[str, Any]: |
|
|
"""Process a single image.""" |
|
|
image = self._load_image(image_data) |
|
|
mode = params.get("mode", "default") |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
dict(type="text", text=prompt), |
|
|
dict(type="image", image=image), |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True, |
|
|
) |
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) |
|
|
|
|
|
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] |
|
|
generated_text = self.processor.tokenizer.decode( |
|
|
generated_tokens, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
result = { |
|
|
"generated_text": generated_text, |
|
|
"image_size": {"width": image.width, "height": image.height} |
|
|
} |
|
|
|
|
|
if mode == "pointing": |
|
|
points = self._parse_video_points(generated_text, image.width, image.height) |
|
|
result["points"] = points |
|
|
result["num_points"] = len(points) |
|
|
|
|
|
return result |
|
|
|
|
|
def _process_multi_image( |
|
|
self, |
|
|
images_data: List, |
|
|
prompt: str, |
|
|
params: Dict, |
|
|
max_new_tokens: int |
|
|
) -> Dict[str, Any]: |
|
|
"""Process multiple images.""" |
|
|
images = [self._load_image(img) for img in images_data] |
|
|
mode = params.get("mode", "default") |
|
|
|
|
|
content = [dict(type="text", text=prompt)] |
|
|
for image in images: |
|
|
content.append(dict(type="image", image=image)) |
|
|
|
|
|
messages = [{"role": "user", "content": content}] |
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True, |
|
|
) |
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) |
|
|
|
|
|
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):] |
|
|
generated_text = self.processor.tokenizer.decode( |
|
|
generated_tokens, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
result = { |
|
|
"generated_text": generated_text, |
|
|
"num_images": len(images), |
|
|
"image_sizes": [{"width": img.width, "height": img.height} for img in images] |
|
|
} |
|
|
|
|
|
if mode == "pointing": |
|
|
widths = [img.width for img in images] |
|
|
heights = [img.height for img in images] |
|
|
points = self._parse_multi_image_points(generated_text, widths, heights) |
|
|
result["points"] = points |
|
|
result["num_points"] = len(points) |
|
|
|
|
|
return result |
|
|
|