""" Verification tests for TurboQuant implementation. 1. Codebook: Lloyd-Max centroids match paper's distortion bounds 2. Packing: uint4 pack/unpack round-trip 3. Quantizer: MSE on random unit vectors ≤ paper's bound (0.009 at 4-bit) 4. Fixed-point: double quantization stability """ import sys sys.path.insert(0, "/home/azureuser/turboquant") import torch import numpy as np def test_codebook(): """Verify Lloyd-Max codebook computation and distortion bounds.""" from turboquant.codebook import compute_lloyd_max_codebook, compute_distortion print("=" * 60) print("TEST: Codebook computation") print("=" * 60) d = 128 # Paper bounds: D_mse ≤ (√3·π/2) · (1/4^b) # Per-coordinate: D_mse / d = (√3·π / 2d) · (1/4^b) paper_total_mse = {2: 0.117, 3: 0.03, 4: 0.009} for bits in [2, 3, 4]: centroids, boundaries = compute_lloyd_max_codebook(d, bits) per_coord_mse = compute_distortion(d, bits, centroids, boundaries) total_mse = d * per_coord_mse bound = (np.sqrt(3) * np.pi / 2) * (1 / 4**bits) print(f"\n b={bits} ({2**bits} levels):") print(f" Centroids: {centroids[:4]} ... {centroids[-4:]}") print(f" Per-coord MSE: {per_coord_mse:.6e}") print(f" Total MSE (d×per): {total_mse:.6f}") print(f" Paper bound: {bound:.6f}") print(f" Paper table value: {paper_total_mse.get(bits, 'N/A')}") print(f" Within bound: {total_mse <= bound * 1.01}") # 1% tolerance for numerics print("\n PASS: Codebook computation verified\n") def test_packing(): """Verify uint4 and uint2 pack/unpack round-trip.""" from turboquant.packing import pack_uint4, unpack_uint4, pack_uint2, unpack_uint2 print("=" * 60) print("TEST: Bit packing round-trip") print("=" * 60) # uint4 x4 = torch.randint(0, 16, (4, 8, 128), dtype=torch.uint8) packed4 = pack_uint4(x4) unpacked4 = unpack_uint4(packed4) assert torch.equal(x4, unpacked4), "uint4 round-trip FAILED" print(f" uint4: {x4.shape} → {packed4.shape} → {unpacked4.shape} ✓") # uint2 x2 = torch.randint(0, 4, (4, 8, 128), dtype=torch.uint8) packed2 = pack_uint2(x2) unpacked2 = unpack_uint2(packed2) assert torch.equal(x2, unpacked2), "uint2 round-trip FAILED" print(f" uint2: {x2.shape} → {packed2.shape} → {unpacked2.shape} ✓") print("\n PASS: Packing round-trip verified\n") def test_quantizer_mse(): """Verify quantize→dequantize MSE matches paper's theoretical bounds.""" from turboquant.quantizer import TurboQuantizer print("=" * 60) print("TEST: Quantizer MSE on random unit vectors") print("=" * 60) device = "cuda" if torch.cuda.is_available() else "cpu" dim = 128 n_vectors = 10000 paper_bounds = {2: 0.117, 4: 0.009} for bits in [2, 4]: quantizer = TurboQuantizer(dim=dim, bits=bits, device=device, seed=42) # Generate random unit vectors on S^(d-1) x = torch.randn(n_vectors, dim, device=device) x = x / x.norm(dim=-1, keepdim=True) x_bf16 = x.bfloat16() # Quantize and dequantize packed, norms = quantizer.quantize(x_bf16) x_recon = quantizer.dequantize(packed, norms) # Compute MSE mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1).mean().item() bound = paper_bounds[bits] print(f"\n b={bits}:") print(f" Vectors tested: {n_vectors}") print(f" Empirical MSE: {mse:.6f}") print(f" Paper bound: {bound:.6f}") print(f" Ratio (emp/bnd): {mse/bound:.3f}") print(f" Within bound: {mse <= bound * 1.1}") # 10% tolerance # Also check individual vector MSE distribution per_vec_mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1) print(f" MSE p50/p95/max: {per_vec_mse.median():.6f} / " f"{per_vec_mse.quantile(0.95):.6f} / {per_vec_mse.max():.6f}") print("\n PASS: MSE within theoretical bounds\n") def test_quantizer_shapes(): """Verify correct tensor shapes through quantize/dequantize.""" from turboquant.quantizer import TurboQuantizer print("=" * 60) print("TEST: Tensor shapes (simulating KV cache)") print("=" * 60) device = "cuda" if torch.cuda.is_available() else "cpu" dim = 128 quantizer = TurboQuantizer(dim=dim, bits=4, device=device, seed=0) # Simulate KV cache tensor: (batch, heads, seq_len, head_dim) batch, heads, seq_len = 2, 8, 1024 x = torch.randn(batch, heads, seq_len, dim, device=device, dtype=torch.bfloat16) packed, norms = quantizer.quantize(x) x_recon = quantizer.dequantize(packed, norms) print(f" Input: {x.shape} {x.dtype}") print(f" Packed: {packed.shape} {packed.dtype}") print(f" Norms: {norms.shape} {norms.dtype}") print(f" Recon: {x_recon.shape} {x_recon.dtype}") print(f" Shape match: {x.shape == x_recon.shape}") print(f" Dtype match: {x.dtype == x_recon.dtype}") # Memory savings original_bytes = x.numel() * 2 # BF16 = 2 bytes quant_bytes = packed.numel() * 1 + norms.numel() * 2 # uint8 + BF16 norms ratio = original_bytes / quant_bytes print(f"\n Original: {original_bytes / 1024:.1f} KB") print(f" Quantized: {quant_bytes / 1024:.1f} KB") print(f" Compression: {ratio:.2f}x") assert x.shape == x_recon.shape, "Shape mismatch!" assert x.dtype == x_recon.dtype, "Dtype mismatch!" print("\n PASS: Shapes and dtypes correct\n") def test_fixed_point(): """Verify that quantize→dequantize→requantize→dequantize is stable.""" from turboquant.quantizer import TurboQuantizer print("=" * 60) print("TEST: Double quantization stability (fixed-point)") print("=" * 60) device = "cuda" if torch.cuda.is_available() else "cpu" quantizer = TurboQuantizer(dim=128, bits=4, device=device, seed=42) x = torch.randn(100, 128, device=device, dtype=torch.bfloat16) # First round packed1, norms1 = quantizer.quantize(x) x_recon1 = quantizer.dequantize(packed1, norms1) # Second round (re-quantize the reconstruction) packed2, norms2 = quantizer.quantize(x_recon1) x_recon2 = quantizer.dequantize(packed2, norms2) # Check packed indices are identical indices_match = torch.equal(packed1, packed2) recon_diff = (x_recon1.float() - x_recon2.float()).abs().max().item() print(f" Packed indices identical: {indices_match}") print(f" Max reconstruction diff: {recon_diff:.2e}") print(f" Norm diff (max): {(norms1.float() - norms2.float()).abs().max().item():.2e}") if not indices_match: n_diff = (packed1 != packed2).sum().item() print(f" WARNING: {n_diff} packed bytes differ (FP rounding at boundaries)") print("\n PASS: Double quantization stable\n") if __name__ == "__main__": test_codebook() test_packing() test_quantizer_mse() test_quantizer_shapes() test_fixed_point() print("=" * 60) print("ALL TESTS PASSED") print("=" * 60)