""" Per-Head Mixed-Precision KV Cache Using PyTorch for correctness, Triton optimization later. """ import torch import json import os def quantize_head(x: torch.Tensor, bits: int): """Quantize [seq, head_dim] tensor to given bits.""" x = x.float() x_min = x.min() x_max = x.max() if bits == 8: qmax = 255.0 elif bits == 4: qmax = 15.0 else: raise ValueError(f"Unsupported bits: {bits}") scale = (x_max - x_min).clamp(min=1e-8) / qmax zp = x_min q = ((x - zp) / scale).round().clamp(0, qmax).to(torch.uint8) return q, scale, zp def dequantize_head(q, scale, zp, bits, original_shape): """Dequantize back to float16.""" x = q.float() * scale + zp return x.to(torch.float16).view(original_shape) class MixedPrecisionKVCache: """ Stores quantized K and V for all heads in one layer. bit_alloc: list of ints, one per head (4 or 8) """ def __init__(self, bit_alloc: list): self.bit_alloc = bit_alloc self.k_cache = [] self.v_cache = [] def store(self, k: torch.Tensor, v: torch.Tensor): """k, v: [batch, num_heads, seq, head_dim]""" self.k_cache = [] self.v_cache = [] for h in range(k.shape[1]): bits = self.bit_alloc[h] k_head = k[0, h] v_head = v[0, h] kq, ks, kz = quantize_head(k_head, bits) vq, vs, vz = quantize_head(v_head, bits) self.k_cache.append((kq, ks, kz, k_head.shape, bits)) self.v_cache.append((vq, vs, vz, v_head.shape, bits)) def retrieve(self): """Dequantize all heads, return [1, heads, seq, head_dim] float16.""" ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache] vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache] k = torch.stack(ks, dim=0).unsqueeze(0) v = torch.stack(vs, dim=0).unsqueeze(0) return k, v def memory_bytes(self): """Theoretical memory — 4-bit stored as uint8 (not truly packed).""" total = 0 for (q, s, z, sh, bits) in self.k_cache + self.v_cache: if bits == 4: total += q.numel() // 2 + 8 # theoretical packed size else: total += q.numel() + 8 return total def real_gpu_bytes(self): """Actual GPU memory used by tensors.""" total = 0 for (q, s, z, sh, bits) in self.k_cache + self.v_cache: total += q.numel() + 8 # actual bytes on GPU (uint8 for 4-bit = wasteful) return total if __name__ == "__main__": print("Testing MixedPrecisionKVCache...") torch.manual_seed(42) k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda") v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda") # test 8-bit only first print("\n--- 8-bit only ---") bit_alloc = [8] * 8 cache = MixedPrecisionKVCache(bit_alloc) cache.store(k, v) k_out, v_out = cache.retrieve() k_err = (k - k_out).abs().mean().item() v_err = (v - v_out).abs().mean().item() print(f"K error: {k_err:.6f} V error: {v_err:.6f}") assert k_err < 0.01, f"8-bit K error too high: {k_err}" print("āœ… 8-bit passed!") # test 4-bit only print("\n--- 4-bit only ---") bit_alloc = [4] * 8 cache = MixedPrecisionKVCache(bit_alloc) cache.store(k, v) k_out, v_out = cache.retrieve() k_err = (k - k_out).abs().mean().item() v_err = (v - v_out).abs().mean().item() print(f"K error: {k_err:.6f} V error: {v_err:.6f}") assert k_err < 0.2, f"4-bit K error too high: {k_err}" # 4-bit on random data ~0.14, real KV data is lower print("āœ… 4-bit passed!") # test mixed print("\n--- Mixed 4/8-bit ---") bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8] cache = MixedPrecisionKVCache(bit_alloc) cache.store(k, v) k_out, v_out = cache.retrieve() k_err = (k - k_out).abs().mean().item() v_err = (v - v_out).abs().mean().item() print(f"K error: {k_err:.6f} V error: {v_err:.6f}") fp16_bytes = k.numel() * 2 * 2 quant_bytes = cache.memory_bytes() print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB") print(f"Quant memory: {quant_bytes/1024:.1f} KB") print(f"Compression: {fp16_bytes/quant_bytes:.2f}x") print("\nāœ… All tests passed!")