File size: 2,520 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright © 2023-2024 Apple Inc.

from functools import partial

import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn


def layer_norm(x, w, b, eps):
    ot = x.dtype
    x = x.astype(mx.float32)
    mu = mx.mean(x, -1, keepdims=True)
    v = mx.var(x, -1, keepdims=True)
    y = (x - mu) * mx.rsqrt(v + eps)
    if w is not None:
        y = y * w
    if b is not None:
        y = y + b
    return y


def time_layer_norm(N, dt):
    L = 1024
    f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
    f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
    g1 = mx.grad(f1, argnums=(0, 1, 2))
    g2 = mx.grad(f2, argnums=(0, 1, 2))

    x = mx.random.uniform(shape=(8, L, N)).astype(dt)
    w = mx.random.uniform(shape=(N,)).astype(dt)
    b = mx.random.uniform(shape=(N,)).astype(dt)
    y = mx.random.uniform(shape=(8, L, N)).astype(dt)
    mx.eval(x, w, b, y)

    def layer_norm_loop(f, x, w, b):
        for _ in range(32):
            x = f(x, w, b)
        return x

    time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
    time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)

    def layer_norm_grad_loop(g, x, w, b):
        gx, gw, gb = x, w, b
        for _ in range(32):
            gx, gw, gb = g(gx, gw, gb, y)
        return gx, gw, gb

    time_fn(layer_norm_grad_loop, g1, x, w, b)
    time_fn(layer_norm_grad_loop, g2, x, w, b)
    time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
    time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)

    f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
    f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
    g1 = mx.grad(f1, argnums=(0,))
    g2 = mx.grad(f2, argnums=(0,))

    x = mx.random.uniform(shape=(8, L, N)).astype(dt)
    w = mx.random.uniform(shape=(N,)).astype(dt)
    b = mx.random.uniform(shape=(N,)).astype(dt)
    y = mx.random.uniform(shape=(8, L, N)).astype(dt)
    mx.eval(x, w, b, y)

    def layer_norm_grad_x_loop(g, x):
        gx = x
        for _ in range(32):
            gx = g(gx, y)
        return gx

    time_fn(layer_norm_grad_x_loop, g1, x)
    time_fn(layer_norm_grad_x_loop, g2, x)
    time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
    time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)


if __name__ == "__main__":
    for dt in [mx.float32, mx.float16, mx.bfloat16]:
        for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
            print(dt, n)
            time_layer_norm(n, dt)