samrat-rm commited on
Commit
3eeca00
Β·
1 Parent(s): d8e7a25

fix: harden label rules to prevent missing_regularization misfires

Browse files
Files changed (2) hide show
  1. inference.py +13 -10
  2. server/scenarios.py +2 -2
inference.py CHANGED
@@ -82,11 +82,11 @@ SYSTEM_PROMPT = textwrap.dedent("""
82
  5. Your reasoning MUST quote specific numbers from the Data you received (e.g. "val_loss=2.34 at epoch 20, train_acc=0.99"). If you cannot quote a specific number from the Data, you have not read it β€” do not submit yet.
83
 
84
  LABEL DECISION RULES β€” use these to pick the exact diagnosis label:
85
- - loss becomes NaN or spikes 100x+ in one epoch β†’ "exploding gradients" (NOT "learning rate too high")
86
- - loss oscillates wildly epoch-to-epoch but stays finite β†’ "learning rate too high"
 
87
  - train_loss low, val_loss rising AND config shows weight_decay=0 AND dropout=0 β†’ "missing regularization" (NOT "overfitting")
88
- - train_loss low, val_loss rising AND regularization is already configured β†’ "overfitting"
89
- - both train_loss and val_loss stay high / plateau β†’ "underfitting"
90
  - gradient norm = 0.0 exactly in hidden layers AND config shows ReLU activation β†’ "dying relu"
91
  - gradient norm tiny but nonzero (e.g. 1e-5, 1e-8) AND config shows sigmoid/tanh β†’ "vanishing gradients"
92
  - config shows lr_scheduler with gamma > 1.0 β†’ "lr scheduler misconfiguration"
@@ -244,10 +244,10 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
244
 
245
  # ── task runners ──────────────────────────────────────────────────────────────
246
 
247
- async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEnv, client: OpenAI) -> None:
248
  if not scenario_keys:
249
  print(f"[SUMMARY] task={task_name} β€” no scenarios defined yet", flush=True)
250
- return
251
 
252
  if USE_LOCAL:
253
  try:
@@ -269,6 +269,7 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
269
  avg_score = sum(r["score"] for r in results) / len(results)
270
  pass_rate = sum(1 for r in results if r["success"]) / len(results)
271
  print(f"[SUMMARY] task={task_name} avg_score={avg_score:.3f} pass_rate={pass_rate:.2f}", flush=True)
 
272
 
273
 
274
  # ── main ──────────────────────────────────────────────────────────────────────
@@ -278,10 +279,12 @@ async def main() -> None:
278
  env = await _make_env()
279
 
280
  try:
281
- await run_task("easy", EASY_SCENARIOS, env, client)
282
- await run_task("medium", MEDIUM_SCENARIOS, env, client)
283
- await run_task("hard", HARD_SCENARIOS, env, client)
284
- print("[END] all tasks complete", flush=True)
 
 
285
  finally:
286
  try:
287
  await env.close()
 
82
  5. Your reasoning MUST quote specific numbers from the Data you received (e.g. "val_loss=2.34 at epoch 20, train_acc=0.99"). If you cannot quote a specific number from the Data, you have not read it β€” do not submit yet.
83
 
84
  LABEL DECISION RULES β€” use these to pick the exact diagnosis label:
85
+ - train_loss is NaN or inf at ANY epoch β†’ "exploding gradients". ABSOLUTE RULE. No other label applies.
86
+ - loss oscillates wildly epoch-to-epoch but stays finite (no NaN) β†’ "learning rate too high"
87
+ - both train_loss AND val_loss stay high with no gap (train_acc β‰ˆ val_acc, both near random) β†’ "underfitting". ABSOLUTE RULE regardless of config.
88
  - train_loss low, val_loss rising AND config shows weight_decay=0 AND dropout=0 β†’ "missing regularization" (NOT "overfitting")
89
+ - train_loss low, val_loss rising AND regularization is already present in config β†’ "overfitting"
 
90
  - gradient norm = 0.0 exactly in hidden layers AND config shows ReLU activation β†’ "dying relu"
91
  - gradient norm tiny but nonzero (e.g. 1e-5, 1e-8) AND config shows sigmoid/tanh β†’ "vanishing gradients"
92
  - config shows lr_scheduler with gamma > 1.0 β†’ "lr scheduler misconfiguration"
 
244
 
245
  # ── task runners ──────────────────────────────────────────────────────────────
246
 
247
+ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEnv, client: OpenAI) -> List[float]:
248
  if not scenario_keys:
249
  print(f"[SUMMARY] task={task_name} β€” no scenarios defined yet", flush=True)
250
+ return []
251
 
252
  if USE_LOCAL:
253
  try:
 
269
  avg_score = sum(r["score"] for r in results) / len(results)
270
  pass_rate = sum(1 for r in results if r["success"]) / len(results)
271
  print(f"[SUMMARY] task={task_name} avg_score={avg_score:.3f} pass_rate={pass_rate:.2f}", flush=True)
272
+ return [r["score"] for r in results]
273
 
274
 
275
  # ── main ──────────────────────────────────────────────────────────────────────
 
279
  env = await _make_env()
280
 
281
  try:
282
+ scores = []
283
+ scores += await run_task("easy", EASY_SCENARIOS, env, client)
284
+ scores += await run_task("medium", MEDIUM_SCENARIOS, env, client)
285
+ scores += await run_task("hard", HARD_SCENARIOS, env, client)
286
+ overall = sum(scores) / len(scores) if scores else 0.0
287
+ print(f"[END] score={overall:.3f}", flush=True)
288
  finally:
289
  try:
290
  await env.close()
server/scenarios.py CHANGED
@@ -64,7 +64,7 @@ SCENARIOS: dict[str, dict] = {
64
  "required_sources": ["logs"],
65
  "config": {
66
  "learning_rate": 0.001, "optimizer": "adam",
67
- "batch_size": 32, "weight_decay": 0.0, "dropout": 0.0,
68
  "architecture": "ResNet50", "dataset": "CIFAR-10",
69
  },
70
  "logs": [
@@ -82,7 +82,7 @@ SCENARIOS: dict[str, dict] = {
82
  {"epoch": 20, "norm": 0.24},
83
  ],
84
  "correct_diagnosis": "overfitting",
85
- "correct_fix": "add dropout=0.3 and weight_decay=0.01",
86
  },
87
 
88
  "underfitting": {
 
64
  "required_sources": ["logs"],
65
  "config": {
66
  "learning_rate": 0.001, "optimizer": "adam",
67
+ "batch_size": 32, "weight_decay": 0.001, "dropout": 0.1,
68
  "architecture": "ResNet50", "dataset": "CIFAR-10",
69
  },
70
  "logs": [
 
82
  {"epoch": 20, "norm": 0.24},
83
  ],
84
  "correct_diagnosis": "overfitting",
85
+ "correct_fix": "increase dropout to 0.3 and weight_decay to 0.01",
86
  },
87
 
88
  "underfitting": {