Anthony Liang commited on
Commit
be80524
·
1 Parent(s): 130aa46
Files changed (3) hide show
  1. app.py +22 -37
  2. eval_server.py +6 -17
  3. trace_inference.py +115 -316
app.py CHANGED
@@ -19,12 +19,9 @@ import requests
19
 
20
  from trace_inference import (
21
  DEFAULT_MODEL_ID,
22
- TRACE_FORMAT,
23
- build_franka_prompt,
24
  build_prompt,
25
  preprocess_image_for_trace,
26
  run_inference,
27
- run_inference_qwenvl,
28
  )
29
  from trajectory_viz import visualize_trajectory_on_image
30
 
@@ -124,7 +121,7 @@ def run_inference_via_server(
124
  image_path: str,
125
  instruction: str,
126
  server_url: str,
127
- use_qwenvl: bool = False,
128
  ) -> Tuple[str, Optional[str]]:
129
  """Run inference via trace eval server. Returns (prediction, overlay_path)."""
130
  with open(image_path, "rb") as f:
@@ -135,7 +132,7 @@ def run_inference_via_server(
135
  json={
136
  "image_base64": image_b64,
137
  "instruction": instruction,
138
- "use_qwenvl": use_qwenvl,
139
  },
140
  timeout=120.0,
141
  headers=headers,
@@ -177,7 +174,7 @@ def run_inference_via_server(
177
 
178
  # --- Gradio UI ---
179
  try:
180
- demo = gr.Blocks(title="Trace Model Visualizer", theme=gr.themes.Soft())
181
  except TypeError:
182
  demo = gr.Blocks(title="Trace Model Visualizer")
183
 
@@ -303,17 +300,18 @@ with demo:
303
  lines=4,
304
  info="Enter a task description in natural language. The model predicts the trace for this instruction.",
305
  )
 
 
 
 
 
 
306
  gr.Markdown("### Local model (if no eval server selected)")
307
  model_id_input = gr.Textbox(
308
  label="Model ID",
309
  value=DEFAULT_MODEL_ID,
310
  info="Hugging Face model ID (auto-loads on first inference if no eval server selected)",
311
  )
312
- use_qwenvl_checkbox = gr.Checkbox(
313
- label="Use Franka / qwenvl inference",
314
- value=False,
315
- info="Uses preprocess_qwen_visual (qwenvl). Requires qwen-vl-finetune on PYTHONPATH.",
316
- )
317
  run_btn = gr.Button("Run Inference", variant="primary")
318
 
319
  with gr.Column(scale=1):
@@ -334,7 +332,7 @@ with demo:
334
  "Select an eval server from the sidebar (auto-connects), or run inference with local model."
335
  )
336
 
337
- def on_run_inference(image_path, instruction, model_id, server_url, use_qwenvl):
338
  if image_path is None:
339
  return (
340
  "",
@@ -343,48 +341,34 @@ with demo:
343
  "**Status:** Please upload an image.",
344
  )
345
 
 
346
  if server_url:
347
- prompt = (
348
- build_franka_prompt(instruction or "predict the trace")
349
- if use_qwenvl
350
- else build_prompt(instruction)
351
- )
352
  prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
353
  pred, overlay_path = run_inference_via_server(
354
- image_path, instruction, server_url, use_qwenvl
355
  )
356
- elif use_qwenvl:
357
- prompt = build_franka_prompt(instruction or "predict the trace")
358
- prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
359
- output_dict, pred, overlay_path, trace_text = run_inference_qwenvl(
360
- image_path, instruction or "predict the trace", model_id
361
- )
362
- if not output_dict and trace_text and "qwenvl package not found" in trace_text:
363
- pred = trace_text
364
- overlay_path = None
365
  else:
366
- prompt = build_prompt(instruction)
367
  prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
368
  pred, overlay_path, _ = run_inference(image_path, prompt, model_id)
369
 
370
  status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
371
  return prompt_md, pred, overlay_path, status
372
 
373
- def update_prompt_display(instruction: str, use_qwenvl: bool):
374
- if use_qwenvl:
375
- prompt = build_franka_prompt(instruction or "predict the trace")
376
- else:
377
- prompt = build_prompt(instruction)
378
  return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
379
 
380
  instruction_input.change(
381
  fn=update_prompt_display,
382
- inputs=[instruction_input, use_qwenvl_checkbox],
383
  outputs=[prompt_display],
384
  )
385
- use_qwenvl_checkbox.change(
386
  fn=update_prompt_display,
387
- inputs=[instruction_input, use_qwenvl_checkbox],
388
  outputs=[prompt_display],
389
  )
390
 
@@ -395,7 +379,7 @@ with demo:
395
  instruction_input,
396
  model_id_input,
397
  server_url_state,
398
- use_qwenvl_checkbox,
399
  ],
400
  outputs=[
401
  prompt_display,
@@ -413,6 +397,7 @@ def main():
413
  server_name="0.0.0.0",
414
  server_port=7860,
415
  share=False,
 
416
  )
417
 
418
 
 
19
 
20
  from trace_inference import (
21
  DEFAULT_MODEL_ID,
 
 
22
  build_prompt,
23
  preprocess_image_for_trace,
24
  run_inference,
 
25
  )
26
  from trajectory_viz import visualize_trajectory_on_image
27
 
 
121
  image_path: str,
122
  instruction: str,
123
  server_url: str,
124
+ is_oxe: bool = False,
125
  ) -> Tuple[str, Optional[str]]:
126
  """Run inference via trace eval server. Returns (prediction, overlay_path)."""
127
  with open(image_path, "rb") as f:
 
132
  json={
133
  "image_base64": image_b64,
134
  "instruction": instruction,
135
+ "is_oxe": is_oxe,
136
  },
137
  timeout=120.0,
138
  headers=headers,
 
174
 
175
  # --- Gradio UI ---
176
  try:
177
+ demo = gr.Blocks(title="Trace Model Visualizer")
178
  except TypeError:
179
  demo = gr.Blocks(title="Trace Model Visualizer")
180
 
 
300
  lines=4,
301
  info="Enter a task description in natural language. The model predicts the trace for this instruction.",
302
  )
303
+ prompt_format = gr.Radio(
304
+ choices=["LIBERO", "OXE"],
305
+ value="LIBERO",
306
+ label="Prompt Format",
307
+ info="Switch between LIBERO and OXE training formats.",
308
+ )
309
  gr.Markdown("### Local model (if no eval server selected)")
310
  model_id_input = gr.Textbox(
311
  label="Model ID",
312
  value=DEFAULT_MODEL_ID,
313
  info="Hugging Face model ID (auto-loads on first inference if no eval server selected)",
314
  )
 
 
 
 
 
315
  run_btn = gr.Button("Run Inference", variant="primary")
316
 
317
  with gr.Column(scale=1):
 
332
  "Select an eval server from the sidebar (auto-connects), or run inference with local model."
333
  )
334
 
335
+ def on_run_inference(image_path, instruction, model_id, server_url, prompt_mode):
336
  if image_path is None:
337
  return (
338
  "",
 
341
  "**Status:** Please upload an image.",
342
  )
343
 
344
+ is_oxe = (prompt_mode == "OXE")
345
  if server_url:
346
+ prompt = build_prompt(instruction, is_oxe=is_oxe)
 
 
 
 
347
  prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
348
  pred, overlay_path = run_inference_via_server(
349
+ image_path, instruction, server_url, is_oxe=is_oxe
350
  )
 
 
 
 
 
 
 
 
 
351
  else:
352
+ prompt = build_prompt(instruction, is_oxe=is_oxe)
353
  prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
354
  pred, overlay_path, _ = run_inference(image_path, prompt, model_id)
355
 
356
  status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
357
  return prompt_md, pred, overlay_path, status
358
 
359
+ def update_prompt_display(instruction: str, prompt_mode: str):
360
+ is_oxe = (prompt_mode == "OXE")
361
+ prompt = build_prompt(instruction, is_oxe=is_oxe)
 
 
362
  return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
363
 
364
  instruction_input.change(
365
  fn=update_prompt_display,
366
+ inputs=[instruction_input, prompt_format],
367
  outputs=[prompt_display],
368
  )
369
+ prompt_format.change(
370
  fn=update_prompt_display,
371
+ inputs=[instruction_input, prompt_format],
372
  outputs=[prompt_display],
373
  )
374
 
 
379
  instruction_input,
380
  model_id_input,
381
  server_url_state,
382
+ prompt_format,
383
  ],
384
  outputs=[
385
  prompt_display,
 
397
  server_name="0.0.0.0",
398
  server_port=7860,
399
  share=False,
400
+ theme=gr.themes.Soft(),
401
  )
402
 
403
 
eval_server.py CHANGED
@@ -33,7 +33,6 @@ from trace_inference import (
33
  build_prompt,
34
  load_model,
35
  run_inference,
36
- run_inference_qwenvl,
37
  )
38
  from trace_inference import _model_state as _trace_model_state
39
  from trajectory_viz import extract_trajectory_from_text
@@ -69,13 +68,12 @@ class TraceEvalServer:
69
  image_path: Optional[str] = None,
70
  image_base64: Optional[str] = None,
71
  instruction: str = "",
72
- use_qwenvl: bool = False,
73
  ) -> Dict[str, Any]:
74
  """
75
  Run inference on a single image.
76
 
77
  Provide either image_path (file path) or image_base64 (base64-encoded image).
78
- If use_qwenvl=True, uses run_inference_qwenvl (Franka-style, requires qwenvl).
79
  """
80
  if image_path is None and image_base64 is None:
81
  return {"error": "Provide image_path or image_base64"}
@@ -102,15 +100,8 @@ class TraceEvalServer:
102
  return {"error": f"Invalid image data: {e}"}
103
 
104
  try:
105
- if use_qwenvl:
106
- output_dict, prediction, _, trace_text = run_inference_qwenvl(
107
- image_path, instruction or "predict the trace", self.model_id
108
- )
109
- if not output_dict and trace_text and "qwenvl package not found" in trace_text:
110
- return {"error": trace_text}
111
- else:
112
- prompt = build_prompt(instruction)
113
- prediction, _, _ = run_inference(image_path, prompt, self.model_id)
114
  finally:
115
  if temp_file_path and os.path.exists(temp_file_path):
116
  try:
@@ -126,8 +117,6 @@ class TraceEvalServer:
126
  "prediction": prediction,
127
  "trajectory": trajectory,
128
  }
129
- if use_qwenvl and output_dict:
130
- result["output_dict"] = output_dict
131
  return result
132
 
133
  def predict_batch(
@@ -146,7 +135,7 @@ class TraceEvalServer:
146
  image_path=sample.get("image_path"),
147
  image_base64=sample.get("image_base64"),
148
  instruction=sample.get("instruction", ""),
149
- use_qwenvl=sample.get("use_qwenvl", False),
150
  )
151
  elapsed = time.time() - start
152
 
@@ -214,14 +203,14 @@ def create_app(
214
  - image_path: (optional) path to image file
215
  - image_base64: (optional) base64-encoded image
216
  - instruction: natural language task description
217
- - use_qwenvl: (optional) if true, use Franka/qwenvl inference (requires qwenvl)
218
  """
219
  body = await request.json()
220
  return trace_server.predict_one(
221
  image_path=body.get("image_path"),
222
  image_base64=body.get("image_base64"),
223
  instruction=body.get("instruction", ""),
224
- use_qwenvl=body.get("use_qwenvl", False),
225
  )
226
 
227
  @app.post("/predict_batch")
 
33
  build_prompt,
34
  load_model,
35
  run_inference,
 
36
  )
37
  from trace_inference import _model_state as _trace_model_state
38
  from trajectory_viz import extract_trajectory_from_text
 
68
  image_path: Optional[str] = None,
69
  image_base64: Optional[str] = None,
70
  instruction: str = "",
71
+ is_oxe: bool = False,
72
  ) -> Dict[str, Any]:
73
  """
74
  Run inference on a single image.
75
 
76
  Provide either image_path (file path) or image_base64 (base64-encoded image).
 
77
  """
78
  if image_path is None and image_base64 is None:
79
  return {"error": "Provide image_path or image_base64"}
 
100
  return {"error": f"Invalid image data: {e}"}
101
 
102
  try:
103
+ prompt = build_prompt(instruction, is_oxe=is_oxe)
104
+ prediction, _, _ = run_inference(image_path, prompt, self.model_id)
 
 
 
 
 
 
 
105
  finally:
106
  if temp_file_path and os.path.exists(temp_file_path):
107
  try:
 
117
  "prediction": prediction,
118
  "trajectory": trajectory,
119
  }
 
 
120
  return result
121
 
122
  def predict_batch(
 
135
  image_path=sample.get("image_path"),
136
  image_base64=sample.get("image_base64"),
137
  instruction=sample.get("instruction", ""),
138
+ is_oxe=sample.get("is_oxe", False),
139
  )
140
  elapsed = time.time() - start
141
 
 
203
  - image_path: (optional) path to image file
204
  - image_base64: (optional) base64-encoded image
205
  - instruction: natural language task description
206
+ - is_oxe: (optional) if true, use OXE prompt format
207
  """
208
  body = await request.json()
209
  return trace_server.predict_one(
210
  image_path=body.get("image_path"),
211
  image_base64=body.get("image_base64"),
212
  instruction=body.get("instruction", ""),
213
+ is_oxe=body.get("is_oxe", False),
214
  )
215
 
216
  @app.post("/predict_batch")
trace_inference.py CHANGED
@@ -9,22 +9,16 @@ Heavy imports are done lazily inside load_model and run_inference.
9
  import logging
10
  import os
11
  import tempfile
12
- from typing import List, Optional, Tuple
13
  import re
 
14
  from pathlib import Path
15
- import torch
16
- from typing import Dict, Any
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- # Constants (no heavy deps)
21
  DEFAULT_MODEL_ID = "mihirgrao/trace-model"
22
- TRACE_FORMAT = (
23
- "Predict the trajectory or trace in this image. "
24
- "Output the coordinates as a list of [x, y] pairs, e.g. [[0.1, 0.2], [0.3, 0.4], ...]. "
25
- "Use normalized coordinates between 0 and 1."
26
- )
27
- PREPROCESS_SIZE = (128, 128)
28
 
29
  # Global model state
30
  _model_state = {
@@ -34,11 +28,12 @@ _model_state = {
34
  }
35
 
36
 
37
- def build_prompt(instruction: str = "") -> str:
38
- """Build the full prompt from task instruction + trace format."""
39
- if instruction.strip():
40
- return f"Task: {instruction.strip()}\n\n{TRACE_FORMAT}"
41
- return TRACE_FORMAT
 
42
 
43
 
44
  def format_trace_points(trajectories: List) -> str:
@@ -55,7 +50,7 @@ def format_trace_points(trajectories: List) -> str:
55
  return "\n".join(lines)
56
 
57
 
58
- def center_crop_resize(image, size: Tuple[int, int] = PREPROCESS_SIZE):
59
  """Center crop to square then resize. Requires PIL Image."""
60
  from PIL import Image
61
 
@@ -64,7 +59,8 @@ def center_crop_resize(image, size: Tuple[int, int] = PREPROCESS_SIZE):
64
  left = (w - min_dim) // 2
65
  top = (h - min_dim) // 2
66
  cropped = image.crop((left, top, left + min_dim, top + min_dim))
67
- return cropped.resize(size, Image.Resampling.LANCZOS)
 
68
 
69
 
70
  def preprocess_image_for_trace(image_path: str) -> Tuple:
@@ -72,195 +68,16 @@ def preprocess_image_for_trace(image_path: str) -> Tuple:
72
  from PIL import Image
73
 
74
  img = Image.open(image_path).convert("RGB")
75
- img = center_crop_resize(img, PREPROCESS_SIZE)
76
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
77
  img.save(tmp.name)
78
  return img, tmp.name
79
 
80
 
81
- def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
82
- """Load the trace model and processor. Returns (success, message)."""
83
- global _model_state
84
-
85
- if _model_state["model"] is not None and _model_state["model_id"] == model_id:
86
- return True, f"Model already loaded: {model_id}"
87
-
88
- try:
89
- import torch
90
- from transformers import AutoModelForImageTextToText, AutoProcessor
91
-
92
- if _model_state["model"] is not None:
93
- del _model_state["model"]
94
- del _model_state["processor"]
95
- _model_state["model"] = None
96
- _model_state["processor"] = None
97
- if torch.cuda.is_available():
98
- torch.cuda.empty_cache()
99
-
100
- load_kwargs = {
101
- "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
102
- "device_map": "auto" if torch.cuda.is_available() else None,
103
- }
104
- try:
105
- if torch.cuda.is_available():
106
- load_kwargs["attn_implementation"] = "flash_attention_2"
107
- model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
108
- except (ValueError, ImportError):
109
- load_kwargs.pop("attn_implementation", None)
110
- model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
111
- processor = AutoProcessor.from_pretrained(model_id)
112
-
113
- _model_state["model"] = model
114
- _model_state["processor"] = processor
115
- _model_state["model_id"] = model_id
116
-
117
- return True, f"Model loaded: {model_id}"
118
- except Exception as e:
119
- logger.exception("Failed to load model")
120
- return False, f"Error loading model: {str(e)}"
121
-
122
-
123
- def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
124
- """
125
- Run trace model inference on an image.
126
- Returns: (prediction_text, overlay_image_path, trace_points_text)
127
- """
128
- success, msg = load_model(model_id)
129
- if not success:
130
- return msg, None, ""
131
-
132
- model = _model_state["model"]
133
- processor = _model_state["processor"]
134
-
135
- if image_path is None or not os.path.exists(image_path):
136
- return "Please provide a valid image.", None, ""
137
-
138
- try:
139
- from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
140
-
141
- try:
142
- from qwen_vl_utils import process_vision_info
143
- except ImportError:
144
- process_vision_info = None
145
-
146
- preprocessed_path = None
147
- try:
148
- _, preprocessed_path = preprocess_image_for_trace(image_path)
149
- image_uri = f"file://{os.path.abspath(preprocessed_path)}"
150
-
151
- messages = [
152
- {
153
- "role": "user",
154
- "content": [
155
- {"type": "image", "image": image_uri},
156
- {"type": "text", "text": prompt},
157
- ],
158
- }
159
- ]
160
-
161
- text = processor.apply_chat_template(
162
- messages, tokenize=False, add_generation_prompt=True
163
- )
164
-
165
- if process_vision_info is not None:
166
- process_kwargs = {"return_video_kwargs": True, "return_video_metadata": True}
167
- if hasattr(processor, "image_processor") and hasattr(
168
- processor.image_processor, "patch_size"
169
- ):
170
- process_kwargs["image_patch_size"] = processor.image_processor.patch_size
171
- image_inputs, video_inputs, video_kwargs = process_vision_info(
172
- messages, **process_kwargs
173
- )
174
- else:
175
- from PIL import Image
176
-
177
- pil_image = Image.open(image_path).convert("RGB")
178
- image_inputs = [pil_image]
179
- video_inputs = None
180
- video_kwargs = {}
181
-
182
- processor_kwargs = {
183
- "text": [text],
184
- "images": image_inputs,
185
- "padding": True,
186
- "return_tensors": "pt",
187
- "do_resize": False,
188
- }
189
- if video_inputs is not None and len(video_inputs) > 0:
190
- if isinstance(video_inputs[0], tuple):
191
- videos, video_metadatas = zip(*video_inputs)
192
- processor_kwargs["videos"] = list(videos)
193
- processor_kwargs["video_metadata"] = list(video_metadatas)
194
- else:
195
- processor_kwargs["videos"] = video_inputs
196
- if video_kwargs:
197
- processor_kwargs.update(video_kwargs)
198
-
199
- import torch
200
-
201
- inputs = processor(**processor_kwargs)
202
- inputs = {k: v.to(model.device) for k, v in inputs.items() if hasattr(v, "to")}
203
-
204
- with torch.no_grad():
205
- generated_ids = model.generate(
206
- **inputs, max_new_tokens=1024, do_sample=False
207
- )
208
-
209
- input_ids = inputs["input_ids"]
210
- generated_ids_trimmed = [
211
- out[len(inp) :] for inp, out in zip(input_ids, generated_ids)
212
- ]
213
- prediction = processor.batch_decode(
214
- generated_ids_trimmed,
215
- skip_special_tokens=True,
216
- clean_up_tokenization_spaces=False,
217
- )[0]
218
-
219
- trajectories = extract_trajectory_from_text(prediction)
220
- trace_points_text = format_trace_points(trajectories)
221
-
222
- overlay_path = None
223
- if trajectories and len(trajectories) >= 2:
224
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
225
- overlay_path = f.name
226
- img_arr = visualize_trajectory_on_image(
227
- trajectory=trajectories,
228
- image_path=preprocessed_path,
229
- output_path=overlay_path,
230
- normalized=True,
231
- )
232
- if img_arr is None:
233
- visualize_trajectory_on_image(
234
- trajectory=trajectories,
235
- image_path=preprocessed_path,
236
- output_path=overlay_path,
237
- normalized=False,
238
- )
239
-
240
- return prediction, overlay_path, trace_points_text
241
-
242
- finally:
243
- if preprocessed_path and os.path.exists(preprocessed_path):
244
- try:
245
- os.unlink(preprocessed_path)
246
- except Exception:
247
- pass
248
-
249
- except Exception as e:
250
- logger.exception("Inference failed")
251
- return f"Error: {str(e)}", None, ""
252
-
253
-
254
- def build_franka_prompt(task: str) -> str:
255
- """Build the Franka-style prompt for trace prediction."""
256
- return (
257
- '<image>\nYou are a Franka robot using the joint control. '
258
- f'The task is "{task}". Can you predict the trace of the end effector?'
259
- )
260
-
261
  def _make_abs_paths(base: Path, files: str) -> str:
262
  return f"{(base / files).resolve()}"
263
 
 
264
  def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]:
265
  # Extract and normalize images and videos
266
  images = item.get("image") or []
@@ -322,7 +139,6 @@ def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any
322
 
323
  return messages
324
 
325
- IGNORE_INDEX = -100
326
 
327
  def preprocess_qwen_visual(
328
  sources,
@@ -382,143 +198,126 @@ def preprocess_qwen_visual(
382
 
383
  return full_result
384
 
385
- def run_inference_qwenvl(
386
- image_path: str,
387
- task: str,
388
- model_id: str = DEFAULT_MODEL_ID,
389
- data_path: Optional[str] = None,
390
- ) -> Tuple[dict, str, Optional[str], str]:
391
- """
392
- Run trace inference using preprocess_qwen_visual (qwenvl data processor).
393
- Uses Franka-style prompt and output format.
394
 
395
- Args:
396
- image_path: Full path to the image file.
397
- task: Task description (e.g. "open pot, then pick bread and place inside pot").
398
- model_id: Model to load.
399
- data_path: Base path for image resolution. Defaults to dirname of image_path.
400
-
401
- Returns:
402
- (output_dict, prediction_text, overlay_path, trace_points_text)
403
- output_dict has format: {"id", "image", "conversations": [human_msg, gpt_msg]}
404
- """
405
- success, msg = load_model(model_id)
406
- if not success:
407
- return {}, msg, None, ""
408
 
409
- model = _model_state["model"]
410
- processor = _model_state["processor"]
411
 
412
- if not image_path or not os.path.exists(image_path):
413
- return {}, "Please provide a valid image.", None, ""
414
 
415
- data_path = data_path or os.path.dirname(os.path.abspath(image_path))
416
- image_rel = os.path.basename(image_path)
417
- sample_id = hash(image_path) % 100000 # deterministic id for display
 
 
 
 
418
 
419
- prompt = build_franka_prompt(task)
420
- inference_sample = {
421
- "id": sample_id,
422
- "image": [image_rel],
423
- "conversations": [{"from": "human", "value": prompt}],
424
- "data_path": data_path,
425
- }
426
 
427
- print("prompt")
428
- print(prompt)
429
- print("image_path")
430
- print(image_rel)
431
- print("data_path")
432
- print(data_path)
433
 
434
- try:
435
- import torch
436
- from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
437
 
438
- processed_data = preprocess_qwen_visual(
439
- [inference_sample], processor, add_gen_prompt=True
440
- )
 
441
 
442
- print("inference_sample")
443
- print(inference_sample)
444
- print("processed_data")
445
- print(processed_data)
446
 
447
- input_ids = processed_data["input_ids"].to(model.device)
448
- pixel_values = (
449
- processed_data["pixel_values"].to(model.device)
450
- if "pixel_values" in processed_data
451
- else None
452
- )
453
- image_grid_thw = (
454
- processed_data["image_grid_thw"].to(model.device)
455
- if "image_grid_thw" in processed_data
456
- else None
457
- )
458
 
459
- inputs = {"input_ids": input_ids}
460
- if pixel_values is not None:
461
- inputs["pixel_values"] = pixel_values
462
- if image_grid_thw is not None:
463
- inputs["image_grid_thw"] = image_grid_thw
464
 
465
- with torch.no_grad():
466
- generated_ids = model.generate(**inputs, max_new_tokens=512)
467
- generated_ids_trimmed = [
468
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
469
- ]
470
- prediction = processor.tokenizer.decode(
471
- generated_ids_trimmed[0], skip_special_tokens=True
472
- )
473
 
474
- print("prediction")
475
- print(prediction)
476
 
477
- # Format output like the example: "Trace: [[x,y], [x,y], ...]"
478
- trajectories = extract_trajectory_from_text(prediction)
479
- trace_value = f"Trace: {trajectories}" if trajectories else f"Trace: {prediction}"
480
- output_dict = {
481
- "id": sample_id,
482
- "image": [image_rel],
483
  "conversations": [
484
- {"from": "human", "value": prompt},
485
- {"from": "gpt", "value": trace_value},
 
 
486
  ],
 
487
  }
488
 
489
- trace_points_text = format_trace_points(trajectories)
 
490
 
491
- print("trace_points_text")
492
- print(trace_points_text)
 
 
 
 
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  overlay_path = None
495
- if trajectories and len(trajectories) >= 2:
496
- _, preprocessed_path = preprocess_image_for_trace(image_path)
497
- try:
498
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
499
- overlay_path = f.name
500
- img_arr = visualize_trajectory_on_image(
501
- trajectory=trajectories,
502
- image_path=preprocessed_path,
503
- output_path=overlay_path,
504
- normalized=True,
505
- )
506
- if img_arr is None:
507
- visualize_trajectory_on_image(
508
- trajectory=trajectories,
509
- image_path=preprocessed_path,
510
- output_path=overlay_path,
511
- normalized=False,
512
- )
513
- finally:
514
- if preprocessed_path and os.path.exists(preprocessed_path):
515
- try:
516
- os.unlink(preprocessed_path)
517
- except Exception:
518
- pass
519
-
520
- return output_dict, prediction, overlay_path, trace_points_text
521
 
522
  except Exception as e:
523
- logger.exception("Inference failed (qwenvl)")
524
- return {}, f"Error: {str(e)}", None, ""
 
9
  import logging
10
  import os
11
  import tempfile
12
+ import torch
13
  import re
14
+ from typing import List, Optional, Tuple, Dict, Any
15
  from pathlib import Path
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Constants
20
  DEFAULT_MODEL_ID = "mihirgrao/trace-model"
21
+ IGNORE_INDEX = -100
 
 
 
 
 
22
 
23
  # Global model state
24
  _model_state = {
 
28
  }
29
 
30
 
31
+ def build_prompt(instruction: str = "", is_oxe: bool = False) -> str:
32
+ """Build the full prompt from task instruction."""
33
+ task = instruction.strip() or "predict the trace"
34
+ if is_oxe:
35
+ return f"<image>\nYou are a Franka robot using the joint control. The task is \"{task}\". Can you predict the trace of the end effector?"
36
+ return f"You are a robot. Your task is: \"{task}\". <image> Can you predict the trace of the end effector in this image to complete the task?"
37
 
38
 
39
  def format_trace_points(trajectories: List) -> str:
 
50
  return "\n".join(lines)
51
 
52
 
53
+ def center_crop_resize(image, size: Tuple[int, int] = (128, 128)):
54
  """Center crop to square then resize. Requires PIL Image."""
55
  from PIL import Image
56
 
 
59
  left = (w - min_dim) // 2
60
  top = (h - min_dim) // 2
61
  cropped = image.crop((left, top, left + min_dim, top + min_dim))
62
+ # return cropped.resize(size, Image.Resampling.LANCZOS)
63
+ return cropped
64
 
65
 
66
  def preprocess_image_for_trace(image_path: str) -> Tuple:
 
68
  from PIL import Image
69
 
70
  img = Image.open(image_path).convert("RGB")
71
+ img = center_crop_resize(img, (128, 128))
72
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
73
  img.save(tmp.name)
74
  return img, tmp.name
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def _make_abs_paths(base: Path, files: str) -> str:
78
  return f"{(base / files).resolve()}"
79
 
80
+
81
  def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]:
82
  # Extract and normalize images and videos
83
  images = item.get("image") or []
 
139
 
140
  return messages
141
 
 
142
 
143
  def preprocess_qwen_visual(
144
  sources,
 
198
 
199
  return full_result
200
 
 
 
 
 
 
 
 
 
 
201
 
202
+ def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
203
+ """Load the trace model and processor. Returns (success, message)."""
204
+ global _model_state
 
 
 
 
 
 
 
 
 
 
205
 
206
+ if _model_state["model"] is not None and _model_state["model_id"] == model_id:
207
+ return True, f"Model already loaded: {model_id}"
208
 
209
+ try:
210
+ from transformers import AutoModelForImageTextToText, AutoProcessor
211
 
212
+ if _model_state["model"] is not None:
213
+ del _model_state["model"]
214
+ del _model_state["processor"]
215
+ _model_state["model"] = None
216
+ _model_state["processor"] = None
217
+ if torch.cuda.is_available():
218
+ torch.cuda.empty_cache()
219
 
220
+ logger.info(f"Loading model from {model_id}...")
221
+ load_kwargs = {
222
+ "dtype": torch.bfloat16,
223
+ "device_map": "auto",
224
+ }
 
 
225
 
226
+ model = AutoModelForImageTextToText.from_pretrained(
227
+ model_id,
228
+ **load_kwargs,
229
+ )
230
+ processor = AutoProcessor.from_pretrained(model_id)
 
231
 
232
+ _model_state["model"] = model
233
+ _model_state["processor"] = processor
234
+ _model_state["model_id"] = model_id
235
 
236
+ return True, f"Model loaded: {model_id}"
237
+ except Exception as e:
238
+ logger.exception("Failed to load model")
239
+ return False, f"Error loading model: {str(e)}"
240
 
 
 
 
 
241
 
242
+ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
243
+ """
244
+ Run trace model inference on an image.
245
+ Returns: (prediction_text, overlay_image_path, trace_points_text)
246
+ """
247
+ success, msg = load_model(model_id)
248
+ if not success:
249
+ return msg, None, ""
 
 
 
250
 
251
+ model = _model_state["model"]
252
+ processor = _model_state["processor"]
 
 
 
253
 
254
+ if image_path is None or not os.path.exists(image_path):
255
+ return "Please provide a valid image.", None, ""
 
 
 
 
 
 
256
 
257
+ try:
258
+ from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
259
 
260
+ abs_image_path = os.path.abspath(image_path)
261
+ raw_item = {
262
+ "id": "single_inference",
263
+ "image": [abs_image_path],
 
 
264
  "conversations": [
265
+ {
266
+ "from": "human",
267
+ "value": prompt
268
+ }
269
  ],
270
+ "data_path": ""
271
  }
272
 
273
+ # Preprocessing using internal method
274
+ processed = preprocess_qwen_visual([raw_item], processor, add_gen_prompt=True)
275
 
276
+ # Prepare inputs - passing only what's necessary as per the new method
277
+ inputs = {"input_ids": processed["input_ids"].to(model.device)}
278
+ if "pixel_values" in processed:
279
+ inputs["pixel_values"] = processed["pixel_values"].to(model.device)
280
+ if "image_grid_thw" in processed:
281
+ inputs["image_grid_thw"] = processed["image_grid_thw"].to(model.device)
282
 
283
+ # Generate prediction
284
+ with torch.no_grad():
285
+ generated_ids = model.generate(
286
+ **inputs,
287
+ max_new_tokens=512,
288
+ do_sample=False,
289
+ )
290
+
291
+ # Trim prompt tokens
292
+ trimmed = generated_ids[:, inputs["input_ids"].shape[1]:]
293
+
294
+ # Decode
295
+ prediction = processor.tokenizer.batch_decode(
296
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
297
+ )[0]
298
+
299
+ trajectory = extract_trajectory_from_text(prediction)
300
+
301
+ trace_points_text = ""
302
  overlay_path = None
303
+
304
+ if trajectory:
305
+ trace_points_text = format_trace_points(trajectory)
306
+
307
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
308
+ overlay_path = f.name
309
+
310
+ visualize_trajectory_on_image(
311
+ trajectory=trajectory,
312
+ image_path=abs_image_path,
313
+ output_path=overlay_path,
314
+ normalized=True
315
+ )
316
+ else:
317
+ trace_points_text = "No trajectory points extracted."
318
+
319
+ return prediction, overlay_path, trace_points_text
 
 
 
 
 
 
 
 
 
320
 
321
  except Exception as e:
322
+ logger.exception("Inference failed")
323
+ return f"Error: {str(e)}", None, ""