Spaces:
Build error
Build error
| """Benchmark: MAX vs CPU dispatch for KAN inference. | |
| Usage: | |
| python3 scripts/benchmark_max_dispatch.py | |
| """ | |
| import json | |
| import time as timing_module | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from training.core.max_graph_bridge import max_available | |
| class _SimpleKAN(nn.Module): | |
| def __init__(self, d): | |
| super().__init__() | |
| self.spline_weight = nn.Parameter(torch.randn(d, 1, 8)) | |
| self.grid = nn.Parameter(torch.linspace(-2, 2, 12)) | |
| self.alive_ema = torch.ones(d, 1) | |
| self.linear = nn.Linear(d, d) | |
| def forward(self, x): | |
| return self.linear(x) | |
| def _measure_ms(): | |
| """Return current time in milliseconds (high-resolution).""" | |
| return timing_module.perf_counter() * 1000 | |
| def run_benchmark(n_iterations=100, batch_size=8, d_model=64): | |
| model = _SimpleKAN(d_model) | |
| x = torch.randn(batch_size, d_model) | |
| times_cpu = [] | |
| with torch.no_grad(): | |
| for _ in range(n_iterations): | |
| t0 = _measure_ms() | |
| _ = model(x) | |
| times_cpu.append(_measure_ms() - t0) | |
| result = { | |
| "cpu_ms": float(np.median(times_cpu)), | |
| "cpu_p99_ms": float(np.percentile(times_cpu, 99)), | |
| "n_iterations": n_iterations, | |
| "batch_size": batch_size, | |
| "d_model": d_model, | |
| } | |
| if max_available(): | |
| try: | |
| from training.core.max_graph_bridge import create_max_engine | |
| engine = create_max_engine(model, "default", (batch_size, d_model)) | |
| x_np = x.numpy() | |
| for _ in range(5): | |
| engine.execute(x_np) | |
| times_max = [] | |
| for _ in range(n_iterations): | |
| t0 = _measure_ms() | |
| _ = engine.execute(x_np) | |
| times_max.append(_measure_ms() - t0) | |
| result["max_ms"] = float(np.median(times_max)) | |
| result["max_p99_ms"] = float(np.percentile(times_max, 99)) | |
| result["speedup"] = result["cpu_ms"] / max(result["max_ms"], 1e-6) | |
| except Exception as e: | |
| result["max_unavailable"] = str(e) | |
| else: | |
| result["max_unavailable"] = "MAX not installed" | |
| return result | |
| if __name__ == "__main__": | |
| print(json.dumps(run_benchmark(), indent=2)) | |