| """ |
| Full benchmark suite comparing: |
| 1. FP16 baseline |
| 2. Uniform 8-bit quantization |
| 3. Our mixed per-head quantization |
| Across: memory, speed, perplexity |
| """ |
| import torch |
| import json |
| import os |
| import sys |
| import time |
| import math |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from datasets import load_dataset |
|
|
| sys.path.append(os.path.expanduser("~/kv-hack")) |
| from kernel.quant_cache import MixedPrecisionKVCache |
|
|
| |
| MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" |
| MODEL_PATHS = { |
| "mistral-7b": "~/kv-hack/mistral-model", |
| "llama-3-8b": "~/kv-hack/llama-model", |
| } |
| model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) |
| results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") |
|
|
| |
| with open(f"{results_dir}/bit_allocation.json") as f: |
| bit_alloc_raw = json.load(f) |
| bit_alloc = { |
| int(l): [bit_alloc_raw[l][str(h)] |
| for h in range(len(bit_alloc_raw[l]))] |
| for l in bit_alloc_raw |
| } |
| num_layers = len(bit_alloc) |
| avg_bits = sum(b for l in bit_alloc.values() for b in l) / \ |
| sum(len(l) for l in bit_alloc.values()) |
|
|
| print(f"Benchmarking: {MODEL_NAME}") |
| print(f"Avg bits: {avg_bits:.2f}") |
|
|
| |
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, dtype=torch.float16, device_map="cuda" |
| ) |
| model.eval() |
| print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB") |
|
|
| |
| def measure_kv_compression(context_len: int): |
| input_ids = torch.randint(1, 1000, (1, context_len)).cuda() |
| with torch.no_grad(): |
| out = model(input_ids, use_cache=True) |
| kv = out.past_key_values |
|
|
| fp16_bytes = 0 |
| compressed_bytes = 0 |
| uniform8_bytes = 0 |
|
|
| for layer_idx in range(num_layers): |
| k = kv.layers[layer_idx].keys |
| v = kv.layers[layer_idx].values |
|
|
| |
| fp16_bytes += k.numel() * 2 + v.numel() * 2 |
|
|
| |
| uniform8_bytes += k.numel() + v.numel() |
|
|
| |
| cache = MixedPrecisionKVCache(bit_alloc[layer_idx]) |
| cache.store(k, v) |
| compressed_bytes += cache.memory_bytes() |
|
|
| return { |
| "context_len": context_len, |
| "fp16_mb": round(fp16_bytes / 1e6, 2), |
| "uniform8_mb": round(uniform8_bytes / 1e6, 2), |
| "mixed_precision_mb": round(compressed_bytes / 1e6, 2), |
| "compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2), |
| "compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2), |
| } |
|
|
| |
| def measure_perplexity(num_samples: int = 50): |
| print(f" Computing perplexity on {num_samples} WikiText samples...") |
| dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples] |
|
|
| total_loss = 0 |
| total_tokens = 0 |
|
|
| for text in texts: |
| inputs = tokenizer( |
| text, return_tensors="pt", |
| max_length=512, truncation=True |
| ).to("cuda") |
|
|
| if inputs["input_ids"].shape[1] < 10: |
| continue |
|
|
| with torch.no_grad(): |
| out = model(**inputs, labels=inputs["input_ids"]) |
| loss = out.loss.item() |
|
|
| n = inputs["input_ids"].shape[1] |
| total_loss += loss * n |
| total_tokens += n |
|
|
| ppl = math.exp(total_loss / total_tokens) |
| return round(ppl, 2) |
|
|
| |
| def measure_speed(context_len: int = 512, n_tokens: int = 100): |
| input_ids = torch.randint(1, 1000, (1, context_len)).cuda() |
| |
| |
| with torch.no_grad(): |
| _ = model.generate( |
| input_ids, max_new_tokens=10, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
|
|
| torch.cuda.synchronize() |
| t0 = time.time() |
| with torch.no_grad(): |
| _ = model.generate( |
| input_ids, max_new_tokens=n_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| torch.cuda.synchronize() |
| elapsed = time.time() - t0 |
| return round(n_tokens / elapsed, 1) |
|
|
| |
| def measure_peak_memory(context_len: int): |
| torch.cuda.reset_peak_memory_stats() |
| input_ids = torch.randint(1, 1000, (1, context_len)).cuda() |
| with torch.no_grad(): |
| _ = model(input_ids, use_cache=True) |
| torch.cuda.synchronize() |
| return round(torch.cuda.max_memory_allocated() / 1e9, 2) |
|
|
| |
| print("\n" + "="*60) |
| print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS") |
| print("="*60) |
|
|
| compression_results = [] |
| for ctx in [512, 1024, 2048, 4096, 8192]: |
| print(f" Context {ctx}...", end=" ", flush=True) |
| r = measure_kv_compression(ctx) |
| compression_results.append(r) |
| print(f"FP16={r['fp16_mb']}MB " |
| f"Uniform8={r['uniform8_mb']}MB " |
| f"Ours={r['mixed_precision_mb']}MB " |
| f"({r['compression_vs_fp16']}x vs FP16)") |
|
|
| print("\n" + "="*60) |
| print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS") |
| print("="*60) |
|
|
| memory_results = [] |
| for ctx in [1024, 4096, 8192]: |
| print(f" Context {ctx}...", end=" ", flush=True) |
| mem = measure_peak_memory(ctx) |
| memory_results.append({"context": ctx, "peak_memory_gb": mem}) |
| print(f"{mem} GB") |
|
|
| print("\n" + "="*60) |
| print("3. DECODE SPEED") |
| print("="*60) |
| print(" Measuring tokens/sec...", end=" ", flush=True) |
| speed = measure_speed() |
| print(f"{speed} tokens/sec") |
|
|
| print("\n" + "="*60) |
| print("4. PERPLEXITY (quality check)") |
| print("="*60) |
| perplexity = measure_perplexity(num_samples=50) |
| print(f" Perplexity: {perplexity}") |
|
|
| |
| benchmark_results = { |
| "model": MODEL_NAME, |
| "avg_bits": round(avg_bits, 2), |
| "compression": compression_results, |
| "memory": memory_results, |
| "decode_tokens_per_sec": speed, |
| "perplexity": perplexity, |
| "summary": { |
| "fp16_8k_mb": next(r["fp16_mb"] for r in compression_results if r["context_len"] == 8192), |
| "ours_8k_mb": next(r["mixed_precision_mb"] for r in compression_results if r["context_len"] == 8192), |
| "compression_8k": next(r["compression_vs_fp16"] for r in compression_results if r["context_len"] == 8192), |
| } |
| } |
|
|
| out_path = f"{results_dir}/benchmark_results.json" |
| with open(out_path, "w") as f: |
| json.dump(benchmark_results, f, indent=2) |
|
|
| print("\n" + "="*60) |
| print("SUMMARY") |
| print("="*60) |
| print(f"Model: {MODEL_NAME}") |
| print(f"Avg bits: {avg_bits:.2f}") |
| print(f"Perplexity: {perplexity}") |
| print(f"Speed: {speed} tokens/sec") |
| print(f"KV @ 8K ctx: {benchmark_results['summary']['fp16_8k_mb']}MB β {benchmark_results['summary']['ours_8k_mb']}MB ({benchmark_results['summary']['compression_8k']}x)") |
| print(f"\nβ
Saved to {out_path}") |
|
|