| | |
| |
|
| | """ |
| | Run with: |
| | mpirun -n 2 python /path/to/distributed_bench.py |
| | """ |
| |
|
| | import time |
| |
|
| | import mlx.core as mx |
| |
|
| |
|
| | def time_fn(fn, *args, **kwargs): |
| | msg = kwargs.pop("msg", None) |
| | world = mx.distributed.init() |
| | if world.rank() == 0: |
| | if msg: |
| | print(f"Timing {msg} ...", end=" ") |
| | else: |
| | print(f"Timing {fn.__name__} ...", end=" ") |
| |
|
| | |
| | for _ in range(5): |
| | mx.eval(fn(*args, **kwargs)) |
| |
|
| | num_iters = 100 |
| | tic = time.perf_counter() |
| | for _ in range(num_iters): |
| | x = mx.eval(fn(*args, **kwargs)) |
| | toc = time.perf_counter() |
| |
|
| | msec = 1e3 * (toc - tic) / num_iters |
| | if world.rank() == 0: |
| | print(f"{msec:.5f} msec") |
| |
|
| |
|
| | def time_all_sum(): |
| | shape = (4096,) |
| | x = mx.random.uniform(shape=shape) |
| | mx.eval(x) |
| |
|
| | def sine(x): |
| | for _ in range(20): |
| | x = mx.sin(x) |
| | return x |
| |
|
| | time_fn(sine, x) |
| |
|
| | def all_sum_plain(x): |
| | for _ in range(20): |
| | x = mx.distributed.all_sum(x) |
| | return x |
| |
|
| | time_fn(all_sum_plain, x) |
| |
|
| | def all_sum_with_sine(x): |
| | for _ in range(20): |
| | x = mx.sin(x) |
| | x = mx.distributed.all_sum(x) |
| | return x |
| |
|
| | time_fn(all_sum_with_sine, x) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | time_all_sum() |
| |
|