import torch import json import os import sys from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset from tqdm import tqdm # ── config ────────────────────────────────────────── MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" MODEL_PATHS = { "mistral-7b": "~/kv-hack/mistral-model", "llama-3-8b": "~/kv-hack/llama-model", } model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") os.makedirs(results_dir, exist_ok=True) # ──────────────────────────────────────────────────── print(f"Running calibration for: {MODEL_NAME}") print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float16, device_map="cuda" ) model.eval() # load calibration dataset print("Loading calibration data...") dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") texts = [t for t in dataset["text"] if len(t.strip()) > 200][:256] def quantize_tensor(x, bits): """Quantize tensor to given bits and dequantize back""" if bits == 16: return x qmin, qmax = 0, 2**bits - 1 xmin = x.amin(dim=-1, keepdim=True) xmax = x.amax(dim=-1, keepdim=True) scale = (xmax - xmin).clamp(min=1e-8) / qmax x_q = ((x - xmin) / scale).round().clamp(qmin, qmax) return x_q * scale + xmin def get_kv_error(layer_idx, head_idx, bits, num_samples=32): """Measure reconstruction error when quantizing a specific head's KV""" errors = [] for text in texts[:num_samples]: inputs = tokenizer( text, return_tensors="pt", max_length=512, truncation=True ).to("cuda") if inputs["input_ids"].shape[1] < 32: continue with torch.no_grad(): outputs = model( **inputs, output_attentions=False, use_cache=True ) kv_cache = outputs.past_key_values k = kv_cache.layers[layer_idx].keys # [1, heads, seq, head_dim] v = kv_cache.layers[layer_idx].values k_head = k[0, head_idx] v_head = v[0, head_idx] k_q = quantize_tensor(k_head, bits) v_q = quantize_tensor(v_head, bits) k_err = (k_head - k_q).pow(2).mean().item() v_err = (v_head - v_q).pow(2).mean().item() errors.append(k_err + v_err) return sum(errors) / len(errors) if errors else float('inf') # get model dimensions print("Detecting model dimensions...") with torch.no_grad(): dummy = tokenizer("hello", return_tensors="pt").to("cuda") out = model(**dummy, use_cache=True) kv_cache = out.past_key_values num_layers = len(kv_cache.layers) num_heads = kv_cache.layers[0].keys.shape[1] print(f"num_layers: {num_layers}, num_heads: {num_heads}") print(f"Model: {num_layers} layers, {num_heads} heads per layer") print("Running per-head sensitivity analysis...") print("This will take ~15-20 minutes. Grab a coffee ☕") sensitivity_map = {} bit_allocation = {} for layer_idx in tqdm(range(num_layers), desc="Layers"): sensitivity_map[layer_idx] = {} bit_allocation[layer_idx] = {} for head_idx in range(num_heads): err_2bit = get_kv_error(layer_idx, head_idx, 2, num_samples=32) err_4bit = get_kv_error(layer_idx, head_idx, 4, num_samples=32) err_8bit = get_kv_error(layer_idx, head_idx, 8, num_samples=32) sensitivity_map[layer_idx][head_idx] = { "2bit": round(err_2bit, 6), "4bit": round(err_4bit, 6), "8bit": round(err_8bit, 6), } # use 4-bit if error is in bottom 50% of all 4-bit errors # use 8-bit for high-sensitivity heads if err_4bit < 0.05: optimal_bits = 4 else: optimal_bits = 8 bit_allocation[layer_idx][head_idx] = optimal_bits # summary all_bits = [bit_allocation[l][h] for l in bit_allocation for h in bit_allocation[l]] avg_bits = sum(all_bits) / len(all_bits) dist = {2: all_bits.count(2), 4: all_bits.count(4), 8: all_bits.count(8)} compression = 16 / avg_bits print(f"\n✅ Calibration complete!") print(f"Bit distribution: {dist}") print(f"Average bits: {avg_bits:.2f}") print(f"Compression vs FP16: {compression:.1f}x") # save with open(f"{results_dir}/sensitivity_map.json", "w") as f: json.dump(sensitivity_map, f, indent=2) with open(f"{results_dir}/bit_allocation.json", "w") as f: json.dump(bit_allocation, f, indent=2) print(f"✅ Saved to {results_dir}/")