""" Long context benchmarks at 16K and 32K. This is where KV cache compression matters most. """ import torch import json import os import sys import time import math from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset sys.path.append(os.path.expanduser("~/kv-hack")) from kernel.quant_cache import MixedPrecisionKVCache MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" MODEL_PATHS = { "mistral-7b": "~/kv-hack/mistral-model", "llama-3-8b": "~/kv-hack/llama-model", } model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") with open(f"{results_dir}/bit_allocation.json") as f: raw = json.load(f) bit_alloc = { int(l): [raw[l][str(h)] for h in range(len(raw[l]))] for l in raw } num_layers = len(bit_alloc) avg_bits = sum(b for l in bit_alloc.values() for b in l) / \ sum(len(l) for l in bit_alloc.values()) print(f"Model: {MODEL_NAME}") print(f"Avg bits: {avg_bits:.2f}") print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float16, device_map="cuda" ) model.eval() print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB") def measure_context(context_len: int): print(f"\n Context {context_len} tokens...") input_ids = torch.randint(1, 1000, (1, context_len)).cuda() # peak memory torch.cuda.reset_peak_memory_stats() with torch.no_grad(): out = model(input_ids, use_cache=True) kv = out.past_key_values torch.cuda.synchronize() peak_mem = torch.cuda.max_memory_allocated() / 1e9 # KV compression fp16_bytes = 0 uniform8_bytes = 0 compressed_bytes = 0 for layer_idx in range(num_layers): k = kv.layers[layer_idx].keys v = kv.layers[layer_idx].values fp16_bytes += k.numel() * 2 + v.numel() * 2 uniform8_bytes += k.numel() + v.numel() cache = MixedPrecisionKVCache(bit_alloc[layer_idx]) cache.store(k, v) compressed_bytes += cache.memory_bytes() # prefill speed times = [] for _ in range(3): torch.cuda.synchronize() t0 = time.time() with torch.no_grad(): _ = model(input_ids, use_cache=True) torch.cuda.synchronize() times.append(time.time() - t0) prefill_ms = round(sum(times) / len(times) * 1000, 1) return { "context_len": context_len, "peak_memory_gb": round(peak_mem, 2), "fp16_mb": round(fp16_bytes / 1e6, 2), "uniform8_mb": round(uniform8_bytes / 1e6, 2), "mixed_precision_mb": round(compressed_bytes / 1e6, 2), "compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2), "compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2), "prefill_ms": prefill_ms, } print("\n" + "="*60) print("LONG CONTEXT BENCHMARK") print("="*60) results = [] for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]: try: r = measure_context(ctx) results.append(r) print(f" ctx={ctx:6d} | " f"mem={r['peak_memory_gb']:.2f}GB | " f"FP16={r['fp16_mb']:.0f}MB | " f"Ours={r['mixed_precision_mb']:.0f}MB | " f"{r['compression_vs_fp16']}x | " f"prefill={r['prefill_ms']}ms") except torch.cuda.OutOfMemoryError: print(f" ctx={ctx:6d} | OOM — FP16 ran out of memory āœ“") # still measure our compressed version results.append({ "context_len": ctx, "peak_memory_gb": "OOM", "fp16_mb": ctx * num_layers * 8 * 128 * 4 / 1e6, "note": "FP16 OOM, compressed might fit" }) break # save out_path = f"{results_dir}/long_context_results.json" with open(out_path, "w") as f: json.dump({"model": MODEL_NAME, "results": results}, f, indent=2) print(f"\nāœ… Saved to {out_path}")