vikash-nuvai commited on
Commit
5d20aef
·
1 Parent(s): 12d9f1b

fix: add structured output markers for validator

Browse files
Files changed (1) hide show
  1. inference.py +165 -137
inference.py CHANGED
@@ -20,9 +20,9 @@ import json
20
  import os
21
  import sys
22
  import time
 
23
 
24
  import requests
25
- from openai import OpenAI
26
 
27
  # ---------------------------------------------------------------------------
28
  # Required environment variables
@@ -32,11 +32,6 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
32
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
33
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
34
 
35
- if not HF_TOKEN:
36
- print("WARNING: HF_TOKEN not set. LLM calls will fail.")
37
-
38
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
39
-
40
  # ---------------------------------------------------------------------------
41
  # System prompt
42
  # ---------------------------------------------------------------------------
@@ -69,6 +64,16 @@ STRATEGY:
69
  Respond with ONLY valid JSON. No explanation, no markdown, no extra text."""
70
 
71
 
 
 
 
 
 
 
 
 
 
 
72
  def parse_action(text: str) -> dict:
73
  """Parse LLM output into an action dict."""
74
  text = text.strip()
@@ -101,175 +106,198 @@ def parse_action(text: str) -> dict:
101
  continue
102
 
103
  # Fallback
104
- print(f" [WARN] Could not parse action: {text[:100]}")
105
  return {"command": "observe"}
106
 
107
 
108
- def run_episode(task_id: str) -> dict:
109
  """Run one episode of the tiffin packing task."""
110
- print(f"\n{'='*60}")
111
- print(f" TASK: {task_id.upper()}")
112
- print(f"{'='*60}")
113
 
114
- # Reset the environment
115
- try:
116
- resp = requests.post(
117
- f"{ENV_URL}/reset",
118
- json={"task_id": task_id, "seed": 42},
119
- timeout=30,
120
- )
121
- resp.raise_for_status()
122
- result = resp.json()
123
- obs = result.get("observation", result)
124
- except Exception as e:
125
- print(f" ERROR: Failed to reset environment: {e}")
126
- return {"task_id": task_id, "reward": 0.0, "score": 0.0, "error": str(e)}
127
-
128
- # Initialize conversation
129
- init_scene = obs.get("scene_description", "")
130
- init_feedback = obs.get("step_feedback", "")
131
- messages = [
132
- {"role": "system", "content": SYSTEM_PROMPT},
133
- {
134
- "role": "user",
135
- "content": (
136
- f"Task: {task_id}\n\n"
137
- f"{init_feedback}\n\n"
138
- f"Scene:\n{init_scene}\n\n"
139
- f"Available commands: {obs.get('available_commands', [])}\n\n"
140
- f"What is your first action? Respond with JSON only."
141
- ),
142
- },
143
- ]
144
-
145
- total_reward = 0.0
146
  step = 0
147
- max_steps = 35 # safety limit
148
-
149
- while not obs.get("done", False) and step < max_steps:
150
- step += 1
151
 
152
- # Get LLM decision
153
- try:
154
- response = client.chat.completions.create(
155
- model=MODEL_NAME,
156
- messages=messages,
157
- temperature=0.0,
158
- max_tokens=200,
159
- )
160
- action_text = response.choices[0].message.content.strip()
161
- except Exception as e:
162
- print(f" [Step {step}] LLM error: {e}")
163
- action_text = '{"command": "observe"}'
164
-
165
- action = parse_action(action_text)
166
- print(f" [Step {step}] Action: {json.dumps(action)}")
167
 
168
- # Execute step
169
  try:
170
  resp = requests.post(
171
- f"{ENV_URL}/step",
172
- json={"action": action},
173
  timeout=30,
174
  )
175
  resp.raise_for_status()
176
  result = resp.json()
177
  obs = result.get("observation", result)
178
- reward = result.get("reward", obs.get("reward", 0.0))
179
- total_reward += reward or 0
180
  except Exception as e:
181
- print(f" [Step {step}] Step error: {e}")
182
- break
183
-
184
- # Print feedback
185
- feedback = obs.get("step_feedback", "")[:200]
186
- print(f" Reward: {reward:+.2f} | Feedback: {feedback}")
187
-
188
- # Update conversation with assistant response and new observation
189
- messages.append({"role": "assistant", "content": action_text})
190
-
191
- # Build concise next observation for LLM
192
- held = obs.get("held_item")
193
- held_str = (
194
- f"Holding: {held.get('name', 'unknown')}" if held else "Arm: idle"
195
- )
196
- items_status = [
197
- f"[{i['id']}] {i.get('name', '?')} ({i['status']})"
198
- for i in obs.get("food_items", [])
199
- ]
200
- containers_status = [
201
- f"[{c['id']}] {c['name']} {c.get('fill_percentage',0):.0f}% full"
202
- for c in obs.get("containers", [])
203
- ]
204
-
205
- messages.append(
206
  {
207
  "role": "user",
208
  "content": (
209
- f"Step {step} result (reward={reward:+.2f}):\n"
210
- f"Feedback: {obs.get('step_feedback', '')}\n\n"
211
- f"{held_str}\n"
212
- f"Items: {', '.join(items_status)}\n"
213
- f"Containers: {', '.join(containers_status)}\n"
214
- f"Available: {obs.get('available_commands', [])}\n\n"
215
- f"{'VLM Result: ' + json.dumps(obs.get('vlm_result')) if obs.get('vlm_result') else ''}\n\n"
216
- f"Next action? JSON only."
217
  ),
218
  },
219
- )
220
 
221
- # Extract final score
222
- final_score = obs.get("metadata", {}).get("final_score", 0.0)
223
- grade_breakdown = obs.get("metadata", {}).get("grade_breakdown", {})
224
-
225
- print(f"\n {'─'*40}")
226
- print(f" Steps taken: {step}")
227
- print(f" Total reward: {total_reward:+.2f}")
228
- print(f" Final score: {final_score:.4f}")
229
- if grade_breakdown:
230
- print(f" Breakdown:")
231
- print(f" Validity: {grade_breakdown.get('validity', 0):.4f} (x0.4)")
232
- print(f" Efficiency: {grade_breakdown.get('efficiency', 0):.4f} (x0.3)")
233
- print(f" Constraints: {grade_breakdown.get('constraints', 0):.4f} (x0.2)")
234
- print(f" Neatness: {grade_breakdown.get('neatness', 0):.4f} (x0.1)")
235
-
236
- return {
237
- "task_id": task_id,
238
- "steps": step,
239
- "total_reward": round(total_reward, 4),
240
- "score": final_score,
241
- "grade_breakdown": grade_breakdown,
242
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
 
245
  def main():
246
  """Run all 3 tasks and report results."""
247
- print("=" * 60)
248
- print(" TIFFIN PACKER — INFERENCE SCRIPT")
249
- print(f" Model: {MODEL_NAME}")
250
- print(f" API: {API_BASE_URL}")
251
- print(f" Env: {ENV_URL}")
252
- print("=" * 60)
 
 
 
253
 
254
  start_time = time.time()
255
  results = {}
256
 
257
  for task_id in ["easy", "medium", "hard"]:
258
- result = run_episode(task_id)
259
  results[task_id] = result
260
 
261
  elapsed = time.time() - start_time
262
 
263
  # Summary
264
- print("\n" + "=" * 60)
265
- print(" FINAL RESULTS")
266
- print("=" * 60)
267
  for task_id, r in results.items():
268
- print(f" {task_id:8s}: score={r['score']:.4f} reward={r['total_reward']:+.2f} steps={r.get('steps', '?')}")
269
 
270
  avg_score = sum(r["score"] for r in results.values()) / max(len(results), 1)
271
- print(f"\n Average score: {avg_score:.4f}")
272
- print(f" Total time: {elapsed:.1f}s")
273
 
274
  # Save results
275
  os.makedirs("outputs/evals", exist_ok=True)
@@ -285,7 +313,7 @@ def main():
285
  f,
286
  indent=2,
287
  )
288
- print(f"\n Results saved to outputs/evals/results.json")
289
 
290
 
291
  if __name__ == "__main__":
 
20
  import os
21
  import sys
22
  import time
23
+ import traceback
24
 
25
  import requests
 
26
 
27
  # ---------------------------------------------------------------------------
28
  # Required environment variables
 
32
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
33
  ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
34
 
 
 
 
 
 
35
  # ---------------------------------------------------------------------------
36
  # System prompt
37
  # ---------------------------------------------------------------------------
 
64
  Respond with ONLY valid JSON. No explanation, no markdown, no extra text."""
65
 
66
 
67
+ def get_client():
68
+ """Lazily create an OpenAI client. Returns None if openai is unavailable."""
69
+ try:
70
+ from openai import OpenAI
71
+ return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
72
+ except Exception as e:
73
+ print(f"WARNING: Could not create OpenAI client: {e}", flush=True)
74
+ return None
75
+
76
+
77
  def parse_action(text: str) -> dict:
78
  """Parse LLM output into an action dict."""
79
  text = text.strip()
 
106
  continue
107
 
108
  # Fallback
109
+ print(f" [WARN] Could not parse action: {text[:100]}", flush=True)
110
  return {"command": "observe"}
111
 
112
 
113
+ def run_episode(task_id: str, client) -> dict:
114
  """Run one episode of the tiffin packing task."""
115
+ # Emit [START] structured output for the validator
116
+ print(f"[START] task={task_id}", flush=True)
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  step = 0
 
 
 
 
119
 
120
+ try:
121
+ print(f"\n{'='*60}", flush=True)
122
+ print(f" TASK: {task_id.upper()}", flush=True)
123
+ print(f"{'='*60}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ # Reset the environment
126
  try:
127
  resp = requests.post(
128
+ f"{ENV_URL}/reset",
129
+ json={"task_id": task_id, "seed": 42},
130
  timeout=30,
131
  )
132
  resp.raise_for_status()
133
  result = resp.json()
134
  obs = result.get("observation", result)
 
 
135
  except Exception as e:
136
+ print(f" ERROR: Failed to reset environment: {e}", flush=True)
137
+ print(f"[END] task={task_id} score=0.0 steps=0", flush=True)
138
+ return {"task_id": task_id, "total_reward": 0.0, "reward": 0.0, "score": 0.0, "steps": 0, "error": str(e)}
139
+
140
+ # Initialize conversation
141
+ init_scene = obs.get("scene_description", "")
142
+ init_feedback = obs.get("step_feedback", "")
143
+ messages = [
144
+ {"role": "system", "content": SYSTEM_PROMPT},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  {
146
  "role": "user",
147
  "content": (
148
+ f"Task: {task_id}\n\n"
149
+ f"{init_feedback}\n\n"
150
+ f"Scene:\n{init_scene}\n\n"
151
+ f"Available commands: {obs.get('available_commands', [])}\n\n"
152
+ f"What is your first action? Respond with JSON only."
 
 
 
153
  ),
154
  },
155
+ ]
156
 
157
+ total_reward = 0.0
158
+ max_steps = 35 # safety limit
159
+
160
+ while not obs.get("done", False) and step < max_steps:
161
+ step += 1
162
+
163
+ # Get LLM decision
164
+ try:
165
+ if client is None:
166
+ raise RuntimeError("No OpenAI client available")
167
+ response = client.chat.completions.create(
168
+ model=MODEL_NAME,
169
+ messages=messages,
170
+ temperature=0.0,
171
+ max_tokens=200,
172
+ )
173
+ action_text = response.choices[0].message.content.strip()
174
+ except Exception as e:
175
+ print(f" [Step {step}] LLM error: {e}", flush=True)
176
+ action_text = '{"command": "observe"}'
177
+
178
+ action = parse_action(action_text)
179
+ print(f" [Step {step}] Action: {json.dumps(action)}", flush=True)
180
+
181
+ # Execute step
182
+ try:
183
+ resp = requests.post(
184
+ f"{ENV_URL}/step",
185
+ json={"action": action},
186
+ timeout=30,
187
+ )
188
+ resp.raise_for_status()
189
+ result = resp.json()
190
+ obs = result.get("observation", result)
191
+ reward = result.get("reward", obs.get("reward", 0.0))
192
+ total_reward += reward or 0
193
+ # Emit [STEP] structured output for the validator
194
+ print(f"[STEP] step={step} reward={reward}", flush=True)
195
+ except Exception as e:
196
+ print(f" [Step {step}] Step error: {e}", flush=True)
197
+ break
198
+
199
+ # Print feedback
200
+ feedback = obs.get("step_feedback", "")[:200]
201
+ print(f" Reward: {reward:+.2f} | Feedback: {feedback}", flush=True)
202
+
203
+ # Update conversation with assistant response and new observation
204
+ messages.append({"role": "assistant", "content": action_text})
205
+
206
+ # Build concise next observation for LLM
207
+ held = obs.get("held_item")
208
+ held_str = (
209
+ f"Holding: {held.get('name', 'unknown')}" if held else "Arm: idle"
210
+ )
211
+ items_status = [
212
+ f"[{i['id']}] {i.get('name', '?')} ({i['status']})"
213
+ for i in obs.get("food_items", [])
214
+ ]
215
+ containers_status = [
216
+ f"[{c['id']}] {c['name']} {c.get('fill_percentage',0):.0f}% full"
217
+ for c in obs.get("containers", [])
218
+ ]
219
+
220
+ messages.append(
221
+ {
222
+ "role": "user",
223
+ "content": (
224
+ f"Step {step} result (reward={reward:+.2f}):\n"
225
+ f"Feedback: {obs.get('step_feedback', '')}\n\n"
226
+ f"{held_str}\n"
227
+ f"Items: {', '.join(items_status)}\n"
228
+ f"Containers: {', '.join(containers_status)}\n"
229
+ f"Available: {obs.get('available_commands', [])}\n\n"
230
+ f"{'VLM Result: ' + json.dumps(obs.get('vlm_result')) if obs.get('vlm_result') else ''}\n\n"
231
+ f"Next action? JSON only."
232
+ ),
233
+ },
234
+ )
235
+
236
+ # Extract final score
237
+ final_score = obs.get("metadata", {}).get("final_score", 0.0)
238
+ grade_breakdown = obs.get("metadata", {}).get("grade_breakdown", {})
239
+
240
+ print(f"\n {'─'*40}", flush=True)
241
+ print(f" Steps taken: {step}", flush=True)
242
+ print(f" Total reward: {total_reward:+.2f}", flush=True)
243
+ print(f" Final score: {final_score:.4f}", flush=True)
244
+ if grade_breakdown:
245
+ print(f" Breakdown:", flush=True)
246
+ print(f" Validity: {grade_breakdown.get('validity', 0):.4f} (x0.4)", flush=True)
247
+ print(f" Efficiency: {grade_breakdown.get('efficiency', 0):.4f} (x0.3)", flush=True)
248
+ print(f" Constraints: {grade_breakdown.get('constraints', 0):.4f} (x0.2)", flush=True)
249
+ print(f" Neatness: {grade_breakdown.get('neatness', 0):.4f} (x0.1)", flush=True)
250
+
251
+ # Emit [END] structured output for the validator
252
+ print(f"[END] task={task_id} score={final_score} steps={step}", flush=True)
253
+
254
+ return {
255
+ "task_id": task_id,
256
+ "steps": step,
257
+ "total_reward": round(total_reward, 4),
258
+ "score": final_score,
259
+ "grade_breakdown": grade_breakdown,
260
+ }
261
+
262
+ except Exception as e:
263
+ # Catch-all: ensure [END] is ALWAYS emitted even on unexpected errors
264
+ print(f" FATAL ERROR in episode {task_id}: {e}", flush=True)
265
+ traceback.print_exc()
266
+ print(f"[END] task={task_id} score=0.0 steps={step}", flush=True)
267
+ return {"task_id": task_id, "total_reward": 0.0, "reward": 0.0, "score": 0.0, "steps": step, "error": str(e)}
268
 
269
 
270
  def main():
271
  """Run all 3 tasks and report results."""
272
+ print("=" * 60, flush=True)
273
+ print(" TIFFIN PACKER — INFERENCE SCRIPT", flush=True)
274
+ print(f" Model: {MODEL_NAME}", flush=True)
275
+ print(f" API: {API_BASE_URL}", flush=True)
276
+ print(f" Env: {ENV_URL}", flush=True)
277
+ print("=" * 60, flush=True)
278
+
279
+ # Create client lazily — don't crash on import
280
+ client = get_client()
281
 
282
  start_time = time.time()
283
  results = {}
284
 
285
  for task_id in ["easy", "medium", "hard"]:
286
+ result = run_episode(task_id, client)
287
  results[task_id] = result
288
 
289
  elapsed = time.time() - start_time
290
 
291
  # Summary
292
+ print("\n" + "=" * 60, flush=True)
293
+ print(" FINAL RESULTS", flush=True)
294
+ print("=" * 60, flush=True)
295
  for task_id, r in results.items():
296
+ print(f" {task_id:8s}: score={r['score']:.4f} reward={r['total_reward']:+.2f} steps={r.get('steps', '?')}", flush=True)
297
 
298
  avg_score = sum(r["score"] for r in results.values()) / max(len(results), 1)
299
+ print(f"\n Average score: {avg_score:.4f}", flush=True)
300
+ print(f" Total time: {elapsed:.1f}s", flush=True)
301
 
302
  # Save results
303
  os.makedirs("outputs/evals", exist_ok=True)
 
313
  f,
314
  indent=2,
315
  )
316
+ print(f"\n Results saved to outputs/evals/results.json", flush=True)
317
 
318
 
319
  if __name__ == "__main__":