| """Tests for real mini-training in pytorch_engine.py.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
|
|
| from ml_training_debugger.pytorch_engine import ( |
| SimpleCNN, |
| SimpleMLP, |
| _TRAINING_CACHE, |
| run_real_training, |
| ) |
| from ml_training_debugger.scenarios import sample_scenario |
|
|
|
|
| class TestRunRealTraining: |
| def test_returns_20_epoch_curves(self) -> None: |
| s = sample_scenario("task_001", seed=42) |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
| assert len(curves["val_loss_history"]) == 20 |
| assert len(curves["val_acc_history"]) == 20 |
|
|
| def test_all_values_are_floats(self) -> None: |
| s = sample_scenario("task_003", seed=42) |
| curves = run_real_training(s) |
| for key in ["loss_history", "val_loss_history", "val_acc_history"]: |
| for v in curves[key]: |
| assert isinstance(v, float), f"{key} has non-float: {type(v)}" |
|
|
| def test_caching_works(self) -> None: |
| _TRAINING_CACHE.clear() |
| s = sample_scenario("task_001", seed=42) |
| c1 = run_real_training(s) |
| c2 = run_real_training(s) |
| assert c1 is c2 |
|
|
| def test_reproducible_across_calls(self) -> None: |
| _TRAINING_CACHE.clear() |
| s = sample_scenario("task_002", seed=42) |
| c1 = run_real_training(s) |
| _TRAINING_CACHE.clear() |
| c2 = run_real_training(s) |
| assert c1["loss_history"] == c2["loss_history"] |
| assert c1["val_acc_history"] == c2["val_acc_history"] |
|
|
| def test_different_seeds_different_curves(self) -> None: |
| s1 = sample_scenario("task_001", seed=42) |
| s2 = sample_scenario("task_001", seed=99) |
| c1 = run_real_training(s1) |
| c2 = run_real_training(s2) |
| assert c1["loss_history"] != c2["loss_history"] |
|
|
| def test_task_001_high_lr_instability(self) -> None: |
| s = sample_scenario("task_001", seed=42) |
| curves = run_real_training(s) |
| max_loss = max(v for v in curves["loss_history"] if v != float("inf")) |
| assert max_loss > 3.0 |
|
|
| def test_task_002_vanishing_slow_learning(self) -> None: |
| s = sample_scenario("task_002", seed=42) |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
|
|
| def test_task_003_data_leakage(self) -> None: |
| s = sample_scenario("task_003", seed=42) |
| curves = run_real_training(s) |
| |
| assert len(curves["val_acc_history"]) == 20 |
|
|
| def test_task_004_overfitting(self) -> None: |
| s = sample_scenario("task_004", seed=42) |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
|
|
| def test_task_005_batchnorm_eval(self) -> None: |
| s = sample_scenario("task_005", seed=42) |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
|
|
| def test_task_006_code_bug(self) -> None: |
| s = sample_scenario("task_006", seed=42) |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
|
|
| def test_task_007_scheduler(self) -> None: |
| s = sample_scenario("task_007", seed=42) |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
|
|
| def test_mlp_architecture(self) -> None: |
| """Find a scenario that uses MLP and verify training works.""" |
| for seed in range(1, 20): |
| s = sample_scenario("task_001", seed=seed) |
| if s.model_type == "mlp": |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
| return |
| |
| from ml_training_debugger.scenarios import ScenarioParams |
| from ml_training_debugger.models import RootCauseDiagnosis |
| s = ScenarioParams( |
| task_id="task_001", |
| root_cause=RootCauseDiagnosis.LR_TOO_HIGH, |
| seed=999, |
| learning_rate=0.1, |
| model_type="mlp", |
| ) |
| curves = run_real_training(s) |
| assert len(curves["loss_history"]) == 20 |
|
|
|
|
| class TestSimpleMLP: |
| def test_is_nn_module(self) -> None: |
| model = SimpleMLP() |
| assert isinstance(model, torch.nn.Module) |
|
|
| def test_param_count(self) -> None: |
| model = SimpleMLP() |
| count = sum(p.numel() for p in model.parameters()) |
| assert 10_000 < count < 500_000 |
|
|
| def test_forward_pass(self) -> None: |
| model = SimpleMLP() |
| x = torch.randn(4, 3, 32, 32) |
| out = model(x) |
| assert out.shape == (4, 10) |
|
|
| def test_has_batchnorm(self) -> None: |
| model = SimpleMLP() |
| has_bn = any( |
| isinstance(m, torch.nn.BatchNorm1d) |
| for m in model.modules() |
| ) |
| assert has_bn |
|
|