kv-cache-compression / kernel /quant_cache.py
harshithsaiv's picture
feat: complete honest 4-method benchmark both models
5e16ca3
"""
Per-Head Mixed-Precision KV Cache
Using PyTorch for correctness, Triton optimization later.
"""
import torch
import json
import os
def quantize_head(x: torch.Tensor, bits: int):
"""Quantize [seq, head_dim] tensor to given bits."""
x = x.float()
x_min = x.min()
x_max = x.max()
if bits == 8:
qmax = 255.0
elif bits == 4:
qmax = 15.0
else:
raise ValueError(f"Unsupported bits: {bits}")
scale = (x_max - x_min).clamp(min=1e-8) / qmax
zp = x_min
q = ((x - zp) / scale).round().clamp(0, qmax).to(torch.uint8)
return q, scale, zp
def dequantize_head(q, scale, zp, bits, original_shape):
"""Dequantize back to float16."""
x = q.float() * scale + zp
return x.to(torch.float16).view(original_shape)
class MixedPrecisionKVCache:
"""
Stores quantized K and V for all heads in one layer.
bit_alloc: list of ints, one per head (4 or 8)
"""
def __init__(self, bit_alloc: list):
self.bit_alloc = bit_alloc
self.k_cache = []
self.v_cache = []
def store(self, k: torch.Tensor, v: torch.Tensor):
"""k, v: [batch, num_heads, seq, head_dim]"""
self.k_cache = []
self.v_cache = []
for h in range(k.shape[1]):
bits = self.bit_alloc[h]
k_head = k[0, h]
v_head = v[0, h]
kq, ks, kz = quantize_head(k_head, bits)
vq, vs, vz = quantize_head(v_head, bits)
self.k_cache.append((kq, ks, kz, k_head.shape, bits))
self.v_cache.append((vq, vs, vz, v_head.shape, bits))
def retrieve(self):
"""Dequantize all heads, return [1, heads, seq, head_dim] float16."""
ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache]
vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache]
k = torch.stack(ks, dim=0).unsqueeze(0)
v = torch.stack(vs, dim=0).unsqueeze(0)
return k, v
def memory_bytes(self):
"""Theoretical memory — 4-bit stored as uint8 (not truly packed)."""
total = 0
for (q, s, z, sh, bits) in self.k_cache + self.v_cache:
if bits == 4:
total += q.numel() // 2 + 8 # theoretical packed size
else:
total += q.numel() + 8
return total
def real_gpu_bytes(self):
"""Actual GPU memory used by tensors."""
total = 0
for (q, s, z, sh, bits) in self.k_cache + self.v_cache:
total += q.numel() + 8 # actual bytes on GPU (uint8 for 4-bit = wasteful)
return total
if __name__ == "__main__":
print("Testing MixedPrecisionKVCache...")
torch.manual_seed(42)
k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
# test 8-bit only first
print("\n--- 8-bit only ---")
bit_alloc = [8] * 8
cache = MixedPrecisionKVCache(bit_alloc)
cache.store(k, v)
k_out, v_out = cache.retrieve()
k_err = (k - k_out).abs().mean().item()
v_err = (v - v_out).abs().mean().item()
print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
assert k_err < 0.01, f"8-bit K error too high: {k_err}"
print("✅ 8-bit passed!")
# test 4-bit only
print("\n--- 4-bit only ---")
bit_alloc = [4] * 8
cache = MixedPrecisionKVCache(bit_alloc)
cache.store(k, v)
k_out, v_out = cache.retrieve()
k_err = (k - k_out).abs().mean().item()
v_err = (v - v_out).abs().mean().item()
print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
assert k_err < 0.2, f"4-bit K error too high: {k_err}" # 4-bit on random data ~0.14, real KV data is lower
print("✅ 4-bit passed!")
# test mixed
print("\n--- Mixed 4/8-bit ---")
bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
cache = MixedPrecisionKVCache(bit_alloc)
cache.store(k, v)
k_out, v_out = cache.retrieve()
k_err = (k - k_out).abs().mean().item()
v_err = (v - v_out).abs().mean().item()
print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
fp16_bytes = k.numel() * 2 * 2
quant_bytes = cache.memory_bytes()
print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB")
print(f"Quant memory: {quant_bytes/1024:.1f} KB")
print(f"Compression: {fp16_bytes/quant_bytes:.2f}x")
print("\n✅ All tests passed!")