|
|
"""
|
|
|
Unit tests for functional API (bitlinear_python, greedy_ternary_decomposition, etc.)
|
|
|
|
|
|
These tests are here to validate the correctness of the pure PyTorch reference implementations. Here are the following test cases:
|
|
|
|
|
|
TestBitLinearPython (5 tests)
|
|
|
1. test_shape_correctness - Verifies output dimensions for 3D inputs
|
|
|
2. test_no_bias - Tests forward pass without bias term
|
|
|
3. test_ternary_constraint - Validates ternary weight values {-1, 0, +1}
|
|
|
4. test_gamma_scaling - Verifies gamma scaling is applied correctly
|
|
|
5. test_numerical_correctness - Compares against manual torch computation
|
|
|
|
|
|
TestGreedyTernaryDecomposition (4 tests)
|
|
|
1. test_decomposition_shape - Checks output tensor shapes
|
|
|
2. test_ternary_values - Ensures all decomposed weights are ternary
|
|
|
3. test_reconstruction_error - Validates error decreases with more components
|
|
|
4. test_single_component - Tests k=1 edge case
|
|
|
|
|
|
TestMultiTernaryLinearPython (2 tests)
|
|
|
1. test_shape_correctness - Verifies output shape
|
|
|
2. test_equivalence_to_sum - Confirms equivalence to summing individual operations
|
|
|
|
|
|
TestActivationQuant (2 tests)
|
|
|
1. test_quantization_range - Validates quantization behavior and output
|
|
|
2. test_absmax_scaling - Tests per-token absmax scaling
|
|
|
|
|
|
TestFunctionalIntegration (3 tests)
|
|
|
1. test_full_pipeline - End-to-end: decomposition → multi-ternary forward
|
|
|
2. test_bitlinear_with_activation_quant - Combines activation quantization with bitlinear
|
|
|
3. test_multi_ternary_end_to_end - Tests different k values with reconstruction validation
|
|
|
"""
|
|
|
|
|
|
import pytest
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from bitlinear.functional import (
|
|
|
bitlinear_python,
|
|
|
greedy_ternary_decomposition,
|
|
|
multi_ternary_linear_python,
|
|
|
activation_quant,
|
|
|
)
|
|
|
|
|
|
|
|
|
class TestBitLinearPython:
|
|
|
"""Tests for bitlinear_python function."""
|
|
|
|
|
|
def test_shape_correctness(self):
|
|
|
"""Test that output shape matches expected dimensions."""
|
|
|
batch_size, seq_len, in_features, out_features = 32, 128, 512, 1024
|
|
|
x = torch.randn(batch_size, seq_len, in_features)
|
|
|
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
|
|
gamma = torch.ones(out_features)
|
|
|
bias = torch.zeros(out_features)
|
|
|
|
|
|
output = bitlinear_python(x, W_ternary, gamma, bias)
|
|
|
|
|
|
assert output.shape == (batch_size, seq_len, out_features)
|
|
|
|
|
|
def test_no_bias(self):
|
|
|
"""Test forward pass without bias."""
|
|
|
batch_size, in_features, out_features = 16, 256, 512
|
|
|
x = torch.randn(batch_size, in_features)
|
|
|
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
|
|
gamma = torch.ones(out_features)
|
|
|
|
|
|
output = bitlinear_python(x, W_ternary, gamma, bias=None)
|
|
|
|
|
|
assert output.shape == (batch_size, out_features)
|
|
|
assert not torch.isnan(output).any()
|
|
|
|
|
|
def test_ternary_constraint(self):
|
|
|
"""Test that function works correctly with ternary weights {-1, 0, +1}."""
|
|
|
x = torch.randn(8, 64)
|
|
|
W_ternary = torch.randint(-1, 2, (128, 64)).float()
|
|
|
gamma = torch.ones(128)
|
|
|
|
|
|
|
|
|
unique_values = torch.unique(W_ternary)
|
|
|
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
|
|
|
|
|
|
|
|
|
output = bitlinear_python(x, W_ternary, gamma)
|
|
|
assert output.shape == (8, 128)
|
|
|
assert not torch.isnan(output).any()
|
|
|
|
|
|
def test_gamma_scaling(self):
|
|
|
"""Test that gamma scaling is applied correctly."""
|
|
|
x = torch.randn(4, 32)
|
|
|
W_ternary = torch.randint(-1, 2, (64, 32)).float()
|
|
|
gamma = torch.rand(64) * 2 + 0.5
|
|
|
|
|
|
|
|
|
output_with_gamma = bitlinear_python(x, W_ternary, gamma, bias=None)
|
|
|
|
|
|
|
|
|
gamma_ones = torch.ones_like(gamma)
|
|
|
output_no_gamma = bitlinear_python(x, W_ternary, gamma_ones, bias=None)
|
|
|
output_manual_scale = output_no_gamma * gamma.unsqueeze(0)
|
|
|
|
|
|
|
|
|
assert torch.allclose(output_with_gamma, output_manual_scale, atol=1e-5)
|
|
|
|
|
|
def test_numerical_correctness(self):
|
|
|
"""Test numerical correctness against standard nn.Linear."""
|
|
|
in_features, out_features = 128, 256
|
|
|
x = torch.randn(16, in_features)
|
|
|
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
|
|
gamma = torch.ones(out_features)
|
|
|
bias = torch.randn(out_features)
|
|
|
|
|
|
|
|
|
output_bitlinear = bitlinear_python(x, W_ternary, gamma, bias)
|
|
|
|
|
|
|
|
|
output_manual = torch.matmul(x, W_ternary.t()) * gamma.unsqueeze(0) + bias
|
|
|
|
|
|
|
|
|
assert torch.allclose(output_bitlinear, output_manual, atol=1e-6)
|
|
|
|
|
|
|
|
|
class TestGreedyTernaryDecomposition:
|
|
|
"""Tests for greedy_ternary_decomposition function."""
|
|
|
|
|
|
def test_decomposition_shape(self):
|
|
|
"""Test that decomposition returns correct shapes."""
|
|
|
W = torch.randn(512, 768)
|
|
|
k = 4
|
|
|
W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
|
|
|
assert W_ternary.shape == (k, 512, 768)
|
|
|
assert gammas.shape == (k, 512)
|
|
|
|
|
|
def test_ternary_values(self):
|
|
|
"""Test that decomposed weights are ternary."""
|
|
|
W = torch.randn(64, 128)
|
|
|
k = 2
|
|
|
W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
|
|
|
|
|
|
unique_values = torch.unique(W_ternary)
|
|
|
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()), \
|
|
|
f"Found non-ternary values: {unique_values.tolist()}"
|
|
|
|
|
|
def test_reconstruction_error(self):
|
|
|
"""Test that reconstruction error decreases with more components."""
|
|
|
W = torch.randn(128, 256)
|
|
|
errors = []
|
|
|
|
|
|
for k in [1, 2, 4, 8]:
|
|
|
W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
|
|
|
|
|
|
reconstruction = torch.zeros_like(W)
|
|
|
for i in range(k):
|
|
|
reconstruction += gammas[i].unsqueeze(1) * W_ternary[i]
|
|
|
|
|
|
error = torch.norm(W - reconstruction).item()
|
|
|
errors.append(error)
|
|
|
|
|
|
|
|
|
assert errors[0] > errors[1], f"Error not decreasing: {errors[0]} vs {errors[1]}"
|
|
|
assert errors[1] > errors[2], f"Error not decreasing: {errors[1]} vs {errors[2]}"
|
|
|
assert errors[2] > errors[3], f"Error not decreasing: {errors[2]} vs {errors[3]}"
|
|
|
|
|
|
def test_single_component(self):
|
|
|
"""Test k=1 case (single ternary quantization)."""
|
|
|
W = torch.randn(32, 64)
|
|
|
k = 1
|
|
|
W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
|
|
|
assert W_ternary.shape == (1, 32, 64)
|
|
|
assert gammas.shape == (1, 32)
|
|
|
|
|
|
|
|
|
unique_values = torch.unique(W_ternary)
|
|
|
assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
|
|
|
|
|
|
|
|
|
class TestMultiTernaryLinearPython:
|
|
|
"""Tests for multi_ternary_linear_python function."""
|
|
|
|
|
|
def test_shape_correctness(self):
|
|
|
"""Test output shape for multi-ternary linear."""
|
|
|
batch_size, in_features, out_features = 16, 128, 256
|
|
|
k = 4
|
|
|
|
|
|
x = torch.randn(batch_size, in_features)
|
|
|
W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
|
|
|
gammas = torch.rand(k, out_features)
|
|
|
bias = torch.randn(out_features)
|
|
|
|
|
|
output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
|
|
|
|
|
assert output.shape == (batch_size, out_features)
|
|
|
|
|
|
def test_equivalence_to_sum(self):
|
|
|
"""Test that multi-ternary equals sum of individual ternary ops."""
|
|
|
batch_size, in_features, out_features = 8, 64, 128
|
|
|
k = 3
|
|
|
|
|
|
x = torch.randn(batch_size, in_features)
|
|
|
W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
|
|
|
gammas = torch.rand(k, out_features)
|
|
|
bias = torch.randn(out_features)
|
|
|
|
|
|
|
|
|
output_multi = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
|
|
|
|
|
|
|
|
output_sum = torch.zeros(batch_size, out_features)
|
|
|
for i in range(k):
|
|
|
output_sum += bitlinear_python(x, W_ternary[i], gammas[i], bias=None)
|
|
|
output_sum += bias
|
|
|
|
|
|
|
|
|
assert torch.allclose(output_multi, output_sum, atol=1e-5)
|
|
|
|
|
|
|
|
|
class TestActivationQuant:
|
|
|
"""Tests for activation quantization."""
|
|
|
|
|
|
def test_quantization_range(self):
|
|
|
"""Test that quantized activations are in expected range."""
|
|
|
x = torch.randn(16, 128, 512) * 10
|
|
|
bits = 8
|
|
|
|
|
|
x_quant = activation_quant(x, bits=bits)
|
|
|
|
|
|
|
|
|
assert x_quant.shape == x.shape
|
|
|
|
|
|
|
|
|
assert not torch.allclose(x, x_quant, atol=1e-6)
|
|
|
|
|
|
|
|
|
assert torch.isfinite(x_quant).all()
|
|
|
|
|
|
def test_absmax_scaling(self):
|
|
|
"""Test that absmax scaling is applied correctly."""
|
|
|
|
|
|
x = torch.tensor([
|
|
|
[1.0, 2.0, 3.0, 4.0],
|
|
|
[-5.0, -10.0, 5.0, 10.0],
|
|
|
])
|
|
|
|
|
|
x_quant = activation_quant(x, bits=8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert x_quant.shape == (2, 4)
|
|
|
assert torch.isfinite(x_quant).all()
|
|
|
|
|
|
|
|
|
|
|
|
relative_error = torch.abs(x - x_quant) / (torch.abs(x) + 1e-5)
|
|
|
assert relative_error.mean() < 0.1
|
|
|
|
|
|
|
|
|
|
|
|
class TestFunctionalIntegration:
|
|
|
"""Integration tests combining multiple functional components."""
|
|
|
|
|
|
def test_full_pipeline(self):
|
|
|
"""Test full pipeline: decomposition → multi-ternary forward."""
|
|
|
|
|
|
in_features, out_features = 256, 512
|
|
|
W_dense = torch.randn(out_features, in_features)
|
|
|
|
|
|
|
|
|
k = 4
|
|
|
W_ternary, gammas = greedy_ternary_decomposition(W_dense, k)
|
|
|
|
|
|
|
|
|
batch_size = 16
|
|
|
x = torch.randn(batch_size, in_features)
|
|
|
bias = torch.randn(out_features)
|
|
|
|
|
|
output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
|
|
|
|
|
|
|
|
|
assert output.shape == (batch_size, out_features)
|
|
|
assert torch.isfinite(output).all()
|
|
|
|
|
|
|
|
|
output_dense = torch.matmul(x, W_dense.t()) + bias
|
|
|
|
|
|
|
|
|
relative_error = torch.norm(output - output_dense) / torch.norm(output_dense)
|
|
|
assert relative_error < 1.0
|
|
|
|
|
|
def test_bitlinear_with_activation_quant(self):
|
|
|
"""Test combining bitlinear with activation quantization."""
|
|
|
batch_size, in_features, out_features = 8, 128, 256
|
|
|
|
|
|
|
|
|
x = torch.randn(batch_size, in_features)
|
|
|
W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
|
|
|
gamma = torch.ones(out_features)
|
|
|
|
|
|
|
|
|
x_quant = activation_quant(x, bits=8)
|
|
|
|
|
|
|
|
|
output = bitlinear_python(x_quant, W_ternary, gamma)
|
|
|
|
|
|
|
|
|
assert output.shape == (batch_size, out_features)
|
|
|
assert torch.isfinite(output).all()
|
|
|
|
|
|
def test_multi_ternary_end_to_end(self):
|
|
|
"""Test multi-ternary from weight decomposition to forward pass."""
|
|
|
|
|
|
W = torch.randn(64, 128) * 0.1
|
|
|
x = torch.randn(4, 128)
|
|
|
|
|
|
|
|
|
for k in [1, 2, 4]:
|
|
|
W_ternary, gammas = greedy_ternary_decomposition(W, k)
|
|
|
output = multi_ternary_linear_python(x, W_ternary, gammas, bias=None)
|
|
|
|
|
|
|
|
|
assert output.shape == (4, 64)
|
|
|
assert torch.isfinite(output).all()
|
|
|
|
|
|
|
|
|
W_reconstructed = torch.zeros_like(W)
|
|
|
for i in range(k):
|
|
|
W_reconstructed += gammas[i].unsqueeze(1) * W_ternary[i]
|
|
|
|
|
|
|
|
|
output_expected = torch.matmul(x, W_reconstructed.t())
|
|
|
|
|
|
|
|
|
assert torch.allclose(output, output_expected, atol=1e-4)
|
|
|
|