BitLinear / tests /test_quantization.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
raw
history blame
11 kB
"""
Unit tests for quantization utilities.
These tests are here to validate ternary quantization, scaling, and packing functions. Here are the following test cases:
TestAbsmaxScale (3 tests)
1. test_global_scale - Tests global absmax scaling computation
2. test_per_channel_scale - Tests per-channel (per-row) absmax scaling
3. test_zero_tensor - Validates behavior with zero tensors (numerical stability)
TestTernaryQuantize (3 tests)
1. test_quantization_values - Ensures output contains only {-1, 0, +1}
2. test_sign_preservation - Validates sign preservation for large values
3. test_threshold_behavior - Tests threshold-based zero assignment
TestWeightToTernary (3 tests)
1. test_output_shapes - Verifies correct output tensor shapes
2. test_per_channel_vs_global - Tests per-channel vs. global scaling modes
3. test_reconstruction_quality - Validates reconstruction error is reasonable
TestActivationQuantization (2 tests)
1. test_quantization_range - Tests 8-bit quantization range
2. test_per_token_scaling - Validates per-token vs. global scaling
TestDequantization (1 test)
1. test_dequantize_inverse - Tests quantize β†’ dequantize inverse operation
TestBase3Packing (3 tests)
1. test_pack_unpack_roundtrip - Validates pack β†’ unpack recovers original
2. test_memory_efficiency - Tests ~20x compression achievement
3. test_packing_with_padding - Tests padding for non-multiple-of-5 dimensions
TestCompressionUtilities (2 tests)
1. test_compression_ratio_calculation - Tests compression ratio computation
2. test_memory_savings_estimation - Validates memory savings estimation
TestQuantizationIntegration (2 tests)
1. test_full_quantization_pipeline - Tests dense β†’ ternary β†’ packed β†’ unpacked
2. test_quantization_preserves_functionality - Validates quantized layer outputs
"""
import pytest
import torch
from bitlinear.quantization import (
absmax_scale,
ternary_quantize,
weight_to_ternary,
quantize_activations_absmax,
dequantize_scale,
)
from bitlinear.packing import (
pack_ternary_base3,
unpack_ternary_base3,
compute_compression_ratio,
estimate_memory_savings,
)
class TestAbsmaxScale:
"""Tests for absmax_scale function."""
def test_global_scale(self):
"""Test global absmax scaling."""
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
scale = absmax_scale(W, dim=None)
assert torch.isclose(scale, torch.tensor(6.0))
def test_per_channel_scale(self):
"""Test per-channel (per-row) absmax scaling."""
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
scale = absmax_scale(W, dim=1)
expected = torch.tensor([3.0, 6.0])
assert torch.allclose(scale, expected)
def test_zero_tensor(self):
"""Test behavior with zero tensor."""
W = torch.zeros(10, 10)
scale = absmax_scale(W, dim=None)
# Should handle division by zero gracefully (clamped to epsilon)
assert scale > 0
assert scale < 1e-4
class TestTernaryQuantize:
"""Tests for ternary_quantize function."""
def test_quantization_values(self):
"""Test that output contains only {-1, 0, +1}."""
W = torch.randn(100, 100)
W_ternary = ternary_quantize(W)
unique_values = torch.unique(W_ternary)
assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
def test_sign_preservation(self):
"""Test that signs are preserved correctly."""
# Use values well above threshold (> 0.5 * max)
W = torch.tensor([[10.0, -10.0, 0.01], [-8.0, 8.0, -0.01]])
W_ternary = ternary_quantize(W)
# Large positive values should be +1
assert W_ternary[0, 0] == 1.0
# Large negative values should be -1
assert W_ternary[0, 1] == -1.0
assert W_ternary[1, 0] == -1.0
# Large positive
assert W_ternary[1, 1] == 1.0
def test_threshold_behavior(self):
"""Test that threshold determines zero assignment."""
# Create tensor with known values
W = torch.tensor([[10.0, 0.1, -10.0], [0.2, -0.2, 5.0]])
W_ternary = ternary_quantize(W)
# Small values near zero should become 0
# Exact behavior depends on threshold, but there should be some zeros
assert 0.0 in W_ternary
class TestWeightToTernary:
"""Tests for weight_to_ternary function."""
def test_output_shapes(self):
"""Test that output shapes are correct."""
W = torch.randn(512, 768)
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
assert W_ternary.shape == (512, 768)
assert gamma.shape == (512,)
def test_per_channel_vs_global(self):
"""Test difference between per-channel and global scaling."""
W = torch.randn(512, 768)
W_t_pc, gamma_pc = weight_to_ternary(W, per_channel=True)
W_t_g, gamma_g = weight_to_ternary(W, per_channel=False)
assert gamma_pc.shape == (512,)
assert gamma_g.shape == torch.Size([]) # Scalar
def test_reconstruction_quality(self):
"""Test that reconstruction W_ternary * gamma approximates W."""
W = torch.randn(512, 768)
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
W_reconstructed = W_ternary * gamma.unsqueeze(1)
error = torch.norm(W - W_reconstructed) / torch.norm(W)
# Ternary quantization has inherent error, allow up to 0.9 relative error
# This is expected for aggressive quantization to only 3 values
assert error < 1.0
class TestActivationQuantization:
"""Tests for activation quantization."""
def test_quantization_range(self):
"""Test that quantized values are in expected range."""
x = torch.randn(16, 32, 512)
x_quant = quantize_activations_absmax(x, bits=8, per_token=True)
# Should be roughly in similar range as input
assert x_quant.abs().max() <= x.abs().max() * 1.1
def test_per_token_scaling(self):
"""Test per-token vs. global scaling."""
x = torch.randn(16, 32, 512)
x_quant_per_token = quantize_activations_absmax(x, bits=8, per_token=True)
x_quant_global = quantize_activations_absmax(x, bits=8, per_token=False)
# Both should work without errors
assert x_quant_per_token.shape == x.shape
assert x_quant_global.shape == x.shape
class TestDequantization:
"""Tests for dequantization."""
def test_dequantize_inverse(self):
"""Test that quantize β†’ dequantize is approximately identity."""
W = torch.randn(512, 768)
W_quant, scale = weight_to_ternary(W, per_channel=True)
W_dequant = dequantize_scale(W_quant, scale)
# Should be close to W_quant * scale reconstruction
W_expected = W_quant * scale.unsqueeze(1)
assert torch.allclose(W_dequant, W_expected)
class TestBase3Packing:
"""Tests for base-3 packing utilities."""
def test_pack_unpack_roundtrip(self):
"""Test that pack β†’ unpack recovers original ternary weights."""
W_ternary = torch.randint(-1, 2, (512, 768)).float()
packed, shape = pack_ternary_base3(W_ternary)
W_unpacked = unpack_ternary_base3(packed, shape)
assert torch.allclose(W_ternary, W_unpacked)
def test_memory_efficiency(self):
"""Test that packing achieves expected compression."""
W_ternary = torch.randint(-1, 2, (512, 768)).float()
original_size = W_ternary.numel() * 4 # float32 = 4 bytes
packed, shape = pack_ternary_base3(W_ternary)
packed_size = packed.numel() * 1 # uint8 = 1 byte
compression = original_size / packed_size
# Should achieve ~20x compression (32 bits β†’ 1.6 bits)
assert compression > 15 # Allow some overhead
def test_packing_with_padding(self):
"""Test packing when dimensions are not multiples of 5."""
# Test with various sizes to ensure padding is handled correctly
for size in [(13, 17), (100, 203), (7, 11)]:
W_ternary = torch.randint(-1, 2, size).float()
packed, shape = pack_ternary_base3(W_ternary)
W_unpacked = unpack_ternary_base3(packed, shape)
assert torch.allclose(W_ternary, W_unpacked)
class TestCompressionUtilities:
"""Tests for compression ratio and memory estimation utilities."""
def test_compression_ratio_calculation(self):
"""Test compression ratio calculation."""
ratio = compute_compression_ratio(1024, 51)
assert abs(ratio - 20.0) < 0.5
def test_memory_savings_estimation(self):
"""Test memory savings estimation for layer."""
stats = estimate_memory_savings(768, 3072, num_layers=12)
assert 'float32_bytes' in stats
assert 'packed_bytes' in stats
assert 'savings_bytes' in stats
assert 'compression_ratio' in stats
assert stats['compression_ratio'] > 15
class TestQuantizationIntegration:
"""Integration tests for quantization pipeline."""
def test_full_quantization_pipeline(self):
"""Test complete pipeline: dense β†’ ternary β†’ packed β†’ unpacked."""
# 1. Start with dense weights
W = torch.randn(128, 256)
# 2. Quantize to ternary
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
# 3. Pack to base-3
packed, shape = pack_ternary_base3(W_ternary)
# 4. Unpack
W_unpacked = unpack_ternary_base3(packed, shape)
# 5. Verify correctness
assert torch.allclose(W_ternary, W_unpacked)
assert set(W_unpacked.unique().tolist()).issubset({-1.0, 0.0, 1.0})
def test_quantization_preserves_functionality(self):
"""Test that quantized layer produces reasonable outputs."""
from bitlinear import BitLinear
import torch.nn as nn
# Create dense layer
dense = nn.Linear(256, 128)
# Test input
x = torch.randn(16, 256)
out_dense = dense(x)
# Quantize to BitLinear
bitlinear = BitLinear.from_linear(dense)
out_quantized = bitlinear(x)
# Outputs should have same shape
assert out_dense.shape == out_quantized.shape
# Outputs should be correlated (similar but not identical)
# Calculate correlation
correlation = torch.corrcoef(torch.stack([out_dense.flatten(), out_quantized.flatten()]))[0, 1]
assert correlation > 0.5 # Should have reasonable correlation