File size: 13,852 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
"""
Unit tests for functional API (bitlinear_python, greedy_ternary_decomposition, etc.)
These tests are here to validate the correctness of the pure PyTorch reference implementations. Here are the following test cases:
TestBitLinearPython (5 tests)
1. test_shape_correctness - Verifies output dimensions for 3D inputs
2. test_no_bias - Tests forward pass without bias term
3. test_ternary_constraint - Validates ternary weight values {-1, 0, +1}
4. test_gamma_scaling - Verifies gamma scaling is applied correctly
5. test_numerical_correctness - Compares against manual torch computation
TestGreedyTernaryDecomposition (4 tests)
1. test_decomposition_shape - Checks output tensor shapes
2. test_ternary_values - Ensures all decomposed weights are ternary
3. test_reconstruction_error - Validates error decreases with more components
4. test_single_component - Tests k=1 edge case
TestMultiTernaryLinearPython (2 tests)
1. test_shape_correctness - Verifies output shape
2. test_equivalence_to_sum - Confirms equivalence to summing individual operations
TestActivationQuant (2 tests)
1. test_quantization_range - Validates quantization behavior and output
2. test_absmax_scaling - Tests per-token absmax scaling
TestFunctionalIntegration (3 tests)
1. test_full_pipeline - End-to-end: decomposition → multi-ternary forward
2. test_bitlinear_with_activation_quant - Combines activation quantization with bitlinear
3. test_multi_ternary_end_to_end - Tests different k values with reconstruction validation
"""
import pytest
import torch
import torch.nn as nn
from bitlinear.functional import (
bitlinear_python,
greedy_ternary_decomposition,
multi_ternary_linear_python,
activation_quant,
)
class TestBitLinearPython:
"""Tests for bitlinear_python function."""
def test_shape_correctness(self):
"""Test that output shape matches expected dimensions."""
batch_size, seq_len, in_features, out_features = 32, 128, 512, 1024
x = torch.randn(batch_size, seq_len, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
bias = torch.zeros(out_features)
output = bitlinear_python(x, W_ternary, gamma, bias)
assert output.shape == (batch_size, seq_len, out_features)
def test_no_bias(self):
"""Test forward pass without bias."""
batch_size, in_features, out_features = 16, 256, 512
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
output = bitlinear_python(x, W_ternary, gamma, bias=None)
assert output.shape == (batch_size, out_features)
assert not torch.isnan(output).any()
def test_ternary_constraint(self):
"""Test that function works correctly with ternary weights {-1, 0, +1}."""
x = torch.randn(8, 64)
W_ternary = torch.randint(-1, 2, (128, 64)).float()
gamma = torch.ones(128)
# Verify W_ternary contains only {-1, 0, +1}
unique_values = torch.unique(W_ternary)
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
# Check output correctness
output = bitlinear_python(x, W_ternary, gamma)
assert output.shape == (8, 128)
assert not torch.isnan(output).any()
def test_gamma_scaling(self):
"""Test that gamma scaling is applied correctly."""
x = torch.randn(4, 32)
W_ternary = torch.randint(-1, 2, (64, 32)).float()
gamma = torch.rand(64) * 2 + 0.5 # Random scales between 0.5 and 2.5
# Compute output with gamma
output_with_gamma = bitlinear_python(x, W_ternary, gamma, bias=None)
# Compute output with gamma=1 and manually scale
gamma_ones = torch.ones_like(gamma)
output_no_gamma = bitlinear_python(x, W_ternary, gamma_ones, bias=None)
output_manual_scale = output_no_gamma * gamma.unsqueeze(0)
# Should be equivalent
assert torch.allclose(output_with_gamma, output_manual_scale, atol=1e-5)
def test_numerical_correctness(self):
"""Test numerical correctness against standard nn.Linear."""
in_features, out_features = 128, 256
x = torch.randn(16, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
bias = torch.randn(out_features)
# Compute with bitlinear_python
output_bitlinear = bitlinear_python(x, W_ternary, gamma, bias)
# Compute manually with torch operations
output_manual = torch.matmul(x, W_ternary.t()) * gamma.unsqueeze(0) + bias
# Should match exactly
assert torch.allclose(output_bitlinear, output_manual, atol=1e-6)
class TestGreedyTernaryDecomposition:
"""Tests for greedy_ternary_decomposition function."""
def test_decomposition_shape(self):
"""Test that decomposition returns correct shapes."""
W = torch.randn(512, 768)
k = 4
W_ternary, gammas = greedy_ternary_decomposition(W, k)
assert W_ternary.shape == (k, 512, 768)
assert gammas.shape == (k, 512)
def test_ternary_values(self):
"""Test that decomposed weights are ternary."""
W = torch.randn(64, 128)
k = 2
W_ternary, gammas = greedy_ternary_decomposition(W, k)
# Verify all values in W_ternary are in {-1, 0, +1}
unique_values = torch.unique(W_ternary)
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()), \
f"Found non-ternary values: {unique_values.tolist()}"
def test_reconstruction_error(self):
"""Test that reconstruction error decreases with more components."""
W = torch.randn(128, 256)
errors = []
for k in [1, 2, 4, 8]:
W_ternary, gammas = greedy_ternary_decomposition(W, k)
# Reconstruct: sum of gamma_i * W_i
reconstruction = torch.zeros_like(W)
for i in range(k):
reconstruction += gammas[i].unsqueeze(1) * W_ternary[i]
error = torch.norm(W - reconstruction).item()
errors.append(error)
# Error should decrease with more components
assert errors[0] > errors[1], f"Error not decreasing: {errors[0]} vs {errors[1]}"
assert errors[1] > errors[2], f"Error not decreasing: {errors[1]} vs {errors[2]}"
assert errors[2] > errors[3], f"Error not decreasing: {errors[2]} vs {errors[3]}"
def test_single_component(self):
"""Test k=1 case (single ternary quantization)."""
W = torch.randn(32, 64)
k = 1
W_ternary, gammas = greedy_ternary_decomposition(W, k)
assert W_ternary.shape == (1, 32, 64)
assert gammas.shape == (1, 32)
# Verify ternary values
unique_values = torch.unique(W_ternary)
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
class TestMultiTernaryLinearPython:
"""Tests for multi_ternary_linear_python function."""
def test_shape_correctness(self):
"""Test output shape for multi-ternary linear."""
batch_size, in_features, out_features = 16, 128, 256
k = 4
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
gammas = torch.rand(k, out_features)
bias = torch.randn(out_features)
output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
assert output.shape == (batch_size, out_features)
def test_equivalence_to_sum(self):
"""Test that multi-ternary equals sum of individual ternary ops."""
batch_size, in_features, out_features = 8, 64, 128
k = 3
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
gammas = torch.rand(k, out_features)
bias = torch.randn(out_features)
# Compute multi-ternary in one call
output_multi = multi_ternary_linear_python(x, W_ternary, gammas, bias)
# Compute sum of k separate bitlinear_python calls
output_sum = torch.zeros(batch_size, out_features)
for i in range(k):
output_sum += bitlinear_python(x, W_ternary[i], gammas[i], bias=None)
output_sum += bias # Add bias once at the end
# Verify they match
assert torch.allclose(output_multi, output_sum, atol=1e-5)
class TestActivationQuant:
"""Tests for activation quantization."""
def test_quantization_range(self):
"""Test that quantized activations are in expected range."""
x = torch.randn(16, 128, 512) * 10 # Large range
bits = 8
x_quant = activation_quant(x, bits=bits)
# Output should have same shape
assert x_quant.shape == x.shape
# Check that quantization reduces precision (should be close but not exact)
assert not torch.allclose(x, x_quant, atol=1e-6)
# Quantized values should still be in reasonable range
assert torch.isfinite(x_quant).all()
def test_absmax_scaling(self):
"""Test that absmax scaling is applied correctly."""
# Create input with known range per token
x = torch.tensor([
[1.0, 2.0, 3.0, 4.0],
[-5.0, -10.0, 5.0, 10.0],
])
x_quant = activation_quant(x, bits=8)
# Should preserve relative magnitudes within each token
# First token: max is 4.0
# Second token: max is 10.0
assert x_quant.shape == (2, 4)
assert torch.isfinite(x_quant).all()
# The quantized values should be close to original for 8-bit
# (127 levels provide good precision)
relative_error = torch.abs(x - x_quant) / (torch.abs(x) + 1e-5)
assert relative_error.mean() < 0.1 # Less than 10% average error
# Integration test
class TestFunctionalIntegration:
"""Integration tests combining multiple functional components."""
def test_full_pipeline(self):
"""Test full pipeline: decomposition → multi-ternary forward."""
# 1. Create dense weights
in_features, out_features = 256, 512
W_dense = torch.randn(out_features, in_features)
# 2. Apply greedy decomposition
k = 4
W_ternary, gammas = greedy_ternary_decomposition(W_dense, k)
# 3. Run multi_ternary_linear_python
batch_size = 16
x = torch.randn(batch_size, in_features)
bias = torch.randn(out_features)
output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
# 4. Verify output shape and basic correctness
assert output.shape == (batch_size, out_features)
assert torch.isfinite(output).all()
# Compare with dense operation to verify reasonable approximation
output_dense = torch.matmul(x, W_dense.t()) + bias
# They should be similar but not identical (due to quantization)
relative_error = torch.norm(output - output_dense) / torch.norm(output_dense)
assert relative_error < 1.0 # Error should be reasonable
def test_bitlinear_with_activation_quant(self):
"""Test combining bitlinear with activation quantization."""
batch_size, in_features, out_features = 8, 128, 256
# Create inputs
x = torch.randn(batch_size, in_features)
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
gamma = torch.ones(out_features)
# Quantize activations
x_quant = activation_quant(x, bits=8)
# Forward pass
output = bitlinear_python(x_quant, W_ternary, gamma)
# Check output
assert output.shape == (batch_size, out_features)
assert torch.isfinite(output).all()
def test_multi_ternary_end_to_end(self):
"""Test multi-ternary from weight decomposition to forward pass."""
# Simulate a small layer
W = torch.randn(64, 128) * 0.1 # Small weights for numerical stability
x = torch.randn(4, 128)
# Decompose with different k values
for k in [1, 2, 4]:
W_ternary, gammas = greedy_ternary_decomposition(W, k)
output = multi_ternary_linear_python(x, W_ternary, gammas, bias=None)
# Check output is valid
assert output.shape == (4, 64)
assert torch.isfinite(output).all()
# Verify reconstruction quality
W_reconstructed = torch.zeros_like(W)
for i in range(k):
W_reconstructed += gammas[i].unsqueeze(1) * W_ternary[i]
# Compute expected output with reconstructed weights
output_expected = torch.matmul(x, W_reconstructed.t())
# Should match closely
assert torch.allclose(output, output_expected, atol=1e-4)
|