File size: 4,828 Bytes
9190eff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import torch
import json
import os
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
# ββ config ββββββββββββββββββββββββββββββββββββββββββ
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
os.makedirs(results_dir, exist_ok=True)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"Running calibration for: {MODEL_NAME}")
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.float16,
device_map="cuda"
)
model.eval()
# load calibration dataset
print("Loading calibration data...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [t for t in dataset["text"] if len(t.strip()) > 200][:256]
def quantize_tensor(x, bits):
"""Quantize tensor to given bits and dequantize back"""
if bits == 16:
return x
qmin, qmax = 0, 2**bits - 1
xmin = x.amin(dim=-1, keepdim=True)
xmax = x.amax(dim=-1, keepdim=True)
scale = (xmax - xmin).clamp(min=1e-8) / qmax
x_q = ((x - xmin) / scale).round().clamp(qmin, qmax)
return x_q * scale + xmin
def get_kv_error(layer_idx, head_idx, bits, num_samples=32):
"""Measure reconstruction error when quantizing a specific head's KV"""
errors = []
for text in texts[:num_samples]:
inputs = tokenizer(
text,
return_tensors="pt",
max_length=512,
truncation=True
).to("cuda")
if inputs["input_ids"].shape[1] < 32:
continue
with torch.no_grad():
outputs = model(
**inputs,
output_attentions=False,
use_cache=True
)
kv_cache = outputs.past_key_values
k = kv_cache.layers[layer_idx].keys # [1, heads, seq, head_dim]
v = kv_cache.layers[layer_idx].values
k_head = k[0, head_idx]
v_head = v[0, head_idx]
k_q = quantize_tensor(k_head, bits)
v_q = quantize_tensor(v_head, bits)
k_err = (k_head - k_q).pow(2).mean().item()
v_err = (v_head - v_q).pow(2).mean().item()
errors.append(k_err + v_err)
return sum(errors) / len(errors) if errors else float('inf')
# get model dimensions
print("Detecting model dimensions...")
with torch.no_grad():
dummy = tokenizer("hello", return_tensors="pt").to("cuda")
out = model(**dummy, use_cache=True)
kv_cache = out.past_key_values
num_layers = len(kv_cache.layers)
num_heads = kv_cache.layers[0].keys.shape[1]
print(f"num_layers: {num_layers}, num_heads: {num_heads}")
print(f"Model: {num_layers} layers, {num_heads} heads per layer")
print("Running per-head sensitivity analysis...")
print("This will take ~15-20 minutes. Grab a coffee β")
sensitivity_map = {}
bit_allocation = {}
for layer_idx in tqdm(range(num_layers), desc="Layers"):
sensitivity_map[layer_idx] = {}
bit_allocation[layer_idx] = {}
for head_idx in range(num_heads):
err_2bit = get_kv_error(layer_idx, head_idx, 2, num_samples=32)
err_4bit = get_kv_error(layer_idx, head_idx, 4, num_samples=32)
err_8bit = get_kv_error(layer_idx, head_idx, 8, num_samples=32)
sensitivity_map[layer_idx][head_idx] = {
"2bit": round(err_2bit, 6),
"4bit": round(err_4bit, 6),
"8bit": round(err_8bit, 6),
}
# use 4-bit if error is in bottom 50% of all 4-bit errors
# use 8-bit for high-sensitivity heads
if err_4bit < 0.05:
optimal_bits = 4
else:
optimal_bits = 8
bit_allocation[layer_idx][head_idx] = optimal_bits
# summary
all_bits = [bit_allocation[l][h] for l in bit_allocation for h in bit_allocation[l]]
avg_bits = sum(all_bits) / len(all_bits)
dist = {2: all_bits.count(2), 4: all_bits.count(4), 8: all_bits.count(8)}
compression = 16 / avg_bits
print(f"\nβ
Calibration complete!")
print(f"Bit distribution: {dist}")
print(f"Average bits: {avg_bits:.2f}")
print(f"Compression vs FP16: {compression:.1f}x")
# save
with open(f"{results_dir}/sensitivity_map.json", "w") as f:
json.dump(sensitivity_map, f, indent=2)
with open(f"{results_dir}/bit_allocation.json", "w") as f:
json.dump(bit_allocation, f, indent=2)
print(f"β
Saved to {results_dir}/") |