harshithsaiv commited on
Commit
1a0124b
·
1 Parent(s): 0f6e4c1

feat: adding benchmark for longer context

Browse files
Files changed (1) hide show
  1. benchmark_long_context.py +124 -0
benchmark_long_context.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Long context benchmarks at 16K and 32K.
3
+ This is where KV cache compression matters most.
4
+ """
5
+ import torch
6
+ import json
7
+ import os
8
+ import sys
9
+ import time
10
+ import math
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ from datasets import load_dataset
13
+
14
+ sys.path.append(os.path.expanduser("~/kv-hack"))
15
+ from kernel.quant_cache import MixedPrecisionKVCache
16
+
17
+ MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
18
+ MODEL_PATHS = {
19
+ "mistral-7b": "~/kv-hack/mistral-model",
20
+ "llama-3-8b": "~/kv-hack/llama-model",
21
+ }
22
+ model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
23
+ results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
24
+
25
+ with open(f"{results_dir}/bit_allocation.json") as f:
26
+ raw = json.load(f)
27
+ bit_alloc = {
28
+ int(l): [raw[l][str(h)] for h in range(len(raw[l]))]
29
+ for l in raw
30
+ }
31
+ num_layers = len(bit_alloc)
32
+ avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
33
+ sum(len(l) for l in bit_alloc.values())
34
+
35
+ print(f"Model: {MODEL_NAME}")
36
+ print(f"Avg bits: {avg_bits:.2f}")
37
+
38
+ print("Loading model...")
39
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ model_path, dtype=torch.float16, device_map="cuda"
42
+ )
43
+ model.eval()
44
+ print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
45
+
46
+ def measure_context(context_len: int):
47
+ print(f"\n Context {context_len} tokens...")
48
+ input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
49
+
50
+ # peak memory
51
+ torch.cuda.reset_peak_memory_stats()
52
+ with torch.no_grad():
53
+ out = model(input_ids, use_cache=True)
54
+ kv = out.past_key_values
55
+ torch.cuda.synchronize()
56
+ peak_mem = torch.cuda.max_memory_allocated() / 1e9
57
+
58
+ # KV compression
59
+ fp16_bytes = 0
60
+ uniform8_bytes = 0
61
+ compressed_bytes = 0
62
+
63
+ for layer_idx in range(num_layers):
64
+ k = kv.layers[layer_idx].keys
65
+ v = kv.layers[layer_idx].values
66
+ fp16_bytes += k.numel() * 2 + v.numel() * 2
67
+ uniform8_bytes += k.numel() + v.numel()
68
+ cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
69
+ cache.store(k, v)
70
+ compressed_bytes += cache.memory_bytes()
71
+
72
+ # prefill speed
73
+ times = []
74
+ for _ in range(3):
75
+ torch.cuda.synchronize()
76
+ t0 = time.time()
77
+ with torch.no_grad():
78
+ _ = model(input_ids, use_cache=True)
79
+ torch.cuda.synchronize()
80
+ times.append(time.time() - t0)
81
+ prefill_ms = round(sum(times) / len(times) * 1000, 1)
82
+
83
+ return {
84
+ "context_len": context_len,
85
+ "peak_memory_gb": round(peak_mem, 2),
86
+ "fp16_mb": round(fp16_bytes / 1e6, 2),
87
+ "uniform8_mb": round(uniform8_bytes / 1e6, 2),
88
+ "mixed_precision_mb": round(compressed_bytes / 1e6, 2),
89
+ "compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2),
90
+ "compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2),
91
+ "prefill_ms": prefill_ms,
92
+ }
93
+
94
+ print("\n" + "="*60)
95
+ print("LONG CONTEXT BENCHMARK")
96
+ print("="*60)
97
+
98
+ results = []
99
+ for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
100
+ try:
101
+ r = measure_context(ctx)
102
+ results.append(r)
103
+ print(f" ctx={ctx:6d} | "
104
+ f"mem={r['peak_memory_gb']:.2f}GB | "
105
+ f"FP16={r['fp16_mb']:.0f}MB | "
106
+ f"Ours={r['mixed_precision_mb']:.0f}MB | "
107
+ f"{r['compression_vs_fp16']}x | "
108
+ f"prefill={r['prefill_ms']}ms")
109
+ except torch.cuda.OutOfMemoryError:
110
+ print(f" ctx={ctx:6d} | OOM — FP16 ran out of memory ✓")
111
+ # still measure our compressed version
112
+ results.append({
113
+ "context_len": ctx,
114
+ "peak_memory_gb": "OOM",
115
+ "fp16_mb": ctx * num_layers * 8 * 128 * 4 / 1e6,
116
+ "note": "FP16 OOM, compressed might fit"
117
+ })
118
+ break
119
+
120
+ # save
121
+ out_path = f"{results_dir}/long_context_results.json"
122
+ with open(out_path, "w") as f:
123
+ json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
124
+ print(f"\n✅ Saved to {out_path}")