Spaces:
Sleeping
Sleeping
| """ | |
| Shared trace model inference logic. | |
| This module has minimal top-level imports so eval_server can import | |
| DEFAULT_MODEL_ID and build_prompt without pulling in torch/transformers. | |
| Heavy imports are done lazily inside load_model and run_inference. | |
| """ | |
| import logging | |
| import os | |
| import tempfile | |
| import torch | |
| import re | |
| from typing import List, Optional, Tuple, Dict, Any | |
| from pathlib import Path | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| DEFAULT_MODEL_ID = "mihirgrao/trace-model" | |
| IGNORE_INDEX = -100 | |
| # Global model state | |
| _model_state = { | |
| "model": None, | |
| "processor": None, | |
| "model_id": None, | |
| } | |
| def build_prompt(instruction: str = "", is_oxe: bool = False) -> str: | |
| """Build the full prompt from task instruction.""" | |
| task = instruction.strip() or "predict the trace" | |
| if is_oxe: | |
| return f"<image>\nYou are a Franka robot using the joint control. The task is \"{task}\". Can you predict the trace of the end effector?" | |
| return f"You are a robot. Your task is: \"{task}\". <image> Can you predict the trace of the end effector in this image to complete the task?" | |
| def format_trace_points(trajectories: List) -> str: | |
| """Format trajectory points for display.""" | |
| if not trajectories: | |
| return "No trajectory points extracted." | |
| lines = ["## Predicted Trace Points\n"] | |
| for i, pt in enumerate(trajectories): | |
| if isinstance(pt, (list, tuple)) and len(pt) >= 2: | |
| x, y = pt[0], pt[1] | |
| lines.append(f"- Point {i + 1}: `[{x:.4f}, {y:.4f}]`") | |
| else: | |
| lines.append(f"- Point {i + 1}: `{pt}`") | |
| return "\n".join(lines) | |
| def center_crop_resize(image, size: Tuple[int, int] = (128, 128)): | |
| """Center crop to square then resize. Requires PIL Image.""" | |
| from PIL import Image | |
| w, h = image.size | |
| min_dim = min(w, h) | |
| left = (w - min_dim) // 2 | |
| top = (h - min_dim) // 2 | |
| cropped = image.crop((left, top, left + min_dim, top + min_dim)) | |
| # return cropped.resize(size, Image.Resampling.LANCZOS) | |
| return cropped | |
| def preprocess_image_for_trace(image_path: str) -> Tuple: | |
| """Load image, center crop and resize to 128x128. Returns (PIL Image, temp_path).""" | |
| from PIL import Image | |
| img = Image.open(image_path).convert("RGB") | |
| img = center_crop_resize(img, (128, 128)) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| img.save(tmp.name) | |
| return img, tmp.name | |
| def _make_abs_paths(base: Path, files: str) -> str: | |
| return f"{(base / files).resolve()}" | |
| def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]: | |
| # Extract and normalize images and videos | |
| images = item.get("image") or [] | |
| if isinstance(images, str): | |
| images = [images] | |
| videos = item.get("video") or [] | |
| if isinstance(videos, str): | |
| videos = [videos] | |
| # Build media pools with absolute paths | |
| image_pool = [ | |
| {"type": "image", "image": _make_abs_paths(base_path, img)} for img in images | |
| ] | |
| video_pool = [ | |
| {"type": "video", "video": _make_abs_paths(base_path, vid)} for vid in videos | |
| ] | |
| messages = [] | |
| for turn in item["conversations"]: | |
| role = "user" if turn["from"] == "human" else "assistant" | |
| text: str = turn["value"] | |
| if role == "user": | |
| content = [] | |
| # Split text by <image> or <video> placeholders while keeping delimiters | |
| text_parts = re.split(r"(<image>|<video>)", text) | |
| for seg in text_parts: | |
| if seg == "<image>": | |
| if not image_pool: | |
| raise ValueError( | |
| "Number of <image> placeholders exceeds the number of provided images" | |
| ) | |
| content.append(image_pool.pop(0)) | |
| elif seg == "<video>": | |
| if not video_pool: | |
| raise ValueError( | |
| "Number of <video> placeholders exceeds the number of provided videos" | |
| ) | |
| content.append(video_pool.pop(0)) | |
| elif seg.strip(): | |
| content.append({"type": "text", "text": seg.strip()}) | |
| messages.append({"role": role, "content": content}) | |
| else: | |
| # Assistant messages contain only text | |
| messages.append({"role": role, "content": [{"type": "text", "text": text}]}) | |
| # Check for unused media files | |
| if image_pool: | |
| raise ValueError( | |
| f"{len(image_pool)} image(s) remain unused (not consumed by placeholders)" | |
| ) | |
| if video_pool: | |
| raise ValueError( | |
| f"{len(video_pool)} video(s) remain unused (not consumed by placeholders)" | |
| ) | |
| return messages | |
| def preprocess_qwen_visual( | |
| sources, | |
| processor, | |
| add_gen_prompt: bool = False, | |
| ) -> Dict: | |
| """ | |
| Preprocess one sample for Qwen-VL. | |
| Args: | |
| sources: List of one dict with keys: image, conversations, data_path. | |
| processor: Qwen-VL processor. | |
| add_gen_prompt: If True, add generation prompt so the model generates the | |
| assistant reply (use for inference). If False, full conversation is | |
| tokenized and labels are built for training. | |
| """ | |
| if len(sources) != 1: | |
| raise ValueError(f"Expected 1 source, got {len(sources)}") | |
| source = sources[0] | |
| base_path = Path(source.get("data_path", "")) | |
| messages = _build_messages(source, base_path) | |
| full_result = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| add_generation_prompt=add_gen_prompt, | |
| ) | |
| input_ids = full_result["input_ids"] | |
| if isinstance(input_ids, list): | |
| input_ids = torch.tensor(input_ids).unsqueeze(0) | |
| full_result["input_ids"] = input_ids | |
| # Labels are only needed for training; skip for generation | |
| if not add_gen_prompt: | |
| labels = torch.full_like(input_ids, IGNORE_INDEX) | |
| input_ids_flat = input_ids[0].tolist() | |
| L = len(input_ids_flat) | |
| pos = 0 | |
| while pos < L: | |
| if input_ids_flat[pos] == 77091: | |
| ans_start = pos + 2 | |
| ans_end = ans_start | |
| while ans_end < L and input_ids_flat[ans_end] != 151645: | |
| ans_end += 1 | |
| if ans_end < L: | |
| labels[0, ans_start : ans_end + 2] = input_ids[ | |
| 0, ans_start : ans_end + 2 | |
| ] | |
| pos = ans_end | |
| pos += 1 | |
| full_result["labels"] = labels | |
| return full_result | |
| def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]: | |
| """Load the trace model and processor. Returns (success, message).""" | |
| global _model_state | |
| if _model_state["model"] is not None and _model_state["model_id"] == model_id: | |
| return True, f"Model already loaded: {model_id}" | |
| try: | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| if _model_state["model"] is not None: | |
| del _model_state["model"] | |
| del _model_state["processor"] | |
| _model_state["model"] = None | |
| _model_state["processor"] = None | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info(f"Loading model from {model_id}...") | |
| load_kwargs = { | |
| "dtype": torch.bfloat16, | |
| "device_map": "auto", | |
| } | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| **load_kwargs, | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| _model_state["model"] = model | |
| _model_state["processor"] = processor | |
| _model_state["model_id"] = model_id | |
| return True, f"Model loaded: {model_id}" | |
| except Exception as e: | |
| logger.exception("Failed to load model") | |
| return False, f"Error loading model: {str(e)}" | |
| def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]: | |
| """ | |
| Run trace model inference on an image. | |
| Returns: (prediction_text, overlay_image_path, trace_points_text) | |
| """ | |
| success, msg = load_model(model_id) | |
| if not success: | |
| return msg, None, "" | |
| model = _model_state["model"] | |
| processor = _model_state["processor"] | |
| if image_path is None or not os.path.exists(image_path): | |
| return "Please provide a valid image.", None, "" | |
| try: | |
| from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image | |
| abs_image_path = os.path.abspath(image_path) | |
| raw_item = { | |
| "id": "single_inference", | |
| "image": [abs_image_path], | |
| "conversations": [ | |
| { | |
| "from": "human", | |
| "value": prompt | |
| } | |
| ], | |
| "data_path": "" | |
| } | |
| # Preprocessing using internal method | |
| processed = preprocess_qwen_visual([raw_item], processor, add_gen_prompt=True) | |
| # Prepare inputs - passing only what's necessary as per the new method | |
| inputs = {"input_ids": processed["input_ids"].to(model.device)} | |
| if "pixel_values" in processed: | |
| inputs["pixel_values"] = processed["pixel_values"].to(model.device) | |
| if "image_grid_thw" in processed: | |
| inputs["image_grid_thw"] = processed["image_grid_thw"].to(model.device) | |
| # Generate prediction | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| ) | |
| # Trim prompt tokens | |
| trimmed = generated_ids[:, inputs["input_ids"].shape[1]:] | |
| # Decode | |
| prediction = processor.tokenizer.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| trajectory = extract_trajectory_from_text(prediction) | |
| trace_points_text = "" | |
| overlay_path = None | |
| if trajectory: | |
| trace_points_text = format_trace_points(trajectory) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f: | |
| overlay_path = f.name | |
| visualize_trajectory_on_image( | |
| trajectory=trajectory, | |
| image_path=abs_image_path, | |
| output_path=overlay_path, | |
| normalized=True | |
| ) | |
| else: | |
| trace_points_text = "No trajectory points extracted." | |
| return prediction, overlay_path, trace_points_text | |
| except Exception as e: | |
| logger.exception("Inference failed") | |
| return f"Error: {str(e)}", None, "" |