""" Unit tests for functional API (bitlinear_python, greedy_ternary_decomposition, etc.) These tests are here to validate the correctness of the pure PyTorch reference implementations. Here are the following test cases: TestBitLinearPython (5 tests) 1. test_shape_correctness - Verifies output dimensions for 3D inputs 2. test_no_bias - Tests forward pass without bias term 3. test_ternary_constraint - Validates ternary weight values {-1, 0, +1} 4. test_gamma_scaling - Verifies gamma scaling is applied correctly 5. test_numerical_correctness - Compares against manual torch computation TestGreedyTernaryDecomposition (4 tests) 1. test_decomposition_shape - Checks output tensor shapes 2. test_ternary_values - Ensures all decomposed weights are ternary 3. test_reconstruction_error - Validates error decreases with more components 4. test_single_component - Tests k=1 edge case TestMultiTernaryLinearPython (2 tests) 1. test_shape_correctness - Verifies output shape 2. test_equivalence_to_sum - Confirms equivalence to summing individual operations TestActivationQuant (2 tests) 1. test_quantization_range - Validates quantization behavior and output 2. test_absmax_scaling - Tests per-token absmax scaling TestFunctionalIntegration (3 tests) 1. test_full_pipeline - End-to-end: decomposition → multi-ternary forward 2. test_bitlinear_with_activation_quant - Combines activation quantization with bitlinear 3. test_multi_ternary_end_to_end - Tests different k values with reconstruction validation """ import pytest import torch import torch.nn as nn from bitlinear.functional import ( bitlinear_python, greedy_ternary_decomposition, multi_ternary_linear_python, activation_quant, ) class TestBitLinearPython: """Tests for bitlinear_python function.""" def test_shape_correctness(self): """Test that output shape matches expected dimensions.""" batch_size, seq_len, in_features, out_features = 32, 128, 512, 1024 x = torch.randn(batch_size, seq_len, in_features) W_ternary = torch.randint(-1, 2, (out_features, in_features)).float() gamma = torch.ones(out_features) bias = torch.zeros(out_features) output = bitlinear_python(x, W_ternary, gamma, bias) assert output.shape == (batch_size, seq_len, out_features) def test_no_bias(self): """Test forward pass without bias.""" batch_size, in_features, out_features = 16, 256, 512 x = torch.randn(batch_size, in_features) W_ternary = torch.randint(-1, 2, (out_features, in_features)).float() gamma = torch.ones(out_features) output = bitlinear_python(x, W_ternary, gamma, bias=None) assert output.shape == (batch_size, out_features) assert not torch.isnan(output).any() def test_ternary_constraint(self): """Test that function works correctly with ternary weights {-1, 0, +1}.""" x = torch.randn(8, 64) W_ternary = torch.randint(-1, 2, (128, 64)).float() gamma = torch.ones(128) # Verify W_ternary contains only {-1, 0, +1} unique_values = torch.unique(W_ternary) assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()) # Check output correctness output = bitlinear_python(x, W_ternary, gamma) assert output.shape == (8, 128) assert not torch.isnan(output).any() def test_gamma_scaling(self): """Test that gamma scaling is applied correctly.""" x = torch.randn(4, 32) W_ternary = torch.randint(-1, 2, (64, 32)).float() gamma = torch.rand(64) * 2 + 0.5 # Random scales between 0.5 and 2.5 # Compute output with gamma output_with_gamma = bitlinear_python(x, W_ternary, gamma, bias=None) # Compute output with gamma=1 and manually scale gamma_ones = torch.ones_like(gamma) output_no_gamma = bitlinear_python(x, W_ternary, gamma_ones, bias=None) output_manual_scale = output_no_gamma * gamma.unsqueeze(0) # Should be equivalent assert torch.allclose(output_with_gamma, output_manual_scale, atol=1e-5) def test_numerical_correctness(self): """Test numerical correctness against standard nn.Linear.""" in_features, out_features = 128, 256 x = torch.randn(16, in_features) W_ternary = torch.randint(-1, 2, (out_features, in_features)).float() gamma = torch.ones(out_features) bias = torch.randn(out_features) # Compute with bitlinear_python output_bitlinear = bitlinear_python(x, W_ternary, gamma, bias) # Compute manually with torch operations output_manual = torch.matmul(x, W_ternary.t()) * gamma.unsqueeze(0) + bias # Should match exactly assert torch.allclose(output_bitlinear, output_manual, atol=1e-6) class TestGreedyTernaryDecomposition: """Tests for greedy_ternary_decomposition function.""" def test_decomposition_shape(self): """Test that decomposition returns correct shapes.""" W = torch.randn(512, 768) k = 4 W_ternary, gammas = greedy_ternary_decomposition(W, k) assert W_ternary.shape == (k, 512, 768) assert gammas.shape == (k, 512) def test_ternary_values(self): """Test that decomposed weights are ternary.""" W = torch.randn(64, 128) k = 2 W_ternary, gammas = greedy_ternary_decomposition(W, k) # Verify all values in W_ternary are in {-1, 0, +1} unique_values = torch.unique(W_ternary) assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()), \ f"Found non-ternary values: {unique_values.tolist()}" def test_reconstruction_error(self): """Test that reconstruction error decreases with more components.""" W = torch.randn(128, 256) errors = [] for k in [1, 2, 4, 8]: W_ternary, gammas = greedy_ternary_decomposition(W, k) # Reconstruct: sum of gamma_i * W_i reconstruction = torch.zeros_like(W) for i in range(k): reconstruction += gammas[i].unsqueeze(1) * W_ternary[i] error = torch.norm(W - reconstruction).item() errors.append(error) # Error should decrease with more components assert errors[0] > errors[1], f"Error not decreasing: {errors[0]} vs {errors[1]}" assert errors[1] > errors[2], f"Error not decreasing: {errors[1]} vs {errors[2]}" assert errors[2] > errors[3], f"Error not decreasing: {errors[2]} vs {errors[3]}" def test_single_component(self): """Test k=1 case (single ternary quantization).""" W = torch.randn(32, 64) k = 1 W_ternary, gammas = greedy_ternary_decomposition(W, k) assert W_ternary.shape == (1, 32, 64) assert gammas.shape == (1, 32) # Verify ternary values unique_values = torch.unique(W_ternary) assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()) class TestMultiTernaryLinearPython: """Tests for multi_ternary_linear_python function.""" def test_shape_correctness(self): """Test output shape for multi-ternary linear.""" batch_size, in_features, out_features = 16, 128, 256 k = 4 x = torch.randn(batch_size, in_features) W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float() gammas = torch.rand(k, out_features) bias = torch.randn(out_features) output = multi_ternary_linear_python(x, W_ternary, gammas, bias) assert output.shape == (batch_size, out_features) def test_equivalence_to_sum(self): """Test that multi-ternary equals sum of individual ternary ops.""" batch_size, in_features, out_features = 8, 64, 128 k = 3 x = torch.randn(batch_size, in_features) W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float() gammas = torch.rand(k, out_features) bias = torch.randn(out_features) # Compute multi-ternary in one call output_multi = multi_ternary_linear_python(x, W_ternary, gammas, bias) # Compute sum of k separate bitlinear_python calls output_sum = torch.zeros(batch_size, out_features) for i in range(k): output_sum += bitlinear_python(x, W_ternary[i], gammas[i], bias=None) output_sum += bias # Add bias once at the end # Verify they match assert torch.allclose(output_multi, output_sum, atol=1e-5) class TestActivationQuant: """Tests for activation quantization.""" def test_quantization_range(self): """Test that quantized activations are in expected range.""" x = torch.randn(16, 128, 512) * 10 # Large range bits = 8 x_quant = activation_quant(x, bits=bits) # Output should have same shape assert x_quant.shape == x.shape # Check that quantization reduces precision (should be close but not exact) assert not torch.allclose(x, x_quant, atol=1e-6) # Quantized values should still be in reasonable range assert torch.isfinite(x_quant).all() def test_absmax_scaling(self): """Test that absmax scaling is applied correctly.""" # Create input with known range per token x = torch.tensor([ [1.0, 2.0, 3.0, 4.0], [-5.0, -10.0, 5.0, 10.0], ]) x_quant = activation_quant(x, bits=8) # Should preserve relative magnitudes within each token # First token: max is 4.0 # Second token: max is 10.0 assert x_quant.shape == (2, 4) assert torch.isfinite(x_quant).all() # The quantized values should be close to original for 8-bit # (127 levels provide good precision) relative_error = torch.abs(x - x_quant) / (torch.abs(x) + 1e-5) assert relative_error.mean() < 0.1 # Less than 10% average error # Integration test class TestFunctionalIntegration: """Integration tests combining multiple functional components.""" def test_full_pipeline(self): """Test full pipeline: decomposition → multi-ternary forward.""" # 1. Create dense weights in_features, out_features = 256, 512 W_dense = torch.randn(out_features, in_features) # 2. Apply greedy decomposition k = 4 W_ternary, gammas = greedy_ternary_decomposition(W_dense, k) # 3. Run multi_ternary_linear_python batch_size = 16 x = torch.randn(batch_size, in_features) bias = torch.randn(out_features) output = multi_ternary_linear_python(x, W_ternary, gammas, bias) # 4. Verify output shape and basic correctness assert output.shape == (batch_size, out_features) assert torch.isfinite(output).all() # Compare with dense operation to verify reasonable approximation output_dense = torch.matmul(x, W_dense.t()) + bias # They should be similar but not identical (due to quantization) relative_error = torch.norm(output - output_dense) / torch.norm(output_dense) assert relative_error < 1.0 # Error should be reasonable def test_bitlinear_with_activation_quant(self): """Test combining bitlinear with activation quantization.""" batch_size, in_features, out_features = 8, 128, 256 # Create inputs x = torch.randn(batch_size, in_features) W_ternary = torch.randint(-1, 2, (out_features, in_features)).float() gamma = torch.ones(out_features) # Quantize activations x_quant = activation_quant(x, bits=8) # Forward pass output = bitlinear_python(x_quant, W_ternary, gamma) # Check output assert output.shape == (batch_size, out_features) assert torch.isfinite(output).all() def test_multi_ternary_end_to_end(self): """Test multi-ternary from weight decomposition to forward pass.""" # Simulate a small layer W = torch.randn(64, 128) * 0.1 # Small weights for numerical stability x = torch.randn(4, 128) # Decompose with different k values for k in [1, 2, 4]: W_ternary, gammas = greedy_ternary_decomposition(W, k) output = multi_ternary_linear_python(x, W_ternary, gammas, bias=None) # Check output is valid assert output.shape == (4, 64) assert torch.isfinite(output).all() # Verify reconstruction quality W_reconstructed = torch.zeros_like(W) for i in range(k): W_reconstructed += gammas[i].unsqueeze(1) * W_ternary[i] # Compute expected output with reconstructed weights output_expected = torch.matmul(x, W_reconstructed.t()) # Should match closely assert torch.allclose(output, output_expected, atol=1e-4)