samrat-rm commited on
Commit
dc7aeea
Β·
1 Parent(s): a310ad6

feat: upgrading the system and user prompt, upgrading the _make_env() function

Browse files
Files changed (1) hide show
  1. inference.py +33 -14
inference.py CHANGED
@@ -71,18 +71,25 @@ SYSTEM_PROMPT = textwrap.dedent("""
71
 
72
  Examples:
73
  {"action_type": "inspect_logs"}
74
- {"action_type": "submit_diagnosis", "diagnosis": "exploding gradients", "suggested_fix": "reduce learning_rate to 0.001", "reasoning": "Loss spiked to NaN by epoch 3 and lr=10.0 in config, indicating weights diverged due to excessive learning rate causing gradient explosion."}
 
 
 
 
 
 
 
75
 
76
  RULES:
77
  - submit_diagnosis MUST include all three fields: diagnosis, suggested_fix, reasoning.
78
  - diagnosis is the short failure mode label β€” it is REQUIRED, never omit it.
79
- - reasoning must cite specific values from the data you inspected (loss values, lr, gradient norms, etc.).
80
  - Use exact failure mode phrasing for diagnosis: "exploding gradients", "overfitting", "underfitting",
81
  "learning rate too high", "learning rate too low", "vanishing gradients",
82
  "dying relu", "missing regularization", "batch size too small",
83
  "optimizer misconfiguration", "bad weight initialization", "lr scheduler misconfiguration".
84
- - Before submitting, check the Feedback field. If it says "N required source(s) still unexamined", inspect those sources first β€” do not submit until no required sources remain.
85
- - If feedback says "This source is not required for this failure mode.", stop investigating that direction and submit.
 
86
  - Never inspect the same source twice.
87
  """).strip()
88
 
@@ -98,6 +105,8 @@ def _user_prompt(step: int, obs_summary: str, history: List[str]) -> str:
98
  Recent history:
99
  {history_block}
100
 
 
 
101
  Respond with a JSON action.
102
  """).strip()
103
 
@@ -136,9 +145,23 @@ def _get_action(client: OpenAI, step: int, obs_summary: str, history: List[str])
136
 
137
  # ── episode runner ────────────────────────────────────────────────────────────
138
 
139
- async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -> dict:
140
- """Run one full episode for a specific scenario. Returns result dict."""
141
- result = await env.reset(scenario_key=scenario_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  obs = result.observation
143
  history: List[str] = []
144
  rewards: List[float] = []
@@ -192,7 +215,7 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
192
  print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", flush=True)
193
 
194
  success = score >= SUCCESS_THRESHOLD
195
- return {"scenario_key": scenario_key, "score": score, "steps": len(rewards), "success": success}
196
 
197
 
198
  # ── task runners ──────────────────────────────────────────────────────────────
@@ -206,7 +229,7 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
206
 
207
  results = []
208
  for key in scenario_keys:
209
- res = await run_episode(env, client, key)
210
  results.append(res)
211
  print(f"[RESULT] scenario={res['scenario_key']} score={res['score']:.3f} steps={res['steps']} success={str(res['success']).lower()}", flush=True)
212
 
@@ -219,11 +242,7 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
219
 
220
  async def main() -> None:
221
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
222
- env = (
223
- await WhyDidItFailEnv.from_docker_image(IMAGE_NAME)
224
- if IMAGE_NAME
225
- else WhyDidItFailEnv(base_url=SERVER_URL)
226
- )
227
 
228
  try:
229
  await run_task("easy", EASY_SCENARIOS, env, client)
 
71
 
72
  Examples:
73
  {"action_type": "inspect_logs"}
74
+ {"action_type": "submit_diagnosis", "diagnosis": "overfitting", "suggested_fix": "add dropout=0.3 and weight_decay=0.01", "reasoning": "train_loss fell to 0.03 by epoch 20 while val_loss rose to 2.34; train_acc=0.99 vs val_acc=0.54 β€” clear generalization gap. Config shows dropout=0.0 and weight_decay=0.0."}
75
+
76
+ DIAGNOSIS PROCESS β€” follow this every episode:
77
+ 1. Call inspect_logs first β€” always.
78
+ 2. Read the Data field carefully. Note the exact numeric values (loss, acc, lr, gradient norms, model).
79
+ 3. If Feedback says "Next required action: inspect_X" β€” call that action next, no exceptions.
80
+ 4. When no required actions remain, form your diagnosis based ONLY on values you actually saw in Data.
81
+ 5. Your reasoning MUST quote specific numbers from the Data you received (e.g. "val_loss=2.34 at epoch 20, train_acc=0.99"). If you cannot quote a specific number from the Data, you have not read it β€” do not submit yet.
82
 
83
  RULES:
84
  - submit_diagnosis MUST include all three fields: diagnosis, suggested_fix, reasoning.
85
  - diagnosis is the short failure mode label β€” it is REQUIRED, never omit it.
 
86
  - Use exact failure mode phrasing for diagnosis: "exploding gradients", "overfitting", "underfitting",
87
  "learning rate too high", "learning rate too low", "vanishing gradients",
88
  "dying relu", "missing regularization", "batch size too small",
89
  "optimizer misconfiguration", "bad weight initialization", "lr scheduler misconfiguration".
90
+ - CRITICAL: If Feedback contains "Next required action: inspect_X", you MUST call that action before submitting. Do not submit while any required source is unexamined.
91
+ - If Feedback says "This source is not required for this failure mode." β€” submit your diagnosis on the very next step. Do NOT inspect other sources.
92
+ - If Feedback says "Relevant clue found" with no "Next required action" β€” all sources are covered. Submit on the next step.
93
  - Never inspect the same source twice.
94
  """).strip()
95
 
 
105
  Recent history:
106
  {history_block}
107
 
108
+ Before responding: read the Data above carefully. What specific numeric values do you see?
109
+ Quote at least one value from the Data in your reasoning before submitting a diagnosis.
110
  Respond with a JSON action.
111
  """).strip()
112
 
 
145
 
146
  # ── episode runner ────────────────────────────────────────────────────────────
147
 
148
+ async def _make_env() -> WhyDidItFailEnv:
149
+ return (
150
+ await WhyDidItFailEnv.from_docker_image(IMAGE_NAME)
151
+ if IMAGE_NAME
152
+ else WhyDidItFailEnv(base_url=SERVER_URL)
153
+ )
154
+
155
+
156
+ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -> tuple[dict, WhyDidItFailEnv]:
157
+ """Run one full episode for a specific scenario. Returns (result dict, env).
158
+ env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
159
+ try:
160
+ result = await env.reset(scenario_key=scenario_key)
161
+ except ConnectionClosedError:
162
+ print(f" [WARN] scenario={scenario_key} reconnecting WebSocket...", flush=True)
163
+ env = await _make_env()
164
+ result = await env.reset(scenario_key=scenario_key)
165
  obs = result.observation
166
  history: List[str] = []
167
  rewards: List[float] = []
 
215
  print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", flush=True)
216
 
217
  success = score >= SUCCESS_THRESHOLD
218
+ return {"scenario_key": scenario_key, "score": score, "steps": len(rewards), "success": success}, env
219
 
220
 
221
  # ── task runners ──────────────────────────────────────────────────────────────
 
229
 
230
  results = []
231
  for key in scenario_keys:
232
+ res, env = await run_episode(env, client, key)
233
  results.append(res)
234
  print(f"[RESULT] scenario={res['scenario_key']} score={res['score']:.3f} steps={res['steps']} success={str(res['success']).lower()}", flush=True)
235
 
 
242
 
243
  async def main() -> None:
244
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
245
+ env = await _make_env()
 
 
 
 
246
 
247
  try:
248
  await run_task("easy", EASY_SCENARIOS, env, client)