File size: 4,839 Bytes
45eee48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Tests for real mini-training in pytorch_engine.py."""

from __future__ import annotations

import torch

from ml_training_debugger.pytorch_engine import (
    SimpleCNN,
    SimpleMLP,
    _TRAINING_CACHE,
    run_real_training,
)
from ml_training_debugger.scenarios import sample_scenario


class TestRunRealTraining:
    def test_returns_20_epoch_curves(self) -> None:
        s = sample_scenario("task_001", seed=42)
        curves = run_real_training(s)
        assert len(curves["loss_history"]) == 20
        assert len(curves["val_loss_history"]) == 20
        assert len(curves["val_acc_history"]) == 20

    def test_all_values_are_floats(self) -> None:
        s = sample_scenario("task_003", seed=42)
        curves = run_real_training(s)
        for key in ["loss_history", "val_loss_history", "val_acc_history"]:
            for v in curves[key]:
                assert isinstance(v, float), f"{key} has non-float: {type(v)}"

    def test_caching_works(self) -> None:
        _TRAINING_CACHE.clear()
        s = sample_scenario("task_001", seed=42)
        c1 = run_real_training(s)
        c2 = run_real_training(s)
        assert c1 is c2  # Same object reference = cached

    def test_reproducible_across_calls(self) -> None:
        _TRAINING_CACHE.clear()
        s = sample_scenario("task_002", seed=42)
        c1 = run_real_training(s)
        _TRAINING_CACHE.clear()
        c2 = run_real_training(s)
        assert c1["loss_history"] == c2["loss_history"]
        assert c1["val_acc_history"] == c2["val_acc_history"]

    def test_different_seeds_different_curves(self) -> None:
        s1 = sample_scenario("task_001", seed=42)
        s2 = sample_scenario("task_001", seed=99)
        c1 = run_real_training(s1)
        c2 = run_real_training(s2)
        assert c1["loss_history"] != c2["loss_history"]

    def test_task_001_high_lr_instability(self) -> None:
        s = sample_scenario("task_001", seed=42)
        curves = run_real_training(s)
        max_loss = max(v for v in curves["loss_history"] if v != float("inf"))
        assert max_loss > 3.0  # High LR causes loss spikes

    def test_task_002_vanishing_slow_learning(self) -> None:
        s = sample_scenario("task_002", seed=42)
        curves = run_real_training(s)
        assert len(curves["loss_history"]) == 20

    def test_task_003_data_leakage(self) -> None:
        s = sample_scenario("task_003", seed=42)
        curves = run_real_training(s)
        # With leakage, val accuracy may be elevated
        assert len(curves["val_acc_history"]) == 20

    def test_task_004_overfitting(self) -> None:
        s = sample_scenario("task_004", seed=42)
        curves = run_real_training(s)
        assert len(curves["loss_history"]) == 20

    def test_task_005_batchnorm_eval(self) -> None:
        s = sample_scenario("task_005", seed=42)
        curves = run_real_training(s)
        assert len(curves["loss_history"]) == 20

    def test_task_006_code_bug(self) -> None:
        s = sample_scenario("task_006", seed=42)
        curves = run_real_training(s)
        assert len(curves["loss_history"]) == 20

    def test_task_007_scheduler(self) -> None:
        s = sample_scenario("task_007", seed=42)
        curves = run_real_training(s)
        assert len(curves["loss_history"]) == 20

    def test_mlp_architecture(self) -> None:
        """Find a scenario that uses MLP and verify training works."""
        for seed in range(1, 20):
            s = sample_scenario("task_001", seed=seed)
            if s.model_type == "mlp":
                curves = run_real_training(s)
                assert len(curves["loss_history"]) == 20
                return
        # If no MLP found in 20 seeds, test directly
        from ml_training_debugger.scenarios import ScenarioParams
        from ml_training_debugger.models import RootCauseDiagnosis
        s = ScenarioParams(
            task_id="task_001",
            root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
            seed=999,
            learning_rate=0.1,
            model_type="mlp",
        )
        curves = run_real_training(s)
        assert len(curves["loss_history"]) == 20


class TestSimpleMLP:
    def test_is_nn_module(self) -> None:
        model = SimpleMLP()
        assert isinstance(model, torch.nn.Module)

    def test_param_count(self) -> None:
        model = SimpleMLP()
        count = sum(p.numel() for p in model.parameters())
        assert 10_000 < count < 500_000

    def test_forward_pass(self) -> None:
        model = SimpleMLP()
        x = torch.randn(4, 3, 32, 32)
        out = model(x)
        assert out.shape == (4, 10)

    def test_has_batchnorm(self) -> None:
        model = SimpleMLP()
        has_bn = any(
            isinstance(m, torch.nn.BatchNorm1d)
            for m in model.modules()
        )
        assert has_bn