File size: 4,229 Bytes
54c5666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
"""Unit tests for optimizer utilities"""
import pytest
import torch
import torch.nn as nn
from src.training.optim import build_optimizer
from argparse import Namespace
class TestBuildOptimizer:
"""Test optimizer builder"""
@pytest.fixture
def simple_model(self):
"""Simple model for testing"""
return nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
@pytest.fixture
def args(self):
"""Default training arguments"""
return Namespace(
learning_rate=1e-3,
weight_decay=0.01,
adam_beta1=0.9,
adam_beta2=0.999,
optimizer='adamw'
)
def test_adamw_optimizer(self, simple_model, args):
"""Test AdamW optimizer creation"""
optimizer = build_optimizer(simple_model, args)
assert isinstance(optimizer, torch.optim.AdamW)
assert optimizer.defaults['lr'] == 1e-3
assert optimizer.defaults['weight_decay'] == 0.01
assert optimizer.defaults['betas'] == (0.9, 0.999)
def test_parameter_groups(self, args):
"""Test parameter grouping (no decay for bias/norm)"""
model = nn.Sequential(
nn.Linear(100, 50),
nn.LayerNorm(50),
nn.Linear(50, 10)
)
optimizer = build_optimizer(model, args)
# Should have 2 parameter groups: decay and no_decay
assert len(optimizer.param_groups) == 2
# Check weight decay settings
decay_group = optimizer.param_groups[0]
no_decay_group = optimizer.param_groups[1]
# One should have weight decay, other shouldn't
assert (decay_group['weight_decay'] == 0.01 and no_decay_group['weight_decay'] == 0.0) or \
(decay_group['weight_decay'] == 0.0 and no_decay_group['weight_decay'] == 0.01)
def test_zero_weight_decay(self, simple_model, args):
"""Test with zero weight decay"""
args.weight_decay = 0.0
optimizer = build_optimizer(simple_model, args)
# Should only have one param group when no weight decay
assert len(optimizer.param_groups) >= 1
def test_optimizer_step(self, simple_model, args):
"""Test that optimizer can step"""
optimizer = build_optimizer(simple_model, args)
# Forward pass
x = torch.randn(4, 100)
output = simple_model(x)
loss = output.sum()
# Backward
loss.backward()
# Get parameter before step
first_param_before = list(simple_model.parameters())[0].clone()
# Optimizer step
optimizer.step()
# Parameter should change
first_param_after = list(simple_model.parameters())[0]
assert not torch.equal(first_param_before, first_param_after)
class TestOptimizerIntegration:
"""Integration tests for optimizer"""
def test_training_loop(self):
"""Test optimizer in simple training loop"""
model = nn.Linear(10, 1)
args = Namespace(
learning_rate=0.01,
weight_decay=0.0,
adam_beta1=0.9,
adam_beta2=0.999,
optimizer='adamw'
)
optimizer = build_optimizer(model, args)
# Training data
x = torch.randn(100, 10)
y = torch.randn(100, 1)
initial_loss = None
final_loss = None
# Train for a few steps
for i in range(10):
optimizer.zero_grad()
pred = model(x)
loss = nn.MSELoss()(pred, y)
if i == 0:
initial_loss = loss.item()
if i == 9:
final_loss = loss.item()
loss.backward()
optimizer.step()
# Loss should decrease (model should learn)
assert final_loss < initial_loss
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|