File size: 2,683 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright © 2023-2024 Apple Inc.

import argparse
import math
import random

import mlx.core as mx
from time_utils import time_fn


def bench_gelu():
    def gelu(x):
        return x * (1 + mx.erf(x / math.sqrt(2))) / 2

    x = mx.random.uniform(shape=(1000, 1024))

    def gen_fun(fun):
        def bench_fun(x):
            for _ in range(10):
                x = fun(x)
            return x

        return bench_fun

    time_fn(gen_fun(gelu), x, msg="fixed gelu")
    time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")

    def randint():
        return random.randint(1, x.shape[0])

    def gen_fun(fun):
        def bench_fun(x, y):
            x = x[: randint()]
            for _ in range(10):
                x = fun(x)
                y = fun(y)
            return x, y

        return bench_fun

    y = mx.random.uniform(shape=(1000, 1024))
    time_fn(gen_fun(gelu), x, y, msg="variable gelu")
    time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
    time_fn(
        gen_fun(mx.compile(gelu, shapeless=True)),
        x,
        y,
        msg="shapeless variable gelu",
    )


def bench_layernorm():
    weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
    bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
    mx.eval(weight, bias)

    def layernorm(x):
        x = x.astype(mx.float32)
        means = mx.mean(x, axis=-1, keepdims=True)
        var = mx.var(x, axis=-1, keepdims=True)
        x = (x - means) * mx.rsqrt(var + 1e-4)
        x = x.astype(mx.float16)
        return weight * x + bias

    x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)

    def gen_fun(fun):
        def bench_fun(x):
            for _ in range(10):
                x = fun(x)
            return x

        return bench_fun

    time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
    time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")

    def randint():
        return random.randint(1, x.shape[0])

    def gen_fun(fun):
        def bench_fun(x):
            x = x[: randint()]
            for _ in range(10):
                x = fun(x)
            return x

        return bench_fun

    random.seed(0)
    time_fn(gen_fun(layernorm), x, msg="variable layernorm")
    random.seed(0)
    time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
    random.seed(0)
    time_fn(
        gen_fun(mx.compile(layernorm, shapeless=True)),
        x,
        msg="shapeless variable layernorm",
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Compile benchmarks.")
    args = parser.parse_args()

    bench_gelu()
    bench_layernorm()