BitLinear / tests /test_functional.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Unit tests for functional API (bitlinear_python, greedy_ternary_decomposition, etc.)
These tests are here to validate the correctness of the pure PyTorch reference implementations. Here are the following test cases:
TestBitLinearPython (5 tests)
1. test_shape_correctness - Verifies output dimensions for 3D inputs
2. test_no_bias - Tests forward pass without bias term
3. test_ternary_constraint - Validates ternary weight values {-1, 0, +1}
4. test_gamma_scaling - Verifies gamma scaling is applied correctly
5. test_numerical_correctness - Compares against manual torch computation
TestGreedyTernaryDecomposition (4 tests)
1. test_decomposition_shape - Checks output tensor shapes
2. test_ternary_values - Ensures all decomposed weights are ternary
3. test_reconstruction_error - Validates error decreases with more components
4. test_single_component - Tests k=1 edge case
TestMultiTernaryLinearPython (2 tests)
1. test_shape_correctness - Verifies output shape
2. test_equivalence_to_sum - Confirms equivalence to summing individual operations
TestActivationQuant (2 tests)
1. test_quantization_range - Validates quantization behavior and output
2. test_absmax_scaling - Tests per-token absmax scaling
TestFunctionalIntegration (3 tests)
1. test_full_pipeline - End-to-end: decomposition → multi-ternary forward
2. test_bitlinear_with_activation_quant - Combines activation quantization with bitlinear
3. test_multi_ternary_end_to_end - Tests different k values with reconstruction validation
"""
import pytest
import torch
import torch.nn as nn
from bitlinear.functional import (
bitlinear_python,
greedy_ternary_decomposition,
multi_ternary_linear_python,
activation_quant,
)
class TestBitLinearPython:
"""Tests for bitlinear_python function."""
def test_shape_correctness(self):
"""Test that output shape matches expected dimensions."""
batch_size, seq_len, in_features, out_features = 32, 128, 512, 1024
x = torch.randn(batch_size, seq_len, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
bias = torch.zeros(out_features)
output = bitlinear_python(x, W_ternary, gamma, bias)
assert output.shape == (batch_size, seq_len, out_features)
def test_no_bias(self):
"""Test forward pass without bias."""
batch_size, in_features, out_features = 16, 256, 512
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
output = bitlinear_python(x, W_ternary, gamma, bias=None)
assert output.shape == (batch_size, out_features)
assert not torch.isnan(output).any()
def test_ternary_constraint(self):
"""Test that function works correctly with ternary weights {-1, 0, +1}."""
x = torch.randn(8, 64)
W_ternary = torch.randint(-1, 2, (128, 64)).float()
gamma = torch.ones(128)
# Verify W_ternary contains only {-1, 0, +1}
unique_values = torch.unique(W_ternary)
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
# Check output correctness
output = bitlinear_python(x, W_ternary, gamma)
assert output.shape == (8, 128)
assert not torch.isnan(output).any()
def test_gamma_scaling(self):
"""Test that gamma scaling is applied correctly."""
x = torch.randn(4, 32)
W_ternary = torch.randint(-1, 2, (64, 32)).float()
gamma = torch.rand(64) * 2 + 0.5 # Random scales between 0.5 and 2.5
# Compute output with gamma
output_with_gamma = bitlinear_python(x, W_ternary, gamma, bias=None)
# Compute output with gamma=1 and manually scale
gamma_ones = torch.ones_like(gamma)
output_no_gamma = bitlinear_python(x, W_ternary, gamma_ones, bias=None)
output_manual_scale = output_no_gamma * gamma.unsqueeze(0)
# Should be equivalent
assert torch.allclose(output_with_gamma, output_manual_scale, atol=1e-5)
def test_numerical_correctness(self):
"""Test numerical correctness against standard nn.Linear."""
in_features, out_features = 128, 256
x = torch.randn(16, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
bias = torch.randn(out_features)
# Compute with bitlinear_python
output_bitlinear = bitlinear_python(x, W_ternary, gamma, bias)
# Compute manually with torch operations
output_manual = torch.matmul(x, W_ternary.t()) * gamma.unsqueeze(0) + bias
# Should match exactly
assert torch.allclose(output_bitlinear, output_manual, atol=1e-6)
class TestGreedyTernaryDecomposition:
"""Tests for greedy_ternary_decomposition function."""
def test_decomposition_shape(self):
"""Test that decomposition returns correct shapes."""
W = torch.randn(512, 768)
k = 4
W_ternary, gammas = greedy_ternary_decomposition(W, k)
assert W_ternary.shape == (k, 512, 768)
assert gammas.shape == (k, 512)
def test_ternary_values(self):
"""Test that decomposed weights are ternary."""
W = torch.randn(64, 128)
k = 2
W_ternary, gammas = greedy_ternary_decomposition(W, k)
# Verify all values in W_ternary are in {-1, 0, +1}
unique_values = torch.unique(W_ternary)
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()), \
f"Found non-ternary values: {unique_values.tolist()}"
def test_reconstruction_error(self):
"""Test that reconstruction error decreases with more components."""
W = torch.randn(128, 256)
errors = []
for k in [1, 2, 4, 8]:
W_ternary, gammas = greedy_ternary_decomposition(W, k)
# Reconstruct: sum of gamma_i * W_i
reconstruction = torch.zeros_like(W)
for i in range(k):
reconstruction += gammas[i].unsqueeze(1) * W_ternary[i]
error = torch.norm(W - reconstruction).item()
errors.append(error)
# Error should decrease with more components
assert errors[0] > errors[1], f"Error not decreasing: {errors[0]} vs {errors[1]}"
assert errors[1] > errors[2], f"Error not decreasing: {errors[1]} vs {errors[2]}"
assert errors[2] > errors[3], f"Error not decreasing: {errors[2]} vs {errors[3]}"
def test_single_component(self):
"""Test k=1 case (single ternary quantization)."""
W = torch.randn(32, 64)
k = 1
W_ternary, gammas = greedy_ternary_decomposition(W, k)
assert W_ternary.shape == (1, 32, 64)
assert gammas.shape == (1, 32)
# Verify ternary values
unique_values = torch.unique(W_ternary)
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
class TestMultiTernaryLinearPython:
"""Tests for multi_ternary_linear_python function."""
def test_shape_correctness(self):
"""Test output shape for multi-ternary linear."""
batch_size, in_features, out_features = 16, 128, 256
k = 4
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
gammas = torch.rand(k, out_features)
bias = torch.randn(out_features)
output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
assert output.shape == (batch_size, out_features)
def test_equivalence_to_sum(self):
"""Test that multi-ternary equals sum of individual ternary ops."""
batch_size, in_features, out_features = 8, 64, 128
k = 3
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
gammas = torch.rand(k, out_features)
bias = torch.randn(out_features)
# Compute multi-ternary in one call
output_multi = multi_ternary_linear_python(x, W_ternary, gammas, bias)
# Compute sum of k separate bitlinear_python calls
output_sum = torch.zeros(batch_size, out_features)
for i in range(k):
output_sum += bitlinear_python(x, W_ternary[i], gammas[i], bias=None)
output_sum += bias # Add bias once at the end
# Verify they match
assert torch.allclose(output_multi, output_sum, atol=1e-5)
class TestActivationQuant:
"""Tests for activation quantization."""
def test_quantization_range(self):
"""Test that quantized activations are in expected range."""
x = torch.randn(16, 128, 512) * 10 # Large range
bits = 8
x_quant = activation_quant(x, bits=bits)
# Output should have same shape
assert x_quant.shape == x.shape
# Check that quantization reduces precision (should be close but not exact)
assert not torch.allclose(x, x_quant, atol=1e-6)
# Quantized values should still be in reasonable range
assert torch.isfinite(x_quant).all()
def test_absmax_scaling(self):
"""Test that absmax scaling is applied correctly."""
# Create input with known range per token
x = torch.tensor([
[1.0, 2.0, 3.0, 4.0],
[-5.0, -10.0, 5.0, 10.0],
])
x_quant = activation_quant(x, bits=8)
# Should preserve relative magnitudes within each token
# First token: max is 4.0
# Second token: max is 10.0
assert x_quant.shape == (2, 4)
assert torch.isfinite(x_quant).all()
# The quantized values should be close to original for 8-bit
# (127 levels provide good precision)
relative_error = torch.abs(x - x_quant) / (torch.abs(x) + 1e-5)
assert relative_error.mean() < 0.1 # Less than 10% average error
# Integration test
class TestFunctionalIntegration:
"""Integration tests combining multiple functional components."""
def test_full_pipeline(self):
"""Test full pipeline: decomposition → multi-ternary forward."""
# 1. Create dense weights
in_features, out_features = 256, 512
W_dense = torch.randn(out_features, in_features)
# 2. Apply greedy decomposition
k = 4
W_ternary, gammas = greedy_ternary_decomposition(W_dense, k)
# 3. Run multi_ternary_linear_python
batch_size = 16
x = torch.randn(batch_size, in_features)
bias = torch.randn(out_features)
output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
# 4. Verify output shape and basic correctness
assert output.shape == (batch_size, out_features)
assert torch.isfinite(output).all()
# Compare with dense operation to verify reasonable approximation
output_dense = torch.matmul(x, W_dense.t()) + bias
# They should be similar but not identical (due to quantization)
relative_error = torch.norm(output - output_dense) / torch.norm(output_dense)
assert relative_error < 1.0 # Error should be reasonable
def test_bitlinear_with_activation_quant(self):
"""Test combining bitlinear with activation quantization."""
batch_size, in_features, out_features = 8, 128, 256
# Create inputs
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
# Quantize activations
x_quant = activation_quant(x, bits=8)
# Forward pass
output = bitlinear_python(x_quant, W_ternary, gamma)
# Check output
assert output.shape == (batch_size, out_features)
assert torch.isfinite(output).all()
def test_multi_ternary_end_to_end(self):
"""Test multi-ternary from weight decomposition to forward pass."""
# Simulate a small layer
W = torch.randn(64, 128) * 0.1 # Small weights for numerical stability
x = torch.randn(4, 128)
# Decompose with different k values
for k in [1, 2, 4]:
W_ternary, gammas = greedy_ternary_decomposition(W, k)
output = multi_ternary_linear_python(x, W_ternary, gammas, bias=None)
# Check output is valid
assert output.shape == (4, 64)
assert torch.isfinite(output).all()
# Verify reconstruction quality
W_reconstructed = torch.zeros_like(W)
for i in range(k):
W_reconstructed += gammas[i].unsqueeze(1) * W_ternary[i]
# Compute expected output with reconstructed weights
output_expected = torch.matmul(x, W_reconstructed.t())
# Should match closely
assert torch.allclose(output, output_expected, atol=1e-4)