samrat-rm commited on
Commit
aa1c27d
·
1 Parent(s): 2ae2b18

fix: error handling for episode run loop

Browse files
Files changed (1) hide show
  1. inference.py +60 -57
inference.py CHANGED
@@ -190,10 +190,11 @@ async def run_episode(
190
  ) -> tuple[dict, WhyDidItFailEnv]:
191
  """Run one full episode for a specific scenario. Returns (result dict, env).
192
  env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
 
193
  try:
194
  result = await env.reset(scenario_key=scenario_key)
195
  except ConnectionClosedError:
196
- print(f" [WARN] scenario={scenario_key} reconnecting WebSocket...", flush=True)
197
  env = await _make_env()
198
  result = await env.reset(scenario_key=scenario_key)
199
 
@@ -204,64 +205,66 @@ async def run_episode(
204
  rewards: List[float] = []
205
  inspection_order: List[str] = []
206
  submit_action: WhyDidItFailAction | None = None
207
- last_error: str | None = None
 
208
 
209
- for step in range(1, MAX_STEPS + 1):
210
- if result.done:
211
- break
212
-
213
- action = _get_action(client, step, _summarize(obs), history)
214
- last_error = None
215
- try:
216
- result = await env.step(action)
217
- except ConnectionClosedError as e:
218
- last_error = str(e)
219
- print(f"[STEP] step={step} action={action.action_type} reward=0.00 done=true error={last_error}", flush=True)
220
- break
221
- obs = result.observation
222
- reward = result.reward or 0.0
223
- done = result.done
224
- act_str = action.model_dump_json(exclude_none=True, exclude_defaults=True)
225
-
226
- if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
227
- source = action.action_type.replace("inspect_", "")
228
- if source not in inspection_order:
229
- inspection_order.append(source)
230
-
231
- if action.action_type == "submit_diagnosis":
232
- submit_action = action # judge runs after loop — WebSocket is closed by then
233
-
234
- rewards.append(reward)
235
- history.append(f"Step {step}: {act_str} reward={reward:.2f} | {obs.feedback}")
236
- print(f"[STEP] step={step} action={act_str} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
237
-
238
- if done:
239
- break
240
-
241
- # WebSocket is closed safe to call the judge now
242
- keyword_score = rewards[-1] if rewards else 0.0
243
- judge_score: float | None = None
244
- if submit_action is not None:
245
- judge_score = llm_judge(
246
- client=client,
247
- model=MODEL_NAME,
248
- diagnosis=submit_action.diagnosis or "",
249
- reasoning=submit_action.reasoning,
250
- suggested_fix=submit_action.suggested_fix,
251
- scenario=SCENARIOS[scenario_key],
252
- inspection_order=inspection_order,
253
- )
254
- if judge_score is None:
255
- score = round(keyword_score, 4)
256
- print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning=n/a total={score:.3f}", flush=True)
257
- else:
258
- score = round(0.85 * keyword_score + 0.15 * judge_score, 4)
259
- print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", flush=True)
 
260
 
261
- success = score >= SUCCESS_THRESHOLD
262
- steps_taken = len(rewards)
263
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
264
- print(f"[END] success={str(success).lower()} steps={steps_taken} rewards={rewards_str}", flush=True)
265
 
266
  return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
267
 
 
190
  ) -> tuple[dict, WhyDidItFailEnv]:
191
  """Run one full episode for a specific scenario. Returns (result dict, env).
192
  env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
193
+ import sys
194
  try:
195
  result = await env.reset(scenario_key=scenario_key)
196
  except ConnectionClosedError:
197
+ print(f" [WARN] scenario={scenario_key} reconnecting WebSocket...", file=sys.stderr, flush=True)
198
  env = await _make_env()
199
  result = await env.reset(scenario_key=scenario_key)
200
 
 
205
  rewards: List[float] = []
206
  inspection_order: List[str] = []
207
  submit_action: WhyDidItFailAction | None = None
208
+ score = 0.0
209
+ success = False
210
 
211
+ try:
212
+ for step in range(1, MAX_STEPS + 1):
213
+ if result.done:
214
+ break
215
+
216
+ action = _get_action(client, step, _summarize(obs), history)
217
+ try:
218
+ result = await env.step(action)
219
+ except ConnectionClosedError as e:
220
+ print(f"[STEP] step={step} action={action.action_type} reward=0.00 done=true error={e}", flush=True)
221
+ break
222
+ obs = result.observation
223
+ reward = result.reward or 0.0
224
+ done = result.done
225
+ act_str = action.model_dump_json(exclude_none=True, exclude_defaults=True)
226
+
227
+ if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
228
+ source = action.action_type.replace("inspect_", "")
229
+ if source not in inspection_order:
230
+ inspection_order.append(source)
231
+
232
+ if action.action_type == "submit_diagnosis":
233
+ submit_action = action # judge runs after loop — WebSocket is closed by then
234
+
235
+ rewards.append(reward)
236
+ history.append(f"Step {step}: {act_str} → reward={reward:.2f} | {obs.feedback}")
237
+ print(f"[STEP] step={step} action={act_str} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
238
+
239
+ if done:
240
+ break
241
+
242
+ # WebSocket is closed — safe to call the judge now
243
+ keyword_score = rewards[-1] if rewards else 0.0
244
+ judge_score: float | None = None
245
+ if submit_action is not None:
246
+ judge_score = llm_judge(
247
+ client=client,
248
+ model=MODEL_NAME,
249
+ diagnosis=submit_action.diagnosis or "",
250
+ reasoning=submit_action.reasoning,
251
+ suggested_fix=submit_action.suggested_fix,
252
+ scenario=SCENARIOS[scenario_key],
253
+ inspection_order=inspection_order,
254
+ )
255
+ if judge_score is None:
256
+ score = round(keyword_score, 4)
257
+ print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning=n/a total={score:.3f}", file=__import__("sys").stderr, flush=True)
258
+ else:
259
+ score = round(0.85 * keyword_score + 0.15 * judge_score, 4)
260
+ print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", file=__import__("sys").stderr, flush=True)
261
+
262
+ success = score >= SUCCESS_THRESHOLD
263
 
264
+ finally:
265
+ steps_taken = len(rewards)
266
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.00"
267
+ print(f"[END] success={str(success).lower()} steps={steps_taken} rewards={rewards_str}", flush=True)
268
 
269
  return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
270