samrat-rm commited on
Commit
77f9568
Β·
1 Parent(s): aa1c27d

fix: comply with openenv stdout spec, preserve inspection data in history, sharpen medium-tier label rules

Browse files
Files changed (1) hide show
  1. inference.py +13 -10
inference.py CHANGED
@@ -22,6 +22,7 @@ STDOUT FORMAT
22
  import asyncio
23
  import json
24
  import os
 
25
  import textwrap
26
  from typing import List
27
 
@@ -87,8 +88,10 @@ SYSTEM_PROMPT = textwrap.dedent("""
87
  LABEL DECISION RULES β€” use these to pick the exact diagnosis label:
88
  - train_loss is NaN from epoch 1 AND config shows extreme weight_init (e.g. std=100) AND gradient norms are massive (>10000) β†’ "bad weight initialization". Check config FIRST before applying the NaN rule below.
89
  - train_loss is NaN or inf AFTER at least one finite epoch β†’ "exploding gradients". ABSOLUTE RULE. No other label applies.
90
- - loss oscillates wildly epoch-to-epoch but stays finite (no NaN) β†’ "learning rate too high"
91
- - both train_loss AND val_loss stay high with no gap (train_acc β‰ˆ val_acc, both near random baseline ~10%) β†’ "underfitting". ABSOLUTE RULE. The config is IRRELEVANT. Do NOT wait for gradients. Submit immediately after seeing the logs.
 
 
92
  - train_loss low, val_loss rising AND config shows weight_decay=0.0 exactly AND dropout=0.0 exactly β†’ "missing regularization" (NOT "overfitting")
93
  - train_loss low, val_loss rising AND config shows ANY non-zero weight_decay OR ANY non-zero dropout β†’ "overfitting" (NOT "missing regularization")
94
  - gradient norm = 0.0 exactly in hidden layers AND config shows ReLU activation β†’ "dying relu"
@@ -190,7 +193,6 @@ async def run_episode(
190
  ) -> tuple[dict, WhyDidItFailEnv]:
191
  """Run one full episode for a specific scenario. Returns (result dict, env).
192
  env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
193
- import sys
194
  try:
195
  result = await env.reset(scenario_key=scenario_key)
196
  except ConnectionClosedError:
@@ -233,7 +235,8 @@ async def run_episode(
233
  submit_action = action # judge runs after loop β€” WebSocket is closed by then
234
 
235
  rewards.append(reward)
236
- history.append(f"Step {step}: {act_str} β†’ reward={reward:.2f} | {obs.feedback}")
 
237
  print(f"[STEP] step={step} action={act_str} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
238
 
239
  if done:
@@ -254,10 +257,10 @@ async def run_episode(
254
  )
255
  if judge_score is None:
256
  score = round(keyword_score, 4)
257
- print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning=n/a total={score:.3f}", file=__import__("sys").stderr, flush=True)
258
  else:
259
  score = round(0.85 * keyword_score + 0.15 * judge_score, 4)
260
- print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", file=__import__("sys").stderr, flush=True)
261
 
262
  success = score >= SUCCESS_THRESHOLD
263
 
@@ -289,11 +292,11 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
289
  for key in scenario_keys:
290
  res, env = await run_episode(env, client, key, task_name, effective_model)
291
  results.append(res)
292
- print(f" [RESULT] scenario={res['scenario_key']} score={res['score']:.3f} steps={res['steps']} success={str(res['success']).lower()}", flush=True)
293
 
294
  avg_score = sum(r["score"] for r in results) / len(results)
295
  pass_rate = sum(1 for r in results if r["success"]) / len(results)
296
- print(f" [SUMMARY] task={task_name} avg_score={avg_score:.3f} pass_rate={pass_rate:.2f}", flush=True)
297
  return [r["score"] for r in results]
298
 
299
 
@@ -309,12 +312,12 @@ async def main() -> None:
309
  scores += await run_task("task_medium", MEDIUM_SCENARIOS, env, client)
310
  scores += await run_task("task_hard", HARD_SCENARIOS, env, client)
311
  overall = sum(scores) / len(scores) if scores else 0.0
312
- print(f" [OVERALL] avg_score={overall:.3f}", flush=True)
313
  finally:
314
  try:
315
  await env.close()
316
  except Exception as e:
317
- print(f" [DEBUG] env.close() error: {e}", flush=True)
318
 
319
 
320
  if __name__ == "__main__":
 
22
  import asyncio
23
  import json
24
  import os
25
+ import sys
26
  import textwrap
27
  from typing import List
28
 
 
88
  LABEL DECISION RULES β€” use these to pick the exact diagnosis label:
89
  - train_loss is NaN from epoch 1 AND config shows extreme weight_init (e.g. std=100) AND gradient norms are massive (>10000) β†’ "bad weight initialization". Check config FIRST before applying the NaN rule below.
90
  - train_loss is NaN or inf AFTER at least one finite epoch β†’ "exploding gradients". ABSOLUTE RULE. No other label applies.
91
+ - loss oscillates wildly epoch-to-epoch but stays finite (no NaN) AND config shows batch_size ≀ 4 β†’ "batch size too small" (NOT "learning rate too high"). PRIORITY RULE: check batch_size in config before applying the oscillation β†’ lr rule.
92
+ - loss oscillates wildly epoch-to-epoch but stays finite (no NaN) AND config shows batch_size > 4 β†’ "learning rate too high"
93
+ - both train_loss AND val_loss stay high with no gap (train_acc β‰ˆ val_acc, both near random baseline ~10%) AND config shows SGD optimizer with momentum=0.0 β†’ "optimizer misconfiguration" (NOT "underfitting"). Check config for SGD momentum before applying the underfitting rule.
94
+ - both train_loss AND val_loss stay high with no gap (train_acc β‰ˆ val_acc, both near random baseline ~10%) AND config does NOT show SGD with momentum=0.0 β†’ "underfitting". ABSOLUTE RULE. Do NOT wait for gradients. Submit immediately after seeing the logs.
95
  - train_loss low, val_loss rising AND config shows weight_decay=0.0 exactly AND dropout=0.0 exactly β†’ "missing regularization" (NOT "overfitting")
96
  - train_loss low, val_loss rising AND config shows ANY non-zero weight_decay OR ANY non-zero dropout β†’ "overfitting" (NOT "missing regularization")
97
  - gradient norm = 0.0 exactly in hidden layers AND config shows ReLU activation β†’ "dying relu"
 
193
  ) -> tuple[dict, WhyDidItFailEnv]:
194
  """Run one full episode for a specific scenario. Returns (result dict, env).
195
  env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
 
196
  try:
197
  result = await env.reset(scenario_key=scenario_key)
198
  except ConnectionClosedError:
 
235
  submit_action = action # judge runs after loop β€” WebSocket is closed by then
236
 
237
  rewards.append(reward)
238
+ data_seen = json.dumps(obs.visible_data) if obs.visible_data else "{}"
239
+ history.append(f"Step {step}: {act_str} β†’ reward={reward:.2f} | {obs.feedback}\n Data: {data_seen}")
240
  print(f"[STEP] step={step} action={act_str} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
241
 
242
  if done:
 
257
  )
258
  if judge_score is None:
259
  score = round(keyword_score, 4)
260
+ print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning=n/a total={score:.3f}", file=sys.stderr, flush=True)
261
  else:
262
  score = round(0.85 * keyword_score + 0.15 * judge_score, 4)
263
+ print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", file=sys.stderr, flush=True)
264
 
265
  success = score >= SUCCESS_THRESHOLD
266
 
 
292
  for key in scenario_keys:
293
  res, env = await run_episode(env, client, key, task_name, effective_model)
294
  results.append(res)
295
+ print(f" [RESULT] scenario={res['scenario_key']} score={res['score']:.3f} steps={res['steps']} success={str(res['success']).lower()}", file=sys.stderr, flush=True)
296
 
297
  avg_score = sum(r["score"] for r in results) / len(results)
298
  pass_rate = sum(1 for r in results if r["success"]) / len(results)
299
+ print(f" [SUMMARY] task={task_name} avg_score={avg_score:.3f} pass_rate={pass_rate:.2f}", file=sys.stderr, flush=True)
300
  return [r["score"] for r in results]
301
 
302
 
 
312
  scores += await run_task("task_medium", MEDIUM_SCENARIOS, env, client)
313
  scores += await run_task("task_hard", HARD_SCENARIOS, env, client)
314
  overall = sum(scores) / len(scores) if scores else 0.0
315
+ print(f" [OVERALL] avg_score={overall:.3f}", file=sys.stderr, flush=True)
316
  finally:
317
  try:
318
  await env.close()
319
  except Exception as e:
320
+ print(f" [DEBUG] env.close() error: {e}", file=sys.stderr, flush=True)
321
 
322
 
323
  if __name__ == "__main__":