Adit1Sharma commited on
Commit
55a5567
·
1 Parent(s): 1d08ca6

Update inference.py to exact STDOUT format requirements

Browse files
Files changed (1) hide show
  1. inference.py +11 -3
inference.py CHANGED
@@ -63,18 +63,26 @@ def run_llm(task_id):
63
  task = env.current_task
64
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
65
  taken = []
 
66
 
67
- print(f"[START] task={task_id}", flush=True)
68
  for i in range(task["max_steps"]):
69
  action = call_llm(obs, messages)
70
  obs, reward, done, info = env.step(action)
71
  taken.append(action)
72
- print(f"[STEP] step={i+1} reward={reward}", flush=True)
 
 
 
 
 
73
  if done:
74
  break
75
 
76
  score = grade_task(task, taken)
77
- print(f"[END] task={task_id} score={score} steps={len(taken)}", flush=True)
 
 
78
  return score
79
 
80
  def main():
 
63
  task = env.current_task
64
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
65
  taken = []
66
+ rewards = []
67
 
68
+ print(f"[START] task={task_id} env=customer-support model={MODEL_NAME}", flush=True)
69
  for i in range(task["max_steps"]):
70
  action = call_llm(obs, messages)
71
  obs, reward, done, info = env.step(action)
72
  taken.append(action)
73
+ rewards.append(reward)
74
+
75
+ action_str = action.model_dump_json() if hasattr(action, 'model_dump_json') else json.dumps(action.__dict__)
76
+ action_str = action_str.replace(" ", "")
77
+
78
+ print(f"[STEP] step={i+1} action={action_str} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
79
  if done:
80
  break
81
 
82
  score = grade_task(task, taken)
83
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
84
+ success = score >= 0.5
85
+ print(f"[END] success={str(success).lower()} steps={len(taken)} score={score:.3f} rewards={rewards_str}", flush=True)
86
  return score
87
 
88
  def main():