UjjwalPardeshi commited on
Commit ·
9c7163b
1
Parent(s): 9442887
rewrite readme
Browse files
README.md
CHANGED
|
@@ -13,358 +13,136 @@ tags:
|
|
| 13 |
|
| 14 |
# PyTorch Training Run Debugger
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
---
|
| 23 |
|
| 24 |
-
##
|
| 25 |
|
| 26 |
-
ML
|
| 27 |
|
| 28 |
-
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
2. **Diagnose** — identify the root cause from 7 known ML failure types
|
| 36 |
-
3. **Fix** — apply the correct intervention
|
| 37 |
-
4. **Verify** — restart training and confirm recovery before submitting
|
| 38 |
|
| 39 |
-
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
Every gradient comes from real `torch.autograd`. Every weight stat comes from real `model.state_dict()`. The environment instantiates actual `torch.nn.Module` models (SimpleCNN ~67K params, SimpleMLP ~412K params), runs 20 real forward+backward epochs per reset, and extracts real tensor statistics. Not synthetic formulas — real PyTorch computation, cached for instant replay.
|
| 46 |
-
|
| 47 |
-
### Context-Gated Reward Shaping
|
| 48 |
-
|
| 49 |
-
Standard RL environments use stateless rewards: "did action X happen?" This environment tracks the agent's information state and conditions penalties on what the agent has already observed.
|
| 50 |
-
|
| 51 |
-
An agent that adds gradient clipping *before* inspecting gradients follows a reasonable prior — **no penalty**. An agent that inspects gradients, sees they are normal, and *then* adds gradient clipping is ignoring counter-evidence — **-0.20 penalty**.
|
| 52 |
-
|
| 53 |
-
The gate requires two conditions to be jointly true (`gradients_inspected AND gradients_were_normal`), both of which depend on prior agent actions. This encodes a transferable skill into the reward signal: don't ignore what you've already learned.
|
| 54 |
-
|
| 55 |
-
### Code-Level Debugging
|
| 56 |
-
|
| 57 |
-
Task 6 presents actual buggy PyTorch training loops. The agent reads real Python code, identifies the buggy line, and submits a line-by-line fix. Four bug variants: `model.eval()` instead of `model.train()`, `.detach()` killing gradient flow, missing `optimizer.zero_grad()`, and `inplace=True` on ReLU corrupting the computation graph.
|
| 58 |
-
|
| 59 |
-
Fix validation uses a 4-strategy pipeline: whitespace normalization, token-stream comparison via Python's `tokenize` module, semantic pattern matching, and `ast.parse()` fallback. This handles the messy fixes that LLM agents actually produce (trailing spaces, inline comments, different indentation).
|
| 60 |
-
|
| 61 |
-
### Red Herring Injection
|
| 62 |
-
|
| 63 |
-
Task 5 (BatchNorm eval mode) deliberately plants misleading signals: a gradient spike in the FC layer that doesn't cross the exploding threshold, a GPU memory warning at 91%, and near-vanishing gradients in conv1. The real problem is only visible through model mode inspection. This separates agents that follow rigid patterns from agents that can reason through ambiguity.
|
| 64 |
|
| 65 |
## Tasks
|
| 66 |
|
| 67 |
-
7 failure scenarios across 3 difficulty tiers, each with configurable difficulty level (1-5):
|
| 68 |
-
|
| 69 |
| ID | Difficulty | Root Cause | What Goes Wrong |
|
| 70 |
|----|-----------|------------|-----------------|
|
| 71 |
-
| `task_001` | Easy | `lr_too_high` |
|
| 72 |
-
| `task_002` | Easy | `vanishing_gradients` | Deeper layers
|
| 73 |
-
| `task_003` | Medium | `data_leakage` | Suspiciously high val accuracy from epoch 1
|
| 74 |
-
| `task_004` | Medium | `overfitting` | Train loss drops
|
| 75 |
-
| `task_005` | Hard | `batchnorm_eval_mode` | Slow degradation
|
| 76 |
-
| `task_006` | Hard | `code_bug` |
|
| 77 |
-
| `task_007` |
|
| 78 |
-
|
| 79 |
-
### How Difficulty Scales
|
| 80 |
-
|
| 81 |
-
Easy tasks have one obvious signal (all gradients exploding). Medium tasks require checking multiple sources and ruling out alternatives. Hard tasks deliberately mislead — the most obvious signal is wrong, and the real problem is hidden behind layers of investigation.
|
| 82 |
-
|
| 83 |
-
## Observation Space
|
| 84 |
-
|
| 85 |
-
| Field | Type | When Visible |
|
| 86 |
-
|-------|------|-------------|
|
| 87 |
-
| `training_loss_history` | `list[float]` (20 epochs) | Always |
|
| 88 |
-
| `val_accuracy_history` | `list[float]` (20 epochs) | Always |
|
| 89 |
-
| `val_loss_history` | `list[float]` (20 epochs) | Always |
|
| 90 |
-
| `current_config` | `TrainingConfig` | Always |
|
| 91 |
-
| `error_log` | `str` or `null` | Always |
|
| 92 |
-
| `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
|
| 93 |
-
| `model_weight_stats` | `list[ModelWeightStats]` | After `inspect_model_weights` |
|
| 94 |
-
| `data_batch_stats` | `DataBatchStats` | After `inspect_data_batch` |
|
| 95 |
-
| `model_mode_info` | `dict[str, str]` | After `inspect_model_modes` |
|
| 96 |
-
| `code_snippet` | `CodeSnippet` | After `inspect_code` |
|
| 97 |
-
| `available_actions` | `list[str]` | Always (dynamic) |
|
| 98 |
-
| `episode_state` | `EpisodeState` | Always |
|
| 99 |
-
|
| 100 |
-
Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_snippet` start as `null` and are only populated after the agent explicitly requests them. The agent must decide what to investigate.
|
| 101 |
-
|
| 102 |
-
## Action Space
|
| 103 |
-
|
| 104 |
-
13 action types in 3 categories:
|
| 105 |
-
|
| 106 |
-
**Investigation** — reveal hidden observation fields:
|
| 107 |
-
- `inspect_gradients` — per-layer gradient norms, is_exploding/is_vanishing flags
|
| 108 |
-
- `inspect_data_batch` — label distribution, class overlap score, confusion matrix
|
| 109 |
-
- `inspect_model_modes` — train/eval mode per layer
|
| 110 |
-
- `inspect_model_weights` — weight norms, dead neurons, NaN/Inf detection
|
| 111 |
-
- `inspect_code` — the actual Python training loop (Task 6)
|
| 112 |
-
|
| 113 |
-
**Fix** — apply an intervention:
|
| 114 |
-
- `modify_config` — change learning_rate, weight_decay, batch_size, optimizer, etc.
|
| 115 |
-
- `add_callback` — add gradient clipping
|
| 116 |
-
- `patch_data_loader` — fix data pipeline
|
| 117 |
-
- `fix_model_mode` — switch model to train mode
|
| 118 |
-
- `fix_code` — fix a specific line of code (requires line number + replacement)
|
| 119 |
-
- `replace_optimizer` — switch optimizer
|
| 120 |
-
|
| 121 |
-
**Terminal** — end the episode:
|
| 122 |
-
- `restart_run` — restart training (only available after a fix)
|
| 123 |
-
- `mark_diagnosed` — submit diagnosis from 7 possible root causes
|
| 124 |
-
|
| 125 |
-
Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
|
| 126 |
-
|
| 127 |
-
## Reward Function
|
| 128 |
-
|
| 129 |
-
Per-step signal, separate from the grader. Hard cap at [-1.0, 1.0].
|
| 130 |
-
|
| 131 |
-
| Event | Reward | Condition |
|
| 132 |
-
|-------|--------|-----------|
|
| 133 |
-
| Any step | -0.01 | Flat, unconditional (encourages efficiency) |
|
| 134 |
-
| First-time inspection | +0.05 | Per inspection type, first time only |
|
| 135 |
-
| Correct diagnosis | +0.50 | `diagnosis == root_cause` |
|
| 136 |
-
| Wrong diagnosis | -0.30 | `diagnosis != root_cause` |
|
| 137 |
-
| Convergence after fix+restart | +0.40 | Fix applied, restarted, training recovers |
|
| 138 |
-
| Invalid action | -0.05 | Action not in `available_actions` |
|
| 139 |
-
| Wrong code fix | -0.10 | `fix_code` with incorrect line/replacement |
|
| 140 |
-
| **Context-gated penalty** | **-0.20** | `gradients_inspected AND gradients_were_normal AND action == add_callback` |
|
| 141 |
-
|
| 142 |
-
The step penalty is flat -0.01 (never multiplied by step count). Investigation bonuses fire once per type. The context-gated penalty requires the agent to have previously inspected gradients and found them normal — it cannot fire before inspection.
|
| 143 |
-
|
| 144 |
-
## Grading
|
| 145 |
-
|
| 146 |
-
Each task has a separate grader that evaluates the complete `EpisodeState` at episode end, returning a normalized 0.0-1.0 score. The grader is **not** a sum of step rewards — it's a holistic evaluation of whether the agent investigated correctly, applied the right fix, restarted training, and diagnosed accurately.
|
| 147 |
-
|
| 148 |
-
Example (Task 5 — BatchNorm Eval):
|
| 149 |
-
|
| 150 |
-
| Component | Points |
|
| 151 |
-
|-----------|--------|
|
| 152 |
-
| Inspected gradients | +0.05 |
|
| 153 |
-
| Inspected model modes (the revealing action) | +0.05 |
|
| 154 |
-
| Fixed model mode | +0.25 |
|
| 155 |
-
| Restarted training | +0.30 |
|
| 156 |
-
| Correct diagnosis | +0.40 |
|
| 157 |
-
| Fell for red herring (add_callback after normal gradients) | -0.20 |
|
| 158 |
-
|
| 159 |
-
An agent that chases the gradient spike red herring loses 0.20 points. An agent that goes straight to model modes and finds the real problem scores 1.0.
|
| 160 |
|
| 161 |
-
|
| 162 |
|
| 163 |
-
##
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|------|-----------|-----------|--------------|
|
| 167 |
-
| `task_001` | Easy | **1.00** | 0.60 |
|
| 168 |
-
| `task_002` | Easy | **1.00** | 0.05 |
|
| 169 |
-
| `task_003` | Medium | **1.00** | 0.40 |
|
| 170 |
-
| `task_004` | Medium | **1.00** | 0.60 |
|
| 171 |
-
| `task_005` | Hard | **0.80** | 0.38-0.55 |
|
| 172 |
-
| `task_006` | Hard | **0.81** | 0.60-1.00 |
|
| 173 |
-
| `task_007` | Hard | **0.79** | 0.60 |
|
| 174 |
-
| **Average** | | **0.91** | 0.52 |
|
| 175 |
|
| 176 |
-
**
|
| 177 |
-
- **Hard tasks are genuinely hard:** All three hard tasks (5, 6, 7) require thorough investigation including weight inspection for full credit. The heuristic scores 0.79-0.81 on hard tasks because it skips weight inspection. An LLM that falls for red herrings or skips investigation scores even lower.
|
| 178 |
-
- **Red herring traps work:** Task 5 penalizes agents that call `add_callback` after seeing normal gradients (-0.20) or `modify_config` when LR isn't the issue (-0.10). LLMs routinely fall for both traps.
|
| 179 |
-
- **Investigation thoroughness matters:** Tasks 6 and 7 scale fix/restart credit based on how thoroughly the agent investigated before acting. Quick fixes without ruling out alternatives score ~60-65% of full credit.
|
| 180 |
-
- **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
|
| 181 |
-
- **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
|
| 182 |
|
| 183 |
-
|
| 184 |
|
| 185 |
-
```
|
| 186 |
-
# Heuristic (deterministic, no API key, bit-exact reproducible)
|
| 187 |
-
python3 baseline_heuristic.py
|
| 188 |
-
|
| 189 |
-
# LLM (hackathon evaluator format — uses OpenAI client)
|
| 190 |
-
API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o OPENAI_API_KEY=sk-... python3 inference.py
|
| 191 |
-
```
|
| 192 |
-
|
| 193 |
-
## API
|
| 194 |
-
|
| 195 |
-
### HTTP Endpoints
|
| 196 |
-
|
| 197 |
-
| Endpoint | Method | Description |
|
| 198 |
-
|----------|--------|-------------|
|
| 199 |
-
| `/health` | GET | `{"status": "healthy", "tasks": 7}` |
|
| 200 |
-
| `/tasks` | GET | Task list with IDs, difficulties, action schema |
|
| 201 |
-
| `/grader` | POST | Score for last completed episode |
|
| 202 |
-
| `/baseline` | POST | Run heuristic on all tasks, return scores |
|
| 203 |
-
| `/dashboard` | GET | Live 4-panel diagnostic dashboard |
|
| 204 |
-
| `/validation-report` | GET | Simulation fidelity report |
|
| 205 |
-
| `/curriculum` | GET | Recommended task order (easy to hard, difficulty 1-5) |
|
| 206 |
-
| `/leaderboard` | GET | Sorted episode scores |
|
| 207 |
-
| `/replay/{id}` | GET | Full action/observation trace for an episode |
|
| 208 |
-
| `/docs` | GET | Swagger UI (auto-generated by FastAPI) |
|
| 209 |
|
| 210 |
-
##
|
| 211 |
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
|
| 215 |
-
```json
|
| 216 |
-
{"type": "reset", "data": {"task_id": "task_003", "seed": 42}}
|
| 217 |
-
```
|
| 218 |
-
|
| 219 |
-
**Step** (take an action):
|
| 220 |
-
```json
|
| 221 |
-
{"type": "step", "data": {"action_type": "inspect_gradients"}}
|
| 222 |
-
{"type": "step", "data": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
|
| 223 |
-
{"type": "step", "data": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
|
| 224 |
-
```
|
| 225 |
|
| 226 |
-
|
| 227 |
-
```json
|
| 228 |
-
{"type": "observation", "data": {"observation": {...}, "reward": 0.04, "done": false}}
|
| 229 |
-
```
|
| 230 |
|
| 231 |
-
|
| 232 |
|
| 233 |
-
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
-
|
| 241 |
|
| 242 |
## Setup
|
| 243 |
|
| 244 |
-
### Local Development
|
| 245 |
-
|
| 246 |
```bash
|
| 247 |
-
|
| 248 |
-
source .venv/bin/activate
|
| 249 |
-
|
| 250 |
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 251 |
pip install openenv-core pydantic fastapi uvicorn
|
| 252 |
-
pip install pytest pytest-cov pytest-asyncio httpx websockets
|
| 253 |
-
|
| 254 |
-
# Start server
|
| 255 |
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 256 |
|
| 257 |
-
#
|
| 258 |
-
pytest tests/ -v --cov=ml_training_debugger
|
| 259 |
-
|
| 260 |
-
# Run heuristic baseline
|
| 261 |
-
python3 baseline_heuristic.py
|
| 262 |
-
```
|
| 263 |
-
|
| 264 |
-
### Docker
|
| 265 |
-
|
| 266 |
-
```bash
|
| 267 |
docker build -t pytorch-debugger .
|
| 268 |
docker run -p 7860:7860 pytorch-debugger
|
| 269 |
-
curl http://localhost:7860/health
|
| 270 |
-
```
|
| 271 |
-
|
| 272 |
-
### Smoke Test
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
curl http://localhost:7860/tasks | python3 -m json.tool
|
| 278 |
-
curl -X POST http://localhost:7860/baseline | python3 -m json.tool
|
| 279 |
-
curl -X POST http://localhost:7860/grader | python3 -m json.tool
|
| 280 |
-
|
| 281 |
-
# Reproducibility (must produce no diff)
|
| 282 |
-
python3 baseline_heuristic.py > run1.json
|
| 283 |
-
python3 baseline_heuristic.py > run2.json
|
| 284 |
-
diff run1.json run2.json
|
| 285 |
```
|
| 286 |
|
| 287 |
-
##
|
| 288 |
|
| 289 |
```
|
| 290 |
ml_training_debugger/
|
| 291 |
-
models.py
|
| 292 |
-
scenarios.py
|
| 293 |
-
pytorch_engine.py
|
| 294 |
-
simulation.py
|
| 295 |
-
reward_engine.py
|
| 296 |
-
graders.py
|
| 297 |
-
code_templates.py
|
| 298 |
-
client.py — Typed client extending EnvClient
|
| 299 |
-
|
| 300 |
server/
|
| 301 |
-
environment.py
|
| 302 |
-
app.py
|
| 303 |
-
dashboard.html
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
baseline_heuristic.py — Rule-based agent (deterministic, no API key)
|
| 307 |
-
inference.py — LLM agent (OpenAI client, hackathon format)
|
| 308 |
```
|
| 309 |
|
| 310 |
-
|
| 311 |
-
- **Grader is separate from reward function.** `reward_engine.py` returns a float per step for RL training signal. `graders.py` returns a holistic 0.0-1.0 score at episode end. They are different modules with different purposes.
|
| 312 |
-
- **Task IDs are opaque.** `task_001` through `task_007` — the agent cannot infer the diagnosis from the ID.
|
| 313 |
-
- **Task 6 diagnosis is always `code_bug`.** Regardless of which bug variant (eval_mode, detach_loss, zero_grad_missing, inplace_relu), the correct diagnosis is `code_bug`.
|
| 314 |
-
- **Dual model architectures.** SimpleCNN and SimpleMLP are randomly selected per episode, testing agent robustness to architecture variation.
|
| 315 |
-
- **Session isolation.** Each WebSocket connection gets its own environment instance with independent state.
|
| 316 |
-
- **`step()` never raises.** All invalid actions return a valid observation with -0.05 penalty and an error note.
|
| 317 |
-
|
| 318 |
-
### Technical Stack
|
| 319 |
-
|
| 320 |
-
- Python 3.12 · PyTorch 2.5.1 CPU-only · openenv-core v0.2.2
|
| 321 |
-
- `import torch` in every core module — zero numpy in core
|
| 322 |
-
- Typed Pydantic v2 models everywhere — no `Dict[str, Any]`
|
| 323 |
-
- Deterministic reproducibility via `torch.manual_seed()` at every reset
|
| 324 |
-
- Docker image: 885MB (multi-stage build, `strip --strip-unneeded`, transitive dep cleanup)
|
| 325 |
-
|
| 326 |
-
### Validation Suite
|
| 327 |
-
|
| 328 |
-
8/8 validation checks pass. Real PyTorch 20-epoch mini-training with fault injection. Each fault type is validated with behavioral checks (gradient detection, loss patterns, model mode, code fix acceptance). Both SimpleCNN and SimpleMLP architectures verified. Results served live at `GET /validation-report`.
|
| 329 |
-
|
| 330 |
-
## Walkthrough: Solving Task 5 (Hard)
|
| 331 |
-
|
| 332 |
-
This is the most interesting task because it has red herrings designed to mislead.
|
| 333 |
-
|
| 334 |
-
**What the agent sees on reset:**
|
| 335 |
-
- Loss oscillates between 2.1-2.5, never converging
|
| 336 |
-
- Val accuracy stuck at ~0.15
|
| 337 |
-
- Error log mentions GPU memory at 91%
|
| 338 |
-
|
| 339 |
-
**Step 1: Inspect gradients**
|
| 340 |
-
```
|
| 341 |
-
conv1: mean_norm=0.15, Normal
|
| 342 |
-
conv2: mean_norm=5.2, Normal
|
| 343 |
-
conv3: mean_norm=0.8, Normal
|
| 344 |
-
fc: mean_norm=1.3, Normal (slight spike — the red herring)
|
| 345 |
-
```
|
| 346 |
-
All layers normal. `gradients_were_normal` is now True.
|
| 347 |
-
|
| 348 |
-
**Step 2: Inspect data** — class overlap 0.0, data is clean.
|
| 349 |
-
|
| 350 |
-
**Step 3: Inspect model modes**
|
| 351 |
-
```
|
| 352 |
-
conv1: "eval" ← Problem found!
|
| 353 |
-
bn1: "eval"
|
| 354 |
-
conv2: "eval"
|
| 355 |
-
bn2: "eval"
|
| 356 |
-
fc: "eval"
|
| 357 |
-
```
|
| 358 |
-
All layers stuck in eval mode. BatchNorm is using running statistics instead of batch statistics during training.
|
| 359 |
-
|
| 360 |
-
**Step 4: Fix model mode** — switches all layers to train mode.
|
| 361 |
-
|
| 362 |
-
**Step 5: Restart training** — convergence confirmed.
|
| 363 |
-
|
| 364 |
-
**Step 6: Diagnose `batchnorm_eval_mode`** — correct. Score: 1.0.
|
| 365 |
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
-
|
|
|
|
| 13 |
|
| 14 |
# PyTorch Training Run Debugger
|
| 15 |
|
| 16 |
+
An OpenEnv RL environment where AI agents debug broken PyTorch training runs.
|
| 17 |
|
| 18 |
+
Built for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, 2026.
|
| 19 |
|
| 20 |
+
[Live Demo](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/dashboard) | [API Health](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/health) | [API Docs](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/docs)
|
| 21 |
|
| 22 |
---
|
| 23 |
|
| 24 |
+
## Why I Built This
|
| 25 |
|
| 26 |
+
Every ML engineer has been there: your model trains for hours, doesn't crash, doesn't throw errors, but the loss just won't go down. You stare at TensorBoard, tweak the learning rate, restart, repeat. It's tedious, time-consuming, and hard to teach. I wanted to turn that debugging experience into an RL environment so agents can learn to do it too.
|
| 27 |
|
| 28 |
+
## How It Works
|
| 29 |
|
| 30 |
+
The environment drops the agent into a broken PyTorch training run. The agent sees loss curves, config, and error logs — but not much else. It has to actively investigate (inspect gradients, look at data, check model modes, read the code) to figure out what's wrong.
|
| 31 |
|
| 32 |
+
Once it thinks it knows the problem, it applies a fix, restarts training, and submits a diagnosis. The grader scores the whole episode — not just whether the answer was right, but whether the agent investigated properly before acting.
|
| 33 |
|
| 34 |
+
There are 7 tasks covering common ML failures: exploding/vanishing gradients, data leakage, overfitting, BatchNorm stuck in eval mode, bugs in the training loop, and misconfigured LR schedulers. The hard tasks have red herrings that punish agents for jumping to conclusions.
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
## What's Under the Hood
|
| 37 |
|
| 38 |
+
- **Real PyTorch, not fake data.** Gradients come from `torch.autograd`, weights from `model.state_dict()`. The env runs actual `torch.nn.Module` models (SimpleCNN, SimpleMLP), does 20 real forward+backward passes per reset, and caches the results.
|
| 39 |
+
- **Context-gated rewards.** If an agent adds gradient clipping after already seeing that gradients are normal, it gets penalized. If it does it before inspecting, no penalty. The reward depends on what the agent knows, not just what it does.
|
| 40 |
+
- **Code-level debugging.** Task 6 presents buggy Python training loops. The agent reads the code, finds the bug, and submits a fix. Four bug variants: `model.eval()` left in, `.detach()` killing gradients, missing `zero_grad()`, and `inplace=True` on ReLU.
|
| 41 |
+
- **Red herrings on hard tasks.** Task 5 plants a suspicious gradient spike and a GPU memory warning. Both are distractions. The real problem is only visible through model mode inspection.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
## Tasks
|
| 44 |
|
|
|
|
|
|
|
| 45 |
| ID | Difficulty | Root Cause | What Goes Wrong |
|
| 46 |
|----|-----------|------------|-----------------|
|
| 47 |
+
| `task_001` | Easy | `lr_too_high` | Gradients explode, NaN in loss |
|
| 48 |
+
| `task_002` | Easy | `vanishing_gradients` | Deeper layers vanish, loss stays flat |
|
| 49 |
+
| `task_003` | Medium | `data_leakage` | Suspiciously high val accuracy from epoch 1 |
|
| 50 |
+
| `task_004` | Medium | `overfitting` | Train loss drops, val loss climbs |
|
| 51 |
+
| `task_005` | Hard | `batchnorm_eval_mode` | Slow degradation, gradient red herrings |
|
| 52 |
+
| `task_006` | Hard | `code_bug` | Buggy training loop (4 variants) |
|
| 53 |
+
| `task_007` | Hard | `scheduler_misconfigured` | LR decays too aggressively |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
Easy tasks have one obvious signal. Medium tasks need multiple inspections. Hard tasks actively mislead you.
|
| 56 |
|
| 57 |
+
## Actions
|
| 58 |
|
| 59 |
+
**Investigate:** `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
**Fix:** `modify_config`, `add_callback`, `patch_data_loader`, `fix_model_mode`, `fix_code`, `replace_optimizer`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
**Terminal:** `restart_run` (needs a fix first), `mark_diagnosed` (submit diagnosis)
|
| 64 |
|
| 65 |
+
Actions are dynamic — `fix_code` only unlocks after code inspection, `restart_run` only after a fix.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
## Reward Signal
|
| 68 |
|
| 69 |
+
| Event | Reward |
|
| 70 |
+
|-------|--------|
|
| 71 |
+
| Any step | -0.01 |
|
| 72 |
+
| First-time inspection | +0.05 |
|
| 73 |
+
| Correct diagnosis | +0.50 |
|
| 74 |
+
| Wrong diagnosis | -0.30 |
|
| 75 |
+
| Convergence after fix+restart | +0.40 |
|
| 76 |
+
| Invalid action | -0.05 |
|
| 77 |
+
| Context-gated penalty | -0.20 |
|
| 78 |
|
| 79 |
+
The context-gated penalty fires when: agent inspected gradients, saw they were normal, and still applied gradient clipping. It's a penalty for ignoring evidence.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
## Grading
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
Each task has a holistic grader (separate from the per-step reward) that looks at the full episode: did the agent investigate the right things, apply the correct fix, restart training, and diagnose accurately? Scores are 0-1.
|
| 84 |
|
| 85 |
+
## Baseline Results
|
| 86 |
|
| 87 |
+
| Task | Heuristic | Llama 3.1 8B |
|
| 88 |
+
|------|-----------|--------------|
|
| 89 |
+
| task_001 (Easy) | 1.00 | 0.60 |
|
| 90 |
+
| task_002 (Easy) | 1.00 | 0.05 |
|
| 91 |
+
| task_003 (Medium) | 1.00 | 0.40 |
|
| 92 |
+
| task_004 (Medium) | 1.00 | 0.60 |
|
| 93 |
+
| task_005 (Hard) | 0.80 | 0.38-0.55 |
|
| 94 |
+
| task_006 (Hard) | 0.81 | 0.60-1.00 |
|
| 95 |
+
| task_007 (Hard) | 0.79 | 0.60 |
|
| 96 |
+
| **Average** | **0.91** | **0.52** |
|
| 97 |
|
| 98 |
+
The heuristic is strong because it knows the task structure. An LLM has to figure it out from observations.
|
| 99 |
|
| 100 |
## Setup
|
| 101 |
|
|
|
|
|
|
|
| 102 |
```bash
|
| 103 |
+
# Local
|
| 104 |
+
python3 -m venv .venv && source .venv/bin/activate
|
|
|
|
| 105 |
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 106 |
pip install openenv-core pydantic fastapi uvicorn
|
|
|
|
|
|
|
|
|
|
| 107 |
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 108 |
|
| 109 |
+
# Docker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
docker build -t pytorch-debugger .
|
| 111 |
docker run -p 7860:7860 pytorch-debugger
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
# Baselines
|
| 114 |
+
python3 baseline_heuristic.py
|
| 115 |
+
API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o HF_TOKEN=sk-... python3 inference.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
```
|
| 117 |
|
| 118 |
+
## Project Structure
|
| 119 |
|
| 120 |
```
|
| 121 |
ml_training_debugger/
|
| 122 |
+
models.py - Data models (Action, Observation, EpisodeState)
|
| 123 |
+
scenarios.py - Task parameter sampling
|
| 124 |
+
pytorch_engine.py - Real PyTorch models and fault injection
|
| 125 |
+
simulation.py - 20-epoch training with fault injection
|
| 126 |
+
reward_engine.py - Per-step reward with context gating
|
| 127 |
+
graders.py - Per-task holistic scoring
|
| 128 |
+
code_templates.py - Task 6 bug variants + fix validation
|
|
|
|
|
|
|
| 129 |
server/
|
| 130 |
+
environment.py - MLTrainingEnvironment (reset/step/state)
|
| 131 |
+
app.py - FastAPI app + endpoints
|
| 132 |
+
dashboard.html - Live diagnostic dashboard (Plotly.js)
|
| 133 |
+
inference.py - LLM agent (OpenAI client, hackathon format)
|
| 134 |
+
baseline_heuristic.py - Rule-based agent (no API key needed)
|
|
|
|
|
|
|
| 135 |
```
|
| 136 |
|
| 137 |
+
## API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
| Endpoint | Method | Description |
|
| 140 |
+
|----------|--------|-------------|
|
| 141 |
+
| `/health` | GET | Health check |
|
| 142 |
+
| `/tasks` | GET | Task list with action schema |
|
| 143 |
+
| `/grader` | POST | Score for last completed episode |
|
| 144 |
+
| `/baseline` | POST | Run heuristic on all tasks |
|
| 145 |
+
| `/dashboard` | GET | Live diagnostic dashboard |
|
| 146 |
+
| `/docs` | GET | Swagger UI |
|
| 147 |
|
| 148 |
+
WebSocket at `/ws` for full episode sessions (reset, step, observe).
|