File size: 2,687 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright © 2023 Apple Inc.

import argparse

import mlx.core as mx
from time_utils import time_fn


def time_add():
    a = mx.random.uniform(shape=(32, 1024, 1024))
    b = mx.random.uniform(shape=(32, 1024, 1024))
    mx.eval(a, b)
    time_fn(mx.add, a, b)

    aT = mx.transpose(a, [0, 2, 1])
    mx.eval(aT)

    def transpose_add(a, b):
        return mx.add(a, b)

    time_fn(transpose_add, aT, b)

    b = mx.random.uniform(shape=(1024,))
    mx.eval(b)

    def slice_add(a, b):
        return mx.add(a, b)

    time_fn(slice_add, a, b)

    b = mx.reshape(b, (1, 1024, 1))
    mx.eval(b)

    def mid_slice_add(a, b):
        return mx.add(a, b)

    time_fn(mid_slice_add, a, b)


def time_matmul():
    a = mx.random.uniform(shape=(1024, 1024))
    b = mx.random.uniform(shape=(1024, 1024))
    mx.eval(a, b)
    time_fn(mx.matmul, a, b)


def time_maximum():
    a = mx.random.uniform(shape=(32, 1024, 1024))
    b = mx.random.uniform(shape=(32, 1024, 1024))
    mx.eval(a, b)
    time_fn(mx.maximum, a, b)


def time_max():
    a = mx.random.uniform(shape=(32, 1024, 1024))
    a[1, 1] = mx.nan
    mx.eval(a)
    time_fn(mx.max, a, 0)


def time_min():
    a = mx.random.uniform(shape=(32, 1024, 1024))
    a[1, 1] = mx.nan
    mx.eval(a)
    time_fn(mx.min, a, 0)


def time_negative():
    a = mx.random.uniform(shape=(10000, 1000))
    mx.eval(a)

    def negative(a):
        return -a

    mx.eval(a)

    time_fn(negative, a)


def time_exp():
    a = mx.random.uniform(shape=(1000, 100))
    mx.eval(a)
    time_fn(mx.exp, a)


def time_logsumexp():
    a = mx.random.uniform(shape=(64, 10, 10000))
    mx.eval(a)
    time_fn(mx.logsumexp, a, axis=-1)


def time_take():
    a = mx.random.uniform(shape=(10000, 500))
    ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
    ids = [mx.reshape(idx, (-1,)) for idx in ids]
    mx.eval(ids)

    def random_take():
        return [mx.take(a, idx, 0) for idx in ids]

    time_fn(random_take)


def time_reshape_transposed():
    x = mx.random.uniform(shape=(256, 256, 128))
    mx.eval(x)

    def reshape_transposed():
        return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))

    time_fn(reshape_transposed)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("MLX benchmarks.")
    parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
    args = parser.parse_args()
    if args.gpu:
        mx.set_default_device(mx.gpu)
    else:
        mx.set_default_device(mx.cpu)

    time_add()
    time_matmul()
    time_min()
    time_max()
    time_maximum()
    time_exp()
    time_negative()
    time_logsumexp()
    time_take()
    time_reshape_transposed()