File size: 1,375 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright © 2023 Apple Inc.

import argparse

import mlx.core as mx
from time_utils import time_fn

B = 8
T = 1024
D = 512


def time_batch_matmul():
    mx.random.seed(3)
    a = mx.random.uniform(shape=(B, T, D))
    b = mx.random.uniform(shape=(D, D))
    c = mx.random.uniform(shape=(B, T, D))
    mx.eval(a, b, c)

    time_fn(mx.matmul, a, b)

    def batch_vjp_first():
        return mx.vjp(mx.matmul, [a, b], [c])[1][0]

    time_fn(batch_vjp_first)

    def batch_vjp_second():
        return mx.vjp(mx.matmul, [a, b], [c])[1][1]

    time_fn(batch_vjp_second)


def time_unbatch_matmul():
    mx.random.seed(3)
    a = mx.random.uniform(shape=(B * T, D))
    b = mx.random.uniform(shape=(D, D))
    c = mx.random.uniform(shape=(B * T, D))
    mx.eval(a, b, c)
    time_fn(mx.matmul, a, b)

    def unbatch_vjp_first():
        return mx.matmul(c, mx.transpose(b))

    time_fn(unbatch_vjp_first)

    def unbatch_vjp_second():
        return mx.matmul(mx.transpose(a), c)

    time_fn(unbatch_vjp_second)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("MLX benchmarks.")
    parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
    args = parser.parse_args()
    if args.gpu:
        mx.set_default_device(mx.gpu)
    else:
        mx.set_default_device(mx.cpu)

    time_batch_matmul()
    time_unbatch_matmul()