omkarrr88 commited on
Commit
7336adb
·
1 Parent(s): deea97b

Minor fixes

Browse files
.gitignore CHANGED
@@ -13,16 +13,7 @@ validation/reports/*.png
13
  .mypy_cache/
14
  .ruff_cache/
15
  .coverage
16
- .claude/
17
- CLAUDE.md
18
  .hf-space/
19
  .python-version
20
  deploy-hf.sh
21
  deploy.sh
22
- AUDIT_REPORT.md
23
- baseline_inference.py
24
- run_all_baselines.py
25
- docs/PAPER.md
26
- docs/PRD.md
27
- docs/ROADMAP.md
28
- docs/PROJECT_GUIDE.md
 
13
  .mypy_cache/
14
  .ruff_cache/
15
  .coverage
 
 
16
  .hf-space/
17
  .python-version
18
  deploy-hf.sh
19
  deploy.sh
 
 
 
 
 
 
 
README.md CHANGED
@@ -29,7 +29,7 @@ The agent starts with limited information (loss curves, config, error log) and m
29
 
30
  ### Real PyTorch Model Internals
31
 
32
- Every gradient comes from real `torch.autograd`. Every weight stat comes from real `model.state_dict()`. The environment instantiates actual `torch.nn.Module` models (SimpleCNN ~50K params, SimpleMLP ~20K params), runs 20 real forward+backward epochs per reset, and extracts real tensor statistics. Not synthetic formulas — real PyTorch computation, cached for instant replay.
33
 
34
  ### Context-Gated Reward Shaping
35
 
@@ -88,7 +88,7 @@ Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_s
88
 
89
  ## Action Space
90
 
91
- 14 action types in 3 categories:
92
 
93
  **Investigation** — reveal hidden observation fields:
94
  - `inspect_gradients` — per-layer gradient norms, is_exploding/is_vanishing flags
@@ -107,7 +107,6 @@ Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_s
107
 
108
  **Terminal** — end the episode:
109
  - `restart_run` — restart training (only available after a fix)
110
- - `rollback_checkpoint` — rollback to pre-fix state (only available after restart)
111
  - `mark_diagnosed` — submit diagnosis from 7 possible root causes
112
 
113
  Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
@@ -174,14 +173,8 @@ An agent that chases the gradient spike red herring loses 0.20 points. An agent
174
  # Heuristic (deterministic, no API key, bit-exact reproducible)
175
  python3 baseline_heuristic.py
176
 
177
- # LLM (multi-provider support)
178
- python3 baseline_inference.py # Groq — Llama 3.3 70B (free)
179
- python3 baseline_inference.py --provider cerebras # Cerebras — Llama 3.1 8B (free)
180
- python3 baseline_inference.py --provider gemini # Google Gemini 2.0 Flash
181
- python3 baseline_inference.py --provider openai # OpenAI GPT-4o
182
-
183
- # Run all baselines with comparison table
184
- python3 run_all_baselines.py
185
  ```
186
 
187
  ## API
@@ -299,7 +292,7 @@ server/
299
 
300
  tests/ — 246 tests, 96% coverage
301
  baseline_heuristic.py — Rule-based agent (deterministic, no API key)
302
- baseline_inference.py — LLM agent (Groq/Cerebras/Gemini/OpenAI)
303
  ```
304
 
305
  **Key design decisions:**
 
29
 
30
  ### Real PyTorch Model Internals
31
 
32
+ Every gradient comes from real `torch.autograd`. Every weight stat comes from real `model.state_dict()`. The environment instantiates actual `torch.nn.Module` models (SimpleCNN ~67K params, SimpleMLP ~412K params), runs 20 real forward+backward epochs per reset, and extracts real tensor statistics. Not synthetic formulas — real PyTorch computation, cached for instant replay.
33
 
34
  ### Context-Gated Reward Shaping
35
 
 
88
 
89
  ## Action Space
90
 
91
+ 13 action types in 3 categories:
92
 
93
  **Investigation** — reveal hidden observation fields:
94
  - `inspect_gradients` — per-layer gradient norms, is_exploding/is_vanishing flags
 
107
 
108
  **Terminal** — end the episode:
109
  - `restart_run` — restart training (only available after a fix)
 
110
  - `mark_diagnosed` — submit diagnosis from 7 possible root causes
111
 
112
  Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
 
173
  # Heuristic (deterministic, no API key, bit-exact reproducible)
174
  python3 baseline_heuristic.py
175
 
176
+ # LLM (hackathon evaluator format — uses OpenAI client)
177
+ API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o OPENAI_API_KEY=sk-... python3 inference.py
 
 
 
 
 
 
178
  ```
179
 
180
  ## API
 
292
 
293
  tests/ — 246 tests, 96% coverage
294
  baseline_heuristic.py — Rule-based agent (deterministic, no API key)
295
+ inference.py — LLM agent (OpenAI client, hackathon format)
296
  ```
297
 
298
  **Key design decisions:**
baseline_inference.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """LLM baseline agent using Google Gemini (via OpenAI-compatible SDK).
3
+
4
+ Requires GEMINI_API_KEY environment variable (or pass via --api-key).
5
+ Uses temperature=0.0 for near-deterministic behavior.
6
+ Spec reference: Section 17.
7
+
8
+ Usage:
9
+ GEMINI_API_KEY=... python baseline_inference.py
10
+ python baseline_inference.py --api-key YOUR_KEY
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ # Load .env file if present
22
+ _env_path = Path(__file__).parent / ".env"
23
+ if _env_path.exists():
24
+ for line in _env_path.read_text().splitlines():
25
+ line = line.strip()
26
+ if line and not line.startswith("#") and "=" in line:
27
+ key, _, value = line.partition("=")
28
+ os.environ.setdefault(key.strip(), value.strip())
29
+
30
+ try:
31
+ from openai import OpenAI
32
+ except ImportError:
33
+ print("Error: openai package not installed. Run: pip install openai")
34
+ sys.exit(1)
35
+
36
+ from ml_training_debugger.models import MLTrainingAction
37
+ from server.environment import MLTrainingEnvironment
38
+
39
+ ALL_TASKS = [
40
+ "task_001",
41
+ "task_002",
42
+ "task_003",
43
+ "task_004",
44
+ "task_005",
45
+ "task_006",
46
+ "task_007",
47
+ ]
48
+
49
+ SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run.
50
+ You are interacting with an environment that simulates a broken training job.
51
+
52
+ Available actions (respond with JSON only, no explanation):
53
+ - {"action_type": "inspect_gradients"} - View gradient statistics per layer
54
+ - {"action_type": "inspect_data_batch"} - View data batch statistics and confusion matrix
55
+ - {"action_type": "inspect_model_modes"} - View model layer modes (train/eval)
56
+ - {"action_type": "inspect_model_weights"} - View model weight statistics
57
+ - {"action_type": "inspect_code"} - View PyTorch training code
58
+ - {"action_type": "modify_config", "target": "<field>", "value": <val>} - Change a hyperparameter
59
+ - {"action_type": "add_callback"} - Add gradient clipping/scheduler
60
+ - {"action_type": "patch_data_loader"} - Fix data pipeline issues
61
+ - {"action_type": "fix_model_mode"} - Call model.train()
62
+ - {"action_type": "fix_code", "line": <int>, "replacement": "<code>"} - Fix a code line
63
+ - {"action_type": "restart_run"} - Restart training (requires a fix first)
64
+ - {"action_type": "mark_diagnosed", "diagnosis": "<cause>"} - Submit diagnosis
65
+
66
+ Valid diagnoses: lr_too_high, vanishing_gradients, data_leakage, overfitting, batchnorm_eval_mode, code_bug, scheduler_misconfigured
67
+
68
+ Strategy:
69
+ 1. First investigate by inspecting gradients, data, model modes, and code
70
+ 2. Form a hypothesis based on the evidence gathered
71
+ 3. Apply the correct fix for the identified root cause
72
+ 4. Restart training to verify the fix works
73
+ 5. Submit your diagnosis
74
+
75
+ IMPORTANT: Respond with ONLY a valid JSON action object. No explanation, no markdown, no code blocks."""
76
+
77
+
78
+ def run_llm_episode(task_id: str, client: OpenAI, model_name: str) -> float:
79
+ """Run one LLM agent episode."""
80
+ env = MLTrainingEnvironment()
81
+ obs = env.reset(seed=42, episode_id=f"llm_{task_id}", task_id=task_id)
82
+
83
+ initial_obs = {
84
+ "training_loss_history": obs.training_loss_history[:5],
85
+ "val_accuracy_history": obs.val_accuracy_history[:5],
86
+ "current_config": obs.current_config.model_dump(),
87
+ "error_log": obs.error_log,
88
+ "available_actions": obs.available_actions,
89
+ "notes": obs.notes,
90
+ "gpu_memory_used_gb": obs.gpu_memory_used_gb,
91
+ }
92
+
93
+ messages = [
94
+ {"role": "system", "content": SYSTEM_PROMPT},
95
+ {
96
+ "role": "user",
97
+ "content": f"New episode started for a broken PyTorch training run.\n\nInitial observation:\n{json.dumps(initial_obs, indent=2, default=str)}",
98
+ },
99
+ ]
100
+
101
+ for step in range(25):
102
+ if obs.done:
103
+ break
104
+
105
+ try:
106
+ response = client.chat.completions.create(
107
+ model=model_name,
108
+ messages=messages,
109
+ temperature=0.0,
110
+ max_tokens=300,
111
+ )
112
+ action_text = response.choices[0].message.content.strip()
113
+ except Exception as e:
114
+ print(f" Step {step}: API error — {e}", file=sys.stderr)
115
+ break
116
+
117
+ # Clean up common LLM formatting issues
118
+ action_text = action_text.strip("`").strip()
119
+ if action_text.startswith("json"):
120
+ action_text = action_text[4:].strip()
121
+
122
+ messages.append({"role": "assistant", "content": action_text})
123
+
124
+ try:
125
+ action_data = json.loads(action_text)
126
+ action = MLTrainingAction(**action_data)
127
+ except (json.JSONDecodeError, Exception) as e:
128
+ messages.append(
129
+ {
130
+ "role": "user",
131
+ "content": f"Invalid action format: {e}. Respond with ONLY valid JSON.",
132
+ }
133
+ )
134
+ continue
135
+
136
+ obs = env.step(action)
137
+
138
+ obs_summary: dict = {
139
+ "reward": obs.reward,
140
+ "done": obs.done,
141
+ "step": obs.episode_state.step_count,
142
+ "available_actions": obs.available_actions,
143
+ }
144
+ if obs.error_log:
145
+ obs_summary["error_log"] = obs.error_log
146
+ if obs.gradient_stats:
147
+ obs_summary["gradient_stats"] = [
148
+ {
149
+ "layer": g.layer_name,
150
+ "mean_norm": round(g.mean_norm, 4),
151
+ "exploding": g.is_exploding,
152
+ "vanishing": g.is_vanishing,
153
+ }
154
+ for g in obs.gradient_stats
155
+ ]
156
+ if obs.data_batch_stats:
157
+ obs_summary["data_overlap"] = obs.data_batch_stats.class_overlap_score
158
+ obs_summary["duplicate_ratio"] = obs.data_batch_stats.duplicate_ratio
159
+ if obs.model_mode_info:
160
+ obs_summary["model_modes"] = obs.model_mode_info
161
+ if obs.code_snippet:
162
+ obs_summary["code"] = obs.code_snippet.code[:600]
163
+ obs_summary["hint"] = obs.code_snippet.hint
164
+
165
+ messages.append(
166
+ {
167
+ "role": "user",
168
+ "content": f"Observation after your action:\n{json.dumps(obs_summary, indent=2, default=str)}",
169
+ }
170
+ )
171
+
172
+ session = env._get_session()
173
+ return session.last_score if session and session.last_score is not None else 0.0
174
+
175
+
176
+ PROVIDERS = {
177
+ "groq": {
178
+ "env_key": "GROQ_API_KEY",
179
+ "base_url": "https://api.groq.com/openai/v1",
180
+ "default_model": "llama-3.3-70b-versatile",
181
+ },
182
+ "cerebras": {
183
+ "env_key": "CEREBRAS_API_KEY",
184
+ "base_url": "https://api.cerebras.ai/v1",
185
+ "default_model": "llama3.1-8b",
186
+ },
187
+ "gemini": {
188
+ "env_key": "GEMINI_API_KEY",
189
+ "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
190
+ "default_model": "gemini-2.0-flash",
191
+ },
192
+ "openai": {
193
+ "env_key": "OPENAI_API_KEY",
194
+ "base_url": None,
195
+ "default_model": "gpt-4o",
196
+ },
197
+ }
198
+
199
+
200
+ def main() -> None:
201
+ parser = argparse.ArgumentParser(description="LLM baseline agent")
202
+ parser.add_argument("--url", default="http://localhost:7860")
203
+ parser.add_argument("--api-key", default=None, help="API key")
204
+ parser.add_argument(
205
+ "--provider",
206
+ default="groq",
207
+ choices=list(PROVIDERS.keys()),
208
+ help="LLM provider (default: groq)",
209
+ )
210
+ parser.add_argument("--model", default=None, help="Model name (auto-detected from provider)")
211
+ args = parser.parse_args()
212
+
213
+ prov = PROVIDERS[args.provider]
214
+ api_key = args.api_key or os.environ.get(prov["env_key"])
215
+ if not api_key:
216
+ print(f"Error: Set {prov['env_key']} env var or pass --api-key")
217
+ sys.exit(1)
218
+
219
+ model_name = args.model or prov["default_model"]
220
+ client_kwargs: dict = {"api_key": api_key}
221
+ if prov["base_url"]:
222
+ client_kwargs["base_url"] = prov["base_url"]
223
+ client = OpenAI(**client_kwargs)
224
+
225
+ scores: dict[str, float] = {}
226
+ print(f"Running LLM baseline with {args.provider}/{model_name}...", file=sys.stderr)
227
+
228
+ for task_id in ALL_TASKS:
229
+ try:
230
+ score = run_llm_episode(task_id, client, model_name)
231
+ scores[task_id] = round(score, 4)
232
+ print(f" {task_id}: {score:.4f}", file=sys.stderr)
233
+ except Exception as e:
234
+ print(f" {task_id}: ERROR — {e}", file=sys.stderr)
235
+ scores[task_id] = 0.0
236
+
237
+ print(json.dumps(scores, indent=2))
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
docs/EXPLANATION.md DELETED
@@ -1,340 +0,0 @@
1
- # PyTorch Training Run Debugger — Explained Simply
2
-
3
- > This file explains the entire project as if you're 10 years old. No jargon. Just simple language.
4
-
5
- ---
6
-
7
- ## What Is This Project?
8
-
9
- Imagine you're a doctor, but instead of fixing sick people, you fix **sick computers that are trying to learn**.
10
-
11
- When computers learn (this is called "Machine Learning" or ML), they look at thousands of examples — like pictures of cats and dogs — and slowly get better at telling them apart. This learning process is called **training**.
12
-
13
- But sometimes, training goes wrong. The computer makes mistakes, gets confused, or learns the wrong things. When that happens, a human engineer has to figure out what went wrong and fix it — just like a doctor diagnosing a patient.
14
-
15
- **This project builds a practice hospital for AI doctors.** It creates fake "sick training runs" with known problems, and then an AI agent (the doctor) has to:
16
-
17
- 1. **Investigate** — Look at clues (like checking temperature or blood pressure)
18
- 2. **Diagnose** — Figure out what's wrong
19
- 3. **Fix** — Apply the right treatment
20
- 4. **Verify** — Check if the patient recovered
21
-
22
- ---
23
-
24
- ## Why Does This Matter?
25
-
26
- Real companies like Meta, Google, and OpenAI spend millions of dollars training AI models. When training breaks, engineers waste hours (sometimes days!) figuring out what went wrong. Each hour of broken training can cost **$2-$8 per GPU** — and some companies use thousands of GPUs at once.
27
-
28
- If we could train an AI to automatically find and fix these problems, it would save enormous amounts of time and money.
29
-
30
- This project is a **training ground** where AI agents can practice debugging — like a flight simulator for pilots, but for ML engineers.
31
-
32
- ---
33
-
34
- ## How Does It Work? (The Big Picture)
35
-
36
- Think of it like a detective game with 6 mystery cases:
37
-
38
- ### The Game Rules
39
-
40
- 1. **The computer shows you a broken training run** — You see charts showing how the training is going (spoiler: it's going badly!)
41
- 2. **You can investigate** — You have 5 different "magnifying glasses" to look at different parts of the problem
42
- 3. **You figure out what's wrong** — You pick from a list of 6 possible problems
43
- 4. **You fix it** — You apply the right fix
44
- 5. **You restart and check** — You restart the training and see if it works now
45
- 6. **You submit your answer** — "I think the problem was X"
46
-
47
- If you're right, you get points. If you're wrong, you lose points. If you investigate smartly, you get bonus points. If you ignore evidence and do something silly, you get penalty points.
48
-
49
- ---
50
-
51
- ## The 6 Mystery Cases (Tasks)
52
-
53
- ### Easy Cases (Like finding a broken window)
54
-
55
- **Case 1: Learning Rate Too High (task_001)**
56
- > Imagine you're learning to ride a bike, but someone set the speed to 100 mph. You'd crash immediately!
57
-
58
- That's what happens here. The computer is learning too fast and everything explodes. The numbers go crazy and become "NaN" (Not a Number — like dividing by zero).
59
-
60
- **Clues:** Every part of the computer shows "EXPLODING!" when you check the gradients (the direction signals that guide learning).
61
-
62
- **Fix:** Turn down the speed (reduce the learning rate from 0.1 to 0.001).
63
-
64
- ---
65
-
66
- **Case 2: Vanishing Gradients (task_002)**
67
- > Now imagine you're whispering instructions to someone 100 rooms away. By the time the message reaches them, it's too quiet to hear.
68
-
69
- The learning signals get weaker and weaker as they travel through the computer's brain layers. The deeper layers get almost zero signal — so they can't learn anything.
70
-
71
- **Clues:** Deeper layers show "VANISHING!" gradients. The loss curve is flat — nothing is being learned.
72
-
73
- **Fix:** Increase the learning rate so the signals are louder.
74
-
75
- ---
76
-
77
- ### Medium Cases (Like finding a hidden leak)
78
-
79
- **Case 3: Data Leakage (task_003)**
80
- > Imagine taking a math test, but the answer key is mixed into your practice problems. You'd score 100% — but you didn't actually learn anything!
81
-
82
- The training data and test data got mixed together. The computer looks amazing on tests, but it's just memorizing answers — it hasn't actually learned.
83
-
84
- **Clues:** Suspiciously high test scores from the very start. When you check the data, you find a "class overlap score" above 0.5 — meaning lots of test answers leaked into the training set.
85
-
86
- **Trick:** There's a misleading note saying "we upgraded the model architecture" — making you think the high scores are from a better model, not leaked data.
87
-
88
- **Fix:** Clean the data pipeline to remove the overlap.
89
-
90
- ---
91
-
92
- **Case 4: Overfitting (task_004)**
93
- > Imagine memorizing every single answer to last year's exam, but then failing this year's exam because the questions are slightly different.
94
-
95
- The computer has memorized the training data perfectly (train loss near zero!) but fails on new data it hasn't seen before (validation loss keeps rising).
96
-
97
- **Clues:** Training loss drops to almost zero while validation loss goes up — the classic "train-val divergence."
98
-
99
- **Fix:** Add regularization (weight decay) — this is like telling the computer "don't memorize, understand the patterns instead."
100
-
101
- ---
102
-
103
- ### Hard Cases (Like solving a mystery with fake clues)
104
-
105
- **Case 5: BatchNorm Eval Mode (task_005)**
106
- > Imagine a student who studies perfectly at home but freezes during the actual exam because they switched into "test mode" too early.
107
-
108
- The computer's model has a special feature called BatchNorm that behaves differently during training vs testing. Someone accidentally left it in "test mode" during training. This causes subtle, slow degradation — not an obvious crash.
109
-
110
- **The Trap:** This case has **red herrings** — fake clues designed to mislead you:
111
- - One layer's gradient suddenly spikes (but it's not actually exploding)
112
- - GPU memory is at 91% (looks scary, but it's not the problem)
113
- - One layer has near-vanishing gradients (but that's normal for this layer)
114
- - An error log warns about GPU memory (irrelevant to the real problem)
115
-
116
- **Clues:** When you check the model modes, you find all layers are in "eval" (test) mode instead of "train" mode. That's the real problem.
117
-
118
- **Why it's hard:** Most agents see the gradient spike and immediately try to fix gradients — falling for the trap. The smart agent checks model modes and finds the real issue.
119
-
120
- ---
121
-
122
- **Case 6: Code Bug (task_006)**
123
- > Imagine a recipe that says "bake for 30 minutes" but someone accidentally changed it to "bake for 0 minutes." The oven runs, but nothing gets cooked.
124
-
125
- There's an actual bug in the Python code. The agent sees the source code and has to find the buggy line and fix it. There are 4 possible bugs:
126
-
127
- 1. **eval_mode** — `model.eval()` instead of `model.train()` (wrong mode)
128
- 2. **detach_loss** — `loss.detach()` before `.backward()` (disconnects the learning signal)
129
- 3. **zero_grad_missing** — Forgot to clear old gradients (gradients pile up incorrectly)
130
- 4. **inplace_relu** — `inplace=True` on ReLU (corrupts the computation graph)
131
-
132
- **Why it's hard:** The agent must actually READ code and understand what each line does — not just look at numbers and charts.
133
-
134
- ---
135
-
136
- ## The Scoring System
137
-
138
- ### Rewards (Points You Earn)
139
-
140
- Think of it like a video game:
141
-
142
- | What You Do | Points | Why |
143
- |-------------|--------|-----|
144
- | Take any action | **-0.01** | Every move costs a tiny bit (encourages efficiency) |
145
- | Investigate something for the first time | **+0.05** | Looking at clues is good! |
146
- | Correct diagnosis | **+0.50** | You found the answer! |
147
- | Fix works and training recovers | **+0.40** | Your fix actually helped! |
148
-
149
- ### Penalties (Points You Lose)
150
-
151
- | What You Do | Points | Why |
152
- |-------------|--------|-----|
153
- | Do something invalid | **-0.05** | You tried something that's not allowed |
154
- | Wrong code fix | **-0.10** | Your code fix didn't work |
155
- | Wrong diagnosis | **-0.30** | You guessed wrong |
156
-
157
- ### The Special Penalty: Context-Gated Penalty
158
-
159
- This is the **coolest part** of the project. Here's how it works:
160
-
161
- > You check the gradients and see they're all normal. Then you add gradient clipping anyway (a fix for gradient problems). But wait — YOU ALREADY KNOW the gradients are fine! You're ignoring your own evidence!
162
-
163
- **Penalty: -0.20 points**
164
-
165
- But if you add gradient clipping BEFORE checking gradients? No penalty — you haven't seen any evidence yet, so it's a reasonable guess.
166
-
167
- This teaches the AI: **"Don't ignore what you've already learned."**
168
-
169
- ---
170
-
171
- ### The Grader (Final Score)
172
-
173
- At the end of each case, a grader gives you a score from **0.0 to 1.0**:
174
-
175
- - **1.0** = Perfect — investigated, fixed, restarted, and diagnosed correctly
176
- - **0.5-0.8** = Partial — got some things right, missed others
177
- - **0.0** = Failed — wrong diagnosis, no fix, or ran out of steps
178
-
179
- The grader looks at the WHOLE story of what you did, not just the final answer.
180
-
181
- ---
182
-
183
- ## How the Code Is Organized
184
-
185
- ```
186
- ML Debugger/
187
-
188
- ├── ml_training_debugger/ ← The brain of the project
189
- │ ├── models.py ← Data shapes (what observations and actions look like)
190
- │ ├── scenarios.py ← Creates the 6 mystery cases with random parameters
191
- │ ├── pytorch_engine.py ← Real PyTorch model that gets "sick" (fault injection)
192
- │ ├── simulation.py ← Generates fake training charts (loss curves, accuracy)
193
- │ ├── reward_engine.py ← Calculates points for each action
194
- │ ├── graders.py ← Final scoring (0.0 to 1.0) at episode end
195
- │ ├── code_templates.py ← The buggy code snippets for Task 6
196
- │ └── client.py ← Helper for connecting to the environment
197
-
198
- ├── server/ ← The web server
199
- │ ├── app.py ← Main server with all API endpoints
200
- │ ├── environment.py ← The game logic (reset, step, state)
201
- │ └── _baseline_results.py ← Stores grader results
202
-
203
- ├── tests/ ← 183 tests making sure everything works
204
-
205
- ├── baseline_heuristic.py ← A simple robot that plays the game using rules
206
- ├── baseline_inference.py ← A smart AI (GPT-4) that plays the game
207
- ├── Dockerfile ← Instructions to package everything in a container
208
- ├── openenv.yaml ← Configuration file for the OpenEnv framework
209
- └── README.md ← Technical documentation
210
- ```
211
-
212
- ---
213
-
214
- ## How a Game Session Works (Step by Step)
215
-
216
- Let's walk through a complete game:
217
-
218
- ### Step 1: Start a New Game
219
- ```
220
- Agent: "Start task_001 please"
221
- Environment: "Here's your broken training run:"
222
- - Loss history: [2.3, 3.5, 8.2, 45.0, inf, inf, inf, ...] ← Yikes, numbers exploding!
223
- - Error log: "Loss is NaN at epoch 12"
224
- - Available actions: [inspect_gradients, inspect_data_batch, ...]
225
- ```
226
-
227
- ### Step 2: Investigate
228
- ```
229
- Agent: "Let me inspect the gradients"
230
- Environment: "Here's what I found:"
231
- - conv1: mean_norm=51.1, is_exploding=True
232
- - conv2: mean_norm=91.3, is_exploding=True
233
- - conv3: mean_norm=111.8, is_exploding=True
234
- - fc: mean_norm=37.7, is_exploding=True
235
- Reward: +0.04 (step penalty + investigation bonus)
236
- ```
237
-
238
- ### Step 3: Fix
239
- ```
240
- Agent: "Reduce learning rate to 0.001"
241
- Environment: "Config updated. learning_rate = 0.001"
242
- Reward: -0.01 (step penalty only)
243
- ```
244
-
245
- ### Step 4: Restart
246
- ```
247
- Agent: "Restart the training run"
248
- Environment: "Training restarted. Convergence detected!"
249
- Reward: +0.39 (step penalty + convergence bonus)
250
- ```
251
-
252
- ### Step 5: Diagnose
253
- ```
254
- Agent: "The problem was lr_too_high"
255
- Environment: "CORRECT! Episode complete."
256
- Reward: +0.49 (step penalty + correct diagnosis)
257
- Final grader score: 1.0 ← Perfect!
258
- ```
259
-
260
- ---
261
-
262
- ## What Makes This Project Special?
263
-
264
- ### 1. It Uses REAL PyTorch
265
- This isn't fake data. When you inspect gradients, you're looking at real numbers computed by a real neural network using `torch.autograd`. The model has ~50,000 parameters and runs real forward/backward passes. This matters because the hackathon is organized by **Meta (the company that makes PyTorch)**.
266
-
267
- ### 2. Context-Gated Rewards
268
- No other OpenEnv environment does this. The reward system tracks what the agent has learned and penalizes it for ignoring evidence. This teaches AI to reason like a real engineer — gather evidence first, then act.
269
-
270
- ### 3. Code-Level Debugging (Task 6)
271
- The agent reads actual Python code and submits line-by-line fixes. This tests code understanding — not just number crunching. Meta cares about this because they want AI that can debug PyTorch code.
272
-
273
- ### 4. Red Herrings in Hard Tasks
274
- Task 5 deliberately plants misleading clues. This separates agents that follow rigid patterns from agents that can reason through ambiguity — exactly like real debugging.
275
-
276
- ### 5. Progressive Information Reveal
277
- The agent starts with limited information and must actively choose what to investigate. Each inspection reveals new data. This makes it a genuine investigation — not just a classification task.
278
-
279
- ---
280
-
281
- ## The Two Baselines (Robot Players)
282
-
283
- ### Baseline 1: The Rule-Following Robot (`baseline_heuristic.py`)
284
- This robot follows a fixed checklist:
285
- 1. Check gradients → if exploding, fix learning rate
286
- 2. Check data → if leaking, patch data
287
- 3. Check model modes → if eval, fix mode
288
- 4. Check code → if bug found, fix it
289
- 5. If nothing works, guess "overfitting"
290
-
291
- **Scores:** Perfect on easy/medium tasks, but only 0.35 on Task 5 because its fixed order means it tries to fix gradients before checking model modes — falling for the red herring.
292
-
293
- ### Baseline 2: The Smart AI (`baseline_inference.py`)
294
- This uses GPT-4 to reason about the evidence. It reads the observations, thinks about what to do, and makes decisions. It should score higher on hard tasks because it can reason, not just follow rules.
295
-
296
- ---
297
-
298
- ## The Technology Stack
299
-
300
- | Component | What It Is | Why We Use It |
301
- |-----------|-----------|---------------|
302
- | **Python 3.12** | Programming language | Modern, fast, supports type hints |
303
- | **PyTorch (CPU)** | Machine learning framework | Real neural networks, real gradients (Meta's framework!) |
304
- | **FastAPI** | Web framework | Fast, modern, auto-generates docs |
305
- | **OpenEnv** | RL environment framework | Standard interface for AI agents (step/reset/state) |
306
- | **Pydantic** | Data validation | Ensures all data is properly typed |
307
- | **Plotly.js** | Charting library | Live dashboard with interactive charts |
308
- | **Docker** | Containerization | Package everything so it runs anywhere |
309
-
310
- ---
311
-
312
- ## How to Think About This Project
313
-
314
- **Analogy 1: Medical Training Simulator**
315
- Medical students practice on mannequins before treating real patients. This project is a mannequin for AI debugging — the "patients" have known problems, and the "doctor" (AI agent) learns to diagnose them.
316
-
317
- **Analogy 2: Escape Room**
318
- Each task is like an escape room. You're locked in with clues scattered around. Some clues are helpful, some are red herrings. You need to investigate systematically, not randomly try everything.
319
-
320
- **Analogy 3: Car Mechanic School**
321
- A car comes in making weird noises. The mechanic can:
322
- - Check the engine (inspect_gradients)
323
- - Check the fuel (inspect_data_batch)
324
- - Check the gearbox (inspect_model_modes)
325
- - Read the error codes (inspect_code)
326
- Then they fix the right part and test-drive it to confirm.
327
-
328
- ---
329
-
330
- ## Summary
331
-
332
- | Question | Answer |
333
- |----------|--------|
334
- | **What?** | A practice environment where AI agents learn to debug broken PyTorch training runs |
335
- | **Why?** | Real ML debugging costs companies millions. Training AI to do it has huge value. |
336
- | **How?** | 7 mystery cases with real PyTorch training (CNN + MLP), progressive clue reveal, and smart scoring |
337
- | **What's special?** | Real 20-epoch training, dual architectures, context-gated rewards, code-level debugging, red herrings, difficulty scaling |
338
- | **Who's it for?** | AI researchers building smarter debugging agents |
339
- | **Built with?** | Python, PyTorch, FastAPI, OpenEnv, Pydantic, Docker |
340
- | **For what event?** | Meta PyTorch OpenEnv Hackathon x Scaler School of Technology |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/ml-training-debugger-spec.md DELETED
The diff for this file is too large to render. See raw diff
 
ml_training_debugger/models.py CHANGED
@@ -113,7 +113,6 @@ class EpisodeState(BaseModel):
113
 
114
  Rules from spec Section 10 — Dynamic available_actions:
115
  - restart_run: only after fix_action_taken
116
- - rollback_checkpoint: only after restart_after_fix
117
  - fix_code: only after code_inspected
118
  - mark_diagnosed: disappears after diagnosis_submitted
119
  """
@@ -133,8 +132,6 @@ class EpisodeState(BaseModel):
133
  actions.append("fix_code")
134
  if self.fix_action_taken:
135
  actions.append("restart_run")
136
- if self.restart_after_fix:
137
- actions.append("rollback_checkpoint")
138
  if not self.diagnosis_submitted:
139
  actions.append("mark_diagnosed")
140
  return actions
@@ -154,7 +151,6 @@ ALL_ACTION_TYPES: set[str] = {
154
  "fix_code",
155
  "restart_run",
156
  "mark_diagnosed",
157
- "rollback_checkpoint",
158
  }
159
 
160
 
 
113
 
114
  Rules from spec Section 10 — Dynamic available_actions:
115
  - restart_run: only after fix_action_taken
 
116
  - fix_code: only after code_inspected
117
  - mark_diagnosed: disappears after diagnosis_submitted
118
  """
 
132
  actions.append("fix_code")
133
  if self.fix_action_taken:
134
  actions.append("restart_run")
 
 
135
  if not self.diagnosis_submitted:
136
  actions.append("mark_diagnosed")
137
  return actions
 
151
  "fix_code",
152
  "restart_run",
153
  "mark_diagnosed",
 
154
  }
155
 
156
 
run_all_baselines.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run heuristic + multiple LLM baselines and show comparison table.
3
+
4
+ Usage:
5
+ python3 run_all_baselines.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import os
12
+ import sys
13
+ import time
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ from pathlib import Path
16
+
17
+ # Load .env
18
+ _env_path = Path(__file__).parent / ".env"
19
+ if _env_path.exists():
20
+ for line in _env_path.read_text().splitlines():
21
+ line = line.strip()
22
+ if line and not line.startswith("#") and "=" in line:
23
+ key, _, value = line.partition("=")
24
+ os.environ.setdefault(key.strip(), value.strip())
25
+
26
+ from baseline_heuristic import ALL_TASKS
27
+ from baseline_heuristic import run_heuristic_episode
28
+ from baseline_inference import PROVIDERS, run_llm_episode
29
+
30
+ try:
31
+ from openai import OpenAI
32
+ except ImportError:
33
+ print("Error: pip install openai")
34
+ sys.exit(1)
35
+
36
+
37
+ def run_heuristic() -> dict[str, float]:
38
+ scores = {}
39
+ for task_id in ALL_TASKS:
40
+ scores[task_id] = round(run_heuristic_episode(task_id), 4)
41
+ return scores
42
+
43
+
44
+ def run_llm_provider(provider_name: str, model: str | None = None) -> dict[str, float]:
45
+ prov = PROVIDERS[provider_name]
46
+ api_key = os.environ.get(prov["env_key"])
47
+ if not api_key:
48
+ return {t: -1.0 for t in ALL_TASKS} # -1 = no key
49
+
50
+ model_name = model or prov["default_model"]
51
+ client_kwargs: dict = {"api_key": api_key}
52
+ if prov["base_url"]:
53
+ client_kwargs["base_url"] = prov["base_url"]
54
+ client = OpenAI(**client_kwargs)
55
+
56
+ scores: dict[str, float] = {}
57
+ for task_id in ALL_TASKS:
58
+ try:
59
+ score = run_llm_episode(task_id, client, model_name)
60
+ scores[task_id] = round(score, 4)
61
+ print(f" [{provider_name}/{model_name}] {task_id}: {score:.4f}", file=sys.stderr)
62
+ except Exception as e:
63
+ err_str = str(e)[:80]
64
+ print(f" [{provider_name}/{model_name}] {task_id}: ERROR — {err_str}", file=sys.stderr)
65
+ scores[task_id] = 0.0
66
+ return scores
67
+
68
+
69
+ def main() -> None:
70
+ print("Running all baselines...\n", file=sys.stderr)
71
+
72
+ results: dict[str, dict[str, float]] = {}
73
+
74
+ # Run heuristic first (fast, deterministic)
75
+ print("--- Heuristic baseline ---", file=sys.stderr)
76
+ results["Heuristic"] = run_heuristic()
77
+ print(f" Done: {json.dumps(results['Heuristic'])}", file=sys.stderr)
78
+
79
+ # Run LLM providers sequentially (avoids thread hang issues)
80
+ llm_runs = [
81
+ ("Cerebras/Llama-3.1-8B", "cerebras", "llama3.1-8b"),
82
+ ("Groq/Llama-3.1-8B", "groq", "llama-3.1-8b-instant"),
83
+ ]
84
+
85
+ for label, provider, model in llm_runs:
86
+ print(f"\n--- {label} ---", file=sys.stderr)
87
+ try:
88
+ results[label] = run_llm_provider(provider, model)
89
+ except Exception as e:
90
+ print(f" {label}: FAILED — {e}", file=sys.stderr)
91
+ results[label] = {t: 0.0 for t in ALL_TASKS}
92
+
93
+ # Print comparison table
94
+ print("\n" + "=" * 80)
95
+ print("BASELINE COMPARISON TABLE")
96
+ print("=" * 80)
97
+
98
+ headers = list(results.keys())
99
+ print(f"\n{'Task':<12}", end="")
100
+ for h in headers:
101
+ print(f"{h:>25}", end="")
102
+ print()
103
+ print("-" * (12 + 25 * len(headers)))
104
+
105
+ for task_id in ALL_TASKS:
106
+ print(f"{task_id:<12}", end="")
107
+ for h in headers:
108
+ score = results[h].get(task_id, 0.0)
109
+ if score < 0:
110
+ print(f"{'no key':>25}", end="")
111
+ else:
112
+ print(f"{score:>25.4f}", end="")
113
+ print()
114
+
115
+ print("-" * (12 + 25 * len(headers)))
116
+
117
+ # Averages
118
+ print(f"{'AVERAGE':<12}", end="")
119
+ for h in headers:
120
+ valid = [v for v in results[h].values() if v >= 0]
121
+ avg = sum(valid) / len(valid) if valid else 0
122
+ print(f"{avg:>25.4f}", end="")
123
+ print()
124
+
125
+ # Save JSON
126
+ print(json.dumps(results, indent=2))
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
server/_baseline_results.py CHANGED
@@ -2,9 +2,11 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  from typing import Optional
6
 
7
- # Store last completed episode results
 
8
  _last_results: dict[str, dict] = {}
9
 
10
 
@@ -12,16 +14,19 @@ def store_grader_result(
12
  session_id: str, score: float, task_id: str, steps: int
13
  ) -> None:
14
  """Store a grader result for retrieval."""
15
- _last_results[session_id] = {
16
  "score": round(score, 4),
17
  "task_id": task_id,
18
  "steps": steps,
19
  }
20
- _last_results["_latest"] = _last_results[session_id]
 
 
21
 
22
 
23
  def get_last_grader_result(session_id: Optional[str] = None) -> dict | None:
24
  """Get grader result for a session, or the most recent one."""
25
- if session_id:
26
- return _last_results.get(session_id)
27
- return _last_results.get("_latest")
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import threading
6
  from typing import Optional
7
 
8
+ # Thread-safe store for completed episode results
9
+ _lock = threading.Lock()
10
  _last_results: dict[str, dict] = {}
11
 
12
 
 
14
  session_id: str, score: float, task_id: str, steps: int
15
  ) -> None:
16
  """Store a grader result for retrieval."""
17
+ entry = {
18
  "score": round(score, 4),
19
  "task_id": task_id,
20
  "steps": steps,
21
  }
22
+ with _lock:
23
+ _last_results[session_id] = entry
24
+ _last_results["_latest"] = entry
25
 
26
 
27
  def get_last_grader_result(session_id: Optional[str] = None) -> dict | None:
28
  """Get grader result for a session, or the most recent one."""
29
+ with _lock:
30
+ if session_id:
31
+ return _last_results.get(session_id)
32
+ return _last_results.get("_latest")
server/environment.py CHANGED
@@ -467,9 +467,6 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
467
  state.diagnosis_submitted = True
468
  session.done = True
469
 
470
- elif at == "rollback_checkpoint":
471
- pass # No-op for now
472
-
473
  return is_correct_fix, convergence
474
 
475
  def _check_convergence(self, session: SessionData) -> bool:
@@ -516,11 +513,21 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
516
  session = self._get_session()
517
  if session is None:
518
  return {"status": "no_active_episode"}
 
519
  return {
520
  "status": "active",
521
  "task_id": session.scenario.task_id,
522
- "step_count": session.state.step_count,
523
  "done": session.done,
 
 
 
 
 
 
 
 
 
524
  }
525
 
526
  def get_last_completed(self, session_id: str | None = None) -> dict | None:
 
467
  state.diagnosis_submitted = True
468
  session.done = True
469
 
 
 
 
470
  return is_correct_fix, convergence
471
 
472
  def _check_convergence(self, session: SessionData) -> bool:
 
513
  session = self._get_session()
514
  if session is None:
515
  return {"status": "no_active_episode"}
516
+ st = session.state
517
  return {
518
  "status": "active",
519
  "task_id": session.scenario.task_id,
520
+ "step_count": st.step_count,
521
  "done": session.done,
522
+ "gradients_inspected": st.gradients_inspected,
523
+ "data_inspected": st.data_inspected,
524
+ "model_modes_inspected": st.model_modes_inspected,
525
+ "model_weights_inspected": st.model_weights_inspected,
526
+ "code_inspected": st.code_inspected,
527
+ "fix_action_taken": st.fix_action_taken,
528
+ "restart_after_fix": st.restart_after_fix,
529
+ "diagnosis_submitted": st.diagnosis_submitted,
530
+ "available_actions": st.compute_available_actions(),
531
  }
532
 
533
  def get_last_completed(self, session_id: str | None = None) -> dict | None:
tests/test_models.py CHANGED
@@ -93,8 +93,6 @@ class TestEpisodeState:
93
  assert "mark_diagnosed" in actions
94
  assert "fix_code" not in actions
95
  assert "restart_run" not in actions
96
- assert "rollback_checkpoint" not in actions
97
-
98
  def test_fix_code_available_after_code_inspected(self):
99
  state = EpisodeState(code_inspected=True)
100
  actions = state.compute_available_actions()
@@ -105,11 +103,6 @@ class TestEpisodeState:
105
  actions = state.compute_available_actions()
106
  assert "restart_run" in actions
107
 
108
- def test_rollback_available_after_restart(self):
109
- state = EpisodeState(restart_after_fix=True)
110
- actions = state.compute_available_actions()
111
- assert "rollback_checkpoint" in actions
112
-
113
  def test_mark_diagnosed_disappears_after_submission(self):
114
  state = EpisodeState(diagnosis_submitted=True)
115
  actions = state.compute_available_actions()
 
93
  assert "mark_diagnosed" in actions
94
  assert "fix_code" not in actions
95
  assert "restart_run" not in actions
 
 
96
  def test_fix_code_available_after_code_inspected(self):
97
  state = EpisodeState(code_inspected=True)
98
  actions = state.compute_available_actions()
 
103
  actions = state.compute_available_actions()
104
  assert "restart_run" in actions
105
 
 
 
 
 
 
106
  def test_mark_diagnosed_disappears_after_submission(self):
107
  state = EpisodeState(diagnosis_submitted=True)
108
  actions = state.compute_available_actions()