turboquant / scripts /verify.py
vivekvar's picture
Upload folder using huggingface_hub
d4ec3e8 verified
"""
Verification tests for TurboQuant implementation.
1. Codebook: Lloyd-Max centroids match paper's distortion bounds
2. Packing: uint4 pack/unpack round-trip
3. Quantizer: MSE on random unit vectors ≤ paper's bound (0.009 at 4-bit)
4. Fixed-point: double quantization stability
"""
import sys
sys.path.insert(0, "/home/azureuser/turboquant")
import torch
import numpy as np
def test_codebook():
"""Verify Lloyd-Max codebook computation and distortion bounds."""
from turboquant.codebook import compute_lloyd_max_codebook, compute_distortion
print("=" * 60)
print("TEST: Codebook computation")
print("=" * 60)
d = 128
# Paper bounds: D_mse ≤ (√3·π/2) · (1/4^b)
# Per-coordinate: D_mse / d = (√3·π / 2d) · (1/4^b)
paper_total_mse = {2: 0.117, 3: 0.03, 4: 0.009}
for bits in [2, 3, 4]:
centroids, boundaries = compute_lloyd_max_codebook(d, bits)
per_coord_mse = compute_distortion(d, bits, centroids, boundaries)
total_mse = d * per_coord_mse
bound = (np.sqrt(3) * np.pi / 2) * (1 / 4**bits)
print(f"\n b={bits} ({2**bits} levels):")
print(f" Centroids: {centroids[:4]} ... {centroids[-4:]}")
print(f" Per-coord MSE: {per_coord_mse:.6e}")
print(f" Total MSE (d×per): {total_mse:.6f}")
print(f" Paper bound: {bound:.6f}")
print(f" Paper table value: {paper_total_mse.get(bits, 'N/A')}")
print(f" Within bound: {total_mse <= bound * 1.01}") # 1% tolerance for numerics
print("\n PASS: Codebook computation verified\n")
def test_packing():
"""Verify uint4 and uint2 pack/unpack round-trip."""
from turboquant.packing import pack_uint4, unpack_uint4, pack_uint2, unpack_uint2
print("=" * 60)
print("TEST: Bit packing round-trip")
print("=" * 60)
# uint4
x4 = torch.randint(0, 16, (4, 8, 128), dtype=torch.uint8)
packed4 = pack_uint4(x4)
unpacked4 = unpack_uint4(packed4)
assert torch.equal(x4, unpacked4), "uint4 round-trip FAILED"
print(f" uint4: {x4.shape}{packed4.shape}{unpacked4.shape} ✓")
# uint2
x2 = torch.randint(0, 4, (4, 8, 128), dtype=torch.uint8)
packed2 = pack_uint2(x2)
unpacked2 = unpack_uint2(packed2)
assert torch.equal(x2, unpacked2), "uint2 round-trip FAILED"
print(f" uint2: {x2.shape}{packed2.shape}{unpacked2.shape} ✓")
print("\n PASS: Packing round-trip verified\n")
def test_quantizer_mse():
"""Verify quantize→dequantize MSE matches paper's theoretical bounds."""
from turboquant.quantizer import TurboQuantizer
print("=" * 60)
print("TEST: Quantizer MSE on random unit vectors")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
dim = 128
n_vectors = 10000
paper_bounds = {2: 0.117, 4: 0.009}
for bits in [2, 4]:
quantizer = TurboQuantizer(dim=dim, bits=bits, device=device, seed=42)
# Generate random unit vectors on S^(d-1)
x = torch.randn(n_vectors, dim, device=device)
x = x / x.norm(dim=-1, keepdim=True)
x_bf16 = x.bfloat16()
# Quantize and dequantize
packed, norms = quantizer.quantize(x_bf16)
x_recon = quantizer.dequantize(packed, norms)
# Compute MSE
mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1).mean().item()
bound = paper_bounds[bits]
print(f"\n b={bits}:")
print(f" Vectors tested: {n_vectors}")
print(f" Empirical MSE: {mse:.6f}")
print(f" Paper bound: {bound:.6f}")
print(f" Ratio (emp/bnd): {mse/bound:.3f}")
print(f" Within bound: {mse <= bound * 1.1}") # 10% tolerance
# Also check individual vector MSE distribution
per_vec_mse = (x_bf16.float() - x_recon.float()).pow(2).sum(dim=-1)
print(f" MSE p50/p95/max: {per_vec_mse.median():.6f} / "
f"{per_vec_mse.quantile(0.95):.6f} / {per_vec_mse.max():.6f}")
print("\n PASS: MSE within theoretical bounds\n")
def test_quantizer_shapes():
"""Verify correct tensor shapes through quantize/dequantize."""
from turboquant.quantizer import TurboQuantizer
print("=" * 60)
print("TEST: Tensor shapes (simulating KV cache)")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
dim = 128
quantizer = TurboQuantizer(dim=dim, bits=4, device=device, seed=0)
# Simulate KV cache tensor: (batch, heads, seq_len, head_dim)
batch, heads, seq_len = 2, 8, 1024
x = torch.randn(batch, heads, seq_len, dim, device=device, dtype=torch.bfloat16)
packed, norms = quantizer.quantize(x)
x_recon = quantizer.dequantize(packed, norms)
print(f" Input: {x.shape} {x.dtype}")
print(f" Packed: {packed.shape} {packed.dtype}")
print(f" Norms: {norms.shape} {norms.dtype}")
print(f" Recon: {x_recon.shape} {x_recon.dtype}")
print(f" Shape match: {x.shape == x_recon.shape}")
print(f" Dtype match: {x.dtype == x_recon.dtype}")
# Memory savings
original_bytes = x.numel() * 2 # BF16 = 2 bytes
quant_bytes = packed.numel() * 1 + norms.numel() * 2 # uint8 + BF16 norms
ratio = original_bytes / quant_bytes
print(f"\n Original: {original_bytes / 1024:.1f} KB")
print(f" Quantized: {quant_bytes / 1024:.1f} KB")
print(f" Compression: {ratio:.2f}x")
assert x.shape == x_recon.shape, "Shape mismatch!"
assert x.dtype == x_recon.dtype, "Dtype mismatch!"
print("\n PASS: Shapes and dtypes correct\n")
def test_fixed_point():
"""Verify that quantize→dequantize→requantize→dequantize is stable."""
from turboquant.quantizer import TurboQuantizer
print("=" * 60)
print("TEST: Double quantization stability (fixed-point)")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
quantizer = TurboQuantizer(dim=128, bits=4, device=device, seed=42)
x = torch.randn(100, 128, device=device, dtype=torch.bfloat16)
# First round
packed1, norms1 = quantizer.quantize(x)
x_recon1 = quantizer.dequantize(packed1, norms1)
# Second round (re-quantize the reconstruction)
packed2, norms2 = quantizer.quantize(x_recon1)
x_recon2 = quantizer.dequantize(packed2, norms2)
# Check packed indices are identical
indices_match = torch.equal(packed1, packed2)
recon_diff = (x_recon1.float() - x_recon2.float()).abs().max().item()
print(f" Packed indices identical: {indices_match}")
print(f" Max reconstruction diff: {recon_diff:.2e}")
print(f" Norm diff (max): {(norms1.float() - norms2.float()).abs().max().item():.2e}")
if not indices_match:
n_diff = (packed1 != packed2).sum().item()
print(f" WARNING: {n_diff} packed bytes differ (FP rounding at boundaries)")
print("\n PASS: Double quantization stable\n")
if __name__ == "__main__":
test_codebook()
test_packing()
test_quantizer_mse()
test_quantizer_shapes()
test_fixed_point()
print("=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)