|
|
"""
|
|
|
Unit tests for BitLinear and MultiTernaryLinear layers.
|
|
|
|
|
|
These tests are here to validate the nn.Module implementations and their compatibility with standard PyTorch workflows. Here are the following test cases:
|
|
|
|
|
|
TestBitLinear (8 tests)
|
|
|
1. test_initialization - Verifies layer initializes with correct shapes
|
|
|
2. test_no_bias_initialization - Tests initialization without bias parameter
|
|
|
3. test_forward_shape - Validates output shape correctness
|
|
|
4. test_compatibility_with_nn_linear - Tests interface compatibility with nn.Linear
|
|
|
5. test_from_linear_conversion - Verifies conversion from nn.Linear to BitLinear
|
|
|
6. test_parameter_count - Validates parameter count calculation
|
|
|
7. test_weight_values_are_ternary - Ensures weights are in {-1, 0, +1}
|
|
|
8. test_gradient_flow - Tests gradient flow for QAT support
|
|
|
|
|
|
TestMultiTernaryLinear (5 tests)
|
|
|
1. test_initialization - Verifies k-component initialization
|
|
|
2. test_forward_shape - Tests forward pass output shape
|
|
|
3. test_k_components - Validates k-component tensor shapes
|
|
|
4. test_from_linear_conversion - Tests conversion with k parameter
|
|
|
5. test_better_approximation_with_more_k - Validates error decreases with larger k
|
|
|
|
|
|
TestConversionUtilities (3 tests)
|
|
|
1. test_convert_simple_model - Tests conversion of Sequential models
|
|
|
2. test_convert_nested_model - Tests conversion of nested module hierarchies
|
|
|
3. test_inplace_conversion - Tests in-place vs. copy conversion modes
|
|
|
|
|
|
TestLayerIntegration (3 tests)
|
|
|
1. test_in_transformer_block - Tests BitLinear in Transformer FFN block
|
|
|
2. test_training_step - Validates full training loop compatibility
|
|
|
3. test_save_and_load - Tests model serialization and deserialization
|
|
|
|
|
|
TestPerformanceComparison (2 tests - skipped)
|
|
|
1. test_memory_usage - Performance benchmark (run manually)
|
|
|
2. test_inference_speed - Performance benchmark (run manually)
|
|
|
"""
|
|
|
|
|
|
import pytest
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
|
|
|
|
|
|
|
|
class TestBitLinear:
|
|
|
"""Tests for BitLinear layer."""
|
|
|
|
|
|
def test_initialization(self):
|
|
|
"""Test that layer initializes correctly."""
|
|
|
layer = BitLinear(512, 1024)
|
|
|
assert layer.in_features == 512
|
|
|
assert layer.out_features == 1024
|
|
|
assert layer.bias is not None
|
|
|
assert layer.W_ternary.shape == (1024, 512)
|
|
|
assert layer.gamma.shape == (1024,)
|
|
|
|
|
|
def test_no_bias_initialization(self):
|
|
|
"""Test initialization without bias."""
|
|
|
layer = BitLinear(512, 1024, bias=False)
|
|
|
assert layer.bias is None
|
|
|
|
|
|
def test_forward_shape(self):
|
|
|
"""Test forward pass produces correct output shape."""
|
|
|
layer = BitLinear(512, 1024)
|
|
|
x = torch.randn(32, 128, 512)
|
|
|
output = layer(x)
|
|
|
assert output.shape == (32, 128, 1024)
|
|
|
|
|
|
def test_compatibility_with_nn_linear(self):
|
|
|
"""Test that BitLinear can replace nn.Linear in terms of interface."""
|
|
|
linear = nn.Linear(512, 512)
|
|
|
bitlinear = BitLinear(512, 512)
|
|
|
|
|
|
x = torch.randn(32, 512)
|
|
|
out_linear = linear(x)
|
|
|
out_bitlinear = bitlinear(x)
|
|
|
|
|
|
|
|
|
assert out_linear.shape == out_bitlinear.shape
|
|
|
|
|
|
def test_from_linear_conversion(self):
|
|
|
"""Test converting nn.Linear to BitLinear."""
|
|
|
linear = nn.Linear(512, 1024)
|
|
|
bitlinear = BitLinear.from_linear(linear)
|
|
|
|
|
|
assert bitlinear.in_features == 512
|
|
|
assert bitlinear.out_features == 1024
|
|
|
|
|
|
|
|
|
x = torch.randn(16, 512)
|
|
|
output = bitlinear(x)
|
|
|
assert output.shape == (16, 1024)
|
|
|
|
|
|
def test_parameter_count(self):
|
|
|
"""Test that parameter count is correct."""
|
|
|
layer = BitLinear(512, 512, bias=True)
|
|
|
|
|
|
expected_params = 512*512 + 512 + 512
|
|
|
actual_params = sum(p.numel() for p in layer.parameters())
|
|
|
assert actual_params == expected_params
|
|
|
|
|
|
def test_weight_values_are_ternary(self):
|
|
|
"""Test that stored weights are ternary {-1, 0, +1}."""
|
|
|
layer = BitLinear(512, 512)
|
|
|
W_ternary = layer.W_ternary
|
|
|
unique_values = torch.unique(W_ternary)
|
|
|
assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
|
|
|
|
|
|
def test_gradient_flow(self):
|
|
|
"""Test that gradients flow correctly (for QAT)."""
|
|
|
layer = BitLinear(256, 128)
|
|
|
x = torch.randn(8, 256, requires_grad=True)
|
|
|
output = layer(x)
|
|
|
loss = output.sum()
|
|
|
loss.backward()
|
|
|
|
|
|
assert x.grad is not None
|
|
|
|
|
|
assert layer.W_ternary.grad is not None
|
|
|
assert layer.gamma.grad is not None
|
|
|
|
|
|
|
|
|
class TestMultiTernaryLinear:
|
|
|
"""Tests for MultiTernaryLinear layer."""
|
|
|
|
|
|
def test_initialization(self):
|
|
|
"""Test layer initialization with k components."""
|
|
|
layer = MultiTernaryLinear(512, 1024, k=4)
|
|
|
assert layer.in_features == 512
|
|
|
assert layer.out_features == 1024
|
|
|
assert layer.k == 4
|
|
|
assert layer.W_ternary.shape == (4, 1024, 512)
|
|
|
assert layer.gammas.shape == (4, 1024)
|
|
|
|
|
|
def test_forward_shape(self):
|
|
|
"""Test forward pass shape."""
|
|
|
layer = MultiTernaryLinear(512, 1024, k=4)
|
|
|
x = torch.randn(32, 128, 512)
|
|
|
output = layer(x)
|
|
|
assert output.shape == (32, 128, 1024)
|
|
|
|
|
|
def test_k_components(self):
|
|
|
"""Test that layer uses k ternary components."""
|
|
|
layer = MultiTernaryLinear(512, 512, k=3)
|
|
|
assert layer.W_ternary.shape == (3, 512, 512)
|
|
|
assert layer.gammas.shape == (3, 512)
|
|
|
|
|
|
def test_from_linear_conversion(self):
|
|
|
"""Test converting nn.Linear to MultiTernaryLinear."""
|
|
|
linear = nn.Linear(512, 1024)
|
|
|
multi_ternary = MultiTernaryLinear.from_linear(linear, k=4)
|
|
|
assert multi_ternary.k == 4
|
|
|
assert multi_ternary.in_features == 512
|
|
|
assert multi_ternary.out_features == 1024
|
|
|
|
|
|
def test_better_approximation_with_more_k(self):
|
|
|
"""Test that larger k provides better approximation of dense layer."""
|
|
|
linear = nn.Linear(512, 512)
|
|
|
x = torch.randn(16, 512)
|
|
|
out_dense = linear(x)
|
|
|
|
|
|
|
|
|
errors = []
|
|
|
for k in [1, 2, 4]:
|
|
|
multi_ternary = MultiTernaryLinear.from_linear(linear, k=k)
|
|
|
out_ternary = multi_ternary(x)
|
|
|
error = torch.norm(out_dense - out_ternary)
|
|
|
errors.append(error)
|
|
|
|
|
|
|
|
|
assert errors[0] > errors[1] and errors[1] > errors[2]
|
|
|
|
|
|
|
|
|
class TestConversionUtilities:
|
|
|
"""Tests for model conversion utilities."""
|
|
|
|
|
|
def test_convert_simple_model(self):
|
|
|
"""Test converting a simple Sequential model."""
|
|
|
model = nn.Sequential(
|
|
|
nn.Linear(512, 1024),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(1024, 512),
|
|
|
)
|
|
|
|
|
|
model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
|
|
|
|
|
|
|
|
|
assert isinstance(model_bitlinear[0], BitLinear)
|
|
|
assert isinstance(model_bitlinear[2], BitLinear)
|
|
|
assert isinstance(model_bitlinear[1], nn.ReLU)
|
|
|
|
|
|
def test_convert_nested_model(self):
|
|
|
"""Test converting a nested model with submodules."""
|
|
|
class NestedModel(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.layer1 = nn.Linear(256, 512)
|
|
|
self.submodule = nn.Sequential(
|
|
|
nn.Linear(512, 512),
|
|
|
nn.ReLU(),
|
|
|
)
|
|
|
self.layer2 = nn.Linear(512, 128)
|
|
|
|
|
|
model = NestedModel()
|
|
|
model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
|
|
|
|
|
|
|
|
|
assert isinstance(model_bitlinear.layer1, BitLinear)
|
|
|
assert isinstance(model_bitlinear.submodule[0], BitLinear)
|
|
|
assert isinstance(model_bitlinear.layer2, BitLinear)
|
|
|
|
|
|
def test_inplace_conversion(self):
|
|
|
"""Test in-place vs. copy conversion."""
|
|
|
model = nn.Sequential(nn.Linear(256, 256))
|
|
|
|
|
|
|
|
|
model_copy = convert_linear_to_bitlinear(model, inplace=False)
|
|
|
assert id(model) != id(model_copy)
|
|
|
assert isinstance(model[0], nn.Linear)
|
|
|
assert isinstance(model_copy[0], BitLinear)
|
|
|
|
|
|
|
|
|
model2 = nn.Sequential(nn.Linear(256, 256))
|
|
|
model2_result = convert_linear_to_bitlinear(model2, inplace=True)
|
|
|
assert id(model2) == id(model2_result)
|
|
|
assert isinstance(model2[0], BitLinear)
|
|
|
|
|
|
|
|
|
class TestLayerIntegration:
|
|
|
"""Integration tests for layers in realistic scenarios."""
|
|
|
|
|
|
def test_in_transformer_block(self):
|
|
|
"""Test BitLinear in a Transformer attention block."""
|
|
|
|
|
|
class TransformerFFN(nn.Module):
|
|
|
def __init__(self, d_model=256, d_ff=1024):
|
|
|
super().__init__()
|
|
|
self.fc1 = BitLinear(d_model, d_ff)
|
|
|
self.relu = nn.ReLU()
|
|
|
self.fc2 = BitLinear(d_ff, d_model)
|
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.dropout(self.fc2(self.relu(self.fc1(x))))
|
|
|
|
|
|
model = TransformerFFN()
|
|
|
|
|
|
|
|
|
batch_size, seq_len, d_model = 8, 32, 256
|
|
|
x = torch.randn(batch_size, seq_len, d_model)
|
|
|
output = model(x)
|
|
|
|
|
|
|
|
|
assert output.shape == (batch_size, seq_len, d_model)
|
|
|
|
|
|
|
|
|
assert set(model.fc1.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
|
|
|
assert set(model.fc2.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
|
|
|
|
|
|
def test_training_step(self):
|
|
|
"""Test that layers work in a training loop."""
|
|
|
|
|
|
model = nn.Sequential(
|
|
|
BitLinear(128, 256),
|
|
|
nn.ReLU(),
|
|
|
BitLinear(256, 10),
|
|
|
)
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
|
|
|
|
|
x = torch.randn(16, 128)
|
|
|
output = model(x)
|
|
|
|
|
|
|
|
|
target = torch.randint(0, 10, (16,))
|
|
|
loss = nn.functional.cross_entropy(output, target)
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
assert model[0].W_ternary.grad is not None
|
|
|
assert model[0].gamma.grad is not None
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
assert torch.isfinite(loss)
|
|
|
|
|
|
def test_save_and_load(self):
|
|
|
"""Test saving and loading models with BitLinear layers."""
|
|
|
import tempfile
|
|
|
import os
|
|
|
|
|
|
|
|
|
model = nn.Sequential(
|
|
|
BitLinear(128, 256),
|
|
|
nn.ReLU(),
|
|
|
BitLinear(256, 64),
|
|
|
)
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as f:
|
|
|
temp_path = f.name
|
|
|
torch.save(model.state_dict(), temp_path)
|
|
|
|
|
|
try:
|
|
|
|
|
|
model_loaded = nn.Sequential(
|
|
|
BitLinear(128, 256),
|
|
|
nn.ReLU(),
|
|
|
BitLinear(256, 64),
|
|
|
)
|
|
|
model_loaded.load_state_dict(torch.load(temp_path))
|
|
|
|
|
|
|
|
|
assert torch.allclose(model[0].W_ternary, model_loaded[0].W_ternary)
|
|
|
assert torch.allclose(model[0].gamma, model_loaded[0].gamma)
|
|
|
assert torch.allclose(model[2].W_ternary, model_loaded[2].W_ternary)
|
|
|
assert torch.allclose(model[2].gamma, model_loaded[2].gamma)
|
|
|
|
|
|
|
|
|
x = torch.randn(8, 128)
|
|
|
with torch.no_grad():
|
|
|
out1 = model(x)
|
|
|
out2 = model_loaded(x)
|
|
|
assert torch.allclose(out1, out2)
|
|
|
finally:
|
|
|
|
|
|
os.unlink(temp_path)
|
|
|
|
|
|
|
|
|
|
|
|
class TestPerformanceComparison:
|
|
|
"""Tests comparing BitLinear to standard nn.Linear."""
|
|
|
|
|
|
@pytest.mark.skip("Performance test - run manually")
|
|
|
def test_memory_usage(self):
|
|
|
"""Compare memory usage of BitLinear vs. nn.Linear."""
|
|
|
|
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
@pytest.mark.skip("Performance test - run manually")
|
|
|
def test_inference_speed(self):
|
|
|
"""Compare inference speed (when CUDA kernels are implemented)."""
|
|
|
|
|
|
pass
|
|
|
|