trace_visualizer / trace_inference.py
Anthony Liang
update
be80524
"""
Shared trace model inference logic.
This module has minimal top-level imports so eval_server can import
DEFAULT_MODEL_ID and build_prompt without pulling in torch/transformers.
Heavy imports are done lazily inside load_model and run_inference.
"""
import logging
import os
import tempfile
import torch
import re
from typing import List, Optional, Tuple, Dict, Any
from pathlib import Path
logger = logging.getLogger(__name__)
# Constants
DEFAULT_MODEL_ID = "mihirgrao/trace-model"
IGNORE_INDEX = -100
# Global model state
_model_state = {
"model": None,
"processor": None,
"model_id": None,
}
def build_prompt(instruction: str = "", is_oxe: bool = False) -> str:
"""Build the full prompt from task instruction."""
task = instruction.strip() or "predict the trace"
if is_oxe:
return f"<image>\nYou are a Franka robot using the joint control. The task is \"{task}\". Can you predict the trace of the end effector?"
return f"You are a robot. Your task is: \"{task}\". <image> Can you predict the trace of the end effector in this image to complete the task?"
def format_trace_points(trajectories: List) -> str:
"""Format trajectory points for display."""
if not trajectories:
return "No trajectory points extracted."
lines = ["## Predicted Trace Points\n"]
for i, pt in enumerate(trajectories):
if isinstance(pt, (list, tuple)) and len(pt) >= 2:
x, y = pt[0], pt[1]
lines.append(f"- Point {i + 1}: `[{x:.4f}, {y:.4f}]`")
else:
lines.append(f"- Point {i + 1}: `{pt}`")
return "\n".join(lines)
def center_crop_resize(image, size: Tuple[int, int] = (128, 128)):
"""Center crop to square then resize. Requires PIL Image."""
from PIL import Image
w, h = image.size
min_dim = min(w, h)
left = (w - min_dim) // 2
top = (h - min_dim) // 2
cropped = image.crop((left, top, left + min_dim, top + min_dim))
# return cropped.resize(size, Image.Resampling.LANCZOS)
return cropped
def preprocess_image_for_trace(image_path: str) -> Tuple:
"""Load image, center crop and resize to 128x128. Returns (PIL Image, temp_path)."""
from PIL import Image
img = Image.open(image_path).convert("RGB")
img = center_crop_resize(img, (128, 128))
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
img.save(tmp.name)
return img, tmp.name
def _make_abs_paths(base: Path, files: str) -> str:
return f"{(base / files).resolve()}"
def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]:
# Extract and normalize images and videos
images = item.get("image") or []
if isinstance(images, str):
images = [images]
videos = item.get("video") or []
if isinstance(videos, str):
videos = [videos]
# Build media pools with absolute paths
image_pool = [
{"type": "image", "image": _make_abs_paths(base_path, img)} for img in images
]
video_pool = [
{"type": "video", "video": _make_abs_paths(base_path, vid)} for vid in videos
]
messages = []
for turn in item["conversations"]:
role = "user" if turn["from"] == "human" else "assistant"
text: str = turn["value"]
if role == "user":
content = []
# Split text by <image> or <video> placeholders while keeping delimiters
text_parts = re.split(r"(<image>|<video>)", text)
for seg in text_parts:
if seg == "<image>":
if not image_pool:
raise ValueError(
"Number of <image> placeholders exceeds the number of provided images"
)
content.append(image_pool.pop(0))
elif seg == "<video>":
if not video_pool:
raise ValueError(
"Number of <video> placeholders exceeds the number of provided videos"
)
content.append(video_pool.pop(0))
elif seg.strip():
content.append({"type": "text", "text": seg.strip()})
messages.append({"role": role, "content": content})
else:
# Assistant messages contain only text
messages.append({"role": role, "content": [{"type": "text", "text": text}]})
# Check for unused media files
if image_pool:
raise ValueError(
f"{len(image_pool)} image(s) remain unused (not consumed by placeholders)"
)
if video_pool:
raise ValueError(
f"{len(video_pool)} video(s) remain unused (not consumed by placeholders)"
)
return messages
def preprocess_qwen_visual(
sources,
processor,
add_gen_prompt: bool = False,
) -> Dict:
"""
Preprocess one sample for Qwen-VL.
Args:
sources: List of one dict with keys: image, conversations, data_path.
processor: Qwen-VL processor.
add_gen_prompt: If True, add generation prompt so the model generates the
assistant reply (use for inference). If False, full conversation is
tokenized and labels are built for training.
"""
if len(sources) != 1:
raise ValueError(f"Expected 1 source, got {len(sources)}")
source = sources[0]
base_path = Path(source.get("data_path", ""))
messages = _build_messages(source, base_path)
full_result = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=add_gen_prompt,
)
input_ids = full_result["input_ids"]
if isinstance(input_ids, list):
input_ids = torch.tensor(input_ids).unsqueeze(0)
full_result["input_ids"] = input_ids
# Labels are only needed for training; skip for generation
if not add_gen_prompt:
labels = torch.full_like(input_ids, IGNORE_INDEX)
input_ids_flat = input_ids[0].tolist()
L = len(input_ids_flat)
pos = 0
while pos < L:
if input_ids_flat[pos] == 77091:
ans_start = pos + 2
ans_end = ans_start
while ans_end < L and input_ids_flat[ans_end] != 151645:
ans_end += 1
if ans_end < L:
labels[0, ans_start : ans_end + 2] = input_ids[
0, ans_start : ans_end + 2
]
pos = ans_end
pos += 1
full_result["labels"] = labels
return full_result
def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
"""Load the trace model and processor. Returns (success, message)."""
global _model_state
if _model_state["model"] is not None and _model_state["model_id"] == model_id:
return True, f"Model already loaded: {model_id}"
try:
from transformers import AutoModelForImageTextToText, AutoProcessor
if _model_state["model"] is not None:
del _model_state["model"]
del _model_state["processor"]
_model_state["model"] = None
_model_state["processor"] = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"Loading model from {model_id}...")
load_kwargs = {
"dtype": torch.bfloat16,
"device_map": "auto",
}
model = AutoModelForImageTextToText.from_pretrained(
model_id,
**load_kwargs,
)
processor = AutoProcessor.from_pretrained(model_id)
_model_state["model"] = model
_model_state["processor"] = processor
_model_state["model_id"] = model_id
return True, f"Model loaded: {model_id}"
except Exception as e:
logger.exception("Failed to load model")
return False, f"Error loading model: {str(e)}"
def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
"""
Run trace model inference on an image.
Returns: (prediction_text, overlay_image_path, trace_points_text)
"""
success, msg = load_model(model_id)
if not success:
return msg, None, ""
model = _model_state["model"]
processor = _model_state["processor"]
if image_path is None or not os.path.exists(image_path):
return "Please provide a valid image.", None, ""
try:
from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
abs_image_path = os.path.abspath(image_path)
raw_item = {
"id": "single_inference",
"image": [abs_image_path],
"conversations": [
{
"from": "human",
"value": prompt
}
],
"data_path": ""
}
# Preprocessing using internal method
processed = preprocess_qwen_visual([raw_item], processor, add_gen_prompt=True)
# Prepare inputs - passing only what's necessary as per the new method
inputs = {"input_ids": processed["input_ids"].to(model.device)}
if "pixel_values" in processed:
inputs["pixel_values"] = processed["pixel_values"].to(model.device)
if "image_grid_thw" in processed:
inputs["image_grid_thw"] = processed["image_grid_thw"].to(model.device)
# Generate prediction
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
# Trim prompt tokens
trimmed = generated_ids[:, inputs["input_ids"].shape[1]:]
# Decode
prediction = processor.tokenizer.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
trajectory = extract_trajectory_from_text(prediction)
trace_points_text = ""
overlay_path = None
if trajectory:
trace_points_text = format_trace_points(trajectory)
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
overlay_path = f.name
visualize_trajectory_on_image(
trajectory=trajectory,
image_path=abs_image_path,
output_path=overlay_path,
normalized=True
)
else:
trace_points_text = "No trajectory points extracted."
return prediction, overlay_path, trace_points_text
except Exception as e:
logger.exception("Inference failed")
return f"Error: {str(e)}", None, ""