| """ |
| Benchmark: H4 geometric attention vs standard softmax attention. |
| |
| Compares wall-clock time, peak memory, and attention score quality |
| at various context lengths to find the empirical crossover point |
| where H4's O(log t) chamber lookup beats softmax's O(t^2) matmul. |
| |
| Now includes Rust-accelerated backend (h4_rust) when available. |
| """ |
|
|
| import math |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from h4_hybrid_attention import H4AttentionLayer |
| from utils.chamber_index import compute_chamber_ids |
|
|
| |
| try: |
| import h4_rust |
| RUST_AVAILABLE = True |
| except ImportError: |
| RUST_AVAILABLE = False |
|
|
|
|
| class SoftmaxAttentionLayer(nn.Module): |
| """Standard multi-head scaled dot-product attention for comparison.""" |
|
|
| def __init__(self, d_model: int, n_heads: int = 8, d_value: int = 16, dropout: float = 0.0): |
| super().__init__() |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.d_value = d_value |
| self.scale = 1.0 / math.sqrt(self.d_head) |
|
|
| self.W_q = nn.Linear(d_model, self.d_head * n_heads, bias=False) |
| self.W_k = nn.Linear(d_model, self.d_head * n_heads, bias=False) |
| self.W_v = nn.Linear(d_model, d_value * n_heads, bias=False) |
| self.W_out = nn.Linear(d_value * n_heads, d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, D = x.shape |
| Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).permute(0, 2, 1, 3) |
| K = self.W_k(x).view(B, T, self.n_heads, self.d_head).permute(0, 2, 1, 3) |
| V = self.W_v(x).view(B, T, self.n_heads, self.d_value).permute(0, 2, 1, 3) |
|
|
| scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale |
| mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1) |
| scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf')) |
|
|
| attn = F.softmax(scores, dim=-1) |
| out = torch.matmul(attn, V) |
| out = out.permute(0, 2, 1, 3).reshape(B, T, -1) |
| return self.W_out(out) |
|
|
|
|
| def benchmark_forward_pass(layer, x, n_warmup=2, n_runs=5, **kwargs): |
| """Time forward pass, return mean and std in milliseconds.""" |
| for _ in range(n_warmup): |
| _ = layer(x, **kwargs) |
|
|
| times = [] |
| for _ in range(n_runs): |
| t0 = time.perf_counter() |
| _ = layer(x, **kwargs) |
| t1 = time.perf_counter() |
| times.append((t1 - t0) * 1000) |
|
|
| return np.mean(times), np.std(times) |
|
|
|
|
| def benchmark_rust_topk(keys_np, queries_np, k, n_warmup=2, n_runs=5): |
| """ |
| Benchmark Rust h4_rust.query_topk on raw numpy arrays. |
| Returns mean and std in milliseconds. |
| """ |
| if not RUST_AVAILABLE: |
| return None, None |
|
|
| keys = keys_np.astype(np.float64) |
| queries = queries_np.astype(np.float64) |
|
|
| |
| for _ in range(n_warmup): |
| _ = h4_rust.query_topk(keys, queries, k) |
|
|
| times = [] |
| for _ in range(n_runs): |
| t0 = time.perf_counter() |
| _ = h4_rust.query_topk(keys, queries, k) |
| t1 = time.perf_counter() |
| times.append((t1 - t0) * 1000) |
|
|
| return np.mean(times), np.std(times) |
|
|
|
|
| def benchmark_numpy_topk(keys_np, queries_np, k, n_warmup=2, n_runs=5): |
| """ |
| Benchmark pure-numpy brute-force top-k for comparison. |
| Returns mean and std in milliseconds. |
| """ |
| keys = keys_np.astype(np.float64) |
| queries = queries_np.astype(np.float64) |
|
|
| |
| k_norms = np.linalg.norm(keys, axis=1, keepdims=True) |
| k_norms[k_norms < 1e-12] = 1.0 |
| keys_normed = keys / k_norms |
|
|
| q_norms = np.linalg.norm(queries, axis=1, keepdims=True) |
| q_norms[q_norms < 1e-12] = 1.0 |
| queries_normed = queries / q_norms |
|
|
| |
| for _ in range(n_warmup): |
| dots = queries_normed @ keys_normed.T |
| _ = np.argsort(-dots, axis=1)[:, :k] |
|
|
| times = [] |
| for _ in range(n_runs): |
| t0 = time.perf_counter() |
| dots = queries_normed @ keys_normed.T |
| _ = np.argsort(-dots, axis=1)[:, :k] |
| t1 = time.perf_counter() |
| times.append((t1 - t0) * 1000) |
|
|
| return np.mean(times), np.std(times) |
|
|
|
|
| def compare_attention_patterns(h4_layer, softmax_layer, x): |
| """ |
| Compare attention score distributions between H4 and softmax. |
| Returns correlation coefficient. |
| """ |
| B, T, D = x.shape |
|
|
| h4_out = h4_layer(x, use_tree=False) |
| softmax_out = softmax_layer(x) |
|
|
| h4_flat = h4_out.detach().flatten() |
| sm_flat = softmax_out.detach().flatten() |
|
|
| if h4_flat.std() < 1e-8 or sm_flat.std() < 1e-8: |
| return 0.0 |
|
|
| corr = torch.corrcoef(torch.stack([h4_flat, sm_flat]))[0, 1].item() |
| return corr |
|
|
|
|
| def main(): |
| torch.manual_seed(42) |
| np.random.seed(42) |
|
|
| d_model = 64 |
| n_heads = 8 |
| d_value = 16 |
| batch_size = 1 |
| top_k = 32 |
|
|
| |
| layer_seq_lengths = [64, 128, 256, 512, 1024] |
|
|
| |
| rust_seq_lengths = [512, 1024, 2048, 4096, 8192, 16384] |
|
|
| print("=" * 100) |
| print("H4 Geometric Attention vs Standard Softmax Attention -- Benchmark") |
| print("=" * 100) |
| print(f"d_model={d_model}, n_heads={n_heads}, d_value={d_value}, batch_size={batch_size}, top_k={top_k}") |
| print(f"Rust backend (h4_rust): {'AVAILABLE' if RUST_AVAILABLE else 'NOT AVAILABLE (install with: cd rust && maturin develop --release)'}") |
| print() |
|
|
| |
| h4_layer = H4AttentionLayer(d_model, n_heads, d_value, top_k=top_k) |
| softmax_layer = SoftmaxAttentionLayer(d_model, n_heads, d_value) |
|
|
| h4_layer.eval() |
| softmax_layer.eval() |
|
|
| |
| |
| |
| print("-" * 100) |
| print("PART 1: Full Attention Layer Forward Pass (ms)") |
| print("-" * 100) |
|
|
| results = [] |
|
|
| header = f"{'seq_len':>8} | {'softmax_ms':>12} | {'h4_full_ms':>12} | {'h4_tree_ms':>12} | {'tree/full':>10} | {'corr':>8}" |
| print(header) |
| print("-" * len(header)) |
|
|
| for T in layer_seq_lengths: |
| x = torch.randn(batch_size, T, d_model) |
|
|
| with torch.no_grad(): |
| sm_mean, sm_std = benchmark_forward_pass(softmax_layer, x) |
| h4_full_mean, h4_full_std = benchmark_forward_pass(h4_layer, x, use_tree=False) |
|
|
| if T > 64: |
| h4_tree_mean, h4_tree_std = benchmark_forward_pass(h4_layer, x, use_tree=True, n_runs=3) |
| else: |
| h4_tree_mean = h4_full_mean |
| h4_tree_std = h4_full_std |
|
|
| corr = compare_attention_patterns(h4_layer, softmax_layer, x) |
| ratio = h4_tree_mean / max(h4_full_mean, 0.001) |
|
|
| print(f"{T:8d} | {sm_mean:10.1f}+/-{sm_std:3.1f} | {h4_full_mean:10.1f}+/-{h4_full_std:3.1f} | {h4_tree_mean:10.1f}+/-{h4_tree_std:3.1f} | {ratio:10.3f} | {corr:8.4f}") |
|
|
| results.append({ |
| 'seq_len': T, |
| 'softmax_ms': sm_mean, |
| 'h4_full_ms': h4_full_mean, |
| 'h4_tree_ms': h4_tree_mean, |
| 'tree_vs_full_ratio': ratio, |
| 'output_correlation': corr, |
| }) |
|
|
| |
| |
| |
| print() |
| print("-" * 100) |
| print("PART 2: Raw Top-k Query Benchmark — Rust h4_rust vs NumPy (ms)") |
| print(" (One attention head: n_queries=64 queries against n_keys keys, k=32)") |
| print("-" * 100) |
|
|
| n_queries = 64 |
| k = 32 |
|
|
| if RUST_AVAILABLE: |
| header2 = f"{'n_keys':>8} | {'numpy_ms':>12} | {'rust_ms':>12} | {'speedup':>10}" |
| print(header2) |
| print("-" * len(header2)) |
|
|
| rust_results = [] |
| for T in rust_seq_lengths: |
| keys_np = np.random.randn(T, 4).astype(np.float64) |
| queries_np = np.random.randn(n_queries, 4).astype(np.float64) |
|
|
| np_mean, np_std = benchmark_numpy_topk(keys_np, queries_np, k) |
| rust_mean, rust_std = benchmark_rust_topk(keys_np, queries_np, k) |
|
|
| speedup = np_mean / max(rust_mean, 0.001) if rust_mean else 0.0 |
|
|
| print(f"{T:8d} | {np_mean:10.3f}+/-{np_std:3.3f} | {rust_mean:10.3f}+/-{rust_std:3.3f} | {speedup:9.1f}x") |
|
|
| rust_results.append({ |
| 'n_keys': T, |
| 'numpy_ms': np_mean, |
| 'rust_ms': rust_mean, |
| 'speedup': speedup, |
| }) |
| else: |
| print(" [SKIPPED] Rust backend not available.") |
| print(" Install with: cd rust && maturin develop --release") |
| rust_results = [] |
|
|
| |
| |
| |
| print() |
| print("-" * 100) |
| print("PART 3: Chamber Index Computation — Rust vs NumPy (ms)") |
| print("-" * 100) |
|
|
| if RUST_AVAILABLE: |
| roots = h4_rust.get_simple_roots() |
| header3 = f"{'n_vectors':>10} | {'numpy_ms':>12} | {'rust_ms':>12} | {'speedup':>10}" |
| print(header3) |
| print("-" * len(header3)) |
|
|
| for n_vecs in [1000, 10000, 100000]: |
| vecs = np.random.randn(n_vecs, 4).astype(np.float64) |
| roots_torch = torch.from_numpy(roots).float() |
|
|
| |
| vecs_torch = torch.from_numpy(vecs).float() |
| |
| for _ in range(2): |
| _ = compute_chamber_ids(vecs_torch, roots_torch) |
|
|
| times_np = [] |
| for _ in range(5): |
| t0 = time.perf_counter() |
| _ = compute_chamber_ids(vecs_torch, roots_torch) |
| t1 = time.perf_counter() |
| times_np.append((t1 - t0) * 1000) |
| np_mean = np.mean(times_np) |
| np_std_val = np.std(times_np) |
|
|
| |
| for _ in range(2): |
| _ = h4_rust.chamber_indices(vecs, roots) |
|
|
| times_rust = [] |
| for _ in range(5): |
| t0 = time.perf_counter() |
| _ = h4_rust.chamber_indices(vecs, roots) |
| t1 = time.perf_counter() |
| times_rust.append((t1 - t0) * 1000) |
| rust_mean = np.mean(times_rust) |
| rust_std_val = np.std(times_rust) |
|
|
| speedup = np_mean / max(rust_mean, 0.001) |
| print(f"{n_vecs:10d} | {np_mean:10.3f}+/-{np_std_val:3.3f} | {rust_mean:10.3f}+/-{rust_std_val:3.3f} | {speedup:9.1f}x") |
|
|
| |
| ids_torch = compute_chamber_ids(vecs_torch, roots_torch).numpy() |
| ids_rust = h4_rust.chamber_indices(vecs, roots) |
| |
| assert ids_rust.min() >= 0 and ids_rust.max() <= 15, "Rust chamber IDs out of range" |
| else: |
| print(" [SKIPPED] Rust backend not available.") |
|
|
| |
| |
| |
| print() |
| print("=" * 100) |
| print("SUMMARY") |
| print("=" * 100) |
|
|
| |
| if len(results) >= 2: |
| sm_times = [(r['seq_len'], r['softmax_ms']) for r in results] |
| h4_times = [(r['seq_len'], r['h4_tree_ms']) for r in results] |
|
|
| sm_exp = math.log(sm_times[-1][1] / max(sm_times[0][1], 0.01)) / math.log(sm_times[-1][0] / sm_times[0][0]) |
| h4_exp = math.log(h4_times[-1][1] / max(h4_times[0][1], 0.01)) / math.log(h4_times[-1][0] / h4_times[0][0]) |
|
|
| print(f" Softmax scaling exponent: ~{sm_exp:.2f} (expect ~2.0 for O(t^2))") |
| print(f" H4 tree scaling exponent: ~{h4_exp:.2f} (expect ~0 for O(log t), higher due to Python overhead)") |
|
|
| crossover = None |
| for r in results: |
| if r['h4_tree_ms'] < r['softmax_ms']: |
| crossover = r['seq_len'] |
| break |
|
|
| if crossover: |
| print(f" H4 tree becomes faster than softmax at seq_len={crossover}") |
| else: |
| print(" Softmax is faster at all tested layer-level lengths") |
| print(" (H4 tree overhead dominates at small/medium lengths due to Python ChamberTree)") |
|
|
| if RUST_AVAILABLE and rust_results: |
| print() |
| print(" Rust backend top-k performance:") |
| for r in rust_results[:6]: |
| print(f" n_keys={r['n_keys']:>6d}: Rust {r['rust_ms']:.3f}ms vs NumPy {r['numpy_ms']:.3f}ms ({r['speedup']:.1f}x)") |
| elif not RUST_AVAILABLE: |
| print() |
| print(" Rust backend was NOT available for this run.") |
| print(" To enable: cd rust && maturin develop --release") |
|
|
| print() |
| print(" Note: The Python ChamberTree has high constant factors.") |
| print(" The Rust h4_rust backend shows raw computation speedups.") |
| print(" Full Rust-accelerated attention layer is the next step.") |
| print("=" * 100) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|