File size: 1,375 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
from time_utils import time_fn
B = 8
T = 1024
D = 512
def time_batch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B, T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B, T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def batch_vjp_first():
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
time_fn(batch_vjp_first)
def batch_vjp_second():
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
time_fn(batch_vjp_second)
def time_unbatch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B * T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B * T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def unbatch_vjp_first():
return mx.matmul(c, mx.transpose(b))
time_fn(unbatch_vjp_first)
def unbatch_vjp_second():
return mx.matmul(mx.transpose(a), c)
time_fn(unbatch_vjp_second)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_batch_matmul()
time_unbatch_matmul()
|