| """ |
| Per-Head Mixed-Precision KV Cache |
| Using PyTorch for correctness, Triton optimization later. |
| """ |
|
|
| import torch |
| import json |
| import os |
|
|
|
|
| def quantize_head(x: torch.Tensor, bits: int): |
| """Quantize [seq, head_dim] tensor to given bits.""" |
| x = x.float() |
| x_min = x.min() |
| x_max = x.max() |
| |
| if bits == 8: |
| qmax = 255.0 |
| elif bits == 4: |
| qmax = 15.0 |
| else: |
| raise ValueError(f"Unsupported bits: {bits}") |
| |
| scale = (x_max - x_min).clamp(min=1e-8) / qmax |
| zp = x_min |
| |
| q = ((x - zp) / scale).round().clamp(0, qmax).to(torch.uint8) |
| return q, scale, zp |
|
|
|
|
| def dequantize_head(q, scale, zp, bits, original_shape): |
| """Dequantize back to float16.""" |
| x = q.float() * scale + zp |
| return x.to(torch.float16).view(original_shape) |
|
|
|
|
| class MixedPrecisionKVCache: |
| """ |
| Stores quantized K and V for all heads in one layer. |
| bit_alloc: list of ints, one per head (4 or 8) |
| """ |
| def __init__(self, bit_alloc: list): |
| self.bit_alloc = bit_alloc |
| self.k_cache = [] |
| self.v_cache = [] |
|
|
| def store(self, k: torch.Tensor, v: torch.Tensor): |
| """k, v: [batch, num_heads, seq, head_dim]""" |
| self.k_cache = [] |
| self.v_cache = [] |
| for h in range(k.shape[1]): |
| bits = self.bit_alloc[h] |
| k_head = k[0, h] |
| v_head = v[0, h] |
| kq, ks, kz = quantize_head(k_head, bits) |
| vq, vs, vz = quantize_head(v_head, bits) |
| self.k_cache.append((kq, ks, kz, k_head.shape, bits)) |
| self.v_cache.append((vq, vs, vz, v_head.shape, bits)) |
|
|
| def retrieve(self): |
| """Dequantize all heads, return [1, heads, seq, head_dim] float16.""" |
| ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache] |
| vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache] |
| k = torch.stack(ks, dim=0).unsqueeze(0) |
| v = torch.stack(vs, dim=0).unsqueeze(0) |
| return k, v |
|
|
| def memory_bytes(self): |
| """Theoretical memory — 4-bit stored as uint8 (not truly packed).""" |
| total = 0 |
| for (q, s, z, sh, bits) in self.k_cache + self.v_cache: |
| if bits == 4: |
| total += q.numel() // 2 + 8 |
| else: |
| total += q.numel() + 8 |
| return total |
|
|
| def real_gpu_bytes(self): |
| """Actual GPU memory used by tensors.""" |
| total = 0 |
| for (q, s, z, sh, bits) in self.k_cache + self.v_cache: |
| total += q.numel() + 8 |
| return total |
|
|
|
|
| if __name__ == "__main__": |
| print("Testing MixedPrecisionKVCache...") |
| torch.manual_seed(42) |
| k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda") |
| v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda") |
|
|
| |
| print("\n--- 8-bit only ---") |
| bit_alloc = [8] * 8 |
| cache = MixedPrecisionKVCache(bit_alloc) |
| cache.store(k, v) |
| k_out, v_out = cache.retrieve() |
| k_err = (k - k_out).abs().mean().item() |
| v_err = (v - v_out).abs().mean().item() |
| print(f"K error: {k_err:.6f} V error: {v_err:.6f}") |
| assert k_err < 0.01, f"8-bit K error too high: {k_err}" |
| print("✅ 8-bit passed!") |
|
|
| |
| print("\n--- 4-bit only ---") |
| bit_alloc = [4] * 8 |
| cache = MixedPrecisionKVCache(bit_alloc) |
| cache.store(k, v) |
| k_out, v_out = cache.retrieve() |
| k_err = (k - k_out).abs().mean().item() |
| v_err = (v - v_out).abs().mean().item() |
| print(f"K error: {k_err:.6f} V error: {v_err:.6f}") |
| assert k_err < 0.2, f"4-bit K error too high: {k_err}" |
| print("✅ 4-bit passed!") |
|
|
| |
| print("\n--- Mixed 4/8-bit ---") |
| bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8] |
| cache = MixedPrecisionKVCache(bit_alloc) |
| cache.store(k, v) |
| k_out, v_out = cache.retrieve() |
| k_err = (k - k_out).abs().mean().item() |
| v_err = (v - v_out).abs().mean().item() |
| print(f"K error: {k_err:.6f} V error: {v_err:.6f}") |
|
|
| fp16_bytes = k.numel() * 2 * 2 |
| quant_bytes = cache.memory_bytes() |
| print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB") |
| print(f"Quant memory: {quant_bytes/1024:.1f} KB") |
| print(f"Compression: {fp16_bytes/quant_bytes:.2f}x") |
| print("\n✅ All tests passed!") |