""" Quick start example — compress KV cache in 10 lines. """ import torch import json import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from kernel.quant_cache import MixedPrecisionKVCache # simulate one layer of KV cache # batch=1, heads=8, seq=1024, head_dim=128 k = torch.randn(1, 8, 1024, 128, dtype=torch.float16, device="cuda") v = torch.randn(1, 8, 1024, 128, dtype=torch.float16, device="cuda") # define bit allocation per head (from calibration) # 4=compress aggressively, 8=keep quality bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8] # compress cache = MixedPrecisionKVCache(bit_alloc) cache.store(k, v) # retrieve k_out, v_out = cache.retrieve() # measure fp16_bytes = k.numel() * 2 * 2 quant_bytes = cache.memory_bytes() print(f"FP16: {fp16_bytes/1024:.0f} KB") print(f"Compressed: {quant_bytes/1024:.0f} KB") print(f"Ratio: {fp16_bytes/quant_bytes:.2f}x") print(f"K error: {(k - k_out).abs().mean():.6f}") print(f"V error: {(v - v_out).abs().mean():.6f}")