Spaces:
Sleeping
Sleeping
fix: error handling for episode run loop
Browse files- inference.py +60 -57
inference.py
CHANGED
|
@@ -190,10 +190,11 @@ async def run_episode(
|
|
| 190 |
) -> tuple[dict, WhyDidItFailEnv]:
|
| 191 |
"""Run one full episode for a specific scenario. Returns (result dict, env).
|
| 192 |
env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
|
|
|
|
| 193 |
try:
|
| 194 |
result = await env.reset(scenario_key=scenario_key)
|
| 195 |
except ConnectionClosedError:
|
| 196 |
-
print(f" [WARN] scenario={scenario_key} reconnecting WebSocket...", flush=True)
|
| 197 |
env = await _make_env()
|
| 198 |
result = await env.reset(scenario_key=scenario_key)
|
| 199 |
|
|
@@ -204,64 +205,66 @@ async def run_episode(
|
|
| 204 |
rewards: List[float] = []
|
| 205 |
inspection_order: List[str] = []
|
| 206 |
submit_action: WhyDidItFailAction | None = None
|
| 207 |
-
|
|
|
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
|
| 266 |
return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
|
| 267 |
|
|
|
|
| 190 |
) -> tuple[dict, WhyDidItFailEnv]:
|
| 191 |
"""Run one full episode for a specific scenario. Returns (result dict, env).
|
| 192 |
env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
|
| 193 |
+
import sys
|
| 194 |
try:
|
| 195 |
result = await env.reset(scenario_key=scenario_key)
|
| 196 |
except ConnectionClosedError:
|
| 197 |
+
print(f" [WARN] scenario={scenario_key} reconnecting WebSocket...", file=sys.stderr, flush=True)
|
| 198 |
env = await _make_env()
|
| 199 |
result = await env.reset(scenario_key=scenario_key)
|
| 200 |
|
|
|
|
| 205 |
rewards: List[float] = []
|
| 206 |
inspection_order: List[str] = []
|
| 207 |
submit_action: WhyDidItFailAction | None = None
|
| 208 |
+
score = 0.0
|
| 209 |
+
success = False
|
| 210 |
|
| 211 |
+
try:
|
| 212 |
+
for step in range(1, MAX_STEPS + 1):
|
| 213 |
+
if result.done:
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
action = _get_action(client, step, _summarize(obs), history)
|
| 217 |
+
try:
|
| 218 |
+
result = await env.step(action)
|
| 219 |
+
except ConnectionClosedError as e:
|
| 220 |
+
print(f"[STEP] step={step} action={action.action_type} reward=0.00 done=true error={e}", flush=True)
|
| 221 |
+
break
|
| 222 |
+
obs = result.observation
|
| 223 |
+
reward = result.reward or 0.0
|
| 224 |
+
done = result.done
|
| 225 |
+
act_str = action.model_dump_json(exclude_none=True, exclude_defaults=True)
|
| 226 |
+
|
| 227 |
+
if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
|
| 228 |
+
source = action.action_type.replace("inspect_", "")
|
| 229 |
+
if source not in inspection_order:
|
| 230 |
+
inspection_order.append(source)
|
| 231 |
+
|
| 232 |
+
if action.action_type == "submit_diagnosis":
|
| 233 |
+
submit_action = action # judge runs after loop — WebSocket is closed by then
|
| 234 |
+
|
| 235 |
+
rewards.append(reward)
|
| 236 |
+
history.append(f"Step {step}: {act_str} → reward={reward:.2f} | {obs.feedback}")
|
| 237 |
+
print(f"[STEP] step={step} action={act_str} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
|
| 238 |
+
|
| 239 |
+
if done:
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
# WebSocket is closed — safe to call the judge now
|
| 243 |
+
keyword_score = rewards[-1] if rewards else 0.0
|
| 244 |
+
judge_score: float | None = None
|
| 245 |
+
if submit_action is not None:
|
| 246 |
+
judge_score = llm_judge(
|
| 247 |
+
client=client,
|
| 248 |
+
model=MODEL_NAME,
|
| 249 |
+
diagnosis=submit_action.diagnosis or "",
|
| 250 |
+
reasoning=submit_action.reasoning,
|
| 251 |
+
suggested_fix=submit_action.suggested_fix,
|
| 252 |
+
scenario=SCENARIOS[scenario_key],
|
| 253 |
+
inspection_order=inspection_order,
|
| 254 |
+
)
|
| 255 |
+
if judge_score is None:
|
| 256 |
+
score = round(keyword_score, 4)
|
| 257 |
+
print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning=n/a total={score:.3f}", file=__import__("sys").stderr, flush=True)
|
| 258 |
+
else:
|
| 259 |
+
score = round(0.85 * keyword_score + 0.15 * judge_score, 4)
|
| 260 |
+
print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", file=__import__("sys").stderr, flush=True)
|
| 261 |
+
|
| 262 |
+
success = score >= SUCCESS_THRESHOLD
|
| 263 |
|
| 264 |
+
finally:
|
| 265 |
+
steps_taken = len(rewards)
|
| 266 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.00"
|
| 267 |
+
print(f"[END] success={str(success).lower()} steps={steps_taken} rewards={rewards_str}", flush=True)
|
| 268 |
|
| 269 |
return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
|
| 270 |
|