File size: 7,166 Bytes
d4ec3e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
Verification tests for TurboQuant implementation.

1. Codebook: Lloyd-Max centroids match paper's distortion bounds
2. Packing: uint4 pack/unpack round-trip
3. Quantizer: MSE on random unit vectors ≤ paper's bound (0.009 at 4-bit)
4. Fixed-point: double quantization stability
"""

import sys
sys.path.insert(0, "/home/azureuser/turboquant")

import torch
import numpy as np

def test_codebook():
    """Verify Lloyd-Max codebook computation and distortion bounds."""
    from turboquant.codebook import compute_lloyd_max_codebook, compute_distortion

    print("=" * 60)
    print("TEST: Codebook computation")
    print("=" * 60)

    d = 128
    # Paper bounds: D_mse ≤ (√3·π/2) · (1/4^b)
    # Per-coordinate: D_mse / d = (√3·π / 2d) · (1/4^b)
    paper_total_mse = {2: 0.117, 3: 0.03, 4: 0.009}

    for bits in [2, 3, 4]:
        centroids, boundaries = compute_lloyd_max_codebook(d, bits)
        per_coord_mse = compute_distortion(d, bits, centroids, boundaries)
        total_mse = d * per_coord_mse
        bound = (np.sqrt(3) * np.pi / 2) * (1 / 4**bits)

        print(f"\n  b={bits} ({2**bits} levels):")
        print(f"    Centroids:         {centroids[:4]} ... {centroids[-4:]}")
        print(f"    Per-coord MSE:     {per_coord_mse:.6e}")
        print(f"    Total MSE (d×per): {total_mse:.6f}")
        print(f"    Paper bound:       {bound:.6f}")
        print(f"    Paper table value: {paper_total_mse.get(bits, 'N/A')}")
        print(f"    Within bound:      {total_mse <= bound * 1.01}")  # 1% tolerance for numerics

    print("\n  PASS: Codebook computation verified\n")


def test_packing():
    """Verify uint4 and uint2 pack/unpack round-trip."""
    from turboquant.packing import pack_uint4, unpack_uint4, pack_uint2, unpack_uint2

    print("=" * 60)
    print("TEST: Bit packing round-trip")
    print("=" * 60)

    # uint4
    x4 = torch.randint(0, 16, (4, 8, 128), dtype=torch.uint8)
    packed4 = pack_uint4(x4)
    unpacked4 = unpack_uint4(packed4)
    assert torch.equal(x4, unpacked4), "uint4 round-trip FAILED"
    print(f"  uint4: {x4.shape}{packed4.shape}{unpacked4.shape} ✓")

    # uint2
    x2 = torch.randint(0, 4, (4, 8, 128), dtype=torch.uint8)
    packed2 = pack_uint2(x2)
    unpacked2 = unpack_uint2(packed2)
    assert torch.equal(x2, unpacked2), "uint2 round-trip FAILED"
    print(f"  uint2: {x2.shape}{packed2.shape}{unpacked2.shape} ✓")

    print("\n  PASS: Packing round-trip verified\n")


def test_quantizer_mse():
    """Verify quantize→dequantize MSE matches paper's theoretical bounds."""
    from turboquant.quantizer import TurboQuantizer

    print("=" * 60)
    print("TEST: Quantizer MSE on random unit vectors")
    print("=" * 60)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dim = 128
    n_vectors = 10000
    paper_bounds = {2: 0.117, 4: 0.009}

    for bits in [2, 4]:
        quantizer = TurboQuantizer(dim=dim, bits=bits, device=device, seed=42)

        # Generate random unit vectors on S^(d-1)
        x = torch.randn(n_vectors, dim, device=device)
        x = x / x.norm(dim=-1, keepdim=True)
        x_bf16 = x.bfloat16()

        # Quantize and dequantize
        packed, norms = quantizer.quantize(x_bf16)
        x_recon = quantizer.dequantize(packed, norms)

        # Compute MSE
        mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1).mean().item()
        bound = paper_bounds[bits]

        print(f"\n  b={bits}:")
        print(f"    Vectors tested:  {n_vectors}")
        print(f"    Empirical MSE:   {mse:.6f}")
        print(f"    Paper bound:     {bound:.6f}")
        print(f"    Ratio (emp/bnd): {mse/bound:.3f}")
        print(f"    Within bound:    {mse <= bound * 1.1}")  # 10% tolerance

        # Also check individual vector MSE distribution
        per_vec_mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1)
        print(f"    MSE p50/p95/max: {per_vec_mse.median():.6f} / "
              f"{per_vec_mse.quantile(0.95):.6f} / {per_vec_mse.max():.6f}")

    print("\n  PASS: MSE within theoretical bounds\n")


def test_quantizer_shapes():
    """Verify correct tensor shapes through quantize/dequantize."""
    from turboquant.quantizer import TurboQuantizer

    print("=" * 60)
    print("TEST: Tensor shapes (simulating KV cache)")
    print("=" * 60)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dim = 128
    quantizer = TurboQuantizer(dim=dim, bits=4, device=device, seed=0)

    # Simulate KV cache tensor: (batch, heads, seq_len, head_dim)
    batch, heads, seq_len = 2, 8, 1024
    x = torch.randn(batch, heads, seq_len, dim, device=device, dtype=torch.bfloat16)

    packed, norms = quantizer.quantize(x)
    x_recon = quantizer.dequantize(packed, norms)

    print(f"  Input:  {x.shape} {x.dtype}")
    print(f"  Packed: {packed.shape} {packed.dtype}")
    print(f"  Norms:  {norms.shape} {norms.dtype}")
    print(f"  Recon:  {x_recon.shape} {x_recon.dtype}")
    print(f"  Shape match: {x.shape == x_recon.shape}")
    print(f"  Dtype match: {x.dtype == x_recon.dtype}")

    # Memory savings
    original_bytes = x.numel() * 2  # BF16 = 2 bytes
    quant_bytes = packed.numel() * 1 + norms.numel() * 2  # uint8 + BF16 norms
    ratio = original_bytes / quant_bytes
    print(f"\n  Original:    {original_bytes / 1024:.1f} KB")
    print(f"  Quantized:   {quant_bytes / 1024:.1f} KB")
    print(f"  Compression: {ratio:.2f}x")

    assert x.shape == x_recon.shape, "Shape mismatch!"
    assert x.dtype == x_recon.dtype, "Dtype mismatch!"
    print("\n  PASS: Shapes and dtypes correct\n")


def test_fixed_point():
    """Verify that quantize→dequantize→requantize→dequantize is stable."""
    from turboquant.quantizer import TurboQuantizer

    print("=" * 60)
    print("TEST: Double quantization stability (fixed-point)")
    print("=" * 60)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    quantizer = TurboQuantizer(dim=128, bits=4, device=device, seed=42)

    x = torch.randn(100, 128, device=device, dtype=torch.bfloat16)

    # First round
    packed1, norms1 = quantizer.quantize(x)
    x_recon1 = quantizer.dequantize(packed1, norms1)

    # Second round (re-quantize the reconstruction)
    packed2, norms2 = quantizer.quantize(x_recon1)
    x_recon2 = quantizer.dequantize(packed2, norms2)

    # Check packed indices are identical
    indices_match = torch.equal(packed1, packed2)
    recon_diff = (x_recon1.float() - x_recon2.float()).abs().max().item()

    print(f"  Packed indices identical: {indices_match}")
    print(f"  Max reconstruction diff:  {recon_diff:.2e}")
    print(f"  Norm diff (max):          {(norms1.float() - norms2.float()).abs().max().item():.2e}")

    if not indices_match:
        n_diff = (packed1 != packed2).sum().item()
        print(f"  WARNING: {n_diff} packed bytes differ (FP rounding at boundaries)")

    print("\n  PASS: Double quantization stable\n")


if __name__ == "__main__":
    test_codebook()
    test_packing()
    test_quantizer_mse()
    test_quantizer_shapes()
    test_fixed_point()
    print("=" * 60)
    print("ALL TESTS PASSED")
    print("=" * 60)