| """ |
| Verification tests for TurboQuant implementation. |
| |
| 1. Codebook: Lloyd-Max centroids match paper's distortion bounds |
| 2. Packing: uint4 pack/unpack round-trip |
| 3. Quantizer: MSE on random unit vectors ≤ paper's bound (0.009 at 4-bit) |
| 4. Fixed-point: double quantization stability |
| """ |
|
|
| import sys |
| sys.path.insert(0, "/home/azureuser/turboquant") |
|
|
| import torch |
| import numpy as np |
|
|
| def test_codebook(): |
| """Verify Lloyd-Max codebook computation and distortion bounds.""" |
| from turboquant.codebook import compute_lloyd_max_codebook, compute_distortion |
|
|
| print("=" * 60) |
| print("TEST: Codebook computation") |
| print("=" * 60) |
|
|
| d = 128 |
| |
| |
| paper_total_mse = {2: 0.117, 3: 0.03, 4: 0.009} |
|
|
| for bits in [2, 3, 4]: |
| centroids, boundaries = compute_lloyd_max_codebook(d, bits) |
| per_coord_mse = compute_distortion(d, bits, centroids, boundaries) |
| total_mse = d * per_coord_mse |
| bound = (np.sqrt(3) * np.pi / 2) * (1 / 4**bits) |
|
|
| print(f"\n b={bits} ({2**bits} levels):") |
| print(f" Centroids: {centroids[:4]} ... {centroids[-4:]}") |
| print(f" Per-coord MSE: {per_coord_mse:.6e}") |
| print(f" Total MSE (d×per): {total_mse:.6f}") |
| print(f" Paper bound: {bound:.6f}") |
| print(f" Paper table value: {paper_total_mse.get(bits, 'N/A')}") |
| print(f" Within bound: {total_mse <= bound * 1.01}") |
|
|
| print("\n PASS: Codebook computation verified\n") |
|
|
|
|
| def test_packing(): |
| """Verify uint4 and uint2 pack/unpack round-trip.""" |
| from turboquant.packing import pack_uint4, unpack_uint4, pack_uint2, unpack_uint2 |
|
|
| print("=" * 60) |
| print("TEST: Bit packing round-trip") |
| print("=" * 60) |
|
|
| |
| x4 = torch.randint(0, 16, (4, 8, 128), dtype=torch.uint8) |
| packed4 = pack_uint4(x4) |
| unpacked4 = unpack_uint4(packed4) |
| assert torch.equal(x4, unpacked4), "uint4 round-trip FAILED" |
| print(f" uint4: {x4.shape} → {packed4.shape} → {unpacked4.shape} ✓") |
|
|
| |
| x2 = torch.randint(0, 4, (4, 8, 128), dtype=torch.uint8) |
| packed2 = pack_uint2(x2) |
| unpacked2 = unpack_uint2(packed2) |
| assert torch.equal(x2, unpacked2), "uint2 round-trip FAILED" |
| print(f" uint2: {x2.shape} → {packed2.shape} → {unpacked2.shape} ✓") |
|
|
| print("\n PASS: Packing round-trip verified\n") |
|
|
|
|
| def test_quantizer_mse(): |
| """Verify quantize→dequantize MSE matches paper's theoretical bounds.""" |
| from turboquant.quantizer import TurboQuantizer |
|
|
| print("=" * 60) |
| print("TEST: Quantizer MSE on random unit vectors") |
| print("=" * 60) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dim = 128 |
| n_vectors = 10000 |
| paper_bounds = {2: 0.117, 4: 0.009} |
|
|
| for bits in [2, 4]: |
| quantizer = TurboQuantizer(dim=dim, bits=bits, device=device, seed=42) |
|
|
| |
| x = torch.randn(n_vectors, dim, device=device) |
| x = x / x.norm(dim=-1, keepdim=True) |
| x_bf16 = x.bfloat16() |
|
|
| |
| packed, norms = quantizer.quantize(x_bf16) |
| x_recon = quantizer.dequantize(packed, norms) |
|
|
| |
| mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1).mean().item() |
| bound = paper_bounds[bits] |
|
|
| print(f"\n b={bits}:") |
| print(f" Vectors tested: {n_vectors}") |
| print(f" Empirical MSE: {mse:.6f}") |
| print(f" Paper bound: {bound:.6f}") |
| print(f" Ratio (emp/bnd): {mse/bound:.3f}") |
| print(f" Within bound: {mse <= bound * 1.1}") |
|
|
| |
| per_vec_mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1) |
| print(f" MSE p50/p95/max: {per_vec_mse.median():.6f} / " |
| f"{per_vec_mse.quantile(0.95):.6f} / {per_vec_mse.max():.6f}") |
|
|
| print("\n PASS: MSE within theoretical bounds\n") |
|
|
|
|
| def test_quantizer_shapes(): |
| """Verify correct tensor shapes through quantize/dequantize.""" |
| from turboquant.quantizer import TurboQuantizer |
|
|
| print("=" * 60) |
| print("TEST: Tensor shapes (simulating KV cache)") |
| print("=" * 60) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dim = 128 |
| quantizer = TurboQuantizer(dim=dim, bits=4, device=device, seed=0) |
|
|
| |
| batch, heads, seq_len = 2, 8, 1024 |
| x = torch.randn(batch, heads, seq_len, dim, device=device, dtype=torch.bfloat16) |
|
|
| packed, norms = quantizer.quantize(x) |
| x_recon = quantizer.dequantize(packed, norms) |
|
|
| print(f" Input: {x.shape} {x.dtype}") |
| print(f" Packed: {packed.shape} {packed.dtype}") |
| print(f" Norms: {norms.shape} {norms.dtype}") |
| print(f" Recon: {x_recon.shape} {x_recon.dtype}") |
| print(f" Shape match: {x.shape == x_recon.shape}") |
| print(f" Dtype match: {x.dtype == x_recon.dtype}") |
|
|
| |
| original_bytes = x.numel() * 2 |
| quant_bytes = packed.numel() * 1 + norms.numel() * 2 |
| ratio = original_bytes / quant_bytes |
| print(f"\n Original: {original_bytes / 1024:.1f} KB") |
| print(f" Quantized: {quant_bytes / 1024:.1f} KB") |
| print(f" Compression: {ratio:.2f}x") |
|
|
| assert x.shape == x_recon.shape, "Shape mismatch!" |
| assert x.dtype == x_recon.dtype, "Dtype mismatch!" |
| print("\n PASS: Shapes and dtypes correct\n") |
|
|
|
|
| def test_fixed_point(): |
| """Verify that quantize→dequantize→requantize→dequantize is stable.""" |
| from turboquant.quantizer import TurboQuantizer |
|
|
| print("=" * 60) |
| print("TEST: Double quantization stability (fixed-point)") |
| print("=" * 60) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| quantizer = TurboQuantizer(dim=128, bits=4, device=device, seed=42) |
|
|
| x = torch.randn(100, 128, device=device, dtype=torch.bfloat16) |
|
|
| |
| packed1, norms1 = quantizer.quantize(x) |
| x_recon1 = quantizer.dequantize(packed1, norms1) |
|
|
| |
| packed2, norms2 = quantizer.quantize(x_recon1) |
| x_recon2 = quantizer.dequantize(packed2, norms2) |
|
|
| |
| indices_match = torch.equal(packed1, packed2) |
| recon_diff = (x_recon1.float() - x_recon2.float()).abs().max().item() |
|
|
| print(f" Packed indices identical: {indices_match}") |
| print(f" Max reconstruction diff: {recon_diff:.2e}") |
| print(f" Norm diff (max): {(norms1.float() - norms2.float()).abs().max().item():.2e}") |
|
|
| if not indices_match: |
| n_diff = (packed1 != packed2).sum().item() |
| print(f" WARNING: {n_diff} packed bytes differ (FP rounding at boundaries)") |
|
|
| print("\n PASS: Double quantization stable\n") |
|
|
|
|
| if __name__ == "__main__": |
| test_codebook() |
| test_packing() |
| test_quantizer_mse() |
| test_quantizer_shapes() |
| test_fixed_point() |
| print("=" * 60) |
| print("ALL TESTS PASSED") |
| print("=" * 60) |
|
|