kv-cache-compression / scripts /benchmark_long_context.py
harshithsaiv's picture
chore: Cleanup of the Repo
9190eff
"""
Long context benchmarks at 16K and 32K.
This is where KV cache compression matters most.
"""
import torch
import json
import os
import sys
import time
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
with open(f"{results_dir}/bit_allocation.json") as f:
raw = json.load(f)
bit_alloc = {
int(l): [raw[l][str(h)] for h in range(len(raw[l]))]
for l in raw
}
num_layers = len(bit_alloc)
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
sum(len(l) for l in bit_alloc.values())
print(f"Model: {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
def measure_context(context_len: int):
print(f"\n Context {context_len} tokens...")
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
# peak memory
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
out = model(input_ids, use_cache=True)
kv = out.past_key_values
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated() / 1e9
# KV compression
fp16_bytes = 0
uniform8_bytes = 0
compressed_bytes = 0
for layer_idx in range(num_layers):
k = kv.layers[layer_idx].keys
v = kv.layers[layer_idx].values
fp16_bytes += k.numel() * 2 + v.numel() * 2
uniform8_bytes += k.numel() + v.numel()
cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
cache.store(k, v)
compressed_bytes += cache.memory_bytes()
# prefill speed
times = []
for _ in range(3):
torch.cuda.synchronize()
t0 = time.time()
with torch.no_grad():
_ = model(input_ids, use_cache=True)
torch.cuda.synchronize()
times.append(time.time() - t0)
prefill_ms = round(sum(times) / len(times) * 1000, 1)
return {
"context_len": context_len,
"peak_memory_gb": round(peak_mem, 2),
"fp16_mb": round(fp16_bytes / 1e6, 2),
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
"mixed_precision_mb": round(compressed_bytes / 1e6, 2),
"compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2),
"compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2),
"prefill_ms": prefill_ms,
}
print("\n" + "="*60)
print("LONG CONTEXT BENCHMARK")
print("="*60)
results = []
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
try:
r = measure_context(ctx)
results.append(r)
print(f" ctx={ctx:6d} | "
f"mem={r['peak_memory_gb']:.2f}GB | "
f"FP16={r['fp16_mb']:.0f}MB | "
f"Ours={r['mixed_precision_mb']:.0f}MB | "
f"{r['compression_vs_fp16']}x | "
f"prefill={r['prefill_ms']}ms")
except torch.cuda.OutOfMemoryError:
print(f" ctx={ctx:6d} | OOM — FP16 ran out of memory ✓")
# still measure our compressed version
results.append({
"context_len": ctx,
"peak_memory_gb": "OOM",
"fp16_mb": ctx * num_layers * 8 * 128 * 4 / 1e6,
"note": "FP16 OOM, compressed might fit"
})
break
# save
out_path = f"{results_dir}/long_context_results.json"
with open(out_path, "w") as f:
json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
print(f"\n✅ Saved to {out_path}")