omkarrr88 commited on
Commit ·
9e6a926
1
Parent(s): e2f8b29
Remaining task added + full openenv compliance
Browse files- Dockerfile +11 -3
- README.md +15 -8
- baseline_heuristic.py +10 -5
- baseline_inference.py +150 -0
- ml_training_debugger/pytorch_engine.py +16 -7
- openenv.yaml +34 -4
- pyproject.toml +3 -0
- requirements.txt +0 -1
- server/app.py +111 -90
- server/dashboard.html +241 -0
- server/environment.py +7 -0
- tests/test_baseline_reproducibility.py +24 -0
- tests/test_endpoints.py +60 -0
- uv.lock +0 -0
- validation/requirements.txt +3 -0
- validation/validate_exploding_gradients.py +77 -0
Dockerfile
CHANGED
|
@@ -2,12 +2,20 @@ FROM python:3.12-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
# Install
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 7 |
|
| 8 |
-
# Install remaining dependencies
|
| 9 |
COPY requirements.txt .
|
| 10 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Copy application code
|
| 13 |
COPY ml_training_debugger/ ml_training_debugger/
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
# Install curl for healthcheck
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends curl && \
|
| 7 |
+
rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
# Install PyTorch CPU-only first (largest layer, cached separately)
|
| 10 |
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 11 |
|
| 12 |
+
# Install remaining dependencies (torch excluded from requirements.txt)
|
| 13 |
COPY requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt && \
|
| 15 |
+
find /usr/local/lib/python3.12/site-packages -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null; \
|
| 16 |
+
find /usr/local/lib/python3.12/site-packages -name "*.pyc" -delete 2>/dev/null; \
|
| 17 |
+
rm -rf /usr/local/lib/python3.12/site-packages/gradio/templates 2>/dev/null; \
|
| 18 |
+
true
|
| 19 |
|
| 20 |
# Copy application code
|
| 21 |
COPY ml_training_debugger/ ml_training_debugger/
|
README.md
CHANGED
|
@@ -76,18 +76,24 @@ Dynamic availability: `restart_run` requires a fix first; `fix_code` requires co
|
|
| 76 |
| ID | Difficulty | Root Cause | Description |
|
| 77 |
|----|-----------|------------|-------------|
|
| 78 |
| `task_001` | Easy | `lr_too_high` | Exploding gradients — all layers show `is_exploding: True`, NaN in error log |
|
|
|
|
| 79 |
| `task_003` | Medium | `data_leakage` | Silent data leakage — suspiciously high val accuracy, `class_overlap_score > 0.5` |
|
|
|
|
| 80 |
| `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
|
|
|
|
| 81 |
|
| 82 |
## Baseline Scores
|
| 83 |
|
| 84 |
-
Rule-based heuristic baseline (deterministic, no API key):
|
| 85 |
|
| 86 |
-
| Task | Score |
|
| 87 |
-
|------|-------|
|
| 88 |
-
| `task_001` | 1.00 |
|
| 89 |
-
| `
|
| 90 |
-
| `
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
## Setup
|
| 93 |
|
|
@@ -127,10 +133,11 @@ curl http://localhost:7860/health
|
|
| 127 |
|
| 128 |
| Endpoint | Method | Description |
|
| 129 |
|----------|--------|-------------|
|
| 130 |
-
| `/health` | GET | `{"status": "ready", "tasks":
|
| 131 |
| `/tasks` | GET | Task list with action schema |
|
| 132 |
| `/grader` | POST | Grader score for last completed episode |
|
| 133 |
-
| `/baseline` | POST | Run baseline, return scores |
|
|
|
|
| 134 |
| `/ws` | WebSocket | Primary agent interface |
|
| 135 |
| `/reset` | POST | Reset environment (framework) |
|
| 136 |
| `/step` | POST | Execute action (framework) |
|
|
|
|
| 76 |
| ID | Difficulty | Root Cause | Description |
|
| 77 |
|----|-----------|------------|-------------|
|
| 78 |
| `task_001` | Easy | `lr_too_high` | Exploding gradients — all layers show `is_exploding: True`, NaN in error log |
|
| 79 |
+
| `task_002` | Easy | `vanishing_gradients` | Vanishing gradients — deeper layers show `is_vanishing: True`, flat loss curve |
|
| 80 |
| `task_003` | Medium | `data_leakage` | Silent data leakage — suspiciously high val accuracy, `class_overlap_score > 0.5` |
|
| 81 |
+
| `task_004` | Medium | `overfitting` | Train-val divergence — loss approaches 0 while val loss climbs |
|
| 82 |
| `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
|
| 83 |
+
| `task_006` | Hard | `code_bug` | PyTorch code bug — agent must read and fix actual Python code (4 bug variants) |
|
| 84 |
|
| 85 |
## Baseline Scores
|
| 86 |
|
| 87 |
+
Rule-based heuristic baseline (deterministic, no API key, bit-exact reproducible):
|
| 88 |
|
| 89 |
+
| Task | Score | Notes |
|
| 90 |
+
|------|-------|-------|
|
| 91 |
+
| `task_001` | 1.00 | Direct signal: `is_exploding` on all layers |
|
| 92 |
+
| `task_002` | 1.00 | Direct signal: `is_vanishing` on deeper layers |
|
| 93 |
+
| `task_003` | 1.00 | `class_overlap_score > 0.5` triggers correct path |
|
| 94 |
+
| `task_004` | 0.45 | Heuristic must rule out leakage first |
|
| 95 |
+
| `task_005` | 0.35 | Fixed investigation order misses eval mode, diagnoses overfitting |
|
| 96 |
+
| `task_006` | 1.00 | Pattern-matching catches 2 of 4 bug variants |
|
| 97 |
|
| 98 |
## Setup
|
| 99 |
|
|
|
|
| 133 |
|
| 134 |
| Endpoint | Method | Description |
|
| 135 |
|----------|--------|-------------|
|
| 136 |
+
| `/health` | GET | `{"status": "ready", "tasks": 6}` |
|
| 137 |
| `/tasks` | GET | Task list with action schema |
|
| 138 |
| `/grader` | POST | Grader score for last completed episode |
|
| 139 |
+
| `/baseline` | POST | Run baseline, return scores for all 6 tasks |
|
| 140 |
+
| `/dashboard` | GET | Live diagnostic dashboard (Plotly.js, 4-panel) |
|
| 141 |
| `/ws` | WebSocket | Primary agent interface |
|
| 142 |
| `/reset` | POST | Reset environment (framework) |
|
| 143 |
| `/step` | POST | Execute action (framework) |
|
baseline_heuristic.py
CHANGED
|
@@ -14,12 +14,17 @@ import argparse
|
|
| 14 |
import json
|
| 15 |
import sys
|
| 16 |
|
| 17 |
-
from ml_training_debugger.
|
| 18 |
-
from ml_training_debugger.models import EpisodeState, MLTrainingAction, MLTrainingObservation
|
| 19 |
-
from ml_training_debugger.scenarios import sample_scenario
|
| 20 |
from server.environment import MLTrainingEnvironment
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
|
|
@@ -175,7 +180,7 @@ def main() -> None:
|
|
| 175 |
args = parser.parse_args()
|
| 176 |
|
| 177 |
scores: dict[str, float] = {}
|
| 178 |
-
for task_id in
|
| 179 |
score = run_heuristic_episode(task_id)
|
| 180 |
scores[task_id] = round(score, 4)
|
| 181 |
|
|
|
|
| 14 |
import json
|
| 15 |
import sys
|
| 16 |
|
| 17 |
+
from ml_training_debugger.models import MLTrainingAction
|
|
|
|
|
|
|
| 18 |
from server.environment import MLTrainingEnvironment
|
| 19 |
|
| 20 |
+
ALL_TASKS = [
|
| 21 |
+
"task_001",
|
| 22 |
+
"task_002",
|
| 23 |
+
"task_003",
|
| 24 |
+
"task_004",
|
| 25 |
+
"task_005",
|
| 26 |
+
"task_006",
|
| 27 |
+
]
|
| 28 |
|
| 29 |
|
| 30 |
def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
|
|
|
|
| 180 |
args = parser.parse_args()
|
| 181 |
|
| 182 |
scores: dict[str, float] = {}
|
| 183 |
+
for task_id in ALL_TASKS:
|
| 184 |
score = run_heuristic_episode(task_id)
|
| 185 |
scores[task_id] = round(score, 4)
|
| 186 |
|
baseline_inference.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""LLM baseline agent using OpenAI GPT-4o.
|
| 3 |
+
|
| 4 |
+
Optional — requires OPENAI_API_KEY environment variable.
|
| 5 |
+
Uses temperature=0.0 and seed=42 for near-deterministic behavior.
|
| 6 |
+
Spec reference: Section 17.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
OPENAI_API_KEY=... python baseline_inference.py [--url http://localhost:7860]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from openai import OpenAI
|
| 21 |
+
except ImportError:
|
| 22 |
+
print("Error: openai package not installed. Run: pip install openai")
|
| 23 |
+
sys.exit(1)
|
| 24 |
+
|
| 25 |
+
from ml_training_debugger.models import MLTrainingAction
|
| 26 |
+
from server.environment import MLTrainingEnvironment
|
| 27 |
+
|
| 28 |
+
ALL_TASKS = [
|
| 29 |
+
"task_001",
|
| 30 |
+
"task_002",
|
| 31 |
+
"task_003",
|
| 32 |
+
"task_004",
|
| 33 |
+
"task_005",
|
| 34 |
+
"task_006",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run.
|
| 38 |
+
You are interacting with an environment that simulates a broken training job.
|
| 39 |
+
|
| 40 |
+
Available actions (respond with JSON):
|
| 41 |
+
- {"action_type": "inspect_gradients"} - View gradient statistics per layer
|
| 42 |
+
- {"action_type": "inspect_data_batch"} - View data batch statistics
|
| 43 |
+
- {"action_type": "inspect_model_modes"} - View model layer modes (train/eval)
|
| 44 |
+
- {"action_type": "inspect_model_weights"} - View model weight statistics
|
| 45 |
+
- {"action_type": "inspect_code"} - View PyTorch training code
|
| 46 |
+
- {"action_type": "modify_config", "target": "<field>", "value": <val>} - Change a hyperparameter
|
| 47 |
+
- {"action_type": "add_callback"} - Add gradient clipping/scheduler
|
| 48 |
+
- {"action_type": "patch_data_loader"} - Fix data pipeline issues
|
| 49 |
+
- {"action_type": "fix_model_mode"} - Call model.train()
|
| 50 |
+
- {"action_type": "fix_code", "line": <int>, "replacement": "<code>"} - Fix a code line
|
| 51 |
+
- {"action_type": "restart_run"} - Restart training (requires a fix first)
|
| 52 |
+
- {"action_type": "mark_diagnosed", "diagnosis": "<cause>"} - Submit diagnosis
|
| 53 |
+
|
| 54 |
+
Valid diagnoses: lr_too_high, vanishing_gradients, data_leakage, overfitting, batchnorm_eval_mode, code_bug
|
| 55 |
+
|
| 56 |
+
Strategy:
|
| 57 |
+
1. First investigate by inspecting gradients, data, and model modes
|
| 58 |
+
2. Form a hypothesis based on the evidence
|
| 59 |
+
3. Apply the correct fix
|
| 60 |
+
4. Restart training to verify
|
| 61 |
+
5. Submit your diagnosis
|
| 62 |
+
|
| 63 |
+
Respond with ONLY a valid JSON action object, no explanation."""
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def run_llm_episode(task_id: str, client: OpenAI) -> float:
|
| 67 |
+
"""Run one LLM agent episode."""
|
| 68 |
+
env = MLTrainingEnvironment()
|
| 69 |
+
obs = env.reset(seed=42, episode_id=f"llm_{task_id}", task_id=task_id)
|
| 70 |
+
|
| 71 |
+
messages = [
|
| 72 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 73 |
+
{"role": "user", "content": f"New episode started. Observation:\n{json.dumps(obs.model_dump(), indent=2, default=str)[:3000]}"},
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
for step in range(20):
|
| 77 |
+
if obs.done:
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
response = client.chat.completions.create(
|
| 81 |
+
model="gpt-4o",
|
| 82 |
+
messages=messages,
|
| 83 |
+
temperature=0.0,
|
| 84 |
+
seed=42,
|
| 85 |
+
max_tokens=200,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
action_text = response.choices[0].message.content.strip()
|
| 89 |
+
messages.append({"role": "assistant", "content": action_text})
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
action_data = json.loads(action_text)
|
| 93 |
+
action = MLTrainingAction(**action_data)
|
| 94 |
+
except (json.JSONDecodeError, Exception) as e:
|
| 95 |
+
messages.append({"role": "user", "content": f"Invalid action: {e}. Try again with valid JSON."})
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
obs = env.step(action)
|
| 99 |
+
obs_summary = {
|
| 100 |
+
"reward": obs.reward,
|
| 101 |
+
"done": obs.done,
|
| 102 |
+
"step": obs.episode_state.step_count,
|
| 103 |
+
"available_actions": obs.available_actions,
|
| 104 |
+
"error_log": obs.error_log,
|
| 105 |
+
}
|
| 106 |
+
if obs.gradient_stats:
|
| 107 |
+
obs_summary["gradient_stats"] = [
|
| 108 |
+
{"layer": g.layer_name, "mean_norm": round(g.mean_norm, 4), "exploding": g.is_exploding, "vanishing": g.is_vanishing}
|
| 109 |
+
for g in obs.gradient_stats
|
| 110 |
+
]
|
| 111 |
+
if obs.data_batch_stats:
|
| 112 |
+
obs_summary["data_overlap"] = obs.data_batch_stats.class_overlap_score
|
| 113 |
+
if obs.model_mode_info:
|
| 114 |
+
obs_summary["model_modes"] = obs.model_mode_info
|
| 115 |
+
if obs.code_snippet:
|
| 116 |
+
obs_summary["code"] = obs.code_snippet.code[:500]
|
| 117 |
+
|
| 118 |
+
messages.append({"role": "user", "content": f"Observation:\n{json.dumps(obs_summary, indent=2, default=str)}"})
|
| 119 |
+
|
| 120 |
+
session = env._get_session()
|
| 121 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def main() -> None:
|
| 125 |
+
parser = argparse.ArgumentParser(description="LLM baseline agent (GPT-4o)")
|
| 126 |
+
parser.add_argument("--url", default="http://localhost:7860")
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
|
| 129 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 130 |
+
if not api_key:
|
| 131 |
+
print("Error: OPENAI_API_KEY environment variable not set")
|
| 132 |
+
sys.exit(1)
|
| 133 |
+
|
| 134 |
+
client = OpenAI(api_key=api_key)
|
| 135 |
+
scores: dict[str, float] = {}
|
| 136 |
+
|
| 137 |
+
for task_id in ALL_TASKS:
|
| 138 |
+
try:
|
| 139 |
+
score = run_llm_episode(task_id, client)
|
| 140 |
+
scores[task_id] = round(score, 4)
|
| 141 |
+
print(f" {task_id}: {score:.4f}", file=sys.stderr)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f" {task_id}: ERROR — {e}", file=sys.stderr)
|
| 144 |
+
scores[task_id] = 0.0
|
| 145 |
+
|
| 146 |
+
print(json.dumps(scores, indent=2))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
main()
|
ml_training_debugger/pytorch_engine.py
CHANGED
|
@@ -80,15 +80,24 @@ def create_model_and_inject_fault(
|
|
| 80 |
loss.backward()
|
| 81 |
|
| 82 |
elif scenario.root_cause.value == "vanishing_gradients":
|
| 83 |
-
#
|
|
|
|
| 84 |
model.train()
|
| 85 |
optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
elif scenario.root_cause.value == "data_leakage":
|
| 94 |
# Normal model — no gradient anomaly
|
|
|
|
| 80 |
loss.backward()
|
| 81 |
|
| 82 |
elif scenario.root_cause.value == "vanishing_gradients":
|
| 83 |
+
# Simulate vanishing gradients: run forward/backward then scale grads
|
| 84 |
+
# to simulate gradient decay through deep layers
|
| 85 |
model.train()
|
| 86 |
optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate)
|
| 87 |
+
optimizer.zero_grad()
|
| 88 |
+
output = model(batch_x)
|
| 89 |
+
loss = criterion(output, batch_y)
|
| 90 |
+
loss.backward()
|
| 91 |
+
# Scale gradients to simulate vanishing: deeper layers get smaller grads
|
| 92 |
+
depth_mult = scenario.depth_multiplier
|
| 93 |
+
layer_idx = 0
|
| 94 |
+
for name, param in model.named_parameters():
|
| 95 |
+
if param.grad is not None:
|
| 96 |
+
decay = torch.tensor(1e-7) * torch.exp(
|
| 97 |
+
torch.tensor(-depth_mult * layer_idx)
|
| 98 |
+
)
|
| 99 |
+
param.grad.data = param.grad.data * decay
|
| 100 |
+
layer_idx += 1
|
| 101 |
|
| 102 |
elif scenario.root_cause.value == "data_leakage":
|
| 103 |
# Normal model — no gradient anomaly
|
openenv.yaml
CHANGED
|
@@ -11,8 +11,8 @@ description: |
|
|
| 11 |
An AI agent investigates, diagnoses, fixes, and verifies broken
|
| 12 |
training runs using real torch.nn.Module models, torch.autograd
|
| 13 |
gradients, state_dict() weight inspection, and PyTorch code-level
|
| 14 |
-
debugging.
|
| 15 |
-
reward shaping.
|
| 16 |
framework: openenv
|
| 17 |
tags:
|
| 18 |
- ml-debugging
|
|
@@ -20,26 +20,55 @@ tags:
|
|
| 20 |
- reinforcement-learning
|
| 21 |
- root-cause-analysis
|
| 22 |
- fault-injection
|
|
|
|
| 23 |
- openenv
|
| 24 |
|
| 25 |
observation_space:
|
| 26 |
type: MLTrainingObservation
|
| 27 |
-
description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes revealed on inspection"
|
| 28 |
|
| 29 |
action_space:
|
| 30 |
type: MLTrainingAction
|
| 31 |
-
description: "Investigation, fix, and diagnosis actions with dynamic availability"
|
| 32 |
|
| 33 |
tasks:
|
| 34 |
- id: task_001
|
| 35 |
difficulty: easy
|
| 36 |
max_steps: 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
- id: task_003
|
| 38 |
difficulty: medium
|
| 39 |
max_steps: 25
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
- id: task_005
|
| 41 |
difficulty: hard
|
| 42 |
max_steps: 30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
reward:
|
| 45 |
range: [-1.0, 1.0]
|
|
@@ -56,3 +85,4 @@ endpoints:
|
|
| 56 |
grader: "POST /grader"
|
| 57 |
baseline: "POST /baseline"
|
| 58 |
health: "GET /health"
|
|
|
|
|
|
| 11 |
An AI agent investigates, diagnoses, fixes, and verifies broken
|
| 12 |
training runs using real torch.nn.Module models, torch.autograd
|
| 13 |
gradients, state_dict() weight inspection, and PyTorch code-level
|
| 14 |
+
debugging. 6 tasks across 3 difficulty tiers with context-gated
|
| 15 |
+
reward shaping and a live diagnostic dashboard.
|
| 16 |
framework: openenv
|
| 17 |
tags:
|
| 18 |
- ml-debugging
|
|
|
|
| 20 |
- reinforcement-learning
|
| 21 |
- root-cause-analysis
|
| 22 |
- fault-injection
|
| 23 |
+
- code-debugging
|
| 24 |
- openenv
|
| 25 |
|
| 26 |
observation_space:
|
| 27 |
type: MLTrainingObservation
|
| 28 |
+
description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes, and code snippets revealed on inspection"
|
| 29 |
|
| 30 |
action_space:
|
| 31 |
type: MLTrainingAction
|
| 32 |
+
description: "Investigation, fix, code-fix, and diagnosis actions with dynamic availability"
|
| 33 |
|
| 34 |
tasks:
|
| 35 |
- id: task_001
|
| 36 |
difficulty: easy
|
| 37 |
max_steps: 20
|
| 38 |
+
param_ranges:
|
| 39 |
+
learning_rate: [0.05, 0.08, 0.10, 0.15, 0.30]
|
| 40 |
+
|
| 41 |
+
- id: task_002
|
| 42 |
+
difficulty: easy
|
| 43 |
+
max_steps: 20
|
| 44 |
+
param_ranges:
|
| 45 |
+
learning_rate: [1e-6, 5e-6, 1e-5]
|
| 46 |
+
depth_multiplier: [1.0, 1.5, 2.0]
|
| 47 |
+
|
| 48 |
- id: task_003
|
| 49 |
difficulty: medium
|
| 50 |
max_steps: 25
|
| 51 |
+
param_ranges:
|
| 52 |
+
leakage_pct: [0.12, 0.18, 0.22, 0.28]
|
| 53 |
+
|
| 54 |
+
- id: task_004
|
| 55 |
+
difficulty: medium
|
| 56 |
+
max_steps: 25
|
| 57 |
+
param_ranges:
|
| 58 |
+
weight_decay: [0.0, 0.0001, 0.001]
|
| 59 |
+
divergence_epoch: [5, 8, 12]
|
| 60 |
+
|
| 61 |
- id: task_005
|
| 62 |
difficulty: hard
|
| 63 |
max_steps: 30
|
| 64 |
+
param_ranges:
|
| 65 |
+
red_herring_intensity: [0.8, 2.5]
|
| 66 |
+
|
| 67 |
+
- id: task_006
|
| 68 |
+
difficulty: hard
|
| 69 |
+
max_steps: 30
|
| 70 |
+
param_ranges:
|
| 71 |
+
bug_type: [eval_mode, detach_loss, zero_grad_missing, inplace_relu]
|
| 72 |
|
| 73 |
reward:
|
| 74 |
range: [-1.0, 1.0]
|
|
|
|
| 85 |
grader: "POST /grader"
|
| 86 |
baseline: "POST /baseline"
|
| 87 |
health: "GET /health"
|
| 88 |
+
dashboard: "GET /dashboard"
|
pyproject.toml
CHANGED
|
@@ -11,6 +11,9 @@ dependencies = [
|
|
| 11 |
"uvicorn",
|
| 12 |
]
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
[project.optional-dependencies]
|
| 15 |
dev = [
|
| 16 |
"pytest",
|
|
|
|
| 11 |
"uvicorn",
|
| 12 |
]
|
| 13 |
|
| 14 |
+
[project.scripts]
|
| 15 |
+
server = "server.app:main"
|
| 16 |
+
|
| 17 |
[project.optional-dependencies]
|
| 18 |
dev = [
|
| 19 |
"pytest",
|
requirements.txt
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
torch
|
| 2 |
openenv-core
|
| 3 |
pydantic>=2.0
|
| 4 |
fastapi
|
|
|
|
|
|
|
| 1 |
openenv-core
|
| 2 |
pydantic>=2.0
|
| 3 |
fastapi
|
server/app.py
CHANGED
|
@@ -1,32 +1,60 @@
|
|
| 1 |
"""FastAPI app — openenv create_app() + custom hackathon routes.
|
| 2 |
|
| 3 |
-
Spec reference: Sections 9, 14.
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import asyncio
|
|
|
|
| 9 |
import logging
|
|
|
|
| 10 |
from typing import Optional
|
| 11 |
|
| 12 |
from fastapi import FastAPI
|
| 13 |
-
from fastapi.responses import JSONResponse
|
| 14 |
from openenv.core.env_server.http_server import create_app
|
| 15 |
|
| 16 |
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
|
|
|
|
| 17 |
from server.environment import MLTrainingEnvironment
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
{"id": "task_001", "difficulty": "easy", "max_steps": 20},
|
|
|
|
| 28 |
{"id": "task_003", "difficulty": "medium", "max_steps": 25},
|
|
|
|
| 29 |
{"id": "task_005", "difficulty": "hard", "max_steps": 30},
|
|
|
|
| 30 |
]
|
| 31 |
|
| 32 |
# create_app takes the class (factory), not an instance
|
|
@@ -39,27 +67,34 @@ app: FastAPI = create_app(
|
|
| 39 |
)
|
| 40 |
|
| 41 |
# Override framework's /health route with our custom version
|
| 42 |
-
# Remove the framework's health route first
|
| 43 |
app.routes[:] = [
|
| 44 |
r for r in app.routes if not (hasattr(r, "path") and r.path == "/health")
|
| 45 |
]
|
| 46 |
|
| 47 |
-
#
|
| 48 |
_baseline_lock = asyncio.Lock()
|
| 49 |
-
_baseline_running = False
|
| 50 |
|
| 51 |
|
| 52 |
@app.get("/health")
|
| 53 |
def health_check() -> dict:
|
| 54 |
"""Health check — required by hackathon auto-validator."""
|
| 55 |
-
return {"status": "ready", "tasks": len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
@app.get("/tasks")
|
| 59 |
def get_tasks() -> list[dict]:
|
| 60 |
"""Return task list with IDs, difficulties, and action schema."""
|
| 61 |
schema = MLTrainingAction.model_json_schema()
|
| 62 |
-
return [{**task, "action_schema": schema} for task in
|
| 63 |
|
| 64 |
|
| 65 |
@app.post("/grader")
|
|
@@ -68,14 +103,8 @@ def post_grader(session_id: Optional[str] = None) -> dict:
|
|
| 68 |
|
| 69 |
Edge cases per spec Section 14:
|
| 70 |
- No episode completed → {"score": null, "error": "no_completed_episode"}
|
| 71 |
-
- Episode in progress → {"score": null, "error": "episode_in_progress"}
|
| 72 |
- Episode completed → {"score": float, "task_id": str, "steps": int}
|
| 73 |
"""
|
| 74 |
-
# Try to find the environment instance
|
| 75 |
-
# The framework manages environment instances internally,
|
| 76 |
-
# so we use the internal baseline results for the /grader endpoint
|
| 77 |
-
from server._baseline_results import get_last_grader_result
|
| 78 |
-
|
| 79 |
result = get_last_grader_result(session_id)
|
| 80 |
if result is None:
|
| 81 |
return {"score": None, "error": "no_completed_episode"}
|
|
@@ -86,36 +115,30 @@ def post_grader(session_id: Optional[str] = None) -> dict:
|
|
| 86 |
async def post_baseline():
|
| 87 |
"""Trigger baseline run, return scores for all tasks.
|
| 88 |
|
| 89 |
-
Returns 409 if already running.
|
| 90 |
"""
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
if _baseline_running:
|
| 94 |
return JSONResponse(
|
| 95 |
status_code=409,
|
| 96 |
content={"error": "baseline_in_progress"},
|
| 97 |
)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
return {"scores": scores}
|
| 103 |
-
finally:
|
| 104 |
-
_baseline_running = False
|
| 105 |
-
|
| 106 |
|
| 107 |
-
async def _run_baseline() -> dict[str, float]:
|
| 108 |
-
"""Run the rule-based baseline internally."""
|
| 109 |
|
|
|
|
|
|
|
| 110 |
scores: dict[str, float] = {}
|
| 111 |
|
| 112 |
-
for task_info in
|
| 113 |
task_id = task_info["id"]
|
| 114 |
env = MLTrainingEnvironment()
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# Run heuristic decision tree
|
| 118 |
-
score = _run_heuristic_episode(env, obs, task_id)
|
| 119 |
scores[task_id] = round(score, 4)
|
| 120 |
|
| 121 |
return scores
|
|
@@ -123,73 +146,66 @@ async def _run_baseline() -> dict[str, float]:
|
|
| 123 |
|
| 124 |
def _run_heuristic_episode(
|
| 125 |
env: MLTrainingEnvironment,
|
| 126 |
-
obs: MLTrainingObservation,
|
| 127 |
task_id: str,
|
| 128 |
) -> float:
|
| 129 |
-
"""Run one heuristic baseline episode. Returns grader score.
|
|
|
|
|
|
|
|
|
|
| 130 |
# Step 1: inspect_gradients
|
| 131 |
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 132 |
|
| 133 |
-
# Check for exploding gradients
|
| 134 |
if obs.gradient_stats:
|
|
|
|
| 135 |
if any(g.is_exploding for g in obs.gradient_stats):
|
| 136 |
-
|
| 137 |
MLTrainingAction(
|
| 138 |
action_type="modify_config",
|
| 139 |
target="learning_rate",
|
| 140 |
value=0.001,
|
| 141 |
)
|
| 142 |
)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
MLTrainingAction(
|
| 146 |
action_type="mark_diagnosed",
|
| 147 |
diagnosis="lr_too_high",
|
| 148 |
)
|
| 149 |
)
|
| 150 |
-
|
| 151 |
-
if session and session.last_score is not None:
|
| 152 |
-
return session.last_score
|
| 153 |
-
return 0.0
|
| 154 |
|
| 155 |
-
# Check
|
| 156 |
if any(g.is_vanishing for g in obs.gradient_stats):
|
| 157 |
-
|
| 158 |
MLTrainingAction(
|
| 159 |
action_type="modify_config",
|
| 160 |
target="learning_rate",
|
| 161 |
value=0.01,
|
| 162 |
)
|
| 163 |
)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
MLTrainingAction(
|
| 167 |
action_type="mark_diagnosed",
|
| 168 |
diagnosis="vanishing_gradients",
|
| 169 |
)
|
| 170 |
)
|
| 171 |
-
|
| 172 |
-
if session and session.last_score is not None:
|
| 173 |
-
return session.last_score
|
| 174 |
-
return 0.0
|
| 175 |
|
| 176 |
# Step 2: inspect_data_batch
|
| 177 |
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
| 178 |
if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
MLTrainingAction(
|
| 183 |
action_type="mark_diagnosed",
|
| 184 |
diagnosis="data_leakage",
|
| 185 |
)
|
| 186 |
)
|
| 187 |
-
|
| 188 |
-
if session and session.last_score is not None:
|
| 189 |
-
return session.last_score
|
| 190 |
-
return 0.0
|
| 191 |
|
| 192 |
-
# Check
|
| 193 |
if obs.val_loss_history and len(obs.val_loss_history) >= 10:
|
| 194 |
early = sum(obs.val_loss_history[:5]) / 5
|
| 195 |
late = sum(obs.val_loss_history[-5:]) / 5
|
|
@@ -198,50 +214,43 @@ def _run_heuristic_episode(
|
|
| 198 |
and obs.data_batch_stats
|
| 199 |
and obs.data_batch_stats.class_overlap_score < 0.1
|
| 200 |
):
|
| 201 |
-
|
| 202 |
MLTrainingAction(
|
| 203 |
action_type="modify_config",
|
| 204 |
target="weight_decay",
|
| 205 |
value=0.01,
|
| 206 |
)
|
| 207 |
)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
MLTrainingAction(
|
| 211 |
action_type="mark_diagnosed",
|
| 212 |
diagnosis="overfitting",
|
| 213 |
)
|
| 214 |
)
|
| 215 |
-
|
| 216 |
-
if session and session.last_score is not None:
|
| 217 |
-
return session.last_score
|
| 218 |
-
return 0.0
|
| 219 |
|
| 220 |
# Step 3: inspect_model_modes
|
| 221 |
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
| 222 |
if obs.model_mode_info:
|
| 223 |
has_eval = any(v == "eval" for v in obs.model_mode_info.values())
|
| 224 |
if has_eval:
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
MLTrainingAction(
|
| 229 |
action_type="mark_diagnosed",
|
| 230 |
diagnosis="batchnorm_eval_mode",
|
| 231 |
)
|
| 232 |
)
|
| 233 |
-
|
| 234 |
-
if session and session.last_score is not None:
|
| 235 |
-
return session.last_score
|
| 236 |
-
return 0.0
|
| 237 |
|
| 238 |
# Step 4: inspect_code (for Task 6)
|
| 239 |
obs = env.step(MLTrainingAction(action_type="inspect_code"))
|
| 240 |
if obs.code_snippet:
|
| 241 |
-
# Simple pattern matching for known bugs
|
| 242 |
code = obs.code_snippet.code
|
| 243 |
if "model.eval()" in code and "model.train()" not in code:
|
| 244 |
-
|
| 245 |
MLTrainingAction(
|
| 246 |
action_type="fix_code",
|
| 247 |
line=5,
|
|
@@ -249,39 +258,51 @@ def _run_heuristic_episode(
|
|
| 249 |
)
|
| 250 |
)
|
| 251 |
elif ".detach()" in code:
|
| 252 |
-
|
| 253 |
MLTrainingAction(
|
| 254 |
action_type="fix_code",
|
| 255 |
line=14,
|
| 256 |
replacement=" loss = criterion(output, batch_y)",
|
| 257 |
)
|
| 258 |
)
|
| 259 |
-
else:
|
| 260 |
-
# Can't reliably fix — just diagnose
|
| 261 |
-
pass
|
| 262 |
|
| 263 |
-
if
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
|
| 266 |
-
|
| 267 |
MLTrainingAction(
|
| 268 |
action_type="mark_diagnosed",
|
| 269 |
diagnosis="code_bug",
|
| 270 |
)
|
| 271 |
)
|
| 272 |
-
|
| 273 |
-
if session and session.last_score is not None:
|
| 274 |
-
return session.last_score
|
| 275 |
-
return 0.0
|
| 276 |
|
| 277 |
# Fallback
|
| 278 |
-
|
| 279 |
MLTrainingAction(
|
| 280 |
action_type="mark_diagnosed",
|
| 281 |
diagnosis="overfitting",
|
| 282 |
)
|
| 283 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
session = env._get_session()
|
| 285 |
if session and session.last_score is not None:
|
| 286 |
return session.last_score
|
| 287 |
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""FastAPI app — openenv create_app() + custom hackathon routes.
|
| 2 |
|
| 3 |
+
Spec reference: Sections 9, 14, 15.
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import asyncio
|
| 9 |
+
import json
|
| 10 |
import logging
|
| 11 |
+
import sys
|
| 12 |
from typing import Optional
|
| 13 |
|
| 14 |
from fastapi import FastAPI
|
| 15 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 16 |
from openenv.core.env_server.http_server import create_app
|
| 17 |
|
| 18 |
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
|
| 19 |
+
from server._baseline_results import get_last_grader_result
|
| 20 |
from server.environment import MLTrainingEnvironment
|
| 21 |
|
| 22 |
+
|
| 23 |
+
# Structured JSON logging (Spec S15)
|
| 24 |
+
class JSONFormatter(logging.Formatter):
|
| 25 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 26 |
+
log_data = {
|
| 27 |
+
"time": self.formatTime(record),
|
| 28 |
+
"level": record.levelname,
|
| 29 |
+
"msg": record.getMessage(),
|
| 30 |
+
}
|
| 31 |
+
if hasattr(record, "session_id"):
|
| 32 |
+
log_data["session_id"] = record.session_id
|
| 33 |
+
if hasattr(record, "task_id"):
|
| 34 |
+
log_data["task_id"] = record.task_id
|
| 35 |
+
if hasattr(record, "step_count"):
|
| 36 |
+
log_data["step_count"] = record.step_count
|
| 37 |
+
if hasattr(record, "action_type"):
|
| 38 |
+
log_data["action_type"] = record.action_type
|
| 39 |
+
if hasattr(record, "score"):
|
| 40 |
+
log_data["score"] = record.score
|
| 41 |
+
return json.dumps(log_data)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 45 |
+
handler.setFormatter(JSONFormatter())
|
| 46 |
+
logging.root.handlers = [handler]
|
| 47 |
+
logging.root.setLevel(logging.INFO)
|
| 48 |
logger = logging.getLogger(__name__)
|
| 49 |
|
| 50 |
+
# All 6 tasks (Spec S11)
|
| 51 |
+
ALL_TASKS = [
|
| 52 |
{"id": "task_001", "difficulty": "easy", "max_steps": 20},
|
| 53 |
+
{"id": "task_002", "difficulty": "easy", "max_steps": 20},
|
| 54 |
{"id": "task_003", "difficulty": "medium", "max_steps": 25},
|
| 55 |
+
{"id": "task_004", "difficulty": "medium", "max_steps": 25},
|
| 56 |
{"id": "task_005", "difficulty": "hard", "max_steps": 30},
|
| 57 |
+
{"id": "task_006", "difficulty": "hard", "max_steps": 30},
|
| 58 |
]
|
| 59 |
|
| 60 |
# create_app takes the class (factory), not an instance
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
# Override framework's /health route with our custom version
|
|
|
|
| 70 |
app.routes[:] = [
|
| 71 |
r for r in app.routes if not (hasattr(r, "path") and r.path == "/health")
|
| 72 |
]
|
| 73 |
|
| 74 |
+
# Thread-safe baseline lock (Fix #14)
|
| 75 |
_baseline_lock = asyncio.Lock()
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
@app.get("/health")
|
| 79 |
def health_check() -> dict:
|
| 80 |
"""Health check — required by hackathon auto-validator."""
|
| 81 |
+
return {"status": "ready", "tasks": len(ALL_TASKS)}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@app.get("/dashboard", response_class=HTMLResponse)
|
| 85 |
+
def get_dashboard() -> str:
|
| 86 |
+
"""Serve live diagnostic dashboard. Spec Section 19."""
|
| 87 |
+
import pathlib
|
| 88 |
+
|
| 89 |
+
html_path = pathlib.Path(__file__).parent / "dashboard.html"
|
| 90 |
+
return html_path.read_text()
|
| 91 |
|
| 92 |
|
| 93 |
@app.get("/tasks")
|
| 94 |
def get_tasks() -> list[dict]:
|
| 95 |
"""Return task list with IDs, difficulties, and action schema."""
|
| 96 |
schema = MLTrainingAction.model_json_schema()
|
| 97 |
+
return [{**task, "action_schema": schema} for task in ALL_TASKS]
|
| 98 |
|
| 99 |
|
| 100 |
@app.post("/grader")
|
|
|
|
| 103 |
|
| 104 |
Edge cases per spec Section 14:
|
| 105 |
- No episode completed → {"score": null, "error": "no_completed_episode"}
|
|
|
|
| 106 |
- Episode completed → {"score": float, "task_id": str, "steps": int}
|
| 107 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
result = get_last_grader_result(session_id)
|
| 109 |
if result is None:
|
| 110 |
return {"score": None, "error": "no_completed_episode"}
|
|
|
|
| 115 |
async def post_baseline():
|
| 116 |
"""Trigger baseline run, return scores for all tasks.
|
| 117 |
|
| 118 |
+
Returns 409 if already running. Uses asyncio.Lock for thread safety.
|
| 119 |
"""
|
| 120 |
+
if _baseline_lock.locked():
|
|
|
|
|
|
|
| 121 |
return JSONResponse(
|
| 122 |
status_code=409,
|
| 123 |
content={"error": "baseline_in_progress"},
|
| 124 |
)
|
| 125 |
|
| 126 |
+
async with _baseline_lock:
|
| 127 |
+
scores = await asyncio.get_event_loop().run_in_executor(
|
| 128 |
+
None, _run_baseline_sync
|
| 129 |
+
)
|
| 130 |
return {"scores": scores}
|
|
|
|
|
|
|
|
|
|
| 131 |
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
def _run_baseline_sync() -> dict[str, float]:
|
| 134 |
+
"""Run the rule-based baseline synchronously."""
|
| 135 |
scores: dict[str, float] = {}
|
| 136 |
|
| 137 |
+
for task_info in ALL_TASKS:
|
| 138 |
task_id = task_info["id"]
|
| 139 |
env = MLTrainingEnvironment()
|
| 140 |
+
env.reset(seed=42, episode_id=f"baseline_{task_id}", task_id=task_id)
|
| 141 |
+
score = _run_heuristic_episode(env, task_id)
|
|
|
|
|
|
|
| 142 |
scores[task_id] = round(score, 4)
|
| 143 |
|
| 144 |
return scores
|
|
|
|
| 146 |
|
| 147 |
def _run_heuristic_episode(
|
| 148 |
env: MLTrainingEnvironment,
|
|
|
|
| 149 |
task_id: str,
|
| 150 |
) -> float:
|
| 151 |
+
"""Run one heuristic baseline episode. Returns grader score.
|
| 152 |
+
|
| 153 |
+
Decision tree per spec Section 17.
|
| 154 |
+
"""
|
| 155 |
# Step 1: inspect_gradients
|
| 156 |
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 157 |
|
|
|
|
| 158 |
if obs.gradient_stats:
|
| 159 |
+
# Check exploding
|
| 160 |
if any(g.is_exploding for g in obs.gradient_stats):
|
| 161 |
+
env.step(
|
| 162 |
MLTrainingAction(
|
| 163 |
action_type="modify_config",
|
| 164 |
target="learning_rate",
|
| 165 |
value=0.001,
|
| 166 |
)
|
| 167 |
)
|
| 168 |
+
env.step(MLTrainingAction(action_type="restart_run"))
|
| 169 |
+
env.step(
|
| 170 |
MLTrainingAction(
|
| 171 |
action_type="mark_diagnosed",
|
| 172 |
diagnosis="lr_too_high",
|
| 173 |
)
|
| 174 |
)
|
| 175 |
+
return _get_score(env)
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
# Check vanishing
|
| 178 |
if any(g.is_vanishing for g in obs.gradient_stats):
|
| 179 |
+
env.step(
|
| 180 |
MLTrainingAction(
|
| 181 |
action_type="modify_config",
|
| 182 |
target="learning_rate",
|
| 183 |
value=0.01,
|
| 184 |
)
|
| 185 |
)
|
| 186 |
+
env.step(MLTrainingAction(action_type="restart_run"))
|
| 187 |
+
env.step(
|
| 188 |
MLTrainingAction(
|
| 189 |
action_type="mark_diagnosed",
|
| 190 |
diagnosis="vanishing_gradients",
|
| 191 |
)
|
| 192 |
)
|
| 193 |
+
return _get_score(env)
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# Step 2: inspect_data_batch
|
| 196 |
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
| 197 |
if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
|
| 198 |
+
env.step(MLTrainingAction(action_type="patch_data_loader"))
|
| 199 |
+
env.step(MLTrainingAction(action_type="restart_run"))
|
| 200 |
+
env.step(
|
| 201 |
MLTrainingAction(
|
| 202 |
action_type="mark_diagnosed",
|
| 203 |
diagnosis="data_leakage",
|
| 204 |
)
|
| 205 |
)
|
| 206 |
+
return _get_score(env)
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
+
# Check overfitting (val_loss diverging)
|
| 209 |
if obs.val_loss_history and len(obs.val_loss_history) >= 10:
|
| 210 |
early = sum(obs.val_loss_history[:5]) / 5
|
| 211 |
late = sum(obs.val_loss_history[-5:]) / 5
|
|
|
|
| 214 |
and obs.data_batch_stats
|
| 215 |
and obs.data_batch_stats.class_overlap_score < 0.1
|
| 216 |
):
|
| 217 |
+
env.step(
|
| 218 |
MLTrainingAction(
|
| 219 |
action_type="modify_config",
|
| 220 |
target="weight_decay",
|
| 221 |
value=0.01,
|
| 222 |
)
|
| 223 |
)
|
| 224 |
+
env.step(MLTrainingAction(action_type="restart_run"))
|
| 225 |
+
env.step(
|
| 226 |
MLTrainingAction(
|
| 227 |
action_type="mark_diagnosed",
|
| 228 |
diagnosis="overfitting",
|
| 229 |
)
|
| 230 |
)
|
| 231 |
+
return _get_score(env)
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
# Step 3: inspect_model_modes
|
| 234 |
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
| 235 |
if obs.model_mode_info:
|
| 236 |
has_eval = any(v == "eval" for v in obs.model_mode_info.values())
|
| 237 |
if has_eval:
|
| 238 |
+
env.step(MLTrainingAction(action_type="fix_model_mode"))
|
| 239 |
+
env.step(MLTrainingAction(action_type="restart_run"))
|
| 240 |
+
env.step(
|
| 241 |
MLTrainingAction(
|
| 242 |
action_type="mark_diagnosed",
|
| 243 |
diagnosis="batchnorm_eval_mode",
|
| 244 |
)
|
| 245 |
)
|
| 246 |
+
return _get_score(env)
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
# Step 4: inspect_code (for Task 6)
|
| 249 |
obs = env.step(MLTrainingAction(action_type="inspect_code"))
|
| 250 |
if obs.code_snippet:
|
|
|
|
| 251 |
code = obs.code_snippet.code
|
| 252 |
if "model.eval()" in code and "model.train()" not in code:
|
| 253 |
+
env.step(
|
| 254 |
MLTrainingAction(
|
| 255 |
action_type="fix_code",
|
| 256 |
line=5,
|
|
|
|
| 258 |
)
|
| 259 |
)
|
| 260 |
elif ".detach()" in code:
|
| 261 |
+
env.step(
|
| 262 |
MLTrainingAction(
|
| 263 |
action_type="fix_code",
|
| 264 |
line=14,
|
| 265 |
replacement=" loss = criterion(output, batch_y)",
|
| 266 |
)
|
| 267 |
)
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
+
# Try restart if fix was applied
|
| 270 |
+
session = env._get_session()
|
| 271 |
+
if session and session.state.fix_action_taken:
|
| 272 |
+
env.step(MLTrainingAction(action_type="restart_run"))
|
| 273 |
|
| 274 |
+
env.step(
|
| 275 |
MLTrainingAction(
|
| 276 |
action_type="mark_diagnosed",
|
| 277 |
diagnosis="code_bug",
|
| 278 |
)
|
| 279 |
)
|
| 280 |
+
return _get_score(env)
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
# Fallback
|
| 283 |
+
env.step(
|
| 284 |
MLTrainingAction(
|
| 285 |
action_type="mark_diagnosed",
|
| 286 |
diagnosis="overfitting",
|
| 287 |
)
|
| 288 |
)
|
| 289 |
+
return _get_score(env)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def _get_score(env: MLTrainingEnvironment) -> float:
|
| 293 |
+
"""Extract the grader score from the environment."""
|
| 294 |
session = env._get_session()
|
| 295 |
if session and session.last_score is not None:
|
| 296 |
return session.last_score
|
| 297 |
return 0.0
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def main() -> None:
|
| 301 |
+
"""Entry point for running the server."""
|
| 302 |
+
import uvicorn
|
| 303 |
+
|
| 304 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
main()
|
server/dashboard.html
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>PyTorch Training Debugger — Live Dashboard</title>
|
| 7 |
+
<script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
|
| 8 |
+
<style>
|
| 9 |
+
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 10 |
+
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: #0d1117; color: #c9d1d9; }
|
| 11 |
+
.header { background: #161b22; padding: 16px 24px; border-bottom: 1px solid #30363d; display: flex; align-items: center; gap: 16px; }
|
| 12 |
+
.header h1 { font-size: 20px; font-weight: 600; }
|
| 13 |
+
.header .status { padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 500; }
|
| 14 |
+
.status.connected { background: #238636; color: #fff; }
|
| 15 |
+
.status.disconnected { background: #da3633; color: #fff; }
|
| 16 |
+
.grid { display: grid; grid-template-columns: 1fr 1fr; grid-template-rows: 1fr 1fr; gap: 12px; padding: 12px; height: calc(100vh - 60px); }
|
| 17 |
+
.panel { background: #161b22; border: 1px solid #30363d; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; }
|
| 18 |
+
.panel-title { padding: 10px 16px; font-size: 14px; font-weight: 600; color: #58a6ff; border-bottom: 1px solid #30363d; background: #0d1117; }
|
| 19 |
+
.panel-body { flex: 1; padding: 8px; position: relative; min-height: 0; }
|
| 20 |
+
.placeholder { display: flex; align-items: center; justify-content: center; height: 100%; color: #484f58; font-style: italic; }
|
| 21 |
+
#controls { display: flex; gap: 8px; align-items: center; }
|
| 22 |
+
#controls select, #controls button { background: #21262d; color: #c9d1d9; border: 1px solid #30363d; padding: 6px 12px; border-radius: 6px; cursor: pointer; font-size: 13px; }
|
| 23 |
+
#controls button:hover { background: #30363d; }
|
| 24 |
+
#controls button.primary { background: #238636; border-color: #238636; color: #fff; }
|
| 25 |
+
#summary { padding: 16px; font-size: 13px; line-height: 1.8; overflow-y: auto; }
|
| 26 |
+
#summary .row { display: flex; justify-content: space-between; border-bottom: 1px solid #21262d; padding: 4px 0; }
|
| 27 |
+
#summary .label { color: #8b949e; }
|
| 28 |
+
#summary .value { font-weight: 600; }
|
| 29 |
+
#summary .score { font-size: 24px; color: #58a6ff; text-align: center; margin: 12px 0; }
|
| 30 |
+
.actions-list { display: flex; flex-wrap: wrap; gap: 4px; margin-top: 8px; }
|
| 31 |
+
.action-tag { padding: 2px 8px; border-radius: 4px; font-size: 11px; font-weight: 500; }
|
| 32 |
+
.action-tag.investigate { background: #1f6feb33; color: #58a6ff; }
|
| 33 |
+
.action-tag.fix { background: #23863633; color: #3fb950; }
|
| 34 |
+
.action-tag.terminal { background: #da363333; color: #f85149; }
|
| 35 |
+
.action-tag.wrong { background: #da363366; color: #f85149; }
|
| 36 |
+
</style>
|
| 37 |
+
</head>
|
| 38 |
+
<body>
|
| 39 |
+
<div class="header">
|
| 40 |
+
<h1>PyTorch Training Debugger</h1>
|
| 41 |
+
<div id="connStatus" class="status disconnected">Disconnected</div>
|
| 42 |
+
<div id="controls">
|
| 43 |
+
<select id="taskSelect">
|
| 44 |
+
<option value="task_001">Task 1 — Exploding Gradients (Easy)</option>
|
| 45 |
+
<option value="task_002">Task 2 — Vanishing Gradients (Easy)</option>
|
| 46 |
+
<option value="task_003">Task 3 — Data Leakage (Medium)</option>
|
| 47 |
+
<option value="task_004">Task 4 — Overfitting (Medium)</option>
|
| 48 |
+
<option value="task_005">Task 5 — BatchNorm Eval (Hard)</option>
|
| 49 |
+
<option value="task_006">Task 6 — Code Bug (Hard)</option>
|
| 50 |
+
</select>
|
| 51 |
+
<button class="primary" onclick="runBaseline()">Run Baseline</button>
|
| 52 |
+
</div>
|
| 53 |
+
</div>
|
| 54 |
+
<div class="grid">
|
| 55 |
+
<div class="panel">
|
| 56 |
+
<div class="panel-title">Training Metrics</div>
|
| 57 |
+
<div class="panel-body"><div id="metricsChart"><div class="placeholder">Run baseline to see metrics</div></div></div>
|
| 58 |
+
</div>
|
| 59 |
+
<div class="panel">
|
| 60 |
+
<div class="panel-title">Gradient & Weight Heatmap</div>
|
| 61 |
+
<div class="panel-body"><div id="gradientChart"><div class="placeholder">Not yet inspected</div></div></div>
|
| 62 |
+
</div>
|
| 63 |
+
<div class="panel">
|
| 64 |
+
<div class="panel-title">Action Timeline & Rewards</div>
|
| 65 |
+
<div class="panel-body"><div id="timelineChart"><div class="placeholder">No actions yet</div></div></div>
|
| 66 |
+
</div>
|
| 67 |
+
<div class="panel">
|
| 68 |
+
<div class="panel-title">Episode Summary</div>
|
| 69 |
+
<div class="panel-body" id="summary">
|
| 70 |
+
<div class="placeholder">Waiting for episode</div>
|
| 71 |
+
</div>
|
| 72 |
+
</div>
|
| 73 |
+
</div>
|
| 74 |
+
|
| 75 |
+
<script>
|
| 76 |
+
const host = window.location.host;
|
| 77 |
+
const wsProto = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 78 |
+
let ws = null;
|
| 79 |
+
let actions = [];
|
| 80 |
+
let rewards = [];
|
| 81 |
+
let cumRewards = [];
|
| 82 |
+
let obs = null;
|
| 83 |
+
|
| 84 |
+
function setStatus(connected) {
|
| 85 |
+
const el = document.getElementById('connStatus');
|
| 86 |
+
el.textContent = connected ? 'Connected' : 'Disconnected';
|
| 87 |
+
el.className = 'status ' + (connected ? 'connected' : 'disconnected');
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
function connect() {
|
| 91 |
+
ws = new WebSocket(`${wsProto}//${host}/ws`);
|
| 92 |
+
ws.onopen = () => setStatus(true);
|
| 93 |
+
ws.onclose = () => { setStatus(false); setTimeout(connect, 2000); };
|
| 94 |
+
ws.onerror = () => ws.close();
|
| 95 |
+
ws.onmessage = (ev) => {
|
| 96 |
+
const msg = JSON.parse(ev.data);
|
| 97 |
+
if (msg.data) handleObservation(msg.data);
|
| 98 |
+
};
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
function handleObservation(data) {
|
| 102 |
+
obs = data;
|
| 103 |
+
if (data.reward !== null && data.reward !== undefined) {
|
| 104 |
+
rewards.push(data.reward);
|
| 105 |
+
const prev = cumRewards.length > 0 ? cumRewards[cumRewards.length - 1] : 0;
|
| 106 |
+
cumRewards.push(prev + data.reward);
|
| 107 |
+
}
|
| 108 |
+
if (data.episode_state && data.episode_state.actions_taken) {
|
| 109 |
+
actions = data.episode_state.actions_taken;
|
| 110 |
+
}
|
| 111 |
+
updateMetrics(data);
|
| 112 |
+
updateGradients(data);
|
| 113 |
+
updateTimeline();
|
| 114 |
+
updateSummary(data);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
function updateMetrics(d) {
|
| 118 |
+
const traces = [];
|
| 119 |
+
if (d.training_loss_history && d.training_loss_history.length > 0) {
|
| 120 |
+
const valid = d.training_loss_history.filter(v => isFinite(v));
|
| 121 |
+
traces.push({ y: valid, name: 'Train Loss', line: { color: '#f85149' } });
|
| 122 |
+
}
|
| 123 |
+
if (d.val_loss_history && d.val_loss_history.length > 0) {
|
| 124 |
+
const valid = d.val_loss_history.filter(v => isFinite(v));
|
| 125 |
+
traces.push({ y: valid, name: 'Val Loss', line: { color: '#f0883e', dash: 'dash' } });
|
| 126 |
+
}
|
| 127 |
+
if (d.val_accuracy_history && d.val_accuracy_history.length > 0) {
|
| 128 |
+
traces.push({ y: d.val_accuracy_history, name: 'Val Accuracy', yaxis: 'y2', line: { color: '#3fb950' } });
|
| 129 |
+
}
|
| 130 |
+
if (traces.length === 0) return;
|
| 131 |
+
Plotly.newPlot('metricsChart', traces, {
|
| 132 |
+
paper_bgcolor: 'transparent', plot_bgcolor: 'transparent',
|
| 133 |
+
font: { color: '#c9d1d9', size: 11 },
|
| 134 |
+
margin: { t: 10, b: 30, l: 50, r: 50 },
|
| 135 |
+
xaxis: { title: 'Epoch', gridcolor: '#21262d' },
|
| 136 |
+
yaxis: { title: 'Loss', gridcolor: '#21262d' },
|
| 137 |
+
yaxis2: { title: 'Accuracy', overlaying: 'y', side: 'right', range: [0, 1], gridcolor: '#21262d' },
|
| 138 |
+
legend: { x: 0, y: 1.15, orientation: 'h' },
|
| 139 |
+
showlegend: true,
|
| 140 |
+
}, { responsive: true });
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
function updateGradients(d) {
|
| 144 |
+
if (!d.gradient_stats || d.gradient_stats.length === 0) return;
|
| 145 |
+
const layers = d.gradient_stats.map(g => g.layer_name);
|
| 146 |
+
const norms = d.gradient_stats.map(g => g.mean_norm);
|
| 147 |
+
const colors = d.gradient_stats.map(g => g.is_exploding ? '#f85149' : g.is_vanishing ? '#1f6feb' : '#3fb950');
|
| 148 |
+
Plotly.newPlot('gradientChart', [{
|
| 149 |
+
x: layers, y: norms, type: 'bar',
|
| 150 |
+
marker: { color: colors },
|
| 151 |
+
text: d.gradient_stats.map(g => g.is_exploding ? 'EXPLODING' : g.is_vanishing ? 'VANISHING' : 'Normal'),
|
| 152 |
+
textposition: 'auto',
|
| 153 |
+
}], {
|
| 154 |
+
paper_bgcolor: 'transparent', plot_bgcolor: 'transparent',
|
| 155 |
+
font: { color: '#c9d1d9', size: 11 },
|
| 156 |
+
margin: { t: 10, b: 30, l: 50, r: 20 },
|
| 157 |
+
yaxis: { title: 'Mean Grad Norm', gridcolor: '#21262d', type: 'log' },
|
| 158 |
+
xaxis: { gridcolor: '#21262d' },
|
| 159 |
+
}, { responsive: true });
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
function updateTimeline() {
|
| 163 |
+
if (actions.length === 0) return;
|
| 164 |
+
const colors = actions.map(a => {
|
| 165 |
+
if (a.startsWith('inspect')) return '#1f6feb';
|
| 166 |
+
if (a.startsWith('fix') || a === 'modify_config' || a === 'patch_data_loader' || a === 'add_callback' || a === 'replace_optimizer') return '#238636';
|
| 167 |
+
if (a.startsWith('mark_diagnosed')) return '#da3633';
|
| 168 |
+
if (a === 'restart_run') return '#f0883e';
|
| 169 |
+
return '#484f58';
|
| 170 |
+
});
|
| 171 |
+
Plotly.newPlot('timelineChart', [
|
| 172 |
+
{ x: actions.map((_, i) => i + 1), y: rewards, type: 'bar', name: 'Step Reward', marker: { color: rewards.map(r => r >= 0 ? '#3fb950' : '#f85149') } },
|
| 173 |
+
{ x: actions.map((_, i) => i + 1), y: cumRewards, type: 'scatter', name: 'Cumulative', line: { color: '#58a6ff', width: 2 } }
|
| 174 |
+
], {
|
| 175 |
+
paper_bgcolor: 'transparent', plot_bgcolor: 'transparent',
|
| 176 |
+
font: { color: '#c9d1d9', size: 11 },
|
| 177 |
+
margin: { t: 10, b: 30, l: 50, r: 20 },
|
| 178 |
+
xaxis: { title: 'Step', gridcolor: '#21262d', tickvals: actions.map((_, i) => i + 1), ticktext: actions.map(a => a.split(':')[0].replace('inspect_', 'i_').replace('mark_diagnosed', 'diag')) },
|
| 179 |
+
yaxis: { title: 'Reward', gridcolor: '#21262d' },
|
| 180 |
+
legend: { x: 0, y: 1.15, orientation: 'h' },
|
| 181 |
+
}, { responsive: true });
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
function updateSummary(d) {
|
| 185 |
+
const s = d.episode_state || {};
|
| 186 |
+
const avail = d.available_actions || [];
|
| 187 |
+
let html = '';
|
| 188 |
+
if (d.done) {
|
| 189 |
+
html += `<div class="score">Episode Complete</div>`;
|
| 190 |
+
}
|
| 191 |
+
html += '<div class="row"><span class="label">Task</span><span class="value">' + (d.run_id || '-') + '</span></div>';
|
| 192 |
+
html += '<div class="row"><span class="label">Steps</span><span class="value">' + (s.step_count || 0) + '</span></div>';
|
| 193 |
+
html += '<div class="row"><span class="label">Gradients Inspected</span><span class="value">' + (s.gradients_inspected ? 'Yes' : 'No') + '</span></div>';
|
| 194 |
+
html += '<div class="row"><span class="label">Gradients Normal</span><span class="value">' + (s.gradients_were_normal ? 'Yes' : '-') + '</span></div>';
|
| 195 |
+
html += '<div class="row"><span class="label">Data Inspected</span><span class="value">' + (s.data_inspected ? 'Yes' : 'No') + '</span></div>';
|
| 196 |
+
html += '<div class="row"><span class="label">Model Modes Inspected</span><span class="value">' + (s.model_modes_inspected ? 'Yes' : 'No') + '</span></div>';
|
| 197 |
+
html += '<div class="row"><span class="label">Code Inspected</span><span class="value">' + (s.code_inspected ? 'Yes' : 'No') + '</span></div>';
|
| 198 |
+
html += '<div class="row"><span class="label">Fix Applied</span><span class="value">' + (s.fix_action_taken ? 'Yes' : 'No') + '</span></div>';
|
| 199 |
+
html += '<div class="row"><span class="label">Restarted</span><span class="value">' + (s.restart_after_fix ? 'Yes' : 'No') + '</span></div>';
|
| 200 |
+
html += '<div class="row"><span class="label">Diagnosed</span><span class="value">' + (s.diagnosis_submitted ? 'Yes' : 'No') + '</span></div>';
|
| 201 |
+
if (d.code_snippet) {
|
| 202 |
+
html += '<div style="margin-top:12px"><span class="label">Code:</span><pre style="background:#0d1117;padding:8px;border-radius:4px;font-size:11px;overflow:auto;max-height:120px;margin-top:4px">' + d.code_snippet.code.replace(/</g,'<') + '</pre></div>';
|
| 203 |
+
}
|
| 204 |
+
html += '<div style="margin-top:8px"><span class="label">Available Actions:</span></div>';
|
| 205 |
+
html += '<div class="actions-list">';
|
| 206 |
+
avail.forEach(a => {
|
| 207 |
+
let cls = 'investigate';
|
| 208 |
+
if (a.startsWith('fix') || a === 'modify_config' || a === 'patch_data_loader' || a === 'add_callback' || a === 'replace_optimizer') cls = 'fix';
|
| 209 |
+
if (a === 'mark_diagnosed' || a === 'restart_run') cls = 'terminal';
|
| 210 |
+
html += `<span class="action-tag ${cls}">${a}</span>`;
|
| 211 |
+
});
|
| 212 |
+
html += '</div>';
|
| 213 |
+
document.getElementById('summary').innerHTML = html;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
async function runBaseline() {
|
| 217 |
+
const taskId = document.getElementById('taskSelect').value;
|
| 218 |
+
actions = []; rewards = []; cumRewards = [];
|
| 219 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 220 |
+
ws.send(JSON.stringify({ type: 'reset', data: { task_id: taskId, seed: 42 } }));
|
| 221 |
+
await new Promise(r => setTimeout(r, 500));
|
| 222 |
+
// Run the heuristic steps
|
| 223 |
+
const steps = [
|
| 224 |
+
{ action_type: 'inspect_gradients' },
|
| 225 |
+
{ action_type: 'inspect_data_batch' },
|
| 226 |
+
{ action_type: 'inspect_model_modes' },
|
| 227 |
+
{ action_type: 'inspect_model_weights' },
|
| 228 |
+
{ action_type: 'inspect_code' },
|
| 229 |
+
];
|
| 230 |
+
for (const step of steps) {
|
| 231 |
+
ws.send(JSON.stringify({ type: 'step', data: step }));
|
| 232 |
+
await new Promise(r => setTimeout(r, 300));
|
| 233 |
+
if (obs && obs.done) break;
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
connect();
|
| 239 |
+
</script>
|
| 240 |
+
</body>
|
| 241 |
+
</html>
|
server/environment.py
CHANGED
|
@@ -46,6 +46,7 @@ from ml_training_debugger.simulation import (
|
|
| 46 |
gen_val_accuracy_history,
|
| 47 |
gen_val_loss_history,
|
| 48 |
)
|
|
|
|
| 49 |
|
| 50 |
logger = logging.getLogger(__name__)
|
| 51 |
|
|
@@ -160,6 +161,9 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
|
|
| 160 |
"task_id": old.scenario.task_id,
|
| 161 |
"steps": old.state.step_count,
|
| 162 |
}
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
self._current_session_id = session_id
|
| 165 |
|
|
@@ -335,6 +339,9 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
|
|
| 335 |
"task_id": scenario.task_id,
|
| 336 |
"steps": state.step_count,
|
| 337 |
}
|
|
|
|
|
|
|
|
|
|
| 338 |
logger.info(
|
| 339 |
"episode_completed",
|
| 340 |
extra={
|
|
|
|
| 46 |
gen_val_accuracy_history,
|
| 47 |
gen_val_loss_history,
|
| 48 |
)
|
| 49 |
+
from server._baseline_results import store_grader_result
|
| 50 |
|
| 51 |
logger = logging.getLogger(__name__)
|
| 52 |
|
|
|
|
| 161 |
"task_id": old.scenario.task_id,
|
| 162 |
"steps": old.state.step_count,
|
| 163 |
}
|
| 164 |
+
store_grader_result(
|
| 165 |
+
session_id, score, old.scenario.task_id, old.state.step_count
|
| 166 |
+
)
|
| 167 |
|
| 168 |
self._current_session_id = session_id
|
| 169 |
|
|
|
|
| 339 |
"task_id": scenario.task_id,
|
| 340 |
"steps": state.step_count,
|
| 341 |
}
|
| 342 |
+
store_grader_result(
|
| 343 |
+
self._current_session_id, score, scenario.task_id, state.step_count
|
| 344 |
+
)
|
| 345 |
logger.info(
|
| 346 |
"episode_completed",
|
| 347 |
extra={
|
tests/test_baseline_reproducibility.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test baseline produces bit-exact identical scores on two runs."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from baseline_heuristic import ALL_TASKS, run_heuristic_episode
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestBaselineReproducibility:
|
| 9 |
+
def test_two_runs_identical(self):
|
| 10 |
+
"""Run baseline twice, verify bit-exact same scores."""
|
| 11 |
+
run1 = {tid: run_heuristic_episode(tid) for tid in ALL_TASKS}
|
| 12 |
+
run2 = {tid: run_heuristic_episode(tid) for tid in ALL_TASKS}
|
| 13 |
+
assert run1 == run2
|
| 14 |
+
|
| 15 |
+
def test_all_scores_in_range(self):
|
| 16 |
+
"""All scores must be in [0.0, 1.0]."""
|
| 17 |
+
for tid in ALL_TASKS:
|
| 18 |
+
score = run_heuristic_episode(tid)
|
| 19 |
+
assert 0.0 <= score <= 1.0, f"{tid}: score {score} out of range"
|
| 20 |
+
|
| 21 |
+
def test_scores_have_meaningful_variance(self):
|
| 22 |
+
"""Not all tasks should return the same score."""
|
| 23 |
+
scores = [run_heuristic_episode(tid) for tid in ALL_TASKS]
|
| 24 |
+
assert len(set(scores)) > 1, "All scores identical — no variance"
|
tests/test_endpoints.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests for HTTP endpoints."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from fastapi.testclient import TestClient
|
| 7 |
+
|
| 8 |
+
from server.app import app
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def client():
|
| 13 |
+
return TestClient(app)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestHealthEndpoint:
|
| 17 |
+
def test_returns_ready(self, client):
|
| 18 |
+
resp = client.get("/health")
|
| 19 |
+
assert resp.status_code == 200
|
| 20 |
+
data = resp.json()
|
| 21 |
+
assert data["status"] == "ready"
|
| 22 |
+
assert data["tasks"] == 6
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TestTasksEndpoint:
|
| 26 |
+
def test_returns_six_tasks(self, client):
|
| 27 |
+
resp = client.get("/tasks")
|
| 28 |
+
assert resp.status_code == 200
|
| 29 |
+
tasks = resp.json()
|
| 30 |
+
assert len(tasks) == 6
|
| 31 |
+
ids = [t["id"] for t in tasks]
|
| 32 |
+
assert "task_001" in ids
|
| 33 |
+
assert "task_006" in ids
|
| 34 |
+
|
| 35 |
+
def test_tasks_have_action_schema(self, client):
|
| 36 |
+
resp = client.get("/tasks")
|
| 37 |
+
tasks = resp.json()
|
| 38 |
+
for task in tasks:
|
| 39 |
+
assert "action_schema" in task
|
| 40 |
+
assert "properties" in task["action_schema"]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TestGraderEndpoint:
|
| 44 |
+
def test_no_completed_episode(self, client):
|
| 45 |
+
import server._baseline_results as br
|
| 46 |
+
|
| 47 |
+
br._last_results.clear() # Reset shared state for clean test
|
| 48 |
+
resp = client.post("/grader")
|
| 49 |
+
assert resp.status_code == 200
|
| 50 |
+
data = resp.json()
|
| 51 |
+
assert data["score"] is None
|
| 52 |
+
assert data["error"] == "no_completed_episode"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TestDashboardEndpoint:
|
| 56 |
+
def test_returns_html(self, client):
|
| 57 |
+
resp = client.get("/dashboard")
|
| 58 |
+
assert resp.status_code == 200
|
| 59 |
+
assert "Plotly" in resp.text
|
| 60 |
+
assert "WebSocket" in resp.text
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
validation/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
matplotlib
|
| 3 |
+
scipy
|
validation/validate_exploding_gradients.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Validate parametric exploding gradient curves against real PyTorch training.
|
| 3 |
+
|
| 4 |
+
Trains a CNN with lr=0.1 for 20 epochs, compares loss curve to simulation.
|
| 5 |
+
Asserts R² > 0.85 between real and simulated curves.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from ml_training_debugger.pytorch_engine import SimpleCNN
|
| 14 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 15 |
+
from ml_training_debugger.simulation import gen_loss_history
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def run_real_training(lr: float = 0.1, epochs: int = 20) -> list[float]:
|
| 19 |
+
"""Run real training with high LR and capture loss history."""
|
| 20 |
+
torch.manual_seed(42)
|
| 21 |
+
model = SimpleCNN()
|
| 22 |
+
model.train()
|
| 23 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
| 24 |
+
criterion = nn.CrossEntropyLoss()
|
| 25 |
+
|
| 26 |
+
losses: list[float] = []
|
| 27 |
+
for _ in range(epochs):
|
| 28 |
+
batch_x = torch.randn(16, 3, 32, 32)
|
| 29 |
+
batch_y = torch.randint(0, 10, (16,))
|
| 30 |
+
optimizer.zero_grad()
|
| 31 |
+
output = model(batch_x)
|
| 32 |
+
loss = criterion(output, batch_y)
|
| 33 |
+
loss.backward()
|
| 34 |
+
optimizer.step()
|
| 35 |
+
loss_val = loss.item()
|
| 36 |
+
losses.append(loss_val if not (loss_val != loss_val) else float("inf"))
|
| 37 |
+
return losses
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_r_squared(real: list[float], simulated: list[float]) -> float:
|
| 41 |
+
"""Compute R² between two curves, ignoring inf/nan values."""
|
| 42 |
+
pairs = [
|
| 43 |
+
(r, s)
|
| 44 |
+
for r, s in zip(real, simulated)
|
| 45 |
+
if r != float("inf") and s != float("inf") and r == r and s == s
|
| 46 |
+
]
|
| 47 |
+
if len(pairs) < 3:
|
| 48 |
+
return 0.0
|
| 49 |
+
real_t = torch.tensor([p[0] for p in pairs])
|
| 50 |
+
sim_t = torch.tensor([p[1] for p in pairs])
|
| 51 |
+
ss_res = ((real_t - sim_t) ** 2).sum()
|
| 52 |
+
ss_tot = ((real_t - real_t.mean()) ** 2).sum()
|
| 53 |
+
if ss_tot == 0:
|
| 54 |
+
return 1.0
|
| 55 |
+
return (1 - ss_res / ss_tot).item()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main() -> None:
|
| 59 |
+
scenario = sample_scenario("task_001", seed=42)
|
| 60 |
+
simulated = gen_loss_history(scenario)
|
| 61 |
+
real = run_real_training(lr=scenario.learning_rate, epochs=20)
|
| 62 |
+
|
| 63 |
+
r2 = compute_r_squared(real, simulated)
|
| 64 |
+
print(f"Exploding Gradients — R²: {r2:.4f}")
|
| 65 |
+
print(f" Real loss trend: {real[0]:.2f} → {'INF' if real[-1] == float('inf') else f'{real[-1]:.2f}'}")
|
| 66 |
+
print(f" Sim loss trend: {simulated[0]:.2f} → {'INF' if simulated[-1] == float('inf') else f'{simulated[-1]:.2f}'}")
|
| 67 |
+
|
| 68 |
+
# Both should diverge — directional agreement is what matters
|
| 69 |
+
real_diverges = any(v == float("inf") or v > 100 for v in real)
|
| 70 |
+
sim_diverges = any(v == float("inf") or v > 100 for v in simulated)
|
| 71 |
+
print(f" Real diverges: {real_diverges}, Sim diverges: {sim_diverges}")
|
| 72 |
+
assert real_diverges and sim_diverges, "Both curves should diverge"
|
| 73 |
+
print(" PASS: Both curves diverge as expected")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|