pytorch-training-debugger / tests /test_real_training.py
omkarrr88
docker size reduced
45eee48
"""Tests for real mini-training in pytorch_engine.py."""
from __future__ import annotations
import torch
from ml_training_debugger.pytorch_engine import (
SimpleCNN,
SimpleMLP,
_TRAINING_CACHE,
run_real_training,
)
from ml_training_debugger.scenarios import sample_scenario
class TestRunRealTraining:
def test_returns_20_epoch_curves(self) -> None:
s = sample_scenario("task_001", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
assert len(curves["val_loss_history"]) == 20
assert len(curves["val_acc_history"]) == 20
def test_all_values_are_floats(self) -> None:
s = sample_scenario("task_003", seed=42)
curves = run_real_training(s)
for key in ["loss_history", "val_loss_history", "val_acc_history"]:
for v in curves[key]:
assert isinstance(v, float), f"{key} has non-float: {type(v)}"
def test_caching_works(self) -> None:
_TRAINING_CACHE.clear()
s = sample_scenario("task_001", seed=42)
c1 = run_real_training(s)
c2 = run_real_training(s)
assert c1 is c2 # Same object reference = cached
def test_reproducible_across_calls(self) -> None:
_TRAINING_CACHE.clear()
s = sample_scenario("task_002", seed=42)
c1 = run_real_training(s)
_TRAINING_CACHE.clear()
c2 = run_real_training(s)
assert c1["loss_history"] == c2["loss_history"]
assert c1["val_acc_history"] == c2["val_acc_history"]
def test_different_seeds_different_curves(self) -> None:
s1 = sample_scenario("task_001", seed=42)
s2 = sample_scenario("task_001", seed=99)
c1 = run_real_training(s1)
c2 = run_real_training(s2)
assert c1["loss_history"] != c2["loss_history"]
def test_task_001_high_lr_instability(self) -> None:
s = sample_scenario("task_001", seed=42)
curves = run_real_training(s)
max_loss = max(v for v in curves["loss_history"] if v != float("inf"))
assert max_loss > 3.0 # High LR causes loss spikes
def test_task_002_vanishing_slow_learning(self) -> None:
s = sample_scenario("task_002", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_003_data_leakage(self) -> None:
s = sample_scenario("task_003", seed=42)
curves = run_real_training(s)
# With leakage, val accuracy may be elevated
assert len(curves["val_acc_history"]) == 20
def test_task_004_overfitting(self) -> None:
s = sample_scenario("task_004", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_005_batchnorm_eval(self) -> None:
s = sample_scenario("task_005", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_006_code_bug(self) -> None:
s = sample_scenario("task_006", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_task_007_scheduler(self) -> None:
s = sample_scenario("task_007", seed=42)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
def test_mlp_architecture(self) -> None:
"""Find a scenario that uses MLP and verify training works."""
for seed in range(1, 20):
s = sample_scenario("task_001", seed=seed)
if s.model_type == "mlp":
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
return
# If no MLP found in 20 seeds, test directly
from ml_training_debugger.scenarios import ScenarioParams
from ml_training_debugger.models import RootCauseDiagnosis
s = ScenarioParams(
task_id="task_001",
root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
seed=999,
learning_rate=0.1,
model_type="mlp",
)
curves = run_real_training(s)
assert len(curves["loss_history"]) == 20
class TestSimpleMLP:
def test_is_nn_module(self) -> None:
model = SimpleMLP()
assert isinstance(model, torch.nn.Module)
def test_param_count(self) -> None:
model = SimpleMLP()
count = sum(p.numel() for p in model.parameters())
assert 10_000 < count < 500_000
def test_forward_pass(self) -> None:
model = SimpleMLP()
x = torch.randn(4, 3, 32, 32)
out = model(x)
assert out.shape == (4, 10)
def test_has_batchnorm(self) -> None:
model = SimpleMLP()
has_bn = any(
isinstance(m, torch.nn.BatchNorm1d)
for m in model.modules()
)
assert has_bn