|
|
|
|
|
|
|
|
""" |
|
|
Run with: |
|
|
mpirun -n 2 python /path/to/distributed_bench.py |
|
|
""" |
|
|
|
|
|
import time |
|
|
|
|
|
import mlx.core as mx |
|
|
|
|
|
|
|
|
def time_fn(fn, *args, **kwargs): |
|
|
msg = kwargs.pop("msg", None) |
|
|
world = mx.distributed.init() |
|
|
if world.rank() == 0: |
|
|
if msg: |
|
|
print(f"Timing {msg} ...", end=" ") |
|
|
else: |
|
|
print(f"Timing {fn.__name__} ...", end=" ") |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
mx.eval(fn(*args, **kwargs)) |
|
|
|
|
|
num_iters = 100 |
|
|
tic = time.perf_counter() |
|
|
for _ in range(num_iters): |
|
|
x = mx.eval(fn(*args, **kwargs)) |
|
|
toc = time.perf_counter() |
|
|
|
|
|
msec = 1e3 * (toc - tic) / num_iters |
|
|
if world.rank() == 0: |
|
|
print(f"{msec:.5f} msec") |
|
|
|
|
|
|
|
|
def time_all_sum(): |
|
|
shape = (4096,) |
|
|
x = mx.random.uniform(shape=shape) |
|
|
mx.eval(x) |
|
|
|
|
|
def sine(x): |
|
|
for _ in range(20): |
|
|
x = mx.sin(x) |
|
|
return x |
|
|
|
|
|
time_fn(sine, x) |
|
|
|
|
|
def all_sum_plain(x): |
|
|
for _ in range(20): |
|
|
x = mx.distributed.all_sum(x) |
|
|
return x |
|
|
|
|
|
time_fn(all_sum_plain, x) |
|
|
|
|
|
def all_sum_with_sine(x): |
|
|
for _ in range(20): |
|
|
x = mx.sin(x) |
|
|
x = mx.distributed.all_sum(x) |
|
|
return x |
|
|
|
|
|
time_fn(all_sum_with_sine, x) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
time_all_sum() |
|
|
|