""" Long context benchmarks at 16K and 32K. This is where KV cache compression matters most. 4 methods: FP16, Uniform 8-bit, Naive Per-Head, Triton True 4-bit """ import torch import json import os import sys import time from transformers import AutoTokenizer, AutoModelForCausalLM sys.path.append(os.path.expanduser("~/kv-hack")) from kernel.quant_cache import MixedPrecisionKVCache from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton # ── config ────────────────────────────────────────── MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" MODEL_PATHS = { "mistral-7b": "~/kv-hack/mistral-model", "llama-3-8b": "~/kv-hack/llama-model", } model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") with open(f"{results_dir}/bit_allocation.json") as f: raw = json.load(f) bit_alloc = { int(l): [raw[l][str(h)] for h in range(len(raw[l]))] for l in raw } num_layers = len(bit_alloc) avg_bits = sum(b for l in bit_alloc.values() for b in l) / \ sum(len(l) for l in bit_alloc.values()) print(f"Model: {MODEL_NAME}") print(f"Avg bits: {avg_bits:.2f}") print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float16, device_map="cuda" ) model.eval() print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB") def measure_context(context_len: int): print(f"\n Context {context_len} tokens...") input_ids = torch.randint(1, 1000, (1, context_len)).cuda() # peak memory torch.cuda.reset_peak_memory_stats() with torch.no_grad(): out = model(input_ids, use_cache=True) kv = out.past_key_values torch.cuda.synchronize() peak_mem = torch.cuda.max_memory_allocated() / 1e9 # KV compression — all 4 methods fp16_bytes = 0 uniform8_bytes = 0 naive_real_bytes = 0 triton_bytes = 0 for layer_idx in range(num_layers): k = kv.layers[layer_idx].keys v = kv.layers[layer_idx].values fp16_bytes += k.numel() * 2 + v.numel() * 2 uniform8_bytes += k.numel() + v.numel() cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx]) cache_naive.store(k, v) naive_real_bytes += cache_naive.real_gpu_bytes() cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx]) cache_triton.store(k, v) triton_bytes += cache_triton.memory_bytes() # prefill speed (3 runs average) times = [] for _ in range(3): torch.cuda.synchronize() t0 = time.time() with torch.no_grad(): _ = model(input_ids, use_cache=True) torch.cuda.synchronize() times.append(time.time() - t0) prefill_ms = round(sum(times) / len(times) * 1000, 1) return { "context_len": context_len, "peak_memory_gb": round(peak_mem, 2), "fp16_mb": round(fp16_bytes / 1e6, 2), "uniform8_mb": round(uniform8_bytes / 1e6, 2), "naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2), "triton_mb": round(triton_bytes / 1e6, 2), "naive_compression": round(fp16_bytes / naive_real_bytes, 2), "triton_compression": round(fp16_bytes / triton_bytes, 2), "prefill_ms": prefill_ms, } # ── RUN ────────────────────────────────────────────── print("\n" + "="*75) print("LONG CONTEXT BENCHMARK — 4 METHODS") print("="*75) results = [] for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]: try: r = measure_context(ctx) results.append(r) print(f" ctx={ctx:6d} | " f"mem={r['peak_memory_gb']:.2f}GB | " f"FP16={r['fp16_mb']:.0f}MB | " f"8bit={r['uniform8_mb']:.0f}MB | " f"Naive={r['naive_real_gpu_mb']:.0f}MB({r['naive_compression']}x) | " f"Triton={r['triton_mb']:.0f}MB({r['triton_compression']}x) | " f"prefill={r['prefill_ms']}ms") except torch.cuda.OutOfMemoryError: print(f" ctx={ctx:6d} | OOM at FP16 — compressed methods would fit ✓") results.append({ "context_len": ctx, "peak_memory_gb": "OOM", "fp16_mb": round(ctx * num_layers * 2 * 8 * 128 * 2 / 1e6, 2), "note": "FP16 OOM" }) break # save out_path = f"{results_dir}/long_context_results.json" with open(out_path, "w") as f: json.dump({"model": MODEL_NAME, "results": results}, f, indent=2) print(f"\n✅ Saved to {out_path}")