File size: 1,916 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse

import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime

matplotlib.use("Agg")
import matplotlib.pyplot as plt


def had(x):
    y = mx.hadamard_transform(x)
    mx.eval(y)


def copy(x):
    y = x + 1.0
    mx.eval(y)


def run(dtype):
    system_size = 2**26
    outputs = {}
    for test_fn in (had, copy):
        for m in [1, 12, 20, 28]:
            if test_fn == copy:
                key = "copy"
            elif m == 1:
                key = "had_2^k"
            else:
                key = "had_m*2^k"
            outputs.setdefault(key, {})
            for k in range(7, 14):
                n = m * 2**k
                if n > 2**15:
                    continue
                x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
                x = mx.array(x_np)
                runtime_ms = measure_runtime(test_fn, x=x)
                bytes_per_gb = 1e9
                ms_per_s = 1e3
                bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
                bandwidth_gb = (
                    system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
                )
                print(n, bandwidth_gb)
                outputs[key][n] = bandwidth_gb

    colors = {
        "copy": "black",
        "had_2^k": "steelblue",
        "had_m*2^k": "skyblue",
    }
    for key, output in outputs.items():
        plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
    plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
    plt.xlabel("N")
    plt.ylabel("Bandwidth (GB/s)")
    plt.legend()
    plt.savefig(f"bench_{dtype.__name__}.png")
    plt.clf()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fp16", action="store_true")
    args = parser.parse_args()
    dtype = np.float16 if args.fp16 else np.float32
    run(dtype)