omkarrr88 commited on
Commit
9e6a926
·
1 Parent(s): e2f8b29

Remaining task added + full openenv compliance

Browse files
Dockerfile CHANGED
@@ -2,12 +2,20 @@ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
5
- # Install PyTorch CPU-only first (largest layer, cached)
 
 
 
 
6
  RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
7
 
8
- # Install remaining dependencies
9
  COPY requirements.txt .
10
- RUN pip install --no-cache-dir -r requirements.txt
 
 
 
 
11
 
12
  # Copy application code
13
  COPY ml_training_debugger/ ml_training_debugger/
 
2
 
3
  WORKDIR /app
4
 
5
+ # Install curl for healthcheck
6
+ RUN apt-get update && apt-get install -y --no-install-recommends curl && \
7
+ rm -rf /var/lib/apt/lists/*
8
+
9
+ # Install PyTorch CPU-only first (largest layer, cached separately)
10
  RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
11
 
12
+ # Install remaining dependencies (torch excluded from requirements.txt)
13
  COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt && \
15
+ find /usr/local/lib/python3.12/site-packages -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null; \
16
+ find /usr/local/lib/python3.12/site-packages -name "*.pyc" -delete 2>/dev/null; \
17
+ rm -rf /usr/local/lib/python3.12/site-packages/gradio/templates 2>/dev/null; \
18
+ true
19
 
20
  # Copy application code
21
  COPY ml_training_debugger/ ml_training_debugger/
README.md CHANGED
@@ -76,18 +76,24 @@ Dynamic availability: `restart_run` requires a fix first; `fix_code` requires co
76
  | ID | Difficulty | Root Cause | Description |
77
  |----|-----------|------------|-------------|
78
  | `task_001` | Easy | `lr_too_high` | Exploding gradients — all layers show `is_exploding: True`, NaN in error log |
 
79
  | `task_003` | Medium | `data_leakage` | Silent data leakage — suspiciously high val accuracy, `class_overlap_score > 0.5` |
 
80
  | `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
 
81
 
82
  ## Baseline Scores
83
 
84
- Rule-based heuristic baseline (deterministic, no API key):
85
 
86
- | Task | Score |
87
- |------|-------|
88
- | `task_001` | 1.00 |
89
- | `task_003` | 1.00 |
90
- | `task_005` | 0.35 |
 
 
 
91
 
92
  ## Setup
93
 
@@ -127,10 +133,11 @@ curl http://localhost:7860/health
127
 
128
  | Endpoint | Method | Description |
129
  |----------|--------|-------------|
130
- | `/health` | GET | `{"status": "ready", "tasks": 3}` |
131
  | `/tasks` | GET | Task list with action schema |
132
  | `/grader` | POST | Grader score for last completed episode |
133
- | `/baseline` | POST | Run baseline, return scores |
 
134
  | `/ws` | WebSocket | Primary agent interface |
135
  | `/reset` | POST | Reset environment (framework) |
136
  | `/step` | POST | Execute action (framework) |
 
76
  | ID | Difficulty | Root Cause | Description |
77
  |----|-----------|------------|-------------|
78
  | `task_001` | Easy | `lr_too_high` | Exploding gradients — all layers show `is_exploding: True`, NaN in error log |
79
+ | `task_002` | Easy | `vanishing_gradients` | Vanishing gradients — deeper layers show `is_vanishing: True`, flat loss curve |
80
  | `task_003` | Medium | `data_leakage` | Silent data leakage — suspiciously high val accuracy, `class_overlap_score > 0.5` |
81
+ | `task_004` | Medium | `overfitting` | Train-val divergence — loss approaches 0 while val loss climbs |
82
  | `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
83
+ | `task_006` | Hard | `code_bug` | PyTorch code bug — agent must read and fix actual Python code (4 bug variants) |
84
 
85
  ## Baseline Scores
86
 
87
+ Rule-based heuristic baseline (deterministic, no API key, bit-exact reproducible):
88
 
89
+ | Task | Score | Notes |
90
+ |------|-------|-------|
91
+ | `task_001` | 1.00 | Direct signal: `is_exploding` on all layers |
92
+ | `task_002` | 1.00 | Direct signal: `is_vanishing` on deeper layers |
93
+ | `task_003` | 1.00 | `class_overlap_score > 0.5` triggers correct path |
94
+ | `task_004` | 0.45 | Heuristic must rule out leakage first |
95
+ | `task_005` | 0.35 | Fixed investigation order misses eval mode, diagnoses overfitting |
96
+ | `task_006` | 1.00 | Pattern-matching catches 2 of 4 bug variants |
97
 
98
  ## Setup
99
 
 
133
 
134
  | Endpoint | Method | Description |
135
  |----------|--------|-------------|
136
+ | `/health` | GET | `{"status": "ready", "tasks": 6}` |
137
  | `/tasks` | GET | Task list with action schema |
138
  | `/grader` | POST | Grader score for last completed episode |
139
+ | `/baseline` | POST | Run baseline, return scores for all 6 tasks |
140
+ | `/dashboard` | GET | Live diagnostic dashboard (Plotly.js, 4-panel) |
141
  | `/ws` | WebSocket | Primary agent interface |
142
  | `/reset` | POST | Reset environment (framework) |
143
  | `/step` | POST | Execute action (framework) |
baseline_heuristic.py CHANGED
@@ -14,12 +14,17 @@ import argparse
14
  import json
15
  import sys
16
 
17
- from ml_training_debugger.graders import grade_episode
18
- from ml_training_debugger.models import EpisodeState, MLTrainingAction, MLTrainingObservation
19
- from ml_training_debugger.scenarios import sample_scenario
20
  from server.environment import MLTrainingEnvironment
21
 
22
- MVP_TASKS = ["task_001", "task_003", "task_005"]
 
 
 
 
 
 
 
23
 
24
 
25
  def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
@@ -175,7 +180,7 @@ def main() -> None:
175
  args = parser.parse_args()
176
 
177
  scores: dict[str, float] = {}
178
- for task_id in MVP_TASKS:
179
  score = run_heuristic_episode(task_id)
180
  scores[task_id] = round(score, 4)
181
 
 
14
  import json
15
  import sys
16
 
17
+ from ml_training_debugger.models import MLTrainingAction
 
 
18
  from server.environment import MLTrainingEnvironment
19
 
20
+ ALL_TASKS = [
21
+ "task_001",
22
+ "task_002",
23
+ "task_003",
24
+ "task_004",
25
+ "task_005",
26
+ "task_006",
27
+ ]
28
 
29
 
30
  def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
 
180
  args = parser.parse_args()
181
 
182
  scores: dict[str, float] = {}
183
+ for task_id in ALL_TASKS:
184
  score = run_heuristic_episode(task_id)
185
  scores[task_id] = round(score, 4)
186
 
baseline_inference.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """LLM baseline agent using OpenAI GPT-4o.
3
+
4
+ Optional — requires OPENAI_API_KEY environment variable.
5
+ Uses temperature=0.0 and seed=42 for near-deterministic behavior.
6
+ Spec reference: Section 17.
7
+
8
+ Usage:
9
+ OPENAI_API_KEY=... python baseline_inference.py [--url http://localhost:7860]
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import sys
18
+
19
+ try:
20
+ from openai import OpenAI
21
+ except ImportError:
22
+ print("Error: openai package not installed. Run: pip install openai")
23
+ sys.exit(1)
24
+
25
+ from ml_training_debugger.models import MLTrainingAction
26
+ from server.environment import MLTrainingEnvironment
27
+
28
+ ALL_TASKS = [
29
+ "task_001",
30
+ "task_002",
31
+ "task_003",
32
+ "task_004",
33
+ "task_005",
34
+ "task_006",
35
+ ]
36
+
37
+ SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run.
38
+ You are interacting with an environment that simulates a broken training job.
39
+
40
+ Available actions (respond with JSON):
41
+ - {"action_type": "inspect_gradients"} - View gradient statistics per layer
42
+ - {"action_type": "inspect_data_batch"} - View data batch statistics
43
+ - {"action_type": "inspect_model_modes"} - View model layer modes (train/eval)
44
+ - {"action_type": "inspect_model_weights"} - View model weight statistics
45
+ - {"action_type": "inspect_code"} - View PyTorch training code
46
+ - {"action_type": "modify_config", "target": "<field>", "value": <val>} - Change a hyperparameter
47
+ - {"action_type": "add_callback"} - Add gradient clipping/scheduler
48
+ - {"action_type": "patch_data_loader"} - Fix data pipeline issues
49
+ - {"action_type": "fix_model_mode"} - Call model.train()
50
+ - {"action_type": "fix_code", "line": <int>, "replacement": "<code>"} - Fix a code line
51
+ - {"action_type": "restart_run"} - Restart training (requires a fix first)
52
+ - {"action_type": "mark_diagnosed", "diagnosis": "<cause>"} - Submit diagnosis
53
+
54
+ Valid diagnoses: lr_too_high, vanishing_gradients, data_leakage, overfitting, batchnorm_eval_mode, code_bug
55
+
56
+ Strategy:
57
+ 1. First investigate by inspecting gradients, data, and model modes
58
+ 2. Form a hypothesis based on the evidence
59
+ 3. Apply the correct fix
60
+ 4. Restart training to verify
61
+ 5. Submit your diagnosis
62
+
63
+ Respond with ONLY a valid JSON action object, no explanation."""
64
+
65
+
66
+ def run_llm_episode(task_id: str, client: OpenAI) -> float:
67
+ """Run one LLM agent episode."""
68
+ env = MLTrainingEnvironment()
69
+ obs = env.reset(seed=42, episode_id=f"llm_{task_id}", task_id=task_id)
70
+
71
+ messages = [
72
+ {"role": "system", "content": SYSTEM_PROMPT},
73
+ {"role": "user", "content": f"New episode started. Observation:\n{json.dumps(obs.model_dump(), indent=2, default=str)[:3000]}"},
74
+ ]
75
+
76
+ for step in range(20):
77
+ if obs.done:
78
+ break
79
+
80
+ response = client.chat.completions.create(
81
+ model="gpt-4o",
82
+ messages=messages,
83
+ temperature=0.0,
84
+ seed=42,
85
+ max_tokens=200,
86
+ )
87
+
88
+ action_text = response.choices[0].message.content.strip()
89
+ messages.append({"role": "assistant", "content": action_text})
90
+
91
+ try:
92
+ action_data = json.loads(action_text)
93
+ action = MLTrainingAction(**action_data)
94
+ except (json.JSONDecodeError, Exception) as e:
95
+ messages.append({"role": "user", "content": f"Invalid action: {e}. Try again with valid JSON."})
96
+ continue
97
+
98
+ obs = env.step(action)
99
+ obs_summary = {
100
+ "reward": obs.reward,
101
+ "done": obs.done,
102
+ "step": obs.episode_state.step_count,
103
+ "available_actions": obs.available_actions,
104
+ "error_log": obs.error_log,
105
+ }
106
+ if obs.gradient_stats:
107
+ obs_summary["gradient_stats"] = [
108
+ {"layer": g.layer_name, "mean_norm": round(g.mean_norm, 4), "exploding": g.is_exploding, "vanishing": g.is_vanishing}
109
+ for g in obs.gradient_stats
110
+ ]
111
+ if obs.data_batch_stats:
112
+ obs_summary["data_overlap"] = obs.data_batch_stats.class_overlap_score
113
+ if obs.model_mode_info:
114
+ obs_summary["model_modes"] = obs.model_mode_info
115
+ if obs.code_snippet:
116
+ obs_summary["code"] = obs.code_snippet.code[:500]
117
+
118
+ messages.append({"role": "user", "content": f"Observation:\n{json.dumps(obs_summary, indent=2, default=str)}"})
119
+
120
+ session = env._get_session()
121
+ return session.last_score if session and session.last_score is not None else 0.0
122
+
123
+
124
+ def main() -> None:
125
+ parser = argparse.ArgumentParser(description="LLM baseline agent (GPT-4o)")
126
+ parser.add_argument("--url", default="http://localhost:7860")
127
+ args = parser.parse_args()
128
+
129
+ api_key = os.environ.get("OPENAI_API_KEY")
130
+ if not api_key:
131
+ print("Error: OPENAI_API_KEY environment variable not set")
132
+ sys.exit(1)
133
+
134
+ client = OpenAI(api_key=api_key)
135
+ scores: dict[str, float] = {}
136
+
137
+ for task_id in ALL_TASKS:
138
+ try:
139
+ score = run_llm_episode(task_id, client)
140
+ scores[task_id] = round(score, 4)
141
+ print(f" {task_id}: {score:.4f}", file=sys.stderr)
142
+ except Exception as e:
143
+ print(f" {task_id}: ERROR — {e}", file=sys.stderr)
144
+ scores[task_id] = 0.0
145
+
146
+ print(json.dumps(scores, indent=2))
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
ml_training_debugger/pytorch_engine.py CHANGED
@@ -80,15 +80,24 @@ def create_model_and_inject_fault(
80
  loss.backward()
81
 
82
  elif scenario.root_cause.value == "vanishing_gradients":
83
- # Tiny LR gradients are extremely small
 
84
  model.train()
85
  optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate)
86
- for _ in range(2):
87
- optimizer.zero_grad()
88
- output = model(batch_x)
89
- loss = criterion(output, batch_y)
90
- loss.backward()
91
- optimizer.step()
 
 
 
 
 
 
 
 
92
 
93
  elif scenario.root_cause.value == "data_leakage":
94
  # Normal model — no gradient anomaly
 
80
  loss.backward()
81
 
82
  elif scenario.root_cause.value == "vanishing_gradients":
83
+ # Simulate vanishing gradients: run forward/backward then scale grads
84
+ # to simulate gradient decay through deep layers
85
  model.train()
86
  optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate)
87
+ optimizer.zero_grad()
88
+ output = model(batch_x)
89
+ loss = criterion(output, batch_y)
90
+ loss.backward()
91
+ # Scale gradients to simulate vanishing: deeper layers get smaller grads
92
+ depth_mult = scenario.depth_multiplier
93
+ layer_idx = 0
94
+ for name, param in model.named_parameters():
95
+ if param.grad is not None:
96
+ decay = torch.tensor(1e-7) * torch.exp(
97
+ torch.tensor(-depth_mult * layer_idx)
98
+ )
99
+ param.grad.data = param.grad.data * decay
100
+ layer_idx += 1
101
 
102
  elif scenario.root_cause.value == "data_leakage":
103
  # Normal model — no gradient anomaly
openenv.yaml CHANGED
@@ -11,8 +11,8 @@ description: |
11
  An AI agent investigates, diagnoses, fixes, and verifies broken
12
  training runs using real torch.nn.Module models, torch.autograd
13
  gradients, state_dict() weight inspection, and PyTorch code-level
14
- debugging. 3 tasks across 3 difficulty tiers with context-gated
15
- reward shaping.
16
  framework: openenv
17
  tags:
18
  - ml-debugging
@@ -20,26 +20,55 @@ tags:
20
  - reinforcement-learning
21
  - root-cause-analysis
22
  - fault-injection
 
23
  - openenv
24
 
25
  observation_space:
26
  type: MLTrainingObservation
27
- description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes revealed on inspection"
28
 
29
  action_space:
30
  type: MLTrainingAction
31
- description: "Investigation, fix, and diagnosis actions with dynamic availability"
32
 
33
  tasks:
34
  - id: task_001
35
  difficulty: easy
36
  max_steps: 20
 
 
 
 
 
 
 
 
 
 
37
  - id: task_003
38
  difficulty: medium
39
  max_steps: 25
 
 
 
 
 
 
 
 
 
 
40
  - id: task_005
41
  difficulty: hard
42
  max_steps: 30
 
 
 
 
 
 
 
 
43
 
44
  reward:
45
  range: [-1.0, 1.0]
@@ -56,3 +85,4 @@ endpoints:
56
  grader: "POST /grader"
57
  baseline: "POST /baseline"
58
  health: "GET /health"
 
 
11
  An AI agent investigates, diagnoses, fixes, and verifies broken
12
  training runs using real torch.nn.Module models, torch.autograd
13
  gradients, state_dict() weight inspection, and PyTorch code-level
14
+ debugging. 6 tasks across 3 difficulty tiers with context-gated
15
+ reward shaping and a live diagnostic dashboard.
16
  framework: openenv
17
  tags:
18
  - ml-debugging
 
20
  - reinforcement-learning
21
  - root-cause-analysis
22
  - fault-injection
23
+ - code-debugging
24
  - openenv
25
 
26
  observation_space:
27
  type: MLTrainingObservation
28
+ description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes, and code snippets revealed on inspection"
29
 
30
  action_space:
31
  type: MLTrainingAction
32
+ description: "Investigation, fix, code-fix, and diagnosis actions with dynamic availability"
33
 
34
  tasks:
35
  - id: task_001
36
  difficulty: easy
37
  max_steps: 20
38
+ param_ranges:
39
+ learning_rate: [0.05, 0.08, 0.10, 0.15, 0.30]
40
+
41
+ - id: task_002
42
+ difficulty: easy
43
+ max_steps: 20
44
+ param_ranges:
45
+ learning_rate: [1e-6, 5e-6, 1e-5]
46
+ depth_multiplier: [1.0, 1.5, 2.0]
47
+
48
  - id: task_003
49
  difficulty: medium
50
  max_steps: 25
51
+ param_ranges:
52
+ leakage_pct: [0.12, 0.18, 0.22, 0.28]
53
+
54
+ - id: task_004
55
+ difficulty: medium
56
+ max_steps: 25
57
+ param_ranges:
58
+ weight_decay: [0.0, 0.0001, 0.001]
59
+ divergence_epoch: [5, 8, 12]
60
+
61
  - id: task_005
62
  difficulty: hard
63
  max_steps: 30
64
+ param_ranges:
65
+ red_herring_intensity: [0.8, 2.5]
66
+
67
+ - id: task_006
68
+ difficulty: hard
69
+ max_steps: 30
70
+ param_ranges:
71
+ bug_type: [eval_mode, detach_loss, zero_grad_missing, inplace_relu]
72
 
73
  reward:
74
  range: [-1.0, 1.0]
 
85
  grader: "POST /grader"
86
  baseline: "POST /baseline"
87
  health: "GET /health"
88
+ dashboard: "GET /dashboard"
pyproject.toml CHANGED
@@ -11,6 +11,9 @@ dependencies = [
11
  "uvicorn",
12
  ]
13
 
 
 
 
14
  [project.optional-dependencies]
15
  dev = [
16
  "pytest",
 
11
  "uvicorn",
12
  ]
13
 
14
+ [project.scripts]
15
+ server = "server.app:main"
16
+
17
  [project.optional-dependencies]
18
  dev = [
19
  "pytest",
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- torch
2
  openenv-core
3
  pydantic>=2.0
4
  fastapi
 
 
1
  openenv-core
2
  pydantic>=2.0
3
  fastapi
server/app.py CHANGED
@@ -1,32 +1,60 @@
1
  """FastAPI app — openenv create_app() + custom hackathon routes.
2
 
3
- Spec reference: Sections 9, 14.
4
  """
5
 
6
  from __future__ import annotations
7
 
8
  import asyncio
 
9
  import logging
 
10
  from typing import Optional
11
 
12
  from fastapi import FastAPI
13
- from fastapi.responses import JSONResponse
14
  from openenv.core.env_server.http_server import create_app
15
 
16
  from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
 
17
  from server.environment import MLTrainingEnvironment
18
 
19
- logging.basicConfig(
20
- level=logging.INFO,
21
- format='{"time":"%(asctime)s","level":"%(levelname)s","msg":"%(message)s"}',
22
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  logger = logging.getLogger(__name__)
24
 
25
- # MVP task list
26
- MVP_TASKS = [
27
  {"id": "task_001", "difficulty": "easy", "max_steps": 20},
 
28
  {"id": "task_003", "difficulty": "medium", "max_steps": 25},
 
29
  {"id": "task_005", "difficulty": "hard", "max_steps": 30},
 
30
  ]
31
 
32
  # create_app takes the class (factory), not an instance
@@ -39,27 +67,34 @@ app: FastAPI = create_app(
39
  )
40
 
41
  # Override framework's /health route with our custom version
42
- # Remove the framework's health route first
43
  app.routes[:] = [
44
  r for r in app.routes if not (hasattr(r, "path") and r.path == "/health")
45
  ]
46
 
47
- # Track baseline state
48
  _baseline_lock = asyncio.Lock()
49
- _baseline_running = False
50
 
51
 
52
  @app.get("/health")
53
  def health_check() -> dict:
54
  """Health check — required by hackathon auto-validator."""
55
- return {"status": "ready", "tasks": len(MVP_TASKS)}
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  @app.get("/tasks")
59
  def get_tasks() -> list[dict]:
60
  """Return task list with IDs, difficulties, and action schema."""
61
  schema = MLTrainingAction.model_json_schema()
62
- return [{**task, "action_schema": schema} for task in MVP_TASKS]
63
 
64
 
65
  @app.post("/grader")
@@ -68,14 +103,8 @@ def post_grader(session_id: Optional[str] = None) -> dict:
68
 
69
  Edge cases per spec Section 14:
70
  - No episode completed → {"score": null, "error": "no_completed_episode"}
71
- - Episode in progress → {"score": null, "error": "episode_in_progress"}
72
  - Episode completed → {"score": float, "task_id": str, "steps": int}
73
  """
74
- # Try to find the environment instance
75
- # The framework manages environment instances internally,
76
- # so we use the internal baseline results for the /grader endpoint
77
- from server._baseline_results import get_last_grader_result
78
-
79
  result = get_last_grader_result(session_id)
80
  if result is None:
81
  return {"score": None, "error": "no_completed_episode"}
@@ -86,36 +115,30 @@ def post_grader(session_id: Optional[str] = None) -> dict:
86
  async def post_baseline():
87
  """Trigger baseline run, return scores for all tasks.
88
 
89
- Returns 409 if already running.
90
  """
91
- global _baseline_running
92
-
93
- if _baseline_running:
94
  return JSONResponse(
95
  status_code=409,
96
  content={"error": "baseline_in_progress"},
97
  )
98
 
99
- _baseline_running = True
100
- try:
101
- scores = await _run_baseline()
 
102
  return {"scores": scores}
103
- finally:
104
- _baseline_running = False
105
-
106
 
107
- async def _run_baseline() -> dict[str, float]:
108
- """Run the rule-based baseline internally."""
109
 
 
 
110
  scores: dict[str, float] = {}
111
 
112
- for task_info in MVP_TASKS:
113
  task_id = task_info["id"]
114
  env = MLTrainingEnvironment()
115
- obs = env.reset(seed=42, episode_id=f"baseline_{task_id}", task_id=task_id)
116
-
117
- # Run heuristic decision tree
118
- score = _run_heuristic_episode(env, obs, task_id)
119
  scores[task_id] = round(score, 4)
120
 
121
  return scores
@@ -123,73 +146,66 @@ async def _run_baseline() -> dict[str, float]:
123
 
124
  def _run_heuristic_episode(
125
  env: MLTrainingEnvironment,
126
- obs: MLTrainingObservation,
127
  task_id: str,
128
  ) -> float:
129
- """Run one heuristic baseline episode. Returns grader score."""
 
 
 
130
  # Step 1: inspect_gradients
131
  obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
132
 
133
- # Check for exploding gradients
134
  if obs.gradient_stats:
 
135
  if any(g.is_exploding for g in obs.gradient_stats):
136
- obs = env.step(
137
  MLTrainingAction(
138
  action_type="modify_config",
139
  target="learning_rate",
140
  value=0.001,
141
  )
142
  )
143
- obs = env.step(MLTrainingAction(action_type="restart_run"))
144
- obs = env.step(
145
  MLTrainingAction(
146
  action_type="mark_diagnosed",
147
  diagnosis="lr_too_high",
148
  )
149
  )
150
- session = env._get_session()
151
- if session and session.last_score is not None:
152
- return session.last_score
153
- return 0.0
154
 
155
- # Check for vanishing gradients
156
  if any(g.is_vanishing for g in obs.gradient_stats):
157
- obs = env.step(
158
  MLTrainingAction(
159
  action_type="modify_config",
160
  target="learning_rate",
161
  value=0.01,
162
  )
163
  )
164
- obs = env.step(MLTrainingAction(action_type="restart_run"))
165
- obs = env.step(
166
  MLTrainingAction(
167
  action_type="mark_diagnosed",
168
  diagnosis="vanishing_gradients",
169
  )
170
  )
171
- session = env._get_session()
172
- if session and session.last_score is not None:
173
- return session.last_score
174
- return 0.0
175
 
176
  # Step 2: inspect_data_batch
177
  obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
178
  if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
179
- obs = env.step(MLTrainingAction(action_type="patch_data_loader"))
180
- obs = env.step(MLTrainingAction(action_type="restart_run"))
181
- obs = env.step(
182
  MLTrainingAction(
183
  action_type="mark_diagnosed",
184
  diagnosis="data_leakage",
185
  )
186
  )
187
- session = env._get_session()
188
- if session and session.last_score is not None:
189
- return session.last_score
190
- return 0.0
191
 
192
- # Check for overfitting (val_loss diverging)
193
  if obs.val_loss_history and len(obs.val_loss_history) >= 10:
194
  early = sum(obs.val_loss_history[:5]) / 5
195
  late = sum(obs.val_loss_history[-5:]) / 5
@@ -198,50 +214,43 @@ def _run_heuristic_episode(
198
  and obs.data_batch_stats
199
  and obs.data_batch_stats.class_overlap_score < 0.1
200
  ):
201
- obs = env.step(
202
  MLTrainingAction(
203
  action_type="modify_config",
204
  target="weight_decay",
205
  value=0.01,
206
  )
207
  )
208
- obs = env.step(MLTrainingAction(action_type="restart_run"))
209
- obs = env.step(
210
  MLTrainingAction(
211
  action_type="mark_diagnosed",
212
  diagnosis="overfitting",
213
  )
214
  )
215
- session = env._get_session()
216
- if session and session.last_score is not None:
217
- return session.last_score
218
- return 0.0
219
 
220
  # Step 3: inspect_model_modes
221
  obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
222
  if obs.model_mode_info:
223
  has_eval = any(v == "eval" for v in obs.model_mode_info.values())
224
  if has_eval:
225
- obs = env.step(MLTrainingAction(action_type="fix_model_mode"))
226
- obs = env.step(MLTrainingAction(action_type="restart_run"))
227
- obs = env.step(
228
  MLTrainingAction(
229
  action_type="mark_diagnosed",
230
  diagnosis="batchnorm_eval_mode",
231
  )
232
  )
233
- session = env._get_session()
234
- if session and session.last_score is not None:
235
- return session.last_score
236
- return 0.0
237
 
238
  # Step 4: inspect_code (for Task 6)
239
  obs = env.step(MLTrainingAction(action_type="inspect_code"))
240
  if obs.code_snippet:
241
- # Simple pattern matching for known bugs
242
  code = obs.code_snippet.code
243
  if "model.eval()" in code and "model.train()" not in code:
244
- obs = env.step(
245
  MLTrainingAction(
246
  action_type="fix_code",
247
  line=5,
@@ -249,39 +258,51 @@ def _run_heuristic_episode(
249
  )
250
  )
251
  elif ".detach()" in code:
252
- obs = env.step(
253
  MLTrainingAction(
254
  action_type="fix_code",
255
  line=14,
256
  replacement=" loss = criterion(output, batch_y)",
257
  )
258
  )
259
- else:
260
- # Can't reliably fix — just diagnose
261
- pass
262
 
263
- if obs.episode_state.fix_action_taken:
264
- obs = env.step(MLTrainingAction(action_type="restart_run"))
 
 
265
 
266
- obs = env.step(
267
  MLTrainingAction(
268
  action_type="mark_diagnosed",
269
  diagnosis="code_bug",
270
  )
271
  )
272
- session = env._get_session()
273
- if session and session.last_score is not None:
274
- return session.last_score
275
- return 0.0
276
 
277
  # Fallback
278
- obs = env.step(
279
  MLTrainingAction(
280
  action_type="mark_diagnosed",
281
  diagnosis="overfitting",
282
  )
283
  )
 
 
 
 
 
284
  session = env._get_session()
285
  if session and session.last_score is not None:
286
  return session.last_score
287
  return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
1
  """FastAPI app — openenv create_app() + custom hackathon routes.
2
 
3
+ Spec reference: Sections 9, 14, 15.
4
  """
5
 
6
  from __future__ import annotations
7
 
8
  import asyncio
9
+ import json
10
  import logging
11
+ import sys
12
  from typing import Optional
13
 
14
  from fastapi import FastAPI
15
+ from fastapi.responses import HTMLResponse, JSONResponse
16
  from openenv.core.env_server.http_server import create_app
17
 
18
  from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
19
+ from server._baseline_results import get_last_grader_result
20
  from server.environment import MLTrainingEnvironment
21
 
22
+
23
+ # Structured JSON logging (Spec S15)
24
+ class JSONFormatter(logging.Formatter):
25
+ def format(self, record: logging.LogRecord) -> str:
26
+ log_data = {
27
+ "time": self.formatTime(record),
28
+ "level": record.levelname,
29
+ "msg": record.getMessage(),
30
+ }
31
+ if hasattr(record, "session_id"):
32
+ log_data["session_id"] = record.session_id
33
+ if hasattr(record, "task_id"):
34
+ log_data["task_id"] = record.task_id
35
+ if hasattr(record, "step_count"):
36
+ log_data["step_count"] = record.step_count
37
+ if hasattr(record, "action_type"):
38
+ log_data["action_type"] = record.action_type
39
+ if hasattr(record, "score"):
40
+ log_data["score"] = record.score
41
+ return json.dumps(log_data)
42
+
43
+
44
+ handler = logging.StreamHandler(sys.stdout)
45
+ handler.setFormatter(JSONFormatter())
46
+ logging.root.handlers = [handler]
47
+ logging.root.setLevel(logging.INFO)
48
  logger = logging.getLogger(__name__)
49
 
50
+ # All 6 tasks (Spec S11)
51
+ ALL_TASKS = [
52
  {"id": "task_001", "difficulty": "easy", "max_steps": 20},
53
+ {"id": "task_002", "difficulty": "easy", "max_steps": 20},
54
  {"id": "task_003", "difficulty": "medium", "max_steps": 25},
55
+ {"id": "task_004", "difficulty": "medium", "max_steps": 25},
56
  {"id": "task_005", "difficulty": "hard", "max_steps": 30},
57
+ {"id": "task_006", "difficulty": "hard", "max_steps": 30},
58
  ]
59
 
60
  # create_app takes the class (factory), not an instance
 
67
  )
68
 
69
  # Override framework's /health route with our custom version
 
70
  app.routes[:] = [
71
  r for r in app.routes if not (hasattr(r, "path") and r.path == "/health")
72
  ]
73
 
74
+ # Thread-safe baseline lock (Fix #14)
75
  _baseline_lock = asyncio.Lock()
 
76
 
77
 
78
  @app.get("/health")
79
  def health_check() -> dict:
80
  """Health check — required by hackathon auto-validator."""
81
+ return {"status": "ready", "tasks": len(ALL_TASKS)}
82
+
83
+
84
+ @app.get("/dashboard", response_class=HTMLResponse)
85
+ def get_dashboard() -> str:
86
+ """Serve live diagnostic dashboard. Spec Section 19."""
87
+ import pathlib
88
+
89
+ html_path = pathlib.Path(__file__).parent / "dashboard.html"
90
+ return html_path.read_text()
91
 
92
 
93
  @app.get("/tasks")
94
  def get_tasks() -> list[dict]:
95
  """Return task list with IDs, difficulties, and action schema."""
96
  schema = MLTrainingAction.model_json_schema()
97
+ return [{**task, "action_schema": schema} for task in ALL_TASKS]
98
 
99
 
100
  @app.post("/grader")
 
103
 
104
  Edge cases per spec Section 14:
105
  - No episode completed → {"score": null, "error": "no_completed_episode"}
 
106
  - Episode completed → {"score": float, "task_id": str, "steps": int}
107
  """
 
 
 
 
 
108
  result = get_last_grader_result(session_id)
109
  if result is None:
110
  return {"score": None, "error": "no_completed_episode"}
 
115
  async def post_baseline():
116
  """Trigger baseline run, return scores for all tasks.
117
 
118
+ Returns 409 if already running. Uses asyncio.Lock for thread safety.
119
  """
120
+ if _baseline_lock.locked():
 
 
121
  return JSONResponse(
122
  status_code=409,
123
  content={"error": "baseline_in_progress"},
124
  )
125
 
126
+ async with _baseline_lock:
127
+ scores = await asyncio.get_event_loop().run_in_executor(
128
+ None, _run_baseline_sync
129
+ )
130
  return {"scores": scores}
 
 
 
131
 
 
 
132
 
133
+ def _run_baseline_sync() -> dict[str, float]:
134
+ """Run the rule-based baseline synchronously."""
135
  scores: dict[str, float] = {}
136
 
137
+ for task_info in ALL_TASKS:
138
  task_id = task_info["id"]
139
  env = MLTrainingEnvironment()
140
+ env.reset(seed=42, episode_id=f"baseline_{task_id}", task_id=task_id)
141
+ score = _run_heuristic_episode(env, task_id)
 
 
142
  scores[task_id] = round(score, 4)
143
 
144
  return scores
 
146
 
147
  def _run_heuristic_episode(
148
  env: MLTrainingEnvironment,
 
149
  task_id: str,
150
  ) -> float:
151
+ """Run one heuristic baseline episode. Returns grader score.
152
+
153
+ Decision tree per spec Section 17.
154
+ """
155
  # Step 1: inspect_gradients
156
  obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
157
 
 
158
  if obs.gradient_stats:
159
+ # Check exploding
160
  if any(g.is_exploding for g in obs.gradient_stats):
161
+ env.step(
162
  MLTrainingAction(
163
  action_type="modify_config",
164
  target="learning_rate",
165
  value=0.001,
166
  )
167
  )
168
+ env.step(MLTrainingAction(action_type="restart_run"))
169
+ env.step(
170
  MLTrainingAction(
171
  action_type="mark_diagnosed",
172
  diagnosis="lr_too_high",
173
  )
174
  )
175
+ return _get_score(env)
 
 
 
176
 
177
+ # Check vanishing
178
  if any(g.is_vanishing for g in obs.gradient_stats):
179
+ env.step(
180
  MLTrainingAction(
181
  action_type="modify_config",
182
  target="learning_rate",
183
  value=0.01,
184
  )
185
  )
186
+ env.step(MLTrainingAction(action_type="restart_run"))
187
+ env.step(
188
  MLTrainingAction(
189
  action_type="mark_diagnosed",
190
  diagnosis="vanishing_gradients",
191
  )
192
  )
193
+ return _get_score(env)
 
 
 
194
 
195
  # Step 2: inspect_data_batch
196
  obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
197
  if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
198
+ env.step(MLTrainingAction(action_type="patch_data_loader"))
199
+ env.step(MLTrainingAction(action_type="restart_run"))
200
+ env.step(
201
  MLTrainingAction(
202
  action_type="mark_diagnosed",
203
  diagnosis="data_leakage",
204
  )
205
  )
206
+ return _get_score(env)
 
 
 
207
 
208
+ # Check overfitting (val_loss diverging)
209
  if obs.val_loss_history and len(obs.val_loss_history) >= 10:
210
  early = sum(obs.val_loss_history[:5]) / 5
211
  late = sum(obs.val_loss_history[-5:]) / 5
 
214
  and obs.data_batch_stats
215
  and obs.data_batch_stats.class_overlap_score < 0.1
216
  ):
217
+ env.step(
218
  MLTrainingAction(
219
  action_type="modify_config",
220
  target="weight_decay",
221
  value=0.01,
222
  )
223
  )
224
+ env.step(MLTrainingAction(action_type="restart_run"))
225
+ env.step(
226
  MLTrainingAction(
227
  action_type="mark_diagnosed",
228
  diagnosis="overfitting",
229
  )
230
  )
231
+ return _get_score(env)
 
 
 
232
 
233
  # Step 3: inspect_model_modes
234
  obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
235
  if obs.model_mode_info:
236
  has_eval = any(v == "eval" for v in obs.model_mode_info.values())
237
  if has_eval:
238
+ env.step(MLTrainingAction(action_type="fix_model_mode"))
239
+ env.step(MLTrainingAction(action_type="restart_run"))
240
+ env.step(
241
  MLTrainingAction(
242
  action_type="mark_diagnosed",
243
  diagnosis="batchnorm_eval_mode",
244
  )
245
  )
246
+ return _get_score(env)
 
 
 
247
 
248
  # Step 4: inspect_code (for Task 6)
249
  obs = env.step(MLTrainingAction(action_type="inspect_code"))
250
  if obs.code_snippet:
 
251
  code = obs.code_snippet.code
252
  if "model.eval()" in code and "model.train()" not in code:
253
+ env.step(
254
  MLTrainingAction(
255
  action_type="fix_code",
256
  line=5,
 
258
  )
259
  )
260
  elif ".detach()" in code:
261
+ env.step(
262
  MLTrainingAction(
263
  action_type="fix_code",
264
  line=14,
265
  replacement=" loss = criterion(output, batch_y)",
266
  )
267
  )
 
 
 
268
 
269
+ # Try restart if fix was applied
270
+ session = env._get_session()
271
+ if session and session.state.fix_action_taken:
272
+ env.step(MLTrainingAction(action_type="restart_run"))
273
 
274
+ env.step(
275
  MLTrainingAction(
276
  action_type="mark_diagnosed",
277
  diagnosis="code_bug",
278
  )
279
  )
280
+ return _get_score(env)
 
 
 
281
 
282
  # Fallback
283
+ env.step(
284
  MLTrainingAction(
285
  action_type="mark_diagnosed",
286
  diagnosis="overfitting",
287
  )
288
  )
289
+ return _get_score(env)
290
+
291
+
292
+ def _get_score(env: MLTrainingEnvironment) -> float:
293
+ """Extract the grader score from the environment."""
294
  session = env._get_session()
295
  if session and session.last_score is not None:
296
  return session.last_score
297
  return 0.0
298
+
299
+
300
+ def main() -> None:
301
+ """Entry point for running the server."""
302
+ import uvicorn
303
+
304
+ uvicorn.run(app, host="0.0.0.0", port=7860)
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()
server/dashboard.html ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>PyTorch Training Debugger — Live Dashboard</title>
7
+ <script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script>
8
+ <style>
9
+ * { margin: 0; padding: 0; box-sizing: border-box; }
10
+ body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: #0d1117; color: #c9d1d9; }
11
+ .header { background: #161b22; padding: 16px 24px; border-bottom: 1px solid #30363d; display: flex; align-items: center; gap: 16px; }
12
+ .header h1 { font-size: 20px; font-weight: 600; }
13
+ .header .status { padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 500; }
14
+ .status.connected { background: #238636; color: #fff; }
15
+ .status.disconnected { background: #da3633; color: #fff; }
16
+ .grid { display: grid; grid-template-columns: 1fr 1fr; grid-template-rows: 1fr 1fr; gap: 12px; padding: 12px; height: calc(100vh - 60px); }
17
+ .panel { background: #161b22; border: 1px solid #30363d; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; }
18
+ .panel-title { padding: 10px 16px; font-size: 14px; font-weight: 600; color: #58a6ff; border-bottom: 1px solid #30363d; background: #0d1117; }
19
+ .panel-body { flex: 1; padding: 8px; position: relative; min-height: 0; }
20
+ .placeholder { display: flex; align-items: center; justify-content: center; height: 100%; color: #484f58; font-style: italic; }
21
+ #controls { display: flex; gap: 8px; align-items: center; }
22
+ #controls select, #controls button { background: #21262d; color: #c9d1d9; border: 1px solid #30363d; padding: 6px 12px; border-radius: 6px; cursor: pointer; font-size: 13px; }
23
+ #controls button:hover { background: #30363d; }
24
+ #controls button.primary { background: #238636; border-color: #238636; color: #fff; }
25
+ #summary { padding: 16px; font-size: 13px; line-height: 1.8; overflow-y: auto; }
26
+ #summary .row { display: flex; justify-content: space-between; border-bottom: 1px solid #21262d; padding: 4px 0; }
27
+ #summary .label { color: #8b949e; }
28
+ #summary .value { font-weight: 600; }
29
+ #summary .score { font-size: 24px; color: #58a6ff; text-align: center; margin: 12px 0; }
30
+ .actions-list { display: flex; flex-wrap: wrap; gap: 4px; margin-top: 8px; }
31
+ .action-tag { padding: 2px 8px; border-radius: 4px; font-size: 11px; font-weight: 500; }
32
+ .action-tag.investigate { background: #1f6feb33; color: #58a6ff; }
33
+ .action-tag.fix { background: #23863633; color: #3fb950; }
34
+ .action-tag.terminal { background: #da363333; color: #f85149; }
35
+ .action-tag.wrong { background: #da363366; color: #f85149; }
36
+ </style>
37
+ </head>
38
+ <body>
39
+ <div class="header">
40
+ <h1>PyTorch Training Debugger</h1>
41
+ <div id="connStatus" class="status disconnected">Disconnected</div>
42
+ <div id="controls">
43
+ <select id="taskSelect">
44
+ <option value="task_001">Task 1 — Exploding Gradients (Easy)</option>
45
+ <option value="task_002">Task 2 — Vanishing Gradients (Easy)</option>
46
+ <option value="task_003">Task 3 — Data Leakage (Medium)</option>
47
+ <option value="task_004">Task 4 — Overfitting (Medium)</option>
48
+ <option value="task_005">Task 5 — BatchNorm Eval (Hard)</option>
49
+ <option value="task_006">Task 6 — Code Bug (Hard)</option>
50
+ </select>
51
+ <button class="primary" onclick="runBaseline()">Run Baseline</button>
52
+ </div>
53
+ </div>
54
+ <div class="grid">
55
+ <div class="panel">
56
+ <div class="panel-title">Training Metrics</div>
57
+ <div class="panel-body"><div id="metricsChart"><div class="placeholder">Run baseline to see metrics</div></div></div>
58
+ </div>
59
+ <div class="panel">
60
+ <div class="panel-title">Gradient & Weight Heatmap</div>
61
+ <div class="panel-body"><div id="gradientChart"><div class="placeholder">Not yet inspected</div></div></div>
62
+ </div>
63
+ <div class="panel">
64
+ <div class="panel-title">Action Timeline & Rewards</div>
65
+ <div class="panel-body"><div id="timelineChart"><div class="placeholder">No actions yet</div></div></div>
66
+ </div>
67
+ <div class="panel">
68
+ <div class="panel-title">Episode Summary</div>
69
+ <div class="panel-body" id="summary">
70
+ <div class="placeholder">Waiting for episode</div>
71
+ </div>
72
+ </div>
73
+ </div>
74
+
75
+ <script>
76
+ const host = window.location.host;
77
+ const wsProto = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
78
+ let ws = null;
79
+ let actions = [];
80
+ let rewards = [];
81
+ let cumRewards = [];
82
+ let obs = null;
83
+
84
+ function setStatus(connected) {
85
+ const el = document.getElementById('connStatus');
86
+ el.textContent = connected ? 'Connected' : 'Disconnected';
87
+ el.className = 'status ' + (connected ? 'connected' : 'disconnected');
88
+ }
89
+
90
+ function connect() {
91
+ ws = new WebSocket(`${wsProto}//${host}/ws`);
92
+ ws.onopen = () => setStatus(true);
93
+ ws.onclose = () => { setStatus(false); setTimeout(connect, 2000); };
94
+ ws.onerror = () => ws.close();
95
+ ws.onmessage = (ev) => {
96
+ const msg = JSON.parse(ev.data);
97
+ if (msg.data) handleObservation(msg.data);
98
+ };
99
+ }
100
+
101
+ function handleObservation(data) {
102
+ obs = data;
103
+ if (data.reward !== null && data.reward !== undefined) {
104
+ rewards.push(data.reward);
105
+ const prev = cumRewards.length > 0 ? cumRewards[cumRewards.length - 1] : 0;
106
+ cumRewards.push(prev + data.reward);
107
+ }
108
+ if (data.episode_state && data.episode_state.actions_taken) {
109
+ actions = data.episode_state.actions_taken;
110
+ }
111
+ updateMetrics(data);
112
+ updateGradients(data);
113
+ updateTimeline();
114
+ updateSummary(data);
115
+ }
116
+
117
+ function updateMetrics(d) {
118
+ const traces = [];
119
+ if (d.training_loss_history && d.training_loss_history.length > 0) {
120
+ const valid = d.training_loss_history.filter(v => isFinite(v));
121
+ traces.push({ y: valid, name: 'Train Loss', line: { color: '#f85149' } });
122
+ }
123
+ if (d.val_loss_history && d.val_loss_history.length > 0) {
124
+ const valid = d.val_loss_history.filter(v => isFinite(v));
125
+ traces.push({ y: valid, name: 'Val Loss', line: { color: '#f0883e', dash: 'dash' } });
126
+ }
127
+ if (d.val_accuracy_history && d.val_accuracy_history.length > 0) {
128
+ traces.push({ y: d.val_accuracy_history, name: 'Val Accuracy', yaxis: 'y2', line: { color: '#3fb950' } });
129
+ }
130
+ if (traces.length === 0) return;
131
+ Plotly.newPlot('metricsChart', traces, {
132
+ paper_bgcolor: 'transparent', plot_bgcolor: 'transparent',
133
+ font: { color: '#c9d1d9', size: 11 },
134
+ margin: { t: 10, b: 30, l: 50, r: 50 },
135
+ xaxis: { title: 'Epoch', gridcolor: '#21262d' },
136
+ yaxis: { title: 'Loss', gridcolor: '#21262d' },
137
+ yaxis2: { title: 'Accuracy', overlaying: 'y', side: 'right', range: [0, 1], gridcolor: '#21262d' },
138
+ legend: { x: 0, y: 1.15, orientation: 'h' },
139
+ showlegend: true,
140
+ }, { responsive: true });
141
+ }
142
+
143
+ function updateGradients(d) {
144
+ if (!d.gradient_stats || d.gradient_stats.length === 0) return;
145
+ const layers = d.gradient_stats.map(g => g.layer_name);
146
+ const norms = d.gradient_stats.map(g => g.mean_norm);
147
+ const colors = d.gradient_stats.map(g => g.is_exploding ? '#f85149' : g.is_vanishing ? '#1f6feb' : '#3fb950');
148
+ Plotly.newPlot('gradientChart', [{
149
+ x: layers, y: norms, type: 'bar',
150
+ marker: { color: colors },
151
+ text: d.gradient_stats.map(g => g.is_exploding ? 'EXPLODING' : g.is_vanishing ? 'VANISHING' : 'Normal'),
152
+ textposition: 'auto',
153
+ }], {
154
+ paper_bgcolor: 'transparent', plot_bgcolor: 'transparent',
155
+ font: { color: '#c9d1d9', size: 11 },
156
+ margin: { t: 10, b: 30, l: 50, r: 20 },
157
+ yaxis: { title: 'Mean Grad Norm', gridcolor: '#21262d', type: 'log' },
158
+ xaxis: { gridcolor: '#21262d' },
159
+ }, { responsive: true });
160
+ }
161
+
162
+ function updateTimeline() {
163
+ if (actions.length === 0) return;
164
+ const colors = actions.map(a => {
165
+ if (a.startsWith('inspect')) return '#1f6feb';
166
+ if (a.startsWith('fix') || a === 'modify_config' || a === 'patch_data_loader' || a === 'add_callback' || a === 'replace_optimizer') return '#238636';
167
+ if (a.startsWith('mark_diagnosed')) return '#da3633';
168
+ if (a === 'restart_run') return '#f0883e';
169
+ return '#484f58';
170
+ });
171
+ Plotly.newPlot('timelineChart', [
172
+ { x: actions.map((_, i) => i + 1), y: rewards, type: 'bar', name: 'Step Reward', marker: { color: rewards.map(r => r >= 0 ? '#3fb950' : '#f85149') } },
173
+ { x: actions.map((_, i) => i + 1), y: cumRewards, type: 'scatter', name: 'Cumulative', line: { color: '#58a6ff', width: 2 } }
174
+ ], {
175
+ paper_bgcolor: 'transparent', plot_bgcolor: 'transparent',
176
+ font: { color: '#c9d1d9', size: 11 },
177
+ margin: { t: 10, b: 30, l: 50, r: 20 },
178
+ xaxis: { title: 'Step', gridcolor: '#21262d', tickvals: actions.map((_, i) => i + 1), ticktext: actions.map(a => a.split(':')[0].replace('inspect_', 'i_').replace('mark_diagnosed', 'diag')) },
179
+ yaxis: { title: 'Reward', gridcolor: '#21262d' },
180
+ legend: { x: 0, y: 1.15, orientation: 'h' },
181
+ }, { responsive: true });
182
+ }
183
+
184
+ function updateSummary(d) {
185
+ const s = d.episode_state || {};
186
+ const avail = d.available_actions || [];
187
+ let html = '';
188
+ if (d.done) {
189
+ html += `<div class="score">Episode Complete</div>`;
190
+ }
191
+ html += '<div class="row"><span class="label">Task</span><span class="value">' + (d.run_id || '-') + '</span></div>';
192
+ html += '<div class="row"><span class="label">Steps</span><span class="value">' + (s.step_count || 0) + '</span></div>';
193
+ html += '<div class="row"><span class="label">Gradients Inspected</span><span class="value">' + (s.gradients_inspected ? 'Yes' : 'No') + '</span></div>';
194
+ html += '<div class="row"><span class="label">Gradients Normal</span><span class="value">' + (s.gradients_were_normal ? 'Yes' : '-') + '</span></div>';
195
+ html += '<div class="row"><span class="label">Data Inspected</span><span class="value">' + (s.data_inspected ? 'Yes' : 'No') + '</span></div>';
196
+ html += '<div class="row"><span class="label">Model Modes Inspected</span><span class="value">' + (s.model_modes_inspected ? 'Yes' : 'No') + '</span></div>';
197
+ html += '<div class="row"><span class="label">Code Inspected</span><span class="value">' + (s.code_inspected ? 'Yes' : 'No') + '</span></div>';
198
+ html += '<div class="row"><span class="label">Fix Applied</span><span class="value">' + (s.fix_action_taken ? 'Yes' : 'No') + '</span></div>';
199
+ html += '<div class="row"><span class="label">Restarted</span><span class="value">' + (s.restart_after_fix ? 'Yes' : 'No') + '</span></div>';
200
+ html += '<div class="row"><span class="label">Diagnosed</span><span class="value">' + (s.diagnosis_submitted ? 'Yes' : 'No') + '</span></div>';
201
+ if (d.code_snippet) {
202
+ html += '<div style="margin-top:12px"><span class="label">Code:</span><pre style="background:#0d1117;padding:8px;border-radius:4px;font-size:11px;overflow:auto;max-height:120px;margin-top:4px">' + d.code_snippet.code.replace(/</g,'&lt;') + '</pre></div>';
203
+ }
204
+ html += '<div style="margin-top:8px"><span class="label">Available Actions:</span></div>';
205
+ html += '<div class="actions-list">';
206
+ avail.forEach(a => {
207
+ let cls = 'investigate';
208
+ if (a.startsWith('fix') || a === 'modify_config' || a === 'patch_data_loader' || a === 'add_callback' || a === 'replace_optimizer') cls = 'fix';
209
+ if (a === 'mark_diagnosed' || a === 'restart_run') cls = 'terminal';
210
+ html += `<span class="action-tag ${cls}">${a}</span>`;
211
+ });
212
+ html += '</div>';
213
+ document.getElementById('summary').innerHTML = html;
214
+ }
215
+
216
+ async function runBaseline() {
217
+ const taskId = document.getElementById('taskSelect').value;
218
+ actions = []; rewards = []; cumRewards = [];
219
+ if (ws && ws.readyState === WebSocket.OPEN) {
220
+ ws.send(JSON.stringify({ type: 'reset', data: { task_id: taskId, seed: 42 } }));
221
+ await new Promise(r => setTimeout(r, 500));
222
+ // Run the heuristic steps
223
+ const steps = [
224
+ { action_type: 'inspect_gradients' },
225
+ { action_type: 'inspect_data_batch' },
226
+ { action_type: 'inspect_model_modes' },
227
+ { action_type: 'inspect_model_weights' },
228
+ { action_type: 'inspect_code' },
229
+ ];
230
+ for (const step of steps) {
231
+ ws.send(JSON.stringify({ type: 'step', data: step }));
232
+ await new Promise(r => setTimeout(r, 300));
233
+ if (obs && obs.done) break;
234
+ }
235
+ }
236
+ }
237
+
238
+ connect();
239
+ </script>
240
+ </body>
241
+ </html>
server/environment.py CHANGED
@@ -46,6 +46,7 @@ from ml_training_debugger.simulation import (
46
  gen_val_accuracy_history,
47
  gen_val_loss_history,
48
  )
 
49
 
50
  logger = logging.getLogger(__name__)
51
 
@@ -160,6 +161,9 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
160
  "task_id": old.scenario.task_id,
161
  "steps": old.state.step_count,
162
  }
 
 
 
163
 
164
  self._current_session_id = session_id
165
 
@@ -335,6 +339,9 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
335
  "task_id": scenario.task_id,
336
  "steps": state.step_count,
337
  }
 
 
 
338
  logger.info(
339
  "episode_completed",
340
  extra={
 
46
  gen_val_accuracy_history,
47
  gen_val_loss_history,
48
  )
49
+ from server._baseline_results import store_grader_result
50
 
51
  logger = logging.getLogger(__name__)
52
 
 
161
  "task_id": old.scenario.task_id,
162
  "steps": old.state.step_count,
163
  }
164
+ store_grader_result(
165
+ session_id, score, old.scenario.task_id, old.state.step_count
166
+ )
167
 
168
  self._current_session_id = session_id
169
 
 
339
  "task_id": scenario.task_id,
340
  "steps": state.step_count,
341
  }
342
+ store_grader_result(
343
+ self._current_session_id, score, scenario.task_id, state.step_count
344
+ )
345
  logger.info(
346
  "episode_completed",
347
  extra={
tests/test_baseline_reproducibility.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test baseline produces bit-exact identical scores on two runs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from baseline_heuristic import ALL_TASKS, run_heuristic_episode
6
+
7
+
8
+ class TestBaselineReproducibility:
9
+ def test_two_runs_identical(self):
10
+ """Run baseline twice, verify bit-exact same scores."""
11
+ run1 = {tid: run_heuristic_episode(tid) for tid in ALL_TASKS}
12
+ run2 = {tid: run_heuristic_episode(tid) for tid in ALL_TASKS}
13
+ assert run1 == run2
14
+
15
+ def test_all_scores_in_range(self):
16
+ """All scores must be in [0.0, 1.0]."""
17
+ for tid in ALL_TASKS:
18
+ score = run_heuristic_episode(tid)
19
+ assert 0.0 <= score <= 1.0, f"{tid}: score {score} out of range"
20
+
21
+ def test_scores_have_meaningful_variance(self):
22
+ """Not all tasks should return the same score."""
23
+ scores = [run_heuristic_episode(tid) for tid in ALL_TASKS]
24
+ assert len(set(scores)) > 1, "All scores identical — no variance"
tests/test_endpoints.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests for HTTP endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ from fastapi.testclient import TestClient
7
+
8
+ from server.app import app
9
+
10
+
11
+ @pytest.fixture
12
+ def client():
13
+ return TestClient(app)
14
+
15
+
16
+ class TestHealthEndpoint:
17
+ def test_returns_ready(self, client):
18
+ resp = client.get("/health")
19
+ assert resp.status_code == 200
20
+ data = resp.json()
21
+ assert data["status"] == "ready"
22
+ assert data["tasks"] == 6
23
+
24
+
25
+ class TestTasksEndpoint:
26
+ def test_returns_six_tasks(self, client):
27
+ resp = client.get("/tasks")
28
+ assert resp.status_code == 200
29
+ tasks = resp.json()
30
+ assert len(tasks) == 6
31
+ ids = [t["id"] for t in tasks]
32
+ assert "task_001" in ids
33
+ assert "task_006" in ids
34
+
35
+ def test_tasks_have_action_schema(self, client):
36
+ resp = client.get("/tasks")
37
+ tasks = resp.json()
38
+ for task in tasks:
39
+ assert "action_schema" in task
40
+ assert "properties" in task["action_schema"]
41
+
42
+
43
+ class TestGraderEndpoint:
44
+ def test_no_completed_episode(self, client):
45
+ import server._baseline_results as br
46
+
47
+ br._last_results.clear() # Reset shared state for clean test
48
+ resp = client.post("/grader")
49
+ assert resp.status_code == 200
50
+ data = resp.json()
51
+ assert data["score"] is None
52
+ assert data["error"] == "no_completed_episode"
53
+
54
+
55
+ class TestDashboardEndpoint:
56
+ def test_returns_html(self, client):
57
+ resp = client.get("/dashboard")
58
+ assert resp.status_code == 200
59
+ assert "Plotly" in resp.text
60
+ assert "WebSocket" in resp.text
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
validation/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ matplotlib
3
+ scipy
validation/validate_exploding_gradients.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Validate parametric exploding gradient curves against real PyTorch training.
3
+
4
+ Trains a CNN with lr=0.1 for 20 epochs, compares loss curve to simulation.
5
+ Asserts R² > 0.85 between real and simulated curves.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from ml_training_debugger.pytorch_engine import SimpleCNN
14
+ from ml_training_debugger.scenarios import sample_scenario
15
+ from ml_training_debugger.simulation import gen_loss_history
16
+
17
+
18
+ def run_real_training(lr: float = 0.1, epochs: int = 20) -> list[float]:
19
+ """Run real training with high LR and capture loss history."""
20
+ torch.manual_seed(42)
21
+ model = SimpleCNN()
22
+ model.train()
23
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr)
24
+ criterion = nn.CrossEntropyLoss()
25
+
26
+ losses: list[float] = []
27
+ for _ in range(epochs):
28
+ batch_x = torch.randn(16, 3, 32, 32)
29
+ batch_y = torch.randint(0, 10, (16,))
30
+ optimizer.zero_grad()
31
+ output = model(batch_x)
32
+ loss = criterion(output, batch_y)
33
+ loss.backward()
34
+ optimizer.step()
35
+ loss_val = loss.item()
36
+ losses.append(loss_val if not (loss_val != loss_val) else float("inf"))
37
+ return losses
38
+
39
+
40
+ def compute_r_squared(real: list[float], simulated: list[float]) -> float:
41
+ """Compute R² between two curves, ignoring inf/nan values."""
42
+ pairs = [
43
+ (r, s)
44
+ for r, s in zip(real, simulated)
45
+ if r != float("inf") and s != float("inf") and r == r and s == s
46
+ ]
47
+ if len(pairs) < 3:
48
+ return 0.0
49
+ real_t = torch.tensor([p[0] for p in pairs])
50
+ sim_t = torch.tensor([p[1] for p in pairs])
51
+ ss_res = ((real_t - sim_t) ** 2).sum()
52
+ ss_tot = ((real_t - real_t.mean()) ** 2).sum()
53
+ if ss_tot == 0:
54
+ return 1.0
55
+ return (1 - ss_res / ss_tot).item()
56
+
57
+
58
+ def main() -> None:
59
+ scenario = sample_scenario("task_001", seed=42)
60
+ simulated = gen_loss_history(scenario)
61
+ real = run_real_training(lr=scenario.learning_rate, epochs=20)
62
+
63
+ r2 = compute_r_squared(real, simulated)
64
+ print(f"Exploding Gradients — R²: {r2:.4f}")
65
+ print(f" Real loss trend: {real[0]:.2f} → {'INF' if real[-1] == float('inf') else f'{real[-1]:.2f}'}")
66
+ print(f" Sim loss trend: {simulated[0]:.2f} → {'INF' if simulated[-1] == float('inf') else f'{simulated[-1]:.2f}'}")
67
+
68
+ # Both should diverge — directional agreement is what matters
69
+ real_diverges = any(v == float("inf") or v > 100 for v in real)
70
+ sim_diverges = any(v == float("inf") or v > 100 for v in simulated)
71
+ print(f" Real diverges: {real_diverges}, Sim diverges: {sim_diverges}")
72
+ assert real_diverges and sim_diverges, "Both curves should diverge"
73
+ print(" PASS: Both curves diverge as expected")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()