| | |
| |
|
| | import time |
| |
|
| | import mlx.core as mx |
| | import numpy as np |
| |
|
| |
|
| | def timeit(fn, its=100, args=[]): |
| | for _ in range(5): |
| | fn(*args) |
| | tic = time.perf_counter() |
| | for _ in range(its): |
| | fn(*args) |
| | toc = time.perf_counter() |
| | return 1e3 * (toc - tic) / its |
| |
|
| |
|
| | def time_little_einsum_path(): |
| | subscripts = "ik,kj->ij" |
| | x = mx.ones((32, 32)) |
| | y = mx.ones((32, 32)) |
| | mx_time = timeit(mx.einsum_path, args=(subscripts, x, y)) |
| |
|
| | x = np.array(x) |
| | y = np.array(y) |
| | np_time = timeit(np.einsum_path, args=(subscripts, x, y)) |
| | print("Timing little einsum path...") |
| | print(f"MLX ... {mx_time:.3f} ms") |
| | print(f"NumPy... {np_time:.3f} ms") |
| |
|
| |
|
| | def time_big_einsum_path(): |
| | chars = list("abcdefgh") |
| | char_to_dim = {c: v for v, c in enumerate(chars)} |
| |
|
| | num_inputs = 10 |
| | inputs = [] |
| | subscripts = [] |
| | for _ in range(num_inputs): |
| | subscript = np.random.choice(chars, size=5, replace=False).tolist() |
| | subscripts.append("".join(subscript)) |
| | inputs.append(np.ones(list(char_to_dim[c] for c in subscript))) |
| | subscripts = ",".join(subscripts) |
| |
|
| | np_time = timeit(np.einsum_path, args=(subscripts, *inputs)) |
| |
|
| | inputs = [mx.array(x) for x in inputs] |
| | mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs)) |
| | print("Timing big einsum path...") |
| | print(f"MLX ... {mx_time:.3f} ms") |
| | print(f"NumPy... {np_time:.3f} ms") |
| |
|
| |
|
| | def time_attention(): |
| | def regular_attention(x): |
| | |
| | queries, keys, values = x, x, x |
| | scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1) |
| | scores = mx.softmax(scores, axis=-1) |
| | output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2) |
| | mx.eval(output) |
| |
|
| | def einsum_attention(x): |
| | |
| | queries, keys, values = x, x, x |
| | scores = mx.einsum("itjk,iujk->ijtu", queries, keys) |
| | scores = mx.softmax(scores, axis=-1) |
| | output = mx.einsum("ijtu,iujk->itjk", scores, values) |
| | mx.eval(output) |
| |
|
| | x = mx.random.uniform(shape=(8, 512, 32, 128)) |
| |
|
| | regular_time = timeit(regular_attention, args=(x,)) |
| | ein_time = timeit(einsum_attention, args=(x,)) |
| | print("Timing einsum attention...") |
| | print(f"Regular ... {regular_time:.3f} ms") |
| | print(f"Einsum ... {ein_time:.3f} ms") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | time_little_einsum_path() |
| | time_big_einsum_path() |
| | time_attention() |
| |
|