File size: 637 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright © 2023-2024 Apple Inc.

import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn


def time_rope():
    rope = nn.RoPE(64)

    # vec
    x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
    mx.eval(x)

    def rope_vec(x):
        for _ in range(32):
            x = rope(x, offset=100)
        return x

    time_fn(rope_vec, x)

    # matrix
    x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
    mx.eval(x)

    def rope_mat(x):
        for _ in range(32):
            x = rope(x)
        return x

    time_fn(rope_mat, x)


if __name__ == "__main__":
    time_rope()