File size: 4,839 Bytes
45eee48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """Tests for real mini-training in pytorch_engine.py."""
from __future__ import annotations
import torch
from ml_training_debugger.pytorch_engine import (
SimpleCNN,
SimpleMLP,
_TRAINING_CACHE,
run_real_training,
)
from ml_training_debugger.scenarios import sample_scenario
class TestRunRealTraining:
def test_returns_20_epoch_curves(self) -> None:
s = sample_scenario("task_001", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
assert len(curves["val_loss_history"]) == 20
assert len(curves["val_acc_history"]) == 20
def test_all_values_are_floats(self) -> None:
s = sample_scenario("task_003", seed=42)
curves = run_real_training(s)
for key in ["loss_history", "val_loss_history", "val_acc_history"]:
for v in curves[key]:
assert isinstance(v, float), f"{key} has non-float: {type(v)}"
def test_caching_works(self) -> None:
_TRAINING_CACHE.clear()
s = sample_scenario("task_001", seed=42)
c1 = run_real_training(s)
c2 = run_real_training(s)
assert c1 is c2 # Same object reference = cached
def test_reproducible_across_calls(self) -> None:
_TRAINING_CACHE.clear()
s = sample_scenario("task_002", seed=42)
c1 = run_real_training(s)
_TRAINING_CACHE.clear()
c2 = run_real_training(s)
assert c1["loss_history"] == c2["loss_history"]
assert c1["val_acc_history"] == c2["val_acc_history"]
def test_different_seeds_different_curves(self) -> None:
s1 = sample_scenario("task_001", seed=42)
s2 = sample_scenario("task_001", seed=99)
c1 = run_real_training(s1)
c2 = run_real_training(s2)
assert c1["loss_history"] != c2["loss_history"]
def test_task_001_high_lr_instability(self) -> None:
s = sample_scenario("task_001", seed=42)
curves = run_real_training(s)
max_loss = max(v for v in curves["loss_history"] if v != float("inf"))
assert max_loss > 3.0 # High LR causes loss spikes
def test_task_002_vanishing_slow_learning(self) -> None:
s = sample_scenario("task_002", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_003_data_leakage(self) -> None:
s = sample_scenario("task_003", seed=42)
curves = run_real_training(s)
# With leakage, val accuracy may be elevated
assert len(curves["val_acc_history"]) == 20
def test_task_004_overfitting(self) -> None:
s = sample_scenario("task_004", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_005_batchnorm_eval(self) -> None:
s = sample_scenario("task_005", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_006_code_bug(self) -> None:
s = sample_scenario("task_006", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_007_scheduler(self) -> None:
s = sample_scenario("task_007", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_mlp_architecture(self) -> None:
"""Find a scenario that uses MLP and verify training works."""
for seed in range(1, 20):
s = sample_scenario("task_001", seed=seed)
if s.model_type == "mlp":
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
return
# If no MLP found in 20 seeds, test directly
from ml_training_debugger.scenarios import ScenarioParams
from ml_training_debugger.models import RootCauseDiagnosis
s = ScenarioParams(
task_id="task_001",
root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
seed=999,
learning_rate=0.1,
model_type="mlp",
)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
class TestSimpleMLP:
def test_is_nn_module(self) -> None:
model = SimpleMLP()
assert isinstance(model, torch.nn.Module)
def test_param_count(self) -> None:
model = SimpleMLP()
count = sum(p.numel() for p in model.parameters())
assert 10_000 < count < 500_000
def test_forward_pass(self) -> None:
model = SimpleMLP()
x = torch.randn(4, 3, 32, 32)
out = model(x)
assert out.shape == (4, 10)
def test_has_batchnorm(self) -> None:
model = SimpleMLP()
has_bn = any(
isinstance(m, torch.nn.BatchNorm1d)
for m in model.modules()
)
assert has_bn
|