samrat-rm commited on
Commit
775f9bc
Β·
1 Parent(s): dfbe1fe

fix: adding reasoning in episode run loop and refactor commments

Browse files
Files changed (1) hide show
  1. inference.py +23 -9
inference.py CHANGED
@@ -13,7 +13,7 @@ TASKS
13
 
14
  STDOUT FORMAT
15
  [START] task=<task_name> scenarios=<n> model=<model_name>
16
- [EPISODE] scenario=<key> step=<n> action=<json> reward=<0.00> done=<bool>
17
  [RESULT] scenario=<key> score=<0.000> steps=<n> success=<bool>
18
  [SUMMARY] task=<task_name> avg_score=<0.000> pass_rate=<0.00>
19
  """
@@ -62,13 +62,25 @@ SYSTEM_PROMPT = textwrap.dedent("""
62
  inspect_gradients β€” examine gradient norm statistics
63
  submit_diagnosis β€” submit your final diagnosis (ends the episode)
64
 
65
- Respond with a JSON object on a single line. Examples:
66
- {"action_type": "inspect_logs"}
67
- {"action_type": "submit_diagnosis", "diagnosis": "exploding gradients"}
68
- {"action_type": "submit_diagnosis", "diagnosis": "overfitting", "suggested_fix": "add dropout=0.3"}
69
-
70
- Be efficient β€” inspect only what you need. Submit when confident.
71
- The diagnosis should be a short phrase describing the failure mode.
 
 
 
 
 
 
 
 
 
 
 
 
72
  """).strip()
73
 
74
 
@@ -124,6 +136,8 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
124
  obs = result.observation
125
  history: List[str] = []
126
  rewards: List[float] = []
 
 
127
 
128
  for step in range(1, MAX_STEPS + 1):
129
  if result.done:
@@ -146,7 +160,7 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
146
 
147
  rewards.append(reward)
148
  history.append(f"Step {step}: {act_str} β†’ reward={reward:.2f} | {obs.feedback}")
149
- print(f" [EPISODE] scenario={scenario_key} step={step} action={act_str} reward={reward:.2f} done={str(done).lower()}", flush=True)
150
 
151
  if done:
152
  break
 
13
 
14
  STDOUT FORMAT
15
  [START] task=<task_name> scenarios=<n> model=<model_name>
16
+ [STEP] scenario=<key> step=<n> action=<json> reward=<0.00> done=<bool>
17
  [RESULT] scenario=<key> score=<0.000> steps=<n> success=<bool>
18
  [SUMMARY] task=<task_name> avg_score=<0.000> pass_rate=<0.00>
19
  """
 
62
  inspect_gradients β€” examine gradient norm statistics
63
  submit_diagnosis β€” submit your final diagnosis (ends the episode)
64
 
65
+ OUTPUT FORMAT β€” STRICT:
66
+ Output ONLY a raw JSON object. No markdown, no code fences, no backticks, no explanation.
67
+ Start with { and end with }. One line only.
68
+
69
+ Examples:
70
+ {"action_type": "inspect_logs"}
71
+ {"action_type": "submit_diagnosis", "diagnosis": "exploding gradients", "suggested_fix": "reduce learning_rate to 0.001", "reasoning": "Loss spiked to NaN by epoch 3 and lr=10.0 in config, indicating weights diverged due to excessive learning rate causing gradient explosion."}
72
+
73
+ RULES:
74
+ - submit_diagnosis MUST include all three fields: diagnosis, suggested_fix, reasoning.
75
+ - diagnosis is the short failure mode label β€” it is REQUIRED, never omit it.
76
+ - reasoning must cite specific values from the data you inspected (loss values, lr, gradient norms, etc.).
77
+ - Use exact failure mode phrasing for diagnosis: "exploding gradients", "overfitting", "underfitting",
78
+ "learning rate too high", "learning rate too low", "vanishing gradients",
79
+ "dying relu", "missing regularization", "batch size too small",
80
+ "optimizer misconfiguration", "bad weight initialization", "lr scheduler misconfiguration".
81
+ - Before submitting, check the Feedback field. If it says "N required source(s) still unexamined", inspect those sources first β€” do not submit until no required sources remain.
82
+ - If feedback says "This source is not required for this failure mode.", stop investigating that direction and submit.
83
+ - Never inspect the same source twice.
84
  """).strip()
85
 
86
 
 
136
  obs = result.observation
137
  history: List[str] = []
138
  rewards: List[float] = []
139
+ inspection_order: List[str] = []
140
+ submit_action: WhyDidItFailAction | None = None
141
 
142
  for step in range(1, MAX_STEPS + 1):
143
  if result.done:
 
160
 
161
  rewards.append(reward)
162
  history.append(f"Step {step}: {act_str} β†’ reward={reward:.2f} | {obs.feedback}")
163
+ print(f" [STEP] scenario={scenario_key} step={step} action={act_str} reward={reward:.2f} done={str(done).lower()}", flush=True)
164
 
165
  if done:
166
  break