Anthony Liang commited on
Commit
8c5e6cc
·
1 Parent(s): 5e40307
Files changed (4) hide show
  1. app.py +13 -256
  2. eval_server.py +24 -8
  3. predict_trace.py +2 -2
  4. trace_inference.py +247 -0
app.py CHANGED
@@ -16,30 +16,20 @@ from typing import List, Optional, Tuple
16
 
17
  import gradio as gr
18
  import requests
19
- import numpy as np
20
- import torch
21
- from PIL import Image
22
- from transformers import AutoModelForImageTextToText, AutoProcessor
23
 
24
- from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
25
-
26
- try:
27
- from qwen_vl_utils import process_vision_info
28
- except ImportError:
29
- process_vision_info = None
 
 
 
 
30
 
31
  logger = logging.getLogger(__name__)
32
 
33
- # Default model path (Hugging Face Hub)
34
- DEFAULT_MODEL_ID = "mihirgrao/trace-model"
35
-
36
- # Trace format instruction (always appended)
37
- TRACE_FORMAT = (
38
- "Predict the trajectory or trace in this image. "
39
- "Output the coordinates as a list of [x, y] pairs, e.g. [[0.1, 0.2], [0.3, 0.4], ...]. "
40
- "Use normalized coordinates between 0 and 1."
41
- )
42
-
43
  # Global server state (eval server mode)
44
  _server_state = {"server_url": None, "base_url": "http://localhost"}
45
 
@@ -53,16 +43,19 @@ def discover_available_models(
53
  start_port, end_port = port_range
54
  for port in range(start_port, end_port + 1):
55
  server_url = f"{base_url.rstrip('/')}:{port}"
 
56
  try:
57
  r = requests.get(f"{server_url}/health", timeout=2.0)
58
  if r.status_code == 200:
59
  try:
60
  info = requests.get(f"{server_url}/model_info", timeout=2.0).json()
 
61
  name = info.get("model_id", f"Trace @ port {port}")
62
  except Exception:
63
  name = f"Trace @ port {port}"
64
  available.append((server_url, name))
65
  except requests.exceptions.RequestException:
 
66
  continue
67
  return available
68
 
@@ -108,9 +101,6 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
108
  return f"Error connecting to server: {str(e)}", None, None
109
 
110
 
111
- PREPROCESS_SIZE = (128, 128)
112
-
113
-
114
  def run_inference_via_server(
115
  image_path: str,
116
  instruction: str,
@@ -157,239 +147,6 @@ def run_inference_via_server(
157
  os.unlink(preprocessed_path)
158
  except Exception:
159
  pass
160
- return prediction, overlay_path, trace_points_text
161
-
162
-
163
- def center_crop_resize(
164
- image: "Image.Image",
165
- size: Tuple[int, int] = PREPROCESS_SIZE,
166
- ) -> "Image.Image":
167
- """Center crop to square then resize to size (default 128x128)."""
168
- w, h = image.size
169
- min_dim = min(w, h)
170
- left = (w - min_dim) // 2
171
- top = (h - min_dim) // 2
172
- cropped = image.crop((left, top, left + min_dim, top + min_dim))
173
- return cropped.resize(size, Image.Resampling.LANCZOS)
174
-
175
-
176
- def preprocess_image_for_trace(image_path: str) -> Tuple["Image.Image", Optional[str]]:
177
- """
178
- Load image, center crop and resize to 128x128.
179
- Returns (preprocessed PIL Image, path to temp file for downstream use).
180
- """
181
- img = Image.open(image_path).convert("RGB")
182
- img = center_crop_resize(img, PREPROCESS_SIZE)
183
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
184
- img.save(tmp.name)
185
- return img, tmp.name
186
-
187
-
188
- def build_prompt(instruction: str = "") -> str:
189
- """Build the full prompt from task instruction + trace format."""
190
- if instruction.strip():
191
- return f"Task: {instruction.strip()}\n\n{TRACE_FORMAT}"
192
- return TRACE_FORMAT
193
-
194
- # Global model state (lazy loading)
195
- _model_state = {
196
- "model": None,
197
- "processor": None,
198
- "model_id": None,
199
- }
200
-
201
-
202
- def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
203
- """Load the trace model and processor. Returns (success, message)."""
204
- global _model_state
205
-
206
- if _model_state["model"] is not None and _model_state["model_id"] == model_id:
207
- return True, f"Model already loaded: {model_id}"
208
-
209
- try:
210
- # Clear previous model
211
- if _model_state["model"] is not None:
212
- del _model_state["model"]
213
- del _model_state["processor"]
214
- _model_state["model"] = None
215
- _model_state["processor"] = None
216
- if torch.cuda.is_available():
217
- torch.cuda.empty_cache()
218
-
219
- # Load model with optional flash attention
220
- load_kwargs = {
221
- "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
222
- "device_map": "auto" if torch.cuda.is_available() else None,
223
- }
224
- try:
225
- if torch.cuda.is_available():
226
- load_kwargs["attn_implementation"] = "flash_attention_2"
227
- model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
228
- except (ValueError, ImportError):
229
- load_kwargs.pop("attn_implementation", None)
230
- model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
231
- processor = AutoProcessor.from_pretrained(model_id)
232
-
233
- _model_state["model"] = model
234
- _model_state["processor"] = processor
235
- _model_state["model_id"] = model_id
236
-
237
- return True, f"Model loaded: {model_id}"
238
- except Exception as e:
239
- logger.exception("Failed to load model")
240
- return False, f"Error loading model: {str(e)}"
241
-
242
-
243
- def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
244
- """
245
- Run trace model inference on an image.
246
-
247
- Returns:
248
- (prediction_text, overlay_image_path, trace_points_text)
249
- """
250
- success, msg = load_model(model_id)
251
- if not success:
252
- return msg, None, ""
253
-
254
- model = _model_state["model"]
255
- processor = _model_state["processor"]
256
-
257
- if image_path is None or not os.path.exists(image_path):
258
- return "Please provide a valid image.", None, ""
259
-
260
- preprocessed_path = None
261
- try:
262
- # Preprocess: center crop and resize to 128x128
263
- _, preprocessed_path = preprocess_image_for_trace(image_path)
264
- image_uri = f"file://{os.path.abspath(preprocessed_path)}"
265
-
266
- messages = [
267
- {
268
- "role": "user",
269
- "content": [
270
- {"type": "image", "image": image_uri},
271
- {"type": "text", "text": prompt},
272
- ],
273
- }
274
- ]
275
-
276
- # Apply chat template
277
- text = processor.apply_chat_template(
278
- messages,
279
- tokenize=False,
280
- add_generation_prompt=True,
281
- )
282
-
283
- # Process vision info
284
- if process_vision_info is not None:
285
- process_kwargs = {"return_video_kwargs": True, "return_video_metadata": True}
286
- if hasattr(processor, "image_processor") and hasattr(
287
- processor.image_processor, "patch_size"
288
- ):
289
- process_kwargs["image_patch_size"] = processor.image_processor.patch_size
290
-
291
- image_inputs, video_inputs, video_kwargs = process_vision_info(
292
- messages, **process_kwargs
293
- )
294
- else:
295
- # Fallback: load image directly and pass to processor
296
- pil_image = Image.open(image_path).convert("RGB")
297
- image_inputs = [pil_image]
298
- video_inputs = None
299
- video_kwargs = {}
300
-
301
- # Prepare inputs
302
- processor_kwargs = {
303
- "text": [text],
304
- "images": image_inputs,
305
- "padding": True,
306
- "return_tensors": "pt",
307
- "do_resize": False,
308
- }
309
- if video_inputs is not None and len(video_inputs) > 0:
310
- if isinstance(video_inputs[0], tuple):
311
- videos, video_metadatas = zip(*video_inputs)
312
- processor_kwargs["videos"] = list(videos)
313
- processor_kwargs["video_metadata"] = list(video_metadatas)
314
- else:
315
- processor_kwargs["videos"] = video_inputs
316
- if video_kwargs:
317
- processor_kwargs.update(video_kwargs)
318
-
319
- inputs = processor(**processor_kwargs)
320
- inputs = {k: v.to(model.device) for k, v in inputs.items() if hasattr(v, "to")}
321
-
322
- # Generate
323
- with torch.no_grad():
324
- generated_ids = model.generate(
325
- **inputs,
326
- max_new_tokens=1024,
327
- do_sample=False,
328
- )
329
-
330
- # Decode output
331
- input_ids = inputs["input_ids"]
332
- generated_ids_trimmed = [
333
- out[len(inp) :] for inp, out in zip(input_ids, generated_ids)
334
- ]
335
- prediction = processor.batch_decode(
336
- generated_ids_trimmed,
337
- skip_special_tokens=True,
338
- clean_up_tokenization_spaces=False,
339
- )[0]
340
-
341
- # Extract trajectory and visualize
342
- trajectories = extract_trajectory_from_text(prediction)
343
- trace_points_text = format_trace_points(trajectories)
344
-
345
- overlay_path = None
346
- if trajectories and len(trajectories) >= 2:
347
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
348
- overlay_path = f.name
349
- # Overlay on preprocessed (128x128) image
350
- img_arr = visualize_trajectory_on_image(
351
- trajectory=trajectories,
352
- image_path=preprocessed_path,
353
- output_path=overlay_path,
354
- normalized=True,
355
- )
356
- if img_arr is None:
357
- visualize_trajectory_on_image(
358
- trajectory=trajectories,
359
- image_path=preprocessed_path,
360
- output_path=overlay_path,
361
- normalized=False,
362
- )
363
-
364
- return prediction, overlay_path, trace_points_text
365
-
366
- except Exception as e:
367
- logger.exception("Inference failed")
368
- return f"Error: {str(e)}", None, ""
369
- finally:
370
- if preprocessed_path and os.path.exists(preprocessed_path):
371
- try:
372
- os.unlink(preprocessed_path)
373
- except Exception:
374
- pass
375
-
376
-
377
- def format_trace_points(trajectories) -> str:
378
- """Format trajectory points for display. trajectories is List[List[float]]."""
379
- if not trajectories:
380
- return "No trajectory points extracted."
381
-
382
- lines = ["## Predicted Trace Points\n"]
383
- for i, pt in enumerate(trajectories):
384
- if isinstance(pt, (list, tuple)) and len(pt) >= 2:
385
- x, y = pt[0], pt[1]
386
- lines.append(f"- Point {i + 1}: `[{x:.4f}, {y:.4f}]`")
387
- else:
388
- lines.append(f"- Point {i + 1}: `{pt}`")
389
-
390
- return "\n".join(lines)
391
-
392
-
393
  # --- Gradio UI ---
394
  try:
395
  demo = gr.Blocks(title="Trace Model Visualizer", theme=gr.themes.Soft())
 
16
 
17
  import gradio as gr
18
  import requests
 
 
 
 
19
 
20
+ from trace_inference import (
21
+ DEFAULT_MODEL_ID,
22
+ TRACE_FORMAT,
23
+ build_prompt,
24
+ format_trace_points,
25
+ load_model,
26
+ preprocess_image_for_trace,
27
+ run_inference,
28
+ )
29
+ from trajectory_viz import visualize_trajectory_on_image
30
 
31
  logger = logging.getLogger(__name__)
32
 
 
 
 
 
 
 
 
 
 
 
33
  # Global server state (eval server mode)
34
  _server_state = {"server_url": None, "base_url": "http://localhost"}
35
 
 
43
  start_port, end_port = port_range
44
  for port in range(start_port, end_port + 1):
45
  server_url = f"{base_url.rstrip('/')}:{port}"
46
+ print(f"Checking {server_url}/health")
47
  try:
48
  r = requests.get(f"{server_url}/health", timeout=2.0)
49
  if r.status_code == 200:
50
  try:
51
  info = requests.get(f"{server_url}/model_info", timeout=2.0).json()
52
+ print(info)
53
  name = info.get("model_id", f"Trace @ port {port}")
54
  except Exception:
55
  name = f"Trace @ port {port}"
56
  available.append((server_url, name))
57
  except requests.exceptions.RequestException:
58
+ print(f"Error checking {server_url}/health")
59
  continue
60
  return available
61
 
 
101
  return f"Error connecting to server: {str(e)}", None, None
102
 
103
 
 
 
 
104
  def run_inference_via_server(
105
  image_path: str,
106
  instruction: str,
 
147
  os.unlink(preprocessed_path)
148
  except Exception:
149
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # --- Gradio UI ---
151
  try:
152
  demo = gr.Blocks(title="Trace Model Visualizer", theme=gr.themes.Soft())
eval_server.py CHANGED
@@ -3,7 +3,7 @@
3
  FastAPI server for Trace Model inference.
4
 
5
  Usage:
6
- python eval_server.py --model-id mihirgrao/trace-model --port 8001
7
 
8
  Endpoints:
9
  POST /predict - Single image + instruction
@@ -14,8 +14,10 @@ Endpoints:
14
 
15
  import argparse
16
  import base64
 
17
  import logging
18
  import os
 
19
  import tempfile
20
  import time
21
  from concurrent.futures import ThreadPoolExecutor
@@ -26,7 +28,13 @@ import uvicorn
26
  from fastapi import FastAPI, Request
27
  from fastapi.middleware.cors import CORSMiddleware
28
 
29
- from app import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference
 
 
 
 
 
 
30
  from trajectory_viz import extract_trajectory_from_text
31
 
32
  logger = logging.getLogger(__name__)
@@ -72,13 +80,23 @@ class TraceEvalServer:
72
  temp_file_path = None
73
  if image_path is None:
74
  try:
75
- image_bytes = base64.b64decode(image_base64)
 
 
 
 
 
 
 
 
 
 
76
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
77
- f.write(image_bytes)
78
  image_path = f.name
79
  temp_file_path = image_path
80
  except Exception as e:
81
- return {"error": f"Invalid base64 image: {e}"}
82
 
83
  try:
84
  prompt = build_prompt(instruction)
@@ -138,9 +156,7 @@ class TraceEvalServer:
138
  def get_model_info(self) -> Dict[str, Any]:
139
  """Get model information."""
140
  try:
141
- from app import _model_state
142
-
143
- model = _model_state.get("model")
144
  if model is None:
145
  return {"model_id": self.model_id, "status": "not_loaded"}
146
 
 
3
  FastAPI server for Trace Model inference.
4
 
5
  Usage:
6
+ python eval_server.py --model-id mihirgrao/trace-model --port 8000
7
 
8
  Endpoints:
9
  POST /predict - Single image + instruction
 
14
 
15
  import argparse
16
  import base64
17
+ import io
18
  import logging
19
  import os
20
+ import re
21
  import tempfile
22
  import time
23
  from concurrent.futures import ThreadPoolExecutor
 
28
  from fastapi import FastAPI, Request
29
  from fastapi.middleware.cors import CORSMiddleware
30
 
31
+ from trace_inference import (
32
+ DEFAULT_MODEL_ID,
33
+ build_prompt,
34
+ load_model,
35
+ run_inference,
36
+ )
37
+ from trace_inference import _model_state as _trace_model_state
38
  from trajectory_viz import extract_trajectory_from_text
39
 
40
  logger = logging.getLogger(__name__)
 
80
  temp_file_path = None
81
  if image_path is None:
82
  try:
83
+ # Strip data URL prefix if present (e.g. "data:image/png;base64,")
84
+ b64_str = image_base64.strip()
85
+ if b64_str.startswith("data:"):
86
+ match = re.match(r"data:image/[^;]+;base64,(.+)", b64_str, re.DOTALL)
87
+ if match:
88
+ b64_str = match.group(1)
89
+ image_bytes = base64.b64decode(b64_str, validate=False)
90
+ # Load via BytesIO to validate and get proper format, then save
91
+ from PIL import Image
92
+
93
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
94
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
95
+ img.save(f.name, format="PNG")
96
  image_path = f.name
97
  temp_file_path = image_path
98
  except Exception as e:
99
+ return {"error": f"Invalid image data: {e}"}
100
 
101
  try:
102
  prompt = build_prompt(instruction)
 
156
  def get_model_info(self) -> Dict[str, Any]:
157
  """Get model information."""
158
  try:
159
+ model = _trace_model_state.get("model")
 
 
160
  if model is None:
161
  return {"model_id": self.model_id, "status": "not_loaded"}
162
 
predict_trace.py CHANGED
@@ -2,7 +2,7 @@
2
  """
3
  CLI script to predict trace on an image using the trace model.
4
 
5
- Reuses load_model and run_inference from app.
6
  """
7
 
8
  import argparse
@@ -10,7 +10,7 @@ import os
10
  import shutil
11
  import sys
12
 
13
- from app import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference
14
 
15
 
16
  def main():
 
2
  """
3
  CLI script to predict trace on an image using the trace model.
4
 
5
+ Reuses load_model and run_inference from trace_inference.
6
  """
7
 
8
  import argparse
 
10
  import shutil
11
  import sys
12
 
13
+ from trace_inference import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference
14
 
15
 
16
  def main():
trace_inference.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared trace model inference logic.
3
+
4
+ This module has minimal top-level imports so eval_server can import
5
+ DEFAULT_MODEL_ID and build_prompt without pulling in torch/transformers.
6
+ Heavy imports are done lazily inside load_model and run_inference.
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import tempfile
12
+ from typing import List, Optional, Tuple
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Constants (no heavy deps)
17
+ DEFAULT_MODEL_ID = "mihirgrao/trace-model"
18
+ TRACE_FORMAT = (
19
+ "Predict the trajectory or trace in this image. "
20
+ "Output the coordinates as a list of [x, y] pairs, e.g. [[0.1, 0.2], [0.3, 0.4], ...]. "
21
+ "Use normalized coordinates between 0 and 1."
22
+ )
23
+ PREPROCESS_SIZE = (128, 128)
24
+
25
+ # Global model state
26
+ _model_state = {
27
+ "model": None,
28
+ "processor": None,
29
+ "model_id": None,
30
+ }
31
+
32
+
33
+ def build_prompt(instruction: str = "") -> str:
34
+ """Build the full prompt from task instruction + trace format."""
35
+ if instruction.strip():
36
+ return f"Task: {instruction.strip()}\n\n{TRACE_FORMAT}"
37
+ return TRACE_FORMAT
38
+
39
+
40
+ def format_trace_points(trajectories: List) -> str:
41
+ """Format trajectory points for display."""
42
+ if not trajectories:
43
+ return "No trajectory points extracted."
44
+ lines = ["## Predicted Trace Points\n"]
45
+ for i, pt in enumerate(trajectories):
46
+ if isinstance(pt, (list, tuple)) and len(pt) >= 2:
47
+ x, y = pt[0], pt[1]
48
+ lines.append(f"- Point {i + 1}: `[{x:.4f}, {y:.4f}]`")
49
+ else:
50
+ lines.append(f"- Point {i + 1}: `{pt}`")
51
+ return "\n".join(lines)
52
+
53
+
54
+ def center_crop_resize(image, size: Tuple[int, int] = PREPROCESS_SIZE):
55
+ """Center crop to square then resize. Requires PIL Image."""
56
+ from PIL import Image
57
+
58
+ w, h = image.size
59
+ min_dim = min(w, h)
60
+ left = (w - min_dim) // 2
61
+ top = (h - min_dim) // 2
62
+ cropped = image.crop((left, top, left + min_dim, top + min_dim))
63
+ return cropped.resize(size, Image.Resampling.LANCZOS)
64
+
65
+
66
+ def preprocess_image_for_trace(image_path: str) -> Tuple:
67
+ """Load image, center crop and resize to 128x128. Returns (PIL Image, temp_path)."""
68
+ from PIL import Image
69
+
70
+ img = Image.open(image_path).convert("RGB")
71
+ img = center_crop_resize(img, PREPROCESS_SIZE)
72
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
73
+ img.save(tmp.name)
74
+ return img, tmp.name
75
+
76
+
77
+ def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
78
+ """Load the trace model and processor. Returns (success, message)."""
79
+ global _model_state
80
+
81
+ if _model_state["model"] is not None and _model_state["model_id"] == model_id:
82
+ return True, f"Model already loaded: {model_id}"
83
+
84
+ try:
85
+ import torch
86
+ from transformers import AutoModelForImageTextToText, AutoProcessor
87
+
88
+ if _model_state["model"] is not None:
89
+ del _model_state["model"]
90
+ del _model_state["processor"]
91
+ _model_state["model"] = None
92
+ _model_state["processor"] = None
93
+ if torch.cuda.is_available():
94
+ torch.cuda.empty_cache()
95
+
96
+ load_kwargs = {
97
+ "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
98
+ "device_map": "auto" if torch.cuda.is_available() else None,
99
+ }
100
+ try:
101
+ if torch.cuda.is_available():
102
+ load_kwargs["attn_implementation"] = "flash_attention_2"
103
+ model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
104
+ except (ValueError, ImportError):
105
+ load_kwargs.pop("attn_implementation", None)
106
+ model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
107
+ processor = AutoProcessor.from_pretrained(model_id)
108
+
109
+ _model_state["model"] = model
110
+ _model_state["processor"] = processor
111
+ _model_state["model_id"] = model_id
112
+
113
+ return True, f"Model loaded: {model_id}"
114
+ except Exception as e:
115
+ logger.exception("Failed to load model")
116
+ return False, f"Error loading model: {str(e)}"
117
+
118
+
119
+ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
120
+ """
121
+ Run trace model inference on an image.
122
+ Returns: (prediction_text, overlay_image_path, trace_points_text)
123
+ """
124
+ success, msg = load_model(model_id)
125
+ if not success:
126
+ return msg, None, ""
127
+
128
+ model = _model_state["model"]
129
+ processor = _model_state["processor"]
130
+
131
+ if image_path is None or not os.path.exists(image_path):
132
+ return "Please provide a valid image.", None, ""
133
+
134
+ try:
135
+ from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
136
+
137
+ try:
138
+ from qwen_vl_utils import process_vision_info
139
+ except ImportError:
140
+ process_vision_info = None
141
+
142
+ preprocessed_path = None
143
+ try:
144
+ _, preprocessed_path = preprocess_image_for_trace(image_path)
145
+ image_uri = f"file://{os.path.abspath(preprocessed_path)}"
146
+
147
+ messages = [
148
+ {
149
+ "role": "user",
150
+ "content": [
151
+ {"type": "image", "image": image_uri},
152
+ {"type": "text", "text": prompt},
153
+ ],
154
+ }
155
+ ]
156
+
157
+ text = processor.apply_chat_template(
158
+ messages, tokenize=False, add_generation_prompt=True
159
+ )
160
+
161
+ if process_vision_info is not None:
162
+ process_kwargs = {"return_video_kwargs": True, "return_video_metadata": True}
163
+ if hasattr(processor, "image_processor") and hasattr(
164
+ processor.image_processor, "patch_size"
165
+ ):
166
+ process_kwargs["image_patch_size"] = processor.image_processor.patch_size
167
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
168
+ messages, **process_kwargs
169
+ )
170
+ else:
171
+ from PIL import Image
172
+
173
+ pil_image = Image.open(image_path).convert("RGB")
174
+ image_inputs = [pil_image]
175
+ video_inputs = None
176
+ video_kwargs = {}
177
+
178
+ processor_kwargs = {
179
+ "text": [text],
180
+ "images": image_inputs,
181
+ "padding": True,
182
+ "return_tensors": "pt",
183
+ "do_resize": False,
184
+ }
185
+ if video_inputs is not None and len(video_inputs) > 0:
186
+ if isinstance(video_inputs[0], tuple):
187
+ videos, video_metadatas = zip(*video_inputs)
188
+ processor_kwargs["videos"] = list(videos)
189
+ processor_kwargs["video_metadata"] = list(video_metadatas)
190
+ else:
191
+ processor_kwargs["videos"] = video_inputs
192
+ if video_kwargs:
193
+ processor_kwargs.update(video_kwargs)
194
+
195
+ import torch
196
+
197
+ inputs = processor(**processor_kwargs)
198
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if hasattr(v, "to")}
199
+
200
+ with torch.no_grad():
201
+ generated_ids = model.generate(
202
+ **inputs, max_new_tokens=1024, do_sample=False
203
+ )
204
+
205
+ input_ids = inputs["input_ids"]
206
+ generated_ids_trimmed = [
207
+ out[len(inp) :] for inp, out in zip(input_ids, generated_ids)
208
+ ]
209
+ prediction = processor.batch_decode(
210
+ generated_ids_trimmed,
211
+ skip_special_tokens=True,
212
+ clean_up_tokenization_spaces=False,
213
+ )[0]
214
+
215
+ trajectories = extract_trajectory_from_text(prediction)
216
+ trace_points_text = format_trace_points(trajectories)
217
+
218
+ overlay_path = None
219
+ if trajectories and len(trajectories) >= 2:
220
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
221
+ overlay_path = f.name
222
+ img_arr = visualize_trajectory_on_image(
223
+ trajectory=trajectories,
224
+ image_path=preprocessed_path,
225
+ output_path=overlay_path,
226
+ normalized=True,
227
+ )
228
+ if img_arr is None:
229
+ visualize_trajectory_on_image(
230
+ trajectory=trajectories,
231
+ image_path=preprocessed_path,
232
+ output_path=overlay_path,
233
+ normalized=False,
234
+ )
235
+
236
+ return prediction, overlay_path, trace_points_text
237
+
238
+ finally:
239
+ if preprocessed_path and os.path.exists(preprocessed_path):
240
+ try:
241
+ os.unlink(preprocessed_path)
242
+ except Exception:
243
+ pass
244
+
245
+ except Exception as e:
246
+ logger.exception("Inference failed")
247
+ return f"Error: {str(e)}", None, ""