omkarrr88 commited on
Commit ·
45eee48
1
Parent(s): 43647d3
docker size reduced
Browse files- .coverage +0 -0
- .dockerignore +3 -1
- Dockerfile +78 -15
- tests/test_graders.py +59 -0
- tests/test_new_endpoints.py +80 -0
- tests/test_real_training.py +135 -0
- tests/test_simulation_fallback.py +149 -0
.coverage
CHANGED
|
Binary files a/.coverage and b/.coverage differ
|
|
|
.dockerignore
CHANGED
|
@@ -3,7 +3,8 @@ __pycache__/
|
|
| 3 |
.git/
|
| 4 |
.pytest_cache/
|
| 5 |
tests/
|
| 6 |
-
validation/
|
|
|
|
| 7 |
*.md
|
| 8 |
!README.md
|
| 9 |
.claude/
|
|
@@ -11,3 +12,4 @@ run*.json
|
|
| 11 |
htmlcov/
|
| 12 |
.mypy_cache/
|
| 13 |
.ruff_cache/
|
|
|
|
|
|
| 3 |
.git/
|
| 4 |
.pytest_cache/
|
| 5 |
tests/
|
| 6 |
+
validation/*.py
|
| 7 |
+
validation/requirements.txt
|
| 8 |
*.md
|
| 9 |
!README.md
|
| 10 |
.claude/
|
|
|
|
| 12 |
htmlcov/
|
| 13 |
.mypy_cache/
|
| 14 |
.ruff_cache/
|
| 15 |
+
.env
|
Dockerfile
CHANGED
|
@@ -1,22 +1,29 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
WORKDIR /app
|
| 4 |
|
| 5 |
-
|
| 6 |
-
RUN apt-get update && apt-get install -y --no-install-recommends curl && \
|
| 7 |
rm -rf /var/lib/apt/lists/*
|
| 8 |
|
| 9 |
-
|
| 10 |
-
# Docker layers are immutable — cleanup in a separate RUN saves nothing.
|
| 11 |
-
# PyTorch CPU-only (~280MB wheel, ~460MB installed) is the minimum for real
|
| 12 |
-
# torch.nn.Module, torch.autograd, and state_dict() support.
|
| 13 |
COPY requirements.txt .
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
rm -rf /usr/local/lib/python3.12/site-packages/torch/test \
|
| 18 |
/usr/local/lib/python3.12/site-packages/torch/include \
|
| 19 |
/usr/local/lib/python3.12/site-packages/torch/share \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
/usr/local/lib/python3.12/site-packages/torch/utils/benchmark \
|
| 21 |
/usr/local/lib/python3.12/site-packages/torch/utils/bottleneck \
|
| 22 |
/usr/local/lib/python3.12/site-packages/torch/utils/tensorboard \
|
|
@@ -24,18 +31,74 @@ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/wh
|
|
| 24 |
/usr/local/lib/python3.12/site-packages/torch/lib/libtorchbind_test.so \
|
| 25 |
/usr/local/lib/python3.12/site-packages/torch/lib/libjitbackend_test.so \
|
| 26 |
/usr/local/lib/python3.12/site-packages/torch/lib/libbackend_with_compiler.so \
|
| 27 |
-
/usr/local/lib/python3.12/site-packages/
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
find /usr/local/lib/python3.12/site-packages -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null; \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
true
|
| 31 |
|
| 32 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
COPY ml_training_debugger/ ml_training_debugger/
|
| 34 |
COPY server/ server/
|
| 35 |
COPY openenv.yaml .
|
| 36 |
COPY baseline_heuristic.py .
|
| 37 |
COPY baseline_inference.py .
|
| 38 |
COPY README.md .
|
|
|
|
| 39 |
|
| 40 |
EXPOSE 7860
|
| 41 |
|
|
|
|
| 1 |
+
# ---- Stage 1: Builder — install + strip aggressively ----
|
| 2 |
+
FROM python:3.12-slim AS builder
|
|
|
|
| 3 |
|
| 4 |
+
RUN apt-get update && apt-get install -y --no-install-recommends binutils && \
|
|
|
|
| 5 |
rm -rf /var/lib/apt/lists/*
|
| 6 |
|
| 7 |
+
WORKDIR /build
|
|
|
|
|
|
|
|
|
|
| 8 |
COPY requirements.txt .
|
| 9 |
+
|
| 10 |
+
RUN pip install --no-cache-dir --no-compile \
|
| 11 |
+
torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu && \
|
| 12 |
+
pip install --no-cache-dir --no-compile -r requirements.txt && \
|
| 13 |
+
#
|
| 14 |
+
# === STRIP DEBUG SYMBOLS FROM ALL .so FILES (saves ~100-200MB) ===
|
| 15 |
+
find /usr/local/lib/python3.12/site-packages -name "*.so" -exec strip --strip-unneeded {} + 2>/dev/null; \
|
| 16 |
+
find /usr/local/lib/python3.12/site-packages -name "*.so.*" -exec strip --strip-unneeded {} + 2>/dev/null; \
|
| 17 |
+
#
|
| 18 |
+
# === TORCH CLEANUP ===
|
| 19 |
rm -rf /usr/local/lib/python3.12/site-packages/torch/test \
|
| 20 |
/usr/local/lib/python3.12/site-packages/torch/include \
|
| 21 |
/usr/local/lib/python3.12/site-packages/torch/share \
|
| 22 |
+
/usr/local/lib/python3.12/site-packages/torch/bin/FileStore* \
|
| 23 |
+
/usr/local/lib/python3.12/site-packages/torch/bin/HashStore* \
|
| 24 |
+
/usr/local/lib/python3.12/site-packages/torch/bin/TCPStore* \
|
| 25 |
+
/usr/local/lib/python3.12/site-packages/torch/bin/protoc* \
|
| 26 |
+
/usr/local/lib/python3.12/site-packages/torch/bin/test_* \
|
| 27 |
/usr/local/lib/python3.12/site-packages/torch/utils/benchmark \
|
| 28 |
/usr/local/lib/python3.12/site-packages/torch/utils/bottleneck \
|
| 29 |
/usr/local/lib/python3.12/site-packages/torch/utils/tensorboard \
|
|
|
|
| 31 |
/usr/local/lib/python3.12/site-packages/torch/lib/libtorchbind_test.so \
|
| 32 |
/usr/local/lib/python3.12/site-packages/torch/lib/libjitbackend_test.so \
|
| 33 |
/usr/local/lib/python3.12/site-packages/torch/lib/libbackend_with_compiler.so \
|
| 34 |
+
/usr/local/lib/python3.12/site-packages/torch/lib/libaoti_custom_ops.so \
|
| 35 |
+
/usr/local/lib/python3.12/site-packages/torch/lib/libshm_windows \
|
| 36 |
+
/usr/local/lib/python3.12/site-packages/caffe2 \
|
| 37 |
+
#
|
| 38 |
+
# === BLOATED TRANSITIVE DEPS ===
|
| 39 |
+
/usr/local/lib/python3.12/site-packages/gradio \
|
| 40 |
+
/usr/local/lib/python3.12/site-packages/gradio_client \
|
| 41 |
+
/usr/local/lib/python3.12/site-packages/hf_gradio \
|
| 42 |
+
/usr/local/lib/python3.12/site-packages/pandas \
|
| 43 |
+
/usr/local/lib/python3.12/site-packages/PIL \
|
| 44 |
+
/usr/local/lib/python3.12/site-packages/Pillow* \
|
| 45 |
+
/usr/local/lib/python3.12/site-packages/pillow* \
|
| 46 |
+
/usr/local/lib/python3.12/site-packages/networkx \
|
| 47 |
+
/usr/local/lib/python3.12/site-packages/scipy \
|
| 48 |
+
/usr/local/lib/python3.12/site-packages/matplotlib \
|
| 49 |
+
/usr/local/lib/python3.12/site-packages/hf_xet \
|
| 50 |
+
/usr/local/lib/python3.12/site-packages/ffmpy \
|
| 51 |
+
/usr/local/lib/python3.12/site-packages/pydub \
|
| 52 |
+
/usr/local/lib/python3.12/site-packages/groovy \
|
| 53 |
+
/usr/local/lib/python3.12/site-packages/tomlkit \
|
| 54 |
+
/usr/local/lib/python3.12/site-packages/semantic_version* \
|
| 55 |
+
/usr/local/lib/python3.12/site-packages/safehttpx* \
|
| 56 |
+
/usr/local/lib/python3.12/site-packages/brotli* \
|
| 57 |
+
/usr/local/lib/python3.12/site-packages/Brotli* \
|
| 58 |
+
/usr/local/lib/python3.12/site-packages/pip \
|
| 59 |
+
/usr/local/lib/python3.12/site-packages/setuptools \
|
| 60 |
+
/usr/local/lib/python3.12/site-packages/docutils \
|
| 61 |
+
/usr/local/lib/python3.12/site-packages/cryptography \
|
| 62 |
+
/usr/local/lib/python3.12/site-packages/cryptography* \
|
| 63 |
+
/usr/local/lib/python3.12/site-packages/pytz 2>/dev/null; \
|
| 64 |
+
#
|
| 65 |
+
# === FILE-LEVEL CLEANUP ===
|
| 66 |
+
find /usr/local/lib/python3.12/site-packages -name "*.pyi" -delete 2>/dev/null; \
|
| 67 |
+
find /usr/local/lib/python3.12/site-packages -name "*.pyc" -delete 2>/dev/null; \
|
| 68 |
find /usr/local/lib/python3.12/site-packages -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null; \
|
| 69 |
+
find /usr/local/lib/python3.12/site-packages -name "*.egg-info" -type d -exec rm -rf {} + 2>/dev/null; \
|
| 70 |
+
find /usr/local/lib/python3.12/site-packages -name "tests" -type d -exec rm -rf {} + 2>/dev/null; \
|
| 71 |
+
find /usr/local/lib/python3.12/site-packages -name "test" -type d -exec rm -rf {} + 2>/dev/null; \
|
| 72 |
+
# Remove stale dist-info for packages we already deleted
|
| 73 |
+
rm -rf /usr/local/lib/python3.12/site-packages/gradio*.dist-info \
|
| 74 |
+
/usr/local/lib/python3.12/site-packages/pandas*.dist-info \
|
| 75 |
+
/usr/local/lib/python3.12/site-packages/Pillow*.dist-info \
|
| 76 |
+
/usr/local/lib/python3.12/site-packages/hf_xet*.dist-info \
|
| 77 |
+
/usr/local/lib/python3.12/site-packages/Brotli*.dist-info \
|
| 78 |
+
/usr/local/lib/python3.12/site-packages/networkx*.dist-info \
|
| 79 |
+
/usr/local/lib/python3.12/site-packages/pip \
|
| 80 |
+
/usr/local/lib/python3.12/site-packages/pip*.dist-info 2>/dev/null; \
|
| 81 |
true
|
| 82 |
|
| 83 |
+
# ---- Stage 2: Runtime — minimal clean image ----
|
| 84 |
+
FROM python:3.12-slim
|
| 85 |
+
|
| 86 |
+
WORKDIR /app
|
| 87 |
+
|
| 88 |
+
RUN apt-get update && apt-get install -y --no-install-recommends curl && \
|
| 89 |
+
rm -rf /var/lib/apt/lists/*
|
| 90 |
+
|
| 91 |
+
# Copy only what's needed from builder
|
| 92 |
+
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
| 93 |
+
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/uvicorn
|
| 94 |
+
|
| 95 |
COPY ml_training_debugger/ ml_training_debugger/
|
| 96 |
COPY server/ server/
|
| 97 |
COPY openenv.yaml .
|
| 98 |
COPY baseline_heuristic.py .
|
| 99 |
COPY baseline_inference.py .
|
| 100 |
COPY README.md .
|
| 101 |
+
COPY validation/reports/ validation/reports/
|
| 102 |
|
| 103 |
EXPOSE 7860
|
| 104 |
|
tests/test_graders.py
CHANGED
|
@@ -5,10 +5,12 @@ from __future__ import annotations
|
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
from ml_training_debugger.graders import (
|
|
|
|
| 8 |
grade_episode,
|
| 9 |
grade_task_001,
|
| 10 |
grade_task_003,
|
| 11 |
grade_task_005,
|
|
|
|
| 12 |
)
|
| 13 |
from ml_training_debugger.models import EpisodeState
|
| 14 |
from ml_training_debugger.scenarios import sample_scenario
|
|
@@ -166,3 +168,60 @@ class TestGradeEpisode:
|
|
| 166 |
state = EpisodeState()
|
| 167 |
score = grade_episode("task_999", state, scenario_001)
|
| 168 |
assert score == 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import pytest
|
| 6 |
|
| 7 |
from ml_training_debugger.graders import (
|
| 8 |
+
_submitted_diagnosis,
|
| 9 |
grade_episode,
|
| 10 |
grade_task_001,
|
| 11 |
grade_task_003,
|
| 12 |
grade_task_005,
|
| 13 |
+
grade_task_007,
|
| 14 |
)
|
| 15 |
from ml_training_debugger.models import EpisodeState
|
| 16 |
from ml_training_debugger.scenarios import sample_scenario
|
|
|
|
| 168 |
state = EpisodeState()
|
| 169 |
score = grade_episode("task_999", state, scenario_001)
|
| 170 |
assert score == 0.0
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class TestGradeTask007:
|
| 174 |
+
def test_perfect_score(self):
|
| 175 |
+
scenario = sample_scenario("task_007", seed=42)
|
| 176 |
+
state = EpisodeState(
|
| 177 |
+
gradients_inspected=True,
|
| 178 |
+
data_inspected=True,
|
| 179 |
+
fix_action_taken=True,
|
| 180 |
+
restart_after_fix=True,
|
| 181 |
+
diagnosis_submitted=True,
|
| 182 |
+
actions_taken=[
|
| 183 |
+
"inspect_gradients",
|
| 184 |
+
"inspect_data_batch",
|
| 185 |
+
"modify_config",
|
| 186 |
+
"restart_run",
|
| 187 |
+
"mark_diagnosed:scheduler_misconfigured",
|
| 188 |
+
],
|
| 189 |
+
)
|
| 190 |
+
score = grade_task_007(state, scenario)
|
| 191 |
+
assert score == 1.0
|
| 192 |
+
|
| 193 |
+
def test_wrong_diagnosis(self):
|
| 194 |
+
scenario = sample_scenario("task_007", seed=42)
|
| 195 |
+
state = EpisodeState(
|
| 196 |
+
diagnosis_submitted=True,
|
| 197 |
+
actions_taken=["mark_diagnosed:overfitting"],
|
| 198 |
+
)
|
| 199 |
+
score = grade_task_007(state, scenario)
|
| 200 |
+
assert score < 0.5
|
| 201 |
+
|
| 202 |
+
def test_score_in_range(self):
|
| 203 |
+
scenario = sample_scenario("task_007", seed=42)
|
| 204 |
+
state = EpisodeState()
|
| 205 |
+
score = grade_task_007(state, scenario)
|
| 206 |
+
assert 0.0 <= score <= 1.0
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class TestSubmittedDiagnosis:
|
| 210 |
+
def test_finds_diagnosis(self):
|
| 211 |
+
state = EpisodeState(
|
| 212 |
+
actions_taken=["inspect_gradients", "mark_diagnosed:lr_too_high"],
|
| 213 |
+
)
|
| 214 |
+
assert _submitted_diagnosis(state) == "lr_too_high"
|
| 215 |
+
|
| 216 |
+
def test_no_diagnosis(self):
|
| 217 |
+
state = EpisodeState(actions_taken=["inspect_gradients"])
|
| 218 |
+
assert _submitted_diagnosis(state) is None
|
| 219 |
+
|
| 220 |
+
def test_latest_diagnosis(self):
|
| 221 |
+
state = EpisodeState(
|
| 222 |
+
actions_taken=[
|
| 223 |
+
"mark_diagnosed:overfitting",
|
| 224 |
+
"mark_diagnosed:lr_too_high",
|
| 225 |
+
],
|
| 226 |
+
)
|
| 227 |
+
assert _submitted_diagnosis(state) == "lr_too_high"
|
tests/test_new_endpoints.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for new endpoints: curriculum, leaderboard, replay, validation-report."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from fastapi.testclient import TestClient
|
| 7 |
+
|
| 8 |
+
from server.app import app
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def client():
|
| 13 |
+
return TestClient(app)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestCurriculumEndpoint:
|
| 17 |
+
def test_returns_curriculum(self, client) -> None:
|
| 18 |
+
resp = client.get("/curriculum")
|
| 19 |
+
assert resp.status_code == 200
|
| 20 |
+
data = resp.json()
|
| 21 |
+
assert "curriculum" in data
|
| 22 |
+
assert "total_episodes" in data
|
| 23 |
+
assert data["total_episodes"] > 0
|
| 24 |
+
|
| 25 |
+
def test_curriculum_has_difficulty_levels(self, client) -> None:
|
| 26 |
+
resp = client.get("/curriculum")
|
| 27 |
+
curriculum = resp.json()["curriculum"]
|
| 28 |
+
levels = {entry["difficulty_level"] for entry in curriculum}
|
| 29 |
+
assert 1 in levels
|
| 30 |
+
assert 3 in levels
|
| 31 |
+
assert 5 in levels
|
| 32 |
+
|
| 33 |
+
def test_curriculum_covers_all_tasks(self, client) -> None:
|
| 34 |
+
resp = client.get("/curriculum")
|
| 35 |
+
curriculum = resp.json()["curriculum"]
|
| 36 |
+
task_ids = {entry["task_id"] for entry in curriculum}
|
| 37 |
+
assert "task_001" in task_ids
|
| 38 |
+
assert "task_007" in task_ids
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestLeaderboardEndpoint:
|
| 42 |
+
def test_returns_leaderboard(self, client) -> None:
|
| 43 |
+
resp = client.get("/leaderboard")
|
| 44 |
+
assert resp.status_code == 200
|
| 45 |
+
data = resp.json()
|
| 46 |
+
assert "entries" in data
|
| 47 |
+
assert "total" in data
|
| 48 |
+
|
| 49 |
+
def test_leaderboard_after_baseline(self, client) -> None:
|
| 50 |
+
# Run baseline to populate scores
|
| 51 |
+
client.post("/baseline")
|
| 52 |
+
resp = client.get("/leaderboard")
|
| 53 |
+
data = resp.json()
|
| 54 |
+
assert data["total"] > 0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestReplayEndpoint:
|
| 58 |
+
def test_missing_episode(self, client) -> None:
|
| 59 |
+
resp = client.get("/replay/nonexistent_episode_123")
|
| 60 |
+
assert resp.status_code == 200
|
| 61 |
+
data = resp.json()
|
| 62 |
+
assert "error" in data
|
| 63 |
+
|
| 64 |
+
def test_replay_after_baseline(self, client) -> None:
|
| 65 |
+
# Run baseline to create episodes
|
| 66 |
+
client.post("/baseline")
|
| 67 |
+
resp = client.get("/replay/baseline_task_001")
|
| 68 |
+
data = resp.json()
|
| 69 |
+
# Should have episode data or error
|
| 70 |
+
assert "episode_id" in data or "error" in data
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TestValidationReportEndpoint:
|
| 74 |
+
def test_returns_real_report(self, client) -> None:
|
| 75 |
+
resp = client.get("/validation-report")
|
| 76 |
+
assert resp.status_code == 200
|
| 77 |
+
data = resp.json()
|
| 78 |
+
assert "results" in data
|
| 79 |
+
assert "summary" in data
|
| 80 |
+
assert data["summary"]["passed"] > 0
|
tests/test_real_training.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for real mini-training in pytorch_engine.py."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ml_training_debugger.pytorch_engine import (
|
| 8 |
+
SimpleCNN,
|
| 9 |
+
SimpleMLP,
|
| 10 |
+
_TRAINING_CACHE,
|
| 11 |
+
run_real_training,
|
| 12 |
+
)
|
| 13 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestRunRealTraining:
|
| 17 |
+
def test_returns_20_epoch_curves(self) -> None:
|
| 18 |
+
s = sample_scenario("task_001", seed=42)
|
| 19 |
+
curves = run_real_training(s)
|
| 20 |
+
assert len(curves["loss_history"]) == 20
|
| 21 |
+
assert len(curves["val_loss_history"]) == 20
|
| 22 |
+
assert len(curves["val_acc_history"]) == 20
|
| 23 |
+
|
| 24 |
+
def test_all_values_are_floats(self) -> None:
|
| 25 |
+
s = sample_scenario("task_003", seed=42)
|
| 26 |
+
curves = run_real_training(s)
|
| 27 |
+
for key in ["loss_history", "val_loss_history", "val_acc_history"]:
|
| 28 |
+
for v in curves[key]:
|
| 29 |
+
assert isinstance(v, float), f"{key} has non-float: {type(v)}"
|
| 30 |
+
|
| 31 |
+
def test_caching_works(self) -> None:
|
| 32 |
+
_TRAINING_CACHE.clear()
|
| 33 |
+
s = sample_scenario("task_001", seed=42)
|
| 34 |
+
c1 = run_real_training(s)
|
| 35 |
+
c2 = run_real_training(s)
|
| 36 |
+
assert c1 is c2 # Same object reference = cached
|
| 37 |
+
|
| 38 |
+
def test_reproducible_across_calls(self) -> None:
|
| 39 |
+
_TRAINING_CACHE.clear()
|
| 40 |
+
s = sample_scenario("task_002", seed=42)
|
| 41 |
+
c1 = run_real_training(s)
|
| 42 |
+
_TRAINING_CACHE.clear()
|
| 43 |
+
c2 = run_real_training(s)
|
| 44 |
+
assert c1["loss_history"] == c2["loss_history"]
|
| 45 |
+
assert c1["val_acc_history"] == c2["val_acc_history"]
|
| 46 |
+
|
| 47 |
+
def test_different_seeds_different_curves(self) -> None:
|
| 48 |
+
s1 = sample_scenario("task_001", seed=42)
|
| 49 |
+
s2 = sample_scenario("task_001", seed=99)
|
| 50 |
+
c1 = run_real_training(s1)
|
| 51 |
+
c2 = run_real_training(s2)
|
| 52 |
+
assert c1["loss_history"] != c2["loss_history"]
|
| 53 |
+
|
| 54 |
+
def test_task_001_high_lr_instability(self) -> None:
|
| 55 |
+
s = sample_scenario("task_001", seed=42)
|
| 56 |
+
curves = run_real_training(s)
|
| 57 |
+
max_loss = max(v for v in curves["loss_history"] if v != float("inf"))
|
| 58 |
+
assert max_loss > 3.0 # High LR causes loss spikes
|
| 59 |
+
|
| 60 |
+
def test_task_002_vanishing_slow_learning(self) -> None:
|
| 61 |
+
s = sample_scenario("task_002", seed=42)
|
| 62 |
+
curves = run_real_training(s)
|
| 63 |
+
assert len(curves["loss_history"]) == 20
|
| 64 |
+
|
| 65 |
+
def test_task_003_data_leakage(self) -> None:
|
| 66 |
+
s = sample_scenario("task_003", seed=42)
|
| 67 |
+
curves = run_real_training(s)
|
| 68 |
+
# With leakage, val accuracy may be elevated
|
| 69 |
+
assert len(curves["val_acc_history"]) == 20
|
| 70 |
+
|
| 71 |
+
def test_task_004_overfitting(self) -> None:
|
| 72 |
+
s = sample_scenario("task_004", seed=42)
|
| 73 |
+
curves = run_real_training(s)
|
| 74 |
+
assert len(curves["loss_history"]) == 20
|
| 75 |
+
|
| 76 |
+
def test_task_005_batchnorm_eval(self) -> None:
|
| 77 |
+
s = sample_scenario("task_005", seed=42)
|
| 78 |
+
curves = run_real_training(s)
|
| 79 |
+
assert len(curves["loss_history"]) == 20
|
| 80 |
+
|
| 81 |
+
def test_task_006_code_bug(self) -> None:
|
| 82 |
+
s = sample_scenario("task_006", seed=42)
|
| 83 |
+
curves = run_real_training(s)
|
| 84 |
+
assert len(curves["loss_history"]) == 20
|
| 85 |
+
|
| 86 |
+
def test_task_007_scheduler(self) -> None:
|
| 87 |
+
s = sample_scenario("task_007", seed=42)
|
| 88 |
+
curves = run_real_training(s)
|
| 89 |
+
assert len(curves["loss_history"]) == 20
|
| 90 |
+
|
| 91 |
+
def test_mlp_architecture(self) -> None:
|
| 92 |
+
"""Find a scenario that uses MLP and verify training works."""
|
| 93 |
+
for seed in range(1, 20):
|
| 94 |
+
s = sample_scenario("task_001", seed=seed)
|
| 95 |
+
if s.model_type == "mlp":
|
| 96 |
+
curves = run_real_training(s)
|
| 97 |
+
assert len(curves["loss_history"]) == 20
|
| 98 |
+
return
|
| 99 |
+
# If no MLP found in 20 seeds, test directly
|
| 100 |
+
from ml_training_debugger.scenarios import ScenarioParams
|
| 101 |
+
from ml_training_debugger.models import RootCauseDiagnosis
|
| 102 |
+
s = ScenarioParams(
|
| 103 |
+
task_id="task_001",
|
| 104 |
+
root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
|
| 105 |
+
seed=999,
|
| 106 |
+
learning_rate=0.1,
|
| 107 |
+
model_type="mlp",
|
| 108 |
+
)
|
| 109 |
+
curves = run_real_training(s)
|
| 110 |
+
assert len(curves["loss_history"]) == 20
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TestSimpleMLP:
|
| 114 |
+
def test_is_nn_module(self) -> None:
|
| 115 |
+
model = SimpleMLP()
|
| 116 |
+
assert isinstance(model, torch.nn.Module)
|
| 117 |
+
|
| 118 |
+
def test_param_count(self) -> None:
|
| 119 |
+
model = SimpleMLP()
|
| 120 |
+
count = sum(p.numel() for p in model.parameters())
|
| 121 |
+
assert 10_000 < count < 500_000
|
| 122 |
+
|
| 123 |
+
def test_forward_pass(self) -> None:
|
| 124 |
+
model = SimpleMLP()
|
| 125 |
+
x = torch.randn(4, 3, 32, 32)
|
| 126 |
+
out = model(x)
|
| 127 |
+
assert out.shape == (4, 10)
|
| 128 |
+
|
| 129 |
+
def test_has_batchnorm(self) -> None:
|
| 130 |
+
model = SimpleMLP()
|
| 131 |
+
has_bn = any(
|
| 132 |
+
isinstance(m, torch.nn.BatchNorm1d)
|
| 133 |
+
for m in model.modules()
|
| 134 |
+
)
|
| 135 |
+
assert has_bn
|
tests/test_simulation_fallback.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for parametric fallback in simulation.py.
|
| 2 |
+
|
| 3 |
+
These test the fallback paths that run when real training is unavailable.
|
| 4 |
+
We force fallback by monkeypatching _get_real_curves to return None.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from unittest.mock import patch
|
| 10 |
+
|
| 11 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 12 |
+
from ml_training_debugger.simulation import (
|
| 13 |
+
gen_loss_history,
|
| 14 |
+
gen_val_accuracy_history,
|
| 15 |
+
gen_val_loss_history,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _force_fallback(*args, **kwargs):
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TestParametricFallbackLoss:
|
| 24 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 25 |
+
def test_task_001_fallback(self) -> None:
|
| 26 |
+
s = sample_scenario("task_001", seed=42)
|
| 27 |
+
hist = gen_loss_history(s)
|
| 28 |
+
assert len(hist) == 20
|
| 29 |
+
|
| 30 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 31 |
+
def test_task_002_fallback(self) -> None:
|
| 32 |
+
s = sample_scenario("task_002", seed=42)
|
| 33 |
+
hist = gen_loss_history(s)
|
| 34 |
+
assert len(hist) == 20
|
| 35 |
+
|
| 36 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 37 |
+
def test_task_003_fallback(self) -> None:
|
| 38 |
+
s = sample_scenario("task_003", seed=42)
|
| 39 |
+
hist = gen_loss_history(s)
|
| 40 |
+
assert len(hist) == 20
|
| 41 |
+
|
| 42 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 43 |
+
def test_task_004_fallback(self) -> None:
|
| 44 |
+
s = sample_scenario("task_004", seed=42)
|
| 45 |
+
hist = gen_loss_history(s)
|
| 46 |
+
assert len(hist) == 20
|
| 47 |
+
|
| 48 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 49 |
+
def test_task_005_fallback(self) -> None:
|
| 50 |
+
s = sample_scenario("task_005", seed=42)
|
| 51 |
+
hist = gen_loss_history(s)
|
| 52 |
+
assert len(hist) == 20
|
| 53 |
+
|
| 54 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 55 |
+
def test_task_006_fallback(self) -> None:
|
| 56 |
+
s = sample_scenario("task_006", seed=42)
|
| 57 |
+
hist = gen_loss_history(s)
|
| 58 |
+
assert len(hist) == 20
|
| 59 |
+
|
| 60 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 61 |
+
def test_task_007_fallback(self) -> None:
|
| 62 |
+
s = sample_scenario("task_007", seed=42)
|
| 63 |
+
hist = gen_loss_history(s)
|
| 64 |
+
assert len(hist) == 20
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TestParametricFallbackValAcc:
|
| 68 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 69 |
+
def test_task_001_fallback(self) -> None:
|
| 70 |
+
s = sample_scenario("task_001", seed=42)
|
| 71 |
+
hist = gen_val_accuracy_history(s)
|
| 72 |
+
assert len(hist) == 20
|
| 73 |
+
|
| 74 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 75 |
+
def test_task_003_fallback(self) -> None:
|
| 76 |
+
s = sample_scenario("task_003", seed=42)
|
| 77 |
+
hist = gen_val_accuracy_history(s)
|
| 78 |
+
assert len(hist) == 20
|
| 79 |
+
|
| 80 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 81 |
+
def test_task_004_fallback(self) -> None:
|
| 82 |
+
s = sample_scenario("task_004", seed=42)
|
| 83 |
+
hist = gen_val_accuracy_history(s)
|
| 84 |
+
assert len(hist) == 20
|
| 85 |
+
|
| 86 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 87 |
+
def test_task_005_fallback(self) -> None:
|
| 88 |
+
s = sample_scenario("task_005", seed=42)
|
| 89 |
+
hist = gen_val_accuracy_history(s)
|
| 90 |
+
assert len(hist) == 20
|
| 91 |
+
|
| 92 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 93 |
+
def test_task_006_fallback(self) -> None:
|
| 94 |
+
s = sample_scenario("task_006", seed=42)
|
| 95 |
+
hist = gen_val_accuracy_history(s)
|
| 96 |
+
assert len(hist) == 20
|
| 97 |
+
|
| 98 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 99 |
+
def test_task_007_fallback(self) -> None:
|
| 100 |
+
s = sample_scenario("task_007", seed=42)
|
| 101 |
+
hist = gen_val_accuracy_history(s)
|
| 102 |
+
assert len(hist) == 20
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class TestParametricFallbackValLoss:
|
| 106 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 107 |
+
def test_task_001_fallback(self) -> None:
|
| 108 |
+
s = sample_scenario("task_001", seed=42)
|
| 109 |
+
hist = gen_val_loss_history(s)
|
| 110 |
+
assert len(hist) == 20
|
| 111 |
+
|
| 112 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 113 |
+
def test_task_004_fallback(self) -> None:
|
| 114 |
+
s = sample_scenario("task_004", seed=42)
|
| 115 |
+
hist = gen_val_loss_history(s)
|
| 116 |
+
assert len(hist) == 20
|
| 117 |
+
|
| 118 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 119 |
+
def test_task_005_fallback(self) -> None:
|
| 120 |
+
s = sample_scenario("task_005", seed=42)
|
| 121 |
+
hist = gen_val_loss_history(s)
|
| 122 |
+
assert len(hist) == 20
|
| 123 |
+
|
| 124 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 125 |
+
def test_task_006_fallback(self) -> None:
|
| 126 |
+
s = sample_scenario("task_006", seed=42)
|
| 127 |
+
hist = gen_val_loss_history(s)
|
| 128 |
+
assert len(hist) == 20
|
| 129 |
+
|
| 130 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 131 |
+
def test_task_007_fallback(self) -> None:
|
| 132 |
+
s = sample_scenario("task_007", seed=42)
|
| 133 |
+
hist = gen_val_loss_history(s)
|
| 134 |
+
assert len(hist) == 20
|
| 135 |
+
|
| 136 |
+
@patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
|
| 137 |
+
def test_fallback_default(self) -> None:
|
| 138 |
+
"""Test the final fallback path for unknown root cause."""
|
| 139 |
+
from ml_training_debugger.models import RootCauseDiagnosis
|
| 140 |
+
from ml_training_debugger.scenarios import ScenarioParams
|
| 141 |
+
|
| 142 |
+
# Use scheduler root cause but force fallback
|
| 143 |
+
s = ScenarioParams(
|
| 144 |
+
task_id="task_999",
|
| 145 |
+
root_cause=RootCauseDiagnosis.SCHEDULER_MISCONFIGURED,
|
| 146 |
+
seed=42,
|
| 147 |
+
)
|
| 148 |
+
hist = gen_val_loss_history(s)
|
| 149 |
+
assert len(hist) == 20
|