|
|
|
|
|
|
|
|
import argparse |
|
|
import math |
|
|
import os |
|
|
import time |
|
|
from functools import partial |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
|
|
|
|
|
|
def int_or_list(x): |
|
|
try: |
|
|
return int(x) |
|
|
except ValueError: |
|
|
return [int(xi) for xi in x.split(",")] |
|
|
|
|
|
|
|
|
def none_or_list(x): |
|
|
if x == "": |
|
|
return None |
|
|
else: |
|
|
return [int(xi) for xi in x.split(",")] |
|
|
|
|
|
|
|
|
def dtype_from_str(x): |
|
|
if x == "": |
|
|
return mx.float32 |
|
|
else: |
|
|
dt = getattr(mx, x) |
|
|
if not isinstance(dt, mx.Dtype): |
|
|
raise ValueError(f"{x} is not an mlx dtype") |
|
|
return dt |
|
|
|
|
|
|
|
|
def bench(f, *args): |
|
|
for i in range(10): |
|
|
f(*args) |
|
|
|
|
|
s = time.time() |
|
|
for i in range(100): |
|
|
f(*args) |
|
|
e = time.time() |
|
|
return e - s |
|
|
|
|
|
|
|
|
def matmul_square(x): |
|
|
y = x |
|
|
for i in range(10): |
|
|
y = y @ x |
|
|
mx.eval(y) |
|
|
return y |
|
|
|
|
|
|
|
|
def matmul(x, y): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(x @ y) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def _quant_matmul(x, w, s, b, transpose, group_size, bits): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append( |
|
|
mx.quantized_matmul( |
|
|
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits |
|
|
) |
|
|
) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
quant_matmul = { |
|
|
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2), |
|
|
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4), |
|
|
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8), |
|
|
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), |
|
|
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), |
|
|
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), |
|
|
"quant_matmul_128_2": partial( |
|
|
_quant_matmul, transpose=False, group_size=128, bits=2 |
|
|
), |
|
|
"quant_matmul_128_4": partial( |
|
|
_quant_matmul, transpose=False, group_size=128, bits=4 |
|
|
), |
|
|
"quant_matmul_128_8": partial( |
|
|
_quant_matmul, transpose=False, group_size=128, bits=8 |
|
|
), |
|
|
"quant_matmul_t_32_2": partial( |
|
|
_quant_matmul, transpose=True, group_size=32, bits=2 |
|
|
), |
|
|
"quant_matmul_t_32_4": partial( |
|
|
_quant_matmul, transpose=True, group_size=32, bits=4 |
|
|
), |
|
|
"quant_matmul_t_32_8": partial( |
|
|
_quant_matmul, transpose=True, group_size=32, bits=8 |
|
|
), |
|
|
"quant_matmul_t_64_2": partial( |
|
|
_quant_matmul, transpose=True, group_size=64, bits=2 |
|
|
), |
|
|
"quant_matmul_t_64_4": partial( |
|
|
_quant_matmul, transpose=True, group_size=64, bits=4 |
|
|
), |
|
|
"quant_matmul_t_64_8": partial( |
|
|
_quant_matmul, transpose=True, group_size=64, bits=8 |
|
|
), |
|
|
"quant_matmul_t_128_2": partial( |
|
|
_quant_matmul, transpose=True, group_size=128, bits=2 |
|
|
), |
|
|
"quant_matmul_t_128_4": partial( |
|
|
_quant_matmul, transpose=True, group_size=128, bits=4 |
|
|
), |
|
|
"quant_matmul_t_128_8": partial( |
|
|
_quant_matmul, transpose=True, group_size=128, bits=8 |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def conv1d(x, y): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(mx.conv1d(x, y)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def conv2d(x, y): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(mx.conv2d(x, y)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def binary(op, x, y): |
|
|
for i in range(100): |
|
|
y = getattr(mx, op)(x, y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def reduction(op, axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ys.append(getattr(mx, op)(x, axis=axis)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def sum_and_add(axis, x, y): |
|
|
z = x.sum(axis=axis, keepdims=True) |
|
|
for i in range(50): |
|
|
z = (z + y).sum(axis=axis, keepdims=True) |
|
|
mx.eval(z) |
|
|
|
|
|
|
|
|
def softmax(axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True)) |
|
|
y = ex / mx.sum(ex, axis=axis, keepdims=True) |
|
|
ys.append(y) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def softmax_fused(axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
y = mx.softmax(x, axis=axis) |
|
|
ys.append(y) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def relu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.relu(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def leaky_relu(x: mx.array): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.leaky_relu(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def prelu(x: mx.array): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.prelu(y, mx.ones(1)) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def softplus(x: mx.array): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.softplus(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def mish(x: mx.array): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.mish(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def leaky_relu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.leaky_relu(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def elu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.elu(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def relu6(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.relu6(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def softplus(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.softplus(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def celu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.celu(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def log_sigmoid(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.log_sigmoid(y) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def scalar_mult(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = y * (1.0 / (1 + i)) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def cross_entropy(targets, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis( |
|
|
x, mx.reshape(targets, (-1, 1)), axis=-1 |
|
|
) |
|
|
ys.append(mx.mean(y)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def logsumexp(axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ys.append(mx.logsumexp(x, axis=axis)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def linear(w, b, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(x @ mx.transpose(w, (1, 0)) + b) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def linear_fused(w, b, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0)))) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def rope(x): |
|
|
*_, N, D = x.shape |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
shape = x.shape |
|
|
x = mx.reshape(x, (-1, N, D)) |
|
|
positions = mx.arange(N) |
|
|
freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1))) |
|
|
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) |
|
|
costheta = mx.cos(theta) |
|
|
sintheta = mx.sin(theta) |
|
|
x1 = x[..., ::2] |
|
|
x2 = x[..., 1::2] |
|
|
rx1 = x1 * costheta - x2 * sintheta |
|
|
rx2 = x1 * sintheta + x2 * costheta |
|
|
y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) |
|
|
y = mx.reshape(y, (-1, N, D)) |
|
|
ys.append(y) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def concatenate(axis, x, y): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(mx.concatenate([x, y], axis=axis)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def cumsum(axis, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(mx.cumsum(x, axis)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def sort(axis, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(mx.sort(x, axis)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def topk(axis, x): |
|
|
k = x.shape[axis] // 3 |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(mx.topk(x, k, axis)) |
|
|
mx.eval(ys) |
|
|
|
|
|
|
|
|
def step_function(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.step(x) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
def selu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = nn.selu(x) |
|
|
mx.eval(y) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("benchmark", help="Choose the benchmark to run") |
|
|
parser.add_argument( |
|
|
"--size", |
|
|
default=[(1024, 1024)], |
|
|
type=lambda x: list(map(int, x.split("x"))), |
|
|
help="Set the matrix size", |
|
|
action="append", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--axis", |
|
|
default=[1], |
|
|
type=int_or_list, |
|
|
help="Set a reduction axis", |
|
|
action="append", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--transpose", |
|
|
type=none_or_list, |
|
|
default=[], |
|
|
help="Permute the matrix", |
|
|
action="append", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--print-pid", action="store_true", help="Print the PID and pause" |
|
|
) |
|
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU") |
|
|
parser.add_argument( |
|
|
"--fused", action="store_true", help="Use fused functions where possible" |
|
|
) |
|
|
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if len(args.size) > 1: |
|
|
args.size.pop(0) |
|
|
if len(args.axis) > 1: |
|
|
args.axis.pop(0) |
|
|
|
|
|
if args.cpu: |
|
|
mx.set_default_device(mx.cpu) |
|
|
else: |
|
|
mx.set_default_device(mx.gpu) |
|
|
|
|
|
types = args.dtype |
|
|
if not types: |
|
|
types = [mx.float32] |
|
|
if len(types) < len(args.size): |
|
|
types = types + [types[0]] * (len(args.size) - len(types)) |
|
|
|
|
|
xs = [] |
|
|
for size, dtype in zip(args.size, types): |
|
|
xs.append(mx.random.normal(size).astype(dtype)) |
|
|
for i, t in enumerate(args.transpose): |
|
|
if t is None: |
|
|
continue |
|
|
xs[i] = mx.transpose(xs[i], t) |
|
|
mx.eval(xs) |
|
|
x = xs[0] |
|
|
axis = args.axis[0] |
|
|
|
|
|
if args.print_pid: |
|
|
print(os.getpid()) |
|
|
input("Press enter to run") |
|
|
|
|
|
if args.benchmark == "matmul_square": |
|
|
print(bench(matmul_square, x)) |
|
|
|
|
|
elif args.benchmark == "matmul": |
|
|
print(bench(matmul, *xs)) |
|
|
|
|
|
elif args.benchmark.startswith("quant_matmul"): |
|
|
print(bench(quant_matmul[args.benchmark], *xs)) |
|
|
|
|
|
elif args.benchmark == "linear": |
|
|
if args.fused: |
|
|
print(bench(linear_fused, *xs)) |
|
|
else: |
|
|
print(bench(linear, *xs)) |
|
|
|
|
|
elif args.benchmark == "sum_axis": |
|
|
print(bench(reduction, "sum", axis, x)) |
|
|
|
|
|
elif args.benchmark == "sum_all": |
|
|
print(bench(reduction, "sum", None, x)) |
|
|
|
|
|
elif args.benchmark == "argmax": |
|
|
print(bench(reduction, "argmax", axis, x)) |
|
|
|
|
|
elif args.benchmark == "add": |
|
|
print(bench(binary, "add", *xs)) |
|
|
|
|
|
elif args.benchmark == "mul": |
|
|
print(bench(binary, "multiply", *xs)) |
|
|
|
|
|
elif args.benchmark == "softmax": |
|
|
if args.fused: |
|
|
print(bench(softmax_fused, axis, x)) |
|
|
else: |
|
|
print(bench(softmax, axis, x)) |
|
|
|
|
|
elif args.benchmark == "relu": |
|
|
print(bench(relu, x)) |
|
|
|
|
|
elif args.benchmark == "elu": |
|
|
print(bench(elu, x)) |
|
|
|
|
|
elif args.benchmark == "relu6": |
|
|
print(bench(relu6, x)) |
|
|
|
|
|
elif args.benchmark == "celu": |
|
|
print(bench(celu, x)) |
|
|
|
|
|
elif args.benchmark == "log_sigmoid": |
|
|
print(bench(log_sigmoid, x)) |
|
|
|
|
|
elif args.benchmark == "leaky_relu": |
|
|
print(bench(leaky_relu, x)) |
|
|
elif args.benchmark == "prelu": |
|
|
print(bench(prelu, x)) |
|
|
elif args.benchmark == "softplus": |
|
|
print(bench(softplus, x)) |
|
|
elif args.benchmark == "mish": |
|
|
print(bench(mish, x)) |
|
|
elif args.benchmark == "scalar_mul": |
|
|
print(bench(scalar_mult, x)) |
|
|
|
|
|
elif args.benchmark == "cross_entropy": |
|
|
if len(size) != 2: |
|
|
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size") |
|
|
|
|
|
targets = mx.zeros((len(x),), dtype=mx.uint32) |
|
|
print(bench(cross_entropy, targets, x)) |
|
|
|
|
|
elif args.benchmark == "logsumexp": |
|
|
print(bench(logsumexp, axis, x)) |
|
|
|
|
|
elif args.benchmark == "rope": |
|
|
print(bench(rope, x)) |
|
|
|
|
|
elif args.benchmark == "concatenate": |
|
|
print(bench(concatenate, axis, *xs)) |
|
|
|
|
|
elif args.benchmark == "cumsum": |
|
|
print(bench(cumsum, axis, *xs)) |
|
|
|
|
|
elif args.benchmark == "conv1d": |
|
|
print(bench(conv1d, *xs)) |
|
|
|
|
|
elif args.benchmark == "conv2d": |
|
|
print(bench(conv2d, *xs)) |
|
|
|
|
|
elif args.benchmark == "sort": |
|
|
print(bench(sort, axis, x)) |
|
|
|
|
|
elif args.benchmark == "topk": |
|
|
print(bench(topk, axis, x)) |
|
|
|
|
|
elif args.benchmark == "step": |
|
|
print(bench(step_function, x)) |
|
|
|
|
|
elif args.benchmark == "selu": |
|
|
print(bench(selu, x)) |
|
|
|
|
|
elif args.benchmark == "sum_and_add": |
|
|
print(bench(sum_and_add, axis, *xs)) |
|
|
|
|
|
else: |
|
|
raise ValueError("Unknown benchmark") |
|
|
|