h4-polytopic-attention / python /benchmark_h4_vs_softmax.py
grapheneaffiliates's picture
Upload python/benchmark_h4_vs_softmax.py with huggingface_hub
06e4588 verified
"""
Benchmark: H4 geometric attention vs standard softmax attention.
Compares wall-clock time, peak memory, and attention score quality
at various context lengths to find the empirical crossover point
where H4's O(log t) chamber lookup beats softmax's O(t^2) matmul.
Now includes Rust-accelerated backend (h4_rust) when available.
"""
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from h4_hybrid_attention import H4AttentionLayer
from utils.chamber_index import compute_chamber_ids
# Rust backend detection
try:
import h4_rust
RUST_AVAILABLE = True
except ImportError:
RUST_AVAILABLE = False
class SoftmaxAttentionLayer(nn.Module):
"""Standard multi-head scaled dot-product attention for comparison."""
def __init__(self, d_model: int, n_heads: int = 8, d_value: int = 16, dropout: float = 0.0):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.d_value = d_value
self.scale = 1.0 / math.sqrt(self.d_head)
self.W_q = nn.Linear(d_model, self.d_head * n_heads, bias=False)
self.W_k = nn.Linear(d_model, self.d_head * n_heads, bias=False)
self.W_v = nn.Linear(d_model, d_value * n_heads, bias=False)
self.W_out = nn.Linear(d_value * n_heads, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, D = x.shape
Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).permute(0, 2, 1, 3)
K = self.W_k(x).view(B, T, self.n_heads, self.d_head).permute(0, 2, 1, 3)
V = self.W_v(x).view(B, T, self.n_heads, self.d_value).permute(0, 2, 1, 3)
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.permute(0, 2, 1, 3).reshape(B, T, -1)
return self.W_out(out)
def benchmark_forward_pass(layer, x, n_warmup=2, n_runs=5, **kwargs):
"""Time forward pass, return mean and std in milliseconds."""
for _ in range(n_warmup):
_ = layer(x, **kwargs)
times = []
for _ in range(n_runs):
t0 = time.perf_counter()
_ = layer(x, **kwargs)
t1 = time.perf_counter()
times.append((t1 - t0) * 1000)
return np.mean(times), np.std(times)
def benchmark_rust_topk(keys_np, queries_np, k, n_warmup=2, n_runs=5):
"""
Benchmark Rust h4_rust.query_topk on raw numpy arrays.
Returns mean and std in milliseconds.
"""
if not RUST_AVAILABLE:
return None, None
keys = keys_np.astype(np.float64)
queries = queries_np.astype(np.float64)
# Warmup
for _ in range(n_warmup):
_ = h4_rust.query_topk(keys, queries, k)
times = []
for _ in range(n_runs):
t0 = time.perf_counter()
_ = h4_rust.query_topk(keys, queries, k)
t1 = time.perf_counter()
times.append((t1 - t0) * 1000)
return np.mean(times), np.std(times)
def benchmark_numpy_topk(keys_np, queries_np, k, n_warmup=2, n_runs=5):
"""
Benchmark pure-numpy brute-force top-k for comparison.
Returns mean and std in milliseconds.
"""
keys = keys_np.astype(np.float64)
queries = queries_np.astype(np.float64)
# Normalize
k_norms = np.linalg.norm(keys, axis=1, keepdims=True)
k_norms[k_norms < 1e-12] = 1.0
keys_normed = keys / k_norms
q_norms = np.linalg.norm(queries, axis=1, keepdims=True)
q_norms[q_norms < 1e-12] = 1.0
queries_normed = queries / q_norms
# Warmup
for _ in range(n_warmup):
dots = queries_normed @ keys_normed.T
_ = np.argsort(-dots, axis=1)[:, :k]
times = []
for _ in range(n_runs):
t0 = time.perf_counter()
dots = queries_normed @ keys_normed.T
_ = np.argsort(-dots, axis=1)[:, :k]
t1 = time.perf_counter()
times.append((t1 - t0) * 1000)
return np.mean(times), np.std(times)
def compare_attention_patterns(h4_layer, softmax_layer, x):
"""
Compare attention score distributions between H4 and softmax.
Returns correlation coefficient.
"""
B, T, D = x.shape
h4_out = h4_layer(x, use_tree=False)
softmax_out = softmax_layer(x)
h4_flat = h4_out.detach().flatten()
sm_flat = softmax_out.detach().flatten()
if h4_flat.std() < 1e-8 or sm_flat.std() < 1e-8:
return 0.0
corr = torch.corrcoef(torch.stack([h4_flat, sm_flat]))[0, 1].item()
return corr
def main():
torch.manual_seed(42)
np.random.seed(42)
d_model = 64
n_heads = 8
d_value = 16
batch_size = 1
top_k = 32
# Part 1 uses the full H4 attention layer (Python tree), so keep lengths moderate
layer_seq_lengths = [64, 128, 256, 512, 1024]
# Part 2 tests raw Rust top-k at extended lengths
rust_seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
print("=" * 100)
print("H4 Geometric Attention vs Standard Softmax Attention -- Benchmark")
print("=" * 100)
print(f"d_model={d_model}, n_heads={n_heads}, d_value={d_value}, batch_size={batch_size}, top_k={top_k}")
print(f"Rust backend (h4_rust): {'AVAILABLE' if RUST_AVAILABLE else 'NOT AVAILABLE (install with: cd rust && maturin develop --release)'}")
print()
# Create layers
h4_layer = H4AttentionLayer(d_model, n_heads, d_value, top_k=top_k)
softmax_layer = SoftmaxAttentionLayer(d_model, n_heads, d_value)
h4_layer.eval()
softmax_layer.eval()
# ============================================================
# Part 1: Full attention layer benchmark (softmax vs H4)
# ============================================================
print("-" * 100)
print("PART 1: Full Attention Layer Forward Pass (ms)")
print("-" * 100)
results = []
header = f"{'seq_len':>8} | {'softmax_ms':>12} | {'h4_full_ms':>12} | {'h4_tree_ms':>12} | {'tree/full':>10} | {'corr':>8}"
print(header)
print("-" * len(header))
for T in layer_seq_lengths:
x = torch.randn(batch_size, T, d_model)
with torch.no_grad():
sm_mean, sm_std = benchmark_forward_pass(softmax_layer, x)
h4_full_mean, h4_full_std = benchmark_forward_pass(h4_layer, x, use_tree=False)
if T > 64:
h4_tree_mean, h4_tree_std = benchmark_forward_pass(h4_layer, x, use_tree=True, n_runs=3)
else:
h4_tree_mean = h4_full_mean
h4_tree_std = h4_full_std
corr = compare_attention_patterns(h4_layer, softmax_layer, x)
ratio = h4_tree_mean / max(h4_full_mean, 0.001)
print(f"{T:8d} | {sm_mean:10.1f}+/-{sm_std:3.1f} | {h4_full_mean:10.1f}+/-{h4_full_std:3.1f} | {h4_tree_mean:10.1f}+/-{h4_tree_std:3.1f} | {ratio:10.3f} | {corr:8.4f}")
results.append({
'seq_len': T,
'softmax_ms': sm_mean,
'h4_full_ms': h4_full_mean,
'h4_tree_ms': h4_tree_mean,
'tree_vs_full_ratio': ratio,
'output_correlation': corr,
})
# ============================================================
# Part 2: Raw top-k benchmark (Rust vs NumPy)
# ============================================================
print()
print("-" * 100)
print("PART 2: Raw Top-k Query Benchmark — Rust h4_rust vs NumPy (ms)")
print(" (One attention head: n_queries=64 queries against n_keys keys, k=32)")
print("-" * 100)
n_queries = 64
k = 32
if RUST_AVAILABLE:
header2 = f"{'n_keys':>8} | {'numpy_ms':>12} | {'rust_ms':>12} | {'speedup':>10}"
print(header2)
print("-" * len(header2))
rust_results = []
for T in rust_seq_lengths:
keys_np = np.random.randn(T, 4).astype(np.float64)
queries_np = np.random.randn(n_queries, 4).astype(np.float64)
np_mean, np_std = benchmark_numpy_topk(keys_np, queries_np, k)
rust_mean, rust_std = benchmark_rust_topk(keys_np, queries_np, k)
speedup = np_mean / max(rust_mean, 0.001) if rust_mean else 0.0
print(f"{T:8d} | {np_mean:10.3f}+/-{np_std:3.3f} | {rust_mean:10.3f}+/-{rust_std:3.3f} | {speedup:9.1f}x")
rust_results.append({
'n_keys': T,
'numpy_ms': np_mean,
'rust_ms': rust_mean,
'speedup': speedup,
})
else:
print(" [SKIPPED] Rust backend not available.")
print(" Install with: cd rust && maturin develop --release")
rust_results = []
# ============================================================
# Part 3: Chamber index computation benchmark
# ============================================================
print()
print("-" * 100)
print("PART 3: Chamber Index Computation — Rust vs NumPy (ms)")
print("-" * 100)
if RUST_AVAILABLE:
roots = h4_rust.get_simple_roots() # (4, 4) f64
header3 = f"{'n_vectors':>10} | {'numpy_ms':>12} | {'rust_ms':>12} | {'speedup':>10}"
print(header3)
print("-" * len(header3))
for n_vecs in [1000, 10000, 100000]:
vecs = np.random.randn(n_vecs, 4).astype(np.float64)
roots_torch = torch.from_numpy(roots).float()
# NumPy/torch chamber IDs
vecs_torch = torch.from_numpy(vecs).float()
# Warmup
for _ in range(2):
_ = compute_chamber_ids(vecs_torch, roots_torch)
times_np = []
for _ in range(5):
t0 = time.perf_counter()
_ = compute_chamber_ids(vecs_torch, roots_torch)
t1 = time.perf_counter()
times_np.append((t1 - t0) * 1000)
np_mean = np.mean(times_np)
np_std_val = np.std(times_np)
# Rust chamber IDs
for _ in range(2):
_ = h4_rust.chamber_indices(vecs, roots)
times_rust = []
for _ in range(5):
t0 = time.perf_counter()
_ = h4_rust.chamber_indices(vecs, roots)
t1 = time.perf_counter()
times_rust.append((t1 - t0) * 1000)
rust_mean = np.mean(times_rust)
rust_std_val = np.std(times_rust)
speedup = np_mean / max(rust_mean, 0.001)
print(f"{n_vecs:10d} | {np_mean:10.3f}+/-{np_std_val:3.3f} | {rust_mean:10.3f}+/-{rust_std_val:3.3f} | {speedup:9.1f}x")
# Verify correctness: Rust and torch should agree
ids_torch = compute_chamber_ids(vecs_torch, roots_torch).numpy()
ids_rust = h4_rust.chamber_indices(vecs, roots)
# Note: bit ordering may differ, just check both produce valid 0-15 range
assert ids_rust.min() >= 0 and ids_rust.max() <= 15, "Rust chamber IDs out of range"
else:
print(" [SKIPPED] Rust backend not available.")
# ============================================================
# Summary
# ============================================================
print()
print("=" * 100)
print("SUMMARY")
print("=" * 100)
# Scaling analysis from Part 1
if len(results) >= 2:
sm_times = [(r['seq_len'], r['softmax_ms']) for r in results]
h4_times = [(r['seq_len'], r['h4_tree_ms']) for r in results]
sm_exp = math.log(sm_times[-1][1] / max(sm_times[0][1], 0.01)) / math.log(sm_times[-1][0] / sm_times[0][0])
h4_exp = math.log(h4_times[-1][1] / max(h4_times[0][1], 0.01)) / math.log(h4_times[-1][0] / h4_times[0][0])
print(f" Softmax scaling exponent: ~{sm_exp:.2f} (expect ~2.0 for O(t^2))")
print(f" H4 tree scaling exponent: ~{h4_exp:.2f} (expect ~0 for O(log t), higher due to Python overhead)")
crossover = None
for r in results:
if r['h4_tree_ms'] < r['softmax_ms']:
crossover = r['seq_len']
break
if crossover:
print(f" H4 tree becomes faster than softmax at seq_len={crossover}")
else:
print(" Softmax is faster at all tested layer-level lengths")
print(" (H4 tree overhead dominates at small/medium lengths due to Python ChamberTree)")
if RUST_AVAILABLE and rust_results:
print()
print(" Rust backend top-k performance:")
for r in rust_results[:6]:
print(f" n_keys={r['n_keys']:>6d}: Rust {r['rust_ms']:.3f}ms vs NumPy {r['numpy_ms']:.3f}ms ({r['speedup']:.1f}x)")
elif not RUST_AVAILABLE:
print()
print(" Rust backend was NOT available for this run.")
print(" To enable: cd rust && maturin develop --release")
print()
print(" Note: The Python ChamberTree has high constant factors.")
print(" The Rust h4_rust backend shows raw computation speedups.")
print(" Full Rust-accelerated attention layer is the next step.")
print("=" * 100)
if __name__ == '__main__':
main()