Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""Unit tests for optimizer utilities"""
import pytest
import torch
import torch.nn as nn
from src.training.optim import build_optimizer
from argparse import Namespace
class TestBuildOptimizer:
"""Test optimizer builder"""
@pytest.fixture
def simple_model(self):
"""Simple model for testing"""
return nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
@pytest.fixture
def args(self):
"""Default training arguments"""
return Namespace(
learning_rate=1e-3,
weight_decay=0.01,
adam_beta1=0.9,
adam_beta2=0.999,
optimizer='adamw'
)
def test_adamw_optimizer(self, simple_model, args):
"""Test AdamW optimizer creation"""
optimizer = build_optimizer(simple_model, args)
assert isinstance(optimizer, torch.optim.AdamW)
assert optimizer.defaults['lr'] == 1e-3
assert optimizer.defaults['weight_decay'] == 0.01
assert optimizer.defaults['betas'] == (0.9, 0.999)
def test_parameter_groups(self, args):
"""Test parameter grouping (no decay for bias/norm)"""
model = nn.Sequential(
nn.Linear(100, 50),
nn.LayerNorm(50),
nn.Linear(50, 10)
)
optimizer = build_optimizer(model, args)
# Should have 2 parameter groups: decay and no_decay
assert len(optimizer.param_groups) == 2
# Check weight decay settings
decay_group = optimizer.param_groups[0]
no_decay_group = optimizer.param_groups[1]
# One should have weight decay, other shouldn't
assert (decay_group['weight_decay'] == 0.01 and no_decay_group['weight_decay'] == 0.0) or \
(decay_group['weight_decay'] == 0.0 and no_decay_group['weight_decay'] == 0.01)
def test_zero_weight_decay(self, simple_model, args):
"""Test with zero weight decay"""
args.weight_decay = 0.0
optimizer = build_optimizer(simple_model, args)
# Should only have one param group when no weight decay
assert len(optimizer.param_groups) >= 1
def test_optimizer_step(self, simple_model, args):
"""Test that optimizer can step"""
optimizer = build_optimizer(simple_model, args)
# Forward pass
x = torch.randn(4, 100)
output = simple_model(x)
loss = output.sum()
# Backward
loss.backward()
# Get parameter before step
first_param_before = list(simple_model.parameters())[0].clone()
# Optimizer step
optimizer.step()
# Parameter should change
first_param_after = list(simple_model.parameters())[0]
assert not torch.equal(first_param_before, first_param_after)
class TestOptimizerIntegration:
"""Integration tests for optimizer"""
def test_training_loop(self):
"""Test optimizer in simple training loop"""
model = nn.Linear(10, 1)
args = Namespace(
learning_rate=0.01,
weight_decay=0.0,
adam_beta1=0.9,
adam_beta2=0.999,
optimizer='adamw'
)
optimizer = build_optimizer(model, args)
# Training data
x = torch.randn(100, 10)
y = torch.randn(100, 1)
initial_loss = None
final_loss = None
# Train for a few steps
for i in range(10):
optimizer.zero_grad()
pred = model(x)
loss = nn.MSELoss()(pred, y)
if i == 0:
initial_loss = loss.item()
if i == 9:
final_loss = loss.item()
loss.backward()
optimizer.step()
# Loss should decrease (model should learn)
assert final_loss < initial_loss
if __name__ == "__main__":
pytest.main([__file__, "-v"])