""" Full benchmark suite comparing: 1. FP16 baseline 2. Uniform 8-bit quantization 3. Naive mixed per-head (uint8 storage — not truly packed) 4. Triton mixed per-head (truly packed 4-bit) Across: memory, speed, perplexity """ import torch import json import os import sys import time import math from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset sys.path.append(os.path.expanduser("~/kv-hack")) from kernel.quant_cache import MixedPrecisionKVCache from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton # ── config ────────────────────────────────────────── MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" MODEL_PATHS = { "mistral-7b": "~/kv-hack/mistral-model", "llama-3-8b": "~/kv-hack/llama-model", } model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") with open(f"{results_dir}/bit_allocation.json") as f: bit_alloc_raw = json.load(f) bit_alloc = { int(l): [bit_alloc_raw[l][str(h)] for h in range(len(bit_alloc_raw[l]))] for l in bit_alloc_raw } num_layers = len(bit_alloc) avg_bits = sum(b for l in bit_alloc.values() for b in l) / \ sum(len(l) for l in bit_alloc.values()) print(f"Benchmarking: {MODEL_NAME}") print(f"Avg bits: {avg_bits:.2f}") print(f"Theoretical compression: {16/avg_bits:.2f}x") print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float16, device_map="cuda" ) model.eval() print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB") def measure_kv_compression(context_len: int): input_ids = torch.randint(1, 1000, (1, context_len)).cuda() with torch.no_grad(): out = model(input_ids, use_cache=True) kv = out.past_key_values fp16_bytes = 0 uniform8_bytes = 0 naive_real_bytes = 0 # actual GPU bytes for naive (uint8) naive_theo_bytes = 0 # theoretical packed size for naive triton_bytes = 0 # actual GPU bytes for triton (truly packed) for layer_idx in range(num_layers): k = kv.layers[layer_idx].keys v = kv.layers[layer_idx].values # FP16 baseline fp16_bytes += k.numel() * 2 + v.numel() * 2 # uniform 8-bit (1 byte per element) uniform8_bytes += k.numel() + v.numel() # naive mixed precision cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx]) cache_naive.store(k, v) naive_real_bytes += cache_naive.real_gpu_bytes() # actual GPU naive_theo_bytes += cache_naive.memory_bytes() # theoretical # triton true 4-bit cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx]) cache_triton.store(k, v) triton_bytes += cache_triton.memory_bytes() # actual GPU (truly packed) return { "context_len": context_len, "fp16_mb": round(fp16_bytes / 1e6, 2), "uniform8_mb": round(uniform8_bytes / 1e6, 2), "naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2), "naive_theoretical_mb": round(naive_theo_bytes / 1e6, 2), "triton_mb": round(triton_bytes / 1e6, 2), "naive_real_compression": round(fp16_bytes / naive_real_bytes, 2), "naive_theo_compression": round(fp16_bytes / naive_theo_bytes, 2), "triton_compression_vs_fp16": round(fp16_bytes / triton_bytes, 2), "triton_compression_vs_8bit": round(uniform8_bytes / triton_bytes, 2), "triton_compression_vs_naive": round(naive_real_bytes / triton_bytes, 2), } def measure_perplexity(num_samples: int = 50): print(f" Computing perplexity on {num_samples} WikiText samples...") dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples] total_loss = 0 total_tokens = 0 for text in texts: inputs = tokenizer( text, return_tensors="pt", max_length=512, truncation=True ).to("cuda") if inputs["input_ids"].shape[1] < 10: continue with torch.no_grad(): out = model(**inputs, labels=inputs["input_ids"]) loss = out.loss.item() n = inputs["input_ids"].shape[1] total_loss += loss * n total_tokens += n return round(math.exp(total_loss / total_tokens), 2) def measure_speed(context_len: int = 512, n_tokens: int = 100): input_ids = torch.randint(1, 1000, (1, context_len)).cuda() # warmup with torch.no_grad(): _ = model.generate( input_ids, max_new_tokens=10, do_sample=False, pad_token_id=tokenizer.eos_token_id ) torch.cuda.synchronize() t0 = time.time() with torch.no_grad(): _ = model.generate( input_ids, max_new_tokens=n_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id ) torch.cuda.synchronize() return round(n_tokens / (time.time() - t0), 1) def measure_peak_memory(context_len: int): torch.cuda.reset_peak_memory_stats() input_ids = torch.randint(1, 1000, (1, context_len)).cuda() with torch.no_grad(): _ = model(input_ids, use_cache=True) torch.cuda.synchronize() return round(torch.cuda.max_memory_allocated() / 1e9, 2) # ── RUN ALL BENCHMARKS ─────────────────────────────── print("\n" + "="*75) print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS") print("="*75) compression_results = [] for ctx in [512, 1024, 2048, 4096, 8192]: print(f" Context {ctx}...", end=" ", flush=True) r = measure_kv_compression(ctx) compression_results.append(r) print(f"FP16={r['fp16_mb']}MB | " f"8bit={r['uniform8_mb']}MB | " f"Naive(actual)={r['naive_real_gpu_mb']}MB({r['naive_real_compression']}x) | " f"Triton={r['triton_mb']}MB({r['triton_compression_vs_fp16']}x)") print("\n" + "="*75) print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS") print("="*75) memory_results = [] for ctx in [1024, 4096, 8192]: print(f" Context {ctx}...", end=" ", flush=True) mem = measure_peak_memory(ctx) memory_results.append({"context": ctx, "peak_memory_gb": mem}) print(f"{mem} GB") print("\n" + "="*75) print("3. DECODE SPEED") print("="*75) print(" Measuring tokens/sec...", end=" ", flush=True) speed = measure_speed() print(f"{speed} tokens/sec") print("\n" + "="*75) print("4. PERPLEXITY (quality check)") print("="*75) perplexity = measure_perplexity(num_samples=50) print(f" Perplexity: {perplexity}") # ── SAVE ───────────────────────────────────────────── r8k = next(r for r in compression_results if r["context_len"] == 8192) benchmark_results = { "model": MODEL_NAME, "avg_bits": round(avg_bits, 2), "compression": compression_results, "memory": memory_results, "decode_tokens_per_sec": speed, "perplexity": perplexity, "summary": { "fp16_8k_mb": r8k["fp16_mb"], "uniform8_8k_mb": r8k["uniform8_mb"], "naive_real_8k_mb": r8k["naive_real_gpu_mb"], "naive_theoretical_8k_mb": r8k["naive_theoretical_mb"], "triton_8k_mb": r8k["triton_mb"], "naive_real_compression_8k": r8k["naive_real_compression"], "naive_theo_compression_8k": r8k["naive_theo_compression"], "triton_compression_8k": r8k["triton_compression_vs_fp16"], "triton_vs_naive_8k": r8k["triton_compression_vs_naive"], "triton_vs_8bit_8k": r8k["triton_compression_vs_8bit"], } } out_path = f"{results_dir}/benchmark_results.json" with open(out_path, "w") as f: json.dump(benchmark_results, f, indent=2) print("\n" + "="*75) print("SUMMARY") print("="*75) print(f"Model: {MODEL_NAME}") print(f"Avg bits per head: {avg_bits:.2f}") print(f"Perplexity: {perplexity}") print(f"Decode speed: {speed} tokens/sec") print() print(f"KV Cache at 8K context:") print(f" FP16 baseline: {r8k['fp16_mb']} MB (1.00x)") print(f" Uniform 8-bit: {r8k['uniform8_mb']} MB (2.00x)") print(f" Naive per-head (actual GPU): {r8k['naive_real_gpu_mb']} MB ({r8k['naive_real_compression']}x) ← uint8 storage") print(f" Naive per-head (theoretical): {r8k['naive_theoretical_mb']} MB ({r8k['naive_theo_compression']}x) ← if truly packed") print(f" Triton true 4-bit: {r8k['triton_mb']} MB ({r8k['triton_compression_vs_fp16']}x) ← actual GPU") print(f" Triton vs Naive: {r8k['triton_compression_vs_naive']}x smaller on GPU") print(f" Triton vs 8-bit: {r8k['triton_compression_vs_8bit']}x smaller") print(f"\n✅ Saved to {out_path}")