| import argparse | |
| import math | |
| import mlx.core as mx | |
| from time_utils import time_fn | |
| L = 16384 | |
| H = 32 | |
| H_k = H // 4 | |
| D = 128 | |
| V = 128 | |
| dtype = mx.float16 | |
| loops = 10 | |
| def upproject(x, w): | |
| if w is None: | |
| return x | |
| else: | |
| return x @ w.T | |
| def attention(q, k, v, mask=None, w=None): | |
| def _sdpa(q, k, v): | |
| B, Hq, L, D = q.shape | |
| _, Hk, S, _ = k.shape | |
| _, _, _, V = v.shape | |
| q = q.reshape(B, Hk, Hq // Hk, L, D) | |
| k = k[:, :, None, :, :] | |
| v = v[:, :, None, :, :] | |
| s = q @ k.transpose(0, 1, 2, 4, 3) | |
| if mask is not None: | |
| m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S) | |
| s = mx.where(m, s, mx.finfo(s.dtype).min) | |
| p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) | |
| o = p @ v | |
| return o.reshape(B, Hq, L, V) | |
| for i in range(loops): | |
| q = _sdpa(q, k, v) | |
| q = upproject(q, w) | |
| return q | |
| def sdpa(q, k, v, mask=None, w=None): | |
| for i in range(loops): | |
| q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) | |
| q = upproject(q, w) | |
| return q | |
| def time_self_attention_primitives(): | |
| mx.random.seed(3) | |
| q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) | |
| k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) | |
| v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) | |
| w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None | |
| mx.eval(q, k, v, w) | |
| time_fn(attention, q, k, v, w=w) | |
| def time_self_attention_sdpa(): | |
| mx.random.seed(3) | |
| q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) | |
| k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) | |
| v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) | |
| w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None | |
| mx.eval(q, k, v, w) | |
| time_fn(sdpa, q, k, v, w=w) | |
| def time_self_attention_sdpa_with_mask(): | |
| mx.random.seed(3) | |
| q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) | |
| k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) | |
| v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) | |
| w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None | |
| mask = mx.full((L,), True) | |
| mask[L // 2 :] = False | |
| mx.eval(q, k, v, mask, w) | |
| def sdpa_mask(*args): | |
| return sdpa(*args, mask=mask, w=w) | |
| def attention_mask(*args): | |
| return attention(*args, mask=mask, w=w) | |
| time_fn(attention_mask, q, k, v) | |
| time_fn(sdpa_mask, q, k, v) | |
| if __name__ == "__main__": | |
| time_self_attention_sdpa() | |
| time_self_attention_primitives() | |
| time_self_attention_sdpa_with_mask() | |