| | |
| |
|
| | import mlx.core as mx |
| | from time_utils import time_fn |
| |
|
| | N = 1024 |
| | D = 1024 |
| | M = 1024 |
| | E = 32 |
| | I = 4 |
| |
|
| |
|
| | def gather_sort(x, indices): |
| | N, M = indices.shape |
| | indices = indices.flatten() |
| | order = mx.argsort(indices) |
| | inv_order = mx.argsort(order) |
| | return x.flatten(0, -3)[order // M], indices[order], inv_order |
| |
|
| |
|
| | def scatter_unsort(x, inv_order, shape=None): |
| | x = x[inv_order] |
| | if shape is not None: |
| | x = mx.unflatten(x, 0, shape) |
| | return x |
| |
|
| |
|
| | def gather_mm_simulate(x, w, indices): |
| | x, idx, inv_order = gather_sort(x, indices) |
| | for i in range(2): |
| | y = mx.concatenate( |
| | [ |
| | mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True) |
| | for i, j in enumerate(idx.tolist()) |
| | ], |
| | axis=0, |
| | ) |
| | x = y[:, None] |
| | x = scatter_unsort(x, inv_order, indices.shape) |
| | return x |
| |
|
| |
|
| | def time_gather_qmm(): |
| | x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 |
| | w1 = mx.random.normal((E, M, D)) / 1024**0.5 |
| | w2 = mx.random.normal((E, D, M)) / 1024**0.5 |
| | w1 = mx.quantize(w1) |
| | w2 = mx.quantize(w2) |
| | indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) |
| | sorted_indices = mx.sort(indices.flatten()).reshape(N, I) |
| | mx.eval(x, w1, w2, indices, sorted_indices) |
| |
|
| | def gather_mm(x, w1, w2, indices, sort): |
| | idx = indices |
| | inv_order = None |
| | if sort: |
| | x, idx, inv_order = gather_sort(x, indices) |
| | x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort) |
| | x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort) |
| | if sort: |
| | x = scatter_unsort(x, inv_order, indices.shape) |
| | return x |
| |
|
| | time_fn(gather_mm, x, w1, w2, indices, False) |
| | time_fn(gather_mm, x, w1, w2, sorted_indices, False) |
| | time_fn(gather_mm, x, w1, w2, indices, True) |
| |
|
| | x = mx.random.normal((N * I, D)) / 1024**0.5 |
| | w1 = mx.random.normal((M, D)) / 1024**0.5 |
| | w2 = mx.random.normal((D, M)) / 1024**0.5 |
| | w1 = mx.quantize(w1) |
| | w2 = mx.quantize(w2) |
| | mx.eval(x, w1, w2) |
| |
|
| | def equivalent_matmul(x, w1, w2): |
| | x = mx.quantized_matmul(x, *w1, transpose=True) |
| | x = mx.quantized_matmul(x, *w2, transpose=True) |
| | return x |
| |
|
| | time_fn(equivalent_matmul, x, w1, w2) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | time_gather_qmm() |
| |
|