| import torch |
| import json |
| import os |
| import sys |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| |
| MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" |
| MODEL_PATHS = { |
| "mistral-7b": "~/kv-hack/mistral-model", |
| "llama-3-8b": "~/kv-hack/llama-model", |
| } |
| model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) |
| results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") |
| os.makedirs(results_dir, exist_ok=True) |
| |
|
|
| print(f"Running calibration for: {MODEL_NAME}") |
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| dtype=torch.float16, |
| device_map="cuda" |
| ) |
| model.eval() |
|
|
| |
| print("Loading calibration data...") |
| dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") |
| texts = [t for t in dataset["text"] if len(t.strip()) > 200][:256] |
|
|
| def quantize_tensor(x, bits): |
| """Quantize tensor to given bits and dequantize back""" |
| if bits == 16: |
| return x |
| qmin, qmax = 0, 2**bits - 1 |
| xmin = x.amin(dim=-1, keepdim=True) |
| xmax = x.amax(dim=-1, keepdim=True) |
| scale = (xmax - xmin).clamp(min=1e-8) / qmax |
| x_q = ((x - xmin) / scale).round().clamp(qmin, qmax) |
| return x_q * scale + xmin |
|
|
| def get_kv_error(layer_idx, head_idx, bits, num_samples=32): |
| """Measure reconstruction error when quantizing a specific head's KV""" |
| errors = [] |
|
|
| for text in texts[:num_samples]: |
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| max_length=512, |
| truncation=True |
| ).to("cuda") |
|
|
| if inputs["input_ids"].shape[1] < 32: |
| continue |
|
|
| with torch.no_grad(): |
| outputs = model( |
| **inputs, |
| output_attentions=False, |
| use_cache=True |
| ) |
|
|
| kv_cache = outputs.past_key_values |
| k = kv_cache.layers[layer_idx].keys |
| v = kv_cache.layers[layer_idx].values |
|
|
| k_head = k[0, head_idx] |
| v_head = v[0, head_idx] |
|
|
| k_q = quantize_tensor(k_head, bits) |
| v_q = quantize_tensor(v_head, bits) |
|
|
| k_err = (k_head - k_q).pow(2).mean().item() |
| v_err = (v_head - v_q).pow(2).mean().item() |
| errors.append(k_err + v_err) |
|
|
| return sum(errors) / len(errors) if errors else float('inf') |
|
|
| |
| print("Detecting model dimensions...") |
| with torch.no_grad(): |
| dummy = tokenizer("hello", return_tensors="pt").to("cuda") |
| out = model(**dummy, use_cache=True) |
| kv_cache = out.past_key_values |
| num_layers = len(kv_cache.layers) |
| num_heads = kv_cache.layers[0].keys.shape[1] |
|
|
| print(f"num_layers: {num_layers}, num_heads: {num_heads}") |
|
|
|
|
| print(f"Model: {num_layers} layers, {num_heads} heads per layer") |
| print("Running per-head sensitivity analysis...") |
| print("This will take ~15-20 minutes. Grab a coffee β") |
|
|
| sensitivity_map = {} |
| bit_allocation = {} |
|
|
| for layer_idx in tqdm(range(num_layers), desc="Layers"): |
| sensitivity_map[layer_idx] = {} |
| bit_allocation[layer_idx] = {} |
|
|
| for head_idx in range(num_heads): |
| err_2bit = get_kv_error(layer_idx, head_idx, 2, num_samples=32) |
| err_4bit = get_kv_error(layer_idx, head_idx, 4, num_samples=32) |
| err_8bit = get_kv_error(layer_idx, head_idx, 8, num_samples=32) |
|
|
| sensitivity_map[layer_idx][head_idx] = { |
| "2bit": round(err_2bit, 6), |
| "4bit": round(err_4bit, 6), |
| "8bit": round(err_8bit, 6), |
| } |
|
|
| |
| |
| if err_4bit < 0.05: |
| optimal_bits = 4 |
| else: |
| optimal_bits = 8 |
|
|
| bit_allocation[layer_idx][head_idx] = optimal_bits |
|
|
| |
| all_bits = [bit_allocation[l][h] for l in bit_allocation for h in bit_allocation[l]] |
| avg_bits = sum(all_bits) / len(all_bits) |
| dist = {2: all_bits.count(2), 4: all_bits.count(4), 8: all_bits.count(8)} |
| compression = 16 / avg_bits |
|
|
| print(f"\nβ
Calibration complete!") |
| print(f"Bit distribution: {dist}") |
| print(f"Average bits: {avg_bits:.2f}") |
| print(f"Compression vs FP16: {compression:.1f}x") |
|
|
| |
| with open(f"{results_dir}/sensitivity_map.json", "w") as f: |
| json.dump(sensitivity_map, f, indent=2) |
|
|
| with open(f"{results_dir}/bit_allocation.json", "w") as f: |
| json.dump(bit_allocation, f, indent=2) |
|
|
| print(f"β
Saved to {results_dir}/") |