|
|
|
|
|
|
|
|
import argparse |
|
|
|
|
|
import mlx.core as mx |
|
|
from time_utils import time_fn |
|
|
|
|
|
|
|
|
def time_add(): |
|
|
a = mx.random.uniform(shape=(32, 1024, 1024)) |
|
|
b = mx.random.uniform(shape=(32, 1024, 1024)) |
|
|
mx.eval(a, b) |
|
|
time_fn(mx.add, a, b) |
|
|
|
|
|
aT = mx.transpose(a, [0, 2, 1]) |
|
|
mx.eval(aT) |
|
|
|
|
|
def transpose_add(a, b): |
|
|
return mx.add(a, b) |
|
|
|
|
|
time_fn(transpose_add, aT, b) |
|
|
|
|
|
b = mx.random.uniform(shape=(1024,)) |
|
|
mx.eval(b) |
|
|
|
|
|
def slice_add(a, b): |
|
|
return mx.add(a, b) |
|
|
|
|
|
time_fn(slice_add, a, b) |
|
|
|
|
|
b = mx.reshape(b, (1, 1024, 1)) |
|
|
mx.eval(b) |
|
|
|
|
|
def mid_slice_add(a, b): |
|
|
return mx.add(a, b) |
|
|
|
|
|
time_fn(mid_slice_add, a, b) |
|
|
|
|
|
|
|
|
def time_matmul(): |
|
|
a = mx.random.uniform(shape=(1024, 1024)) |
|
|
b = mx.random.uniform(shape=(1024, 1024)) |
|
|
mx.eval(a, b) |
|
|
time_fn(mx.matmul, a, b) |
|
|
|
|
|
|
|
|
def time_maximum(): |
|
|
a = mx.random.uniform(shape=(32, 1024, 1024)) |
|
|
b = mx.random.uniform(shape=(32, 1024, 1024)) |
|
|
mx.eval(a, b) |
|
|
time_fn(mx.maximum, a, b) |
|
|
|
|
|
|
|
|
def time_max(): |
|
|
a = mx.random.uniform(shape=(32, 1024, 1024)) |
|
|
a[1, 1] = mx.nan |
|
|
mx.eval(a) |
|
|
time_fn(mx.max, a, 0) |
|
|
|
|
|
|
|
|
def time_min(): |
|
|
a = mx.random.uniform(shape=(32, 1024, 1024)) |
|
|
a[1, 1] = mx.nan |
|
|
mx.eval(a) |
|
|
time_fn(mx.min, a, 0) |
|
|
|
|
|
|
|
|
def time_negative(): |
|
|
a = mx.random.uniform(shape=(10000, 1000)) |
|
|
mx.eval(a) |
|
|
|
|
|
def negative(a): |
|
|
return -a |
|
|
|
|
|
mx.eval(a) |
|
|
|
|
|
time_fn(negative, a) |
|
|
|
|
|
|
|
|
def time_exp(): |
|
|
a = mx.random.uniform(shape=(1000, 100)) |
|
|
mx.eval(a) |
|
|
time_fn(mx.exp, a) |
|
|
|
|
|
|
|
|
def time_logsumexp(): |
|
|
a = mx.random.uniform(shape=(64, 10, 10000)) |
|
|
mx.eval(a) |
|
|
time_fn(mx.logsumexp, a, axis=-1) |
|
|
|
|
|
|
|
|
def time_take(): |
|
|
a = mx.random.uniform(shape=(10000, 500)) |
|
|
ids = mx.random.randint(low=0, high=10000, shape=(20, 10)) |
|
|
ids = [mx.reshape(idx, (-1,)) for idx in ids] |
|
|
mx.eval(ids) |
|
|
|
|
|
def random_take(): |
|
|
return [mx.take(a, idx, 0) for idx in ids] |
|
|
|
|
|
time_fn(random_take) |
|
|
|
|
|
|
|
|
def time_reshape_transposed(): |
|
|
x = mx.random.uniform(shape=(256, 256, 128)) |
|
|
mx.eval(x) |
|
|
|
|
|
def reshape_transposed(): |
|
|
return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,)) |
|
|
|
|
|
time_fn(reshape_transposed) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser("MLX benchmarks.") |
|
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") |
|
|
args = parser.parse_args() |
|
|
if args.gpu: |
|
|
mx.set_default_device(mx.gpu) |
|
|
else: |
|
|
mx.set_default_device(mx.cpu) |
|
|
|
|
|
time_add() |
|
|
time_matmul() |
|
|
time_min() |
|
|
time_max() |
|
|
time_maximum() |
|
|
time_exp() |
|
|
time_negative() |
|
|
time_logsumexp() |
|
|
time_take() |
|
|
time_reshape_transposed() |
|
|
|