harshithsaiv commited on
Commit
c32a0aa
Β·
1 Parent(s): c0919f1

feat: Implementing benchmark

Browse files
Files changed (1) hide show
  1. benchmark.py +216 -0
benchmark.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full benchmark suite comparing:
3
+ 1. FP16 baseline
4
+ 2. Uniform 8-bit quantization
5
+ 3. Our mixed per-head quantization
6
+ Across: memory, speed, perplexity
7
+ """
8
+ import torch
9
+ import json
10
+ import os
11
+ import sys
12
+ import time
13
+ import math
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ from datasets import load_dataset
16
+
17
+ sys.path.append(os.path.expanduser("~/kv-hack"))
18
+ from kernel.quant_cache import MixedPrecisionKVCache
19
+
20
+ # ── config ──────────────────────────────────────────
21
+ MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
22
+ MODEL_PATHS = {
23
+ "mistral-7b": "~/kv-hack/mistral-model",
24
+ "llama-3-8b": "~/kv-hack/llama-model",
25
+ }
26
+ model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
27
+ results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
28
+
29
+ # load bit allocation
30
+ with open(f"{results_dir}/bit_allocation.json") as f:
31
+ bit_alloc_raw = json.load(f)
32
+ bit_alloc = {
33
+ int(l): [bit_alloc_raw[l][str(h)]
34
+ for h in range(len(bit_alloc_raw[l]))]
35
+ for l in bit_alloc_raw
36
+ }
37
+ num_layers = len(bit_alloc)
38
+ avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
39
+ sum(len(l) for l in bit_alloc.values())
40
+
41
+ print(f"Benchmarking: {MODEL_NAME}")
42
+ print(f"Avg bits: {avg_bits:.2f}")
43
+
44
+ # ── load model ──────────────────────────────────────
45
+ print("Loading model...")
46
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_path, dtype=torch.float16, device_map="cuda"
49
+ )
50
+ model.eval()
51
+ print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
52
+
53
+ # ── helper: compute KV compression at given context ──
54
+ def measure_kv_compression(context_len: int):
55
+ input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
56
+ with torch.no_grad():
57
+ out = model(input_ids, use_cache=True)
58
+ kv = out.past_key_values
59
+
60
+ fp16_bytes = 0
61
+ compressed_bytes = 0
62
+ uniform8_bytes = 0
63
+
64
+ for layer_idx in range(num_layers):
65
+ k = kv.layers[layer_idx].keys
66
+ v = kv.layers[layer_idx].values
67
+
68
+ # FP16 baseline
69
+ fp16_bytes += k.numel() * 2 + v.numel() * 2
70
+
71
+ # uniform 8-bit
72
+ uniform8_bytes += k.numel() + v.numel() # 1 byte per element
73
+
74
+ # our mixed precision
75
+ cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
76
+ cache.store(k, v)
77
+ compressed_bytes += cache.memory_bytes()
78
+
79
+ return {
80
+ "context_len": context_len,
81
+ "fp16_mb": round(fp16_bytes / 1e6, 2),
82
+ "uniform8_mb": round(uniform8_bytes / 1e6, 2),
83
+ "mixed_precision_mb": round(compressed_bytes / 1e6, 2),
84
+ "compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2),
85
+ "compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2),
86
+ }
87
+
88
+ # ── helper: measure perplexity ───────────────────────
89
+ def measure_perplexity(num_samples: int = 50):
90
+ print(f" Computing perplexity on {num_samples} WikiText samples...")
91
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
92
+ texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples]
93
+
94
+ total_loss = 0
95
+ total_tokens = 0
96
+
97
+ for text in texts:
98
+ inputs = tokenizer(
99
+ text, return_tensors="pt",
100
+ max_length=512, truncation=True
101
+ ).to("cuda")
102
+
103
+ if inputs["input_ids"].shape[1] < 10:
104
+ continue
105
+
106
+ with torch.no_grad():
107
+ out = model(**inputs, labels=inputs["input_ids"])
108
+ loss = out.loss.item()
109
+
110
+ n = inputs["input_ids"].shape[1]
111
+ total_loss += loss * n
112
+ total_tokens += n
113
+
114
+ ppl = math.exp(total_loss / total_tokens)
115
+ return round(ppl, 2)
116
+
117
+ # ── helper: measure decode speed ─────────────────────
118
+ def measure_speed(context_len: int = 512, n_tokens: int = 100):
119
+ input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
120
+
121
+ # warmup
122
+ with torch.no_grad():
123
+ _ = model.generate(
124
+ input_ids, max_new_tokens=10,
125
+ do_sample=False,
126
+ pad_token_id=tokenizer.eos_token_id
127
+ )
128
+
129
+ torch.cuda.synchronize()
130
+ t0 = time.time()
131
+ with torch.no_grad():
132
+ _ = model.generate(
133
+ input_ids, max_new_tokens=n_tokens,
134
+ do_sample=False,
135
+ pad_token_id=tokenizer.eos_token_id
136
+ )
137
+ torch.cuda.synchronize()
138
+ elapsed = time.time() - t0
139
+ return round(n_tokens / elapsed, 1)
140
+
141
+ # ── helper: peak memory at context ───────────────────
142
+ def measure_peak_memory(context_len: int):
143
+ torch.cuda.reset_peak_memory_stats()
144
+ input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
145
+ with torch.no_grad():
146
+ _ = model(input_ids, use_cache=True)
147
+ torch.cuda.synchronize()
148
+ return round(torch.cuda.max_memory_allocated() / 1e9, 2)
149
+
150
+ # ── RUN ALL BENCHMARKS ───────────────────────────────
151
+ print("\n" + "="*60)
152
+ print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS")
153
+ print("="*60)
154
+
155
+ compression_results = []
156
+ for ctx in [512, 1024, 2048, 4096, 8192]:
157
+ print(f" Context {ctx}...", end=" ", flush=True)
158
+ r = measure_kv_compression(ctx)
159
+ compression_results.append(r)
160
+ print(f"FP16={r['fp16_mb']}MB "
161
+ f"Uniform8={r['uniform8_mb']}MB "
162
+ f"Ours={r['mixed_precision_mb']}MB "
163
+ f"({r['compression_vs_fp16']}x vs FP16)")
164
+
165
+ print("\n" + "="*60)
166
+ print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS")
167
+ print("="*60)
168
+
169
+ memory_results = []
170
+ for ctx in [1024, 4096, 8192]:
171
+ print(f" Context {ctx}...", end=" ", flush=True)
172
+ mem = measure_peak_memory(ctx)
173
+ memory_results.append({"context": ctx, "peak_memory_gb": mem})
174
+ print(f"{mem} GB")
175
+
176
+ print("\n" + "="*60)
177
+ print("3. DECODE SPEED")
178
+ print("="*60)
179
+ print(" Measuring tokens/sec...", end=" ", flush=True)
180
+ speed = measure_speed()
181
+ print(f"{speed} tokens/sec")
182
+
183
+ print("\n" + "="*60)
184
+ print("4. PERPLEXITY (quality check)")
185
+ print("="*60)
186
+ perplexity = measure_perplexity(num_samples=50)
187
+ print(f" Perplexity: {perplexity}")
188
+
189
+ # ── SAVE ALL RESULTS ─────────────────────────────────
190
+ benchmark_results = {
191
+ "model": MODEL_NAME,
192
+ "avg_bits": round(avg_bits, 2),
193
+ "compression": compression_results,
194
+ "memory": memory_results,
195
+ "decode_tokens_per_sec": speed,
196
+ "perplexity": perplexity,
197
+ "summary": {
198
+ "fp16_8k_mb": next(r["fp16_mb"] for r in compression_results if r["context_len"] == 8192),
199
+ "ours_8k_mb": next(r["mixed_precision_mb"] for r in compression_results if r["context_len"] == 8192),
200
+ "compression_8k": next(r["compression_vs_fp16"] for r in compression_results if r["context_len"] == 8192),
201
+ }
202
+ }
203
+
204
+ out_path = f"{results_dir}/benchmark_results.json"
205
+ with open(out_path, "w") as f:
206
+ json.dump(benchmark_results, f, indent=2)
207
+
208
+ print("\n" + "="*60)
209
+ print("SUMMARY")
210
+ print("="*60)
211
+ print(f"Model: {MODEL_NAME}")
212
+ print(f"Avg bits: {avg_bits:.2f}")
213
+ print(f"Perplexity: {perplexity}")
214
+ print(f"Speed: {speed} tokens/sec")
215
+ print(f"KV @ 8K ctx: {benchmark_results['summary']['fp16_8k_mb']}MB β†’ {benchmark_results['summary']['ours_8k_mb']}MB ({benchmark_results['summary']['compression_8k']}x)")
216
+ print(f"\nβœ… Saved to {out_path}")