ane-kan-runtime / scripts /benchmark_max_dispatch.py
JohnGenetica's picture
Deploy ANE KAN runtime Space
201cf4d verified
"""Benchmark: MAX vs CPU dispatch for KAN inference.
Usage:
python3 scripts/benchmark_max_dispatch.py
"""
import json
import time as timing_module
import numpy as np
import torch
import torch.nn as nn
from training.core.max_graph_bridge import max_available
class _SimpleKAN(nn.Module):
def __init__(self, d):
super().__init__()
self.spline_weight = nn.Parameter(torch.randn(d, 1, 8))
self.grid = nn.Parameter(torch.linspace(-2, 2, 12))
self.alive_ema = torch.ones(d, 1)
self.linear = nn.Linear(d, d)
def forward(self, x):
return self.linear(x)
def _measure_ms():
"""Return current time in milliseconds (high-resolution)."""
return timing_module.perf_counter() * 1000
def run_benchmark(n_iterations=100, batch_size=8, d_model=64):
model = _SimpleKAN(d_model)
x = torch.randn(batch_size, d_model)
times_cpu = []
with torch.no_grad():
for _ in range(n_iterations):
t0 = _measure_ms()
_ = model(x)
times_cpu.append(_measure_ms() - t0)
result = {
"cpu_ms": float(np.median(times_cpu)),
"cpu_p99_ms": float(np.percentile(times_cpu, 99)),
"n_iterations": n_iterations,
"batch_size": batch_size,
"d_model": d_model,
}
if max_available():
try:
from training.core.max_graph_bridge import create_max_engine
engine = create_max_engine(model, "default", (batch_size, d_model))
x_np = x.numpy()
for _ in range(5):
engine.execute(x_np)
times_max = []
for _ in range(n_iterations):
t0 = _measure_ms()
_ = engine.execute(x_np)
times_max.append(_measure_ms() - t0)
result["max_ms"] = float(np.median(times_max))
result["max_p99_ms"] = float(np.percentile(times_max, 99))
result["speedup"] = result["cpu_ms"] / max(result["max_ms"], 1e-6)
except Exception as e:
result["max_unavailable"] = str(e)
else:
result["max_unavailable"] = "MAX not installed"
return result
if __name__ == "__main__":
print(json.dumps(run_benchmark(), indent=2))