File size: 2,509 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright © 2024 Apple Inc.

import time

import mlx.core as mx
import numpy as np


def timeit(fn, its=100, args=[]):
    for _ in range(5):
        fn(*args)
    tic = time.perf_counter()
    for _ in range(its):
        fn(*args)
    toc = time.perf_counter()
    return 1e3 * (toc - tic) / its


def time_little_einsum_path():
    subscripts = "ik,kj->ij"
    x = mx.ones((32, 32))
    y = mx.ones((32, 32))
    mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))

    x = np.array(x)
    y = np.array(y)
    np_time = timeit(np.einsum_path, args=(subscripts, x, y))
    print("Timing little einsum path...")
    print(f"MLX ... {mx_time:.3f} ms")
    print(f"NumPy... {np_time:.3f} ms")


def time_big_einsum_path():
    chars = list("abcdefgh")
    char_to_dim = {c: v for v, c in enumerate(chars)}

    num_inputs = 10
    inputs = []
    subscripts = []
    for _ in range(num_inputs):
        subscript = np.random.choice(chars, size=5, replace=False).tolist()
        subscripts.append("".join(subscript))
        inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
    subscripts = ",".join(subscripts)

    np_time = timeit(np.einsum_path, args=(subscripts, *inputs))

    inputs = [mx.array(x) for x in inputs]
    mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
    print("Timing big einsum path...")
    print(f"MLX ... {mx_time:.3f} ms")
    print(f"NumPy... {np_time:.3f} ms")


def time_attention():
    def regular_attention(x):
        # shape [batch, sequence, num_heads, head_dim]
        queries, keys, values = x, x, x
        scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
        scores = mx.softmax(scores, axis=-1)
        output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
        mx.eval(output)

    def einsum_attention(x):
        # shape [batch, sequence, num_heads, head_dim]
        queries, keys, values = x, x, x
        scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
        scores = mx.softmax(scores, axis=-1)
        output = mx.einsum("ijtu,iujk->itjk", scores, values)
        mx.eval(output)

    x = mx.random.uniform(shape=(8, 512, 32, 128))

    regular_time = timeit(regular_attention, args=(x,))
    ein_time = timeit(einsum_attention, args=(x,))
    print("Timing einsum attention...")
    print(f"Regular ... {regular_time:.3f} ms")
    print(f"Einsum ... {ein_time:.3f} ms")


if __name__ == "__main__":
    time_little_einsum_path()
    time_big_einsum_path()
    time_attention()