omkarrr88 commited on
Commit ·
02e58fe
1
Parent(s): a3e1032
task 6 and 7 made hard
Browse files- README.md +9 -7
- ml_training_debugger/graders.py +45 -10
- pyproject.toml +1 -1
- server/app.py +3 -4
- tests/test_graders.py +146 -2
README.md
CHANGED
|
@@ -107,6 +107,7 @@ Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_s
|
|
| 107 |
|
| 108 |
**Terminal** — end the episode:
|
| 109 |
- `restart_run` — restart training (only available after a fix)
|
|
|
|
| 110 |
- `mark_diagnosed` — submit diagnosis from 7 possible root causes
|
| 111 |
|
| 112 |
Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
|
|
@@ -156,13 +157,14 @@ An agent that chases the gradient spike red herring loses 0.20 points. An agent
|
|
| 156 |
| `task_003` | Medium | **1.00** | 0.40 |
|
| 157 |
| `task_004` | Medium | **1.00** | 0.60 |
|
| 158 |
| `task_005` | Hard | **0.80** | 0.38-0.55 |
|
| 159 |
-
| `task_006` | Hard | **
|
| 160 |
-
| `task_007` | Hard | **
|
| 161 |
-
| **Average** | | **0.
|
| 162 |
|
| 163 |
**What this tells you:**
|
| 164 |
-
- **Hard tasks are genuinely hard:**
|
| 165 |
- **Red herring traps work:** Task 5 penalizes agents that call `add_callback` after seeing normal gradients (-0.20) or `modify_config` when LR isn't the issue (-0.10). LLMs routinely fall for both traps.
|
|
|
|
| 166 |
- **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
|
| 167 |
- **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
|
| 168 |
|
|
@@ -247,7 +249,7 @@ pip install pytest pytest-cov pytest-asyncio httpx websockets
|
|
| 247 |
# Start server
|
| 248 |
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 249 |
|
| 250 |
-
# Run tests (
|
| 251 |
pytest tests/ -v --cov=ml_training_debugger
|
| 252 |
|
| 253 |
# Run heuristic baseline
|
|
@@ -284,7 +286,7 @@ ml_training_debugger/
|
|
| 284 |
models.py — Pydantic data models (Action, Observation, EpisodeState)
|
| 285 |
scenarios.py — Task parameter sampling (7 tasks, deterministic per seed)
|
| 286 |
pytorch_engine.py — Real PyTorch models, fault injection, gradient/weight extraction
|
| 287 |
-
simulation.py — 20-epoch real training with
|
| 288 |
reward_engine.py — 7-component per-step reward with context gating
|
| 289 |
graders.py — Per-task holistic 0.0-1.0 scoring
|
| 290 |
code_templates.py — Task 6 bug variants + 4-strategy fix validation
|
|
@@ -295,7 +297,7 @@ server/
|
|
| 295 |
app.py — FastAPI + custom endpoints
|
| 296 |
dashboard.html — Live Plotly.js diagnostic dashboard
|
| 297 |
|
| 298 |
-
tests/ —
|
| 299 |
baseline_heuristic.py — Rule-based agent (deterministic, no API key)
|
| 300 |
baseline_inference.py — LLM agent (Groq/Cerebras/Gemini/OpenAI)
|
| 301 |
```
|
|
|
|
| 107 |
|
| 108 |
**Terminal** — end the episode:
|
| 109 |
- `restart_run` — restart training (only available after a fix)
|
| 110 |
+
- `rollback_checkpoint` — rollback to pre-fix state (only available after restart)
|
| 111 |
- `mark_diagnosed` — submit diagnosis from 7 possible root causes
|
| 112 |
|
| 113 |
Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
|
|
|
|
| 157 |
| `task_003` | Medium | **1.00** | 0.40 |
|
| 158 |
| `task_004` | Medium | **1.00** | 0.60 |
|
| 159 |
| `task_005` | Hard | **0.80** | 0.38-0.55 |
|
| 160 |
+
| `task_006` | Hard | **0.81** | 0.60-1.00 |
|
| 161 |
+
| `task_007` | Hard | **0.79** | 0.60 |
|
| 162 |
+
| **Average** | | **0.91** | 0.52 |
|
| 163 |
|
| 164 |
**What this tells you:**
|
| 165 |
+
- **Hard tasks are genuinely hard:** All three hard tasks (5, 6, 7) require thorough investigation including weight inspection for full credit. The heuristic scores 0.79-0.81 on hard tasks because it skips weight inspection. An LLM that falls for red herrings or skips investigation scores even lower.
|
| 166 |
- **Red herring traps work:** Task 5 penalizes agents that call `add_callback` after seeing normal gradients (-0.20) or `modify_config` when LR isn't the issue (-0.10). LLMs routinely fall for both traps.
|
| 167 |
+
- **Investigation thoroughness matters:** Tasks 6 and 7 scale fix/restart credit based on how thoroughly the agent investigated before acting. Quick fixes without ruling out alternatives score ~60-65% of full credit.
|
| 168 |
- **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
|
| 169 |
- **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
|
| 170 |
|
|
|
|
| 249 |
# Start server
|
| 250 |
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 251 |
|
| 252 |
+
# Run tests (246 tests, 96% coverage)
|
| 253 |
pytest tests/ -v --cov=ml_training_debugger
|
| 254 |
|
| 255 |
# Run heuristic baseline
|
|
|
|
| 286 |
models.py — Pydantic data models (Action, Observation, EpisodeState)
|
| 287 |
scenarios.py — Task parameter sampling (7 tasks, deterministic per seed)
|
| 288 |
pytorch_engine.py — Real PyTorch models, fault injection, gradient/weight extraction
|
| 289 |
+
simulation.py — 20-epoch real training with fault injection
|
| 290 |
reward_engine.py — 7-component per-step reward with context gating
|
| 291 |
graders.py — Per-task holistic 0.0-1.0 scoring
|
| 292 |
code_templates.py — Task 6 bug variants + 4-strategy fix validation
|
|
|
|
| 297 |
app.py — FastAPI + custom endpoints
|
| 298 |
dashboard.html — Live Plotly.js diagnostic dashboard
|
| 299 |
|
| 300 |
+
tests/ — 246 tests, 96% coverage
|
| 301 |
baseline_heuristic.py — Rule-based agent (deterministic, no API key)
|
| 302 |
baseline_inference.py — LLM agent (Groq/Cerebras/Gemini/OpenAI)
|
| 303 |
```
|
ml_training_debugger/graders.py
CHANGED
|
@@ -183,26 +183,35 @@ def grade_task_006(state: EpisodeState, scenario: ScenarioParams) -> float:
|
|
| 183 |
|
| 184 |
Diagnosis must ALWAYS be 'code_bug' regardless of bug variant.
|
| 185 |
Hard task rewards thorough investigation before fixing.
|
|
|
|
| 186 |
"""
|
| 187 |
score = 0.0
|
| 188 |
|
| 189 |
-
# +0.05 for
|
| 190 |
if state.code_inspected:
|
| 191 |
score += 0.05
|
| 192 |
-
|
| 193 |
-
# Thoroughness bonus: inspecting other systems first rules out non-code causes
|
| 194 |
if state.gradients_inspected:
|
| 195 |
score += 0.05
|
| 196 |
if state.model_modes_inspected:
|
| 197 |
score += 0.05
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
# Code fix credit
|
| 200 |
if _has_action(state, "fix_code") and state.fix_action_taken:
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
# Restart credit
|
| 204 |
if state.restart_after_fix:
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
# +0.45 for correct diagnosis (must be code_bug)
|
| 208 |
if _correct_diagnosis(state, scenario):
|
|
@@ -212,20 +221,46 @@ def grade_task_006(state: EpisodeState, scenario: ScenarioParams) -> float:
|
|
| 212 |
|
| 213 |
|
| 214 |
def grade_task_007(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 215 |
-
"""Grade Task 7 — LR Scheduler Misconfigured (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
score = 0.0
|
| 217 |
|
|
|
|
| 218 |
if state.gradients_inspected:
|
| 219 |
score += 0.05
|
| 220 |
if state.data_inspected:
|
| 221 |
score += 0.05
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
if _has_action(state, "modify_config"):
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
if state.restart_after_fix:
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
if _correct_diagnosis(state, scenario):
|
| 227 |
score += 0.40
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
return min(1.0, max(0.0, score))
|
| 230 |
|
| 231 |
|
|
|
|
| 183 |
|
| 184 |
Diagnosis must ALWAYS be 'code_bug' regardless of bug variant.
|
| 185 |
Hard task rewards thorough investigation before fixing.
|
| 186 |
+
Full credit requires ruling out non-code causes via weight inspection.
|
| 187 |
"""
|
| 188 |
score = 0.0
|
| 189 |
|
| 190 |
+
# Investigation credits (+0.05 each, up to +0.25 for all 5 types)
|
| 191 |
if state.code_inspected:
|
| 192 |
score += 0.05
|
|
|
|
|
|
|
| 193 |
if state.gradients_inspected:
|
| 194 |
score += 0.05
|
| 195 |
if state.model_modes_inspected:
|
| 196 |
score += 0.05
|
| 197 |
+
if state.model_weights_inspected:
|
| 198 |
+
score += 0.05
|
| 199 |
+
if state.data_inspected:
|
| 200 |
+
score += 0.05
|
| 201 |
|
| 202 |
+
# Code fix credit scaled by investigation thoroughness
|
| 203 |
if _has_action(state, "fix_code") and state.fix_action_taken:
|
| 204 |
+
if state.model_weights_inspected:
|
| 205 |
+
score += 0.15 # Thorough: ruled out weight-related causes
|
| 206 |
+
else:
|
| 207 |
+
score += 0.08 # Quick fix without full investigation
|
| 208 |
|
| 209 |
+
# Restart credit scaled by thoroughness
|
| 210 |
if state.restart_after_fix:
|
| 211 |
+
if state.model_weights_inspected:
|
| 212 |
+
score += 0.15 # Full restart credit
|
| 213 |
+
else:
|
| 214 |
+
score += 0.08 # Partial credit
|
| 215 |
|
| 216 |
# +0.45 for correct diagnosis (must be code_bug)
|
| 217 |
if _correct_diagnosis(state, scenario):
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
def grade_task_007(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 224 |
+
"""Grade Task 7 — LR Scheduler Misconfigured (hard). Spec extension.
|
| 225 |
+
|
| 226 |
+
Requires thorough investigation: agents must inspect weights to rule out
|
| 227 |
+
weight-related issues before concluding scheduler is the root cause.
|
| 228 |
+
Penalizes wrong fixes (e.g. patch_data_loader when data is fine).
|
| 229 |
+
"""
|
| 230 |
score = 0.0
|
| 231 |
|
| 232 |
+
# Investigation credits (+0.05 each, up to +0.20 for all 4 types)
|
| 233 |
if state.gradients_inspected:
|
| 234 |
score += 0.05
|
| 235 |
if state.data_inspected:
|
| 236 |
score += 0.05
|
| 237 |
+
if state.model_weights_inspected:
|
| 238 |
+
score += 0.05
|
| 239 |
+
if state.model_modes_inspected:
|
| 240 |
+
score += 0.05
|
| 241 |
+
|
| 242 |
+
# Fix credit scaled by investigation thoroughness
|
| 243 |
if _has_action(state, "modify_config"):
|
| 244 |
+
if state.model_weights_inspected:
|
| 245 |
+
score += 0.20 # Thorough: ruled out weight issues
|
| 246 |
+
else:
|
| 247 |
+
score += 0.12 # Partial: didn't check weights
|
| 248 |
+
|
| 249 |
+
# Restart credit scaled by thoroughness
|
| 250 |
if state.restart_after_fix:
|
| 251 |
+
if state.model_weights_inspected:
|
| 252 |
+
score += 0.20 # Full restart credit
|
| 253 |
+
else:
|
| 254 |
+
score += 0.12 # Partial credit
|
| 255 |
+
|
| 256 |
+
# Diagnosis
|
| 257 |
if _correct_diagnosis(state, scenario):
|
| 258 |
score += 0.40
|
| 259 |
|
| 260 |
+
# Wrong-fix penalty: patch_data_loader when data is clean
|
| 261 |
+
if _has_action(state, "patch_data_loader"):
|
| 262 |
+
score -= 0.10
|
| 263 |
+
|
| 264 |
return min(1.0, max(0.0, score))
|
| 265 |
|
| 266 |
|
pyproject.toml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
[project]
|
| 2 |
name = "pytorch-training-debugger"
|
| 3 |
-
version = "1.
|
| 4 |
description = "OpenEnv RL environment for PyTorch training failure debugging"
|
| 5 |
requires-python = ">=3.12"
|
| 6 |
dependencies = [
|
|
|
|
| 1 |
[project]
|
| 2 |
name = "pytorch-training-debugger"
|
| 3 |
+
version = "1.1.0"
|
| 4 |
description = "OpenEnv RL environment for PyTorch training failure debugging"
|
| 5 |
requires-python = ">=3.12"
|
| 6 |
dependencies = [
|
server/app.py
CHANGED
|
@@ -12,7 +12,7 @@ import sys
|
|
| 12 |
from typing import Optional
|
| 13 |
|
| 14 |
from fastapi import FastAPI
|
| 15 |
-
from fastapi.responses import HTMLResponse, JSONResponse
|
| 16 |
from openenv.core.env_server.http_server import create_app
|
| 17 |
|
| 18 |
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
|
|
@@ -77,9 +77,8 @@ _baseline_lock = asyncio.Lock()
|
|
| 77 |
|
| 78 |
|
| 79 |
@app.get("/")
|
| 80 |
-
def root():
|
| 81 |
"""Redirect root to dashboard."""
|
| 82 |
-
from fastapi.responses import RedirectResponse
|
| 83 |
return RedirectResponse(url="/dashboard")
|
| 84 |
|
| 85 |
|
|
@@ -174,7 +173,7 @@ def post_grader(session_id: Optional[str] = None) -> dict:
|
|
| 174 |
|
| 175 |
|
| 176 |
@app.post("/baseline", response_model=None)
|
| 177 |
-
async def post_baseline():
|
| 178 |
"""Trigger baseline run, return scores for all tasks.
|
| 179 |
|
| 180 |
Returns 409 if already running. Uses asyncio.Lock for thread safety.
|
|
|
|
| 12 |
from typing import Optional
|
| 13 |
|
| 14 |
from fastapi import FastAPI
|
| 15 |
+
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
| 16 |
from openenv.core.env_server.http_server import create_app
|
| 17 |
|
| 18 |
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
@app.get("/")
|
| 80 |
+
def root() -> RedirectResponse:
|
| 81 |
"""Redirect root to dashboard."""
|
|
|
|
| 82 |
return RedirectResponse(url="/dashboard")
|
| 83 |
|
| 84 |
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
@app.post("/baseline", response_model=None)
|
| 176 |
+
async def post_baseline() -> JSONResponse | dict:
|
| 177 |
"""Trigger baseline run, return scores for all tasks.
|
| 178 |
|
| 179 |
Returns 409 if already running. Uses asyncio.Lock for thread safety.
|
tests/test_graders.py
CHANGED
|
@@ -10,6 +10,7 @@ from ml_training_debugger.graders import (
|
|
| 10 |
grade_task_001,
|
| 11 |
grade_task_003,
|
| 12 |
grade_task_005,
|
|
|
|
| 13 |
grade_task_007,
|
| 14 |
)
|
| 15 |
from ml_training_debugger.models import EpisodeState
|
|
@@ -241,25 +242,168 @@ class TestGradeEpisode:
|
|
| 241 |
assert score == 0.0
|
| 242 |
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
class TestGradeTask007:
|
| 245 |
-
def
|
|
|
|
| 246 |
scenario = sample_scenario("task_007", seed=42)
|
| 247 |
state = EpisodeState(
|
| 248 |
gradients_inspected=True,
|
| 249 |
data_inspected=True,
|
|
|
|
|
|
|
| 250 |
fix_action_taken=True,
|
| 251 |
restart_after_fix=True,
|
| 252 |
diagnosis_submitted=True,
|
| 253 |
actions_taken=[
|
| 254 |
"inspect_gradients",
|
| 255 |
"inspect_data_batch",
|
|
|
|
|
|
|
| 256 |
"modify_config",
|
| 257 |
"restart_run",
|
| 258 |
"mark_diagnosed:scheduler_misconfigured",
|
| 259 |
],
|
| 260 |
)
|
| 261 |
score = grade_task_007(state, scenario)
|
| 262 |
-
assert score == 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
def test_wrong_diagnosis(self):
|
| 265 |
scenario = sample_scenario("task_007", seed=42)
|
|
|
|
| 10 |
grade_task_001,
|
| 11 |
grade_task_003,
|
| 12 |
grade_task_005,
|
| 13 |
+
grade_task_006,
|
| 14 |
grade_task_007,
|
| 15 |
)
|
| 16 |
from ml_training_debugger.models import EpisodeState
|
|
|
|
| 242 |
assert score == 0.0
|
| 243 |
|
| 244 |
|
| 245 |
+
class TestGradeTask006:
|
| 246 |
+
@pytest.fixture
|
| 247 |
+
def scenario_006(self):
|
| 248 |
+
return sample_scenario("task_006", seed=42)
|
| 249 |
+
|
| 250 |
+
def test_perfect_score_thorough(self, scenario_006):
|
| 251 |
+
"""Thorough agent inspects ALL systems including weights — gets perfect score."""
|
| 252 |
+
state = EpisodeState(
|
| 253 |
+
code_inspected=True,
|
| 254 |
+
gradients_inspected=True,
|
| 255 |
+
model_modes_inspected=True,
|
| 256 |
+
model_weights_inspected=True,
|
| 257 |
+
data_inspected=True,
|
| 258 |
+
fix_action_taken=True,
|
| 259 |
+
restart_after_fix=True,
|
| 260 |
+
diagnosis_submitted=True,
|
| 261 |
+
actions_taken=[
|
| 262 |
+
"inspect_gradients",
|
| 263 |
+
"inspect_data_batch",
|
| 264 |
+
"inspect_model_weights",
|
| 265 |
+
"inspect_model_modes",
|
| 266 |
+
"inspect_code",
|
| 267 |
+
"fix_code",
|
| 268 |
+
"restart_run",
|
| 269 |
+
"mark_diagnosed:code_bug",
|
| 270 |
+
],
|
| 271 |
+
)
|
| 272 |
+
score = grade_task_006(state, scenario_006)
|
| 273 |
+
assert score == pytest.approx(1.0)
|
| 274 |
+
|
| 275 |
+
def test_no_weights_inspection_partial(self, scenario_006):
|
| 276 |
+
"""Agent that skips weight inspection gets reduced fix/restart credit."""
|
| 277 |
+
state = EpisodeState(
|
| 278 |
+
code_inspected=True,
|
| 279 |
+
gradients_inspected=True,
|
| 280 |
+
model_modes_inspected=True,
|
| 281 |
+
data_inspected=True,
|
| 282 |
+
fix_action_taken=True,
|
| 283 |
+
restart_after_fix=True,
|
| 284 |
+
diagnosis_submitted=True,
|
| 285 |
+
actions_taken=[
|
| 286 |
+
"inspect_gradients",
|
| 287 |
+
"inspect_data_batch",
|
| 288 |
+
"inspect_model_modes",
|
| 289 |
+
"inspect_code",
|
| 290 |
+
"fix_code",
|
| 291 |
+
"restart_run",
|
| 292 |
+
"mark_diagnosed:code_bug",
|
| 293 |
+
],
|
| 294 |
+
)
|
| 295 |
+
score = grade_task_006(state, scenario_006)
|
| 296 |
+
# 0.05*4 + 0.08 + 0.08 + 0.45 = 0.81
|
| 297 |
+
assert score == pytest.approx(0.81)
|
| 298 |
+
assert score < 1.0 # Must not be perfect without weights
|
| 299 |
+
|
| 300 |
+
def test_minimal_investigation(self, scenario_006):
|
| 301 |
+
"""Agent that only inspects code, fixes, and diagnoses."""
|
| 302 |
+
state = EpisodeState(
|
| 303 |
+
code_inspected=True,
|
| 304 |
+
fix_action_taken=True,
|
| 305 |
+
restart_after_fix=True,
|
| 306 |
+
diagnosis_submitted=True,
|
| 307 |
+
actions_taken=[
|
| 308 |
+
"inspect_code",
|
| 309 |
+
"fix_code",
|
| 310 |
+
"restart_run",
|
| 311 |
+
"mark_diagnosed:code_bug",
|
| 312 |
+
],
|
| 313 |
+
)
|
| 314 |
+
score = grade_task_006(state, scenario_006)
|
| 315 |
+
# 0.05 + 0.08 + 0.08 + 0.45 = 0.66
|
| 316 |
+
assert score == pytest.approx(0.66)
|
| 317 |
+
|
| 318 |
+
def test_wrong_diagnosis(self, scenario_006):
|
| 319 |
+
"""Submitting batchnorm_eval_mode on a code_bug task fails."""
|
| 320 |
+
state = EpisodeState(
|
| 321 |
+
code_inspected=True,
|
| 322 |
+
diagnosis_submitted=True,
|
| 323 |
+
actions_taken=[
|
| 324 |
+
"inspect_code",
|
| 325 |
+
"mark_diagnosed:batchnorm_eval_mode",
|
| 326 |
+
],
|
| 327 |
+
)
|
| 328 |
+
score = grade_task_006(state, scenario_006)
|
| 329 |
+
assert score < 0.2 # Only gets code_inspected bonus
|
| 330 |
+
|
| 331 |
+
def test_score_in_range(self, scenario_006):
|
| 332 |
+
state = EpisodeState()
|
| 333 |
+
score = grade_task_006(state, scenario_006)
|
| 334 |
+
assert 0.0 <= score <= 1.0
|
| 335 |
+
|
| 336 |
+
|
| 337 |
class TestGradeTask007:
|
| 338 |
+
def test_perfect_score_thorough(self):
|
| 339 |
+
"""Thorough agent inspects weights — gets perfect score."""
|
| 340 |
scenario = sample_scenario("task_007", seed=42)
|
| 341 |
state = EpisodeState(
|
| 342 |
gradients_inspected=True,
|
| 343 |
data_inspected=True,
|
| 344 |
+
model_weights_inspected=True,
|
| 345 |
+
model_modes_inspected=True,
|
| 346 |
fix_action_taken=True,
|
| 347 |
restart_after_fix=True,
|
| 348 |
diagnosis_submitted=True,
|
| 349 |
actions_taken=[
|
| 350 |
"inspect_gradients",
|
| 351 |
"inspect_data_batch",
|
| 352 |
+
"inspect_model_weights",
|
| 353 |
+
"inspect_model_modes",
|
| 354 |
"modify_config",
|
| 355 |
"restart_run",
|
| 356 |
"mark_diagnosed:scheduler_misconfigured",
|
| 357 |
],
|
| 358 |
)
|
| 359 |
score = grade_task_007(state, scenario)
|
| 360 |
+
assert score == pytest.approx(1.0)
|
| 361 |
+
|
| 362 |
+
def test_no_weights_partial(self):
|
| 363 |
+
"""Agent without weight inspection gets reduced fix/restart credit."""
|
| 364 |
+
scenario = sample_scenario("task_007", seed=42)
|
| 365 |
+
state = EpisodeState(
|
| 366 |
+
gradients_inspected=True,
|
| 367 |
+
data_inspected=True,
|
| 368 |
+
model_modes_inspected=True,
|
| 369 |
+
fix_action_taken=True,
|
| 370 |
+
restart_after_fix=True,
|
| 371 |
+
diagnosis_submitted=True,
|
| 372 |
+
actions_taken=[
|
| 373 |
+
"inspect_gradients",
|
| 374 |
+
"inspect_data_batch",
|
| 375 |
+
"inspect_model_modes",
|
| 376 |
+
"modify_config",
|
| 377 |
+
"restart_run",
|
| 378 |
+
"mark_diagnosed:scheduler_misconfigured",
|
| 379 |
+
],
|
| 380 |
+
)
|
| 381 |
+
score = grade_task_007(state, scenario)
|
| 382 |
+
# 0.05*3 + 0.12 + 0.12 + 0.40 = 0.79
|
| 383 |
+
assert score == pytest.approx(0.79)
|
| 384 |
+
assert score < 1.0
|
| 385 |
+
|
| 386 |
+
def test_wrong_fix_penalty(self):
|
| 387 |
+
"""Agent that patches data loader (wrong fix) gets penalized."""
|
| 388 |
+
scenario = sample_scenario("task_007", seed=42)
|
| 389 |
+
state = EpisodeState(
|
| 390 |
+
gradients_inspected=True,
|
| 391 |
+
data_inspected=True,
|
| 392 |
+
fix_action_taken=True,
|
| 393 |
+
restart_after_fix=True,
|
| 394 |
+
diagnosis_submitted=True,
|
| 395 |
+
actions_taken=[
|
| 396 |
+
"inspect_gradients",
|
| 397 |
+
"inspect_data_batch",
|
| 398 |
+
"patch_data_loader",
|
| 399 |
+
"modify_config",
|
| 400 |
+
"restart_run",
|
| 401 |
+
"mark_diagnosed:scheduler_misconfigured",
|
| 402 |
+
],
|
| 403 |
+
)
|
| 404 |
+
score = grade_task_007(state, scenario)
|
| 405 |
+
# Normal partial score minus 0.10 penalty
|
| 406 |
+
assert score < 0.75
|
| 407 |
|
| 408 |
def test_wrong_diagnosis(self):
|
| 409 |
scenario = sample_scenario("task_007", seed=42)
|