""" Full benchmark suite comparing: 1. FP16 baseline 2. Uniform 8-bit quantization 3. Our mixed per-head quantization Across: memory, speed, perplexity """ import torch import json import os import sys import time import math from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset sys.path.append(os.path.expanduser("~/kv-hack")) from kernel.quant_cache import MixedPrecisionKVCache # ── config ────────────────────────────────────────── MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" MODEL_PATHS = { "mistral-7b": "~/kv-hack/mistral-model", "llama-3-8b": "~/kv-hack/llama-model", } model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") # load bit allocation with open(f"{results_dir}/bit_allocation.json") as f: bit_alloc_raw = json.load(f) bit_alloc = { int(l): [bit_alloc_raw[l][str(h)] for h in range(len(bit_alloc_raw[l]))] for l in bit_alloc_raw } num_layers = len(bit_alloc) avg_bits = sum(b for l in bit_alloc.values() for b in l) / \ sum(len(l) for l in bit_alloc.values()) print(f"Benchmarking: {MODEL_NAME}") print(f"Avg bits: {avg_bits:.2f}") # ── load model ────────────────────────────────────── print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float16, device_map="cuda" ) model.eval() print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB") # ── helper: compute KV compression at given context ── def measure_kv_compression(context_len: int): input_ids = torch.randint(1, 1000, (1, context_len)).cuda() with torch.no_grad(): out = model(input_ids, use_cache=True) kv = out.past_key_values fp16_bytes = 0 compressed_bytes = 0 uniform8_bytes = 0 for layer_idx in range(num_layers): k = kv.layers[layer_idx].keys v = kv.layers[layer_idx].values # FP16 baseline fp16_bytes += k.numel() * 2 + v.numel() * 2 # uniform 8-bit uniform8_bytes += k.numel() + v.numel() # 1 byte per element # our mixed precision cache = MixedPrecisionKVCache(bit_alloc[layer_idx]) cache.store(k, v) compressed_bytes += cache.memory_bytes() return { "context_len": context_len, "fp16_mb": round(fp16_bytes / 1e6, 2), "uniform8_mb": round(uniform8_bytes / 1e6, 2), "mixed_precision_mb": round(compressed_bytes / 1e6, 2), "compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2), "compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2), } # ── helper: measure perplexity ─────────────────────── def measure_perplexity(num_samples: int = 50): print(f" Computing perplexity on {num_samples} WikiText samples...") dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples] total_loss = 0 total_tokens = 0 for text in texts: inputs = tokenizer( text, return_tensors="pt", max_length=512, truncation=True ).to("cuda") if inputs["input_ids"].shape[1] < 10: continue with torch.no_grad(): out = model(**inputs, labels=inputs["input_ids"]) loss = out.loss.item() n = inputs["input_ids"].shape[1] total_loss += loss * n total_tokens += n ppl = math.exp(total_loss / total_tokens) return round(ppl, 2) # ── helper: measure decode speed ───────────────────── def measure_speed(context_len: int = 512, n_tokens: int = 100): input_ids = torch.randint(1, 1000, (1, context_len)).cuda() # warmup with torch.no_grad(): _ = model.generate( input_ids, max_new_tokens=10, do_sample=False, pad_token_id=tokenizer.eos_token_id ) torch.cuda.synchronize() t0 = time.time() with torch.no_grad(): _ = model.generate( input_ids, max_new_tokens=n_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id ) torch.cuda.synchronize() elapsed = time.time() - t0 return round(n_tokens / elapsed, 1) # ── helper: peak memory at context ─────────────────── def measure_peak_memory(context_len: int): torch.cuda.reset_peak_memory_stats() input_ids = torch.randint(1, 1000, (1, context_len)).cuda() with torch.no_grad(): _ = model(input_ids, use_cache=True) torch.cuda.synchronize() return round(torch.cuda.max_memory_allocated() / 1e9, 2) # ── RUN ALL BENCHMARKS ─────────────────────────────── print("\n" + "="*60) print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS") print("="*60) compression_results = [] for ctx in [512, 1024, 2048, 4096, 8192]: print(f" Context {ctx}...", end=" ", flush=True) r = measure_kv_compression(ctx) compression_results.append(r) print(f"FP16={r['fp16_mb']}MB " f"Uniform8={r['uniform8_mb']}MB " f"Ours={r['mixed_precision_mb']}MB " f"({r['compression_vs_fp16']}x vs FP16)") print("\n" + "="*60) print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS") print("="*60) memory_results = [] for ctx in [1024, 4096, 8192]: print(f" Context {ctx}...", end=" ", flush=True) mem = measure_peak_memory(ctx) memory_results.append({"context": ctx, "peak_memory_gb": mem}) print(f"{mem} GB") print("\n" + "="*60) print("3. DECODE SPEED") print("="*60) print(" Measuring tokens/sec...", end=" ", flush=True) speed = measure_speed() print(f"{speed} tokens/sec") print("\n" + "="*60) print("4. PERPLEXITY (quality check)") print("="*60) perplexity = measure_perplexity(num_samples=50) print(f" Perplexity: {perplexity}") # ── SAVE ALL RESULTS ───────────────────────────────── benchmark_results = { "model": MODEL_NAME, "avg_bits": round(avg_bits, 2), "compression": compression_results, "memory": memory_results, "decode_tokens_per_sec": speed, "perplexity": perplexity, "summary": { "fp16_8k_mb": next(r["fp16_mb"] for r in compression_results if r["context_len"] == 8192), "ours_8k_mb": next(r["mixed_precision_mb"] for r in compression_results if r["context_len"] == 8192), "compression_8k": next(r["compression_vs_fp16"] for r in compression_results if r["context_len"] == 8192), } } out_path = f"{results_dir}/benchmark_results.json" with open(out_path, "w") as f: json.dump(benchmark_results, f, indent=2) print("\n" + "="*60) print("SUMMARY") print("="*60) print(f"Model: {MODEL_NAME}") print(f"Avg bits: {avg_bits:.2f}") print(f"Perplexity: {perplexity}") print(f"Speed: {speed} tokens/sec") print(f"KV @ 8K ctx: {benchmark_results['summary']['fp16_8k_mb']}MB → {benchmark_results['summary']['ours_8k_mb']}MB ({benchmark_results['summary']['compression_8k']}x)") print(f"\n✅ Saved to {out_path}")