"""Benchmark: MAX vs CPU dispatch for KAN inference. Usage: python3 scripts/benchmark_max_dispatch.py """ import json import time as timing_module import numpy as np import torch import torch.nn as nn from training.core.max_graph_bridge import max_available class _SimpleKAN(nn.Module): def __init__(self, d): super().__init__() self.spline_weight = nn.Parameter(torch.randn(d, 1, 8)) self.grid = nn.Parameter(torch.linspace(-2, 2, 12)) self.alive_ema = torch.ones(d, 1) self.linear = nn.Linear(d, d) def forward(self, x): return self.linear(x) def _measure_ms(): """Return current time in milliseconds (high-resolution).""" return timing_module.perf_counter() * 1000 def run_benchmark(n_iterations=100, batch_size=8, d_model=64): model = _SimpleKAN(d_model) x = torch.randn(batch_size, d_model) times_cpu = [] with torch.no_grad(): for _ in range(n_iterations): t0 = _measure_ms() _ = model(x) times_cpu.append(_measure_ms() - t0) result = { "cpu_ms": float(np.median(times_cpu)), "cpu_p99_ms": float(np.percentile(times_cpu, 99)), "n_iterations": n_iterations, "batch_size": batch_size, "d_model": d_model, } if max_available(): try: from training.core.max_graph_bridge import create_max_engine engine = create_max_engine(model, "default", (batch_size, d_model)) x_np = x.numpy() for _ in range(5): engine.execute(x_np) times_max = [] for _ in range(n_iterations): t0 = _measure_ms() _ = engine.execute(x_np) times_max.append(_measure_ms() - t0) result["max_ms"] = float(np.median(times_max)) result["max_p99_ms"] = float(np.percentile(times_max, 99)) result["speedup"] = result["cpu_ms"] / max(result["max_ms"], 1e-6) except Exception as e: result["max_unavailable"] = str(e) else: result["max_unavailable"] = "MAX not installed" return result if __name__ == "__main__": print(json.dumps(run_benchmark(), indent=2))