| """ |
| Long context benchmarks at 16K and 32K. |
| This is where KV cache compression matters most. |
| 4 methods: FP16, Uniform 8-bit, Naive Per-Head, Triton True 4-bit |
| """ |
| import torch |
| import json |
| import os |
| import sys |
| import time |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| sys.path.append(os.path.expanduser("~/kv-hack")) |
| from kernel.quant_cache import MixedPrecisionKVCache |
| from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton |
|
|
| |
| MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" |
| MODEL_PATHS = { |
| "mistral-7b": "~/kv-hack/mistral-model", |
| "llama-3-8b": "~/kv-hack/llama-model", |
| } |
| model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) |
| results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") |
|
|
| with open(f"{results_dir}/bit_allocation.json") as f: |
| raw = json.load(f) |
| bit_alloc = { |
| int(l): [raw[l][str(h)] for h in range(len(raw[l]))] |
| for l in raw |
| } |
| num_layers = len(bit_alloc) |
| avg_bits = sum(b for l in bit_alloc.values() for b in l) / \ |
| sum(len(l) for l in bit_alloc.values()) |
|
|
| print(f"Model: {MODEL_NAME}") |
| print(f"Avg bits: {avg_bits:.2f}") |
|
|
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, dtype=torch.float16, device_map="cuda" |
| ) |
| model.eval() |
| print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB") |
|
|
|
|
| def measure_context(context_len: int): |
| print(f"\n Context {context_len} tokens...") |
| input_ids = torch.randint(1, 1000, (1, context_len)).cuda() |
|
|
| |
| torch.cuda.reset_peak_memory_stats() |
| with torch.no_grad(): |
| out = model(input_ids, use_cache=True) |
| kv = out.past_key_values |
| torch.cuda.synchronize() |
| peak_mem = torch.cuda.max_memory_allocated() / 1e9 |
|
|
| |
| fp16_bytes = 0 |
| uniform8_bytes = 0 |
| naive_real_bytes = 0 |
| triton_bytes = 0 |
|
|
| for layer_idx in range(num_layers): |
| k = kv.layers[layer_idx].keys |
| v = kv.layers[layer_idx].values |
|
|
| fp16_bytes += k.numel() * 2 + v.numel() * 2 |
| uniform8_bytes += k.numel() + v.numel() |
|
|
| cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx]) |
| cache_naive.store(k, v) |
| naive_real_bytes += cache_naive.real_gpu_bytes() |
|
|
| cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx]) |
| cache_triton.store(k, v) |
| triton_bytes += cache_triton.memory_bytes() |
|
|
| |
| times = [] |
| for _ in range(3): |
| torch.cuda.synchronize() |
| t0 = time.time() |
| with torch.no_grad(): |
| _ = model(input_ids, use_cache=True) |
| torch.cuda.synchronize() |
| times.append(time.time() - t0) |
| prefill_ms = round(sum(times) / len(times) * 1000, 1) |
|
|
| return { |
| "context_len": context_len, |
| "peak_memory_gb": round(peak_mem, 2), |
| "fp16_mb": round(fp16_bytes / 1e6, 2), |
| "uniform8_mb": round(uniform8_bytes / 1e6, 2), |
| "naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2), |
| "triton_mb": round(triton_bytes / 1e6, 2), |
| "naive_compression": round(fp16_bytes / naive_real_bytes, 2), |
| "triton_compression": round(fp16_bytes / triton_bytes, 2), |
| "prefill_ms": prefill_ms, |
| } |
|
|
|
|
| |
| print("\n" + "="*75) |
| print("LONG CONTEXT BENCHMARK β 4 METHODS") |
| print("="*75) |
|
|
| results = [] |
| for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]: |
| try: |
| r = measure_context(ctx) |
| results.append(r) |
| print(f" ctx={ctx:6d} | " |
| f"mem={r['peak_memory_gb']:.2f}GB | " |
| f"FP16={r['fp16_mb']:.0f}MB | " |
| f"8bit={r['uniform8_mb']:.0f}MB | " |
| f"Naive={r['naive_real_gpu_mb']:.0f}MB({r['naive_compression']}x) | " |
| f"Triton={r['triton_mb']:.0f}MB({r['triton_compression']}x) | " |
| f"prefill={r['prefill_ms']}ms") |
| except torch.cuda.OutOfMemoryError: |
| print(f" ctx={ctx:6d} | OOM at FP16 β compressed methods would fit β") |
| results.append({ |
| "context_len": ctx, |
| "peak_memory_gb": "OOM", |
| "fp16_mb": round(ctx * num_layers * 2 * 8 * 128 * 2 / 1e6, 2), |
| "note": "FP16 OOM" |
| }) |
| break |
|
|
| |
| out_path = f"{results_dir}/long_context_results.json" |
| with open(out_path, "w") as f: |
| json.dump({"model": MODEL_NAME, "results": results}, f, indent=2) |
| print(f"\nβ
Saved to {out_path}") |