UjjwalPardeshi commited on
Commit
9c7163b
·
1 Parent(s): 9442887

rewrite readme

Browse files
Files changed (1) hide show
  1. README.md +81 -303
README.md CHANGED
@@ -13,358 +13,136 @@ tags:
13
 
14
  # PyTorch Training Run Debugger
15
 
16
- **OpenEnv RL Environment** | Meta PyTorch OpenEnv Hackathon x Scaler School of Technology
17
 
18
- **Live Demo:** [HF Space](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/dashboard) | **API Health:** [/health](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/health) | **API Docs:** [/docs](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/docs)
19
 
20
- An AI agent debugs broken PyTorch training runs by investigating gradients, model weights, data pipelines, and source code to diagnose and fix real ML failure patterns.
21
 
22
  ---
23
 
24
- ## The Problem
25
 
26
- ML teams spend 15-25% of engineer time debugging silent training failures runs that produce no error, no crash, just mysteriously bad metrics. Each misdiagnosed restart wastes GPU compute at $2-8/hour/card. The diagnostic process is hard because multiple symptoms point to multiple causes, some bugs produce no error at all, and fixing the wrong thing wastes hours.
27
 
28
- No existing OpenEnv environment covers this domain.
29
 
30
- ## What This Does
31
 
32
- The environment recreates the experience of an ML engineer facing a broken training job. The agent receives a snapshot of a failing training run and must:
33
 
34
- 1. **Investigate** inspect gradients, data batches, model weights, model modes, and code
35
- 2. **Diagnose** — identify the root cause from 7 known ML failure types
36
- 3. **Fix** — apply the correct intervention
37
- 4. **Verify** — restart training and confirm recovery before submitting
38
 
39
- The agent starts with limited information (loss curves, config, error log) and must actively choose what to investigate. Each inspection reveals new data — gradient norms, class overlap scores, model train/eval modes, or buggy source code. This makes it a genuine investigation, not just a classification task.
40
 
41
- ## What Makes This Different
42
-
43
- ### Real PyTorch Model Internals
44
-
45
- Every gradient comes from real `torch.autograd`. Every weight stat comes from real `model.state_dict()`. The environment instantiates actual `torch.nn.Module` models (SimpleCNN ~67K params, SimpleMLP ~412K params), runs 20 real forward+backward epochs per reset, and extracts real tensor statistics. Not synthetic formulas — real PyTorch computation, cached for instant replay.
46
-
47
- ### Context-Gated Reward Shaping
48
-
49
- Standard RL environments use stateless rewards: "did action X happen?" This environment tracks the agent's information state and conditions penalties on what the agent has already observed.
50
-
51
- An agent that adds gradient clipping *before* inspecting gradients follows a reasonable prior — **no penalty**. An agent that inspects gradients, sees they are normal, and *then* adds gradient clipping is ignoring counter-evidence — **-0.20 penalty**.
52
-
53
- The gate requires two conditions to be jointly true (`gradients_inspected AND gradients_were_normal`), both of which depend on prior agent actions. This encodes a transferable skill into the reward signal: don't ignore what you've already learned.
54
-
55
- ### Code-Level Debugging
56
-
57
- Task 6 presents actual buggy PyTorch training loops. The agent reads real Python code, identifies the buggy line, and submits a line-by-line fix. Four bug variants: `model.eval()` instead of `model.train()`, `.detach()` killing gradient flow, missing `optimizer.zero_grad()`, and `inplace=True` on ReLU corrupting the computation graph.
58
-
59
- Fix validation uses a 4-strategy pipeline: whitespace normalization, token-stream comparison via Python's `tokenize` module, semantic pattern matching, and `ast.parse()` fallback. This handles the messy fixes that LLM agents actually produce (trailing spaces, inline comments, different indentation).
60
-
61
- ### Red Herring Injection
62
-
63
- Task 5 (BatchNorm eval mode) deliberately plants misleading signals: a gradient spike in the FC layer that doesn't cross the exploding threshold, a GPU memory warning at 91%, and near-vanishing gradients in conv1. The real problem is only visible through model mode inspection. This separates agents that follow rigid patterns from agents that can reason through ambiguity.
64
 
65
  ## Tasks
66
 
67
- 7 failure scenarios across 3 difficulty tiers, each with configurable difficulty level (1-5):
68
-
69
  | ID | Difficulty | Root Cause | What Goes Wrong |
70
  |----|-----------|------------|-----------------|
71
- | `task_001` | Easy | `lr_too_high` | All gradient layers explode, NaN in loss. Direct signal — inspect gradients, reduce LR. |
72
- | `task_002` | Easy | `vanishing_gradients` | Deeper layers show vanishing norms, loss stays flat. Model can't learn. |
73
- | `task_003` | Medium | `data_leakage` | Suspiciously high val accuracy from epoch 1. `class_overlap_score > 0.5` confirms test data leaked into training. Red herring note about "architecture upgrade." |
74
- | `task_004` | Medium | `overfitting` | Train loss drops to near-zero while val loss climbs. Classic memorization pattern. |
75
- | `task_005` | Hard | `batchnorm_eval_mode` | Slow degradation with compound red herrings. Gradients look normal. The real problem: all layers stuck in eval mode. |
76
- | `task_006` | Hard | `code_bug` | Metrics are anomalous but gradients/data/modes look fine. Root cause is in the Python training loop 4 possible bug variants. |
77
- | `task_007` | Med-Hard | `scheduler_misconfigured` | Training improves initially then stagnates. LR scheduler decays too aggressively (low gamma, small step size). |
78
-
79
- ### How Difficulty Scales
80
-
81
- Easy tasks have one obvious signal (all gradients exploding). Medium tasks require checking multiple sources and ruling out alternatives. Hard tasks deliberately mislead — the most obvious signal is wrong, and the real problem is hidden behind layers of investigation.
82
-
83
- ## Observation Space
84
-
85
- | Field | Type | When Visible |
86
- |-------|------|-------------|
87
- | `training_loss_history` | `list[float]` (20 epochs) | Always |
88
- | `val_accuracy_history` | `list[float]` (20 epochs) | Always |
89
- | `val_loss_history` | `list[float]` (20 epochs) | Always |
90
- | `current_config` | `TrainingConfig` | Always |
91
- | `error_log` | `str` or `null` | Always |
92
- | `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
93
- | `model_weight_stats` | `list[ModelWeightStats]` | After `inspect_model_weights` |
94
- | `data_batch_stats` | `DataBatchStats` | After `inspect_data_batch` |
95
- | `model_mode_info` | `dict[str, str]` | After `inspect_model_modes` |
96
- | `code_snippet` | `CodeSnippet` | After `inspect_code` |
97
- | `available_actions` | `list[str]` | Always (dynamic) |
98
- | `episode_state` | `EpisodeState` | Always |
99
-
100
- Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_snippet` start as `null` and are only populated after the agent explicitly requests them. The agent must decide what to investigate.
101
-
102
- ## Action Space
103
-
104
- 13 action types in 3 categories:
105
-
106
- **Investigation** — reveal hidden observation fields:
107
- - `inspect_gradients` — per-layer gradient norms, is_exploding/is_vanishing flags
108
- - `inspect_data_batch` — label distribution, class overlap score, confusion matrix
109
- - `inspect_model_modes` — train/eval mode per layer
110
- - `inspect_model_weights` — weight norms, dead neurons, NaN/Inf detection
111
- - `inspect_code` — the actual Python training loop (Task 6)
112
-
113
- **Fix** — apply an intervention:
114
- - `modify_config` — change learning_rate, weight_decay, batch_size, optimizer, etc.
115
- - `add_callback` — add gradient clipping
116
- - `patch_data_loader` — fix data pipeline
117
- - `fix_model_mode` — switch model to train mode
118
- - `fix_code` — fix a specific line of code (requires line number + replacement)
119
- - `replace_optimizer` — switch optimizer
120
-
121
- **Terminal** — end the episode:
122
- - `restart_run` — restart training (only available after a fix)
123
- - `mark_diagnosed` — submit diagnosis from 7 possible root causes
124
-
125
- Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
126
-
127
- ## Reward Function
128
-
129
- Per-step signal, separate from the grader. Hard cap at [-1.0, 1.0].
130
-
131
- | Event | Reward | Condition |
132
- |-------|--------|-----------|
133
- | Any step | -0.01 | Flat, unconditional (encourages efficiency) |
134
- | First-time inspection | +0.05 | Per inspection type, first time only |
135
- | Correct diagnosis | +0.50 | `diagnosis == root_cause` |
136
- | Wrong diagnosis | -0.30 | `diagnosis != root_cause` |
137
- | Convergence after fix+restart | +0.40 | Fix applied, restarted, training recovers |
138
- | Invalid action | -0.05 | Action not in `available_actions` |
139
- | Wrong code fix | -0.10 | `fix_code` with incorrect line/replacement |
140
- | **Context-gated penalty** | **-0.20** | `gradients_inspected AND gradients_were_normal AND action == add_callback` |
141
-
142
- The step penalty is flat -0.01 (never multiplied by step count). Investigation bonuses fire once per type. The context-gated penalty requires the agent to have previously inspected gradients and found them normal — it cannot fire before inspection.
143
-
144
- ## Grading
145
-
146
- Each task has a separate grader that evaluates the complete `EpisodeState` at episode end, returning a normalized 0.0-1.0 score. The grader is **not** a sum of step rewards — it's a holistic evaluation of whether the agent investigated correctly, applied the right fix, restarted training, and diagnosed accurately.
147
-
148
- Example (Task 5 — BatchNorm Eval):
149
-
150
- | Component | Points |
151
- |-----------|--------|
152
- | Inspected gradients | +0.05 |
153
- | Inspected model modes (the revealing action) | +0.05 |
154
- | Fixed model mode | +0.25 |
155
- | Restarted training | +0.30 |
156
- | Correct diagnosis | +0.40 |
157
- | Fell for red herring (add_callback after normal gradients) | -0.20 |
158
-
159
- An agent that chases the gradient spike red herring loses 0.20 points. An agent that goes straight to model modes and finds the real problem scores 1.0.
160
 
161
- ## Baseline Scores
162
 
163
- ### Heuristic vs LLM Comparison
164
 
165
- | Task | Difficulty | Heuristic | Llama 3.1 8B |
166
- |------|-----------|-----------|--------------|
167
- | `task_001` | Easy | **1.00** | 0.60 |
168
- | `task_002` | Easy | **1.00** | 0.05 |
169
- | `task_003` | Medium | **1.00** | 0.40 |
170
- | `task_004` | Medium | **1.00** | 0.60 |
171
- | `task_005` | Hard | **0.80** | 0.38-0.55 |
172
- | `task_006` | Hard | **0.81** | 0.60-1.00 |
173
- | `task_007` | Hard | **0.79** | 0.60 |
174
- | **Average** | | **0.91** | 0.52 |
175
 
176
- **What this tells you:**
177
- - **Hard tasks are genuinely hard:** All three hard tasks (5, 6, 7) require thorough investigation including weight inspection for full credit. The heuristic scores 0.79-0.81 on hard tasks because it skips weight inspection. An LLM that falls for red herrings or skips investigation scores even lower.
178
- - **Red herring traps work:** Task 5 penalizes agents that call `add_callback` after seeing normal gradients (-0.20) or `modify_config` when LR isn't the issue (-0.10). LLMs routinely fall for both traps.
179
- - **Investigation thoroughness matters:** Tasks 6 and 7 scale fix/restart credit based on how thoroughly the agent investigated before acting. Quick fixes without ruling out alternatives score ~60-65% of full credit.
180
- - **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
181
- - **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
182
 
183
- ### Running Baselines
184
 
185
- ```bash
186
- # Heuristic (deterministic, no API key, bit-exact reproducible)
187
- python3 baseline_heuristic.py
188
-
189
- # LLM (hackathon evaluator format — uses OpenAI client)
190
- API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o OPENAI_API_KEY=sk-... python3 inference.py
191
- ```
192
-
193
- ## API
194
-
195
- ### HTTP Endpoints
196
-
197
- | Endpoint | Method | Description |
198
- |----------|--------|-------------|
199
- | `/health` | GET | `{"status": "healthy", "tasks": 7}` |
200
- | `/tasks` | GET | Task list with IDs, difficulties, action schema |
201
- | `/grader` | POST | Score for last completed episode |
202
- | `/baseline` | POST | Run heuristic on all tasks, return scores |
203
- | `/dashboard` | GET | Live 4-panel diagnostic dashboard |
204
- | `/validation-report` | GET | Simulation fidelity report |
205
- | `/curriculum` | GET | Recommended task order (easy to hard, difficulty 1-5) |
206
- | `/leaderboard` | GET | Sorted episode scores |
207
- | `/replay/{id}` | GET | Full action/observation trace for an episode |
208
- | `/docs` | GET | Swagger UI (auto-generated by FastAPI) |
209
 
210
- ### WebSocket (Primary Agent Interface)
211
 
212
- The WebSocket endpoint at `/ws` maintains session state across a full episode. HTTP endpoints are stateless by framework design.
 
 
 
 
 
 
 
 
213
 
214
- **Reset** (start an episode):
215
- ```json
216
- {"type": "reset", "data": {"task_id": "task_003", "seed": 42}}
217
- ```
218
-
219
- **Step** (take an action):
220
- ```json
221
- {"type": "step", "data": {"action_type": "inspect_gradients"}}
222
- {"type": "step", "data": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
223
- {"type": "step", "data": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
224
- ```
225
 
226
- **Response format:**
227
- ```json
228
- {"type": "observation", "data": {"observation": {...}, "reward": 0.04, "done": false}}
229
- ```
230
 
231
- ## Dashboard
232
 
233
- A live 4-panel diagnostic dashboard at `/dashboard`:
234
 
235
- 1. **Training Metrics** loss/accuracy curves with Plotly.js
236
- 2. **Gradient & Weight Heatmap** — per-layer bars, color-coded (green=normal, red=exploding, blue=vanishing)
237
- 3. **Action Timeline & Rewards** step-by-step bars showing reward per action and cumulative reward line
238
- 4. **Episode Summary** state flags, available actions, code snippet (Task 6)
 
 
 
 
 
 
239
 
240
- Select a task, click "Run Baseline", and watch the heuristic agent investigate, fix, and diagnose step by step. The charts update live over WebSocket.
241
 
242
  ## Setup
243
 
244
- ### Local Development
245
-
246
  ```bash
247
- python3 -m venv .venv
248
- source .venv/bin/activate
249
-
250
  pip install torch --index-url https://download.pytorch.org/whl/cpu
251
  pip install openenv-core pydantic fastapi uvicorn
252
- pip install pytest pytest-cov pytest-asyncio httpx websockets
253
-
254
- # Start server
255
  uvicorn server.app:app --host 0.0.0.0 --port 7860
256
 
257
- # Run tests (246 tests, 96% coverage)
258
- pytest tests/ -v --cov=ml_training_debugger
259
-
260
- # Run heuristic baseline
261
- python3 baseline_heuristic.py
262
- ```
263
-
264
- ### Docker
265
-
266
- ```bash
267
  docker build -t pytorch-debugger .
268
  docker run -p 7860:7860 pytorch-debugger
269
- curl http://localhost:7860/health
270
- ```
271
-
272
- ### Smoke Test
273
 
274
- ```bash
275
- # Verify all critical paths
276
- curl http://localhost:7860/health
277
- curl http://localhost:7860/tasks | python3 -m json.tool
278
- curl -X POST http://localhost:7860/baseline | python3 -m json.tool
279
- curl -X POST http://localhost:7860/grader | python3 -m json.tool
280
-
281
- # Reproducibility (must produce no diff)
282
- python3 baseline_heuristic.py > run1.json
283
- python3 baseline_heuristic.py > run2.json
284
- diff run1.json run2.json
285
  ```
286
 
287
- ## Architecture
288
 
289
  ```
290
  ml_training_debugger/
291
- models.py Pydantic data models (Action, Observation, EpisodeState)
292
- scenarios.py Task parameter sampling (7 tasks, deterministic per seed)
293
- pytorch_engine.py Real PyTorch models, fault injection, gradient/weight extraction
294
- simulation.py 20-epoch real training with fault injection
295
- reward_engine.py — 7-component per-step reward with context gating
296
- graders.py Per-task holistic 0.0-1.0 scoring
297
- code_templates.py Task 6 bug variants + 4-strategy fix validation
298
- client.py — Typed client extending EnvClient
299
-
300
  server/
301
- environment.py MLTrainingEnvironment (reset/step/state)
302
- app.py FastAPI + custom endpoints
303
- dashboard.html Live Plotly.js diagnostic dashboard
304
-
305
- tests/ — 246 tests, 96% coverage
306
- baseline_heuristic.py — Rule-based agent (deterministic, no API key)
307
- inference.py — LLM agent (OpenAI client, hackathon format)
308
  ```
309
 
310
- **Key design decisions:**
311
- - **Grader is separate from reward function.** `reward_engine.py` returns a float per step for RL training signal. `graders.py` returns a holistic 0.0-1.0 score at episode end. They are different modules with different purposes.
312
- - **Task IDs are opaque.** `task_001` through `task_007` — the agent cannot infer the diagnosis from the ID.
313
- - **Task 6 diagnosis is always `code_bug`.** Regardless of which bug variant (eval_mode, detach_loss, zero_grad_missing, inplace_relu), the correct diagnosis is `code_bug`.
314
- - **Dual model architectures.** SimpleCNN and SimpleMLP are randomly selected per episode, testing agent robustness to architecture variation.
315
- - **Session isolation.** Each WebSocket connection gets its own environment instance with independent state.
316
- - **`step()` never raises.** All invalid actions return a valid observation with -0.05 penalty and an error note.
317
-
318
- ### Technical Stack
319
-
320
- - Python 3.12 · PyTorch 2.5.1 CPU-only · openenv-core v0.2.2
321
- - `import torch` in every core module — zero numpy in core
322
- - Typed Pydantic v2 models everywhere — no `Dict[str, Any]`
323
- - Deterministic reproducibility via `torch.manual_seed()` at every reset
324
- - Docker image: 885MB (multi-stage build, `strip --strip-unneeded`, transitive dep cleanup)
325
-
326
- ### Validation Suite
327
-
328
- 8/8 validation checks pass. Real PyTorch 20-epoch mini-training with fault injection. Each fault type is validated with behavioral checks (gradient detection, loss patterns, model mode, code fix acceptance). Both SimpleCNN and SimpleMLP architectures verified. Results served live at `GET /validation-report`.
329
-
330
- ## Walkthrough: Solving Task 5 (Hard)
331
-
332
- This is the most interesting task because it has red herrings designed to mislead.
333
-
334
- **What the agent sees on reset:**
335
- - Loss oscillates between 2.1-2.5, never converging
336
- - Val accuracy stuck at ~0.15
337
- - Error log mentions GPU memory at 91%
338
-
339
- **Step 1: Inspect gradients**
340
- ```
341
- conv1: mean_norm=0.15, Normal
342
- conv2: mean_norm=5.2, Normal
343
- conv3: mean_norm=0.8, Normal
344
- fc: mean_norm=1.3, Normal (slight spike — the red herring)
345
- ```
346
- All layers normal. `gradients_were_normal` is now True.
347
-
348
- **Step 2: Inspect data** — class overlap 0.0, data is clean.
349
-
350
- **Step 3: Inspect model modes**
351
- ```
352
- conv1: "eval" ← Problem found!
353
- bn1: "eval"
354
- conv2: "eval"
355
- bn2: "eval"
356
- fc: "eval"
357
- ```
358
- All layers stuck in eval mode. BatchNorm is using running statistics instead of batch statistics during training.
359
-
360
- **Step 4: Fix model mode** — switches all layers to train mode.
361
-
362
- **Step 5: Restart training** — convergence confirmed.
363
-
364
- **Step 6: Diagnose `batchnorm_eval_mode`** — correct. Score: 1.0.
365
 
366
- **What would have gone wrong:**
367
- If the agent had seen the FC gradient spike and called `add_callback` (gradient clipping), it would have received -0.20 context-gated penalty — because it already knew gradients were normal. The penalty only fires when both `gradients_inspected=True` and `gradients_were_normal=True`. Before inspection, the same action would have no penalty.
368
- ---
 
 
 
 
 
369
 
370
- *Built for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, 2026.*
 
13
 
14
  # PyTorch Training Run Debugger
15
 
16
+ An OpenEnv RL environment where AI agents debug broken PyTorch training runs.
17
 
18
+ Built for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, 2026.
19
 
20
+ [Live Demo](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/dashboard) | [API Health](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/health) | [API Docs](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/docs)
21
 
22
  ---
23
 
24
+ ## Why I Built This
25
 
26
+ Every ML engineer has been there: your model trains for hours, doesn't crash, doesn't throw errors, but the loss just won't go down. You stare at TensorBoard, tweak the learning rate, restart, repeat. It's tedious, time-consuming, and hard to teach. I wanted to turn that debugging experience into an RL environment so agents can learn to do it too.
27
 
28
+ ## How It Works
29
 
30
+ The environment drops the agent into a broken PyTorch training run. The agent sees loss curves, config, and error logs — but not much else. It has to actively investigate (inspect gradients, look at data, check model modes, read the code) to figure out what's wrong.
31
 
32
+ Once it thinks it knows the problem, it applies a fix, restarts training, and submits a diagnosis. The grader scores the whole episode not just whether the answer was right, but whether the agent investigated properly before acting.
33
 
34
+ There are 7 tasks covering common ML failures: exploding/vanishing gradients, data leakage, overfitting, BatchNorm stuck in eval mode, bugs in the training loop, and misconfigured LR schedulers. The hard tasks have red herrings that punish agents for jumping to conclusions.
 
 
 
35
 
36
+ ## What's Under the Hood
37
 
38
+ - **Real PyTorch, not fake data.** Gradients come from `torch.autograd`, weights from `model.state_dict()`. The env runs actual `torch.nn.Module` models (SimpleCNN, SimpleMLP), does 20 real forward+backward passes per reset, and caches the results.
39
+ - **Context-gated rewards.** If an agent adds gradient clipping after already seeing that gradients are normal, it gets penalized. If it does it before inspecting, no penalty. The reward depends on what the agent knows, not just what it does.
40
+ - **Code-level debugging.** Task 6 presents buggy Python training loops. The agent reads the code, finds the bug, and submits a fix. Four bug variants: `model.eval()` left in, `.detach()` killing gradients, missing `zero_grad()`, and `inplace=True` on ReLU.
41
+ - **Red herrings on hard tasks.** Task 5 plants a suspicious gradient spike and a GPU memory warning. Both are distractions. The real problem is only visible through model mode inspection.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ## Tasks
44
 
 
 
45
  | ID | Difficulty | Root Cause | What Goes Wrong |
46
  |----|-----------|------------|-----------------|
47
+ | `task_001` | Easy | `lr_too_high` | Gradients explode, NaN in loss |
48
+ | `task_002` | Easy | `vanishing_gradients` | Deeper layers vanish, loss stays flat |
49
+ | `task_003` | Medium | `data_leakage` | Suspiciously high val accuracy from epoch 1 |
50
+ | `task_004` | Medium | `overfitting` | Train loss drops, val loss climbs |
51
+ | `task_005` | Hard | `batchnorm_eval_mode` | Slow degradation, gradient red herrings |
52
+ | `task_006` | Hard | `code_bug` | Buggy training loop (4 variants) |
53
+ | `task_007` | Hard | `scheduler_misconfigured` | LR decays too aggressively |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ Easy tasks have one obvious signal. Medium tasks need multiple inspections. Hard tasks actively mislead you.
56
 
57
+ ## Actions
58
 
59
+ **Investigate:** `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code`
 
 
 
 
 
 
 
 
 
60
 
61
+ **Fix:** `modify_config`, `add_callback`, `patch_data_loader`, `fix_model_mode`, `fix_code`, `replace_optimizer`
 
 
 
 
 
62
 
63
+ **Terminal:** `restart_run` (needs a fix first), `mark_diagnosed` (submit diagnosis)
64
 
65
+ Actions are dynamic — `fix_code` only unlocks after code inspection, `restart_run` only after a fix.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ ## Reward Signal
68
 
69
+ | Event | Reward |
70
+ |-------|--------|
71
+ | Any step | -0.01 |
72
+ | First-time inspection | +0.05 |
73
+ | Correct diagnosis | +0.50 |
74
+ | Wrong diagnosis | -0.30 |
75
+ | Convergence after fix+restart | +0.40 |
76
+ | Invalid action | -0.05 |
77
+ | Context-gated penalty | -0.20 |
78
 
79
+ The context-gated penalty fires when: agent inspected gradients, saw they were normal, and still applied gradient clipping. It's a penalty for ignoring evidence.
 
 
 
 
 
 
 
 
 
 
80
 
81
+ ## Grading
 
 
 
82
 
83
+ Each task has a holistic grader (separate from the per-step reward) that looks at the full episode: did the agent investigate the right things, apply the correct fix, restart training, and diagnose accurately? Scores are 0-1.
84
 
85
+ ## Baseline Results
86
 
87
+ | Task | Heuristic | Llama 3.1 8B |
88
+ |------|-----------|--------------|
89
+ | task_001 (Easy) | 1.00 | 0.60 |
90
+ | task_002 (Easy) | 1.00 | 0.05 |
91
+ | task_003 (Medium) | 1.00 | 0.40 |
92
+ | task_004 (Medium) | 1.00 | 0.60 |
93
+ | task_005 (Hard) | 0.80 | 0.38-0.55 |
94
+ | task_006 (Hard) | 0.81 | 0.60-1.00 |
95
+ | task_007 (Hard) | 0.79 | 0.60 |
96
+ | **Average** | **0.91** | **0.52** |
97
 
98
+ The heuristic is strong because it knows the task structure. An LLM has to figure it out from observations.
99
 
100
  ## Setup
101
 
 
 
102
  ```bash
103
+ # Local
104
+ python3 -m venv .venv && source .venv/bin/activate
 
105
  pip install torch --index-url https://download.pytorch.org/whl/cpu
106
  pip install openenv-core pydantic fastapi uvicorn
 
 
 
107
  uvicorn server.app:app --host 0.0.0.0 --port 7860
108
 
109
+ # Docker
 
 
 
 
 
 
 
 
 
110
  docker build -t pytorch-debugger .
111
  docker run -p 7860:7860 pytorch-debugger
 
 
 
 
112
 
113
+ # Baselines
114
+ python3 baseline_heuristic.py
115
+ API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o HF_TOKEN=sk-... python3 inference.py
 
 
 
 
 
 
 
 
116
  ```
117
 
118
+ ## Project Structure
119
 
120
  ```
121
  ml_training_debugger/
122
+ models.py - Data models (Action, Observation, EpisodeState)
123
+ scenarios.py - Task parameter sampling
124
+ pytorch_engine.py - Real PyTorch models and fault injection
125
+ simulation.py - 20-epoch training with fault injection
126
+ reward_engine.py - Per-step reward with context gating
127
+ graders.py - Per-task holistic scoring
128
+ code_templates.py - Task 6 bug variants + fix validation
 
 
129
  server/
130
+ environment.py - MLTrainingEnvironment (reset/step/state)
131
+ app.py - FastAPI app + endpoints
132
+ dashboard.html - Live diagnostic dashboard (Plotly.js)
133
+ inference.py - LLM agent (OpenAI client, hackathon format)
134
+ baseline_heuristic.py - Rule-based agent (no API key needed)
 
 
135
  ```
136
 
137
+ ## API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ | Endpoint | Method | Description |
140
+ |----------|--------|-------------|
141
+ | `/health` | GET | Health check |
142
+ | `/tasks` | GET | Task list with action schema |
143
+ | `/grader` | POST | Score for last completed episode |
144
+ | `/baseline` | POST | Run heuristic on all tasks |
145
+ | `/dashboard` | GET | Live diagnostic dashboard |
146
+ | `/docs` | GET | Swagger UI |
147
 
148
+ WebSocket at `/ws` for full episode sessions (reset, step, observe).