File size: 6,337 Bytes
e2f8b29 4f58e42 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """Test real PyTorch model instantiation and fault injection."""
from __future__ import annotations
import torch
import torch.nn as nn
from ml_training_debugger.pytorch_engine import (
SimpleCNN,
create_model_and_inject_fault,
extract_gradient_stats,
extract_model_modes,
extract_weight_stats,
)
from ml_training_debugger.scenarios import sample_scenario
class TestSimpleCNN:
def test_is_nn_module(self):
model = SimpleCNN()
assert isinstance(model, nn.Module)
def test_param_count(self):
model = SimpleCNN()
count = sum(p.numel() for p in model.parameters())
assert 30_000 < count < 100_000 # ~50K params
def test_forward_pass(self):
model = SimpleCNN()
x = torch.randn(2, 3, 32, 32)
out = model(x)
assert out.shape == (2, 10)
class TestFaultInjection:
def test_task_001_exploding_gradients(self):
scenario = sample_scenario("task_001", seed=42)
model, info = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
assert len(stats) > 0
# At least some layers should have elevated gradients
any_high = any(s.mean_norm > 1.0 for s in stats)
assert any_high
def test_task_005_eval_mode(self):
scenario = sample_scenario("task_005", seed=42)
model, info = create_model_and_inject_fault(scenario)
assert not model.training # model.eval() was called
def test_task_005_gradients_not_exploding(self):
scenario = sample_scenario("task_005", seed=42)
model, info = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
# ALL layers must have is_exploding=False
for s in stats:
assert not s.is_exploding, f"Layer {s.layer_name} should not be exploding"
class TestExtractGradientStats:
def test_returns_gradient_stats(self):
scenario = sample_scenario("task_001", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
assert len(stats) == 4 # conv1, conv2, conv3, fc
for s in stats:
assert isinstance(s.mean_norm, float)
assert isinstance(s.norm_history, list)
assert len(s.norm_history) == 5
class TestExtractWeightStats:
def test_returns_weight_stats(self):
scenario = sample_scenario("task_001", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_weight_stats(model)
assert len(stats) > 0
for s in stats:
assert isinstance(s.weight_norm, float)
assert isinstance(s.has_nan, bool)
class TestExtractModelModes:
def test_train_mode(self):
model = SimpleCNN()
model.train()
modes = extract_model_modes(model)
assert all(v == "train" for v in modes.values())
def test_eval_mode(self):
model = SimpleCNN()
model.eval()
modes = extract_model_modes(model)
assert all(v == "eval" for v in modes.values())
class TestTask005RedHerrings:
"""Test Task 5 red herring injection — conv1 near-vanishing, FC spike."""
def test_conv1_near_vanishing_red_herring(self):
"""When spike layer is fc, conv1 should show near-vanishing gradient."""
scenario = sample_scenario("task_005", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
conv1 = next(s for s in stats if s.layer_name == "conv1")
if scenario.red_herring_spike_layer != "conv1":
# conv1 should be near-vanishing (but not is_vanishing since 0.0003 > 1e-6)
assert conv1.mean_norm < 0.01
assert not conv1.is_vanishing # 0.0003 > 1e-6
def test_fc_spike_not_exploding(self):
"""FC spike has elevated gradient but is_exploding=False (mean < 10.0)."""
scenario = sample_scenario("task_005", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
spike_layer = next(
s for s in stats if s.layer_name == scenario.red_herring_spike_layer
)
assert not spike_layer.is_exploding
# Should have non-trivial norm from the spike
assert spike_layer.mean_norm > 0
def test_all_layers_not_exploding(self):
"""All layers is_exploding=False — this gates gradients_were_normal."""
scenario = sample_scenario("task_005", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
for s in stats:
assert not s.is_exploding, f"{s.layer_name} should not be exploding"
class TestVanishingGradientInjection:
"""Test vanishing gradient fault injection produces correct stats."""
def test_task_002_vanishing(self):
scenario = sample_scenario("task_002", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
# Deeper layers should have vanishing gradients
assert any(s.is_vanishing for s in stats)
def test_task_002_model_in_train_mode(self):
scenario = sample_scenario("task_002", seed=42)
model, _ = create_model_and_inject_fault(scenario)
assert model.training
class TestCodeBugFaultInjection:
"""Test code bug fault injection — model should be normal."""
def test_task_006_model_trains_normally(self):
scenario = sample_scenario("task_006", seed=42)
model, _ = create_model_and_inject_fault(scenario)
assert model.training # Should be in train mode
stats = extract_gradient_stats(model, scenario)
# No exploding/vanishing — bug is in code only
assert not any(s.is_exploding for s in stats)
class TestDataLeakageFaultInjection:
"""Test data leakage scenario — model should be normal."""
def test_task_003_normal_model(self):
scenario = sample_scenario("task_003", seed=42)
model, _ = create_model_and_inject_fault(scenario)
assert model.training
stats = extract_gradient_stats(model, scenario)
assert not any(s.is_exploding for s in stats)
|