UjjwalPardeshi commited on
Commit Β·
d222546
1
Parent(s): 8435256
update readme
Browse files- README.md +270 -142
- baseline_heuristic.py +35 -28
README.md
CHANGED
|
@@ -1,114 +1,169 @@
|
|
| 1 |
# PyTorch Training Run Debugger
|
| 2 |
|
| 3 |
-
**OpenEnv RL Environment**
|
| 4 |
|
| 5 |
An AI agent debugs broken PyTorch training runs by investigating gradients, model weights, data pipelines, and source code to diagnose and fix real ML failure patterns.
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
1. **Investigate** β inspect gradients, data batches, model weights, model modes, and code
|
| 12 |
-
2. **Diagnose** β identify the root cause from
|
| 13 |
-
3. **Fix** β apply the correct intervention
|
| 14 |
-
4. **Verify** β restart training and confirm recovery before submitting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
##
|
| 25 |
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
| `training_loss_history` | `list[float]` (20 epochs) | Always |
|
| 31 |
| `val_accuracy_history` | `list[float]` (20 epochs) | Always |
|
| 32 |
| `val_loss_history` | `list[float]` (20 epochs) | Always |
|
| 33 |
| `current_config` | `TrainingConfig` | Always |
|
| 34 |
-
| `error_log` | `
|
| 35 |
| `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
|
| 36 |
-
| `model_weight_stats` | `
|
| 37 |
-
| `data_batch_stats` | `
|
| 38 |
-
| `model_mode_info` | `
|
| 39 |
-
| `code_snippet` | `
|
| 40 |
| `available_actions` | `list[str]` | Always (dynamic) |
|
| 41 |
| `episode_state` | `EpisodeState` | Always |
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|-------|-------------|
|
| 57 |
-
| `lr_too_high` | Learning rate too large |
|
| 58 |
-
| `vanishing_gradients` | Gradients decay to near-zero |
|
| 59 |
-
| `data_leakage` | Validation samples in training |
|
| 60 |
-
| `overfitting` | Model memorizing, failing to generalize |
|
| 61 |
-
| `batchnorm_eval_mode` | Model in eval mode during training |
|
| 62 |
-
| `code_bug` | Bug in PyTorch training code |
|
| 63 |
|
| 64 |
-
##
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
| 69 |
-
|
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
| Invalid action | -0.05 | Action not in `available_actions` |
|
| 72 |
-
|
|
| 73 |
-
|
|
| 74 |
-
| Convergence after fix+restart | +0.40 | All gates met |
|
| 75 |
|
| 76 |
-
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|----|-----------|------------|-------------|
|
| 80 |
-
| `task_001` | Easy | `lr_too_high` | Exploding gradients β all layers show `is_exploding: True`, NaN in error log |
|
| 81 |
-
| `task_002` | Easy | `vanishing_gradients` | Vanishing gradients β deeper layers show `is_vanishing: True`, flat loss curve |
|
| 82 |
-
| `task_003` | Medium | `data_leakage` | Silent data leakage β suspiciously high val accuracy, `class_overlap_score > 0.5` |
|
| 83 |
-
| `task_004` | Medium | `overfitting` | Train-val divergence β loss approaches 0 while val loss climbs |
|
| 84 |
-
| `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
|
| 85 |
-
| `task_006` | Hard | `code_bug` | PyTorch code bug β agent must read and fix actual Python code (4 bug variants) |
|
| 86 |
-
| `task_007` | Med-Hard | `scheduler_misconfigured` | LR scheduler with wrong gamma/step_size β training stagnates after initial progress |
|
| 87 |
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
## Baseline Scores
|
| 91 |
|
| 92 |
-
### Heuristic vs LLM Comparison
|
| 93 |
|
| 94 |
-
| Task | Difficulty | Heuristic | Llama 3.3 70B | Llama 3.1 8B |
|
| 95 |
-
|------|-----------|-----------|---------------|--------------|
|
| 96 |
-
| `task_001` | Easy | **1.00** | 1.00 | 0.60 |
|
| 97 |
-
| `task_002` | Easy | **1.00** | 1.00 | 0.05 |
|
| 98 |
-
| `task_003` | Medium | **1.00** | 0.40 | 0.40 |
|
| 99 |
-
| `task_004` | Medium |
|
| 100 |
-
| `task_005` | Hard | **1.00** | 1.00 | 1.00 |
|
| 101 |
-
| `task_006` | Hard | **1.00** | β | 0.60
|
| 102 |
-
| `task_007` | Med-Hard | **1.00** | β | 0.60 |
|
| 103 |
-
| **Average** | | **
|
| 104 |
|
| 105 |
-
*Llama 3.3 70B results are partial (5/7 tasks before rate limit).
|
| 106 |
|
| 107 |
-
**
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
4. **8B struggles on multi-step tasks:** Task 2 (0.05) shows small models can't maintain investigation strategy across many steps
|
| 112 |
|
| 113 |
### Running Baselines
|
| 114 |
|
|
@@ -116,40 +171,86 @@ All tasks support `difficulty_level` (1-5) via reset: `{"type": "reset", "data":
|
|
| 116 |
# Heuristic (deterministic, no API key, bit-exact reproducible)
|
| 117 |
python3 baseline_heuristic.py
|
| 118 |
|
| 119 |
-
# LLM (multi-provider support
|
| 120 |
-
python3 baseline_inference.py # Groq
|
| 121 |
-
python3 baseline_inference.py --provider cerebras # Cerebras (free)
|
| 122 |
-
python3 baseline_inference.py --provider gemini # Google Gemini
|
| 123 |
python3 baseline_inference.py --provider openai # OpenAI GPT-4o
|
| 124 |
|
| 125 |
# Run all baselines with comparison table
|
| 126 |
python3 run_all_baselines.py
|
| 127 |
```
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
## Setup
|
| 130 |
|
| 131 |
### Local Development
|
| 132 |
|
| 133 |
```bash
|
| 134 |
-
# Create virtual environment
|
| 135 |
python3 -m venv .venv
|
| 136 |
source .venv/bin/activate
|
| 137 |
|
| 138 |
-
# Install dependencies
|
| 139 |
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 140 |
pip install openenv-core pydantic fastapi uvicorn
|
| 141 |
-
|
| 142 |
-
# Install dev tools
|
| 143 |
-
pip install pytest pytest-cov black ruff isort
|
| 144 |
|
| 145 |
# Start server
|
| 146 |
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 147 |
|
| 148 |
-
# Run tests
|
| 149 |
pytest tests/ -v --cov=ml_training_debugger
|
| 150 |
|
| 151 |
-
# Run baseline
|
| 152 |
-
|
| 153 |
```
|
| 154 |
|
| 155 |
### Docker
|
|
@@ -160,82 +261,109 @@ docker run -p 7860:7860 pytorch-debugger
|
|
| 160 |
curl http://localhost:7860/health
|
| 161 |
```
|
| 162 |
|
| 163 |
-
##
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
| `/schema` | GET | Action/observation schemas (framework) |
|
| 177 |
-
| `/docs` | GET | Swagger UI (framework) |
|
| 178 |
-
|
| 179 |
-
### WebSocket Message Format
|
| 180 |
-
|
| 181 |
-
The primary agent interface is the WebSocket endpoint at `/ws`. Messages use JSON:
|
| 182 |
-
|
| 183 |
-
**Reset** (start a new episode, optionally select task):
|
| 184 |
-
```json
|
| 185 |
-
{"type": "reset"}
|
| 186 |
-
{"type": "reset", "data": {"task_id": "task_003", "seed": 42}}
|
| 187 |
```
|
| 188 |
-
Without `data`, defaults to `task_001`. With `data`, selects the specified task.
|
| 189 |
|
| 190 |
-
|
| 191 |
|
| 192 |
-
**Step** (execute an action):
|
| 193 |
-
```json
|
| 194 |
-
{"type": "step", "data": {"action_type": "inspect_gradients"}}
|
| 195 |
```
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
```
|
| 199 |
-
```json
|
| 200 |
-
{"type": "step", "data": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
|
| 201 |
-
```
|
| 202 |
-
Returns: `{"type": "observation", "data": {"observation": {...}, "reward": float, "done": bool}}`
|
| 203 |
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
-
|
| 207 |
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
-
|
| 211 |
|
| 212 |
-
|
| 213 |
|
| 214 |
-
|
| 215 |
|
| 216 |
-
|
| 217 |
|
| 218 |
-
**
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
|
| 223 |
-
- **Dual model architectures**: SimpleCNN (~50K params) + SimpleMLP (~20K params)
|
| 224 |
-
- **Real 20-epoch mini-training** per reset (cached per task/seed for instant replay)
|
| 225 |
-
- Typed Pydantic models everywhere β no `Dict[str, Any]`
|
| 226 |
-
- `import torch` in every core module β zero numpy in core
|
| 227 |
-
- Session isolation via per-session `EpisodeState`
|
| 228 |
-
- Deterministic reproducibility via `torch.manual_seed()`
|
| 229 |
-
- **251 tests, 95% coverage**
|
| 230 |
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
|
| 234 |
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
See [PAPER.md](PAPER.md) β "Context-Gated Reward Shaping for Evidence-Based ML Debugging"
|
| 238 |
|
| 239 |
-
|
|
|
|
|
|
|
| 240 |
|
| 241 |
-
|
|
|
|
| 1 |
# PyTorch Training Run Debugger
|
| 2 |
|
| 3 |
+
**OpenEnv RL Environment** | Meta PyTorch OpenEnv Hackathon x Scaler School of Technology
|
| 4 |
|
| 5 |
An AI agent debugs broken PyTorch training runs by investigating gradients, model weights, data pipelines, and source code to diagnose and fix real ML failure patterns.
|
| 6 |
|
| 7 |
+
---
|
| 8 |
|
| 9 |
+
## The Problem
|
| 10 |
+
|
| 11 |
+
ML teams spend 15-25% of engineer time debugging silent training failures β runs that produce no error, no crash, just mysteriously bad metrics. Each misdiagnosed restart wastes GPU compute at $2-8/hour/card. The diagnostic process is hard because multiple symptoms point to multiple causes, some bugs produce no error at all, and fixing the wrong thing wastes hours.
|
| 12 |
+
|
| 13 |
+
No existing OpenEnv environment covers this domain.
|
| 14 |
+
|
| 15 |
+
## What This Does
|
| 16 |
+
|
| 17 |
+
The environment recreates the experience of an ML engineer facing a broken training job. The agent receives a snapshot of a failing training run and must:
|
| 18 |
|
| 19 |
1. **Investigate** β inspect gradients, data batches, model weights, model modes, and code
|
| 20 |
+
2. **Diagnose** β identify the root cause from 7 known ML failure types
|
| 21 |
+
3. **Fix** β apply the correct intervention
|
| 22 |
+
4. **Verify** β restart training and confirm recovery before submitting
|
| 23 |
+
|
| 24 |
+
The agent starts with limited information (loss curves, config, error log) and must actively choose what to investigate. Each inspection reveals new data β gradient norms, class overlap scores, model train/eval modes, or buggy source code. This makes it a genuine investigation, not just a classification task.
|
| 25 |
+
|
| 26 |
+
## What Makes This Different
|
| 27 |
+
|
| 28 |
+
### Real PyTorch Model Internals
|
| 29 |
+
|
| 30 |
+
Every gradient comes from real `torch.autograd`. Every weight stat comes from real `model.state_dict()`. The environment instantiates actual `torch.nn.Module` models (SimpleCNN ~50K params, SimpleMLP ~20K params), runs 20 real forward+backward epochs per reset, and extracts real tensor statistics. Not synthetic formulas β real PyTorch computation, cached for instant replay.
|
| 31 |
+
|
| 32 |
+
### Context-Gated Reward Shaping
|
| 33 |
+
|
| 34 |
+
Standard RL environments use stateless rewards: "did action X happen?" This environment tracks the agent's information state and conditions penalties on what the agent has already observed.
|
| 35 |
+
|
| 36 |
+
An agent that adds gradient clipping *before* inspecting gradients follows a reasonable prior β **no penalty**. An agent that inspects gradients, sees they are normal, and *then* adds gradient clipping is ignoring counter-evidence β **-0.20 penalty**.
|
| 37 |
+
|
| 38 |
+
The gate requires two conditions to be jointly true (`gradients_inspected AND gradients_were_normal`), both of which depend on prior agent actions. This encodes a transferable skill into the reward signal: don't ignore what you've already learned.
|
| 39 |
+
|
| 40 |
+
### Code-Level Debugging
|
| 41 |
+
|
| 42 |
+
Task 6 presents actual buggy PyTorch training loops. The agent reads real Python code, identifies the buggy line, and submits a line-by-line fix. Four bug variants: `model.eval()` instead of `model.train()`, `.detach()` killing gradient flow, missing `optimizer.zero_grad()`, and `inplace=True` on ReLU corrupting the computation graph.
|
| 43 |
+
|
| 44 |
+
Fix validation uses a 4-strategy pipeline: whitespace normalization, token-stream comparison via Python's `tokenize` module, semantic pattern matching, and `ast.parse()` fallback. This handles the messy fixes that LLM agents actually produce (trailing spaces, inline comments, different indentation).
|
| 45 |
+
|
| 46 |
+
### Red Herring Injection
|
| 47 |
+
|
| 48 |
+
Task 5 (BatchNorm eval mode) deliberately plants misleading signals: a gradient spike in the FC layer that doesn't cross the exploding threshold, a GPU memory warning at 91%, and near-vanishing gradients in conv1. The real problem is only visible through model mode inspection. This separates agents that follow rigid patterns from agents that can reason through ambiguity.
|
| 49 |
+
|
| 50 |
+
## Tasks
|
| 51 |
|
| 52 |
+
7 failure scenarios across 3 difficulty tiers, each with configurable difficulty level (1-5):
|
| 53 |
|
| 54 |
+
| ID | Difficulty | Root Cause | What Goes Wrong |
|
| 55 |
+
|----|-----------|------------|-----------------|
|
| 56 |
+
| `task_001` | Easy | `lr_too_high` | All gradient layers explode, NaN in loss. Direct signal β inspect gradients, reduce LR. |
|
| 57 |
+
| `task_002` | Easy | `vanishing_gradients` | Deeper layers show vanishing norms, loss stays flat. Model can't learn. |
|
| 58 |
+
| `task_003` | Medium | `data_leakage` | Suspiciously high val accuracy from epoch 1. `class_overlap_score > 0.5` confirms test data leaked into training. Red herring note about "architecture upgrade." |
|
| 59 |
+
| `task_004` | Medium | `overfitting` | Train loss drops to near-zero while val loss climbs. Classic memorization pattern. |
|
| 60 |
+
| `task_005` | Hard | `batchnorm_eval_mode` | Slow degradation with compound red herrings. Gradients look normal. The real problem: all layers stuck in eval mode. |
|
| 61 |
+
| `task_006` | Hard | `code_bug` | Metrics are anomalous but gradients/data/modes look fine. Root cause is in the Python training loop β 4 possible bug variants. |
|
| 62 |
+
| `task_007` | Med-Hard | `scheduler_misconfigured` | Training improves initially then stagnates. LR scheduler decays too aggressively (low gamma, small step size). |
|
| 63 |
|
| 64 |
+
### How Difficulty Scales
|
| 65 |
|
| 66 |
+
Easy tasks have one obvious signal (all gradients exploding). Medium tasks require checking multiple sources and ruling out alternatives. Hard tasks deliberately mislead β the most obvious signal is wrong, and the real problem is hidden behind layers of investigation.
|
| 67 |
|
| 68 |
+
## Observation Space
|
| 69 |
+
|
| 70 |
+
| Field | Type | When Visible |
|
| 71 |
+
|-------|------|-------------|
|
| 72 |
| `training_loss_history` | `list[float]` (20 epochs) | Always |
|
| 73 |
| `val_accuracy_history` | `list[float]` (20 epochs) | Always |
|
| 74 |
| `val_loss_history` | `list[float]` (20 epochs) | Always |
|
| 75 |
| `current_config` | `TrainingConfig` | Always |
|
| 76 |
+
| `error_log` | `str` or `null` | Always |
|
| 77 |
| `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
|
| 78 |
+
| `model_weight_stats` | `list[ModelWeightStats]` | After `inspect_model_weights` |
|
| 79 |
+
| `data_batch_stats` | `DataBatchStats` | After `inspect_data_batch` |
|
| 80 |
+
| `model_mode_info` | `dict[str, str]` | After `inspect_model_modes` |
|
| 81 |
+
| `code_snippet` | `CodeSnippet` | After `inspect_code` |
|
| 82 |
| `available_actions` | `list[str]` | Always (dynamic) |
|
| 83 |
| `episode_state` | `EpisodeState` | Always |
|
| 84 |
|
| 85 |
+
Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_snippet` start as `null` and are only populated after the agent explicitly requests them. The agent must decide what to investigate.
|
| 86 |
+
|
| 87 |
+
## Action Space
|
| 88 |
+
|
| 89 |
+
14 action types in 3 categories:
|
| 90 |
|
| 91 |
+
**Investigation** β reveal hidden observation fields:
|
| 92 |
+
- `inspect_gradients` β per-layer gradient norms, is_exploding/is_vanishing flags
|
| 93 |
+
- `inspect_data_batch` β label distribution, class overlap score, confusion matrix
|
| 94 |
+
- `inspect_model_modes` β train/eval mode per layer
|
| 95 |
+
- `inspect_model_weights` β weight norms, dead neurons, NaN/Inf detection
|
| 96 |
+
- `inspect_code` β the actual Python training loop (Task 6)
|
| 97 |
|
| 98 |
+
**Fix** β apply an intervention:
|
| 99 |
+
- `modify_config` β change learning_rate, weight_decay, batch_size, optimizer, etc.
|
| 100 |
+
- `add_callback` β add gradient clipping
|
| 101 |
+
- `patch_data_loader` β fix data pipeline
|
| 102 |
+
- `fix_model_mode` β switch model to train mode
|
| 103 |
+
- `fix_code` β fix a specific line of code (requires line number + replacement)
|
| 104 |
+
- `replace_optimizer` β switch optimizer
|
| 105 |
|
| 106 |
+
**Terminal** β end the episode:
|
| 107 |
+
- `restart_run` β restart training (only available after a fix)
|
| 108 |
+
- `mark_diagnosed` β submit diagnosis from 7 possible root causes
|
| 109 |
|
| 110 |
+
Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
## Reward Function
|
| 113 |
|
| 114 |
+
Per-step signal, separate from the grader. Hard cap at [-1.0, 1.0].
|
| 115 |
+
|
| 116 |
+
| Event | Reward | Condition |
|
| 117 |
+
|-------|--------|-----------|
|
| 118 |
+
| Any step | -0.01 | Flat, unconditional (encourages efficiency) |
|
| 119 |
+
| First-time inspection | +0.05 | Per inspection type, first time only |
|
| 120 |
+
| Correct diagnosis | +0.50 | `diagnosis == root_cause` |
|
| 121 |
+
| Wrong diagnosis | -0.30 | `diagnosis != root_cause` |
|
| 122 |
+
| Convergence after fix+restart | +0.40 | Fix applied, restarted, training recovers |
|
| 123 |
| Invalid action | -0.05 | Action not in `available_actions` |
|
| 124 |
+
| Wrong code fix | -0.10 | `fix_code` with incorrect line/replacement |
|
| 125 |
+
| **Context-gated penalty** | **-0.20** | `gradients_inspected AND gradients_were_normal AND action == add_callback` |
|
|
|
|
| 126 |
|
| 127 |
+
The step penalty is flat -0.01 (never multiplied by step count). Investigation bonuses fire once per type. The context-gated penalty requires the agent to have previously inspected gradients and found them normal β it cannot fire before inspection.
|
| 128 |
|
| 129 |
+
## Grading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
Each task has a separate grader that evaluates the complete `EpisodeState` at episode end, returning a normalized 0.0-1.0 score. The grader is **not** a sum of step rewards β it's a holistic evaluation of whether the agent investigated correctly, applied the right fix, restarted training, and diagnosed accurately.
|
| 132 |
+
|
| 133 |
+
Example (Task 5 β BatchNorm Eval):
|
| 134 |
+
|
| 135 |
+
| Component | Points |
|
| 136 |
+
|-----------|--------|
|
| 137 |
+
| Inspected gradients | +0.05 |
|
| 138 |
+
| Inspected model modes (the revealing action) | +0.05 |
|
| 139 |
+
| Fixed model mode | +0.25 |
|
| 140 |
+
| Restarted training | +0.30 |
|
| 141 |
+
| Correct diagnosis | +0.40 |
|
| 142 |
+
| Fell for red herring (add_callback after normal gradients) | -0.20 |
|
| 143 |
+
|
| 144 |
+
An agent that chases the gradient spike red herring loses 0.20 points. An agent that goes straight to model modes and finds the real problem scores 1.0.
|
| 145 |
|
| 146 |
## Baseline Scores
|
| 147 |
|
| 148 |
+
### Heuristic vs LLM Comparison
|
| 149 |
|
| 150 |
+
| Task | Difficulty | Heuristic | Llama 3.3 70B | Llama 3.1 8B |
|
| 151 |
+
|------|-----------|-----------|---------------|--------------|
|
| 152 |
+
| `task_001` | Easy | **1.00** | 1.00 | 0.60 |
|
| 153 |
+
| `task_002` | Easy | **1.00** | 1.00 | 0.05 |
|
| 154 |
+
| `task_003` | Medium | **1.00** | 0.40 | 0.40 |
|
| 155 |
+
| `task_004` | Medium | **1.00** | 0.45 | 0.60 |
|
| 156 |
+
| `task_005` | Hard | **1.00** | 1.00 | 1.00 |
|
| 157 |
+
| `task_006` | Hard | **1.00** | β | 0.60-1.00 |
|
| 158 |
+
| `task_007` | Med-Hard | **1.00** | β | 0.60 |
|
| 159 |
+
| **Average** | | **1.00** | ~0.69* | 0.55 |
|
| 160 |
|
| 161 |
+
*Llama 3.3 70B results are partial (5/7 tasks before rate limit).
|
| 162 |
|
| 163 |
+
**What this tells you:**
|
| 164 |
+
- **Model size matters:** 70B scores ~25% higher than 8B. The environment scales with model capability.
|
| 165 |
+
- **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
|
| 166 |
+
- **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
|
|
|
|
| 167 |
|
| 168 |
### Running Baselines
|
| 169 |
|
|
|
|
| 171 |
# Heuristic (deterministic, no API key, bit-exact reproducible)
|
| 172 |
python3 baseline_heuristic.py
|
| 173 |
|
| 174 |
+
# LLM (multi-provider support)
|
| 175 |
+
python3 baseline_inference.py # Groq β Llama 3.3 70B (free)
|
| 176 |
+
python3 baseline_inference.py --provider cerebras # Cerebras β Llama 3.1 8B (free)
|
| 177 |
+
python3 baseline_inference.py --provider gemini # Google Gemini 2.0 Flash
|
| 178 |
python3 baseline_inference.py --provider openai # OpenAI GPT-4o
|
| 179 |
|
| 180 |
# Run all baselines with comparison table
|
| 181 |
python3 run_all_baselines.py
|
| 182 |
```
|
| 183 |
|
| 184 |
+
## API
|
| 185 |
+
|
| 186 |
+
### HTTP Endpoints
|
| 187 |
+
|
| 188 |
+
| Endpoint | Method | Description |
|
| 189 |
+
|----------|--------|-------------|
|
| 190 |
+
| `/health` | GET | `{"status": "ready", "tasks": 7}` |
|
| 191 |
+
| `/tasks` | GET | Task list with IDs, difficulties, action schema |
|
| 192 |
+
| `/grader` | POST | Score for last completed episode |
|
| 193 |
+
| `/baseline` | POST | Run heuristic on all tasks, return scores |
|
| 194 |
+
| `/dashboard` | GET | Live 4-panel diagnostic dashboard |
|
| 195 |
+
| `/validation-report` | GET | Simulation fidelity report |
|
| 196 |
+
| `/curriculum` | GET | Recommended task order (easy to hard, difficulty 1-5) |
|
| 197 |
+
| `/leaderboard` | GET | Sorted episode scores |
|
| 198 |
+
| `/replay/{id}` | GET | Full action/observation trace for an episode |
|
| 199 |
+
| `/schema` | GET | Action/observation JSON schemas |
|
| 200 |
+
| `/docs` | GET | Swagger UI |
|
| 201 |
+
|
| 202 |
+
### WebSocket (Primary Agent Interface)
|
| 203 |
+
|
| 204 |
+
The WebSocket endpoint at `/ws` maintains session state across a full episode. HTTP endpoints are stateless by framework design.
|
| 205 |
+
|
| 206 |
+
**Reset** (start an episode):
|
| 207 |
+
```json
|
| 208 |
+
{"type": "reset", "data": {"task_id": "task_003", "seed": 42}}
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
**Step** (take an action):
|
| 212 |
+
```json
|
| 213 |
+
{"type": "step", "data": {"action_type": "inspect_gradients"}}
|
| 214 |
+
{"type": "step", "data": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
|
| 215 |
+
{"type": "step", "data": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
**Response format:**
|
| 219 |
+
```json
|
| 220 |
+
{"type": "observation", "data": {"observation": {...}, "reward": 0.04, "done": false}}
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
## Dashboard
|
| 224 |
+
|
| 225 |
+
A live 4-panel diagnostic dashboard at `/dashboard`:
|
| 226 |
+
|
| 227 |
+
1. **Training Metrics** β loss/accuracy curves with Plotly.js
|
| 228 |
+
2. **Gradient & Weight Heatmap** β per-layer bars, color-coded (green=normal, red=exploding, blue=vanishing)
|
| 229 |
+
3. **Action Timeline & Rewards** β step-by-step bars showing reward per action and cumulative reward line
|
| 230 |
+
4. **Episode Summary** β state flags, available actions, code snippet (Task 6)
|
| 231 |
+
|
| 232 |
+
Select a task, click "Run Baseline", and watch the heuristic agent investigate, fix, and diagnose step by step. The charts update live over WebSocket.
|
| 233 |
+
|
| 234 |
## Setup
|
| 235 |
|
| 236 |
### Local Development
|
| 237 |
|
| 238 |
```bash
|
|
|
|
| 239 |
python3 -m venv .venv
|
| 240 |
source .venv/bin/activate
|
| 241 |
|
|
|
|
| 242 |
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 243 |
pip install openenv-core pydantic fastapi uvicorn
|
| 244 |
+
pip install pytest pytest-cov pytest-asyncio httpx websockets
|
|
|
|
|
|
|
| 245 |
|
| 246 |
# Start server
|
| 247 |
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 248 |
|
| 249 |
+
# Run tests (255 tests, 97% coverage)
|
| 250 |
pytest tests/ -v --cov=ml_training_debugger
|
| 251 |
|
| 252 |
+
# Run heuristic baseline
|
| 253 |
+
python3 baseline_heuristic.py
|
| 254 |
```
|
| 255 |
|
| 256 |
### Docker
|
|
|
|
| 261 |
curl http://localhost:7860/health
|
| 262 |
```
|
| 263 |
|
| 264 |
+
### Smoke Test
|
| 265 |
|
| 266 |
+
```bash
|
| 267 |
+
# Verify all critical paths
|
| 268 |
+
curl http://localhost:7860/health
|
| 269 |
+
curl http://localhost:7860/tasks | python3 -m json.tool
|
| 270 |
+
curl -X POST http://localhost:7860/baseline | python3 -m json.tool
|
| 271 |
+
curl -X POST http://localhost:7860/grader | python3 -m json.tool
|
| 272 |
+
|
| 273 |
+
# Reproducibility (must produce no diff)
|
| 274 |
+
python3 baseline_heuristic.py > run1.json
|
| 275 |
+
python3 baseline_heuristic.py > run2.json
|
| 276 |
+
diff run1.json run2.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
```
|
|
|
|
| 278 |
|
| 279 |
+
## Architecture
|
| 280 |
|
|
|
|
|
|
|
|
|
|
| 281 |
```
|
| 282 |
+
ml_training_debugger/
|
| 283 |
+
models.py β Pydantic data models (Action, Observation, EpisodeState)
|
| 284 |
+
scenarios.py β Task parameter sampling (7 tasks, deterministic per seed)
|
| 285 |
+
pytorch_engine.py β Real PyTorch models, fault injection, gradient/weight extraction
|
| 286 |
+
simulation.py β 20-epoch real training with parametric fallback
|
| 287 |
+
reward_engine.py β 7-component per-step reward with context gating
|
| 288 |
+
graders.py β Per-task holistic 0.0-1.0 scoring
|
| 289 |
+
code_templates.py β Task 6 bug variants + 4-strategy fix validation
|
| 290 |
+
client.py β Typed client extending EnvClient
|
| 291 |
+
|
| 292 |
+
server/
|
| 293 |
+
environment.py β MLTrainingEnvironment (reset/step/state)
|
| 294 |
+
app.py β FastAPI + custom endpoints
|
| 295 |
+
dashboard.html β Live Plotly.js diagnostic dashboard
|
| 296 |
+
|
| 297 |
+
tests/ β 255 tests, 97% coverage
|
| 298 |
+
baseline_heuristic.py β Rule-based agent (deterministic, no API key)
|
| 299 |
+
baseline_inference.py β LLM agent (Groq/Cerebras/Gemini/OpenAI)
|
| 300 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
**Key design decisions:**
|
| 303 |
+
- **Grader is separate from reward function.** `reward_engine.py` returns a float per step for RL training signal. `graders.py` returns a holistic 0.0-1.0 score at episode end. They are different modules with different purposes.
|
| 304 |
+
- **Task IDs are opaque.** `task_001` through `task_007` β the agent cannot infer the diagnosis from the ID.
|
| 305 |
+
- **Task 6 diagnosis is always `code_bug`.** Regardless of which bug variant (eval_mode, detach_loss, zero_grad_missing, inplace_relu), the correct diagnosis is `code_bug`.
|
| 306 |
+
- **Dual model architectures.** SimpleCNN and SimpleMLP are randomly selected per episode, testing agent robustness to architecture variation.
|
| 307 |
+
- **Session isolation.** Each WebSocket connection gets its own environment instance with independent state.
|
| 308 |
+
- **`step()` never raises.** All invalid actions return a valid observation with -0.05 penalty and an error note.
|
| 309 |
|
| 310 |
+
### Technical Stack
|
| 311 |
|
| 312 |
+
- Python 3.12 Β· PyTorch 2.5.1 CPU-only Β· openenv-core v0.2.2
|
| 313 |
+
- `import torch` in every core module β zero numpy in core
|
| 314 |
+
- Typed Pydantic v2 models everywhere β no `Dict[str, Any]`
|
| 315 |
+
- Deterministic reproducibility via `torch.manual_seed()` at every reset
|
| 316 |
+
- Docker image: 885MB (multi-stage build, `strip --strip-unneeded`, transitive dep cleanup)
|
| 317 |
|
| 318 |
+
### Validation Suite
|
| 319 |
|
| 320 |
+
8/8 validation checks pass. Real PyTorch 20-epoch mini-training with fault injection. Each fault type is validated with behavioral checks (gradient detection, loss patterns, model mode, code fix acceptance). Both SimpleCNN and SimpleMLP architectures verified. Results served live at `GET /validation-report`.
|
| 321 |
|
| 322 |
+
## Walkthrough: Solving Task 5 (Hard)
|
| 323 |
|
| 324 |
+
This is the most interesting task because it has red herrings designed to mislead.
|
| 325 |
|
| 326 |
+
**What the agent sees on reset:**
|
| 327 |
+
- Loss oscillates between 2.1-2.5, never converging
|
| 328 |
+
- Val accuracy stuck at ~0.15
|
| 329 |
+
- Error log mentions GPU memory at 91%
|
| 330 |
|
| 331 |
+
**Step 1: Inspect gradients**
|
| 332 |
+
```
|
| 333 |
+
conv1: mean_norm=0.15, Normal
|
| 334 |
+
conv2: mean_norm=5.2, Normal
|
| 335 |
+
conv3: mean_norm=0.8, Normal
|
| 336 |
+
fc: mean_norm=1.3, Normal (slight spike β the red herring)
|
| 337 |
+
```
|
| 338 |
+
All layers normal. `gradients_were_normal` is now True.
|
| 339 |
|
| 340 |
+
**Step 2: Inspect data** β class overlap 0.0, data is clean.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
+
**Step 3: Inspect model modes**
|
| 343 |
+
```
|
| 344 |
+
conv1: "eval" β Problem found!
|
| 345 |
+
bn1: "eval"
|
| 346 |
+
conv2: "eval"
|
| 347 |
+
bn2: "eval"
|
| 348 |
+
fc: "eval"
|
| 349 |
+
```
|
| 350 |
+
All layers stuck in eval mode. BatchNorm is using running statistics instead of batch statistics during training.
|
| 351 |
|
| 352 |
+
**Step 4: Fix model mode** β switches all layers to train mode.
|
| 353 |
|
| 354 |
+
**Step 5: Restart training** β convergence confirmed.
|
| 355 |
+
|
| 356 |
+
**Step 6: Diagnose `batchnorm_eval_mode`** β correct. Score: 1.0.
|
| 357 |
+
|
| 358 |
+
**What would have gone wrong:**
|
| 359 |
+
If the agent had seen the FC gradient spike and called `add_callback` (gradient clipping), it would have received -0.20 context-gated penalty β because it already knew gradients were normal. The penalty only fires when both `gradients_inspected=True` and `gradients_were_normal=True`. Before inspection, the same action would have no penalty.
|
| 360 |
+
|
| 361 |
+
## Research Summary
|
| 362 |
|
| 363 |
See [PAPER.md](PAPER.md) β "Context-Gated Reward Shaping for Evidence-Based ML Debugging"
|
| 364 |
|
| 365 |
+
Core claim: by conditioning penalties on the agent's accumulated information state (not just action outcomes), we create environments that reward systematic investigation over pattern-matching β a capability with direct transfer value to real-world MLOps debugging.
|
| 366 |
+
|
| 367 |
+
---
|
| 368 |
|
| 369 |
+
*Built for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, 2026.*
|
baseline_heuristic.py
CHANGED
|
@@ -89,36 +89,24 @@ def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
|
|
| 89 |
session = env._get_session()
|
| 90 |
return session.last_score if session and session.last_score is not None else 0.0
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
if (
|
| 102 |
-
(
|
|
|
|
| 103 |
and obs.data_batch_stats
|
| 104 |
-
and obs.data_batch_stats.class_overlap_score < 0.
|
| 105 |
):
|
| 106 |
-
|
| 107 |
-
MLTrainingAction(
|
| 108 |
-
action_type="modify_config",
|
| 109 |
-
target="weight_decay",
|
| 110 |
-
value=0.01,
|
| 111 |
-
)
|
| 112 |
-
)
|
| 113 |
-
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 114 |
-
obs = env.step(
|
| 115 |
-
MLTrainingAction(
|
| 116 |
-
action_type="mark_diagnosed",
|
| 117 |
-
diagnosis="overfitting",
|
| 118 |
-
)
|
| 119 |
-
)
|
| 120 |
-
session = env._get_session()
|
| 121 |
-
return session.last_score if session and session.last_score is not None else 0.0
|
| 122 |
|
| 123 |
# Step 3: inspect_model_modes
|
| 124 |
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
|
@@ -193,7 +181,26 @@ def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
|
|
| 193 |
session = env._get_session()
|
| 194 |
return session.last_score if session and session.last_score is not None else 0.0
|
| 195 |
|
| 196 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
obs = env.step(
|
| 198 |
MLTrainingAction(
|
| 199 |
action_type="mark_diagnosed",
|
|
|
|
| 89 |
session = env._get_session()
|
| 90 |
return session.last_score if session and session.last_score is not None else 0.0
|
| 91 |
|
| 92 |
+
# Detect overfitting pattern (used later, after ruling out code bugs)
|
| 93 |
+
_looks_like_overfitting = False
|
| 94 |
+
if obs.val_loss_history and obs.training_loss_history and len(obs.val_loss_history) >= 10:
|
| 95 |
+
early_train = sum(obs.training_loss_history[:5]) / 5
|
| 96 |
+
late_train = sum(obs.training_loss_history[-5:]) / 5
|
| 97 |
+
early_val = sum(obs.val_loss_history[:5]) / 5
|
| 98 |
+
late_val = sum(obs.val_loss_history[-5:]) / 5
|
| 99 |
+
train_dropped = late_train < early_train * 0.5
|
| 100 |
+
train_loss_low = late_train < 0.15
|
| 101 |
+
val_not_improving = late_val >= early_val * 0.95
|
| 102 |
+
gap_widening = (late_val - late_train) > (early_val - early_train)
|
| 103 |
if (
|
| 104 |
+
(train_dropped or train_loss_low)
|
| 105 |
+
and (val_not_improving or gap_widening)
|
| 106 |
and obs.data_batch_stats
|
| 107 |
+
and obs.data_batch_stats.class_overlap_score < 0.3
|
| 108 |
):
|
| 109 |
+
_looks_like_overfitting = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# Step 3: inspect_model_modes
|
| 112 |
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
|
|
|
| 181 |
session = env._get_session()
|
| 182 |
return session.last_score if session and session.last_score is not None else 0.0
|
| 183 |
|
| 184 |
+
# Overfitting fallback β only if code inspection didn't find a bug
|
| 185 |
+
if _looks_like_overfitting:
|
| 186 |
+
obs = env.step(
|
| 187 |
+
MLTrainingAction(
|
| 188 |
+
action_type="modify_config",
|
| 189 |
+
target="weight_decay",
|
| 190 |
+
value=0.01,
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 194 |
+
obs = env.step(
|
| 195 |
+
MLTrainingAction(
|
| 196 |
+
action_type="mark_diagnosed",
|
| 197 |
+
diagnosis="overfitting",
|
| 198 |
+
)
|
| 199 |
+
)
|
| 200 |
+
session = env._get_session()
|
| 201 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 202 |
+
|
| 203 |
+
# Final fallback
|
| 204 |
obs = env.step(
|
| 205 |
MLTrainingAction(
|
| 206 |
action_type="mark_diagnosed",
|