File size: 1,075 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import time

import mlx.core as mx

rank = mx.distributed.init().rank()


def timeit(fn, a):

    # warmup
    for _ in range(5):
        mx.eval(fn(a))

    its = 10
    tic = time.perf_counter()
    for _ in range(its):
        mx.eval(fn(a))
    toc = time.perf_counter()
    ms = 1000 * (toc - tic) / its
    return ms


def all_reduce_benchmark():
    a = mx.ones((5, 5), mx.int32)

    its_per_eval = 100

    def fn(x):
        for _ in range(its_per_eval):
            x = mx.distributed.all_sum(x)
            x = x - 1
        return x

    ms = timeit(fn, a) / its_per_eval
    if rank == 0:
        print(f"All Reduce: time per iteration {ms:.6f} (ms)")


def all_gather_benchmark():
    a = mx.ones((5, 5), mx.int32)
    its_per_eval = 100

    def fn(x):
        for _ in range(its_per_eval):
            x = mx.distributed.all_gather(x)[0]
        return x

    ms = timeit(fn, a) / its_per_eval
    if rank == 0:
        print(f"All gather: time per iteration {ms:.6f} (ms)")


if __name__ == "__main__":
    all_reduce_benchmark()
    all_gather_benchmark()