| """ |
| Quick start example — compress KV cache in 10 lines. |
| """ |
| import torch |
| import json |
| import sys |
| import os |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from kernel.quant_cache import MixedPrecisionKVCache |
|
|
| |
| |
| k = torch.randn(1, 8, 1024, 128, dtype=torch.float16, device="cuda") |
| v = torch.randn(1, 8, 1024, 128, dtype=torch.float16, device="cuda") |
|
|
| |
| |
| bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8] |
|
|
| |
| cache = MixedPrecisionKVCache(bit_alloc) |
| cache.store(k, v) |
|
|
| |
| k_out, v_out = cache.retrieve() |
|
|
| |
| fp16_bytes = k.numel() * 2 * 2 |
| quant_bytes = cache.memory_bytes() |
| print(f"FP16: {fp16_bytes/1024:.0f} KB") |
| print(f"Compressed: {quant_bytes/1024:.0f} KB") |
| print(f"Ratio: {fp16_bytes/quant_bytes:.2f}x") |
| print(f"K error: {(k - k_out).abs().mean():.6f}") |
| print(f"V error: {(v - v_out).abs().mean():.6f}") |
|
|