grapheneaffiliates commited on
Commit
06e4588
·
verified ·
1 Parent(s): 26849c5

Upload python/benchmark_h4_vs_softmax.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. python/benchmark_h4_vs_softmax.py +373 -0
python/benchmark_h4_vs_softmax.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark: H4 geometric attention vs standard softmax attention.
3
+
4
+ Compares wall-clock time, peak memory, and attention score quality
5
+ at various context lengths to find the empirical crossover point
6
+ where H4's O(log t) chamber lookup beats softmax's O(t^2) matmul.
7
+
8
+ Now includes Rust-accelerated backend (h4_rust) when available.
9
+ """
10
+
11
+ import math
12
+ import time
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ import sys
18
+ import os
19
+
20
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
21
+
22
+ from h4_hybrid_attention import H4AttentionLayer
23
+ from utils.chamber_index import compute_chamber_ids
24
+
25
+ # Rust backend detection
26
+ try:
27
+ import h4_rust
28
+ RUST_AVAILABLE = True
29
+ except ImportError:
30
+ RUST_AVAILABLE = False
31
+
32
+
33
+ class SoftmaxAttentionLayer(nn.Module):
34
+ """Standard multi-head scaled dot-product attention for comparison."""
35
+
36
+ def __init__(self, d_model: int, n_heads: int = 8, d_value: int = 16, dropout: float = 0.0):
37
+ super().__init__()
38
+ self.n_heads = n_heads
39
+ self.d_head = d_model // n_heads
40
+ self.d_value = d_value
41
+ self.scale = 1.0 / math.sqrt(self.d_head)
42
+
43
+ self.W_q = nn.Linear(d_model, self.d_head * n_heads, bias=False)
44
+ self.W_k = nn.Linear(d_model, self.d_head * n_heads, bias=False)
45
+ self.W_v = nn.Linear(d_model, d_value * n_heads, bias=False)
46
+ self.W_out = nn.Linear(d_value * n_heads, d_model, bias=False)
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ B, T, D = x.shape
50
+ Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).permute(0, 2, 1, 3)
51
+ K = self.W_k(x).view(B, T, self.n_heads, self.d_head).permute(0, 2, 1, 3)
52
+ V = self.W_v(x).view(B, T, self.n_heads, self.d_value).permute(0, 2, 1, 3)
53
+
54
+ scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
55
+ mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
56
+ scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
57
+
58
+ attn = F.softmax(scores, dim=-1)
59
+ out = torch.matmul(attn, V)
60
+ out = out.permute(0, 2, 1, 3).reshape(B, T, -1)
61
+ return self.W_out(out)
62
+
63
+
64
+ def benchmark_forward_pass(layer, x, n_warmup=2, n_runs=5, **kwargs):
65
+ """Time forward pass, return mean and std in milliseconds."""
66
+ for _ in range(n_warmup):
67
+ _ = layer(x, **kwargs)
68
+
69
+ times = []
70
+ for _ in range(n_runs):
71
+ t0 = time.perf_counter()
72
+ _ = layer(x, **kwargs)
73
+ t1 = time.perf_counter()
74
+ times.append((t1 - t0) * 1000)
75
+
76
+ return np.mean(times), np.std(times)
77
+
78
+
79
+ def benchmark_rust_topk(keys_np, queries_np, k, n_warmup=2, n_runs=5):
80
+ """
81
+ Benchmark Rust h4_rust.query_topk on raw numpy arrays.
82
+ Returns mean and std in milliseconds.
83
+ """
84
+ if not RUST_AVAILABLE:
85
+ return None, None
86
+
87
+ keys = keys_np.astype(np.float64)
88
+ queries = queries_np.astype(np.float64)
89
+
90
+ # Warmup
91
+ for _ in range(n_warmup):
92
+ _ = h4_rust.query_topk(keys, queries, k)
93
+
94
+ times = []
95
+ for _ in range(n_runs):
96
+ t0 = time.perf_counter()
97
+ _ = h4_rust.query_topk(keys, queries, k)
98
+ t1 = time.perf_counter()
99
+ times.append((t1 - t0) * 1000)
100
+
101
+ return np.mean(times), np.std(times)
102
+
103
+
104
+ def benchmark_numpy_topk(keys_np, queries_np, k, n_warmup=2, n_runs=5):
105
+ """
106
+ Benchmark pure-numpy brute-force top-k for comparison.
107
+ Returns mean and std in milliseconds.
108
+ """
109
+ keys = keys_np.astype(np.float64)
110
+ queries = queries_np.astype(np.float64)
111
+
112
+ # Normalize
113
+ k_norms = np.linalg.norm(keys, axis=1, keepdims=True)
114
+ k_norms[k_norms < 1e-12] = 1.0
115
+ keys_normed = keys / k_norms
116
+
117
+ q_norms = np.linalg.norm(queries, axis=1, keepdims=True)
118
+ q_norms[q_norms < 1e-12] = 1.0
119
+ queries_normed = queries / q_norms
120
+
121
+ # Warmup
122
+ for _ in range(n_warmup):
123
+ dots = queries_normed @ keys_normed.T
124
+ _ = np.argsort(-dots, axis=1)[:, :k]
125
+
126
+ times = []
127
+ for _ in range(n_runs):
128
+ t0 = time.perf_counter()
129
+ dots = queries_normed @ keys_normed.T
130
+ _ = np.argsort(-dots, axis=1)[:, :k]
131
+ t1 = time.perf_counter()
132
+ times.append((t1 - t0) * 1000)
133
+
134
+ return np.mean(times), np.std(times)
135
+
136
+
137
+ def compare_attention_patterns(h4_layer, softmax_layer, x):
138
+ """
139
+ Compare attention score distributions between H4 and softmax.
140
+ Returns correlation coefficient.
141
+ """
142
+ B, T, D = x.shape
143
+
144
+ h4_out = h4_layer(x, use_tree=False)
145
+ softmax_out = softmax_layer(x)
146
+
147
+ h4_flat = h4_out.detach().flatten()
148
+ sm_flat = softmax_out.detach().flatten()
149
+
150
+ if h4_flat.std() < 1e-8 or sm_flat.std() < 1e-8:
151
+ return 0.0
152
+
153
+ corr = torch.corrcoef(torch.stack([h4_flat, sm_flat]))[0, 1].item()
154
+ return corr
155
+
156
+
157
+ def main():
158
+ torch.manual_seed(42)
159
+ np.random.seed(42)
160
+
161
+ d_model = 64
162
+ n_heads = 8
163
+ d_value = 16
164
+ batch_size = 1
165
+ top_k = 32
166
+
167
+ # Part 1 uses the full H4 attention layer (Python tree), so keep lengths moderate
168
+ layer_seq_lengths = [64, 128, 256, 512, 1024]
169
+
170
+ # Part 2 tests raw Rust top-k at extended lengths
171
+ rust_seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
172
+
173
+ print("=" * 100)
174
+ print("H4 Geometric Attention vs Standard Softmax Attention -- Benchmark")
175
+ print("=" * 100)
176
+ print(f"d_model={d_model}, n_heads={n_heads}, d_value={d_value}, batch_size={batch_size}, top_k={top_k}")
177
+ print(f"Rust backend (h4_rust): {'AVAILABLE' if RUST_AVAILABLE else 'NOT AVAILABLE (install with: cd rust && maturin develop --release)'}")
178
+ print()
179
+
180
+ # Create layers
181
+ h4_layer = H4AttentionLayer(d_model, n_heads, d_value, top_k=top_k)
182
+ softmax_layer = SoftmaxAttentionLayer(d_model, n_heads, d_value)
183
+
184
+ h4_layer.eval()
185
+ softmax_layer.eval()
186
+
187
+ # ============================================================
188
+ # Part 1: Full attention layer benchmark (softmax vs H4)
189
+ # ============================================================
190
+ print("-" * 100)
191
+ print("PART 1: Full Attention Layer Forward Pass (ms)")
192
+ print("-" * 100)
193
+
194
+ results = []
195
+
196
+ header = f"{'seq_len':>8} | {'softmax_ms':>12} | {'h4_full_ms':>12} | {'h4_tree_ms':>12} | {'tree/full':>10} | {'corr':>8}"
197
+ print(header)
198
+ print("-" * len(header))
199
+
200
+ for T in layer_seq_lengths:
201
+ x = torch.randn(batch_size, T, d_model)
202
+
203
+ with torch.no_grad():
204
+ sm_mean, sm_std = benchmark_forward_pass(softmax_layer, x)
205
+ h4_full_mean, h4_full_std = benchmark_forward_pass(h4_layer, x, use_tree=False)
206
+
207
+ if T > 64:
208
+ h4_tree_mean, h4_tree_std = benchmark_forward_pass(h4_layer, x, use_tree=True, n_runs=3)
209
+ else:
210
+ h4_tree_mean = h4_full_mean
211
+ h4_tree_std = h4_full_std
212
+
213
+ corr = compare_attention_patterns(h4_layer, softmax_layer, x)
214
+ ratio = h4_tree_mean / max(h4_full_mean, 0.001)
215
+
216
+ print(f"{T:8d} | {sm_mean:10.1f}+/-{sm_std:3.1f} | {h4_full_mean:10.1f}+/-{h4_full_std:3.1f} | {h4_tree_mean:10.1f}+/-{h4_tree_std:3.1f} | {ratio:10.3f} | {corr:8.4f}")
217
+
218
+ results.append({
219
+ 'seq_len': T,
220
+ 'softmax_ms': sm_mean,
221
+ 'h4_full_ms': h4_full_mean,
222
+ 'h4_tree_ms': h4_tree_mean,
223
+ 'tree_vs_full_ratio': ratio,
224
+ 'output_correlation': corr,
225
+ })
226
+
227
+ # ============================================================
228
+ # Part 2: Raw top-k benchmark (Rust vs NumPy)
229
+ # ============================================================
230
+ print()
231
+ print("-" * 100)
232
+ print("PART 2: Raw Top-k Query Benchmark — Rust h4_rust vs NumPy (ms)")
233
+ print(" (One attention head: n_queries=64 queries against n_keys keys, k=32)")
234
+ print("-" * 100)
235
+
236
+ n_queries = 64
237
+ k = 32
238
+
239
+ if RUST_AVAILABLE:
240
+ header2 = f"{'n_keys':>8} | {'numpy_ms':>12} | {'rust_ms':>12} | {'speedup':>10}"
241
+ print(header2)
242
+ print("-" * len(header2))
243
+
244
+ rust_results = []
245
+ for T in rust_seq_lengths:
246
+ keys_np = np.random.randn(T, 4).astype(np.float64)
247
+ queries_np = np.random.randn(n_queries, 4).astype(np.float64)
248
+
249
+ np_mean, np_std = benchmark_numpy_topk(keys_np, queries_np, k)
250
+ rust_mean, rust_std = benchmark_rust_topk(keys_np, queries_np, k)
251
+
252
+ speedup = np_mean / max(rust_mean, 0.001) if rust_mean else 0.0
253
+
254
+ print(f"{T:8d} | {np_mean:10.3f}+/-{np_std:3.3f} | {rust_mean:10.3f}+/-{rust_std:3.3f} | {speedup:9.1f}x")
255
+
256
+ rust_results.append({
257
+ 'n_keys': T,
258
+ 'numpy_ms': np_mean,
259
+ 'rust_ms': rust_mean,
260
+ 'speedup': speedup,
261
+ })
262
+ else:
263
+ print(" [SKIPPED] Rust backend not available.")
264
+ print(" Install with: cd rust && maturin develop --release")
265
+ rust_results = []
266
+
267
+ # ============================================================
268
+ # Part 3: Chamber index computation benchmark
269
+ # ============================================================
270
+ print()
271
+ print("-" * 100)
272
+ print("PART 3: Chamber Index Computation — Rust vs NumPy (ms)")
273
+ print("-" * 100)
274
+
275
+ if RUST_AVAILABLE:
276
+ roots = h4_rust.get_simple_roots() # (4, 4) f64
277
+ header3 = f"{'n_vectors':>10} | {'numpy_ms':>12} | {'rust_ms':>12} | {'speedup':>10}"
278
+ print(header3)
279
+ print("-" * len(header3))
280
+
281
+ for n_vecs in [1000, 10000, 100000]:
282
+ vecs = np.random.randn(n_vecs, 4).astype(np.float64)
283
+ roots_torch = torch.from_numpy(roots).float()
284
+
285
+ # NumPy/torch chamber IDs
286
+ vecs_torch = torch.from_numpy(vecs).float()
287
+ # Warmup
288
+ for _ in range(2):
289
+ _ = compute_chamber_ids(vecs_torch, roots_torch)
290
+
291
+ times_np = []
292
+ for _ in range(5):
293
+ t0 = time.perf_counter()
294
+ _ = compute_chamber_ids(vecs_torch, roots_torch)
295
+ t1 = time.perf_counter()
296
+ times_np.append((t1 - t0) * 1000)
297
+ np_mean = np.mean(times_np)
298
+ np_std_val = np.std(times_np)
299
+
300
+ # Rust chamber IDs
301
+ for _ in range(2):
302
+ _ = h4_rust.chamber_indices(vecs, roots)
303
+
304
+ times_rust = []
305
+ for _ in range(5):
306
+ t0 = time.perf_counter()
307
+ _ = h4_rust.chamber_indices(vecs, roots)
308
+ t1 = time.perf_counter()
309
+ times_rust.append((t1 - t0) * 1000)
310
+ rust_mean = np.mean(times_rust)
311
+ rust_std_val = np.std(times_rust)
312
+
313
+ speedup = np_mean / max(rust_mean, 0.001)
314
+ print(f"{n_vecs:10d} | {np_mean:10.3f}+/-{np_std_val:3.3f} | {rust_mean:10.3f}+/-{rust_std_val:3.3f} | {speedup:9.1f}x")
315
+
316
+ # Verify correctness: Rust and torch should agree
317
+ ids_torch = compute_chamber_ids(vecs_torch, roots_torch).numpy()
318
+ ids_rust = h4_rust.chamber_indices(vecs, roots)
319
+ # Note: bit ordering may differ, just check both produce valid 0-15 range
320
+ assert ids_rust.min() >= 0 and ids_rust.max() <= 15, "Rust chamber IDs out of range"
321
+ else:
322
+ print(" [SKIPPED] Rust backend not available.")
323
+
324
+ # ============================================================
325
+ # Summary
326
+ # ============================================================
327
+ print()
328
+ print("=" * 100)
329
+ print("SUMMARY")
330
+ print("=" * 100)
331
+
332
+ # Scaling analysis from Part 1
333
+ if len(results) >= 2:
334
+ sm_times = [(r['seq_len'], r['softmax_ms']) for r in results]
335
+ h4_times = [(r['seq_len'], r['h4_tree_ms']) for r in results]
336
+
337
+ sm_exp = math.log(sm_times[-1][1] / max(sm_times[0][1], 0.01)) / math.log(sm_times[-1][0] / sm_times[0][0])
338
+ h4_exp = math.log(h4_times[-1][1] / max(h4_times[0][1], 0.01)) / math.log(h4_times[-1][0] / h4_times[0][0])
339
+
340
+ print(f" Softmax scaling exponent: ~{sm_exp:.2f} (expect ~2.0 for O(t^2))")
341
+ print(f" H4 tree scaling exponent: ~{h4_exp:.2f} (expect ~0 for O(log t), higher due to Python overhead)")
342
+
343
+ crossover = None
344
+ for r in results:
345
+ if r['h4_tree_ms'] < r['softmax_ms']:
346
+ crossover = r['seq_len']
347
+ break
348
+
349
+ if crossover:
350
+ print(f" H4 tree becomes faster than softmax at seq_len={crossover}")
351
+ else:
352
+ print(" Softmax is faster at all tested layer-level lengths")
353
+ print(" (H4 tree overhead dominates at small/medium lengths due to Python ChamberTree)")
354
+
355
+ if RUST_AVAILABLE and rust_results:
356
+ print()
357
+ print(" Rust backend top-k performance:")
358
+ for r in rust_results[:6]:
359
+ print(f" n_keys={r['n_keys']:>6d}: Rust {r['rust_ms']:.3f}ms vs NumPy {r['numpy_ms']:.3f}ms ({r['speedup']:.1f}x)")
360
+ elif not RUST_AVAILABLE:
361
+ print()
362
+ print(" Rust backend was NOT available for this run.")
363
+ print(" To enable: cd rust && maturin develop --release")
364
+
365
+ print()
366
+ print(" Note: The Python ChamberTree has high constant factors.")
367
+ print(" The Rust h4_rust backend shows raw computation speedups.")
368
+ print(" Full Rust-accelerated attention layer is the next step.")
369
+ print("=" * 100)
370
+
371
+
372
+ if __name__ == '__main__':
373
+ main()