Spaces:
Build error
Build error
File size: 2,234 Bytes
201cf4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | """Benchmark: MAX vs CPU dispatch for KAN inference.
Usage:
python3 scripts/benchmark_max_dispatch.py
"""
import json
import time as timing_module
import numpy as np
import torch
import torch.nn as nn
from training.core.max_graph_bridge import max_available
class _SimpleKAN(nn.Module):
def __init__(self, d):
super().__init__()
self.spline_weight = nn.Parameter(torch.randn(d, 1, 8))
self.grid = nn.Parameter(torch.linspace(-2, 2, 12))
self.alive_ema = torch.ones(d, 1)
self.linear = nn.Linear(d, d)
def forward(self, x):
return self.linear(x)
def _measure_ms():
"""Return current time in milliseconds (high-resolution)."""
return timing_module.perf_counter() * 1000
def run_benchmark(n_iterations=100, batch_size=8, d_model=64):
model = _SimpleKAN(d_model)
x = torch.randn(batch_size, d_model)
times_cpu = []
with torch.no_grad():
for _ in range(n_iterations):
t0 = _measure_ms()
_ = model(x)
times_cpu.append(_measure_ms() - t0)
result = {
"cpu_ms": float(np.median(times_cpu)),
"cpu_p99_ms": float(np.percentile(times_cpu, 99)),
"n_iterations": n_iterations,
"batch_size": batch_size,
"d_model": d_model,
}
if max_available():
try:
from training.core.max_graph_bridge import create_max_engine
engine = create_max_engine(model, "default", (batch_size, d_model))
x_np = x.numpy()
for _ in range(5):
engine.execute(x_np)
times_max = []
for _ in range(n_iterations):
t0 = _measure_ms()
_ = engine.execute(x_np)
times_max.append(_measure_ms() - t0)
result["max_ms"] = float(np.median(times_max))
result["max_p99_ms"] = float(np.percentile(times_max, 99))
result["speedup"] = result["cpu_ms"] / max(result["max_ms"], 1e-6)
except Exception as e:
result["max_unavailable"] = str(e)
else:
result["max_unavailable"] = "MAX not installed"
return result
if __name__ == "__main__":
print(json.dumps(run_benchmark(), indent=2))
|