File size: 4,848 Bytes
1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 1a0124b 0774ec2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
Long context benchmarks at 16K and 32K.
This is where KV cache compression matters most.
4 methods: FP16, Uniform 8-bit, Naive Per-Head, Triton True 4-bit
"""
import torch
import json
import os
import sys
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache
from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton
# ββ config ββββββββββββββββββββββββββββββββββββββββββ
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
with open(f"{results_dir}/bit_allocation.json") as f:
raw = json.load(f)
bit_alloc = {
int(l): [raw[l][str(h)] for h in range(len(raw[l]))]
for l in raw
}
num_layers = len(bit_alloc)
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
sum(len(l) for l in bit_alloc.values())
print(f"Model: {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
def measure_context(context_len: int):
print(f"\n Context {context_len} tokens...")
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
# peak memory
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
out = model(input_ids, use_cache=True)
kv = out.past_key_values
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated() / 1e9
# KV compression β all 4 methods
fp16_bytes = 0
uniform8_bytes = 0
naive_real_bytes = 0
triton_bytes = 0
for layer_idx in range(num_layers):
k = kv.layers[layer_idx].keys
v = kv.layers[layer_idx].values
fp16_bytes += k.numel() * 2 + v.numel() * 2
uniform8_bytes += k.numel() + v.numel()
cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx])
cache_naive.store(k, v)
naive_real_bytes += cache_naive.real_gpu_bytes()
cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx])
cache_triton.store(k, v)
triton_bytes += cache_triton.memory_bytes()
# prefill speed (3 runs average)
times = []
for _ in range(3):
torch.cuda.synchronize()
t0 = time.time()
with torch.no_grad():
_ = model(input_ids, use_cache=True)
torch.cuda.synchronize()
times.append(time.time() - t0)
prefill_ms = round(sum(times) / len(times) * 1000, 1)
return {
"context_len": context_len,
"peak_memory_gb": round(peak_mem, 2),
"fp16_mb": round(fp16_bytes / 1e6, 2),
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
"naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2),
"triton_mb": round(triton_bytes / 1e6, 2),
"naive_compression": round(fp16_bytes / naive_real_bytes, 2),
"triton_compression": round(fp16_bytes / triton_bytes, 2),
"prefill_ms": prefill_ms,
}
# ββ RUN ββββββββββββββββββββββββββββββββββββββββββββββ
print("\n" + "="*75)
print("LONG CONTEXT BENCHMARK β 4 METHODS")
print("="*75)
results = []
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
try:
r = measure_context(ctx)
results.append(r)
print(f" ctx={ctx:6d} | "
f"mem={r['peak_memory_gb']:.2f}GB | "
f"FP16={r['fp16_mb']:.0f}MB | "
f"8bit={r['uniform8_mb']:.0f}MB | "
f"Naive={r['naive_real_gpu_mb']:.0f}MB({r['naive_compression']}x) | "
f"Triton={r['triton_mb']:.0f}MB({r['triton_compression']}x) | "
f"prefill={r['prefill_ms']}ms")
except torch.cuda.OutOfMemoryError:
print(f" ctx={ctx:6d} | OOM at FP16 β compressed methods would fit β")
results.append({
"context_len": ctx,
"peak_memory_gb": "OOM",
"fp16_mb": round(ctx * num_layers * 2 * 8 * 128 * 2 / 1e6, 2),
"note": "FP16 OOM"
})
break
# save
out_path = f"{results_dir}/long_context_results.json"
with open(out_path, "w") as f:
json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
print(f"\nβ
Saved to {out_path}") |