Spaces:
Sleeping
Sleeping
Anthony Liang
commited on
Commit
·
be80524
1
Parent(s):
130aa46
update
Browse files- app.py +22 -37
- eval_server.py +6 -17
- trace_inference.py +115 -316
app.py
CHANGED
|
@@ -19,12 +19,9 @@ import requests
|
|
| 19 |
|
| 20 |
from trace_inference import (
|
| 21 |
DEFAULT_MODEL_ID,
|
| 22 |
-
TRACE_FORMAT,
|
| 23 |
-
build_franka_prompt,
|
| 24 |
build_prompt,
|
| 25 |
preprocess_image_for_trace,
|
| 26 |
run_inference,
|
| 27 |
-
run_inference_qwenvl,
|
| 28 |
)
|
| 29 |
from trajectory_viz import visualize_trajectory_on_image
|
| 30 |
|
|
@@ -124,7 +121,7 @@ def run_inference_via_server(
|
|
| 124 |
image_path: str,
|
| 125 |
instruction: str,
|
| 126 |
server_url: str,
|
| 127 |
-
|
| 128 |
) -> Tuple[str, Optional[str]]:
|
| 129 |
"""Run inference via trace eval server. Returns (prediction, overlay_path)."""
|
| 130 |
with open(image_path, "rb") as f:
|
|
@@ -135,7 +132,7 @@ def run_inference_via_server(
|
|
| 135 |
json={
|
| 136 |
"image_base64": image_b64,
|
| 137 |
"instruction": instruction,
|
| 138 |
-
"
|
| 139 |
},
|
| 140 |
timeout=120.0,
|
| 141 |
headers=headers,
|
|
@@ -177,7 +174,7 @@ def run_inference_via_server(
|
|
| 177 |
|
| 178 |
# --- Gradio UI ---
|
| 179 |
try:
|
| 180 |
-
demo = gr.Blocks(title="Trace Model Visualizer"
|
| 181 |
except TypeError:
|
| 182 |
demo = gr.Blocks(title="Trace Model Visualizer")
|
| 183 |
|
|
@@ -303,17 +300,18 @@ with demo:
|
|
| 303 |
lines=4,
|
| 304 |
info="Enter a task description in natural language. The model predicts the trace for this instruction.",
|
| 305 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
gr.Markdown("### Local model (if no eval server selected)")
|
| 307 |
model_id_input = gr.Textbox(
|
| 308 |
label="Model ID",
|
| 309 |
value=DEFAULT_MODEL_ID,
|
| 310 |
info="Hugging Face model ID (auto-loads on first inference if no eval server selected)",
|
| 311 |
)
|
| 312 |
-
use_qwenvl_checkbox = gr.Checkbox(
|
| 313 |
-
label="Use Franka / qwenvl inference",
|
| 314 |
-
value=False,
|
| 315 |
-
info="Uses preprocess_qwen_visual (qwenvl). Requires qwen-vl-finetune on PYTHONPATH.",
|
| 316 |
-
)
|
| 317 |
run_btn = gr.Button("Run Inference", variant="primary")
|
| 318 |
|
| 319 |
with gr.Column(scale=1):
|
|
@@ -334,7 +332,7 @@ with demo:
|
|
| 334 |
"Select an eval server from the sidebar (auto-connects), or run inference with local model."
|
| 335 |
)
|
| 336 |
|
| 337 |
-
def on_run_inference(image_path, instruction, model_id, server_url,
|
| 338 |
if image_path is None:
|
| 339 |
return (
|
| 340 |
"",
|
|
@@ -343,48 +341,34 @@ with demo:
|
|
| 343 |
"**Status:** Please upload an image.",
|
| 344 |
)
|
| 345 |
|
|
|
|
| 346 |
if server_url:
|
| 347 |
-
prompt = (
|
| 348 |
-
build_franka_prompt(instruction or "predict the trace")
|
| 349 |
-
if use_qwenvl
|
| 350 |
-
else build_prompt(instruction)
|
| 351 |
-
)
|
| 352 |
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 353 |
pred, overlay_path = run_inference_via_server(
|
| 354 |
-
image_path, instruction, server_url,
|
| 355 |
)
|
| 356 |
-
elif use_qwenvl:
|
| 357 |
-
prompt = build_franka_prompt(instruction or "predict the trace")
|
| 358 |
-
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 359 |
-
output_dict, pred, overlay_path, trace_text = run_inference_qwenvl(
|
| 360 |
-
image_path, instruction or "predict the trace", model_id
|
| 361 |
-
)
|
| 362 |
-
if not output_dict and trace_text and "qwenvl package not found" in trace_text:
|
| 363 |
-
pred = trace_text
|
| 364 |
-
overlay_path = None
|
| 365 |
else:
|
| 366 |
-
prompt = build_prompt(instruction)
|
| 367 |
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 368 |
pred, overlay_path, _ = run_inference(image_path, prompt, model_id)
|
| 369 |
|
| 370 |
status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
|
| 371 |
return prompt_md, pred, overlay_path, status
|
| 372 |
|
| 373 |
-
def update_prompt_display(instruction: str,
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
else:
|
| 377 |
-
prompt = build_prompt(instruction)
|
| 378 |
return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 379 |
|
| 380 |
instruction_input.change(
|
| 381 |
fn=update_prompt_display,
|
| 382 |
-
inputs=[instruction_input,
|
| 383 |
outputs=[prompt_display],
|
| 384 |
)
|
| 385 |
-
|
| 386 |
fn=update_prompt_display,
|
| 387 |
-
inputs=[instruction_input,
|
| 388 |
outputs=[prompt_display],
|
| 389 |
)
|
| 390 |
|
|
@@ -395,7 +379,7 @@ with demo:
|
|
| 395 |
instruction_input,
|
| 396 |
model_id_input,
|
| 397 |
server_url_state,
|
| 398 |
-
|
| 399 |
],
|
| 400 |
outputs=[
|
| 401 |
prompt_display,
|
|
@@ -413,6 +397,7 @@ def main():
|
|
| 413 |
server_name="0.0.0.0",
|
| 414 |
server_port=7860,
|
| 415 |
share=False,
|
|
|
|
| 416 |
)
|
| 417 |
|
| 418 |
|
|
|
|
| 19 |
|
| 20 |
from trace_inference import (
|
| 21 |
DEFAULT_MODEL_ID,
|
|
|
|
|
|
|
| 22 |
build_prompt,
|
| 23 |
preprocess_image_for_trace,
|
| 24 |
run_inference,
|
|
|
|
| 25 |
)
|
| 26 |
from trajectory_viz import visualize_trajectory_on_image
|
| 27 |
|
|
|
|
| 121 |
image_path: str,
|
| 122 |
instruction: str,
|
| 123 |
server_url: str,
|
| 124 |
+
is_oxe: bool = False,
|
| 125 |
) -> Tuple[str, Optional[str]]:
|
| 126 |
"""Run inference via trace eval server. Returns (prediction, overlay_path)."""
|
| 127 |
with open(image_path, "rb") as f:
|
|
|
|
| 132 |
json={
|
| 133 |
"image_base64": image_b64,
|
| 134 |
"instruction": instruction,
|
| 135 |
+
"is_oxe": is_oxe,
|
| 136 |
},
|
| 137 |
timeout=120.0,
|
| 138 |
headers=headers,
|
|
|
|
| 174 |
|
| 175 |
# --- Gradio UI ---
|
| 176 |
try:
|
| 177 |
+
demo = gr.Blocks(title="Trace Model Visualizer")
|
| 178 |
except TypeError:
|
| 179 |
demo = gr.Blocks(title="Trace Model Visualizer")
|
| 180 |
|
|
|
|
| 300 |
lines=4,
|
| 301 |
info="Enter a task description in natural language. The model predicts the trace for this instruction.",
|
| 302 |
)
|
| 303 |
+
prompt_format = gr.Radio(
|
| 304 |
+
choices=["LIBERO", "OXE"],
|
| 305 |
+
value="LIBERO",
|
| 306 |
+
label="Prompt Format",
|
| 307 |
+
info="Switch between LIBERO and OXE training formats.",
|
| 308 |
+
)
|
| 309 |
gr.Markdown("### Local model (if no eval server selected)")
|
| 310 |
model_id_input = gr.Textbox(
|
| 311 |
label="Model ID",
|
| 312 |
value=DEFAULT_MODEL_ID,
|
| 313 |
info="Hugging Face model ID (auto-loads on first inference if no eval server selected)",
|
| 314 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
run_btn = gr.Button("Run Inference", variant="primary")
|
| 316 |
|
| 317 |
with gr.Column(scale=1):
|
|
|
|
| 332 |
"Select an eval server from the sidebar (auto-connects), or run inference with local model."
|
| 333 |
)
|
| 334 |
|
| 335 |
+
def on_run_inference(image_path, instruction, model_id, server_url, prompt_mode):
|
| 336 |
if image_path is None:
|
| 337 |
return (
|
| 338 |
"",
|
|
|
|
| 341 |
"**Status:** Please upload an image.",
|
| 342 |
)
|
| 343 |
|
| 344 |
+
is_oxe = (prompt_mode == "OXE")
|
| 345 |
if server_url:
|
| 346 |
+
prompt = build_prompt(instruction, is_oxe=is_oxe)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 348 |
pred, overlay_path = run_inference_via_server(
|
| 349 |
+
image_path, instruction, server_url, is_oxe=is_oxe
|
| 350 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
else:
|
| 352 |
+
prompt = build_prompt(instruction, is_oxe=is_oxe)
|
| 353 |
prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 354 |
pred, overlay_path, _ = run_inference(image_path, prompt, model_id)
|
| 355 |
|
| 356 |
status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
|
| 357 |
return prompt_md, pred, overlay_path, status
|
| 358 |
|
| 359 |
+
def update_prompt_display(instruction: str, prompt_mode: str):
|
| 360 |
+
is_oxe = (prompt_mode == "OXE")
|
| 361 |
+
prompt = build_prompt(instruction, is_oxe=is_oxe)
|
|
|
|
|
|
|
| 362 |
return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
|
| 363 |
|
| 364 |
instruction_input.change(
|
| 365 |
fn=update_prompt_display,
|
| 366 |
+
inputs=[instruction_input, prompt_format],
|
| 367 |
outputs=[prompt_display],
|
| 368 |
)
|
| 369 |
+
prompt_format.change(
|
| 370 |
fn=update_prompt_display,
|
| 371 |
+
inputs=[instruction_input, prompt_format],
|
| 372 |
outputs=[prompt_display],
|
| 373 |
)
|
| 374 |
|
|
|
|
| 379 |
instruction_input,
|
| 380 |
model_id_input,
|
| 381 |
server_url_state,
|
| 382 |
+
prompt_format,
|
| 383 |
],
|
| 384 |
outputs=[
|
| 385 |
prompt_display,
|
|
|
|
| 397 |
server_name="0.0.0.0",
|
| 398 |
server_port=7860,
|
| 399 |
share=False,
|
| 400 |
+
theme=gr.themes.Soft(),
|
| 401 |
)
|
| 402 |
|
| 403 |
|
eval_server.py
CHANGED
|
@@ -33,7 +33,6 @@ from trace_inference import (
|
|
| 33 |
build_prompt,
|
| 34 |
load_model,
|
| 35 |
run_inference,
|
| 36 |
-
run_inference_qwenvl,
|
| 37 |
)
|
| 38 |
from trace_inference import _model_state as _trace_model_state
|
| 39 |
from trajectory_viz import extract_trajectory_from_text
|
|
@@ -69,13 +68,12 @@ class TraceEvalServer:
|
|
| 69 |
image_path: Optional[str] = None,
|
| 70 |
image_base64: Optional[str] = None,
|
| 71 |
instruction: str = "",
|
| 72 |
-
|
| 73 |
) -> Dict[str, Any]:
|
| 74 |
"""
|
| 75 |
Run inference on a single image.
|
| 76 |
|
| 77 |
Provide either image_path (file path) or image_base64 (base64-encoded image).
|
| 78 |
-
If use_qwenvl=True, uses run_inference_qwenvl (Franka-style, requires qwenvl).
|
| 79 |
"""
|
| 80 |
if image_path is None and image_base64 is None:
|
| 81 |
return {"error": "Provide image_path or image_base64"}
|
|
@@ -102,15 +100,8 @@ class TraceEvalServer:
|
|
| 102 |
return {"error": f"Invalid image data: {e}"}
|
| 103 |
|
| 104 |
try:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
image_path, instruction or "predict the trace", self.model_id
|
| 108 |
-
)
|
| 109 |
-
if not output_dict and trace_text and "qwenvl package not found" in trace_text:
|
| 110 |
-
return {"error": trace_text}
|
| 111 |
-
else:
|
| 112 |
-
prompt = build_prompt(instruction)
|
| 113 |
-
prediction, _, _ = run_inference(image_path, prompt, self.model_id)
|
| 114 |
finally:
|
| 115 |
if temp_file_path and os.path.exists(temp_file_path):
|
| 116 |
try:
|
|
@@ -126,8 +117,6 @@ class TraceEvalServer:
|
|
| 126 |
"prediction": prediction,
|
| 127 |
"trajectory": trajectory,
|
| 128 |
}
|
| 129 |
-
if use_qwenvl and output_dict:
|
| 130 |
-
result["output_dict"] = output_dict
|
| 131 |
return result
|
| 132 |
|
| 133 |
def predict_batch(
|
|
@@ -146,7 +135,7 @@ class TraceEvalServer:
|
|
| 146 |
image_path=sample.get("image_path"),
|
| 147 |
image_base64=sample.get("image_base64"),
|
| 148 |
instruction=sample.get("instruction", ""),
|
| 149 |
-
|
| 150 |
)
|
| 151 |
elapsed = time.time() - start
|
| 152 |
|
|
@@ -214,14 +203,14 @@ def create_app(
|
|
| 214 |
- image_path: (optional) path to image file
|
| 215 |
- image_base64: (optional) base64-encoded image
|
| 216 |
- instruction: natural language task description
|
| 217 |
-
-
|
| 218 |
"""
|
| 219 |
body = await request.json()
|
| 220 |
return trace_server.predict_one(
|
| 221 |
image_path=body.get("image_path"),
|
| 222 |
image_base64=body.get("image_base64"),
|
| 223 |
instruction=body.get("instruction", ""),
|
| 224 |
-
|
| 225 |
)
|
| 226 |
|
| 227 |
@app.post("/predict_batch")
|
|
|
|
| 33 |
build_prompt,
|
| 34 |
load_model,
|
| 35 |
run_inference,
|
|
|
|
| 36 |
)
|
| 37 |
from trace_inference import _model_state as _trace_model_state
|
| 38 |
from trajectory_viz import extract_trajectory_from_text
|
|
|
|
| 68 |
image_path: Optional[str] = None,
|
| 69 |
image_base64: Optional[str] = None,
|
| 70 |
instruction: str = "",
|
| 71 |
+
is_oxe: bool = False,
|
| 72 |
) -> Dict[str, Any]:
|
| 73 |
"""
|
| 74 |
Run inference on a single image.
|
| 75 |
|
| 76 |
Provide either image_path (file path) or image_base64 (base64-encoded image).
|
|
|
|
| 77 |
"""
|
| 78 |
if image_path is None and image_base64 is None:
|
| 79 |
return {"error": "Provide image_path or image_base64"}
|
|
|
|
| 100 |
return {"error": f"Invalid image data: {e}"}
|
| 101 |
|
| 102 |
try:
|
| 103 |
+
prompt = build_prompt(instruction, is_oxe=is_oxe)
|
| 104 |
+
prediction, _, _ = run_inference(image_path, prompt, self.model_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
finally:
|
| 106 |
if temp_file_path and os.path.exists(temp_file_path):
|
| 107 |
try:
|
|
|
|
| 117 |
"prediction": prediction,
|
| 118 |
"trajectory": trajectory,
|
| 119 |
}
|
|
|
|
|
|
|
| 120 |
return result
|
| 121 |
|
| 122 |
def predict_batch(
|
|
|
|
| 135 |
image_path=sample.get("image_path"),
|
| 136 |
image_base64=sample.get("image_base64"),
|
| 137 |
instruction=sample.get("instruction", ""),
|
| 138 |
+
is_oxe=sample.get("is_oxe", False),
|
| 139 |
)
|
| 140 |
elapsed = time.time() - start
|
| 141 |
|
|
|
|
| 203 |
- image_path: (optional) path to image file
|
| 204 |
- image_base64: (optional) base64-encoded image
|
| 205 |
- instruction: natural language task description
|
| 206 |
+
- is_oxe: (optional) if true, use OXE prompt format
|
| 207 |
"""
|
| 208 |
body = await request.json()
|
| 209 |
return trace_server.predict_one(
|
| 210 |
image_path=body.get("image_path"),
|
| 211 |
image_base64=body.get("image_base64"),
|
| 212 |
instruction=body.get("instruction", ""),
|
| 213 |
+
is_oxe=body.get("is_oxe", False),
|
| 214 |
)
|
| 215 |
|
| 216 |
@app.post("/predict_batch")
|
trace_inference.py
CHANGED
|
@@ -9,22 +9,16 @@ Heavy imports are done lazily inside load_model and run_inference.
|
|
| 9 |
import logging
|
| 10 |
import os
|
| 11 |
import tempfile
|
| 12 |
-
|
| 13 |
import re
|
|
|
|
| 14 |
from pathlib import Path
|
| 15 |
-
import torch
|
| 16 |
-
from typing import Dict, Any
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
-
# Constants
|
| 21 |
DEFAULT_MODEL_ID = "mihirgrao/trace-model"
|
| 22 |
-
|
| 23 |
-
"Predict the trajectory or trace in this image. "
|
| 24 |
-
"Output the coordinates as a list of [x, y] pairs, e.g. [[0.1, 0.2], [0.3, 0.4], ...]. "
|
| 25 |
-
"Use normalized coordinates between 0 and 1."
|
| 26 |
-
)
|
| 27 |
-
PREPROCESS_SIZE = (128, 128)
|
| 28 |
|
| 29 |
# Global model state
|
| 30 |
_model_state = {
|
|
@@ -34,11 +28,12 @@ _model_state = {
|
|
| 34 |
}
|
| 35 |
|
| 36 |
|
| 37 |
-
def build_prompt(instruction: str = "") -> str:
|
| 38 |
-
"""Build the full prompt from task instruction
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
def format_trace_points(trajectories: List) -> str:
|
|
@@ -55,7 +50,7 @@ def format_trace_points(trajectories: List) -> str:
|
|
| 55 |
return "\n".join(lines)
|
| 56 |
|
| 57 |
|
| 58 |
-
def center_crop_resize(image, size: Tuple[int, int] =
|
| 59 |
"""Center crop to square then resize. Requires PIL Image."""
|
| 60 |
from PIL import Image
|
| 61 |
|
|
@@ -64,7 +59,8 @@ def center_crop_resize(image, size: Tuple[int, int] = PREPROCESS_SIZE):
|
|
| 64 |
left = (w - min_dim) // 2
|
| 65 |
top = (h - min_dim) // 2
|
| 66 |
cropped = image.crop((left, top, left + min_dim, top + min_dim))
|
| 67 |
-
return cropped.resize(size, Image.Resampling.LANCZOS)
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def preprocess_image_for_trace(image_path: str) -> Tuple:
|
|
@@ -72,195 +68,16 @@ def preprocess_image_for_trace(image_path: str) -> Tuple:
|
|
| 72 |
from PIL import Image
|
| 73 |
|
| 74 |
img = Image.open(image_path).convert("RGB")
|
| 75 |
-
img = center_crop_resize(img,
|
| 76 |
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 77 |
img.save(tmp.name)
|
| 78 |
return img, tmp.name
|
| 79 |
|
| 80 |
|
| 81 |
-
def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
|
| 82 |
-
"""Load the trace model and processor. Returns (success, message)."""
|
| 83 |
-
global _model_state
|
| 84 |
-
|
| 85 |
-
if _model_state["model"] is not None and _model_state["model_id"] == model_id:
|
| 86 |
-
return True, f"Model already loaded: {model_id}"
|
| 87 |
-
|
| 88 |
-
try:
|
| 89 |
-
import torch
|
| 90 |
-
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 91 |
-
|
| 92 |
-
if _model_state["model"] is not None:
|
| 93 |
-
del _model_state["model"]
|
| 94 |
-
del _model_state["processor"]
|
| 95 |
-
_model_state["model"] = None
|
| 96 |
-
_model_state["processor"] = None
|
| 97 |
-
if torch.cuda.is_available():
|
| 98 |
-
torch.cuda.empty_cache()
|
| 99 |
-
|
| 100 |
-
load_kwargs = {
|
| 101 |
-
"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 102 |
-
"device_map": "auto" if torch.cuda.is_available() else None,
|
| 103 |
-
}
|
| 104 |
-
try:
|
| 105 |
-
if torch.cuda.is_available():
|
| 106 |
-
load_kwargs["attn_implementation"] = "flash_attention_2"
|
| 107 |
-
model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
|
| 108 |
-
except (ValueError, ImportError):
|
| 109 |
-
load_kwargs.pop("attn_implementation", None)
|
| 110 |
-
model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
|
| 111 |
-
processor = AutoProcessor.from_pretrained(model_id)
|
| 112 |
-
|
| 113 |
-
_model_state["model"] = model
|
| 114 |
-
_model_state["processor"] = processor
|
| 115 |
-
_model_state["model_id"] = model_id
|
| 116 |
-
|
| 117 |
-
return True, f"Model loaded: {model_id}"
|
| 118 |
-
except Exception as e:
|
| 119 |
-
logger.exception("Failed to load model")
|
| 120 |
-
return False, f"Error loading model: {str(e)}"
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
|
| 124 |
-
"""
|
| 125 |
-
Run trace model inference on an image.
|
| 126 |
-
Returns: (prediction_text, overlay_image_path, trace_points_text)
|
| 127 |
-
"""
|
| 128 |
-
success, msg = load_model(model_id)
|
| 129 |
-
if not success:
|
| 130 |
-
return msg, None, ""
|
| 131 |
-
|
| 132 |
-
model = _model_state["model"]
|
| 133 |
-
processor = _model_state["processor"]
|
| 134 |
-
|
| 135 |
-
if image_path is None or not os.path.exists(image_path):
|
| 136 |
-
return "Please provide a valid image.", None, ""
|
| 137 |
-
|
| 138 |
-
try:
|
| 139 |
-
from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
|
| 140 |
-
|
| 141 |
-
try:
|
| 142 |
-
from qwen_vl_utils import process_vision_info
|
| 143 |
-
except ImportError:
|
| 144 |
-
process_vision_info = None
|
| 145 |
-
|
| 146 |
-
preprocessed_path = None
|
| 147 |
-
try:
|
| 148 |
-
_, preprocessed_path = preprocess_image_for_trace(image_path)
|
| 149 |
-
image_uri = f"file://{os.path.abspath(preprocessed_path)}"
|
| 150 |
-
|
| 151 |
-
messages = [
|
| 152 |
-
{
|
| 153 |
-
"role": "user",
|
| 154 |
-
"content": [
|
| 155 |
-
{"type": "image", "image": image_uri},
|
| 156 |
-
{"type": "text", "text": prompt},
|
| 157 |
-
],
|
| 158 |
-
}
|
| 159 |
-
]
|
| 160 |
-
|
| 161 |
-
text = processor.apply_chat_template(
|
| 162 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
if process_vision_info is not None:
|
| 166 |
-
process_kwargs = {"return_video_kwargs": True, "return_video_metadata": True}
|
| 167 |
-
if hasattr(processor, "image_processor") and hasattr(
|
| 168 |
-
processor.image_processor, "patch_size"
|
| 169 |
-
):
|
| 170 |
-
process_kwargs["image_patch_size"] = processor.image_processor.patch_size
|
| 171 |
-
image_inputs, video_inputs, video_kwargs = process_vision_info(
|
| 172 |
-
messages, **process_kwargs
|
| 173 |
-
)
|
| 174 |
-
else:
|
| 175 |
-
from PIL import Image
|
| 176 |
-
|
| 177 |
-
pil_image = Image.open(image_path).convert("RGB")
|
| 178 |
-
image_inputs = [pil_image]
|
| 179 |
-
video_inputs = None
|
| 180 |
-
video_kwargs = {}
|
| 181 |
-
|
| 182 |
-
processor_kwargs = {
|
| 183 |
-
"text": [text],
|
| 184 |
-
"images": image_inputs,
|
| 185 |
-
"padding": True,
|
| 186 |
-
"return_tensors": "pt",
|
| 187 |
-
"do_resize": False,
|
| 188 |
-
}
|
| 189 |
-
if video_inputs is not None and len(video_inputs) > 0:
|
| 190 |
-
if isinstance(video_inputs[0], tuple):
|
| 191 |
-
videos, video_metadatas = zip(*video_inputs)
|
| 192 |
-
processor_kwargs["videos"] = list(videos)
|
| 193 |
-
processor_kwargs["video_metadata"] = list(video_metadatas)
|
| 194 |
-
else:
|
| 195 |
-
processor_kwargs["videos"] = video_inputs
|
| 196 |
-
if video_kwargs:
|
| 197 |
-
processor_kwargs.update(video_kwargs)
|
| 198 |
-
|
| 199 |
-
import torch
|
| 200 |
-
|
| 201 |
-
inputs = processor(**processor_kwargs)
|
| 202 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items() if hasattr(v, "to")}
|
| 203 |
-
|
| 204 |
-
with torch.no_grad():
|
| 205 |
-
generated_ids = model.generate(
|
| 206 |
-
**inputs, max_new_tokens=1024, do_sample=False
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
input_ids = inputs["input_ids"]
|
| 210 |
-
generated_ids_trimmed = [
|
| 211 |
-
out[len(inp) :] for inp, out in zip(input_ids, generated_ids)
|
| 212 |
-
]
|
| 213 |
-
prediction = processor.batch_decode(
|
| 214 |
-
generated_ids_trimmed,
|
| 215 |
-
skip_special_tokens=True,
|
| 216 |
-
clean_up_tokenization_spaces=False,
|
| 217 |
-
)[0]
|
| 218 |
-
|
| 219 |
-
trajectories = extract_trajectory_from_text(prediction)
|
| 220 |
-
trace_points_text = format_trace_points(trajectories)
|
| 221 |
-
|
| 222 |
-
overlay_path = None
|
| 223 |
-
if trajectories and len(trajectories) >= 2:
|
| 224 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
|
| 225 |
-
overlay_path = f.name
|
| 226 |
-
img_arr = visualize_trajectory_on_image(
|
| 227 |
-
trajectory=trajectories,
|
| 228 |
-
image_path=preprocessed_path,
|
| 229 |
-
output_path=overlay_path,
|
| 230 |
-
normalized=True,
|
| 231 |
-
)
|
| 232 |
-
if img_arr is None:
|
| 233 |
-
visualize_trajectory_on_image(
|
| 234 |
-
trajectory=trajectories,
|
| 235 |
-
image_path=preprocessed_path,
|
| 236 |
-
output_path=overlay_path,
|
| 237 |
-
normalized=False,
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
return prediction, overlay_path, trace_points_text
|
| 241 |
-
|
| 242 |
-
finally:
|
| 243 |
-
if preprocessed_path and os.path.exists(preprocessed_path):
|
| 244 |
-
try:
|
| 245 |
-
os.unlink(preprocessed_path)
|
| 246 |
-
except Exception:
|
| 247 |
-
pass
|
| 248 |
-
|
| 249 |
-
except Exception as e:
|
| 250 |
-
logger.exception("Inference failed")
|
| 251 |
-
return f"Error: {str(e)}", None, ""
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def build_franka_prompt(task: str) -> str:
|
| 255 |
-
"""Build the Franka-style prompt for trace prediction."""
|
| 256 |
-
return (
|
| 257 |
-
'<image>\nYou are a Franka robot using the joint control. '
|
| 258 |
-
f'The task is "{task}". Can you predict the trace of the end effector?'
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
def _make_abs_paths(base: Path, files: str) -> str:
|
| 262 |
return f"{(base / files).resolve()}"
|
| 263 |
|
|
|
|
| 264 |
def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]:
|
| 265 |
# Extract and normalize images and videos
|
| 266 |
images = item.get("image") or []
|
|
@@ -322,7 +139,6 @@ def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any
|
|
| 322 |
|
| 323 |
return messages
|
| 324 |
|
| 325 |
-
IGNORE_INDEX = -100
|
| 326 |
|
| 327 |
def preprocess_qwen_visual(
|
| 328 |
sources,
|
|
@@ -382,143 +198,126 @@ def preprocess_qwen_visual(
|
|
| 382 |
|
| 383 |
return full_result
|
| 384 |
|
| 385 |
-
def run_inference_qwenvl(
|
| 386 |
-
image_path: str,
|
| 387 |
-
task: str,
|
| 388 |
-
model_id: str = DEFAULT_MODEL_ID,
|
| 389 |
-
data_path: Optional[str] = None,
|
| 390 |
-
) -> Tuple[dict, str, Optional[str], str]:
|
| 391 |
-
"""
|
| 392 |
-
Run trace inference using preprocess_qwen_visual (qwenvl data processor).
|
| 393 |
-
Uses Franka-style prompt and output format.
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
model_id: Model to load.
|
| 399 |
-
data_path: Base path for image resolution. Defaults to dirname of image_path.
|
| 400 |
-
|
| 401 |
-
Returns:
|
| 402 |
-
(output_dict, prediction_text, overlay_path, trace_points_text)
|
| 403 |
-
output_dict has format: {"id", "image", "conversations": [human_msg, gpt_msg]}
|
| 404 |
-
"""
|
| 405 |
-
success, msg = load_model(model_id)
|
| 406 |
-
if not success:
|
| 407 |
-
return {}, msg, None, ""
|
| 408 |
|
| 409 |
-
model
|
| 410 |
-
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
"data_path": data_path,
|
| 425 |
-
}
|
| 426 |
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
print(data_path)
|
| 433 |
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
)
|
|
|
|
| 441 |
|
| 442 |
-
print("inference_sample")
|
| 443 |
-
print(inference_sample)
|
| 444 |
-
print("processed_data")
|
| 445 |
-
print(processed_data)
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
if "image_grid_thw" in processed_data
|
| 456 |
-
else None
|
| 457 |
-
)
|
| 458 |
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
inputs["pixel_values"] = pixel_values
|
| 462 |
-
if image_grid_thw is not None:
|
| 463 |
-
inputs["image_grid_thw"] = image_grid_thw
|
| 464 |
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
generated_ids_trimmed = [
|
| 468 |
-
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
|
| 469 |
-
]
|
| 470 |
-
prediction = processor.tokenizer.decode(
|
| 471 |
-
generated_ids_trimmed[0], skip_special_tokens=True
|
| 472 |
-
)
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
"id": sample_id,
|
| 482 |
-
"image": [image_rel],
|
| 483 |
"conversations": [
|
| 484 |
-
{
|
| 485 |
-
|
|
|
|
|
|
|
| 486 |
],
|
|
|
|
| 487 |
}
|
| 488 |
|
| 489 |
-
|
|
|
|
| 490 |
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
overlay_path = None
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
)
|
| 513 |
-
finally:
|
| 514 |
-
if preprocessed_path and os.path.exists(preprocessed_path):
|
| 515 |
-
try:
|
| 516 |
-
os.unlink(preprocessed_path)
|
| 517 |
-
except Exception:
|
| 518 |
-
pass
|
| 519 |
-
|
| 520 |
-
return output_dict, prediction, overlay_path, trace_points_text
|
| 521 |
|
| 522 |
except Exception as e:
|
| 523 |
-
logger.exception("Inference failed
|
| 524 |
-
return
|
|
|
|
| 9 |
import logging
|
| 10 |
import os
|
| 11 |
import tempfile
|
| 12 |
+
import torch
|
| 13 |
import re
|
| 14 |
+
from typing import List, Optional, Tuple, Dict, Any
|
| 15 |
from pathlib import Path
|
|
|
|
|
|
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
+
# Constants
|
| 20 |
DEFAULT_MODEL_ID = "mihirgrao/trace-model"
|
| 21 |
+
IGNORE_INDEX = -100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Global model state
|
| 24 |
_model_state = {
|
|
|
|
| 28 |
}
|
| 29 |
|
| 30 |
|
| 31 |
+
def build_prompt(instruction: str = "", is_oxe: bool = False) -> str:
|
| 32 |
+
"""Build the full prompt from task instruction."""
|
| 33 |
+
task = instruction.strip() or "predict the trace"
|
| 34 |
+
if is_oxe:
|
| 35 |
+
return f"<image>\nYou are a Franka robot using the joint control. The task is \"{task}\". Can you predict the trace of the end effector?"
|
| 36 |
+
return f"You are a robot. Your task is: \"{task}\". <image> Can you predict the trace of the end effector in this image to complete the task?"
|
| 37 |
|
| 38 |
|
| 39 |
def format_trace_points(trajectories: List) -> str:
|
|
|
|
| 50 |
return "\n".join(lines)
|
| 51 |
|
| 52 |
|
| 53 |
+
def center_crop_resize(image, size: Tuple[int, int] = (128, 128)):
|
| 54 |
"""Center crop to square then resize. Requires PIL Image."""
|
| 55 |
from PIL import Image
|
| 56 |
|
|
|
|
| 59 |
left = (w - min_dim) // 2
|
| 60 |
top = (h - min_dim) // 2
|
| 61 |
cropped = image.crop((left, top, left + min_dim, top + min_dim))
|
| 62 |
+
# return cropped.resize(size, Image.Resampling.LANCZOS)
|
| 63 |
+
return cropped
|
| 64 |
|
| 65 |
|
| 66 |
def preprocess_image_for_trace(image_path: str) -> Tuple:
|
|
|
|
| 68 |
from PIL import Image
|
| 69 |
|
| 70 |
img = Image.open(image_path).convert("RGB")
|
| 71 |
+
img = center_crop_resize(img, (128, 128))
|
| 72 |
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 73 |
img.save(tmp.name)
|
| 74 |
return img, tmp.name
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def _make_abs_paths(base: Path, files: str) -> str:
|
| 78 |
return f"{(base / files).resolve()}"
|
| 79 |
|
| 80 |
+
|
| 81 |
def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]:
|
| 82 |
# Extract and normalize images and videos
|
| 83 |
images = item.get("image") or []
|
|
|
|
| 139 |
|
| 140 |
return messages
|
| 141 |
|
|
|
|
| 142 |
|
| 143 |
def preprocess_qwen_visual(
|
| 144 |
sources,
|
|
|
|
| 198 |
|
| 199 |
return full_result
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
|
| 203 |
+
"""Load the trace model and processor. Returns (success, message)."""
|
| 204 |
+
global _model_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
if _model_state["model"] is not None and _model_state["model_id"] == model_id:
|
| 207 |
+
return True, f"Model already loaded: {model_id}"
|
| 208 |
|
| 209 |
+
try:
|
| 210 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 211 |
|
| 212 |
+
if _model_state["model"] is not None:
|
| 213 |
+
del _model_state["model"]
|
| 214 |
+
del _model_state["processor"]
|
| 215 |
+
_model_state["model"] = None
|
| 216 |
+
_model_state["processor"] = None
|
| 217 |
+
if torch.cuda.is_available():
|
| 218 |
+
torch.cuda.empty_cache()
|
| 219 |
|
| 220 |
+
logger.info(f"Loading model from {model_id}...")
|
| 221 |
+
load_kwargs = {
|
| 222 |
+
"dtype": torch.bfloat16,
|
| 223 |
+
"device_map": "auto",
|
| 224 |
+
}
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
| 227 |
+
model_id,
|
| 228 |
+
**load_kwargs,
|
| 229 |
+
)
|
| 230 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
| 231 |
|
| 232 |
+
_model_state["model"] = model
|
| 233 |
+
_model_state["processor"] = processor
|
| 234 |
+
_model_state["model_id"] = model_id
|
| 235 |
|
| 236 |
+
return True, f"Model loaded: {model_id}"
|
| 237 |
+
except Exception as e:
|
| 238 |
+
logger.exception("Failed to load model")
|
| 239 |
+
return False, f"Error loading model: {str(e)}"
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
+
def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
|
| 243 |
+
"""
|
| 244 |
+
Run trace model inference on an image.
|
| 245 |
+
Returns: (prediction_text, overlay_image_path, trace_points_text)
|
| 246 |
+
"""
|
| 247 |
+
success, msg = load_model(model_id)
|
| 248 |
+
if not success:
|
| 249 |
+
return msg, None, ""
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
model = _model_state["model"]
|
| 252 |
+
processor = _model_state["processor"]
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
+
if image_path is None or not os.path.exists(image_path):
|
| 255 |
+
return "Please provide a valid image.", None, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
try:
|
| 258 |
+
from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
|
| 259 |
|
| 260 |
+
abs_image_path = os.path.abspath(image_path)
|
| 261 |
+
raw_item = {
|
| 262 |
+
"id": "single_inference",
|
| 263 |
+
"image": [abs_image_path],
|
|
|
|
|
|
|
| 264 |
"conversations": [
|
| 265 |
+
{
|
| 266 |
+
"from": "human",
|
| 267 |
+
"value": prompt
|
| 268 |
+
}
|
| 269 |
],
|
| 270 |
+
"data_path": ""
|
| 271 |
}
|
| 272 |
|
| 273 |
+
# Preprocessing using internal method
|
| 274 |
+
processed = preprocess_qwen_visual([raw_item], processor, add_gen_prompt=True)
|
| 275 |
|
| 276 |
+
# Prepare inputs - passing only what's necessary as per the new method
|
| 277 |
+
inputs = {"input_ids": processed["input_ids"].to(model.device)}
|
| 278 |
+
if "pixel_values" in processed:
|
| 279 |
+
inputs["pixel_values"] = processed["pixel_values"].to(model.device)
|
| 280 |
+
if "image_grid_thw" in processed:
|
| 281 |
+
inputs["image_grid_thw"] = processed["image_grid_thw"].to(model.device)
|
| 282 |
|
| 283 |
+
# Generate prediction
|
| 284 |
+
with torch.no_grad():
|
| 285 |
+
generated_ids = model.generate(
|
| 286 |
+
**inputs,
|
| 287 |
+
max_new_tokens=512,
|
| 288 |
+
do_sample=False,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Trim prompt tokens
|
| 292 |
+
trimmed = generated_ids[:, inputs["input_ids"].shape[1]:]
|
| 293 |
+
|
| 294 |
+
# Decode
|
| 295 |
+
prediction = processor.tokenizer.batch_decode(
|
| 296 |
+
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 297 |
+
)[0]
|
| 298 |
+
|
| 299 |
+
trajectory = extract_trajectory_from_text(prediction)
|
| 300 |
+
|
| 301 |
+
trace_points_text = ""
|
| 302 |
overlay_path = None
|
| 303 |
+
|
| 304 |
+
if trajectory:
|
| 305 |
+
trace_points_text = format_trace_points(trajectory)
|
| 306 |
+
|
| 307 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
|
| 308 |
+
overlay_path = f.name
|
| 309 |
+
|
| 310 |
+
visualize_trajectory_on_image(
|
| 311 |
+
trajectory=trajectory,
|
| 312 |
+
image_path=abs_image_path,
|
| 313 |
+
output_path=overlay_path,
|
| 314 |
+
normalized=True
|
| 315 |
+
)
|
| 316 |
+
else:
|
| 317 |
+
trace_points_text = "No trajectory points extracted."
|
| 318 |
+
|
| 319 |
+
return prediction, overlay_path, trace_points_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
except Exception as e:
|
| 322 |
+
logger.exception("Inference failed")
|
| 323 |
+
return f"Error: {str(e)}", None, ""
|