File size: 13,675 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
"""
Unit tests for BitLinear and MultiTernaryLinear layers.
These tests are here to validate the nn.Module implementations and their compatibility with standard PyTorch workflows. Here are the following test cases:
TestBitLinear (8 tests)
1. test_initialization - Verifies layer initializes with correct shapes
2. test_no_bias_initialization - Tests initialization without bias parameter
3. test_forward_shape - Validates output shape correctness
4. test_compatibility_with_nn_linear - Tests interface compatibility with nn.Linear
5. test_from_linear_conversion - Verifies conversion from nn.Linear to BitLinear
6. test_parameter_count - Validates parameter count calculation
7. test_weight_values_are_ternary - Ensures weights are in {-1, 0, +1}
8. test_gradient_flow - Tests gradient flow for QAT support
TestMultiTernaryLinear (5 tests)
1. test_initialization - Verifies k-component initialization
2. test_forward_shape - Tests forward pass output shape
3. test_k_components - Validates k-component tensor shapes
4. test_from_linear_conversion - Tests conversion with k parameter
5. test_better_approximation_with_more_k - Validates error decreases with larger k
TestConversionUtilities (3 tests)
1. test_convert_simple_model - Tests conversion of Sequential models
2. test_convert_nested_model - Tests conversion of nested module hierarchies
3. test_inplace_conversion - Tests in-place vs. copy conversion modes
TestLayerIntegration (3 tests)
1. test_in_transformer_block - Tests BitLinear in Transformer FFN block
2. test_training_step - Validates full training loop compatibility
3. test_save_and_load - Tests model serialization and deserialization
TestPerformanceComparison (2 tests - skipped)
1. test_memory_usage - Performance benchmark (run manually)
2. test_inference_speed - Performance benchmark (run manually)
"""
import pytest
import torch
import torch.nn as nn
from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
class TestBitLinear:
"""Tests for BitLinear layer."""
def test_initialization(self):
"""Test that layer initializes correctly."""
layer = BitLinear(512, 1024)
assert layer.in_features == 512
assert layer.out_features == 1024
assert layer.bias is not None
assert layer.W_ternary.shape == (1024, 512)
assert layer.gamma.shape == (1024,)
def test_no_bias_initialization(self):
"""Test initialization without bias."""
layer = BitLinear(512, 1024, bias=False)
assert layer.bias is None
def test_forward_shape(self):
"""Test forward pass produces correct output shape."""
layer = BitLinear(512, 1024)
x = torch.randn(32, 128, 512)
output = layer(x)
assert output.shape == (32, 128, 1024)
def test_compatibility_with_nn_linear(self):
"""Test that BitLinear can replace nn.Linear in terms of interface."""
linear = nn.Linear(512, 512)
bitlinear = BitLinear(512, 512)
x = torch.randn(32, 512)
out_linear = linear(x)
out_bitlinear = bitlinear(x)
# Shapes should match (values will differ due to quantization)
assert out_linear.shape == out_bitlinear.shape
def test_from_linear_conversion(self):
"""Test converting nn.Linear to BitLinear."""
linear = nn.Linear(512, 1024)
bitlinear = BitLinear.from_linear(linear)
assert bitlinear.in_features == 512
assert bitlinear.out_features == 1024
# Test forward pass
x = torch.randn(16, 512)
output = bitlinear(x)
assert output.shape == (16, 1024)
def test_parameter_count(self):
"""Test that parameter count is correct."""
layer = BitLinear(512, 512, bias=True)
# W_ternary: 512*512, gamma: 512, bias: 512
expected_params = 512*512 + 512 + 512
actual_params = sum(p.numel() for p in layer.parameters())
assert actual_params == expected_params
def test_weight_values_are_ternary(self):
"""Test that stored weights are ternary {-1, 0, +1}."""
layer = BitLinear(512, 512)
W_ternary = layer.W_ternary
unique_values = torch.unique(W_ternary)
assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
def test_gradient_flow(self):
"""Test that gradients flow correctly (for QAT)."""
layer = BitLinear(256, 128)
x = torch.randn(8, 256, requires_grad=True)
output = layer(x)
loss = output.sum()
loss.backward()
# Check that input has gradients
assert x.grad is not None
# Check that parameters have gradients
assert layer.W_ternary.grad is not None
assert layer.gamma.grad is not None
class TestMultiTernaryLinear:
"""Tests for MultiTernaryLinear layer."""
def test_initialization(self):
"""Test layer initialization with k components."""
layer = MultiTernaryLinear(512, 1024, k=4)
assert layer.in_features == 512
assert layer.out_features == 1024
assert layer.k == 4
assert layer.W_ternary.shape == (4, 1024, 512)
assert layer.gammas.shape == (4, 1024)
def test_forward_shape(self):
"""Test forward pass shape."""
layer = MultiTernaryLinear(512, 1024, k=4)
x = torch.randn(32, 128, 512)
output = layer(x)
assert output.shape == (32, 128, 1024)
def test_k_components(self):
"""Test that layer uses k ternary components."""
layer = MultiTernaryLinear(512, 512, k=3)
assert layer.W_ternary.shape == (3, 512, 512)
assert layer.gammas.shape == (3, 512)
def test_from_linear_conversion(self):
"""Test converting nn.Linear to MultiTernaryLinear."""
linear = nn.Linear(512, 1024)
multi_ternary = MultiTernaryLinear.from_linear(linear, k=4)
assert multi_ternary.k == 4
assert multi_ternary.in_features == 512
assert multi_ternary.out_features == 1024
def test_better_approximation_with_more_k(self):
"""Test that larger k provides better approximation of dense layer."""
linear = nn.Linear(512, 512)
x = torch.randn(16, 512)
out_dense = linear(x)
# Compare approximation quality for different k
errors = []
for k in [1, 2, 4]:
multi_ternary = MultiTernaryLinear.from_linear(linear, k=k)
out_ternary = multi_ternary(x)
error = torch.norm(out_dense - out_ternary)
errors.append(error)
# Error should generally decrease with larger k
assert errors[0] > errors[1] and errors[1] > errors[2]
class TestConversionUtilities:
"""Tests for model conversion utilities."""
def test_convert_simple_model(self):
"""Test converting a simple Sequential model."""
model = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
)
model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
# Check that Linear layers are replaced
assert isinstance(model_bitlinear[0], BitLinear)
assert isinstance(model_bitlinear[2], BitLinear)
assert isinstance(model_bitlinear[1], nn.ReLU)
def test_convert_nested_model(self):
"""Test converting a nested model with submodules."""
class NestedModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(256, 512)
self.submodule = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
)
self.layer2 = nn.Linear(512, 128)
model = NestedModel()
model_bitlinear = convert_linear_to_bitlinear(model, inplace=False)
# Check conversions
assert isinstance(model_bitlinear.layer1, BitLinear)
assert isinstance(model_bitlinear.submodule[0], BitLinear)
assert isinstance(model_bitlinear.layer2, BitLinear)
def test_inplace_conversion(self):
"""Test in-place vs. copy conversion."""
model = nn.Sequential(nn.Linear(256, 256))
# Test inplace=False creates a copy
model_copy = convert_linear_to_bitlinear(model, inplace=False)
assert id(model) != id(model_copy)
assert isinstance(model[0], nn.Linear) # Original unchanged
assert isinstance(model_copy[0], BitLinear) # Copy converted
# Test inplace=True modifies original
model2 = nn.Sequential(nn.Linear(256, 256))
model2_result = convert_linear_to_bitlinear(model2, inplace=True)
assert id(model2) == id(model2_result)
assert isinstance(model2[0], BitLinear) # Original modified
class TestLayerIntegration:
"""Integration tests for layers in realistic scenarios."""
def test_in_transformer_block(self):
"""Test BitLinear in a Transformer attention block."""
# Create a simplified Transformer FFN block
class TransformerFFN(nn.Module):
def __init__(self, d_model=256, d_ff=1024):
super().__init__()
self.fc1 = BitLinear(d_model, d_ff)
self.relu = nn.ReLU()
self.fc2 = BitLinear(d_ff, d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
return self.dropout(self.fc2(self.relu(self.fc1(x))))
model = TransformerFFN()
# Test forward pass
batch_size, seq_len, d_model = 8, 32, 256
x = torch.randn(batch_size, seq_len, d_model)
output = model(x)
# Verify shape
assert output.shape == (batch_size, seq_len, d_model)
# Verify weights are ternary
assert set(model.fc1.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
assert set(model.fc2.W_ternary.unique().tolist()).issubset({-1.0, 0.0, 1.0})
def test_training_step(self):
"""Test that layers work in a training loop."""
# Create simple model
model = nn.Sequential(
BitLinear(128, 256),
nn.ReLU(),
BitLinear(256, 10),
)
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Forward pass
x = torch.randn(16, 128)
output = model(x)
# Compute loss
target = torch.randint(0, 10, (16,))
loss = nn.functional.cross_entropy(output, target)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Verify gradients exist
assert model[0].W_ternary.grad is not None
assert model[0].gamma.grad is not None
# Optimizer step
optimizer.step()
# Verify no errors and loss is finite
assert torch.isfinite(loss)
def test_save_and_load(self):
"""Test saving and loading models with BitLinear layers."""
import tempfile
import os
# Create model
model = nn.Sequential(
BitLinear(128, 256),
nn.ReLU(),
BitLinear(256, 64),
)
# Save model
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as f:
temp_path = f.name
torch.save(model.state_dict(), temp_path)
try:
# Create new model and load weights
model_loaded = nn.Sequential(
BitLinear(128, 256),
nn.ReLU(),
BitLinear(256, 64),
)
model_loaded.load_state_dict(torch.load(temp_path))
# Verify weights match
assert torch.allclose(model[0].W_ternary, model_loaded[0].W_ternary)
assert torch.allclose(model[0].gamma, model_loaded[0].gamma)
assert torch.allclose(model[2].W_ternary, model_loaded[2].W_ternary)
assert torch.allclose(model[2].gamma, model_loaded[2].gamma)
# Verify forward pass produces same output
x = torch.randn(8, 128)
with torch.no_grad():
out1 = model(x)
out2 = model_loaded(x)
assert torch.allclose(out1, out2)
finally:
# Clean up
os.unlink(temp_path)
# Performance comparison tests
class TestPerformanceComparison:
"""Tests comparing BitLinear to standard nn.Linear."""
@pytest.mark.skip("Performance test - run manually")
def test_memory_usage(self):
"""Compare memory usage of BitLinear vs. nn.Linear."""
# TODO: Implement test
# Measure memory for large layers
# BitLinear should use significantly less memory
pass
@pytest.mark.skip("Performance test - run manually")
def test_inference_speed(self):
"""Compare inference speed (when CUDA kernels are implemented)."""
# TODO: Implement test
pass
|