kv-cache-compression / scripts /calibrate.py
harshithsaiv's picture
chore: Cleanup of the Repo
9190eff
import torch
import json
import os
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
# ── config ──────────────────────────────────────────
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
os.makedirs(results_dir, exist_ok=True)
# ────────────────────────────────────────────────────
print(f"Running calibration for: {MODEL_NAME}")
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.float16,
device_map="cuda"
)
model.eval()
# load calibration dataset
print("Loading calibration data...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [t for t in dataset["text"] if len(t.strip()) > 200][:256]
def quantize_tensor(x, bits):
"""Quantize tensor to given bits and dequantize back"""
if bits == 16:
return x
qmin, qmax = 0, 2**bits - 1
xmin = x.amin(dim=-1, keepdim=True)
xmax = x.amax(dim=-1, keepdim=True)
scale = (xmax - xmin).clamp(min=1e-8) / qmax
x_q = ((x - xmin) / scale).round().clamp(qmin, qmax)
return x_q * scale + xmin
def get_kv_error(layer_idx, head_idx, bits, num_samples=32):
"""Measure reconstruction error when quantizing a specific head's KV"""
errors = []
for text in texts[:num_samples]:
inputs = tokenizer(
text,
return_tensors="pt",
max_length=512,
truncation=True
).to("cuda")
if inputs["input_ids"].shape[1] < 32:
continue
with torch.no_grad():
outputs = model(
**inputs,
output_attentions=False,
use_cache=True
)
kv_cache = outputs.past_key_values
k = kv_cache.layers[layer_idx].keys # [1, heads, seq, head_dim]
v = kv_cache.layers[layer_idx].values
k_head = k[0, head_idx]
v_head = v[0, head_idx]
k_q = quantize_tensor(k_head, bits)
v_q = quantize_tensor(v_head, bits)
k_err = (k_head - k_q).pow(2).mean().item()
v_err = (v_head - v_q).pow(2).mean().item()
errors.append(k_err + v_err)
return sum(errors) / len(errors) if errors else float('inf')
# get model dimensions
print("Detecting model dimensions...")
with torch.no_grad():
dummy = tokenizer("hello", return_tensors="pt").to("cuda")
out = model(**dummy, use_cache=True)
kv_cache = out.past_key_values
num_layers = len(kv_cache.layers)
num_heads = kv_cache.layers[0].keys.shape[1]
print(f"num_layers: {num_layers}, num_heads: {num_heads}")
print(f"Model: {num_layers} layers, {num_heads} heads per layer")
print("Running per-head sensitivity analysis...")
print("This will take ~15-20 minutes. Grab a coffee β˜•")
sensitivity_map = {}
bit_allocation = {}
for layer_idx in tqdm(range(num_layers), desc="Layers"):
sensitivity_map[layer_idx] = {}
bit_allocation[layer_idx] = {}
for head_idx in range(num_heads):
err_2bit = get_kv_error(layer_idx, head_idx, 2, num_samples=32)
err_4bit = get_kv_error(layer_idx, head_idx, 4, num_samples=32)
err_8bit = get_kv_error(layer_idx, head_idx, 8, num_samples=32)
sensitivity_map[layer_idx][head_idx] = {
"2bit": round(err_2bit, 6),
"4bit": round(err_4bit, 6),
"8bit": round(err_8bit, 6),
}
# use 4-bit if error is in bottom 50% of all 4-bit errors
# use 8-bit for high-sensitivity heads
if err_4bit < 0.05:
optimal_bits = 4
else:
optimal_bits = 8
bit_allocation[layer_idx][head_idx] = optimal_bits
# summary
all_bits = [bit_allocation[l][h] for l in bit_allocation for h in bit_allocation[l]]
avg_bits = sum(all_bits) / len(all_bits)
dist = {2: all_bits.count(2), 4: all_bits.count(4), 8: all_bits.count(8)}
compression = 16 / avg_bits
print(f"\nβœ… Calibration complete!")
print(f"Bit distribution: {dist}")
print(f"Average bits: {avg_bits:.2f}")
print(f"Compression vs FP16: {compression:.1f}x")
# save
with open(f"{results_dir}/sensitivity_map.json", "w") as f:
json.dump(sensitivity_map, f, indent=2)
with open(f"{results_dir}/bit_allocation.json", "w") as f:
json.dump(bit_allocation, f, indent=2)
print(f"βœ… Saved to {results_dir}/")