File size: 13,852 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""

Unit tests for functional API (bitlinear_python, greedy_ternary_decomposition, etc.)



These tests are here to validate the correctness of the pure PyTorch reference implementations. Here are the following test cases:



TestBitLinearPython (5 tests)

    1. test_shape_correctness - Verifies output dimensions for 3D inputs

    2. test_no_bias - Tests forward pass without bias term

    3. test_ternary_constraint - Validates ternary weight values {-1, 0, +1}

    4. test_gamma_scaling - Verifies gamma scaling is applied correctly

    5. test_numerical_correctness - Compares against manual torch computation

    

TestGreedyTernaryDecomposition (4 tests)

    1. test_decomposition_shape - Checks output tensor shapes

    2. test_ternary_values - Ensures all decomposed weights are ternary

    3. test_reconstruction_error - Validates error decreases with more components

    4. test_single_component - Tests k=1 edge case

    

TestMultiTernaryLinearPython (2 tests)

    1. test_shape_correctness - Verifies output shape

    2. test_equivalence_to_sum - Confirms equivalence to summing individual operations

    

TestActivationQuant (2 tests)

    1. test_quantization_range - Validates quantization behavior and output

    2. test_absmax_scaling - Tests per-token absmax scaling

    

TestFunctionalIntegration (3 tests)

    1. test_full_pipeline - End-to-end: decomposition → multi-ternary forward

    2. test_bitlinear_with_activation_quant - Combines activation quantization with bitlinear

    3. test_multi_ternary_end_to_end - Tests different k values with reconstruction validation

"""

import pytest
import torch
import torch.nn as nn

from bitlinear.functional import (
    bitlinear_python,
    greedy_ternary_decomposition,
    multi_ternary_linear_python,
    activation_quant,
)


class TestBitLinearPython:
    """Tests for bitlinear_python function."""
    
    def test_shape_correctness(self):
        """Test that output shape matches expected dimensions."""
        batch_size, seq_len, in_features, out_features = 32, 128, 512, 1024
        x = torch.randn(batch_size, seq_len, in_features)
        W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
        gamma = torch.ones(out_features)
        bias = torch.zeros(out_features)
        
        output = bitlinear_python(x, W_ternary, gamma, bias)
        
        assert output.shape == (batch_size, seq_len, out_features)
    
    def test_no_bias(self):
        """Test forward pass without bias."""
        batch_size, in_features, out_features = 16, 256, 512
        x = torch.randn(batch_size, in_features)
        W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
        gamma = torch.ones(out_features)
        
        output = bitlinear_python(x, W_ternary, gamma, bias=None)
        
        assert output.shape == (batch_size, out_features)
        assert not torch.isnan(output).any()
    
    def test_ternary_constraint(self):
        """Test that function works correctly with ternary weights {-1, 0, +1}."""
        x = torch.randn(8, 64)
        W_ternary = torch.randint(-1, 2, (128, 64)).float()
        gamma = torch.ones(128)
        
        # Verify W_ternary contains only {-1, 0, +1}
        unique_values = torch.unique(W_ternary)
        assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())
        
        # Check output correctness
        output = bitlinear_python(x, W_ternary, gamma)
        assert output.shape == (8, 128)
        assert not torch.isnan(output).any()
    
    def test_gamma_scaling(self):
        """Test that gamma scaling is applied correctly."""
        x = torch.randn(4, 32)
        W_ternary = torch.randint(-1, 2, (64, 32)).float()
        gamma = torch.rand(64) * 2 + 0.5  # Random scales between 0.5 and 2.5
        
        # Compute output with gamma
        output_with_gamma = bitlinear_python(x, W_ternary, gamma, bias=None)
        
        # Compute output with gamma=1 and manually scale
        gamma_ones = torch.ones_like(gamma)
        output_no_gamma = bitlinear_python(x, W_ternary, gamma_ones, bias=None)
        output_manual_scale = output_no_gamma * gamma.unsqueeze(0)
        
        # Should be equivalent
        assert torch.allclose(output_with_gamma, output_manual_scale, atol=1e-5)
    
    def test_numerical_correctness(self):
        """Test numerical correctness against standard nn.Linear."""
        in_features, out_features = 128, 256
        x = torch.randn(16, in_features)
        W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
        gamma = torch.ones(out_features)
        bias = torch.randn(out_features)
        
        # Compute with bitlinear_python
        output_bitlinear = bitlinear_python(x, W_ternary, gamma, bias)
        
        # Compute manually with torch operations
        output_manual = torch.matmul(x, W_ternary.t()) * gamma.unsqueeze(0) + bias
        
        # Should match exactly
        assert torch.allclose(output_bitlinear, output_manual, atol=1e-6)


class TestGreedyTernaryDecomposition:
    """Tests for greedy_ternary_decomposition function."""
    
    def test_decomposition_shape(self):
        """Test that decomposition returns correct shapes."""
        W = torch.randn(512, 768)
        k = 4
        W_ternary, gammas = greedy_ternary_decomposition(W, k)
        
        assert W_ternary.shape == (k, 512, 768)
        assert gammas.shape == (k, 512)
    
    def test_ternary_values(self):
        """Test that decomposed weights are ternary."""
        W = torch.randn(64, 128)
        k = 2
        W_ternary, gammas = greedy_ternary_decomposition(W, k)
        
        # Verify all values in W_ternary are in {-1, 0, +1}
        unique_values = torch.unique(W_ternary)
        assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist()), \
            f"Found non-ternary values: {unique_values.tolist()}"
    
    def test_reconstruction_error(self):
        """Test that reconstruction error decreases with more components."""
        W = torch.randn(128, 256)
        errors = []
        
        for k in [1, 2, 4, 8]:
            W_ternary, gammas = greedy_ternary_decomposition(W, k)
            
            # Reconstruct: sum of gamma_i * W_i
            reconstruction = torch.zeros_like(W)
            for i in range(k):
                reconstruction += gammas[i].unsqueeze(1) * W_ternary[i]
            
            error = torch.norm(W - reconstruction).item()
            errors.append(error)
        
        # Error should decrease with more components
        assert errors[0] > errors[1], f"Error not decreasing: {errors[0]} vs {errors[1]}"
        assert errors[1] > errors[2], f"Error not decreasing: {errors[1]} vs {errors[2]}"
        assert errors[2] > errors[3], f"Error not decreasing: {errors[2]} vs {errors[3]}"
    
    def test_single_component(self):
        """Test k=1 case (single ternary quantization)."""
        W = torch.randn(32, 64)
        k = 1
        W_ternary, gammas = greedy_ternary_decomposition(W, k)
        
        assert W_ternary.shape == (1, 32, 64)
        assert gammas.shape == (1, 32)
        
        # Verify ternary values
        unique_values = torch.unique(W_ternary)
        assert all(v in [-1.0, 0.0, 1.0] for v in unique_values.tolist())


class TestMultiTernaryLinearPython:
    """Tests for multi_ternary_linear_python function."""
    
    def test_shape_correctness(self):
        """Test output shape for multi-ternary linear."""
        batch_size, in_features, out_features = 16, 128, 256
        k = 4
        
        x = torch.randn(batch_size, in_features)
        W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
        gammas = torch.rand(k, out_features)
        bias = torch.randn(out_features)
        
        output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
        
        assert output.shape == (batch_size, out_features)
    
    def test_equivalence_to_sum(self):
        """Test that multi-ternary equals sum of individual ternary ops."""
        batch_size, in_features, out_features = 8, 64, 128
        k = 3
        
        x = torch.randn(batch_size, in_features)
        W_ternary = torch.randint(-1, 2, (k, out_features, in_features)).float()
        gammas = torch.rand(k, out_features)
        bias = torch.randn(out_features)
        
        # Compute multi-ternary in one call
        output_multi = multi_ternary_linear_python(x, W_ternary, gammas, bias)
        
        # Compute sum of k separate bitlinear_python calls
        output_sum = torch.zeros(batch_size, out_features)
        for i in range(k):
            output_sum += bitlinear_python(x, W_ternary[i], gammas[i], bias=None)
        output_sum += bias  # Add bias once at the end
        
        # Verify they match
        assert torch.allclose(output_multi, output_sum, atol=1e-5)


class TestActivationQuant:
    """Tests for activation quantization."""
    
    def test_quantization_range(self):
        """Test that quantized activations are in expected range."""
        x = torch.randn(16, 128, 512) * 10  # Large range
        bits = 8
        
        x_quant = activation_quant(x, bits=bits)
        
        # Output should have same shape
        assert x_quant.shape == x.shape
        
        # Check that quantization reduces precision (should be close but not exact)
        assert not torch.allclose(x, x_quant, atol=1e-6)
        
        # Quantized values should still be in reasonable range
        assert torch.isfinite(x_quant).all()
    
    def test_absmax_scaling(self):
        """Test that absmax scaling is applied correctly."""
        # Create input with known range per token
        x = torch.tensor([
            [1.0, 2.0, 3.0, 4.0],
            [-5.0, -10.0, 5.0, 10.0],
        ])
        
        x_quant = activation_quant(x, bits=8)
        
        # Should preserve relative magnitudes within each token
        # First token: max is 4.0
        # Second token: max is 10.0
        assert x_quant.shape == (2, 4)
        assert torch.isfinite(x_quant).all()
        
        # The quantized values should be close to original for 8-bit
        # (127 levels provide good precision)
        relative_error = torch.abs(x - x_quant) / (torch.abs(x) + 1e-5)
        assert relative_error.mean() < 0.1  # Less than 10% average error


# Integration test
class TestFunctionalIntegration:
    """Integration tests combining multiple functional components."""
    
    def test_full_pipeline(self):
        """Test full pipeline: decomposition → multi-ternary forward."""
        # 1. Create dense weights
        in_features, out_features = 256, 512
        W_dense = torch.randn(out_features, in_features)
        
        # 2. Apply greedy decomposition
        k = 4
        W_ternary, gammas = greedy_ternary_decomposition(W_dense, k)
        
        # 3. Run multi_ternary_linear_python
        batch_size = 16
        x = torch.randn(batch_size, in_features)
        bias = torch.randn(out_features)
        
        output = multi_ternary_linear_python(x, W_ternary, gammas, bias)
        
        # 4. Verify output shape and basic correctness
        assert output.shape == (batch_size, out_features)
        assert torch.isfinite(output).all()
        
        # Compare with dense operation to verify reasonable approximation
        output_dense = torch.matmul(x, W_dense.t()) + bias
        
        # They should be similar but not identical (due to quantization)
        relative_error = torch.norm(output - output_dense) / torch.norm(output_dense)
        assert relative_error < 1.0  # Error should be reasonable
        
    def test_bitlinear_with_activation_quant(self):
        """Test combining bitlinear with activation quantization."""
        batch_size, in_features, out_features = 8, 128, 256
        
        # Create inputs
        x = torch.randn(batch_size, in_features)
        W_ternary = torch.randint(-1, 2, (out_features, in_features)).float()
        gamma = torch.ones(out_features)
        
        # Quantize activations
        x_quant = activation_quant(x, bits=8)
        
        # Forward pass
        output = bitlinear_python(x_quant, W_ternary, gamma)
        
        # Check output
        assert output.shape == (batch_size, out_features)
        assert torch.isfinite(output).all()
        
    def test_multi_ternary_end_to_end(self):
        """Test multi-ternary from weight decomposition to forward pass."""
        # Simulate a small layer
        W = torch.randn(64, 128) * 0.1  # Small weights for numerical stability
        x = torch.randn(4, 128)
        
        # Decompose with different k values
        for k in [1, 2, 4]:
            W_ternary, gammas = greedy_ternary_decomposition(W, k)
            output = multi_ternary_linear_python(x, W_ternary, gammas, bias=None)
            
            # Check output is valid
            assert output.shape == (4, 64)
            assert torch.isfinite(output).all()
            
            # Verify reconstruction quality
            W_reconstructed = torch.zeros_like(W)
            for i in range(k):
                W_reconstructed += gammas[i].unsqueeze(1) * W_ternary[i]
            
            # Compute expected output with reconstructed weights
            output_expected = torch.matmul(x, W_reconstructed.t())
            
            # Should match closely
            assert torch.allclose(output, output_expected, atol=1e-4)