""" Unit tests for BitLinear and MultiTernaryLinear layers. These tests are here to validate the nn.Module implementations and their compatibility with standard PyTorch workflows. Here are the following test cases: TestBitLinear (8 tests) 1. test_initialization - Verifies layer initializes with correct shapes 2. test_no_bias_initialization - Tests initialization without bias parameter 3. test_forward_shape - Validates output shape correctness 4. test_compatibility_with_nn_linear - Tests interface compatibility with nn.Linear 5. test_from_linear_conversion - Verifies conversion from nn.Linear to BitLinear 6. test_parameter_count - Validates parameter count calculation 7. test_weight_values_are_ternary - Ensures weights are in {-1, 0, +1} 8. test_gradient_flow - Tests gradient flow for QAT support TestMultiTernaryLinear (5 tests) 1. test_initialization - Verifies k-component initialization 2. test_forward_shape - Tests forward pass output shape 3. test_k_components - Validates k-component tensor shapes 4. test_from_linear_conversion - Tests conversion with k parameter 5. test_better_approximation_with_more_k - Validates error decreases with larger k TestConversionUtilities (3 tests) 1. test_convert_simple_model - Tests conversion of Sequential models 2. test_convert_nested_model - Tests conversion of nested module hierarchies 3. test_inplace_conversion - Tests in-place vs. copy conversion modes TestLayerIntegration (3 tests) 1. test_in_transformer_block - Tests BitLinear in Transformer FFN block 2. test_training_step - Validates full training loop compatibility 3. test_save_and_load - Tests model serialization and deserialization TestPerformanceComparison (2 tests - skipped) 1. test_memory_usage - Performance benchmark (run manually) 2. test_inference_speed - Performance benchmark (run manually) """ import pytest import torch import torch.nn as nn from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear class TestBitLinear: """Tests for BitLinear layer.""" def test_initialization(self): """Test that layer initializes correctly.""" layer = BitLinear(512, 1024) assert layer.in_features == 512 assert layer.out_features == 1024 assert layer.bias is not None assert layer.W_ternary.shape == (1024, 512) assert layer.gamma.shape == (1024,) def test_no_bias_initialization(self): """Test initialization without bias.""" layer = BitLinear(512, 1024, bias=False) assert layer.bias is None def test_forward_shape(self): """Test forward pass produces correct output shape.""" layer = BitLinear(512, 1024) x = torch.randn(32, 128, 512) output = layer(x) assert output.shape == (32, 128, 1024) def test_compatibility_with_nn_linear(self): """Test that BitLinear can replace nn.Linear in terms of interface.""" linear = nn.Linear(512, 512) bitlinear = BitLinear(512, 512) x = torch.randn(32, 512) out_linear = linear(x) out_bitlinear = bitlinear(x) # Shapes should match (values will differ due to quantization) assert out_linear.shape == out_bitlinear.shape def test_from_linear_conversion(self): """Test converting nn.Linear to BitLinear.""" linear = nn.Linear(512, 1024) bitlinear = BitLinear.from_linear(linear) assert bitlinear.in_features == 512 assert bitlinear.out_features == 1024 # Test forward pass x = torch.randn(16, 512) output = bitlinear(x) assert output.shape == (16, 1024) def test_parameter_count(self): """Test that parameter count is correct.""" layer = BitLinear(512, 512, bias=True) # W_ternary: 512*512, gamma: 512, bias: 512 expected_params = 512*512 + 512 + 512 actual_params = sum(p.numel() for p in layer.parameters()) assert actual_params == expected_params def test_weight_values_are_ternary(self): """Test that stored weights are ternary {-1, 0, +1}.""" layer = BitLinear(512, 512) W_ternary = layer.W_ternary unique_values = torch.unique(W_ternary) assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0}) def test_gradient_flow(self): """Test that gradients flow correctly (for QAT).""" layer = BitLinear(256, 128) x = torch.randn(8, 256, requires_grad=True) output = layer(x) loss = output.sum() loss.backward() # Check that input has gradients assert x.grad is not None # Check that parameters have gradients assert layer.W_ternary.grad is not None assert layer.gamma.grad is not None class TestMultiTernaryLinear: """Tests for MultiTernaryLinear layer.""" def test_initialization(self): """Test layer initialization with k components.""" layer = MultiTernaryLinear(512, 1024, k=4) assert layer.in_features == 512 assert layer.out_features == 1024 assert layer.k == 4 assert layer.W_ternary.shape == (4, 1024, 512) assert layer.gammas.shape == (4, 1024) def test_forward_shape(self): """Test forward pass shape.""" layer = MultiTernaryLinear(512, 1024, k=4) x = torch.randn(32, 128, 512) output = layer(x) assert output.shape == (32, 128, 1024) def test_k_components(self): """Test that layer uses k ternary components.""" layer = MultiTernaryLinear(512, 512, k=3) assert layer.W_ternary.shape == (3, 512, 512) assert layer.gammas.shape == (3, 512) def test_from_linear_conversion(self): """Test converting nn.Linear to MultiTernaryLinear.""" linear = nn.Linear(512, 1024) multi_ternary = MultiTernaryLinear.from_linear(linear, k=4) assert multi_ternary.k == 4 assert multi_ternary.in_features == 512 assert multi_ternary.out_features == 1024 def test_better_approximation_with_more_k(self): """Test that larger k provides better approximation of dense layer.""" linear = nn.Linear(512, 512) x = torch.randn(16, 512) out_dense = linear(x) # Compare approximation quality for different k errors = [] for k in [1, 2, 4]: multi_ternary = MultiTernaryLinear.from_linear(linear, k=k) out_ternary = multi_ternary(x) error = torch.norm(out_dense - out_ternary) errors.append(error) # Error should generally decrease with larger k assert errors[0] > errors[1] and errors[1] > errors[2] class TestConversionUtilities: """Tests for model conversion utilities.""" def test_convert_simple_model(self): """Test converting a simple Sequential model.""" model = nn.Sequential( nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, 512), ) model_bitlinear = convert_linear_to_bitlinear(model, inplace=False) # Check that Linear layers are replaced assert isinstance(model_bitlinear[0], BitLinear) assert isinstance(model_bitlinear[2], BitLinear) assert isinstance(model_bitlinear[1], nn.ReLU) def test_convert_nested_model(self): """Test converting a nested model with submodules.""" class NestedModel(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Linear(256, 512) self.submodule = nn.Sequential( nn.Linear(512, 512), nn.ReLU(), ) self.layer2 = nn.Linear(512, 128) model = NestedModel() model_bitlinear = convert_linear_to_bitlinear(model, inplace=False) # Check conversions assert isinstance(model_bitlinear.layer1, BitLinear) assert isinstance(model_bitlinear.submodule[0], BitLinear) assert isinstance(model_bitlinear.layer2, BitLinear) def test_inplace_conversion(self): """Test in-place vs. copy conversion.""" model = nn.Sequential(nn.Linear(256, 256)) # Test inplace=False creates a copy model_copy = convert_linear_to_bitlinear(model, inplace=False) assert id(model) != id(model_copy) assert isinstance(model[0], nn.Linear) # Original unchanged assert isinstance(model_copy[0], BitLinear) # Copy converted # Test inplace=True modifies original model2 = nn.Sequential(nn.Linear(256, 256)) model2_result = convert_linear_to_bitlinear(model2, inplace=True) assert id(model2) == id(model2_result) assert isinstance(model2[0], BitLinear) # Original modified class TestLayerIntegration: """Integration tests for layers in realistic scenarios.""" def test_in_transformer_block(self): """Test BitLinear in a Transformer attention block.""" # Create a simplified Transformer FFN block class TransformerFFN(nn.Module): def __init__(self, d_model=256, d_ff=1024): super().__init__() self.fc1 = BitLinear(d_model, d_ff) self.relu = nn.ReLU() self.fc2 = BitLinear(d_ff, d_model) self.dropout = nn.Dropout(0.1) def forward(self, x): return self.dropout(self.fc2(self.relu(self.fc1(x)))) model = TransformerFFN() # Test forward pass batch_size, seq_len, d_model = 8, 32, 256 x = torch.randn(batch_size, seq_len, d_model) output = model(x) # Verify shape assert output.shape == (batch_size, seq_len, d_model) # Verify weights are ternary assert set(model.fc1.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0}) assert set(model.fc2.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0}) def test_training_step(self): """Test that layers work in a training loop.""" # Create simple model model = nn.Sequential( BitLinear(128, 256), nn.ReLU(), BitLinear(256, 10), ) # Create optimizer optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Forward pass x = torch.randn(16, 128) output = model(x) # Compute loss target = torch.randint(0, 10, (16,)) loss = nn.functional.cross_entropy(output, target) # Backward pass optimizer.zero_grad() loss.backward() # Verify gradients exist assert model[0].W_ternary.grad is not None assert model[0].gamma.grad is not None # Optimizer step optimizer.step() # Verify no errors and loss is finite assert torch.isfinite(loss) def test_save_and_load(self): """Test saving and loading models with BitLinear layers.""" import tempfile import os # Create model model = nn.Sequential( BitLinear(128, 256), nn.ReLU(), BitLinear(256, 64), ) # Save model with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as f: temp_path = f.name torch.save(model.state_dict(), temp_path) try: # Create new model and load weights model_loaded = nn.Sequential( BitLinear(128, 256), nn.ReLU(), BitLinear(256, 64), ) model_loaded.load_state_dict(torch.load(temp_path)) # Verify weights match assert torch.allclose(model[0].W_ternary, model_loaded[0].W_ternary) assert torch.allclose(model[0].gamma, model_loaded[0].gamma) assert torch.allclose(model[2].W_ternary, model_loaded[2].W_ternary) assert torch.allclose(model[2].gamma, model_loaded[2].gamma) # Verify forward pass produces same output x = torch.randn(8, 128) with torch.no_grad(): out1 = model(x) out2 = model_loaded(x) assert torch.allclose(out1, out2) finally: # Clean up os.unlink(temp_path) # Performance comparison tests class TestPerformanceComparison: """Tests comparing BitLinear to standard nn.Linear.""" @pytest.mark.skip("Performance test - run manually") def test_memory_usage(self): """Compare memory usage of BitLinear vs. nn.Linear.""" # TODO: Implement test # Measure memory for large layers # BitLinear should use significantly less memory pass @pytest.mark.skip("Performance test - run manually") def test_inference_speed(self): """Compare inference speed (when CUDA kernels are implemented).""" # TODO: Implement test pass