File size: 2,687 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
from time_utils import time_fn
def time_add():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.add, a, b)
aT = mx.transpose(a, [0, 2, 1])
mx.eval(aT)
def transpose_add(a, b):
return mx.add(a, b)
time_fn(transpose_add, aT, b)
b = mx.random.uniform(shape=(1024,))
mx.eval(b)
def slice_add(a, b):
return mx.add(a, b)
time_fn(slice_add, a, b)
b = mx.reshape(b, (1, 1024, 1))
mx.eval(b)
def mid_slice_add(a, b):
return mx.add(a, b)
time_fn(mid_slice_add, a, b)
def time_matmul():
a = mx.random.uniform(shape=(1024, 1024))
b = mx.random.uniform(shape=(1024, 1024))
mx.eval(a, b)
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
def negative(a):
return -a
mx.eval(a)
time_fn(negative, a)
def time_exp():
a = mx.random.uniform(shape=(1000, 100))
mx.eval(a)
time_fn(mx.exp, a)
def time_logsumexp():
a = mx.random.uniform(shape=(64, 10, 10000))
mx.eval(a)
time_fn(mx.logsumexp, a, axis=-1)
def time_take():
a = mx.random.uniform(shape=(10000, 500))
ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
ids = [mx.reshape(idx, (-1,)) for idx in ids]
mx.eval(ids)
def random_take():
return [mx.take(a, idx, 0) for idx in ids]
time_fn(random_take)
def time_reshape_transposed():
x = mx.random.uniform(shape=(256, 256, 128))
mx.eval(x)
def reshape_transposed():
return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
time_fn(reshape_transposed)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()
time_negative()
time_logsumexp()
time_take()
time_reshape_transposed()
|