kv-cache-compression / benchmark_long_context.py
harshithsaiv's picture
feat: complete 4-method benchmark with honest memory reporting
0774ec2
"""
Long context benchmarks at 16K and 32K.
This is where KV cache compression matters most.
4 methods: FP16, Uniform 8-bit, Naive Per-Head, Triton True 4-bit
"""
import torch
import json
import os
import sys
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache
from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton
# ── config ──────────────────────────────────────────
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
with open(f"{results_dir}/bit_allocation.json") as f:
raw = json.load(f)
bit_alloc = {
int(l): [raw[l][str(h)] for h in range(len(raw[l]))]
for l in raw
}
num_layers = len(bit_alloc)
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
sum(len(l) for l in bit_alloc.values())
print(f"Model: {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
def measure_context(context_len: int):
print(f"\n Context {context_len} tokens...")
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
# peak memory
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
out = model(input_ids, use_cache=True)
kv = out.past_key_values
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated() / 1e9
# KV compression β€” all 4 methods
fp16_bytes = 0
uniform8_bytes = 0
naive_real_bytes = 0
triton_bytes = 0
for layer_idx in range(num_layers):
k = kv.layers[layer_idx].keys
v = kv.layers[layer_idx].values
fp16_bytes += k.numel() * 2 + v.numel() * 2
uniform8_bytes += k.numel() + v.numel()
cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx])
cache_naive.store(k, v)
naive_real_bytes += cache_naive.real_gpu_bytes()
cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx])
cache_triton.store(k, v)
triton_bytes += cache_triton.memory_bytes()
# prefill speed (3 runs average)
times = []
for _ in range(3):
torch.cuda.synchronize()
t0 = time.time()
with torch.no_grad():
_ = model(input_ids, use_cache=True)
torch.cuda.synchronize()
times.append(time.time() - t0)
prefill_ms = round(sum(times) / len(times) * 1000, 1)
return {
"context_len": context_len,
"peak_memory_gb": round(peak_mem, 2),
"fp16_mb": round(fp16_bytes / 1e6, 2),
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
"naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2),
"triton_mb": round(triton_bytes / 1e6, 2),
"naive_compression": round(fp16_bytes / naive_real_bytes, 2),
"triton_compression": round(fp16_bytes / triton_bytes, 2),
"prefill_ms": prefill_ms,
}
# ── RUN ──────────────────────────────────────────────
print("\n" + "="*75)
print("LONG CONTEXT BENCHMARK β€” 4 METHODS")
print("="*75)
results = []
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
try:
r = measure_context(ctx)
results.append(r)
print(f" ctx={ctx:6d} | "
f"mem={r['peak_memory_gb']:.2f}GB | "
f"FP16={r['fp16_mb']:.0f}MB | "
f"8bit={r['uniform8_mb']:.0f}MB | "
f"Naive={r['naive_real_gpu_mb']:.0f}MB({r['naive_compression']}x) | "
f"Triton={r['triton_mb']:.0f}MB({r['triton_compression']}x) | "
f"prefill={r['prefill_ms']}ms")
except torch.cuda.OutOfMemoryError:
print(f" ctx={ctx:6d} | OOM at FP16 β€” compressed methods would fit βœ“")
results.append({
"context_len": ctx,
"peak_memory_gb": "OOM",
"fp16_mb": round(ctx * num_layers * 2 * 8 * 128 * 2 / 1e6, 2),
"note": "FP16 OOM"
})
break
# save
out_path = f"{results_dir}/long_context_results.json"
with open(out_path, "w") as f:
json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
print(f"\nβœ… Saved to {out_path}")