Anthony Liang commited on
Commit
521dbac
·
1 Parent(s): cde89ae
Files changed (2) hide show
  1. app.py +43 -35
  2. trace_inference.py +142 -0
app.py CHANGED
@@ -20,11 +20,11 @@ import requests
20
  from trace_inference import (
21
  DEFAULT_MODEL_ID,
22
  TRACE_FORMAT,
 
23
  build_prompt,
24
- format_trace_points,
25
- load_model,
26
  preprocess_image_for_trace,
27
  run_inference,
 
28
  )
29
  from trajectory_viz import visualize_trajectory_on_image
30
 
@@ -124,8 +124,8 @@ def run_inference_via_server(
124
  image_path: str,
125
  instruction: str,
126
  server_url: str,
127
- ) -> Tuple[str, Optional[str], str]:
128
- """Run inference via trace eval server. Returns (prediction, overlay_path, trace_points_text)."""
129
  with open(image_path, "rb") as f:
130
  image_b64 = base64.b64encode(f.read()).decode("utf-8")
131
  headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {}
@@ -138,10 +138,9 @@ def run_inference_via_server(
138
  r.raise_for_status()
139
  data = r.json()
140
  if "error" in data:
141
- return data["error"], None, ""
142
  prediction = data.get("prediction", "")
143
  trajectory = data.get("trajectory", [])
144
- trace_points_text = format_trace_points(trajectory)
145
 
146
  overlay_path = None
147
  if trajectory and len(trajectory) >= 2:
@@ -168,7 +167,7 @@ def run_inference_via_server(
168
  os.unlink(preprocessed_path)
169
  except Exception:
170
  pass
171
- return prediction, overlay_path, trace_points_text
172
 
173
 
174
  # --- Gradio UI ---
@@ -263,7 +262,7 @@ with demo:
263
  interactive=True,
264
  info="Discover trace eval servers on ports 8000-8010",
265
  )
266
- server_status = gr.Markdown("Click 'Discover Eval Servers' or use local model below")
267
  gr.Markdown("---")
268
  gr.Markdown("### 📋 Model Information")
269
  model_info_display = gr.Markdown("")
@@ -303,9 +302,13 @@ with demo:
303
  model_id_input = gr.Textbox(
304
  label="Model ID",
305
  value=DEFAULT_MODEL_ID,
306
- info="Hugging Face model ID (used when no eval server is selected)",
 
 
 
 
 
307
  )
308
- load_model_btn = gr.Button("Load Model", variant="secondary")
309
  run_btn = gr.Button("Run Inference", variant="primary")
310
 
311
  with gr.Column(scale=1):
@@ -321,53 +324,58 @@ with demo:
321
  label="Model Prediction (raw)",
322
  lines=6,
323
  )
324
- trace_points_output = gr.Markdown(
325
- label="Extracted Trace Points",
326
- )
327
 
328
  status_md = gr.Markdown(
329
- "Select an eval server from the sidebar, or load a local model and run inference."
330
  )
331
 
332
- def on_load_model(model_id: str):
333
- _, msg = load_model(model_id)
334
- return f"**Status:** {msg}"
335
-
336
- def on_run_inference(image_path, instruction, model_id, server_url):
337
  if image_path is None:
338
  return (
339
  "",
340
  "Please upload an image first.",
341
  None,
342
- "",
343
  "**Status:** Please upload an image.",
344
  )
345
- prompt = build_prompt(instruction)
346
- prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
347
 
348
  if server_url:
349
- pred, overlay_path, trace_text = run_inference_via_server(
 
 
350
  image_path, instruction, server_url
351
  )
 
 
 
 
 
 
 
 
 
352
  else:
353
- pred, overlay_path, trace_text = run_inference(image_path, prompt, model_id)
 
 
354
 
355
  status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
356
- return prompt_md, pred, overlay_path, trace_text, status
357
 
358
- load_model_btn.click(
359
- fn=on_load_model,
360
- inputs=[model_id_input],
361
- outputs=[status_md],
362
- )
363
-
364
- def update_prompt_display(instruction: str):
365
- prompt = build_prompt(instruction)
366
  return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
367
 
368
  instruction_input.change(
369
  fn=update_prompt_display,
370
- inputs=[instruction_input],
 
 
 
 
 
371
  outputs=[prompt_display],
372
  )
373
 
@@ -378,12 +386,12 @@ with demo:
378
  instruction_input,
379
  model_id_input,
380
  server_url_state,
 
381
  ],
382
  outputs=[
383
  prompt_display,
384
  prediction_output,
385
  overlay_output,
386
- trace_points_output,
387
  status_md,
388
  ],
389
  api_name="run_inference",
 
20
  from trace_inference import (
21
  DEFAULT_MODEL_ID,
22
  TRACE_FORMAT,
23
+ build_franka_prompt,
24
  build_prompt,
 
 
25
  preprocess_image_for_trace,
26
  run_inference,
27
+ run_inference_qwenvl,
28
  )
29
  from trajectory_viz import visualize_trajectory_on_image
30
 
 
124
  image_path: str,
125
  instruction: str,
126
  server_url: str,
127
+ ) -> Tuple[str, Optional[str]]:
128
+ """Run inference via trace eval server. Returns (prediction, overlay_path)."""
129
  with open(image_path, "rb") as f:
130
  image_b64 = base64.b64encode(f.read()).decode("utf-8")
131
  headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {}
 
138
  r.raise_for_status()
139
  data = r.json()
140
  if "error" in data:
141
+ return data["error"], None
142
  prediction = data.get("prediction", "")
143
  trajectory = data.get("trajectory", [])
 
144
 
145
  overlay_path = None
146
  if trajectory and len(trajectory) >= 2:
 
167
  os.unlink(preprocessed_path)
168
  except Exception:
169
  pass
170
+ return prediction, overlay_path
171
 
172
 
173
  # --- Gradio UI ---
 
262
  interactive=True,
263
  info="Discover trace eval servers on ports 8000-8010",
264
  )
265
+ server_status = gr.Markdown("Select an eval server below (auto-connects on selection)")
266
  gr.Markdown("---")
267
  gr.Markdown("### 📋 Model Information")
268
  model_info_display = gr.Markdown("")
 
302
  model_id_input = gr.Textbox(
303
  label="Model ID",
304
  value=DEFAULT_MODEL_ID,
305
+ info="Hugging Face model ID (auto-loads on first inference if no eval server selected)",
306
+ )
307
+ use_qwenvl_checkbox = gr.Checkbox(
308
+ label="Use Franka / qwenvl inference",
309
+ value=False,
310
+ info="Uses preprocess_qwen_visual (qwenvl). Requires qwen-vl-finetune on PYTHONPATH.",
311
  )
 
312
  run_btn = gr.Button("Run Inference", variant="primary")
313
 
314
  with gr.Column(scale=1):
 
324
  label="Model Prediction (raw)",
325
  lines=6,
326
  )
 
 
 
327
 
328
  status_md = gr.Markdown(
329
+ "Select an eval server from the sidebar (auto-connects), or run inference with local model."
330
  )
331
 
332
+ def on_run_inference(image_path, instruction, model_id, server_url, use_qwenvl):
 
 
 
 
333
  if image_path is None:
334
  return (
335
  "",
336
  "Please upload an image first.",
337
  None,
 
338
  "**Status:** Please upload an image.",
339
  )
 
 
340
 
341
  if server_url:
342
+ prompt = build_prompt(instruction)
343
+ prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
344
+ pred, overlay_path = run_inference_via_server(
345
  image_path, instruction, server_url
346
  )
347
+ elif use_qwenvl:
348
+ prompt = build_franka_prompt(instruction or "predict the trace")
349
+ prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
350
+ output_dict, pred, overlay_path, trace_text = run_inference_qwenvl(
351
+ image_path, instruction or "predict the trace", model_id
352
+ )
353
+ if not output_dict and trace_text and "qwenvl package not found" in trace_text:
354
+ pred = trace_text
355
+ overlay_path = None
356
  else:
357
+ prompt = build_prompt(instruction)
358
+ prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
359
+ pred, overlay_path, _ = run_inference(image_path, prompt, model_id)
360
 
361
  status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
362
+ return prompt_md, pred, overlay_path, status
363
 
364
+ def update_prompt_display(instruction: str, use_qwenvl: bool):
365
+ if use_qwenvl:
366
+ prompt = build_franka_prompt(instruction or "predict the trace")
367
+ else:
368
+ prompt = build_prompt(instruction)
 
 
 
369
  return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
370
 
371
  instruction_input.change(
372
  fn=update_prompt_display,
373
+ inputs=[instruction_input, use_qwenvl_checkbox],
374
+ outputs=[prompt_display],
375
+ )
376
+ use_qwenvl_checkbox.change(
377
+ fn=update_prompt_display,
378
+ inputs=[instruction_input, use_qwenvl_checkbox],
379
  outputs=[prompt_display],
380
  )
381
 
 
386
  instruction_input,
387
  model_id_input,
388
  server_url_state,
389
+ use_qwenvl_checkbox,
390
  ],
391
  outputs=[
392
  prompt_display,
393
  prediction_output,
394
  overlay_output,
 
395
  status_md,
396
  ],
397
  api_name="run_inference",
trace_inference.py CHANGED
@@ -245,3 +245,145 @@ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Opt
245
  except Exception as e:
246
  logger.exception("Inference failed")
247
  return f"Error: {str(e)}", None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  except Exception as e:
246
  logger.exception("Inference failed")
247
  return f"Error: {str(e)}", None, ""
248
+
249
+
250
+ def build_franka_prompt(task: str) -> str:
251
+ """Build the Franka-style prompt for trace prediction."""
252
+ return (
253
+ '<image>\nYou are a Franka robot using the joint control. '
254
+ f'The task is "{task}". Can you predict the trace of the end effector?'
255
+ )
256
+
257
+
258
+ def run_inference_qwenvl(
259
+ image_path: str,
260
+ task: str,
261
+ model_id: str = DEFAULT_MODEL_ID,
262
+ data_path: Optional[str] = None,
263
+ ) -> Tuple[dict, str, Optional[str], str]:
264
+ """
265
+ Run trace inference using preprocess_qwen_visual (qwenvl data processor).
266
+ Uses Franka-style prompt and output format.
267
+
268
+ Args:
269
+ image_path: Full path to the image file.
270
+ task: Task description (e.g. "open pot, then pick bread and place inside pot").
271
+ model_id: Model to load.
272
+ data_path: Base path for image resolution. Defaults to dirname of image_path.
273
+
274
+ Returns:
275
+ (output_dict, prediction_text, overlay_path, trace_points_text)
276
+ output_dict has format: {"id", "image", "conversations": [human_msg, gpt_msg]}
277
+ """
278
+ try:
279
+ from qwenvl.data.data_processor import preprocess_qwen_visual
280
+ except ImportError as e:
281
+ return (
282
+ {},
283
+ "",
284
+ None,
285
+ f"qwenvl package not found: {e}. Install qwen-vl-finetune or add to PYTHONPATH.",
286
+ )
287
+
288
+ success, msg = load_model(model_id)
289
+ if not success:
290
+ return {}, msg, None, ""
291
+
292
+ model = _model_state["model"]
293
+ processor = _model_state["processor"]
294
+
295
+ if not image_path or not os.path.exists(image_path):
296
+ return {}, "Please provide a valid image.", None, ""
297
+
298
+ data_path = data_path or os.path.dirname(os.path.abspath(image_path))
299
+ image_rel = os.path.basename(image_path)
300
+ sample_id = hash(image_path) % 100000 # deterministic id for display
301
+
302
+ prompt = build_franka_prompt(task)
303
+ inference_sample = {
304
+ "id": sample_id,
305
+ "image": [image_rel],
306
+ "conversations": [{"from": "human", "value": prompt}],
307
+ "data_path": data_path,
308
+ }
309
+
310
+ try:
311
+ import torch
312
+ from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
313
+
314
+ processed_data = preprocess_qwen_visual(
315
+ [inference_sample], processor, add_gen_prompt=True
316
+ )
317
+
318
+ input_ids = processed_data["input_ids"].to(model.device)
319
+ pixel_values = (
320
+ processed_data["pixel_values"].to(model.device)
321
+ if "pixel_values" in processed_data
322
+ else None
323
+ )
324
+ image_grid_thw = (
325
+ processed_data["image_grid_thw"].to(model.device)
326
+ if "image_grid_thw" in processed_data
327
+ else None
328
+ )
329
+
330
+ inputs = {"input_ids": input_ids}
331
+ if pixel_values is not None:
332
+ inputs["pixel_values"] = pixel_values
333
+ if image_grid_thw is not None:
334
+ inputs["image_grid_thw"] = image_grid_thw
335
+
336
+ with torch.no_grad():
337
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
338
+ generated_ids_trimmed = [
339
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
340
+ ]
341
+ prediction = processor.tokenizer.decode(
342
+ generated_ids_trimmed[0], skip_special_tokens=True
343
+ )
344
+
345
+ # Format output like the example: "Trace: [[x,y], [x,y], ...]"
346
+ trajectories = extract_trajectory_from_text(prediction)
347
+ trace_value = f"Trace: {trajectories}" if trajectories else f"Trace: {prediction}"
348
+ output_dict = {
349
+ "id": sample_id,
350
+ "image": [image_rel],
351
+ "conversations": [
352
+ {"from": "human", "value": prompt},
353
+ {"from": "gpt", "value": trace_value},
354
+ ],
355
+ }
356
+
357
+ trace_points_text = format_trace_points(trajectories)
358
+
359
+ overlay_path = None
360
+ if trajectories and len(trajectories) >= 2:
361
+ _, preprocessed_path = preprocess_image_for_trace(image_path)
362
+ try:
363
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
364
+ overlay_path = f.name
365
+ img_arr = visualize_trajectory_on_image(
366
+ trajectory=trajectories,
367
+ image_path=preprocessed_path,
368
+ output_path=overlay_path,
369
+ normalized=True,
370
+ )
371
+ if img_arr is None:
372
+ visualize_trajectory_on_image(
373
+ trajectory=trajectories,
374
+ image_path=preprocessed_path,
375
+ output_path=overlay_path,
376
+ normalized=False,
377
+ )
378
+ finally:
379
+ if preprocessed_path and os.path.exists(preprocessed_path):
380
+ try:
381
+ os.unlink(preprocessed_path)
382
+ except Exception:
383
+ pass
384
+
385
+ return output_dict, prediction, overlay_path, trace_points_text
386
+
387
+ except Exception as e:
388
+ logger.exception("Inference failed (qwenvl)")
389
+ return {}, f"Error: {str(e)}", None, ""