File size: 2,391 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright © 2025 Apple Inc.

import mlx.core as mx
from time_utils import time_fn

N = 1024
D = 1024
M = 1024
E = 32
I = 4


def gather_sort(x, indices):
    N, M = indices.shape
    indices = indices.flatten()
    order = mx.argsort(indices)
    inv_order = mx.argsort(order)
    return x.flatten(0, -3)[order // M], indices[order], inv_order


def scatter_unsort(x, inv_order, shape=None):
    x = x[inv_order]
    if shape is not None:
        x = mx.unflatten(x, 0, shape)
    return x


def gather_mm_simulate(x, w, indices):
    x, idx, inv_order = gather_sort(x, indices)
    for i in range(2):
        y = mx.concatenate(
            [
                mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
                for i, j in enumerate(idx.tolist())
            ],
            axis=0,
        )
        x = y[:, None]
    x = scatter_unsort(x, inv_order, indices.shape)
    return x


def time_gather_qmm():
    x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
    w1 = mx.random.normal((E, M, D)) / 1024**0.5
    w2 = mx.random.normal((E, D, M)) / 1024**0.5
    w1 = mx.quantize(w1)
    w2 = mx.quantize(w2)
    indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
    sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
    mx.eval(x, w1, w2, indices, sorted_indices)

    def gather_mm(x, w1, w2, indices, sort):
        idx = indices
        inv_order = None
        if sort:
            x, idx, inv_order = gather_sort(x, indices)
        x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
        x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
        if sort:
            x = scatter_unsort(x, inv_order, indices.shape)
        return x

    time_fn(gather_mm, x, w1, w2, indices, False)
    time_fn(gather_mm, x, w1, w2, sorted_indices, False)
    time_fn(gather_mm, x, w1, w2, indices, True)

    x = mx.random.normal((N * I, D)) / 1024**0.5
    w1 = mx.random.normal((M, D)) / 1024**0.5
    w2 = mx.random.normal((D, M)) / 1024**0.5
    w1 = mx.quantize(w1)
    w2 = mx.quantize(w2)
    mx.eval(x, w1, w2)

    def equivalent_matmul(x, w1, w2):
        x = mx.quantized_matmul(x, *w1, transpose=True)
        x = mx.quantized_matmul(x, *w2, transpose=True)
        return x

    time_fn(equivalent_matmul, x, w1, w2)


if __name__ == "__main__":
    time_gather_qmm()