| | |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | from time_utils import time_fn |
| |
|
| |
|
| | def rms_norm(x, w, eps): |
| | ot = x.dtype |
| | x = x.astype(mx.float32) |
| | n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) |
| | y = (x * n).astype(ot) |
| | if w is not None: |
| | y = y * w |
| | return y |
| |
|
| |
|
| | def time_rms_norm(): |
| | f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum() |
| | f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum() |
| | g1 = mx.grad(f1, argnums=(0, 1)) |
| | g2 = mx.grad(f2, argnums=(0, 1)) |
| |
|
| | x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) |
| | w = mx.random.uniform(shape=(4096,)).astype(mx.float16) |
| | y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) |
| | mx.eval(x, w, y) |
| |
|
| | def rms_norm_loop(g, x, w): |
| | gx, gw = x, w |
| | for _ in range(32): |
| | gx, gw = g(gx, gw, y) |
| | return gx, gw |
| |
|
| | time_fn(rms_norm_loop, g1, x, w) |
| | time_fn(rms_norm_loop, g2, x, w) |
| | time_fn(rms_norm_loop, mx.compile(g1), x, w) |
| | time_fn(rms_norm_loop, mx.compile(g2), x, w) |
| |
|
| | f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum() |
| | f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum() |
| | g1 = mx.grad(f1, argnums=(0,)) |
| | g2 = mx.grad(f2, argnums=(0,)) |
| |
|
| | x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) |
| | w = mx.random.uniform(shape=(4096,)).astype(mx.float16) |
| | y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) |
| | mx.eval(x, w, y) |
| |
|
| | def rms_norm_loop(g, x): |
| | gx = x |
| | for _ in range(32): |
| | gx = g(gx, y) |
| | return gx |
| |
|
| | time_fn(rms_norm_loop, g1, x) |
| | time_fn(rms_norm_loop, g2, x) |
| | time_fn(rms_norm_loop, mx.compile(g1), x) |
| | time_fn(rms_norm_loop, mx.compile(g2), x) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | time_rms_norm() |
| |
|