samrat-rm commited on
Commit
66d62a2
Β·
1 Parent(s): e216a2f

feat: 3 modes of difficulty and updating the logs

Browse files
Files changed (1) hide show
  1. inference.py +115 -101
inference.py CHANGED
@@ -6,17 +6,23 @@ MANDATORY environment variables:
6
  MODEL_NAME The model identifier to use for inference.
7
  HF_TOKEN / API_KEY Your Hugging Face / API key.
8
 
 
 
 
 
 
9
  STDOUT FORMAT
10
- [START] task=<task_name> env=<benchmark> model=<model_name>
11
- [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
12
- [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
 
13
  """
14
 
15
  import asyncio
16
  import json
17
  import os
18
  import textwrap
19
- from typing import List, Optional
20
 
21
  from dotenv import load_dotenv
22
  load_dotenv()
@@ -25,154 +31,162 @@ from openai import OpenAI
25
 
26
  from client import WhyDidItFailEnv
27
  from models import WhyDidItFailAction
 
28
 
29
- IMAGE_NAME = os.getenv("IMAGE_NAME")
30
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
31
- API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
32
- MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
33
- TASK_NAME = os.getenv("WHYDIDITFAIL_TASK", "whydiditfail")
34
- BENCHMARK = os.getenv("WHYDIDITFAIL_BENCHMARK", "whydiditfail")
35
- MAX_STEPS = 8
36
- TEMPERATURE = 0.3
37
- MAX_TOKENS = 256
38
- SUCCESS_SCORE_THRESHOLD = 0.5 # reward >= 0.5 counts as success
39
-
40
- SYSTEM_PROMPT = textwrap.dedent(
41
- """
42
- You are a machine learning engineer diagnosing a failed training run.
43
- Each turn you will receive data from the training run and must decide what to investigate next.
44
 
45
- Available actions:
46
- - inspect_logs : examine training loss curves
47
- - inspect_config : examine hyperparameter config (lr, optimizer, etc.)
48
- - inspect_gradients : examine gradient statistics
49
- - submit_diagnosis : submit your final diagnosis (ends the episode)
50
 
51
- You must respond with a JSON object on a single line. Examples:
52
- {"action_type": "inspect_logs"}
53
- {"action_type": "inspect_config"}
54
- {"action_type": "submit_diagnosis", "diagnosis": "exploding gradients"}
55
 
56
- Only submit_diagnosis when you are confident. The diagnosis should describe the failure mode
57
- in plain terms (e.g. "exploding gradients", "overfitting", "vanishing gradients").
58
- """
59
- ).strip()
60
-
61
-
62
- def log_start(task: str, env: str, model: str) -> None:
63
- print(f"[START] task={task} env={env} model={model}", flush=True)
64
 
 
 
 
65
 
66
- def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
67
- error_val = error if error else "null"
68
- print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
 
 
69
 
 
 
 
 
70
 
71
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
72
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
73
- print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
74
 
75
 
76
- def build_user_prompt(step: int, observation_summary: str, history: List[str]) -> str:
77
  history_block = "\n".join(history[-4:]) if history else "None"
78
- return textwrap.dedent(
79
- f"""
80
- Step: {step}
81
 
82
- Current observation:
83
- {observation_summary}
84
 
85
- History:
86
  {history_block}
87
 
88
  Respond with a JSON action.
89
- """
90
- ).strip()
 
 
 
 
 
 
 
 
 
91
 
92
 
93
- def get_model_action(client: OpenAI, step: int, observation_summary: str, history: List[str]) -> WhyDidItFailAction:
94
- user_prompt = build_user_prompt(step, observation_summary, history)
95
  try:
96
  completion = client.chat.completions.create(
97
  model=MODEL_NAME,
98
  messages=[
99
  {"role": "system", "content": SYSTEM_PROMPT},
100
- {"role": "user", "content": user_prompt},
101
  ],
102
  temperature=TEMPERATURE,
103
  max_tokens=MAX_TOKENS,
104
- stream=False,
105
  )
106
  text = (completion.choices[0].message.content or "").strip()
107
- data = json.loads(text)
108
- return WhyDidItFailAction(**data)
109
  except Exception as exc:
110
- print(f"[DEBUG] Model request/parse failed: {exc}", flush=True)
111
- # Fallback: inspect logs if early, otherwise give up and submit empty diagnosis
112
  if step <= 2:
113
- return WhyDidItFailAction(action_type="inspect_logs")
114
- return WhyDidItFailAction(action_type="submit_diagnosis", diagnosis="unknown")
115
 
 
116
 
117
- def summarize_observation(obs) -> str:
118
- lines = [
119
- f"Task: {obs.task_description}",
120
- f"Feedback: {obs.feedback}",
121
- f"Available actions: {', '.join(obs.available_actions)}",
122
- ]
123
- if obs.visible_data:
124
- lines.append(f"Data: {json.dumps(obs.visible_data, indent=2)}")
125
- return "\n".join(lines)
126
 
 
 
 
127
 
128
- async def main() -> None:
129
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
130
- env = await WhyDidItFailEnv.from_docker_image(IMAGE_NAME or "")
 
 
 
131
 
132
- history: List[str] = []
133
- rewards: List[float] = []
134
- steps_taken = 0
135
- score = 0.0
136
- success = False
137
 
138
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
 
139
 
140
- try:
141
- result = await env.reset()
142
- obs = result.observation
 
143
 
144
- for step in range(1, MAX_STEPS + 1):
145
- if result.done:
146
- break
147
 
148
- obs_summary = summarize_observation(obs)
149
- action = get_model_action(client, step, obs_summary, history)
150
 
151
- result = await env.step(action)
152
- obs = result.observation
 
 
153
 
154
- reward = result.reward or 0.0
155
- done = result.done
156
- action_str = action.model_dump_json(exclude_none=True)
157
 
158
- rewards.append(reward)
159
- steps_taken = step
 
 
 
160
 
161
- log_step(step=step, action=action_str, reward=reward, done=done, error=None)
162
- history.append(f"Step {step}: {action_str} -> reward={reward:.2f} feedback={obs.feedback!r}")
 
163
 
164
- if done:
165
- break
166
 
167
- score = max(rewards) if rewards else 0.0 # final diagnosis reward is what matters
168
- success = score >= SUCCESS_SCORE_THRESHOLD
169
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  finally:
171
  try:
172
  await env.close()
173
  except Exception as e:
174
  print(f"[DEBUG] env.close() error: {e}", flush=True)
175
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
176
 
177
 
178
  if __name__ == "__main__":
 
6
  MODEL_NAME The model identifier to use for inference.
7
  HF_TOKEN / API_KEY Your Hugging Face / API key.
8
 
9
+ TASKS
10
+ Task 1 (easy) β€” identify failure mode from logs only
11
+ Task 2 (medium) β€” identify failure mode from logs + config [coming soon]
12
+ Task 3 (hard) β€” identify failure mode + provide correct fix [coming soon]
13
+
14
  STDOUT FORMAT
15
+ [START] task=<task_name> scenarios=<n> model=<model_name>
16
+ [EPISODE] scenario=<key> step=<n> action=<json> reward=<0.00> done=<bool>
17
+ [RESULT] scenario=<key> score=<0.000> steps=<n> success=<bool>
18
+ [SUMMARY] task=<task_name> avg_score=<0.000> pass_rate=<0.00>
19
  """
20
 
21
  import asyncio
22
  import json
23
  import os
24
  import textwrap
25
+ from typing import List
26
 
27
  from dotenv import load_dotenv
28
  load_dotenv()
 
31
 
32
  from client import WhyDidItFailEnv
33
  from models import WhyDidItFailAction
34
+ from server.scenarios import SCENARIOS
35
 
36
+ IMAGE_NAME = os.getenv("IMAGE_NAME", "")
37
+ SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
38
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
39
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
40
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
41
+ MAX_STEPS = 8
42
+ TEMPERATURE = 0.3
43
+ MAX_TOKENS = 256
44
+ SUCCESS_THRESHOLD = 0.5
 
 
 
 
 
 
45
 
46
+ # ── scenario lists by difficulty ─────────────────────────────────────────────
 
 
 
 
47
 
48
+ EASY_SCENARIOS = [k for k, v in SCENARIOS.items() if v["difficulty"] == "easy"]
49
+ MEDIUM_SCENARIOS = [k for k, v in SCENARIOS.items() if v["difficulty"] == "medium"]
50
+ HARD_SCENARIOS = [k for k, v in SCENARIOS.items() if v["difficulty"] == "hard"]
 
51
 
52
+ # ── prompts ───────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
53
 
54
+ SYSTEM_PROMPT = textwrap.dedent("""
55
+ You are a machine learning engineer diagnosing a failed training run.
56
+ Each turn you receive data and must decide what to investigate next.
57
 
58
+ Available actions:
59
+ inspect_logs β€” examine training loss/accuracy curves
60
+ inspect_config β€” examine hyperparameter config (lr, optimizer, etc.)
61
+ inspect_gradients β€” examine gradient norm statistics
62
+ submit_diagnosis β€” submit your final diagnosis (ends the episode)
63
 
64
+ Respond with a JSON object on a single line. Examples:
65
+ {"action_type": "inspect_logs"}
66
+ {"action_type": "submit_diagnosis", "diagnosis": "exploding gradients"}
67
+ {"action_type": "submit_diagnosis", "diagnosis": "overfitting", "suggested_fix": "add dropout=0.3"}
68
 
69
+ Be efficient β€” inspect only what you need. Submit when confident.
70
+ The diagnosis should be a short phrase describing the failure mode.
71
+ """).strip()
72
 
73
 
74
+ def _user_prompt(step: int, obs_summary: str, history: List[str]) -> str:
75
  history_block = "\n".join(history[-4:]) if history else "None"
76
+ return textwrap.dedent(f"""
77
+ Step {step}
 
78
 
79
+ Observation:
80
+ {obs_summary}
81
 
82
+ Recent history:
83
  {history_block}
84
 
85
  Respond with a JSON action.
86
+ """).strip()
87
+
88
+
89
+ def _summarize(obs) -> str:
90
+ lines = [
91
+ f"Task: {obs.task_description}",
92
+ f"Feedback: {obs.feedback}",
93
+ ]
94
+ if obs.visible_data:
95
+ lines.append(f"Data:\n{json.dumps(obs.visible_data, indent=2)}")
96
+ return "\n".join(lines)
97
 
98
 
99
+ def _get_action(client: OpenAI, step: int, obs_summary: str, history: List[str]) -> WhyDidItFailAction:
 
100
  try:
101
  completion = client.chat.completions.create(
102
  model=MODEL_NAME,
103
  messages=[
104
  {"role": "system", "content": SYSTEM_PROMPT},
105
+ {"role": "user", "content": _user_prompt(step, obs_summary, history)},
106
  ],
107
  temperature=TEMPERATURE,
108
  max_tokens=MAX_TOKENS,
 
109
  )
110
  text = (completion.choices[0].message.content or "").strip()
111
+ return WhyDidItFailAction(**json.loads(text))
 
112
  except Exception as exc:
113
+ print(f" [DEBUG] parse error: {exc}", flush=True)
 
114
  if step <= 2:
115
+ return WhyDidItFailAction(action_type="inspect_logs", diagnosis=None, suggested_fix=None)
116
+ return WhyDidItFailAction(action_type="submit_diagnosis", diagnosis="unknown", suggested_fix=None)
117
 
118
+ # ── episode runner ────────────────────────────────────────────────────────────
119
 
120
+ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -> dict:
121
+ """Run one full episode for a specific scenario. Returns result dict."""
122
+ result = await env.reset(scenario_key=scenario_key)
123
+ obs = result.observation
124
+ history: List[str] = []
125
+ rewards: List[float] = []
 
 
 
126
 
127
+ for step in range(1, MAX_STEPS + 1):
128
+ if result.done:
129
+ break
130
 
131
+ action = _get_action(client, step, _summarize(obs), history)
132
+ result = await env.step(action)
133
+ obs = result.observation
134
+ reward = result.reward or 0.0
135
+ done = result.done
136
+ act_str = action.model_dump_json(exclude_none=True)
137
 
138
+ rewards.append(reward)
139
+ history.append(f"Step {step}: {act_str} β†’ reward={reward:.2f} | {obs.feedback}")
140
+ print(f" [EPISODE] scenario={scenario_key} step={step} action={act_str} reward={reward:.2f} done={str(done).lower()}", flush=True)
 
 
141
 
142
+ if done:
143
+ break
144
 
145
+ # Final score = reward on submit_diagnosis (last reward)
146
+ score = rewards[-1] if rewards else 0.0
147
+ success = score >= SUCCESS_THRESHOLD
148
+ return {"scenario_key": scenario_key, "score": score, "steps": len(rewards), "success": success}
149
 
 
 
 
150
 
151
+ # ── task runners ──────────────────────────────────────────────────────────────
 
152
 
153
+ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEnv, client: OpenAI) -> None:
154
+ if not scenario_keys:
155
+ print(f"[SUMMARY] task={task_name} β€” no scenarios defined yet", flush=True)
156
+ return
157
 
158
+ print(f"\n[START] task={task_name} scenarios={len(scenario_keys)} model={MODEL_NAME}", flush=True)
 
 
159
 
160
+ results = []
161
+ for key in scenario_keys:
162
+ res = await run_episode(env, client, key)
163
+ results.append(res)
164
+ print(f"[RESULT] scenario={res['scenario_key']} score={res['score']:.3f} steps={res['steps']} success={str(res['success']).lower()}", flush=True)
165
 
166
+ avg_score = sum(r["score"] for r in results) / len(results)
167
+ pass_rate = sum(1 for r in results if r["success"]) / len(results)
168
+ print(f"[SUMMARY] task={task_name} avg_score={avg_score:.3f} pass_rate={pass_rate:.2f}", flush=True)
169
 
 
 
170
 
171
+ # ── main ──────────────────────────────────────────────────────────────────────
 
172
 
173
+ async def main() -> None:
174
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
175
+ env = (
176
+ await WhyDidItFailEnv.from_docker_image(IMAGE_NAME)
177
+ if IMAGE_NAME
178
+ else WhyDidItFailEnv(base_url=SERVER_URL)
179
+ )
180
+
181
+ try:
182
+ await run_task("easy", EASY_SCENARIOS, env, client)
183
+ await run_task("medium", MEDIUM_SCENARIOS, env, client)
184
+ await run_task("hard", HARD_SCENARIOS, env, client)
185
  finally:
186
  try:
187
  await env.close()
188
  except Exception as e:
189
  print(f"[DEBUG] env.close() error: {e}", flush=True)
 
190
 
191
 
192
  if __name__ == "__main__":