File size: 2,234 Bytes
201cf4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Benchmark: MAX vs CPU dispatch for KAN inference.

Usage:
    python3 scripts/benchmark_max_dispatch.py
"""
import json
import time as timing_module

import numpy as np
import torch
import torch.nn as nn

from training.core.max_graph_bridge import max_available


class _SimpleKAN(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.spline_weight = nn.Parameter(torch.randn(d, 1, 8))
        self.grid = nn.Parameter(torch.linspace(-2, 2, 12))
        self.alive_ema = torch.ones(d, 1)
        self.linear = nn.Linear(d, d)

    def forward(self, x):
        return self.linear(x)


def _measure_ms():
    """Return current time in milliseconds (high-resolution)."""
    return timing_module.perf_counter() * 1000


def run_benchmark(n_iterations=100, batch_size=8, d_model=64):
    model = _SimpleKAN(d_model)
    x = torch.randn(batch_size, d_model)
    times_cpu = []
    with torch.no_grad():
        for _ in range(n_iterations):
            t0 = _measure_ms()
            _ = model(x)
            times_cpu.append(_measure_ms() - t0)

    result = {
        "cpu_ms": float(np.median(times_cpu)),
        "cpu_p99_ms": float(np.percentile(times_cpu, 99)),
        "n_iterations": n_iterations,
        "batch_size": batch_size,
        "d_model": d_model,
    }

    if max_available():
        try:
            from training.core.max_graph_bridge import create_max_engine
            engine = create_max_engine(model, "default", (batch_size, d_model))
            x_np = x.numpy()
            for _ in range(5):
                engine.execute(x_np)
            times_max = []
            for _ in range(n_iterations):
                t0 = _measure_ms()
                _ = engine.execute(x_np)
                times_max.append(_measure_ms() - t0)
            result["max_ms"] = float(np.median(times_max))
            result["max_p99_ms"] = float(np.percentile(times_max, 99))
            result["speedup"] = result["cpu_ms"] / max(result["max_ms"], 1e-6)
        except Exception as e:
            result["max_unavailable"] = str(e)
    else:
        result["max_unavailable"] = "MAX not installed"

    return result


if __name__ == "__main__":
    print(json.dumps(run_benchmark(), indent=2))