File size: 7,166 Bytes
d4ec3e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """
Verification tests for TurboQuant implementation.
1. Codebook: Lloyd-Max centroids match paper's distortion bounds
2. Packing: uint4 pack/unpack round-trip
3. Quantizer: MSE on random unit vectors ≤ paper's bound (0.009 at 4-bit)
4. Fixed-point: double quantization stability
"""
import sys
sys.path.insert(0, "/home/azureuser/turboquant")
import torch
import numpy as np
def test_codebook():
"""Verify Lloyd-Max codebook computation and distortion bounds."""
from turboquant.codebook import compute_lloyd_max_codebook, compute_distortion
print("=" * 60)
print("TEST: Codebook computation")
print("=" * 60)
d = 128
# Paper bounds: D_mse ≤ (√3·π/2) · (1/4^b)
# Per-coordinate: D_mse / d = (√3·π / 2d) · (1/4^b)
paper_total_mse = {2: 0.117, 3: 0.03, 4: 0.009}
for bits in [2, 3, 4]:
centroids, boundaries = compute_lloyd_max_codebook(d, bits)
per_coord_mse = compute_distortion(d, bits, centroids, boundaries)
total_mse = d * per_coord_mse
bound = (np.sqrt(3) * np.pi / 2) * (1 / 4**bits)
print(f"\n b={bits} ({2**bits} levels):")
print(f" Centroids: {centroids[:4]} ... {centroids[-4:]}")
print(f" Per-coord MSE: {per_coord_mse:.6e}")
print(f" Total MSE (d×per): {total_mse:.6f}")
print(f" Paper bound: {bound:.6f}")
print(f" Paper table value: {paper_total_mse.get(bits, 'N/A')}")
print(f" Within bound: {total_mse <= bound * 1.01}") # 1% tolerance for numerics
print("\n PASS: Codebook computation verified\n")
def test_packing():
"""Verify uint4 and uint2 pack/unpack round-trip."""
from turboquant.packing import pack_uint4, unpack_uint4, pack_uint2, unpack_uint2
print("=" * 60)
print("TEST: Bit packing round-trip")
print("=" * 60)
# uint4
x4 = torch.randint(0, 16, (4, 8, 128), dtype=torch.uint8)
packed4 = pack_uint4(x4)
unpacked4 = unpack_uint4(packed4)
assert torch.equal(x4, unpacked4), "uint4 round-trip FAILED"
print(f" uint4: {x4.shape} → {packed4.shape} → {unpacked4.shape} ✓")
# uint2
x2 = torch.randint(0, 4, (4, 8, 128), dtype=torch.uint8)
packed2 = pack_uint2(x2)
unpacked2 = unpack_uint2(packed2)
assert torch.equal(x2, unpacked2), "uint2 round-trip FAILED"
print(f" uint2: {x2.shape} → {packed2.shape} → {unpacked2.shape} ✓")
print("\n PASS: Packing round-trip verified\n")
def test_quantizer_mse():
"""Verify quantize→dequantize MSE matches paper's theoretical bounds."""
from turboquant.quantizer import TurboQuantizer
print("=" * 60)
print("TEST: Quantizer MSE on random unit vectors")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
dim = 128
n_vectors = 10000
paper_bounds = {2: 0.117, 4: 0.009}
for bits in [2, 4]:
quantizer = TurboQuantizer(dim=dim, bits=bits, device=device, seed=42)
# Generate random unit vectors on S^(d-1)
x = torch.randn(n_vectors, dim, device=device)
x = x / x.norm(dim=-1, keepdim=True)
x_bf16 = x.bfloat16()
# Quantize and dequantize
packed, norms = quantizer.quantize(x_bf16)
x_recon = quantizer.dequantize(packed, norms)
# Compute MSE
mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1).mean().item()
bound = paper_bounds[bits]
print(f"\n b={bits}:")
print(f" Vectors tested: {n_vectors}")
print(f" Empirical MSE: {mse:.6f}")
print(f" Paper bound: {bound:.6f}")
print(f" Ratio (emp/bnd): {mse/bound:.3f}")
print(f" Within bound: {mse <= bound * 1.1}") # 10% tolerance
# Also check individual vector MSE distribution
per_vec_mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1)
print(f" MSE p50/p95/max: {per_vec_mse.median():.6f} / "
f"{per_vec_mse.quantile(0.95):.6f} / {per_vec_mse.max():.6f}")
print("\n PASS: MSE within theoretical bounds\n")
def test_quantizer_shapes():
"""Verify correct tensor shapes through quantize/dequantize."""
from turboquant.quantizer import TurboQuantizer
print("=" * 60)
print("TEST: Tensor shapes (simulating KV cache)")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
dim = 128
quantizer = TurboQuantizer(dim=dim, bits=4, device=device, seed=0)
# Simulate KV cache tensor: (batch, heads, seq_len, head_dim)
batch, heads, seq_len = 2, 8, 1024
x = torch.randn(batch, heads, seq_len, dim, device=device, dtype=torch.bfloat16)
packed, norms = quantizer.quantize(x)
x_recon = quantizer.dequantize(packed, norms)
print(f" Input: {x.shape} {x.dtype}")
print(f" Packed: {packed.shape} {packed.dtype}")
print(f" Norms: {norms.shape} {norms.dtype}")
print(f" Recon: {x_recon.shape} {x_recon.dtype}")
print(f" Shape match: {x.shape == x_recon.shape}")
print(f" Dtype match: {x.dtype == x_recon.dtype}")
# Memory savings
original_bytes = x.numel() * 2 # BF16 = 2 bytes
quant_bytes = packed.numel() * 1 + norms.numel() * 2 # uint8 + BF16 norms
ratio = original_bytes / quant_bytes
print(f"\n Original: {original_bytes / 1024:.1f} KB")
print(f" Quantized: {quant_bytes / 1024:.1f} KB")
print(f" Compression: {ratio:.2f}x")
assert x.shape == x_recon.shape, "Shape mismatch!"
assert x.dtype == x_recon.dtype, "Dtype mismatch!"
print("\n PASS: Shapes and dtypes correct\n")
def test_fixed_point():
"""Verify that quantize→dequantize→requantize→dequantize is stable."""
from turboquant.quantizer import TurboQuantizer
print("=" * 60)
print("TEST: Double quantization stability (fixed-point)")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
quantizer = TurboQuantizer(dim=128, bits=4, device=device, seed=42)
x = torch.randn(100, 128, device=device, dtype=torch.bfloat16)
# First round
packed1, norms1 = quantizer.quantize(x)
x_recon1 = quantizer.dequantize(packed1, norms1)
# Second round (re-quantize the reconstruction)
packed2, norms2 = quantizer.quantize(x_recon1)
x_recon2 = quantizer.dequantize(packed2, norms2)
# Check packed indices are identical
indices_match = torch.equal(packed1, packed2)
recon_diff = (x_recon1.float() - x_recon2.float()).abs().max().item()
print(f" Packed indices identical: {indices_match}")
print(f" Max reconstruction diff: {recon_diff:.2e}")
print(f" Norm diff (max): {(norms1.float() - norms2.float()).abs().max().item():.2e}")
if not indices_match:
n_diff = (packed1 != packed2).sum().item()
print(f" WARNING: {n_diff} packed bytes differ (FP rounding at boundaries)")
print("\n PASS: Double quantization stable\n")
if __name__ == "__main__":
test_codebook()
test_packing()
test_quantizer_mse()
test_quantizer_shapes()
test_fixed_point()
print("=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)
|