Anthony Liang commited on
Commit
5e40307
·
1 Parent(s): 7c21061

added eval server code

Browse files
Files changed (5) hide show
  1. README.md +13 -0
  2. app.py +307 -23
  3. eval_server.py +280 -0
  4. predict_trace.py +1 -1
  5. requirements.txt +3 -0
README.md CHANGED
@@ -41,6 +41,19 @@ Then open the URL (default: http://localhost:7860).
41
  3. Click **Run Inference**
42
  4. View the overlay image and predicted trace points
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ### CLI script
45
 
46
  ```bash
 
41
  3. Click **Run Inference**
42
  4. View the overlay image and predicted trace points
43
 
44
+ ### Eval server
45
+
46
+ Run a FastAPI server for batch evaluation (e.g. from scripts or the Gradio app):
47
+
48
+ ```bash
49
+ python eval_server.py --model-id mihirgrao/trace-model --port 8001
50
+ ```
51
+
52
+ Endpoints:
53
+ - `POST /predict` – single image + instruction
54
+ - `POST /predict_batch` – batch of `{image_path?|image_base64?, instruction}` samples
55
+ - `GET /health`, `GET /model_info`
56
+
57
  ### CLI script
58
 
59
  ```bash
app.py CHANGED
@@ -8,12 +8,14 @@ overlays the trace on the image, and displays the predicted coordinates.
8
  Model: https://huggingface.co/mihirgrao/trace-model
9
  """
10
 
 
11
  import os
12
  import tempfile
13
  import logging
14
- from typing import Optional, Tuple
15
 
16
  import gradio as gr
 
17
  import numpy as np
18
  import torch
19
  from PIL import Image
@@ -38,6 +40,150 @@ TRACE_FORMAT = (
38
  "Use normalized coordinates between 0 and 1."
39
  )
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def build_prompt(instruction: str = "") -> str:
43
  """Build the full prompt from task instruction + trace format."""
@@ -111,12 +257,11 @@ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Opt
111
  if image_path is None or not os.path.exists(image_path):
112
  return "Please provide a valid image.", None, ""
113
 
 
114
  try:
115
- # Ensure file:// format for qwen_vl_utils
116
- if not image_path.startswith("file://") and not image_path.startswith("http"):
117
- image_uri = f"file://{os.path.abspath(image_path)}"
118
- else:
119
- image_uri = image_path
120
 
121
  messages = [
122
  {
@@ -201,18 +346,17 @@ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Opt
201
  if trajectories and len(trajectories) >= 2:
202
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
203
  overlay_path = f.name
204
- # Try normalized first (common for VLMs)
205
  img_arr = visualize_trajectory_on_image(
206
  trajectory=trajectories,
207
- image_path=image_path,
208
  output_path=overlay_path,
209
  normalized=True,
210
  )
211
  if img_arr is None:
212
- # Fallback: pixel coordinates
213
  visualize_trajectory_on_image(
214
  trajectory=trajectories,
215
- image_path=image_path,
216
  output_path=overlay_path,
217
  normalized=False,
218
  )
@@ -222,6 +366,12 @@ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Opt
222
  except Exception as e:
223
  logger.exception("Inference failed")
224
  return f"Error: {str(e)}", None, ""
 
 
 
 
 
 
225
 
226
 
227
  def format_trace_points(trajectories) -> str:
@@ -251,12 +401,109 @@ with demo:
251
  """
252
  # Trace Model Visualizer
253
 
254
- Upload an image to predict the trajectory/trace using [mihirgrao/trace-model](https://huggingface.co/mihirgrao/trace-model).
255
 
256
- The model predicts coordinate points; they are overlaid on the image (green → red gradient) and listed below.
257
  """
258
  )
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  with gr.Row():
261
  with gr.Column(scale=1):
262
  image_input = gr.Image(
@@ -265,21 +512,26 @@ with demo:
265
  height=400,
266
  )
267
  instruction_input = gr.Textbox(
268
- label="Task / Language instruction",
269
- placeholder="e.g. Pick up the red block and place it on the table",
270
  value="",
271
- lines=2,
272
- info="Describe the task. The model will predict the trace for this instruction.",
273
  )
 
274
  model_id_input = gr.Textbox(
275
  label="Model ID",
276
  value=DEFAULT_MODEL_ID,
277
- info="Hugging Face model ID",
278
  )
279
  load_model_btn = gr.Button("Load Model", variant="secondary")
280
  run_btn = gr.Button("Run Inference", variant="primary")
281
 
282
  with gr.Column(scale=1):
 
 
 
 
283
  overlay_output = gr.Image(
284
  label="Image with Trace Overlay",
285
  height=400,
@@ -292,19 +544,35 @@ with demo:
292
  label="Extracted Trace Points",
293
  )
294
 
295
- status_md = gr.Markdown("Click 'Load Model' to load the trace model, then 'Run Inference' on an image.")
 
 
296
 
297
  def on_load_model(model_id: str):
298
  _, msg = load_model(model_id)
299
  return f"**Status:** {msg}"
300
 
301
- def on_run_inference(image_path, instruction, model_id):
302
  if image_path is None:
303
- return "Please upload an image first.", None, "", "**Status:** Please upload an image."
 
 
 
 
 
 
304
  prompt = build_prompt(instruction)
305
- pred, overlay_path, trace_text = run_inference(image_path, prompt, model_id)
 
 
 
 
 
 
 
 
306
  status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
307
- return pred, overlay_path, trace_text, status
308
 
309
  load_model_btn.click(
310
  fn=on_load_model,
@@ -312,10 +580,26 @@ with demo:
312
  outputs=[status_md],
313
  )
314
 
 
 
 
 
 
 
 
 
 
 
315
  run_btn.click(
316
  fn=on_run_inference,
317
- inputs=[image_input, instruction_input, model_id_input],
 
 
 
 
 
318
  outputs=[
 
319
  prediction_output,
320
  overlay_output,
321
  trace_points_output,
 
8
  Model: https://huggingface.co/mihirgrao/trace-model
9
  """
10
 
11
+ import base64
12
  import os
13
  import tempfile
14
  import logging
15
+ from typing import List, Optional, Tuple
16
 
17
  import gradio as gr
18
+ import requests
19
  import numpy as np
20
  import torch
21
  from PIL import Image
 
40
  "Use normalized coordinates between 0 and 1."
41
  )
42
 
43
+ # Global server state (eval server mode)
44
+ _server_state = {"server_url": None, "base_url": "http://localhost"}
45
+
46
+
47
+ def discover_available_models(
48
+ base_url: str = "http://localhost",
49
+ port_range: Tuple[int, int] = (8000, 8010),
50
+ ) -> List[Tuple[str, str]]:
51
+ """Discover trace eval servers by pinging /health on ports. Returns [(server_url, model_name), ...]."""
52
+ available = []
53
+ start_port, end_port = port_range
54
+ for port in range(start_port, end_port + 1):
55
+ server_url = f"{base_url.rstrip('/')}:{port}"
56
+ try:
57
+ r = requests.get(f"{server_url}/health", timeout=2.0)
58
+ if r.status_code == 200:
59
+ try:
60
+ info = requests.get(f"{server_url}/model_info", timeout=2.0).json()
61
+ name = info.get("model_id", f"Trace @ port {port}")
62
+ except Exception:
63
+ name = f"Trace @ port {port}"
64
+ available.append((server_url, name))
65
+ except requests.exceptions.RequestException:
66
+ continue
67
+ return available
68
+
69
+
70
+ def get_model_info_for_url(server_url: str) -> Optional[str]:
71
+ """Get formatted model info for a trace eval server."""
72
+ if not server_url:
73
+ return None
74
+ try:
75
+ r = requests.get(f"{server_url.rstrip('/')}/model_info", timeout=5.0)
76
+ if r.status_code == 200:
77
+ return format_trace_model_info(r.json())
78
+ except Exception as e:
79
+ logger.warning(f"Could not fetch model info: {e}")
80
+ return None
81
+
82
+
83
+ def format_trace_model_info(info: dict) -> str:
84
+ """Format trace model info as markdown."""
85
+ lines = ["## Model Information\n"]
86
+ lines.append(f"**Model ID:** `{info.get('model_id', 'Unknown')}`\n")
87
+ if "model_class" in info:
88
+ lines.append(f"**Model Class:** `{info.get('model_class')}`\n")
89
+ if "total_parameters" in info:
90
+ lines.append(f"**Parameters:** {info.get('total_parameters', 0):,}\n")
91
+ if "error" in info:
92
+ lines.append(f"**Error:** {info['error']}\n")
93
+ return "".join(lines)
94
+
95
+
96
+ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
97
+ """Check trace eval server health. Returns (status_msg, health_data, model_info_text)."""
98
+ if not server_url:
99
+ return "Please provide a server URL.", None, None
100
+ try:
101
+ r = requests.get(f"{server_url.rstrip('/')}/health", timeout=5.0)
102
+ r.raise_for_status()
103
+ data = r.json()
104
+ info = get_model_info_for_url(server_url)
105
+ _server_state["server_url"] = server_url
106
+ return f"Server connected: {data.get('status', 'ok')}", data, info
107
+ except requests.exceptions.RequestException as e:
108
+ return f"Error connecting to server: {str(e)}", None, None
109
+
110
+
111
+ PREPROCESS_SIZE = (128, 128)
112
+
113
+
114
+ def run_inference_via_server(
115
+ image_path: str,
116
+ instruction: str,
117
+ server_url: str,
118
+ ) -> Tuple[str, Optional[str], str]:
119
+ """Run inference via trace eval server. Returns (prediction, overlay_path, trace_points_text)."""
120
+ with open(image_path, "rb") as f:
121
+ image_b64 = base64.b64encode(f.read()).decode("utf-8")
122
+ r = requests.post(
123
+ f"{server_url.rstrip('/')}/predict",
124
+ json={"image_base64": image_b64, "instruction": instruction},
125
+ timeout=120.0,
126
+ )
127
+ r.raise_for_status()
128
+ data = r.json()
129
+ if "error" in data:
130
+ return data["error"], None, ""
131
+ prediction = data.get("prediction", "")
132
+ trajectory = data.get("trajectory", [])
133
+ trace_points_text = format_trace_points(trajectory)
134
+
135
+ overlay_path = None
136
+ if trajectory and len(trajectory) >= 2:
137
+ _, preprocessed_path = preprocess_image_for_trace(image_path)
138
+ try:
139
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
140
+ overlay_path = f.name
141
+ img_arr = visualize_trajectory_on_image(
142
+ trajectory=trajectory,
143
+ image_path=preprocessed_path,
144
+ output_path=overlay_path,
145
+ normalized=True,
146
+ )
147
+ if img_arr is None:
148
+ visualize_trajectory_on_image(
149
+ trajectory=trajectory,
150
+ image_path=preprocessed_path,
151
+ output_path=overlay_path,
152
+ normalized=False,
153
+ )
154
+ finally:
155
+ if os.path.exists(preprocessed_path):
156
+ try:
157
+ os.unlink(preprocessed_path)
158
+ except Exception:
159
+ pass
160
+ return prediction, overlay_path, trace_points_text
161
+
162
+
163
+ def center_crop_resize(
164
+ image: "Image.Image",
165
+ size: Tuple[int, int] = PREPROCESS_SIZE,
166
+ ) -> "Image.Image":
167
+ """Center crop to square then resize to size (default 128x128)."""
168
+ w, h = image.size
169
+ min_dim = min(w, h)
170
+ left = (w - min_dim) // 2
171
+ top = (h - min_dim) // 2
172
+ cropped = image.crop((left, top, left + min_dim, top + min_dim))
173
+ return cropped.resize(size, Image.Resampling.LANCZOS)
174
+
175
+
176
+ def preprocess_image_for_trace(image_path: str) -> Tuple["Image.Image", Optional[str]]:
177
+ """
178
+ Load image, center crop and resize to 128x128.
179
+ Returns (preprocessed PIL Image, path to temp file for downstream use).
180
+ """
181
+ img = Image.open(image_path).convert("RGB")
182
+ img = center_crop_resize(img, PREPROCESS_SIZE)
183
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
184
+ img.save(tmp.name)
185
+ return img, tmp.name
186
+
187
 
188
  def build_prompt(instruction: str = "") -> str:
189
  """Build the full prompt from task instruction + trace format."""
 
257
  if image_path is None or not os.path.exists(image_path):
258
  return "Please provide a valid image.", None, ""
259
 
260
+ preprocessed_path = None
261
  try:
262
+ # Preprocess: center crop and resize to 128x128
263
+ _, preprocessed_path = preprocess_image_for_trace(image_path)
264
+ image_uri = f"file://{os.path.abspath(preprocessed_path)}"
 
 
265
 
266
  messages = [
267
  {
 
346
  if trajectories and len(trajectories) >= 2:
347
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
348
  overlay_path = f.name
349
+ # Overlay on preprocessed (128x128) image
350
  img_arr = visualize_trajectory_on_image(
351
  trajectory=trajectories,
352
+ image_path=preprocessed_path,
353
  output_path=overlay_path,
354
  normalized=True,
355
  )
356
  if img_arr is None:
 
357
  visualize_trajectory_on_image(
358
  trajectory=trajectories,
359
+ image_path=preprocessed_path,
360
  output_path=overlay_path,
361
  normalized=False,
362
  )
 
366
  except Exception as e:
367
  logger.exception("Inference failed")
368
  return f"Error: {str(e)}", None, ""
369
+ finally:
370
+ if preprocessed_path and os.path.exists(preprocessed_path):
371
+ try:
372
+ os.unlink(preprocessed_path)
373
+ except Exception:
374
+ pass
375
 
376
 
377
  def format_trace_points(trajectories) -> str:
 
401
  """
402
  # Trace Model Visualizer
403
 
404
+ Upload an image and provide a natural language task instruction to predict the trajectory/trace using [mihirgrao/trace-model](https://huggingface.co/mihirgrao/trace-model).
405
 
406
+ The model predicts coordinate points from your instruction; they are overlaid on the image (green → red gradient) and listed below.
407
  """
408
  )
409
 
410
+ server_url_state = gr.State(value=None)
411
+ model_url_mapping_state = gr.State(value={})
412
+
413
+ def discover_and_select_models(base_url: str):
414
+ if not base_url:
415
+ return (
416
+ gr.update(choices=[], value=None),
417
+ gr.update(value="Please provide a base URL", visible=True),
418
+ gr.update(value="", visible=True),
419
+ None,
420
+ {},
421
+ )
422
+ _server_state["base_url"] = base_url
423
+ models = discover_available_models(base_url, port_range=(8000, 8010))
424
+ if not models:
425
+ return (
426
+ gr.update(choices=[], value=None),
427
+ gr.update(
428
+ value="❌ No trace eval servers found on ports 8000-8010.",
429
+ visible=True,
430
+ ),
431
+ gr.update(value="", visible=True),
432
+ None,
433
+ {},
434
+ )
435
+ choices = []
436
+ url_map = {}
437
+ for url, name in models:
438
+ choices.append(name)
439
+ url_map[name] = url
440
+ selected = choices[0] if choices else None
441
+ selected_url = url_map.get(selected) if selected else None
442
+ model_info_text = get_model_info_for_url(selected_url) if selected_url else ""
443
+ status = f"✅ Found {len(models)} server(s). Auto-selected first."
444
+ _server_state["server_url"] = selected_url
445
+ return (
446
+ gr.update(choices=choices, value=selected),
447
+ gr.update(value=status, visible=True),
448
+ gr.update(value=model_info_text, visible=True),
449
+ selected_url,
450
+ url_map,
451
+ )
452
+
453
+ def on_model_selected(model_choice: str, url_mapping: dict):
454
+ if not model_choice:
455
+ return gr.update(value="No model selected", visible=True), gr.update(value="", visible=True), None
456
+ server_url = url_mapping.get(model_choice) if url_mapping else None
457
+ if not server_url:
458
+ return (
459
+ gr.update(value="Could not find server URL. Please rediscover.", visible=True),
460
+ gr.update(value="", visible=True),
461
+ None,
462
+ )
463
+ model_info_text = get_model_info_for_url(server_url) or ""
464
+ status, _, _ = check_server_health(server_url)
465
+ _server_state["server_url"] = server_url
466
+ return gr.update(value=status, visible=True), gr.update(value=model_info_text, visible=True), server_url
467
+
468
+ with gr.Sidebar():
469
+ gr.Markdown("### 🔧 Model Configuration")
470
+
471
+ base_url_input = gr.Textbox(
472
+ label="Base Server URL",
473
+ placeholder="http://localhost",
474
+ value="http://localhost",
475
+ interactive=True,
476
+ )
477
+ discover_btn = gr.Button("🔍 Discover Eval Servers", variant="primary", size="lg")
478
+ model_dropdown = gr.Dropdown(
479
+ label="Select Eval Server",
480
+ choices=[],
481
+ value=None,
482
+ interactive=True,
483
+ info="Discover trace eval servers on ports 8000-8010",
484
+ )
485
+ server_status = gr.Markdown("Click 'Discover Eval Servers' or use local model below")
486
+ gr.Markdown("---")
487
+ gr.Markdown("### 📋 Model Information")
488
+ model_info_display = gr.Markdown("")
489
+
490
+ discover_btn.click(
491
+ fn=discover_and_select_models,
492
+ inputs=[base_url_input],
493
+ outputs=[
494
+ model_dropdown,
495
+ server_status,
496
+ model_info_display,
497
+ server_url_state,
498
+ model_url_mapping_state,
499
+ ],
500
+ )
501
+ model_dropdown.change(
502
+ fn=on_model_selected,
503
+ inputs=[model_dropdown, model_url_mapping_state],
504
+ outputs=[server_status, model_info_display, server_url_state],
505
+ )
506
+
507
  with gr.Row():
508
  with gr.Column(scale=1):
509
  image_input = gr.Image(
 
512
  height=400,
513
  )
514
  instruction_input = gr.Textbox(
515
+ label="Natural language instruction",
516
+ placeholder="e.g. Pick up the red block and place it on the table. Stack the cube on top of the block.",
517
  value="",
518
+ lines=4,
519
+ info="Enter a task description in natural language. The model predicts the trace for this instruction.",
520
  )
521
+ gr.Markdown("### Local model (if no eval server selected)")
522
  model_id_input = gr.Textbox(
523
  label="Model ID",
524
  value=DEFAULT_MODEL_ID,
525
+ info="Hugging Face model ID (used when no eval server is selected)",
526
  )
527
  load_model_btn = gr.Button("Load Model", variant="secondary")
528
  run_btn = gr.Button("Run Inference", variant="primary")
529
 
530
  with gr.Column(scale=1):
531
+ prompt_display = gr.Markdown(
532
+ f"**Prompt sent to model:**\n\n```\n{build_prompt('')}\n```",
533
+ label="Model prompt",
534
+ )
535
  overlay_output = gr.Image(
536
  label="Image with Trace Overlay",
537
  height=400,
 
544
  label="Extracted Trace Points",
545
  )
546
 
547
+ status_md = gr.Markdown(
548
+ "Select an eval server from the sidebar, or load a local model and run inference."
549
+ )
550
 
551
  def on_load_model(model_id: str):
552
  _, msg = load_model(model_id)
553
  return f"**Status:** {msg}"
554
 
555
+ def on_run_inference(image_path, instruction, model_id, server_url):
556
  if image_path is None:
557
+ return (
558
+ "",
559
+ "Please upload an image first.",
560
+ None,
561
+ "",
562
+ "**Status:** Please upload an image.",
563
+ )
564
  prompt = build_prompt(instruction)
565
+ prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
566
+
567
+ if server_url:
568
+ pred, overlay_path, trace_text = run_inference_via_server(
569
+ image_path, instruction, server_url
570
+ )
571
+ else:
572
+ pred, overlay_path, trace_text = run_inference(image_path, prompt, model_id)
573
+
574
  status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
575
+ return prompt_md, pred, overlay_path, trace_text, status
576
 
577
  load_model_btn.click(
578
  fn=on_load_model,
 
580
  outputs=[status_md],
581
  )
582
 
583
+ def update_prompt_display(instruction: str):
584
+ prompt = build_prompt(instruction)
585
+ return f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
586
+
587
+ instruction_input.change(
588
+ fn=update_prompt_display,
589
+ inputs=[instruction_input],
590
+ outputs=[prompt_display],
591
+ )
592
+
593
  run_btn.click(
594
  fn=on_run_inference,
595
+ inputs=[
596
+ image_input,
597
+ instruction_input,
598
+ model_id_input,
599
+ server_url_state,
600
+ ],
601
  outputs=[
602
+ prompt_display,
603
  prediction_output,
604
  overlay_output,
605
  trace_points_output,
eval_server.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FastAPI server for Trace Model inference.
4
+
5
+ Usage:
6
+ python eval_server.py --model-id mihirgrao/trace-model --port 8001
7
+
8
+ Endpoints:
9
+ POST /predict - Single image + instruction
10
+ POST /predict_batch - Batch of (image, instruction) pairs
11
+ GET /health - Health check
12
+ GET /model_info - Model information
13
+ """
14
+
15
+ import argparse
16
+ import base64
17
+ import logging
18
+ import os
19
+ import tempfile
20
+ import time
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ from threading import Lock
23
+ from typing import Any, Dict, List, Optional
24
+
25
+ import uvicorn
26
+ from fastapi import FastAPI, Request
27
+ from fastapi.middleware.cors import CORSMiddleware
28
+
29
+ from app import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference
30
+ from trajectory_viz import extract_trajectory_from_text
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # --- Trace Eval Server ---
35
+
36
+
37
+ class TraceEvalServer:
38
+ """Inference server for the trace model."""
39
+
40
+ def __init__(
41
+ self,
42
+ model_id: str = DEFAULT_MODEL_ID,
43
+ max_workers: int = 1,
44
+ ):
45
+ self.model_id = model_id
46
+ self.max_workers = max_workers
47
+ self._job_counter = 0
48
+ self._completed_jobs = 0
49
+ self._lock = Lock()
50
+ self.executor = ThreadPoolExecutor(max_workers=max_workers)
51
+
52
+ logger.info(f"Loading trace model: {model_id}")
53
+ success, msg = load_model(model_id)
54
+ if not success:
55
+ raise RuntimeError(f"Failed to load model: {msg}")
56
+ logger.info(msg)
57
+
58
+ def predict_one(
59
+ self,
60
+ image_path: Optional[str] = None,
61
+ image_base64: Optional[str] = None,
62
+ instruction: str = "",
63
+ ) -> Dict[str, Any]:
64
+ """
65
+ Run inference on a single image.
66
+
67
+ Provide either image_path (file path) or image_base64 (base64-encoded image).
68
+ """
69
+ if image_path is None and image_base64 is None:
70
+ return {"error": "Provide image_path or image_base64"}
71
+
72
+ temp_file_path = None
73
+ if image_path is None:
74
+ try:
75
+ image_bytes = base64.b64decode(image_base64)
76
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
77
+ f.write(image_bytes)
78
+ image_path = f.name
79
+ temp_file_path = image_path
80
+ except Exception as e:
81
+ return {"error": f"Invalid base64 image: {e}"}
82
+
83
+ try:
84
+ prompt = build_prompt(instruction)
85
+ prediction, overlay_path, _ = run_inference(image_path, prompt, self.model_id)
86
+ finally:
87
+ if temp_file_path and os.path.exists(temp_file_path):
88
+ try:
89
+ os.unlink(temp_file_path)
90
+ except Exception:
91
+ pass
92
+
93
+ if prediction.startswith("Error:") or prediction.startswith("Please "):
94
+ return {"error": prediction}
95
+
96
+ trajectory = extract_trajectory_from_text(prediction)
97
+ return {
98
+ "prediction": prediction,
99
+ "trajectory": trajectory,
100
+ }
101
+
102
+ def predict_batch(
103
+ self,
104
+ samples: List[Dict[str, Any]],
105
+ ) -> Dict[str, Any]:
106
+ """Process a batch of (image_path or image_base64, instruction) samples."""
107
+ results = []
108
+ for sample in samples:
109
+ with self._lock:
110
+ self._job_counter += 1
111
+ job_id = self._job_counter
112
+
113
+ start = time.time()
114
+ result = self.predict_one(
115
+ image_path=sample.get("image_path"),
116
+ image_base64=sample.get("image_base64"),
117
+ instruction=sample.get("instruction", ""),
118
+ )
119
+ elapsed = time.time() - start
120
+
121
+ with self._lock:
122
+ self._completed_jobs += 1
123
+
124
+ logger.debug(f"[job {job_id}] completed in {elapsed:.3f}s")
125
+ results.append(result)
126
+
127
+ return {"results": results}
128
+
129
+ def get_status(self) -> Dict[str, Any]:
130
+ """Get server status."""
131
+ return {
132
+ "model_id": self.model_id,
133
+ "max_workers": self.max_workers,
134
+ "completed_jobs": self._completed_jobs,
135
+ "job_counter": self._job_counter,
136
+ }
137
+
138
+ def get_model_info(self) -> Dict[str, Any]:
139
+ """Get model information."""
140
+ try:
141
+ from app import _model_state
142
+
143
+ model = _model_state.get("model")
144
+ if model is None:
145
+ return {"model_id": self.model_id, "status": "not_loaded"}
146
+
147
+ all_params = sum(p.numel() for p in model.parameters())
148
+ return {
149
+ "model_id": self.model_id,
150
+ "model_class": model.__class__.__name__,
151
+ "total_parameters": all_params,
152
+ }
153
+ except Exception as e:
154
+ return {"model_id": self.model_id, "error": str(e)}
155
+
156
+ def shutdown(self):
157
+ """Shutdown the executor."""
158
+ self.executor.shutdown(wait=True)
159
+
160
+
161
+ def create_app(
162
+ model_id: str = DEFAULT_MODEL_ID,
163
+ max_workers: int = 1,
164
+ server: Optional[TraceEvalServer] = None,
165
+ ) -> FastAPI:
166
+ app = FastAPI(title="Trace Model Evaluation Server")
167
+
168
+ app.add_middleware(
169
+ CORSMiddleware,
170
+ allow_origins=["*"],
171
+ allow_credentials=True,
172
+ allow_methods=["*"],
173
+ allow_headers=["*"],
174
+ )
175
+
176
+ trace_server = server or TraceEvalServer(model_id=model_id, max_workers=max_workers)
177
+
178
+ @app.post("/predict")
179
+ async def predict(request: Request) -> Dict[str, Any]:
180
+ """
181
+ Predict trace for a single image.
182
+
183
+ JSON body:
184
+ - image_path: (optional) path to image file
185
+ - image_base64: (optional) base64-encoded image
186
+ - instruction: natural language task description
187
+ """
188
+ body = await request.json()
189
+ return trace_server.predict_one(
190
+ image_path=body.get("image_path"),
191
+ image_base64=body.get("image_base64"),
192
+ instruction=body.get("instruction", ""),
193
+ )
194
+
195
+ @app.post("/predict_batch")
196
+ async def predict_batch(request: Request) -> Dict[str, Any]:
197
+ """
198
+ Predict trace for a batch of images.
199
+
200
+ JSON body:
201
+ - samples: list of {image_path?, image_base64?, instruction}
202
+ """
203
+ body = await request.json()
204
+ samples = body.get("samples", [])
205
+ if not samples:
206
+ return {"error": "samples list is required", "results": []}
207
+ return trace_server.predict_batch(samples)
208
+
209
+ @app.post("/evaluate_batch")
210
+ async def evaluate_batch(request: Request) -> Dict[str, Any]:
211
+ """
212
+ Alias for /predict_batch for compatibility with RFM-style clients.
213
+ Accepts same format as /predict_batch.
214
+ """
215
+ return await predict_batch(request)
216
+
217
+ @app.get("/health")
218
+ def health() -> Dict[str, Any]:
219
+ """Health check."""
220
+ status = trace_server.get_status()
221
+ return {
222
+ "status": "healthy",
223
+ "model_id": status["model_id"],
224
+ }
225
+
226
+ @app.get("/model_info")
227
+ def model_info() -> Dict[str, Any]:
228
+ """Get model information."""
229
+ return trace_server.get_model_info()
230
+
231
+ @app.get("/gpu_status")
232
+ def gpu_status() -> Dict[str, Any]:
233
+ """Get server status (RFM-compatible endpoint name)."""
234
+ return trace_server.get_status()
235
+
236
+ @app.on_event("shutdown")
237
+ async def shutdown_event():
238
+ trace_server.shutdown()
239
+
240
+ return app
241
+
242
+
243
+ def main():
244
+ parser = argparse.ArgumentParser(description="Trace Model Evaluation Server")
245
+ parser.add_argument(
246
+ "--model-id",
247
+ type=str,
248
+ default=DEFAULT_MODEL_ID,
249
+ help=f"Model ID (default: {DEFAULT_MODEL_ID})",
250
+ )
251
+ parser.add_argument(
252
+ "--host",
253
+ type=str,
254
+ default="0.0.0.0",
255
+ help="Server host",
256
+ )
257
+ parser.add_argument(
258
+ "--port",
259
+ type=int,
260
+ default=8001,
261
+ help="Server port",
262
+ )
263
+ parser.add_argument(
264
+ "--max-workers",
265
+ type=int,
266
+ default=1,
267
+ help="Max worker threads for batch processing",
268
+ )
269
+ args = parser.parse_args()
270
+
271
+ logging.basicConfig(level=logging.INFO)
272
+
273
+ app = create_app(model_id=args.model_id, max_workers=args.max_workers)
274
+ print(f"Trace eval server starting on {args.host}:{args.port}")
275
+ print(f"Model: {args.model_id}")
276
+ uvicorn.run(app, host=args.host, port=args.port)
277
+
278
+
279
+ if __name__ == "__main__":
280
+ main()
predict_trace.py CHANGED
@@ -37,7 +37,7 @@ def main():
37
  "--instruction",
38
  type=str,
39
  default="",
40
- help="Task / language instruction (e.g. 'Pick up the red block')",
41
  )
42
  parser.add_argument(
43
  "-p",
 
37
  "--instruction",
38
  type=str,
39
  default="",
40
+ help="Natural language task instruction (e.g. 'Pick up the red block and place it on the table')",
41
  )
42
  parser.add_argument(
43
  "-p",
requirements.txt CHANGED
@@ -1,7 +1,10 @@
 
1
  gradio>=4.0.0
 
2
  torch>=2.0.0
3
  transformers>=4.45.0
4
  accelerate>=0.25.0
5
  Pillow>=9.0.0
6
  numpy>=1.20.0
 
7
  qwen-vl-utils>=0.0.8
 
1
+ fastapi>=0.100.0
2
  gradio>=4.0.0
3
+ uvicorn>=0.22.0
4
  torch>=2.0.0
5
  transformers>=4.45.0
6
  accelerate>=0.25.0
7
  Pillow>=9.0.0
8
  numpy>=1.20.0
9
+ requests>=2.28.0
10
  qwen-vl-utils>=0.0.8