File size: 6,337 Bytes
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f58e42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Test real PyTorch model instantiation and fault injection."""

from __future__ import annotations

import torch
import torch.nn as nn

from ml_training_debugger.pytorch_engine import (
    SimpleCNN,
    create_model_and_inject_fault,
    extract_gradient_stats,
    extract_model_modes,
    extract_weight_stats,
)
from ml_training_debugger.scenarios import sample_scenario


class TestSimpleCNN:
    def test_is_nn_module(self):
        model = SimpleCNN()
        assert isinstance(model, nn.Module)

    def test_param_count(self):
        model = SimpleCNN()
        count = sum(p.numel() for p in model.parameters())
        assert 30_000 < count < 100_000  # ~50K params

    def test_forward_pass(self):
        model = SimpleCNN()
        x = torch.randn(2, 3, 32, 32)
        out = model(x)
        assert out.shape == (2, 10)


class TestFaultInjection:
    def test_task_001_exploding_gradients(self):
        scenario = sample_scenario("task_001", seed=42)
        model, info = create_model_and_inject_fault(scenario)
        stats = extract_gradient_stats(model, scenario)
        assert len(stats) > 0
        # At least some layers should have elevated gradients
        any_high = any(s.mean_norm > 1.0 for s in stats)
        assert any_high

    def test_task_005_eval_mode(self):
        scenario = sample_scenario("task_005", seed=42)
        model, info = create_model_and_inject_fault(scenario)
        assert not model.training  # model.eval() was called

    def test_task_005_gradients_not_exploding(self):
        scenario = sample_scenario("task_005", seed=42)
        model, info = create_model_and_inject_fault(scenario)
        stats = extract_gradient_stats(model, scenario)
        # ALL layers must have is_exploding=False
        for s in stats:
            assert not s.is_exploding, f"Layer {s.layer_name} should not be exploding"


class TestExtractGradientStats:
    def test_returns_gradient_stats(self):
        scenario = sample_scenario("task_001", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        stats = extract_gradient_stats(model, scenario)
        assert len(stats) == 4  # conv1, conv2, conv3, fc
        for s in stats:
            assert isinstance(s.mean_norm, float)
            assert isinstance(s.norm_history, list)
            assert len(s.norm_history) == 5


class TestExtractWeightStats:
    def test_returns_weight_stats(self):
        scenario = sample_scenario("task_001", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        stats = extract_weight_stats(model)
        assert len(stats) > 0
        for s in stats:
            assert isinstance(s.weight_norm, float)
            assert isinstance(s.has_nan, bool)


class TestExtractModelModes:
    def test_train_mode(self):
        model = SimpleCNN()
        model.train()
        modes = extract_model_modes(model)
        assert all(v == "train" for v in modes.values())

    def test_eval_mode(self):
        model = SimpleCNN()
        model.eval()
        modes = extract_model_modes(model)
        assert all(v == "eval" for v in modes.values())


class TestTask005RedHerrings:
    """Test Task 5 red herring injection — conv1 near-vanishing, FC spike."""

    def test_conv1_near_vanishing_red_herring(self):
        """When spike layer is fc, conv1 should show near-vanishing gradient."""
        scenario = sample_scenario("task_005", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        stats = extract_gradient_stats(model, scenario)

        conv1 = next(s for s in stats if s.layer_name == "conv1")
        if scenario.red_herring_spike_layer != "conv1":
            # conv1 should be near-vanishing (but not is_vanishing since 0.0003 > 1e-6)
            assert conv1.mean_norm < 0.01
            assert not conv1.is_vanishing  # 0.0003 > 1e-6

    def test_fc_spike_not_exploding(self):
        """FC spike has elevated gradient but is_exploding=False (mean < 10.0)."""
        scenario = sample_scenario("task_005", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        stats = extract_gradient_stats(model, scenario)

        spike_layer = next(
            s for s in stats if s.layer_name == scenario.red_herring_spike_layer
        )
        assert not spike_layer.is_exploding
        # Should have non-trivial norm from the spike
        assert spike_layer.mean_norm > 0

    def test_all_layers_not_exploding(self):
        """All layers is_exploding=False — this gates gradients_were_normal."""
        scenario = sample_scenario("task_005", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        stats = extract_gradient_stats(model, scenario)
        for s in stats:
            assert not s.is_exploding, f"{s.layer_name} should not be exploding"


class TestVanishingGradientInjection:
    """Test vanishing gradient fault injection produces correct stats."""

    def test_task_002_vanishing(self):
        scenario = sample_scenario("task_002", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        stats = extract_gradient_stats(model, scenario)
        # Deeper layers should have vanishing gradients
        assert any(s.is_vanishing for s in stats)

    def test_task_002_model_in_train_mode(self):
        scenario = sample_scenario("task_002", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        assert model.training


class TestCodeBugFaultInjection:
    """Test code bug fault injection — model should be normal."""

    def test_task_006_model_trains_normally(self):
        scenario = sample_scenario("task_006", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        assert model.training  # Should be in train mode
        stats = extract_gradient_stats(model, scenario)
        # No exploding/vanishing — bug is in code only
        assert not any(s.is_exploding for s in stats)


class TestDataLeakageFaultInjection:
    """Test data leakage scenario — model should be normal."""

    def test_task_003_normal_model(self):
        scenario = sample_scenario("task_003", seed=42)
        model, _ = create_model_and_inject_fault(scenario)
        assert model.training
        stats = extract_gradient_stats(model, scenario)
        assert not any(s.is_exploding for s in stats)