File size: 1,311 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright © 2024 Apple Inc.

"""
Run with:
    mpirun -n 2 python /path/to/distributed_bench.py
"""

import time

import mlx.core as mx


def time_fn(fn, *args, **kwargs):
    msg = kwargs.pop("msg", None)
    world = mx.distributed.init()
    if world.rank() == 0:
        if msg:
            print(f"Timing {msg} ...", end=" ")
        else:
            print(f"Timing {fn.__name__} ...", end=" ")

    # warmup
    for _ in range(5):
        mx.eval(fn(*args, **kwargs))

    num_iters = 100
    tic = time.perf_counter()
    for _ in range(num_iters):
        x = mx.eval(fn(*args, **kwargs))
    toc = time.perf_counter()

    msec = 1e3 * (toc - tic) / num_iters
    if world.rank() == 0:
        print(f"{msec:.5f} msec")


def time_all_sum():
    shape = (4096,)
    x = mx.random.uniform(shape=shape)
    mx.eval(x)

    def sine(x):
        for _ in range(20):
            x = mx.sin(x)
        return x

    time_fn(sine, x)

    def all_sum_plain(x):
        for _ in range(20):
            x = mx.distributed.all_sum(x)
        return x

    time_fn(all_sum_plain, x)

    def all_sum_with_sine(x):
        for _ in range(20):
            x = mx.sin(x)
            x = mx.distributed.all_sum(x)
        return x

    time_fn(all_sum_with_sine, x)


if __name__ == "__main__":
    time_all_sum()