UjjwalPardeshi commited on
Commit
d222546
Β·
1 Parent(s): 8435256

update readme

Browse files
Files changed (2) hide show
  1. README.md +270 -142
  2. baseline_heuristic.py +35 -28
README.md CHANGED
@@ -1,114 +1,169 @@
1
  # PyTorch Training Run Debugger
2
 
3
- **OpenEnv RL Environment** β€” Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, Round 1
4
 
5
  An AI agent debugs broken PyTorch training runs by investigating gradients, model weights, data pipelines, and source code to diagnose and fix real ML failure patterns.
6
 
7
- ## What Is This?
8
 
9
- This environment recreates the experience of an ML engineer facing a broken PyTorch training job. The agent receives a snapshot of a failing training run and must:
 
 
 
 
 
 
 
 
10
 
11
  1. **Investigate** β€” inspect gradients, data batches, model weights, model modes, and code
12
- 2. **Diagnose** β€” identify the root cause from a closed set of known ML failures
13
- 3. **Fix** β€” apply the correct intervention (reduce LR, patch data, fix model mode, etc.)
14
- 4. **Verify** β€” restart training and confirm recovery before submitting diagnosis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- ### Key Differentiators
17
 
18
- - **Real PyTorch mini-training** β€” 20 real forward+backward epochs per reset, cached for instant replay. Loss/accuracy curves come from real training, not parametric formulas.
19
- - **Dual model architectures** β€” SimpleCNN (~50K params) and SimpleMLP (~20K params) randomly selected per episode
20
- - **Context-gated reward shaping** β€” Penalty fires only when agent ignores evidence it already gathered; no penalty for reasonable priors
21
- - **Progressive information reveal** β€” Gradient stats, weight stats, data batch stats, confusion matrices only populated after corresponding inspection actions
22
- - **7 tasks with difficulty scaling** β€” Easy to hard, with configurable difficulty level (1-5) per task
 
 
 
 
23
 
24
- ## Environment Design
25
 
26
- ### Observation Space (`MLTrainingObservation`)
27
 
28
- | Field | Type | Visibility |
29
- |-------|------|-----------|
 
 
30
  | `training_loss_history` | `list[float]` (20 epochs) | Always |
31
  | `val_accuracy_history` | `list[float]` (20 epochs) | Always |
32
  | `val_loss_history` | `list[float]` (20 epochs) | Always |
33
  | `current_config` | `TrainingConfig` | Always |
34
- | `error_log` | `Optional[str]` | Always |
35
  | `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
36
- | `model_weight_stats` | `Optional[list[ModelWeightStats]]` | After `inspect_model_weights` |
37
- | `data_batch_stats` | `Optional[DataBatchStats]` | After `inspect_data_batch` |
38
- | `model_mode_info` | `Optional[dict[str, str]]` | After `inspect_model_modes` |
39
- | `code_snippet` | `Optional[CodeSnippet]` | After `inspect_code` |
40
  | `available_actions` | `list[str]` | Always (dynamic) |
41
  | `episode_state` | `EpisodeState` | Always |
42
 
43
- ### Action Space (`MLTrainingAction`)
 
 
 
 
44
 
45
- | Category | Actions |
46
- |----------|---------|
47
- | **Investigation** | `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code` |
48
- | **Fix** | `modify_config`, `add_callback`, `replace_optimizer`, `patch_data_loader`, `fix_model_mode`, `fix_code` |
49
- | **Terminal** | `restart_run`, `mark_diagnosed` |
 
50
 
51
- Dynamic availability: `restart_run` requires a fix first; `fix_code` requires code inspection; `mark_diagnosed` disappears after submission.
 
 
 
 
 
 
52
 
53
- ### Diagnosis Enum
 
 
54
 
55
- | Value | Description |
56
- |-------|-------------|
57
- | `lr_too_high` | Learning rate too large |
58
- | `vanishing_gradients` | Gradients decay to near-zero |
59
- | `data_leakage` | Validation samples in training |
60
- | `overfitting` | Model memorizing, failing to generalize |
61
- | `batchnorm_eval_mode` | Model in eval mode during training |
62
- | `code_bug` | Bug in PyTorch training code |
63
 
64
- ### Reward Function
65
 
66
- | Event | Reward | Gate |
67
- |-------|--------|------|
68
- | Any step | -0.01 | Flat, unconditional |
69
- | First-time inspection | +0.05 | Per inspection type |
70
- | `add_callback` after normal gradients | -0.20 | `gradients_inspected AND gradients_were_normal` |
 
 
 
 
71
  | Invalid action | -0.05 | Action not in `available_actions` |
72
- | Correct diagnosis | +0.50 | Equality check |
73
- | Wrong diagnosis | -0.30 | Inequality check |
74
- | Convergence after fix+restart | +0.40 | All gates met |
75
 
76
- ## Tasks
77
 
78
- | ID | Difficulty | Root Cause | Description |
79
- |----|-----------|------------|-------------|
80
- | `task_001` | Easy | `lr_too_high` | Exploding gradients β€” all layers show `is_exploding: True`, NaN in error log |
81
- | `task_002` | Easy | `vanishing_gradients` | Vanishing gradients β€” deeper layers show `is_vanishing: True`, flat loss curve |
82
- | `task_003` | Medium | `data_leakage` | Silent data leakage β€” suspiciously high val accuracy, `class_overlap_score > 0.5` |
83
- | `task_004` | Medium | `overfitting` | Train-val divergence β€” loss approaches 0 while val loss climbs |
84
- | `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
85
- | `task_006` | Hard | `code_bug` | PyTorch code bug β€” agent must read and fix actual Python code (4 bug variants) |
86
- | `task_007` | Med-Hard | `scheduler_misconfigured` | LR scheduler with wrong gamma/step_size β€” training stagnates after initial progress |
87
 
88
- All tasks support `difficulty_level` (1-5) via reset: `{"type": "reset", "data": {"task_id": "task_005", "difficulty_level": 4}}`
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  ## Baseline Scores
91
 
92
- ### Heuristic vs LLM Comparison (3 agents, 7 tasks)
93
 
94
- | Task | Difficulty | Heuristic | Llama 3.3 70B | Llama 3.1 8B | Notes |
95
- |------|-----------|-----------|---------------|--------------|-------|
96
- | `task_001` | Easy | **1.00** | 1.00 | 0.60 | 8B finds issue but misses fix+restart sequence |
97
- | `task_002` | Easy | **1.00** | 1.00 | 0.05 | 8B barely investigates β€” struggles with multi-step reasoning |
98
- | `task_003` | Medium | **1.00** | 0.40 | 0.40 | Both LLMs explore inefficiently vs heuristic's direct path |
99
- | `task_004` | Medium | 0.45 | 0.45 | **0.60** | LLM's flexible investigation finds overfitting signals heuristic misses |
100
- | `task_005` | Hard | **1.00** | 1.00 | 1.00 | All agents find eval mode via model inspection |
101
- | `task_006` | Hard | **1.00** | β€” | 0.60–1.00 | Code debugging β€” 8B varies across providers |
102
- | `task_007` | Med-Hard | **1.00** | β€” | 0.60 | Scheduler detection β€” heuristic's pattern matching excels |
103
- | **Average** | | **0.92** | **0.69*** | **0.55** | |
104
 
105
- *Llama 3.3 70B results are partial (5/7 tasks before rate limit). Projected average ~0.69.
106
 
107
- **Key insights:**
108
- 1. **Model size matters:** 70B scores ~25% higher than 8B β€” the environment scales with model capability
109
- 2. **Heuristic beats LLMs:** A domain-specific decision tree (0.92) outperforms general-purpose LLMs (0.55-0.69) β€” proving the environment rewards systematic debugging strategy
110
- 3. **Task 4 is the exception:** LLMs outperform the heuristic on overfitting because real training curves require flexible reasoning, not rigid pattern matching
111
- 4. **8B struggles on multi-step tasks:** Task 2 (0.05) shows small models can't maintain investigation strategy across many steps
112
 
113
  ### Running Baselines
114
 
@@ -116,40 +171,86 @@ All tasks support `difficulty_level` (1-5) via reset: `{"type": "reset", "data":
116
  # Heuristic (deterministic, no API key, bit-exact reproducible)
117
  python3 baseline_heuristic.py
118
 
119
- # LLM (multi-provider support β€” set API key in .env)
120
- python3 baseline_inference.py # Groq (default, free)
121
- python3 baseline_inference.py --provider cerebras # Cerebras (free)
122
- python3 baseline_inference.py --provider gemini # Google Gemini
123
  python3 baseline_inference.py --provider openai # OpenAI GPT-4o
124
 
125
  # Run all baselines with comparison table
126
  python3 run_all_baselines.py
127
  ```
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  ## Setup
130
 
131
  ### Local Development
132
 
133
  ```bash
134
- # Create virtual environment
135
  python3 -m venv .venv
136
  source .venv/bin/activate
137
 
138
- # Install dependencies
139
  pip install torch --index-url https://download.pytorch.org/whl/cpu
140
  pip install openenv-core pydantic fastapi uvicorn
141
-
142
- # Install dev tools
143
- pip install pytest pytest-cov black ruff isort
144
 
145
  # Start server
146
  uvicorn server.app:app --host 0.0.0.0 --port 7860
147
 
148
- # Run tests
149
  pytest tests/ -v --cov=ml_training_debugger
150
 
151
- # Run baseline
152
- python baseline_heuristic.py
153
  ```
154
 
155
  ### Docker
@@ -160,82 +261,109 @@ docker run -p 7860:7860 pytorch-debugger
160
  curl http://localhost:7860/health
161
  ```
162
 
163
- ## API Endpoints
164
 
165
- | Endpoint | Method | Description |
166
- |----------|--------|-------------|
167
- | `/health` | GET | `{"status": "ready", "tasks": 6}` |
168
- | `/tasks` | GET | Task list with action schema |
169
- | `/grader` | POST | Grader score for last completed episode |
170
- | `/baseline` | POST | Run baseline, return scores for all 6 tasks |
171
- | `/dashboard` | GET | Live diagnostic dashboard (Plotly.js, 4-panel) |
172
- | `/ws` | WebSocket | Primary agent interface |
173
- | `/reset` | POST | Reset environment (framework) |
174
- | `/step` | POST | Execute action (framework) |
175
- | `/state` | GET | Current state (framework) |
176
- | `/schema` | GET | Action/observation schemas (framework) |
177
- | `/docs` | GET | Swagger UI (framework) |
178
-
179
- ### WebSocket Message Format
180
-
181
- The primary agent interface is the WebSocket endpoint at `/ws`. Messages use JSON:
182
-
183
- **Reset** (start a new episode, optionally select task):
184
- ```json
185
- {"type": "reset"}
186
- {"type": "reset", "data": {"task_id": "task_003", "seed": 42}}
187
  ```
188
- Without `data`, defaults to `task_001`. With `data`, selects the specified task.
189
 
190
- Returns: `{"type": "observation", "data": {"observation": {...}, "reward": 0.0, "done": false}}`
191
 
192
- **Step** (execute an action):
193
- ```json
194
- {"type": "step", "data": {"action_type": "inspect_gradients"}}
195
  ```
196
- ```json
197
- {"type": "step", "data": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  ```
199
- ```json
200
- {"type": "step", "data": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
201
- ```
202
- Returns: `{"type": "observation", "data": {"observation": {...}, "reward": float, "done": bool}}`
203
 
204
- ### HTTP vs WebSocket
 
 
 
 
 
 
205
 
206
- **WebSocket `/ws`** is the primary agent interface β€” it maintains a persistent session across reset/step/diagnose. Use this for full episodes.
207
 
208
- **HTTP `POST /reset` and `POST /step`** are stateless per the OpenEnv framework design β€” each request creates a fresh environment instance. Use these for single-action queries or health checks, not full episodes.
 
 
 
 
209
 
210
- **Custom endpoints** (`POST /baseline`, `POST /grader`, `GET /tasks`, `GET /health`) work independently of sessions.
211
 
212
- ## Validation Suite
213
 
214
- 8/8 validation checks pass β€” served live at `GET /validation-report`:
215
 
216
- **Methodology:** Real PyTorch 20-epoch mini-training with fault injection. Each fault type is validated with behavioral checks (gradient detection, loss patterns, model mode, code fix acceptance). Both SimpleCNN and SimpleMLP architectures verified.
217
 
218
- **Coverage:** Exploding gradients, vanishing gradients, data leakage, overfitting, BatchNorm eval mode, code bugs (4 variants), scheduler misconfigured, dual architecture.
 
 
 
219
 
220
- ## Architecture
 
 
 
 
 
 
 
221
 
222
- - **Python 3.12** Β· PyTorch 2.5.1 CPU-only Β· openenv-core v0.2.2
223
- - **Dual model architectures**: SimpleCNN (~50K params) + SimpleMLP (~20K params)
224
- - **Real 20-epoch mini-training** per reset (cached per task/seed for instant replay)
225
- - Typed Pydantic models everywhere β€” no `Dict[str, Any]`
226
- - `import torch` in every core module β€” zero numpy in core
227
- - Session isolation via per-session `EpisodeState`
228
- - Deterministic reproducibility via `torch.manual_seed()`
229
- - **251 tests, 95% coverage**
230
 
231
- ### Docker Image Size
 
 
 
 
 
 
 
 
232
 
233
- The Docker image is **885MB** (optimized from 1.96GB via multi-stage build, torch 2.5.1, `strip --strip-unneeded`, and removal of unused transitive dependencies). The core `libtorch_cpu.so` (329MB stripped) is the irreducible minimum for real `torch.nn.Module`, `torch.autograd`, and `model.state_dict()` support β€” the intentional trade-off for authentic PyTorch computation vs synthetic data.
234
 
235
- ### Research Paper
 
 
 
 
 
 
 
236
 
237
  See [PAPER.md](PAPER.md) β€” "Context-Gated Reward Shaping for Evidence-Based ML Debugging"
238
 
239
- ### Project Explanation
 
 
240
 
241
- See [EXPLANATION.md](EXPLANATION.md) β€” full project explanation in simple language
 
1
  # PyTorch Training Run Debugger
2
 
3
+ **OpenEnv RL Environment** | Meta PyTorch OpenEnv Hackathon x Scaler School of Technology
4
 
5
  An AI agent debugs broken PyTorch training runs by investigating gradients, model weights, data pipelines, and source code to diagnose and fix real ML failure patterns.
6
 
7
+ ---
8
 
9
+ ## The Problem
10
+
11
+ ML teams spend 15-25% of engineer time debugging silent training failures β€” runs that produce no error, no crash, just mysteriously bad metrics. Each misdiagnosed restart wastes GPU compute at $2-8/hour/card. The diagnostic process is hard because multiple symptoms point to multiple causes, some bugs produce no error at all, and fixing the wrong thing wastes hours.
12
+
13
+ No existing OpenEnv environment covers this domain.
14
+
15
+ ## What This Does
16
+
17
+ The environment recreates the experience of an ML engineer facing a broken training job. The agent receives a snapshot of a failing training run and must:
18
 
19
  1. **Investigate** β€” inspect gradients, data batches, model weights, model modes, and code
20
+ 2. **Diagnose** β€” identify the root cause from 7 known ML failure types
21
+ 3. **Fix** β€” apply the correct intervention
22
+ 4. **Verify** β€” restart training and confirm recovery before submitting
23
+
24
+ The agent starts with limited information (loss curves, config, error log) and must actively choose what to investigate. Each inspection reveals new data β€” gradient norms, class overlap scores, model train/eval modes, or buggy source code. This makes it a genuine investigation, not just a classification task.
25
+
26
+ ## What Makes This Different
27
+
28
+ ### Real PyTorch Model Internals
29
+
30
+ Every gradient comes from real `torch.autograd`. Every weight stat comes from real `model.state_dict()`. The environment instantiates actual `torch.nn.Module` models (SimpleCNN ~50K params, SimpleMLP ~20K params), runs 20 real forward+backward epochs per reset, and extracts real tensor statistics. Not synthetic formulas β€” real PyTorch computation, cached for instant replay.
31
+
32
+ ### Context-Gated Reward Shaping
33
+
34
+ Standard RL environments use stateless rewards: "did action X happen?" This environment tracks the agent's information state and conditions penalties on what the agent has already observed.
35
+
36
+ An agent that adds gradient clipping *before* inspecting gradients follows a reasonable prior β€” **no penalty**. An agent that inspects gradients, sees they are normal, and *then* adds gradient clipping is ignoring counter-evidence β€” **-0.20 penalty**.
37
+
38
+ The gate requires two conditions to be jointly true (`gradients_inspected AND gradients_were_normal`), both of which depend on prior agent actions. This encodes a transferable skill into the reward signal: don't ignore what you've already learned.
39
+
40
+ ### Code-Level Debugging
41
+
42
+ Task 6 presents actual buggy PyTorch training loops. The agent reads real Python code, identifies the buggy line, and submits a line-by-line fix. Four bug variants: `model.eval()` instead of `model.train()`, `.detach()` killing gradient flow, missing `optimizer.zero_grad()`, and `inplace=True` on ReLU corrupting the computation graph.
43
+
44
+ Fix validation uses a 4-strategy pipeline: whitespace normalization, token-stream comparison via Python's `tokenize` module, semantic pattern matching, and `ast.parse()` fallback. This handles the messy fixes that LLM agents actually produce (trailing spaces, inline comments, different indentation).
45
+
46
+ ### Red Herring Injection
47
+
48
+ Task 5 (BatchNorm eval mode) deliberately plants misleading signals: a gradient spike in the FC layer that doesn't cross the exploding threshold, a GPU memory warning at 91%, and near-vanishing gradients in conv1. The real problem is only visible through model mode inspection. This separates agents that follow rigid patterns from agents that can reason through ambiguity.
49
+
50
+ ## Tasks
51
 
52
+ 7 failure scenarios across 3 difficulty tiers, each with configurable difficulty level (1-5):
53
 
54
+ | ID | Difficulty | Root Cause | What Goes Wrong |
55
+ |----|-----------|------------|-----------------|
56
+ | `task_001` | Easy | `lr_too_high` | All gradient layers explode, NaN in loss. Direct signal β€” inspect gradients, reduce LR. |
57
+ | `task_002` | Easy | `vanishing_gradients` | Deeper layers show vanishing norms, loss stays flat. Model can't learn. |
58
+ | `task_003` | Medium | `data_leakage` | Suspiciously high val accuracy from epoch 1. `class_overlap_score > 0.5` confirms test data leaked into training. Red herring note about "architecture upgrade." |
59
+ | `task_004` | Medium | `overfitting` | Train loss drops to near-zero while val loss climbs. Classic memorization pattern. |
60
+ | `task_005` | Hard | `batchnorm_eval_mode` | Slow degradation with compound red herrings. Gradients look normal. The real problem: all layers stuck in eval mode. |
61
+ | `task_006` | Hard | `code_bug` | Metrics are anomalous but gradients/data/modes look fine. Root cause is in the Python training loop β€” 4 possible bug variants. |
62
+ | `task_007` | Med-Hard | `scheduler_misconfigured` | Training improves initially then stagnates. LR scheduler decays too aggressively (low gamma, small step size). |
63
 
64
+ ### How Difficulty Scales
65
 
66
+ Easy tasks have one obvious signal (all gradients exploding). Medium tasks require checking multiple sources and ruling out alternatives. Hard tasks deliberately mislead β€” the most obvious signal is wrong, and the real problem is hidden behind layers of investigation.
67
 
68
+ ## Observation Space
69
+
70
+ | Field | Type | When Visible |
71
+ |-------|------|-------------|
72
  | `training_loss_history` | `list[float]` (20 epochs) | Always |
73
  | `val_accuracy_history` | `list[float]` (20 epochs) | Always |
74
  | `val_loss_history` | `list[float]` (20 epochs) | Always |
75
  | `current_config` | `TrainingConfig` | Always |
76
+ | `error_log` | `str` or `null` | Always |
77
  | `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
78
+ | `model_weight_stats` | `list[ModelWeightStats]` | After `inspect_model_weights` |
79
+ | `data_batch_stats` | `DataBatchStats` | After `inspect_data_batch` |
80
+ | `model_mode_info` | `dict[str, str]` | After `inspect_model_modes` |
81
+ | `code_snippet` | `CodeSnippet` | After `inspect_code` |
82
  | `available_actions` | `list[str]` | Always (dynamic) |
83
  | `episode_state` | `EpisodeState` | Always |
84
 
85
+ Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_snippet` start as `null` and are only populated after the agent explicitly requests them. The agent must decide what to investigate.
86
+
87
+ ## Action Space
88
+
89
+ 14 action types in 3 categories:
90
 
91
+ **Investigation** β€” reveal hidden observation fields:
92
+ - `inspect_gradients` β€” per-layer gradient norms, is_exploding/is_vanishing flags
93
+ - `inspect_data_batch` β€” label distribution, class overlap score, confusion matrix
94
+ - `inspect_model_modes` β€” train/eval mode per layer
95
+ - `inspect_model_weights` β€” weight norms, dead neurons, NaN/Inf detection
96
+ - `inspect_code` β€” the actual Python training loop (Task 6)
97
 
98
+ **Fix** β€” apply an intervention:
99
+ - `modify_config` β€” change learning_rate, weight_decay, batch_size, optimizer, etc.
100
+ - `add_callback` β€” add gradient clipping
101
+ - `patch_data_loader` β€” fix data pipeline
102
+ - `fix_model_mode` β€” switch model to train mode
103
+ - `fix_code` β€” fix a specific line of code (requires line number + replacement)
104
+ - `replace_optimizer` β€” switch optimizer
105
 
106
+ **Terminal** β€” end the episode:
107
+ - `restart_run` β€” restart training (only available after a fix)
108
+ - `mark_diagnosed` β€” submit diagnosis from 7 possible root causes
109
 
110
+ Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
 
 
 
 
 
 
 
111
 
112
+ ## Reward Function
113
 
114
+ Per-step signal, separate from the grader. Hard cap at [-1.0, 1.0].
115
+
116
+ | Event | Reward | Condition |
117
+ |-------|--------|-----------|
118
+ | Any step | -0.01 | Flat, unconditional (encourages efficiency) |
119
+ | First-time inspection | +0.05 | Per inspection type, first time only |
120
+ | Correct diagnosis | +0.50 | `diagnosis == root_cause` |
121
+ | Wrong diagnosis | -0.30 | `diagnosis != root_cause` |
122
+ | Convergence after fix+restart | +0.40 | Fix applied, restarted, training recovers |
123
  | Invalid action | -0.05 | Action not in `available_actions` |
124
+ | Wrong code fix | -0.10 | `fix_code` with incorrect line/replacement |
125
+ | **Context-gated penalty** | **-0.20** | `gradients_inspected AND gradients_were_normal AND action == add_callback` |
 
126
 
127
+ The step penalty is flat -0.01 (never multiplied by step count). Investigation bonuses fire once per type. The context-gated penalty requires the agent to have previously inspected gradients and found them normal β€” it cannot fire before inspection.
128
 
129
+ ## Grading
 
 
 
 
 
 
 
 
130
 
131
+ Each task has a separate grader that evaluates the complete `EpisodeState` at episode end, returning a normalized 0.0-1.0 score. The grader is **not** a sum of step rewards β€” it's a holistic evaluation of whether the agent investigated correctly, applied the right fix, restarted training, and diagnosed accurately.
132
+
133
+ Example (Task 5 β€” BatchNorm Eval):
134
+
135
+ | Component | Points |
136
+ |-----------|--------|
137
+ | Inspected gradients | +0.05 |
138
+ | Inspected model modes (the revealing action) | +0.05 |
139
+ | Fixed model mode | +0.25 |
140
+ | Restarted training | +0.30 |
141
+ | Correct diagnosis | +0.40 |
142
+ | Fell for red herring (add_callback after normal gradients) | -0.20 |
143
+
144
+ An agent that chases the gradient spike red herring loses 0.20 points. An agent that goes straight to model modes and finds the real problem scores 1.0.
145
 
146
  ## Baseline Scores
147
 
148
+ ### Heuristic vs LLM Comparison
149
 
150
+ | Task | Difficulty | Heuristic | Llama 3.3 70B | Llama 3.1 8B |
151
+ |------|-----------|-----------|---------------|--------------|
152
+ | `task_001` | Easy | **1.00** | 1.00 | 0.60 |
153
+ | `task_002` | Easy | **1.00** | 1.00 | 0.05 |
154
+ | `task_003` | Medium | **1.00** | 0.40 | 0.40 |
155
+ | `task_004` | Medium | **1.00** | 0.45 | 0.60 |
156
+ | `task_005` | Hard | **1.00** | 1.00 | 1.00 |
157
+ | `task_006` | Hard | **1.00** | β€” | 0.60-1.00 |
158
+ | `task_007` | Med-Hard | **1.00** | β€” | 0.60 |
159
+ | **Average** | | **1.00** | ~0.69* | 0.55 |
160
 
161
+ *Llama 3.3 70B results are partial (5/7 tasks before rate limit).
162
 
163
+ **What this tells you:**
164
+ - **Model size matters:** 70B scores ~25% higher than 8B. The environment scales with model capability.
165
+ - **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
166
+ - **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
 
167
 
168
  ### Running Baselines
169
 
 
171
  # Heuristic (deterministic, no API key, bit-exact reproducible)
172
  python3 baseline_heuristic.py
173
 
174
+ # LLM (multi-provider support)
175
+ python3 baseline_inference.py # Groq β€” Llama 3.3 70B (free)
176
+ python3 baseline_inference.py --provider cerebras # Cerebras β€” Llama 3.1 8B (free)
177
+ python3 baseline_inference.py --provider gemini # Google Gemini 2.0 Flash
178
  python3 baseline_inference.py --provider openai # OpenAI GPT-4o
179
 
180
  # Run all baselines with comparison table
181
  python3 run_all_baselines.py
182
  ```
183
 
184
+ ## API
185
+
186
+ ### HTTP Endpoints
187
+
188
+ | Endpoint | Method | Description |
189
+ |----------|--------|-------------|
190
+ | `/health` | GET | `{"status": "ready", "tasks": 7}` |
191
+ | `/tasks` | GET | Task list with IDs, difficulties, action schema |
192
+ | `/grader` | POST | Score for last completed episode |
193
+ | `/baseline` | POST | Run heuristic on all tasks, return scores |
194
+ | `/dashboard` | GET | Live 4-panel diagnostic dashboard |
195
+ | `/validation-report` | GET | Simulation fidelity report |
196
+ | `/curriculum` | GET | Recommended task order (easy to hard, difficulty 1-5) |
197
+ | `/leaderboard` | GET | Sorted episode scores |
198
+ | `/replay/{id}` | GET | Full action/observation trace for an episode |
199
+ | `/schema` | GET | Action/observation JSON schemas |
200
+ | `/docs` | GET | Swagger UI |
201
+
202
+ ### WebSocket (Primary Agent Interface)
203
+
204
+ The WebSocket endpoint at `/ws` maintains session state across a full episode. HTTP endpoints are stateless by framework design.
205
+
206
+ **Reset** (start an episode):
207
+ ```json
208
+ {"type": "reset", "data": {"task_id": "task_003", "seed": 42}}
209
+ ```
210
+
211
+ **Step** (take an action):
212
+ ```json
213
+ {"type": "step", "data": {"action_type": "inspect_gradients"}}
214
+ {"type": "step", "data": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
215
+ {"type": "step", "data": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
216
+ ```
217
+
218
+ **Response format:**
219
+ ```json
220
+ {"type": "observation", "data": {"observation": {...}, "reward": 0.04, "done": false}}
221
+ ```
222
+
223
+ ## Dashboard
224
+
225
+ A live 4-panel diagnostic dashboard at `/dashboard`:
226
+
227
+ 1. **Training Metrics** β€” loss/accuracy curves with Plotly.js
228
+ 2. **Gradient & Weight Heatmap** β€” per-layer bars, color-coded (green=normal, red=exploding, blue=vanishing)
229
+ 3. **Action Timeline & Rewards** β€” step-by-step bars showing reward per action and cumulative reward line
230
+ 4. **Episode Summary** β€” state flags, available actions, code snippet (Task 6)
231
+
232
+ Select a task, click "Run Baseline", and watch the heuristic agent investigate, fix, and diagnose step by step. The charts update live over WebSocket.
233
+
234
  ## Setup
235
 
236
  ### Local Development
237
 
238
  ```bash
 
239
  python3 -m venv .venv
240
  source .venv/bin/activate
241
 
 
242
  pip install torch --index-url https://download.pytorch.org/whl/cpu
243
  pip install openenv-core pydantic fastapi uvicorn
244
+ pip install pytest pytest-cov pytest-asyncio httpx websockets
 
 
245
 
246
  # Start server
247
  uvicorn server.app:app --host 0.0.0.0 --port 7860
248
 
249
+ # Run tests (255 tests, 97% coverage)
250
  pytest tests/ -v --cov=ml_training_debugger
251
 
252
+ # Run heuristic baseline
253
+ python3 baseline_heuristic.py
254
  ```
255
 
256
  ### Docker
 
261
  curl http://localhost:7860/health
262
  ```
263
 
264
+ ### Smoke Test
265
 
266
+ ```bash
267
+ # Verify all critical paths
268
+ curl http://localhost:7860/health
269
+ curl http://localhost:7860/tasks | python3 -m json.tool
270
+ curl -X POST http://localhost:7860/baseline | python3 -m json.tool
271
+ curl -X POST http://localhost:7860/grader | python3 -m json.tool
272
+
273
+ # Reproducibility (must produce no diff)
274
+ python3 baseline_heuristic.py > run1.json
275
+ python3 baseline_heuristic.py > run2.json
276
+ diff run1.json run2.json
 
 
 
 
 
 
 
 
 
 
 
277
  ```
 
278
 
279
+ ## Architecture
280
 
 
 
 
281
  ```
282
+ ml_training_debugger/
283
+ models.py β€” Pydantic data models (Action, Observation, EpisodeState)
284
+ scenarios.py β€” Task parameter sampling (7 tasks, deterministic per seed)
285
+ pytorch_engine.py β€” Real PyTorch models, fault injection, gradient/weight extraction
286
+ simulation.py β€” 20-epoch real training with parametric fallback
287
+ reward_engine.py β€” 7-component per-step reward with context gating
288
+ graders.py β€” Per-task holistic 0.0-1.0 scoring
289
+ code_templates.py β€” Task 6 bug variants + 4-strategy fix validation
290
+ client.py β€” Typed client extending EnvClient
291
+
292
+ server/
293
+ environment.py β€” MLTrainingEnvironment (reset/step/state)
294
+ app.py β€” FastAPI + custom endpoints
295
+ dashboard.html β€” Live Plotly.js diagnostic dashboard
296
+
297
+ tests/ β€” 255 tests, 97% coverage
298
+ baseline_heuristic.py β€” Rule-based agent (deterministic, no API key)
299
+ baseline_inference.py β€” LLM agent (Groq/Cerebras/Gemini/OpenAI)
300
  ```
 
 
 
 
301
 
302
+ **Key design decisions:**
303
+ - **Grader is separate from reward function.** `reward_engine.py` returns a float per step for RL training signal. `graders.py` returns a holistic 0.0-1.0 score at episode end. They are different modules with different purposes.
304
+ - **Task IDs are opaque.** `task_001` through `task_007` β€” the agent cannot infer the diagnosis from the ID.
305
+ - **Task 6 diagnosis is always `code_bug`.** Regardless of which bug variant (eval_mode, detach_loss, zero_grad_missing, inplace_relu), the correct diagnosis is `code_bug`.
306
+ - **Dual model architectures.** SimpleCNN and SimpleMLP are randomly selected per episode, testing agent robustness to architecture variation.
307
+ - **Session isolation.** Each WebSocket connection gets its own environment instance with independent state.
308
+ - **`step()` never raises.** All invalid actions return a valid observation with -0.05 penalty and an error note.
309
 
310
+ ### Technical Stack
311
 
312
+ - Python 3.12 Β· PyTorch 2.5.1 CPU-only Β· openenv-core v0.2.2
313
+ - `import torch` in every core module β€” zero numpy in core
314
+ - Typed Pydantic v2 models everywhere β€” no `Dict[str, Any]`
315
+ - Deterministic reproducibility via `torch.manual_seed()` at every reset
316
+ - Docker image: 885MB (multi-stage build, `strip --strip-unneeded`, transitive dep cleanup)
317
 
318
+ ### Validation Suite
319
 
320
+ 8/8 validation checks pass. Real PyTorch 20-epoch mini-training with fault injection. Each fault type is validated with behavioral checks (gradient detection, loss patterns, model mode, code fix acceptance). Both SimpleCNN and SimpleMLP architectures verified. Results served live at `GET /validation-report`.
321
 
322
+ ## Walkthrough: Solving Task 5 (Hard)
323
 
324
+ This is the most interesting task because it has red herrings designed to mislead.
325
 
326
+ **What the agent sees on reset:**
327
+ - Loss oscillates between 2.1-2.5, never converging
328
+ - Val accuracy stuck at ~0.15
329
+ - Error log mentions GPU memory at 91%
330
 
331
+ **Step 1: Inspect gradients**
332
+ ```
333
+ conv1: mean_norm=0.15, Normal
334
+ conv2: mean_norm=5.2, Normal
335
+ conv3: mean_norm=0.8, Normal
336
+ fc: mean_norm=1.3, Normal (slight spike β€” the red herring)
337
+ ```
338
+ All layers normal. `gradients_were_normal` is now True.
339
 
340
+ **Step 2: Inspect data** β€” class overlap 0.0, data is clean.
 
 
 
 
 
 
 
341
 
342
+ **Step 3: Inspect model modes**
343
+ ```
344
+ conv1: "eval" ← Problem found!
345
+ bn1: "eval"
346
+ conv2: "eval"
347
+ bn2: "eval"
348
+ fc: "eval"
349
+ ```
350
+ All layers stuck in eval mode. BatchNorm is using running statistics instead of batch statistics during training.
351
 
352
+ **Step 4: Fix model mode** β€” switches all layers to train mode.
353
 
354
+ **Step 5: Restart training** β€” convergence confirmed.
355
+
356
+ **Step 6: Diagnose `batchnorm_eval_mode`** β€” correct. Score: 1.0.
357
+
358
+ **What would have gone wrong:**
359
+ If the agent had seen the FC gradient spike and called `add_callback` (gradient clipping), it would have received -0.20 context-gated penalty β€” because it already knew gradients were normal. The penalty only fires when both `gradients_inspected=True` and `gradients_were_normal=True`. Before inspection, the same action would have no penalty.
360
+
361
+ ## Research Summary
362
 
363
  See [PAPER.md](PAPER.md) β€” "Context-Gated Reward Shaping for Evidence-Based ML Debugging"
364
 
365
+ Core claim: by conditioning penalties on the agent's accumulated information state (not just action outcomes), we create environments that reward systematic investigation over pattern-matching β€” a capability with direct transfer value to real-world MLOps debugging.
366
+
367
+ ---
368
 
369
+ *Built for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, 2026.*
baseline_heuristic.py CHANGED
@@ -89,36 +89,24 @@ def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
89
  session = env._get_session()
90
  return session.last_score if session and session.last_score is not None else 0.0
91
 
92
- # Check overfitting (val_loss diverging OR train loss near-zero with rising val_loss)
93
- if obs.val_loss_history and len(obs.val_loss_history) >= 10:
94
- early = sum(obs.val_loss_history[:5]) / 5
95
- late = sum(obs.val_loss_history[-5:]) / 5
96
- train_loss_low = (
97
- obs.training_loss_history
98
- and obs.training_loss_history[-1] < 0.1
99
- )
100
- val_loss_rising = late > early * 1.05
 
 
101
  if (
102
- (val_loss_rising or train_loss_low)
 
103
  and obs.data_batch_stats
104
- and obs.data_batch_stats.class_overlap_score < 0.1
105
  ):
106
- obs = env.step(
107
- MLTrainingAction(
108
- action_type="modify_config",
109
- target="weight_decay",
110
- value=0.01,
111
- )
112
- )
113
- obs = env.step(MLTrainingAction(action_type="restart_run"))
114
- obs = env.step(
115
- MLTrainingAction(
116
- action_type="mark_diagnosed",
117
- diagnosis="overfitting",
118
- )
119
- )
120
- session = env._get_session()
121
- return session.last_score if session and session.last_score is not None else 0.0
122
 
123
  # Step 3: inspect_model_modes
124
  obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
@@ -193,7 +181,26 @@ def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
193
  session = env._get_session()
194
  return session.last_score if session and session.last_score is not None else 0.0
195
 
196
- # Fallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  obs = env.step(
198
  MLTrainingAction(
199
  action_type="mark_diagnosed",
 
89
  session = env._get_session()
90
  return session.last_score if session and session.last_score is not None else 0.0
91
 
92
+ # Detect overfitting pattern (used later, after ruling out code bugs)
93
+ _looks_like_overfitting = False
94
+ if obs.val_loss_history and obs.training_loss_history and len(obs.val_loss_history) >= 10:
95
+ early_train = sum(obs.training_loss_history[:5]) / 5
96
+ late_train = sum(obs.training_loss_history[-5:]) / 5
97
+ early_val = sum(obs.val_loss_history[:5]) / 5
98
+ late_val = sum(obs.val_loss_history[-5:]) / 5
99
+ train_dropped = late_train < early_train * 0.5
100
+ train_loss_low = late_train < 0.15
101
+ val_not_improving = late_val >= early_val * 0.95
102
+ gap_widening = (late_val - late_train) > (early_val - early_train)
103
  if (
104
+ (train_dropped or train_loss_low)
105
+ and (val_not_improving or gap_widening)
106
  and obs.data_batch_stats
107
+ and obs.data_batch_stats.class_overlap_score < 0.3
108
  ):
109
+ _looks_like_overfitting = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # Step 3: inspect_model_modes
112
  obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
 
181
  session = env._get_session()
182
  return session.last_score if session and session.last_score is not None else 0.0
183
 
184
+ # Overfitting fallback β€” only if code inspection didn't find a bug
185
+ if _looks_like_overfitting:
186
+ obs = env.step(
187
+ MLTrainingAction(
188
+ action_type="modify_config",
189
+ target="weight_decay",
190
+ value=0.01,
191
+ )
192
+ )
193
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
194
+ obs = env.step(
195
+ MLTrainingAction(
196
+ action_type="mark_diagnosed",
197
+ diagnosis="overfitting",
198
+ )
199
+ )
200
+ session = env._get_session()
201
+ return session.last_score if session and session.last_score is not None else 0.0
202
+
203
+ # Final fallback
204
  obs = env.step(
205
  MLTrainingAction(
206
  action_type="mark_diagnosed",