|
|
|
|
|
|
|
|
import argparse |
|
|
import math |
|
|
import random |
|
|
|
|
|
import mlx.core as mx |
|
|
from time_utils import time_fn |
|
|
|
|
|
|
|
|
def bench_gelu(): |
|
|
def gelu(x): |
|
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2 |
|
|
|
|
|
x = mx.random.uniform(shape=(1000, 1024)) |
|
|
|
|
|
def gen_fun(fun): |
|
|
def bench_fun(x): |
|
|
for _ in range(10): |
|
|
x = fun(x) |
|
|
return x |
|
|
|
|
|
return bench_fun |
|
|
|
|
|
time_fn(gen_fun(gelu), x, msg="fixed gelu") |
|
|
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu") |
|
|
|
|
|
def randint(): |
|
|
return random.randint(1, x.shape[0]) |
|
|
|
|
|
def gen_fun(fun): |
|
|
def bench_fun(x, y): |
|
|
x = x[: randint()] |
|
|
for _ in range(10): |
|
|
x = fun(x) |
|
|
y = fun(y) |
|
|
return x, y |
|
|
|
|
|
return bench_fun |
|
|
|
|
|
y = mx.random.uniform(shape=(1000, 1024)) |
|
|
time_fn(gen_fun(gelu), x, y, msg="variable gelu") |
|
|
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu") |
|
|
time_fn( |
|
|
gen_fun(mx.compile(gelu, shapeless=True)), |
|
|
x, |
|
|
y, |
|
|
msg="shapeless variable gelu", |
|
|
) |
|
|
|
|
|
|
|
|
def bench_layernorm(): |
|
|
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) |
|
|
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) |
|
|
mx.eval(weight, bias) |
|
|
|
|
|
def layernorm(x): |
|
|
x = x.astype(mx.float32) |
|
|
means = mx.mean(x, axis=-1, keepdims=True) |
|
|
var = mx.var(x, axis=-1, keepdims=True) |
|
|
x = (x - means) * mx.rsqrt(var + 1e-4) |
|
|
x = x.astype(mx.float16) |
|
|
return weight * x + bias |
|
|
|
|
|
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16) |
|
|
|
|
|
def gen_fun(fun): |
|
|
def bench_fun(x): |
|
|
for _ in range(10): |
|
|
x = fun(x) |
|
|
return x |
|
|
|
|
|
return bench_fun |
|
|
|
|
|
time_fn(gen_fun(layernorm), x, msg="fixed layernorm") |
|
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm") |
|
|
|
|
|
def randint(): |
|
|
return random.randint(1, x.shape[0]) |
|
|
|
|
|
def gen_fun(fun): |
|
|
def bench_fun(x): |
|
|
x = x[: randint()] |
|
|
for _ in range(10): |
|
|
x = fun(x) |
|
|
return x |
|
|
|
|
|
return bench_fun |
|
|
|
|
|
random.seed(0) |
|
|
time_fn(gen_fun(layernorm), x, msg="variable layernorm") |
|
|
random.seed(0) |
|
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm") |
|
|
random.seed(0) |
|
|
time_fn( |
|
|
gen_fun(mx.compile(layernorm, shapeless=True)), |
|
|
x, |
|
|
msg="shapeless variable layernorm", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser("Compile benchmarks.") |
|
|
args = parser.parse_args() |
|
|
|
|
|
bench_gelu() |
|
|
bench_layernorm() |
|
|
|