File size: 786 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright © 2023-2024 Apple Inc.

import time

import mlx.core as mx


def time_fn(fn, *args, **kwargs):
    msg = kwargs.pop("msg", None)
    if msg:
        print(f"Timing {msg} ...", end=" ")
    else:
        print(f"Timing {fn.__name__} ...", end=" ")

    # warmup
    for _ in range(5):
        mx.eval(fn(*args, **kwargs))

    num_iters = 100
    tic = time.perf_counter()
    for _ in range(num_iters):
        x = mx.eval(fn(*args, **kwargs))
    toc = time.perf_counter()

    msec = 1e3 * (toc - tic) / num_iters
    print(f"{msec:.5f} msec")


def measure_runtime(fn, **kwargs):
    # Warmup
    for _ in range(5):
        fn(**kwargs)

    tic = time.time()
    iters = 100
    for _ in range(iters):
        fn(**kwargs)
    return (time.time() - tic) * 1000 / iters