Spaces:
Sleeping
Sleeping
Anthony Liang commited on
Commit ·
521dbac
1
Parent(s): cde89ae
update
Browse files- app.py +43 -35
- trace_inference.py +142 -0
app.py
CHANGED
|
@@ -20,11 +20,11 @@ import requests
|
|
| 20 |
from trace_inference import (
|
| 21 |
DEFAULT_MODEL_ID,
|
| 22 |
TRACE_FORMAT,
|
|
|
|
| 23 |
build_prompt,
|
| 24 |
-
format_trace_points,
|
| 25 |
-
load_model,
|
| 26 |
preprocess_image_for_trace,
|
| 27 |
run_inference,
|
|
|
|
| 28 |
)
|
| 29 |
from trajectory_viz import visualize_trajectory_on_image
|
| 30 |
|
|
@@ -124,8 +124,8 @@ def run_inference_via_server(
|
|
| 124 |
image_path: str,
|
| 125 |
instruction: str,
|
| 126 |
server_url: str,
|
| 127 |
-
) -> Tuple[str, Optional[str]
|
| 128 |
-
"""Run inference via trace eval server. Returns (prediction, overlay_path
|
| 129 |
with open(image_path, "rb") as f:
|
| 130 |
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
| 131 |
headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {}
|
|
@@ -138,10 +138,9 @@ def run_inference_via_server(
|
|
| 138 |
r.raise_for_status()
|
| 139 |
data = r.json()
|
| 140 |
if "error" in data:
|
| 141 |
-
return data["error"], None
|
| 142 |
prediction = data.get("prediction", "")
|
| 143 |
trajectory = data.get("trajectory", [])
|
| 144 |
-
trace_points_text = format_trace_points(trajectory)
|
| 145 |
|
| 146 |
overlay_path = None
|
| 147 |
if trajectory and len(trajectory) >= 2:
|
|
@@ -168,7 +167,7 @@ def run_inference_via_server(
|
|
| 168 |
os.unlink(preprocessed_path)
|
| 169 |
except Exception:
|
| 170 |
pass
|
| 171 |
-
return prediction, overlay_path
|
| 172 |
|
| 173 |
|
| 174 |
# --- Gradio UI ---
|
|
@@ -263,7 +262,7 @@ with demo:
|
|
| 263 |
interactive=True,
|
| 264 |
info="Discover trace eval servers on ports 8000-8010",
|
| 265 |
)
|
| 266 |
-
server_status = gr.Markdown("
|
| 267 |
gr.Markdown("---")
|
| 268 |
gr.Markdown("### 📋 Model Information")
|
| 269 |
model_info_display = gr.Markdown("")
|
|
@@ -303,9 +302,13 @@ with demo:
|
|
| 303 |
model_id_input = gr.Textbox(
|
| 304 |
label="Model ID",
|
| 305 |
value=DEFAULT_MODEL_ID,
|
| 306 |
-
info="Hugging Face model ID (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
)
|
| 308 |
-
load_model_btn = gr.Button("Load Model", variant="secondary")
|
| 309 |
run_btn = gr.Button("Run Inference", variant="primary")
|
| 310 |
|
| 311 |
with gr.Column(scale=1):
|
|
@@ -321,53 +324,58 @@ with demo:
|
|
| 321 |
label="Model Prediction (raw)",
|
| 322 |
lines=6,
|
| 323 |
)
|
| 324 |
-
trace_points_output = gr.Markdown(
|
| 325 |
-
label="Extracted Trace Points",
|
| 326 |
-
)
|
| 327 |
|
| 328 |
status_md = gr.Markdown(
|
| 329 |
-
"Select an eval server from the sidebar, or
|
| 330 |
)
|
| 331 |
|
| 332 |
-
def
|
| 333 |
-
_, msg = load_model(model_id)
|
| 334 |
-
return f"**Status:** {msg}"
|
| 335 |
-
|
| 336 |
-
def on_run_inference(image_path, instruction, model_id, server_url):
|
| 337 |
if image_path is None:
|
| 338 |
return (
|
| 339 |
"",
|
| 340 |
"Please upload an image first.",
|
| 341 |
None,
|
| 342 |
-
"",
|
| 343 |
"**Status:** Please upload an image.",
|
| 344 |
)
|
| 345 |
-
prompt = build_prompt(instruction)
|
| 346 |
-
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 347 |
|
| 348 |
if server_url:
|
| 349 |
-
|
|
|
|
|
|
|
| 350 |
image_path, instruction, server_url
|
| 351 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
else:
|
| 353 |
-
|
|
|
|
|
|
|
| 354 |
|
| 355 |
status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
|
| 356 |
-
return prompt_md, pred, overlay_path,
|
| 357 |
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
def update_prompt_display(instruction: str):
|
| 365 |
-
prompt = build_prompt(instruction)
|
| 366 |
return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 367 |
|
| 368 |
instruction_input.change(
|
| 369 |
fn=update_prompt_display,
|
| 370 |
-
inputs=[instruction_input],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
outputs=[prompt_display],
|
| 372 |
)
|
| 373 |
|
|
@@ -378,12 +386,12 @@ with demo:
|
|
| 378 |
instruction_input,
|
| 379 |
model_id_input,
|
| 380 |
server_url_state,
|
|
|
|
| 381 |
],
|
| 382 |
outputs=[
|
| 383 |
prompt_display,
|
| 384 |
prediction_output,
|
| 385 |
overlay_output,
|
| 386 |
-
trace_points_output,
|
| 387 |
status_md,
|
| 388 |
],
|
| 389 |
api_name="run_inference",
|
|
|
|
| 20 |
from trace_inference import (
|
| 21 |
DEFAULT_MODEL_ID,
|
| 22 |
TRACE_FORMAT,
|
| 23 |
+
build_franka_prompt,
|
| 24 |
build_prompt,
|
|
|
|
|
|
|
| 25 |
preprocess_image_for_trace,
|
| 26 |
run_inference,
|
| 27 |
+
run_inference_qwenvl,
|
| 28 |
)
|
| 29 |
from trajectory_viz import visualize_trajectory_on_image
|
| 30 |
|
|
|
|
| 124 |
image_path: str,
|
| 125 |
instruction: str,
|
| 126 |
server_url: str,
|
| 127 |
+
) -> Tuple[str, Optional[str]]:
|
| 128 |
+
"""Run inference via trace eval server. Returns (prediction, overlay_path)."""
|
| 129 |
with open(image_path, "rb") as f:
|
| 130 |
image_b64 = base64.b64encode(f.read()).decode("utf-8")
|
| 131 |
headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {}
|
|
|
|
| 138 |
r.raise_for_status()
|
| 139 |
data = r.json()
|
| 140 |
if "error" in data:
|
| 141 |
+
return data["error"], None
|
| 142 |
prediction = data.get("prediction", "")
|
| 143 |
trajectory = data.get("trajectory", [])
|
|
|
|
| 144 |
|
| 145 |
overlay_path = None
|
| 146 |
if trajectory and len(trajectory) >= 2:
|
|
|
|
| 167 |
os.unlink(preprocessed_path)
|
| 168 |
except Exception:
|
| 169 |
pass
|
| 170 |
+
return prediction, overlay_path
|
| 171 |
|
| 172 |
|
| 173 |
# --- Gradio UI ---
|
|
|
|
| 262 |
interactive=True,
|
| 263 |
info="Discover trace eval servers on ports 8000-8010",
|
| 264 |
)
|
| 265 |
+
server_status = gr.Markdown("Select an eval server below (auto-connects on selection)")
|
| 266 |
gr.Markdown("---")
|
| 267 |
gr.Markdown("### 📋 Model Information")
|
| 268 |
model_info_display = gr.Markdown("")
|
|
|
|
| 302 |
model_id_input = gr.Textbox(
|
| 303 |
label="Model ID",
|
| 304 |
value=DEFAULT_MODEL_ID,
|
| 305 |
+
info="Hugging Face model ID (auto-loads on first inference if no eval server selected)",
|
| 306 |
+
)
|
| 307 |
+
use_qwenvl_checkbox = gr.Checkbox(
|
| 308 |
+
label="Use Franka / qwenvl inference",
|
| 309 |
+
value=False,
|
| 310 |
+
info="Uses preprocess_qwen_visual (qwenvl). Requires qwen-vl-finetune on PYTHONPATH.",
|
| 311 |
)
|
|
|
|
| 312 |
run_btn = gr.Button("Run Inference", variant="primary")
|
| 313 |
|
| 314 |
with gr.Column(scale=1):
|
|
|
|
| 324 |
label="Model Prediction (raw)",
|
| 325 |
lines=6,
|
| 326 |
)
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
status_md = gr.Markdown(
|
| 329 |
+
"Select an eval server from the sidebar (auto-connects), or run inference with local model."
|
| 330 |
)
|
| 331 |
|
| 332 |
+
def on_run_inference(image_path, instruction, model_id, server_url, use_qwenvl):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
if image_path is None:
|
| 334 |
return (
|
| 335 |
"",
|
| 336 |
"Please upload an image first.",
|
| 337 |
None,
|
|
|
|
| 338 |
"**Status:** Please upload an image.",
|
| 339 |
)
|
|
|
|
|
|
|
| 340 |
|
| 341 |
if server_url:
|
| 342 |
+
prompt = build_prompt(instruction)
|
| 343 |
+
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 344 |
+
pred, overlay_path = run_inference_via_server(
|
| 345 |
image_path, instruction, server_url
|
| 346 |
)
|
| 347 |
+
elif use_qwenvl:
|
| 348 |
+
prompt = build_franka_prompt(instruction or "predict the trace")
|
| 349 |
+
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 350 |
+
output_dict, pred, overlay_path, trace_text = run_inference_qwenvl(
|
| 351 |
+
image_path, instruction or "predict the trace", model_id
|
| 352 |
+
)
|
| 353 |
+
if not output_dict and trace_text and "qwenvl package not found" in trace_text:
|
| 354 |
+
pred = trace_text
|
| 355 |
+
overlay_path = None
|
| 356 |
else:
|
| 357 |
+
prompt = build_prompt(instruction)
|
| 358 |
+
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 359 |
+
pred, overlay_path, _ = run_inference(image_path, prompt, model_id)
|
| 360 |
|
| 361 |
status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
|
| 362 |
+
return prompt_md, pred, overlay_path, status
|
| 363 |
|
| 364 |
+
def update_prompt_display(instruction: str, use_qwenvl: bool):
|
| 365 |
+
if use_qwenvl:
|
| 366 |
+
prompt = build_franka_prompt(instruction or "predict the trace")
|
| 367 |
+
else:
|
| 368 |
+
prompt = build_prompt(instruction)
|
|
|
|
|
|
|
|
|
|
| 369 |
return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 370 |
|
| 371 |
instruction_input.change(
|
| 372 |
fn=update_prompt_display,
|
| 373 |
+
inputs=[instruction_input, use_qwenvl_checkbox],
|
| 374 |
+
outputs=[prompt_display],
|
| 375 |
+
)
|
| 376 |
+
use_qwenvl_checkbox.change(
|
| 377 |
+
fn=update_prompt_display,
|
| 378 |
+
inputs=[instruction_input, use_qwenvl_checkbox],
|
| 379 |
outputs=[prompt_display],
|
| 380 |
)
|
| 381 |
|
|
|
|
| 386 |
instruction_input,
|
| 387 |
model_id_input,
|
| 388 |
server_url_state,
|
| 389 |
+
use_qwenvl_checkbox,
|
| 390 |
],
|
| 391 |
outputs=[
|
| 392 |
prompt_display,
|
| 393 |
prediction_output,
|
| 394 |
overlay_output,
|
|
|
|
| 395 |
status_md,
|
| 396 |
],
|
| 397 |
api_name="run_inference",
|
trace_inference.py
CHANGED
|
@@ -245,3 +245,145 @@ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Opt
|
|
| 245 |
except Exception as e:
|
| 246 |
logger.exception("Inference failed")
|
| 247 |
return f"Error: {str(e)}", None, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
except Exception as e:
|
| 246 |
logger.exception("Inference failed")
|
| 247 |
return f"Error: {str(e)}", None, ""
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def build_franka_prompt(task: str) -> str:
|
| 251 |
+
"""Build the Franka-style prompt for trace prediction."""
|
| 252 |
+
return (
|
| 253 |
+
'<image>\nYou are a Franka robot using the joint control. '
|
| 254 |
+
f'The task is "{task}". Can you predict the trace of the end effector?'
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def run_inference_qwenvl(
|
| 259 |
+
image_path: str,
|
| 260 |
+
task: str,
|
| 261 |
+
model_id: str = DEFAULT_MODEL_ID,
|
| 262 |
+
data_path: Optional[str] = None,
|
| 263 |
+
) -> Tuple[dict, str, Optional[str], str]:
|
| 264 |
+
"""
|
| 265 |
+
Run trace inference using preprocess_qwen_visual (qwenvl data processor).
|
| 266 |
+
Uses Franka-style prompt and output format.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
image_path: Full path to the image file.
|
| 270 |
+
task: Task description (e.g. "open pot, then pick bread and place inside pot").
|
| 271 |
+
model_id: Model to load.
|
| 272 |
+
data_path: Base path for image resolution. Defaults to dirname of image_path.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
(output_dict, prediction_text, overlay_path, trace_points_text)
|
| 276 |
+
output_dict has format: {"id", "image", "conversations": [human_msg, gpt_msg]}
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
from qwenvl.data.data_processor import preprocess_qwen_visual
|
| 280 |
+
except ImportError as e:
|
| 281 |
+
return (
|
| 282 |
+
{},
|
| 283 |
+
"",
|
| 284 |
+
None,
|
| 285 |
+
f"qwenvl package not found: {e}. Install qwen-vl-finetune or add to PYTHONPATH.",
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
success, msg = load_model(model_id)
|
| 289 |
+
if not success:
|
| 290 |
+
return {}, msg, None, ""
|
| 291 |
+
|
| 292 |
+
model = _model_state["model"]
|
| 293 |
+
processor = _model_state["processor"]
|
| 294 |
+
|
| 295 |
+
if not image_path or not os.path.exists(image_path):
|
| 296 |
+
return {}, "Please provide a valid image.", None, ""
|
| 297 |
+
|
| 298 |
+
data_path = data_path or os.path.dirname(os.path.abspath(image_path))
|
| 299 |
+
image_rel = os.path.basename(image_path)
|
| 300 |
+
sample_id = hash(image_path) % 100000 # deterministic id for display
|
| 301 |
+
|
| 302 |
+
prompt = build_franka_prompt(task)
|
| 303 |
+
inference_sample = {
|
| 304 |
+
"id": sample_id,
|
| 305 |
+
"image": [image_rel],
|
| 306 |
+
"conversations": [{"from": "human", "value": prompt}],
|
| 307 |
+
"data_path": data_path,
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
try:
|
| 311 |
+
import torch
|
| 312 |
+
from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
|
| 313 |
+
|
| 314 |
+
processed_data = preprocess_qwen_visual(
|
| 315 |
+
[inference_sample], processor, add_gen_prompt=True
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
input_ids = processed_data["input_ids"].to(model.device)
|
| 319 |
+
pixel_values = (
|
| 320 |
+
processed_data["pixel_values"].to(model.device)
|
| 321 |
+
if "pixel_values" in processed_data
|
| 322 |
+
else None
|
| 323 |
+
)
|
| 324 |
+
image_grid_thw = (
|
| 325 |
+
processed_data["image_grid_thw"].to(model.device)
|
| 326 |
+
if "image_grid_thw" in processed_data
|
| 327 |
+
else None
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
inputs = {"input_ids": input_ids}
|
| 331 |
+
if pixel_values is not None:
|
| 332 |
+
inputs["pixel_values"] = pixel_values
|
| 333 |
+
if image_grid_thw is not None:
|
| 334 |
+
inputs["image_grid_thw"] = image_grid_thw
|
| 335 |
+
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
generated_ids = model.generate(**inputs, max_new_tokens=1024)
|
| 338 |
+
generated_ids_trimmed = [
|
| 339 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
|
| 340 |
+
]
|
| 341 |
+
prediction = processor.tokenizer.decode(
|
| 342 |
+
generated_ids_trimmed[0], skip_special_tokens=True
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Format output like the example: "Trace: [[x,y], [x,y], ...]"
|
| 346 |
+
trajectories = extract_trajectory_from_text(prediction)
|
| 347 |
+
trace_value = f"Trace: {trajectories}" if trajectories else f"Trace: {prediction}"
|
| 348 |
+
output_dict = {
|
| 349 |
+
"id": sample_id,
|
| 350 |
+
"image": [image_rel],
|
| 351 |
+
"conversations": [
|
| 352 |
+
{"from": "human", "value": prompt},
|
| 353 |
+
{"from": "gpt", "value": trace_value},
|
| 354 |
+
],
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
trace_points_text = format_trace_points(trajectories)
|
| 358 |
+
|
| 359 |
+
overlay_path = None
|
| 360 |
+
if trajectories and len(trajectories) >= 2:
|
| 361 |
+
_, preprocessed_path = preprocess_image_for_trace(image_path)
|
| 362 |
+
try:
|
| 363 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
|
| 364 |
+
overlay_path = f.name
|
| 365 |
+
img_arr = visualize_trajectory_on_image(
|
| 366 |
+
trajectory=trajectories,
|
| 367 |
+
image_path=preprocessed_path,
|
| 368 |
+
output_path=overlay_path,
|
| 369 |
+
normalized=True,
|
| 370 |
+
)
|
| 371 |
+
if img_arr is None:
|
| 372 |
+
visualize_trajectory_on_image(
|
| 373 |
+
trajectory=trajectories,
|
| 374 |
+
image_path=preprocessed_path,
|
| 375 |
+
output_path=overlay_path,
|
| 376 |
+
normalized=False,
|
| 377 |
+
)
|
| 378 |
+
finally:
|
| 379 |
+
if preprocessed_path and os.path.exists(preprocessed_path):
|
| 380 |
+
try:
|
| 381 |
+
os.unlink(preprocessed_path)
|
| 382 |
+
except Exception:
|
| 383 |
+
pass
|
| 384 |
+
|
| 385 |
+
return output_dict, prediction, overlay_path, trace_points_text
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logger.exception("Inference failed (qwenvl)")
|
| 389 |
+
return {}, f"Error: {str(e)}", None, ""
|