""" Shared trace model inference logic. This module has minimal top-level imports so eval_server can import DEFAULT_MODEL_ID and build_prompt without pulling in torch/transformers. Heavy imports are done lazily inside load_model and run_inference. """ import logging import os import tempfile import torch import re from typing import List, Optional, Tuple, Dict, Any from pathlib import Path logger = logging.getLogger(__name__) # Constants DEFAULT_MODEL_ID = "mihirgrao/trace-model" IGNORE_INDEX = -100 # Global model state _model_state = { "model": None, "processor": None, "model_id": None, } def build_prompt(instruction: str = "", is_oxe: bool = False) -> str: """Build the full prompt from task instruction.""" task = instruction.strip() or "predict the trace" if is_oxe: return f"\nYou are a Franka robot using the joint control. The task is \"{task}\". Can you predict the trace of the end effector?" return f"You are a robot. Your task is: \"{task}\". Can you predict the trace of the end effector in this image to complete the task?" def format_trace_points(trajectories: List) -> str: """Format trajectory points for display.""" if not trajectories: return "No trajectory points extracted." lines = ["## Predicted Trace Points\n"] for i, pt in enumerate(trajectories): if isinstance(pt, (list, tuple)) and len(pt) >= 2: x, y = pt[0], pt[1] lines.append(f"- Point {i + 1}: `[{x:.4f}, {y:.4f}]`") else: lines.append(f"- Point {i + 1}: `{pt}`") return "\n".join(lines) def center_crop_resize(image, size: Tuple[int, int] = (128, 128)): """Center crop to square then resize. Requires PIL Image.""" from PIL import Image w, h = image.size min_dim = min(w, h) left = (w - min_dim) // 2 top = (h - min_dim) // 2 cropped = image.crop((left, top, left + min_dim, top + min_dim)) # return cropped.resize(size, Image.Resampling.LANCZOS) return cropped def preprocess_image_for_trace(image_path: str) -> Tuple: """Load image, center crop and resize to 128x128. Returns (PIL Image, temp_path).""" from PIL import Image img = Image.open(image_path).convert("RGB") img = center_crop_resize(img, (128, 128)) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") img.save(tmp.name) return img, tmp.name def _make_abs_paths(base: Path, files: str) -> str: return f"{(base / files).resolve()}" def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]: # Extract and normalize images and videos images = item.get("image") or [] if isinstance(images, str): images = [images] videos = item.get("video") or [] if isinstance(videos, str): videos = [videos] # Build media pools with absolute paths image_pool = [ {"type": "image", "image": _make_abs_paths(base_path, img)} for img in images ] video_pool = [ {"type": "video", "video": _make_abs_paths(base_path, vid)} for vid in videos ] messages = [] for turn in item["conversations"]: role = "user" if turn["from"] == "human" else "assistant" text: str = turn["value"] if role == "user": content = [] # Split text by or