Anthony Liang commited on
Commit
fad52c2
·
1 Parent(s): 521dbac

more app fix

Browse files
Files changed (2) hide show
  1. app.py +12 -3
  2. eval_server.py +19 -3
app.py CHANGED
@@ -124,6 +124,7 @@ def run_inference_via_server(
124
  image_path: str,
125
  instruction: str,
126
  server_url: str,
 
127
  ) -> Tuple[str, Optional[str]]:
128
  """Run inference via trace eval server. Returns (prediction, overlay_path)."""
129
  with open(image_path, "rb") as f:
@@ -131,7 +132,11 @@ def run_inference_via_server(
131
  headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {}
132
  r = requests.post(
133
  f"{server_url.rstrip('/')}/predict",
134
- json={"image_base64": image_b64, "instruction": instruction},
 
 
 
 
135
  timeout=120.0,
136
  headers=headers,
137
  )
@@ -339,10 +344,14 @@ with demo:
339
  )
340
 
341
  if server_url:
342
- prompt = build_prompt(instruction)
 
 
 
 
343
  prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
344
  pred, overlay_path = run_inference_via_server(
345
- image_path, instruction, server_url
346
  )
347
  elif use_qwenvl:
348
  prompt = build_franka_prompt(instruction or "predict the trace")
 
124
  image_path: str,
125
  instruction: str,
126
  server_url: str,
127
+ use_qwenvl: bool = False,
128
  ) -> Tuple[str, Optional[str]]:
129
  """Run inference via trace eval server. Returns (prediction, overlay_path)."""
130
  with open(image_path, "rb") as f:
 
132
  headers = {"ngrok-skip-browser-warning": "true"} if "ngrok" in server_url else {}
133
  r = requests.post(
134
  f"{server_url.rstrip('/')}/predict",
135
+ json={
136
+ "image_base64": image_b64,
137
+ "instruction": instruction,
138
+ "use_qwenvl": use_qwenvl,
139
+ },
140
  timeout=120.0,
141
  headers=headers,
142
  )
 
344
  )
345
 
346
  if server_url:
347
+ prompt = (
348
+ build_franka_prompt(instruction or "predict the trace")
349
+ if use_qwenvl
350
+ else build_prompt(instruction)
351
+ )
352
  prompt_md = f"**Prompt sent to model:**\n\n```\n{prompt}\n```"
353
  pred, overlay_path = run_inference_via_server(
354
+ image_path, instruction, server_url, use_qwenvl
355
  )
356
  elif use_qwenvl:
357
  prompt = build_franka_prompt(instruction or "predict the trace")
eval_server.py CHANGED
@@ -33,6 +33,7 @@ from trace_inference import (
33
  build_prompt,
34
  load_model,
35
  run_inference,
 
36
  )
37
  from trace_inference import _model_state as _trace_model_state
38
  from trajectory_viz import extract_trajectory_from_text
@@ -68,11 +69,13 @@ class TraceEvalServer:
68
  image_path: Optional[str] = None,
69
  image_base64: Optional[str] = None,
70
  instruction: str = "",
 
71
  ) -> Dict[str, Any]:
72
  """
73
  Run inference on a single image.
74
 
75
  Provide either image_path (file path) or image_base64 (base64-encoded image).
 
76
  """
77
  if image_path is None and image_base64 is None:
78
  return {"error": "Provide image_path or image_base64"}
@@ -99,8 +102,15 @@ class TraceEvalServer:
99
  return {"error": f"Invalid image data: {e}"}
100
 
101
  try:
102
- prompt = build_prompt(instruction)
103
- prediction, overlay_path, _ = run_inference(image_path, prompt, self.model_id)
 
 
 
 
 
 
 
104
  finally:
105
  if temp_file_path and os.path.exists(temp_file_path):
106
  try:
@@ -112,10 +122,13 @@ class TraceEvalServer:
112
  return {"error": prediction}
113
 
114
  trajectory = extract_trajectory_from_text(prediction)
115
- return {
116
  "prediction": prediction,
117
  "trajectory": trajectory,
118
  }
 
 
 
119
 
120
  def predict_batch(
121
  self,
@@ -133,6 +146,7 @@ class TraceEvalServer:
133
  image_path=sample.get("image_path"),
134
  image_base64=sample.get("image_base64"),
135
  instruction=sample.get("instruction", ""),
 
136
  )
137
  elapsed = time.time() - start
138
 
@@ -200,12 +214,14 @@ def create_app(
200
  - image_path: (optional) path to image file
201
  - image_base64: (optional) base64-encoded image
202
  - instruction: natural language task description
 
203
  """
204
  body = await request.json()
205
  return trace_server.predict_one(
206
  image_path=body.get("image_path"),
207
  image_base64=body.get("image_base64"),
208
  instruction=body.get("instruction", ""),
 
209
  )
210
 
211
  @app.post("/predict_batch")
 
33
  build_prompt,
34
  load_model,
35
  run_inference,
36
+ run_inference_qwenvl,
37
  )
38
  from trace_inference import _model_state as _trace_model_state
39
  from trajectory_viz import extract_trajectory_from_text
 
69
  image_path: Optional[str] = None,
70
  image_base64: Optional[str] = None,
71
  instruction: str = "",
72
+ use_qwenvl: bool = False,
73
  ) -> Dict[str, Any]:
74
  """
75
  Run inference on a single image.
76
 
77
  Provide either image_path (file path) or image_base64 (base64-encoded image).
78
+ If use_qwenvl=True, uses run_inference_qwenvl (Franka-style, requires qwenvl).
79
  """
80
  if image_path is None and image_base64 is None:
81
  return {"error": "Provide image_path or image_base64"}
 
102
  return {"error": f"Invalid image data: {e}"}
103
 
104
  try:
105
+ if use_qwenvl:
106
+ output_dict, prediction, _, trace_text = run_inference_qwenvl(
107
+ image_path, instruction or "predict the trace", self.model_id
108
+ )
109
+ if not output_dict and trace_text and "qwenvl package not found" in trace_text:
110
+ return {"error": trace_text}
111
+ else:
112
+ prompt = build_prompt(instruction)
113
+ prediction, _, _ = run_inference(image_path, prompt, self.model_id)
114
  finally:
115
  if temp_file_path and os.path.exists(temp_file_path):
116
  try:
 
122
  return {"error": prediction}
123
 
124
  trajectory = extract_trajectory_from_text(prediction)
125
+ result: Dict[str, Any] = {
126
  "prediction": prediction,
127
  "trajectory": trajectory,
128
  }
129
+ if use_qwenvl and output_dict:
130
+ result["output_dict"] = output_dict
131
+ return result
132
 
133
  def predict_batch(
134
  self,
 
146
  image_path=sample.get("image_path"),
147
  image_base64=sample.get("image_base64"),
148
  instruction=sample.get("instruction", ""),
149
+ use_qwenvl=sample.get("use_qwenvl", False),
150
  )
151
  elapsed = time.time() - start
152
 
 
214
  - image_path: (optional) path to image file
215
  - image_base64: (optional) base64-encoded image
216
  - instruction: natural language task description
217
+ - use_qwenvl: (optional) if true, use Franka/qwenvl inference (requires qwenvl)
218
  """
219
  body = await request.json()
220
  return trace_server.predict_one(
221
  image_path=body.get("image_path"),
222
  image_base64=body.get("image_base64"),
223
  instruction=body.get("instruction", ""),
224
+ use_qwenvl=body.get("use_qwenvl", False),
225
  )
226
 
227
  @app.post("/predict_batch")