File size: 4,387 Bytes
91c163e e23db09 91c163e e23db09 91c163e e23db09 91c163e e23db09 91c163e e23db09 91c163e e23db09 91c163e bc4bbbe e23db09 91c163e e23db09 91c163e bc4bbbe 91c163e e23db09 91c163e bc4bbbe 91c163e bc4bbbe 91c163e bc4bbbe e23db09 bc4bbbe 91c163e bc4bbbe 5e16ca3 e23db09 5e16ca3 e23db09 91c163e 5e16ca3 91c163e e23db09 91c163e e23db09 4b2bdf2 e23db09 91c163e e23db09 91c163e e23db09 91c163e bc4bbbe 91c163e e23db09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Per-Head Mixed-Precision KV Cache
Using PyTorch for correctness, Triton optimization later.
"""
import torch
import json
import os
def quantize_head(x: torch.Tensor, bits: int):
"""Quantize [seq, head_dim] tensor to given bits."""
x = x.float()
x_min = x.min()
x_max = x.max()
if bits == 8:
qmax = 255.0
elif bits == 4:
qmax = 15.0
else:
raise ValueError(f"Unsupported bits: {bits}")
scale = (x_max - x_min).clamp(min=1e-8) / qmax
zp = x_min
q = ((x - zp) / scale).round().clamp(0, qmax).to(torch.uint8)
return q, scale, zp
def dequantize_head(q, scale, zp, bits, original_shape):
"""Dequantize back to float16."""
x = q.float() * scale + zp
return x.to(torch.float16).view(original_shape)
class MixedPrecisionKVCache:
"""
Stores quantized K and V for all heads in one layer.
bit_alloc: list of ints, one per head (4 or 8)
"""
def __init__(self, bit_alloc: list):
self.bit_alloc = bit_alloc
self.k_cache = []
self.v_cache = []
def store(self, k: torch.Tensor, v: torch.Tensor):
"""k, v: [batch, num_heads, seq, head_dim]"""
self.k_cache = []
self.v_cache = []
for h in range(k.shape[1]):
bits = self.bit_alloc[h]
k_head = k[0, h]
v_head = v[0, h]
kq, ks, kz = quantize_head(k_head, bits)
vq, vs, vz = quantize_head(v_head, bits)
self.k_cache.append((kq, ks, kz, k_head.shape, bits))
self.v_cache.append((vq, vs, vz, v_head.shape, bits))
def retrieve(self):
"""Dequantize all heads, return [1, heads, seq, head_dim] float16."""
ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache]
vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache]
k = torch.stack(ks, dim=0).unsqueeze(0)
v = torch.stack(vs, dim=0).unsqueeze(0)
return k, v
def memory_bytes(self):
"""Theoretical memory — 4-bit stored as uint8 (not truly packed)."""
total = 0
for (q, s, z, sh, bits) in self.k_cache + self.v_cache:
if bits == 4:
total += q.numel() // 2 + 8 # theoretical packed size
else:
total += q.numel() + 8
return total
def real_gpu_bytes(self):
"""Actual GPU memory used by tensors."""
total = 0
for (q, s, z, sh, bits) in self.k_cache + self.v_cache:
total += q.numel() + 8 # actual bytes on GPU (uint8 for 4-bit = wasteful)
return total
if __name__ == "__main__":
print("Testing MixedPrecisionKVCache...")
torch.manual_seed(42)
k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
# test 8-bit only first
print("\n--- 8-bit only ---")
bit_alloc = [8] * 8
cache = MixedPrecisionKVCache(bit_alloc)
cache.store(k, v)
k_out, v_out = cache.retrieve()
k_err = (k - k_out).abs().mean().item()
v_err = (v - v_out).abs().mean().item()
print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
assert k_err < 0.01, f"8-bit K error too high: {k_err}"
print("✅ 8-bit passed!")
# test 4-bit only
print("\n--- 4-bit only ---")
bit_alloc = [4] * 8
cache = MixedPrecisionKVCache(bit_alloc)
cache.store(k, v)
k_out, v_out = cache.retrieve()
k_err = (k - k_out).abs().mean().item()
v_err = (v - v_out).abs().mean().item()
print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
assert k_err < 0.2, f"4-bit K error too high: {k_err}" # 4-bit on random data ~0.14, real KV data is lower
print("✅ 4-bit passed!")
# test mixed
print("\n--- Mixed 4/8-bit ---")
bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
cache = MixedPrecisionKVCache(bit_alloc)
cache.store(k, v)
k_out, v_out = cache.retrieve()
k_err = (k - k_out).abs().mean().item()
v_err = (v - v_out).abs().mean().item()
print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
fp16_bytes = k.numel() * 2 * 2
quant_bytes = cache.memory_bytes()
print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB")
print(f"Quant memory: {quant_bytes/1024:.1f} KB")
print(f"Compression: {fp16_bytes/quant_bytes:.2f}x")
print("\n✅ All tests passed!") |