| | |
| |
|
| | import argparse |
| | import math |
| | import random |
| |
|
| | import mlx.core as mx |
| | from time_utils import time_fn |
| |
|
| |
|
| | def bench_gelu(): |
| | def gelu(x): |
| | return x * (1 + mx.erf(x / math.sqrt(2))) / 2 |
| |
|
| | x = mx.random.uniform(shape=(1000, 1024)) |
| |
|
| | def gen_fun(fun): |
| | def bench_fun(x): |
| | for _ in range(10): |
| | x = fun(x) |
| | return x |
| |
|
| | return bench_fun |
| |
|
| | time_fn(gen_fun(gelu), x, msg="fixed gelu") |
| | time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu") |
| |
|
| | def randint(): |
| | return random.randint(1, x.shape[0]) |
| |
|
| | def gen_fun(fun): |
| | def bench_fun(x, y): |
| | x = x[: randint()] |
| | for _ in range(10): |
| | x = fun(x) |
| | y = fun(y) |
| | return x, y |
| |
|
| | return bench_fun |
| |
|
| | y = mx.random.uniform(shape=(1000, 1024)) |
| | time_fn(gen_fun(gelu), x, y, msg="variable gelu") |
| | time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu") |
| | time_fn( |
| | gen_fun(mx.compile(gelu, shapeless=True)), |
| | x, |
| | y, |
| | msg="shapeless variable gelu", |
| | ) |
| |
|
| |
|
| | def bench_layernorm(): |
| | weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) |
| | bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) |
| | mx.eval(weight, bias) |
| |
|
| | def layernorm(x): |
| | x = x.astype(mx.float32) |
| | means = mx.mean(x, axis=-1, keepdims=True) |
| | var = mx.var(x, axis=-1, keepdims=True) |
| | x = (x - means) * mx.rsqrt(var + 1e-4) |
| | x = x.astype(mx.float16) |
| | return weight * x + bias |
| |
|
| | x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16) |
| |
|
| | def gen_fun(fun): |
| | def bench_fun(x): |
| | for _ in range(10): |
| | x = fun(x) |
| | return x |
| |
|
| | return bench_fun |
| |
|
| | time_fn(gen_fun(layernorm), x, msg="fixed layernorm") |
| | time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm") |
| |
|
| | def randint(): |
| | return random.randint(1, x.shape[0]) |
| |
|
| | def gen_fun(fun): |
| | def bench_fun(x): |
| | x = x[: randint()] |
| | for _ in range(10): |
| | x = fun(x) |
| | return x |
| |
|
| | return bench_fun |
| |
|
| | random.seed(0) |
| | time_fn(gen_fun(layernorm), x, msg="variable layernorm") |
| | random.seed(0) |
| | time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm") |
| | random.seed(0) |
| | time_fn( |
| | gen_fun(mx.compile(layernorm, shapeless=True)), |
| | x, |
| | msg="shapeless variable layernorm", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser("Compile benchmarks.") |
| | args = parser.parse_args() |
| |
|
| | bench_gelu() |
| | bench_layernorm() |
| |
|