omkarrr88 commited on
Commit
45eee48
·
1 Parent(s): 43647d3

docker size reduced

Browse files
.coverage CHANGED
Binary files a/.coverage and b/.coverage differ
 
.dockerignore CHANGED
@@ -3,7 +3,8 @@ __pycache__/
3
  .git/
4
  .pytest_cache/
5
  tests/
6
- validation/
 
7
  *.md
8
  !README.md
9
  .claude/
@@ -11,3 +12,4 @@ run*.json
11
  htmlcov/
12
  .mypy_cache/
13
  .ruff_cache/
 
 
3
  .git/
4
  .pytest_cache/
5
  tests/
6
+ validation/*.py
7
+ validation/requirements.txt
8
  *.md
9
  !README.md
10
  .claude/
 
12
  htmlcov/
13
  .mypy_cache/
14
  .ruff_cache/
15
+ .env
Dockerfile CHANGED
@@ -1,22 +1,29 @@
1
- FROM python:3.12-slim
2
-
3
- WORKDIR /app
4
 
5
- # Install system deps (curl for healthcheck)
6
- RUN apt-get update && apt-get install -y --no-install-recommends curl && \
7
  rm -rf /var/lib/apt/lists/*
8
 
9
- # Install ALL Python deps + safe cleanup in ONE layer.
10
- # Docker layers are immutable — cleanup in a separate RUN saves nothing.
11
- # PyTorch CPU-only (~280MB wheel, ~460MB installed) is the minimum for real
12
- # torch.nn.Module, torch.autograd, and state_dict() support.
13
  COPY requirements.txt .
14
- RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu && \
15
- pip install --no-cache-dir -r requirements.txt && \
16
- # Remove non-essential torch components (safe — verified these don't break imports)
 
 
 
 
 
 
 
17
  rm -rf /usr/local/lib/python3.12/site-packages/torch/test \
18
  /usr/local/lib/python3.12/site-packages/torch/include \
19
  /usr/local/lib/python3.12/site-packages/torch/share \
 
 
 
 
 
20
  /usr/local/lib/python3.12/site-packages/torch/utils/benchmark \
21
  /usr/local/lib/python3.12/site-packages/torch/utils/bottleneck \
22
  /usr/local/lib/python3.12/site-packages/torch/utils/tensorboard \
@@ -24,18 +31,74 @@ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/wh
24
  /usr/local/lib/python3.12/site-packages/torch/lib/libtorchbind_test.so \
25
  /usr/local/lib/python3.12/site-packages/torch/lib/libjitbackend_test.so \
26
  /usr/local/lib/python3.12/site-packages/torch/lib/libbackend_with_compiler.so \
27
- /usr/local/lib/python3.12/site-packages/caffe2 2>/dev/null; \
28
- find /usr/local/lib/python3.12/site-packages/torch -name "*.pyi" -delete 2>/dev/null; \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  find /usr/local/lib/python3.12/site-packages -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null; \
 
 
 
 
 
 
 
 
 
 
 
 
30
  true
31
 
32
- # Copy application code
 
 
 
 
 
 
 
 
 
 
 
33
  COPY ml_training_debugger/ ml_training_debugger/
34
  COPY server/ server/
35
  COPY openenv.yaml .
36
  COPY baseline_heuristic.py .
37
  COPY baseline_inference.py .
38
  COPY README.md .
 
39
 
40
  EXPOSE 7860
41
 
 
1
+ # ---- Stage 1: Builder — install + strip aggressively ----
2
+ FROM python:3.12-slim AS builder
 
3
 
4
+ RUN apt-get update && apt-get install -y --no-install-recommends binutils && \
 
5
  rm -rf /var/lib/apt/lists/*
6
 
7
+ WORKDIR /build
 
 
 
8
  COPY requirements.txt .
9
+
10
+ RUN pip install --no-cache-dir --no-compile \
11
+ torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu && \
12
+ pip install --no-cache-dir --no-compile -r requirements.txt && \
13
+ #
14
+ # === STRIP DEBUG SYMBOLS FROM ALL .so FILES (saves ~100-200MB) ===
15
+ find /usr/local/lib/python3.12/site-packages -name "*.so" -exec strip --strip-unneeded {} + 2>/dev/null; \
16
+ find /usr/local/lib/python3.12/site-packages -name "*.so.*" -exec strip --strip-unneeded {} + 2>/dev/null; \
17
+ #
18
+ # === TORCH CLEANUP ===
19
  rm -rf /usr/local/lib/python3.12/site-packages/torch/test \
20
  /usr/local/lib/python3.12/site-packages/torch/include \
21
  /usr/local/lib/python3.12/site-packages/torch/share \
22
+ /usr/local/lib/python3.12/site-packages/torch/bin/FileStore* \
23
+ /usr/local/lib/python3.12/site-packages/torch/bin/HashStore* \
24
+ /usr/local/lib/python3.12/site-packages/torch/bin/TCPStore* \
25
+ /usr/local/lib/python3.12/site-packages/torch/bin/protoc* \
26
+ /usr/local/lib/python3.12/site-packages/torch/bin/test_* \
27
  /usr/local/lib/python3.12/site-packages/torch/utils/benchmark \
28
  /usr/local/lib/python3.12/site-packages/torch/utils/bottleneck \
29
  /usr/local/lib/python3.12/site-packages/torch/utils/tensorboard \
 
31
  /usr/local/lib/python3.12/site-packages/torch/lib/libtorchbind_test.so \
32
  /usr/local/lib/python3.12/site-packages/torch/lib/libjitbackend_test.so \
33
  /usr/local/lib/python3.12/site-packages/torch/lib/libbackend_with_compiler.so \
34
+ /usr/local/lib/python3.12/site-packages/torch/lib/libaoti_custom_ops.so \
35
+ /usr/local/lib/python3.12/site-packages/torch/lib/libshm_windows \
36
+ /usr/local/lib/python3.12/site-packages/caffe2 \
37
+ #
38
+ # === BLOATED TRANSITIVE DEPS ===
39
+ /usr/local/lib/python3.12/site-packages/gradio \
40
+ /usr/local/lib/python3.12/site-packages/gradio_client \
41
+ /usr/local/lib/python3.12/site-packages/hf_gradio \
42
+ /usr/local/lib/python3.12/site-packages/pandas \
43
+ /usr/local/lib/python3.12/site-packages/PIL \
44
+ /usr/local/lib/python3.12/site-packages/Pillow* \
45
+ /usr/local/lib/python3.12/site-packages/pillow* \
46
+ /usr/local/lib/python3.12/site-packages/networkx \
47
+ /usr/local/lib/python3.12/site-packages/scipy \
48
+ /usr/local/lib/python3.12/site-packages/matplotlib \
49
+ /usr/local/lib/python3.12/site-packages/hf_xet \
50
+ /usr/local/lib/python3.12/site-packages/ffmpy \
51
+ /usr/local/lib/python3.12/site-packages/pydub \
52
+ /usr/local/lib/python3.12/site-packages/groovy \
53
+ /usr/local/lib/python3.12/site-packages/tomlkit \
54
+ /usr/local/lib/python3.12/site-packages/semantic_version* \
55
+ /usr/local/lib/python3.12/site-packages/safehttpx* \
56
+ /usr/local/lib/python3.12/site-packages/brotli* \
57
+ /usr/local/lib/python3.12/site-packages/Brotli* \
58
+ /usr/local/lib/python3.12/site-packages/pip \
59
+ /usr/local/lib/python3.12/site-packages/setuptools \
60
+ /usr/local/lib/python3.12/site-packages/docutils \
61
+ /usr/local/lib/python3.12/site-packages/cryptography \
62
+ /usr/local/lib/python3.12/site-packages/cryptography* \
63
+ /usr/local/lib/python3.12/site-packages/pytz 2>/dev/null; \
64
+ #
65
+ # === FILE-LEVEL CLEANUP ===
66
+ find /usr/local/lib/python3.12/site-packages -name "*.pyi" -delete 2>/dev/null; \
67
+ find /usr/local/lib/python3.12/site-packages -name "*.pyc" -delete 2>/dev/null; \
68
  find /usr/local/lib/python3.12/site-packages -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null; \
69
+ find /usr/local/lib/python3.12/site-packages -name "*.egg-info" -type d -exec rm -rf {} + 2>/dev/null; \
70
+ find /usr/local/lib/python3.12/site-packages -name "tests" -type d -exec rm -rf {} + 2>/dev/null; \
71
+ find /usr/local/lib/python3.12/site-packages -name "test" -type d -exec rm -rf {} + 2>/dev/null; \
72
+ # Remove stale dist-info for packages we already deleted
73
+ rm -rf /usr/local/lib/python3.12/site-packages/gradio*.dist-info \
74
+ /usr/local/lib/python3.12/site-packages/pandas*.dist-info \
75
+ /usr/local/lib/python3.12/site-packages/Pillow*.dist-info \
76
+ /usr/local/lib/python3.12/site-packages/hf_xet*.dist-info \
77
+ /usr/local/lib/python3.12/site-packages/Brotli*.dist-info \
78
+ /usr/local/lib/python3.12/site-packages/networkx*.dist-info \
79
+ /usr/local/lib/python3.12/site-packages/pip \
80
+ /usr/local/lib/python3.12/site-packages/pip*.dist-info 2>/dev/null; \
81
  true
82
 
83
+ # ---- Stage 2: Runtime — minimal clean image ----
84
+ FROM python:3.12-slim
85
+
86
+ WORKDIR /app
87
+
88
+ RUN apt-get update && apt-get install -y --no-install-recommends curl && \
89
+ rm -rf /var/lib/apt/lists/*
90
+
91
+ # Copy only what's needed from builder
92
+ COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
93
+ COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/uvicorn
94
+
95
  COPY ml_training_debugger/ ml_training_debugger/
96
  COPY server/ server/
97
  COPY openenv.yaml .
98
  COPY baseline_heuristic.py .
99
  COPY baseline_inference.py .
100
  COPY README.md .
101
+ COPY validation/reports/ validation/reports/
102
 
103
  EXPOSE 7860
104
 
tests/test_graders.py CHANGED
@@ -5,10 +5,12 @@ from __future__ import annotations
5
  import pytest
6
 
7
  from ml_training_debugger.graders import (
 
8
  grade_episode,
9
  grade_task_001,
10
  grade_task_003,
11
  grade_task_005,
 
12
  )
13
  from ml_training_debugger.models import EpisodeState
14
  from ml_training_debugger.scenarios import sample_scenario
@@ -166,3 +168,60 @@ class TestGradeEpisode:
166
  state = EpisodeState()
167
  score = grade_episode("task_999", state, scenario_001)
168
  assert score == 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import pytest
6
 
7
  from ml_training_debugger.graders import (
8
+ _submitted_diagnosis,
9
  grade_episode,
10
  grade_task_001,
11
  grade_task_003,
12
  grade_task_005,
13
+ grade_task_007,
14
  )
15
  from ml_training_debugger.models import EpisodeState
16
  from ml_training_debugger.scenarios import sample_scenario
 
168
  state = EpisodeState()
169
  score = grade_episode("task_999", state, scenario_001)
170
  assert score == 0.0
171
+
172
+
173
+ class TestGradeTask007:
174
+ def test_perfect_score(self):
175
+ scenario = sample_scenario("task_007", seed=42)
176
+ state = EpisodeState(
177
+ gradients_inspected=True,
178
+ data_inspected=True,
179
+ fix_action_taken=True,
180
+ restart_after_fix=True,
181
+ diagnosis_submitted=True,
182
+ actions_taken=[
183
+ "inspect_gradients",
184
+ "inspect_data_batch",
185
+ "modify_config",
186
+ "restart_run",
187
+ "mark_diagnosed:scheduler_misconfigured",
188
+ ],
189
+ )
190
+ score = grade_task_007(state, scenario)
191
+ assert score == 1.0
192
+
193
+ def test_wrong_diagnosis(self):
194
+ scenario = sample_scenario("task_007", seed=42)
195
+ state = EpisodeState(
196
+ diagnosis_submitted=True,
197
+ actions_taken=["mark_diagnosed:overfitting"],
198
+ )
199
+ score = grade_task_007(state, scenario)
200
+ assert score < 0.5
201
+
202
+ def test_score_in_range(self):
203
+ scenario = sample_scenario("task_007", seed=42)
204
+ state = EpisodeState()
205
+ score = grade_task_007(state, scenario)
206
+ assert 0.0 <= score <= 1.0
207
+
208
+
209
+ class TestSubmittedDiagnosis:
210
+ def test_finds_diagnosis(self):
211
+ state = EpisodeState(
212
+ actions_taken=["inspect_gradients", "mark_diagnosed:lr_too_high"],
213
+ )
214
+ assert _submitted_diagnosis(state) == "lr_too_high"
215
+
216
+ def test_no_diagnosis(self):
217
+ state = EpisodeState(actions_taken=["inspect_gradients"])
218
+ assert _submitted_diagnosis(state) is None
219
+
220
+ def test_latest_diagnosis(self):
221
+ state = EpisodeState(
222
+ actions_taken=[
223
+ "mark_diagnosed:overfitting",
224
+ "mark_diagnosed:lr_too_high",
225
+ ],
226
+ )
227
+ assert _submitted_diagnosis(state) == "lr_too_high"
tests/test_new_endpoints.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for new endpoints: curriculum, leaderboard, replay, validation-report."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ from fastapi.testclient import TestClient
7
+
8
+ from server.app import app
9
+
10
+
11
+ @pytest.fixture
12
+ def client():
13
+ return TestClient(app)
14
+
15
+
16
+ class TestCurriculumEndpoint:
17
+ def test_returns_curriculum(self, client) -> None:
18
+ resp = client.get("/curriculum")
19
+ assert resp.status_code == 200
20
+ data = resp.json()
21
+ assert "curriculum" in data
22
+ assert "total_episodes" in data
23
+ assert data["total_episodes"] > 0
24
+
25
+ def test_curriculum_has_difficulty_levels(self, client) -> None:
26
+ resp = client.get("/curriculum")
27
+ curriculum = resp.json()["curriculum"]
28
+ levels = {entry["difficulty_level"] for entry in curriculum}
29
+ assert 1 in levels
30
+ assert 3 in levels
31
+ assert 5 in levels
32
+
33
+ def test_curriculum_covers_all_tasks(self, client) -> None:
34
+ resp = client.get("/curriculum")
35
+ curriculum = resp.json()["curriculum"]
36
+ task_ids = {entry["task_id"] for entry in curriculum}
37
+ assert "task_001" in task_ids
38
+ assert "task_007" in task_ids
39
+
40
+
41
+ class TestLeaderboardEndpoint:
42
+ def test_returns_leaderboard(self, client) -> None:
43
+ resp = client.get("/leaderboard")
44
+ assert resp.status_code == 200
45
+ data = resp.json()
46
+ assert "entries" in data
47
+ assert "total" in data
48
+
49
+ def test_leaderboard_after_baseline(self, client) -> None:
50
+ # Run baseline to populate scores
51
+ client.post("/baseline")
52
+ resp = client.get("/leaderboard")
53
+ data = resp.json()
54
+ assert data["total"] > 0
55
+
56
+
57
+ class TestReplayEndpoint:
58
+ def test_missing_episode(self, client) -> None:
59
+ resp = client.get("/replay/nonexistent_episode_123")
60
+ assert resp.status_code == 200
61
+ data = resp.json()
62
+ assert "error" in data
63
+
64
+ def test_replay_after_baseline(self, client) -> None:
65
+ # Run baseline to create episodes
66
+ client.post("/baseline")
67
+ resp = client.get("/replay/baseline_task_001")
68
+ data = resp.json()
69
+ # Should have episode data or error
70
+ assert "episode_id" in data or "error" in data
71
+
72
+
73
+ class TestValidationReportEndpoint:
74
+ def test_returns_real_report(self, client) -> None:
75
+ resp = client.get("/validation-report")
76
+ assert resp.status_code == 200
77
+ data = resp.json()
78
+ assert "results" in data
79
+ assert "summary" in data
80
+ assert data["summary"]["passed"] > 0
tests/test_real_training.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for real mini-training in pytorch_engine.py."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ml_training_debugger.pytorch_engine import (
8
+ SimpleCNN,
9
+ SimpleMLP,
10
+ _TRAINING_CACHE,
11
+ run_real_training,
12
+ )
13
+ from ml_training_debugger.scenarios import sample_scenario
14
+
15
+
16
+ class TestRunRealTraining:
17
+ def test_returns_20_epoch_curves(self) -> None:
18
+ s = sample_scenario("task_001", seed=42)
19
+ curves = run_real_training(s)
20
+ assert len(curves["loss_history"]) == 20
21
+ assert len(curves["val_loss_history"]) == 20
22
+ assert len(curves["val_acc_history"]) == 20
23
+
24
+ def test_all_values_are_floats(self) -> None:
25
+ s = sample_scenario("task_003", seed=42)
26
+ curves = run_real_training(s)
27
+ for key in ["loss_history", "val_loss_history", "val_acc_history"]:
28
+ for v in curves[key]:
29
+ assert isinstance(v, float), f"{key} has non-float: {type(v)}"
30
+
31
+ def test_caching_works(self) -> None:
32
+ _TRAINING_CACHE.clear()
33
+ s = sample_scenario("task_001", seed=42)
34
+ c1 = run_real_training(s)
35
+ c2 = run_real_training(s)
36
+ assert c1 is c2 # Same object reference = cached
37
+
38
+ def test_reproducible_across_calls(self) -> None:
39
+ _TRAINING_CACHE.clear()
40
+ s = sample_scenario("task_002", seed=42)
41
+ c1 = run_real_training(s)
42
+ _TRAINING_CACHE.clear()
43
+ c2 = run_real_training(s)
44
+ assert c1["loss_history"] == c2["loss_history"]
45
+ assert c1["val_acc_history"] == c2["val_acc_history"]
46
+
47
+ def test_different_seeds_different_curves(self) -> None:
48
+ s1 = sample_scenario("task_001", seed=42)
49
+ s2 = sample_scenario("task_001", seed=99)
50
+ c1 = run_real_training(s1)
51
+ c2 = run_real_training(s2)
52
+ assert c1["loss_history"] != c2["loss_history"]
53
+
54
+ def test_task_001_high_lr_instability(self) -> None:
55
+ s = sample_scenario("task_001", seed=42)
56
+ curves = run_real_training(s)
57
+ max_loss = max(v for v in curves["loss_history"] if v != float("inf"))
58
+ assert max_loss > 3.0 # High LR causes loss spikes
59
+
60
+ def test_task_002_vanishing_slow_learning(self) -> None:
61
+ s = sample_scenario("task_002", seed=42)
62
+ curves = run_real_training(s)
63
+ assert len(curves["loss_history"]) == 20
64
+
65
+ def test_task_003_data_leakage(self) -> None:
66
+ s = sample_scenario("task_003", seed=42)
67
+ curves = run_real_training(s)
68
+ # With leakage, val accuracy may be elevated
69
+ assert len(curves["val_acc_history"]) == 20
70
+
71
+ def test_task_004_overfitting(self) -> None:
72
+ s = sample_scenario("task_004", seed=42)
73
+ curves = run_real_training(s)
74
+ assert len(curves["loss_history"]) == 20
75
+
76
+ def test_task_005_batchnorm_eval(self) -> None:
77
+ s = sample_scenario("task_005", seed=42)
78
+ curves = run_real_training(s)
79
+ assert len(curves["loss_history"]) == 20
80
+
81
+ def test_task_006_code_bug(self) -> None:
82
+ s = sample_scenario("task_006", seed=42)
83
+ curves = run_real_training(s)
84
+ assert len(curves["loss_history"]) == 20
85
+
86
+ def test_task_007_scheduler(self) -> None:
87
+ s = sample_scenario("task_007", seed=42)
88
+ curves = run_real_training(s)
89
+ assert len(curves["loss_history"]) == 20
90
+
91
+ def test_mlp_architecture(self) -> None:
92
+ """Find a scenario that uses MLP and verify training works."""
93
+ for seed in range(1, 20):
94
+ s = sample_scenario("task_001", seed=seed)
95
+ if s.model_type == "mlp":
96
+ curves = run_real_training(s)
97
+ assert len(curves["loss_history"]) == 20
98
+ return
99
+ # If no MLP found in 20 seeds, test directly
100
+ from ml_training_debugger.scenarios import ScenarioParams
101
+ from ml_training_debugger.models import RootCauseDiagnosis
102
+ s = ScenarioParams(
103
+ task_id="task_001",
104
+ root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
105
+ seed=999,
106
+ learning_rate=0.1,
107
+ model_type="mlp",
108
+ )
109
+ curves = run_real_training(s)
110
+ assert len(curves["loss_history"]) == 20
111
+
112
+
113
+ class TestSimpleMLP:
114
+ def test_is_nn_module(self) -> None:
115
+ model = SimpleMLP()
116
+ assert isinstance(model, torch.nn.Module)
117
+
118
+ def test_param_count(self) -> None:
119
+ model = SimpleMLP()
120
+ count = sum(p.numel() for p in model.parameters())
121
+ assert 10_000 < count < 500_000
122
+
123
+ def test_forward_pass(self) -> None:
124
+ model = SimpleMLP()
125
+ x = torch.randn(4, 3, 32, 32)
126
+ out = model(x)
127
+ assert out.shape == (4, 10)
128
+
129
+ def test_has_batchnorm(self) -> None:
130
+ model = SimpleMLP()
131
+ has_bn = any(
132
+ isinstance(m, torch.nn.BatchNorm1d)
133
+ for m in model.modules()
134
+ )
135
+ assert has_bn
tests/test_simulation_fallback.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for parametric fallback in simulation.py.
2
+
3
+ These test the fallback paths that run when real training is unavailable.
4
+ We force fallback by monkeypatching _get_real_curves to return None.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from unittest.mock import patch
10
+
11
+ from ml_training_debugger.scenarios import sample_scenario
12
+ from ml_training_debugger.simulation import (
13
+ gen_loss_history,
14
+ gen_val_accuracy_history,
15
+ gen_val_loss_history,
16
+ )
17
+
18
+
19
+ def _force_fallback(*args, **kwargs):
20
+ return None
21
+
22
+
23
+ class TestParametricFallbackLoss:
24
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
25
+ def test_task_001_fallback(self) -> None:
26
+ s = sample_scenario("task_001", seed=42)
27
+ hist = gen_loss_history(s)
28
+ assert len(hist) == 20
29
+
30
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
31
+ def test_task_002_fallback(self) -> None:
32
+ s = sample_scenario("task_002", seed=42)
33
+ hist = gen_loss_history(s)
34
+ assert len(hist) == 20
35
+
36
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
37
+ def test_task_003_fallback(self) -> None:
38
+ s = sample_scenario("task_003", seed=42)
39
+ hist = gen_loss_history(s)
40
+ assert len(hist) == 20
41
+
42
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
43
+ def test_task_004_fallback(self) -> None:
44
+ s = sample_scenario("task_004", seed=42)
45
+ hist = gen_loss_history(s)
46
+ assert len(hist) == 20
47
+
48
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
49
+ def test_task_005_fallback(self) -> None:
50
+ s = sample_scenario("task_005", seed=42)
51
+ hist = gen_loss_history(s)
52
+ assert len(hist) == 20
53
+
54
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
55
+ def test_task_006_fallback(self) -> None:
56
+ s = sample_scenario("task_006", seed=42)
57
+ hist = gen_loss_history(s)
58
+ assert len(hist) == 20
59
+
60
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
61
+ def test_task_007_fallback(self) -> None:
62
+ s = sample_scenario("task_007", seed=42)
63
+ hist = gen_loss_history(s)
64
+ assert len(hist) == 20
65
+
66
+
67
+ class TestParametricFallbackValAcc:
68
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
69
+ def test_task_001_fallback(self) -> None:
70
+ s = sample_scenario("task_001", seed=42)
71
+ hist = gen_val_accuracy_history(s)
72
+ assert len(hist) == 20
73
+
74
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
75
+ def test_task_003_fallback(self) -> None:
76
+ s = sample_scenario("task_003", seed=42)
77
+ hist = gen_val_accuracy_history(s)
78
+ assert len(hist) == 20
79
+
80
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
81
+ def test_task_004_fallback(self) -> None:
82
+ s = sample_scenario("task_004", seed=42)
83
+ hist = gen_val_accuracy_history(s)
84
+ assert len(hist) == 20
85
+
86
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
87
+ def test_task_005_fallback(self) -> None:
88
+ s = sample_scenario("task_005", seed=42)
89
+ hist = gen_val_accuracy_history(s)
90
+ assert len(hist) == 20
91
+
92
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
93
+ def test_task_006_fallback(self) -> None:
94
+ s = sample_scenario("task_006", seed=42)
95
+ hist = gen_val_accuracy_history(s)
96
+ assert len(hist) == 20
97
+
98
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
99
+ def test_task_007_fallback(self) -> None:
100
+ s = sample_scenario("task_007", seed=42)
101
+ hist = gen_val_accuracy_history(s)
102
+ assert len(hist) == 20
103
+
104
+
105
+ class TestParametricFallbackValLoss:
106
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
107
+ def test_task_001_fallback(self) -> None:
108
+ s = sample_scenario("task_001", seed=42)
109
+ hist = gen_val_loss_history(s)
110
+ assert len(hist) == 20
111
+
112
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
113
+ def test_task_004_fallback(self) -> None:
114
+ s = sample_scenario("task_004", seed=42)
115
+ hist = gen_val_loss_history(s)
116
+ assert len(hist) == 20
117
+
118
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
119
+ def test_task_005_fallback(self) -> None:
120
+ s = sample_scenario("task_005", seed=42)
121
+ hist = gen_val_loss_history(s)
122
+ assert len(hist) == 20
123
+
124
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
125
+ def test_task_006_fallback(self) -> None:
126
+ s = sample_scenario("task_006", seed=42)
127
+ hist = gen_val_loss_history(s)
128
+ assert len(hist) == 20
129
+
130
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
131
+ def test_task_007_fallback(self) -> None:
132
+ s = sample_scenario("task_007", seed=42)
133
+ hist = gen_val_loss_history(s)
134
+ assert len(hist) == 20
135
+
136
+ @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
137
+ def test_fallback_default(self) -> None:
138
+ """Test the final fallback path for unknown root cause."""
139
+ from ml_training_debugger.models import RootCauseDiagnosis
140
+ from ml_training_debugger.scenarios import ScenarioParams
141
+
142
+ # Use scheduler root cause but force fallback
143
+ s = ScenarioParams(
144
+ task_id="task_999",
145
+ root_cause=RootCauseDiagnosis.SCHEDULER_MISCONFIGURED,
146
+ seed=42,
147
+ )
148
+ hist = gen_val_loss_history(s)
149
+ assert len(hist) == 20