File size: 2,658 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import math

import mlx.core as mx
from time_utils import time_fn

L = 16384
H = 32
H_k = H // 4
D = 128
V = 128
dtype = mx.float16
loops = 10


def upproject(x, w):
    if w is None:
        return x
    else:
        return x @ w.T


def attention(q, k, v, mask=None, w=None):
    def _sdpa(q, k, v):
        B, Hq, L, D = q.shape
        _, Hk, S, _ = k.shape
        _, _, _, V = v.shape
        q = q.reshape(B, Hk, Hq // Hk, L, D)
        k = k[:, :, None, :, :]
        v = v[:, :, None, :, :]
        s = q @ k.transpose(0, 1, 2, 4, 3)
        if mask is not None:
            m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
            s = mx.where(m, s, mx.finfo(s.dtype).min)
        p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
        o = p @ v
        return o.reshape(B, Hq, L, V)

    for i in range(loops):
        q = _sdpa(q, k, v)
        q = upproject(q, w)
    return q


def sdpa(q, k, v, mask=None, w=None):
    for i in range(loops):
        q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
        q = upproject(q, w)
    return q


def time_self_attention_primitives():
    mx.random.seed(3)
    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
    v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
    w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
    mx.eval(q, k, v, w)
    time_fn(attention, q, k, v, w=w)


def time_self_attention_sdpa():
    mx.random.seed(3)
    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
    v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
    w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
    mx.eval(q, k, v, w)
    time_fn(sdpa, q, k, v, w=w)


def time_self_attention_sdpa_with_mask():
    mx.random.seed(3)
    q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
    k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
    v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
    w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
    mask = mx.full((L,), True)
    mask[L // 2 :] = False
    mx.eval(q, k, v, mask, w)

    def sdpa_mask(*args):
        return sdpa(*args, mask=mask, w=w)

    def attention_mask(*args):
        return attention(*args, mask=mask, w=w)

    time_fn(attention_mask, q, k, v)
    time_fn(sdpa_mask, q, k, v)


if __name__ == "__main__":
    time_self_attention_sdpa()
    time_self_attention_primitives()
    time_self_attention_sdpa_with_mask()