File size: 6,856 Bytes
1d45444
 
 
 
 
 
 
 
 
 
 
 
 
e2f8b29
 
9c7163b
e2f8b29
9c7163b
4414fa9
9c7163b
e2f8b29
d222546
e2f8b29
9c7163b
d222546
9c7163b
d222546
9c7163b
d222546
9c7163b
d222546
9c7163b
e2f8b29
9c7163b
d222546
9c7163b
d222546
9c7163b
 
 
 
d222546
 
e2f8b29
d222546
 
9c7163b
 
 
 
 
 
 
e2f8b29
9c7163b
e2f8b29
9c7163b
43647d3
9c7163b
43647d3
9c7163b
43647d3
9c7163b
43647d3
9c7163b
d222546
9c7163b
d222546
9c7163b
 
 
 
 
 
 
 
 
d222546
9c7163b
d222546
9c7163b
d222546
9c7163b
d222546
9c7163b
d222546
9c7163b
 
 
 
 
 
 
 
 
 
d222546
9c7163b
d222546
e2f8b29
 
 
9c7163b
 
e2f8b29
 
 
 
9c7163b
e2f8b29
 
 
9c7163b
 
 
4f58e42
 
9c7163b
4f58e42
 
d222546
9c7163b
 
 
 
 
 
 
d222546
9c7163b
 
 
 
 
4f58e42
 
9c7163b
d222546
9c7163b
 
 
 
 
 
 
 
f4c428c
9c7163b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
---
title: PyTorch Training Run Debugger
emoji: 🔧
colorFrom: purple
colorTo: blue
sdk: docker
app_port: 7860
tags:
  - openenv
  - pytorch
  - reinforcement-learning
---

# PyTorch Training Run Debugger

An OpenEnv RL environment where AI agents debug broken PyTorch training runs.

Built for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, 2026.

[Live Demo](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/dashboard) | [API Health](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/health) | [API Docs](https://ujjwalpardeshi-pytorch-training-debugger.hf.space/docs)

---

## Why I Built This

Every ML engineer has been there: your model trains for hours, doesn't crash, doesn't throw errors, but the loss just won't go down. You stare at TensorBoard, tweak the learning rate, restart, repeat. It's tedious, time-consuming, and hard to teach. I wanted to turn that debugging experience into an RL environment so agents can learn to do it too.

## How It Works

The environment drops the agent into a broken PyTorch training run. The agent sees loss curves, config, and error logs — but not much else. It has to actively investigate (inspect gradients, look at data, check model modes, read the code) to figure out what's wrong.

Once it thinks it knows the problem, it applies a fix, restarts training, and submits a diagnosis. The grader scores the whole episode — not just whether the answer was right, but whether the agent investigated properly before acting.

There are 7 tasks covering common ML failures: exploding/vanishing gradients, data leakage, overfitting, BatchNorm stuck in eval mode, bugs in the training loop, and misconfigured LR schedulers. The hard tasks have red herrings that punish agents for jumping to conclusions.

## What's Under the Hood

- **Real PyTorch, not fake data.** Gradients come from `torch.autograd`, weights from `model.state_dict()`. The env runs actual `torch.nn.Module` models (SimpleCNN, SimpleMLP), does 20 real forward+backward passes per reset, and caches the results.
- **Context-gated rewards.** If an agent adds gradient clipping after already seeing that gradients are normal, it gets penalized. If it does it before inspecting, no penalty. The reward depends on what the agent knows, not just what it does.
- **Code-level debugging.** Task 6 presents buggy Python training loops. The agent reads the code, finds the bug, and submits a fix. Four bug variants: `model.eval()` left in, `.detach()` killing gradients, missing `zero_grad()`, and `inplace=True` on ReLU.
- **Red herrings on hard tasks.** Task 5 plants a suspicious gradient spike and a GPU memory warning. Both are distractions. The real problem is only visible through model mode inspection.

## Tasks

| ID | Difficulty | Root Cause | What Goes Wrong |
|----|-----------|------------|-----------------|
| `task_001` | Easy | `lr_too_high` | Gradients explode, NaN in loss |
| `task_002` | Easy | `vanishing_gradients` | Deeper layers vanish, loss stays flat |
| `task_003` | Medium | `data_leakage` | Suspiciously high val accuracy from epoch 1 |
| `task_004` | Medium | `overfitting` | Train loss drops, val loss climbs |
| `task_005` | Hard | `batchnorm_eval_mode` | Slow degradation, gradient red herrings |
| `task_006` | Hard | `code_bug` | Buggy training loop (4 variants) |
| `task_007` | Hard | `scheduler_misconfigured` | LR decays too aggressively |

Easy tasks have one obvious signal. Medium tasks need multiple inspections. Hard tasks actively mislead you.

## Actions

**Investigate:** `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code`

**Fix:** `modify_config`, `add_callback`, `patch_data_loader`, `fix_model_mode`, `fix_code`, `replace_optimizer`

**Terminal:** `restart_run` (needs a fix first), `mark_diagnosed` (submit diagnosis)

Actions are dynamic — `fix_code` only unlocks after code inspection, `restart_run` only after a fix.

## Reward Signal

| Event | Reward |
|-------|--------|
| Any step | -0.01 |
| First-time inspection | +0.05 |
| Correct diagnosis | +0.50 |
| Wrong diagnosis | -0.30 |
| Convergence after fix+restart | +0.40 |
| Invalid action | -0.05 |
| Context-gated penalty | -0.20 |

The context-gated penalty fires when: agent inspected gradients, saw they were normal, and still applied gradient clipping. It's a penalty for ignoring evidence.

## Grading

Each task has a holistic grader (separate from the per-step reward) that looks at the full episode: did the agent investigate the right things, apply the correct fix, restart training, and diagnose accurately? Scores are 0-1.

## Baseline Results

| Task | Heuristic | Llama 3.1 8B |
|------|-----------|--------------|
| task_001 (Easy) | 1.00 | 0.60 |
| task_002 (Easy) | 1.00 | 0.05 |
| task_003 (Medium) | 1.00 | 0.40 |
| task_004 (Medium) | 1.00 | 0.60 |
| task_005 (Hard) | 0.80 | 0.38-0.55 |
| task_006 (Hard) | 0.81 | 0.60-1.00 |
| task_007 (Hard) | 0.79 | 0.60 |
| **Average** | **0.91** | **0.52** |

The heuristic is strong because it knows the task structure. An LLM has to figure it out from observations.

## Setup

```bash
# Local
python3 -m venv .venv && source .venv/bin/activate
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install openenv-core pydantic fastapi uvicorn
uvicorn server.app:app --host 0.0.0.0 --port 7860

# Docker
docker build -t pytorch-debugger .
docker run -p 7860:7860 pytorch-debugger

# Baselines
python3 baseline_heuristic.py
API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4o HF_TOKEN=sk-... python3 inference.py
```

## Project Structure

```
ml_training_debugger/
    models.py            - Data models (Action, Observation, EpisodeState)
    scenarios.py         - Task parameter sampling
    pytorch_engine.py    - Real PyTorch models and fault injection
    simulation.py        - 20-epoch training with fault injection
    reward_engine.py     - Per-step reward with context gating
    graders.py           - Per-task holistic scoring
    code_templates.py    - Task 6 bug variants + fix validation
server/
    environment.py       - MLTrainingEnvironment (reset/step/state)
    app.py               - FastAPI app + endpoints
    dashboard.html       - Live diagnostic dashboard (Plotly.js)
inference.py             - LLM agent (OpenAI client, hackathon format)
baseline_heuristic.py    - Rule-based agent (no API key needed)
```

## API

| Endpoint | Method | Description |
|----------|--------|-------------|
| `/health` | GET | Health check |
| `/tasks` | GET | Task list with action schema |
| `/grader` | POST | Score for last completed episode |
| `/baseline` | POST | Run heuristic on all tasks |
| `/dashboard` | GET | Live diagnostic dashboard |
| `/docs` | GET | Swagger UI |

WebSocket at `/ws` for full episode sessions (reset, step, observe).