| | |
| |
|
| | import unittest |
| |
|
| | import mlx.core as mx |
| | import mlx_distributed_tests |
| |
|
| |
|
| | class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): |
| | @classmethod |
| | def setUpClass(cls): |
| | world = mx.distributed.init(strict=True, backend="mpi") |
| |
|
| | def test_groups(self): |
| | world = mx.distributed.init() |
| | self.assertEqual(world.size(), 8) |
| | self.assertTrue(0 <= world.rank() < 8) |
| |
|
| | world2 = mx.distributed.init() |
| | self.assertEqual(world.size(), world2.size()) |
| | self.assertEqual(world.rank(), world2.rank()) |
| |
|
| | sub = world.split(world.rank() % 2) |
| | self.assertEqual(sub.size(), 4) |
| | self.assertEqual(sub.rank(), world.rank() // 2) |
| |
|
| | sub = world.split(world.rank() // 2) |
| | self.assertEqual(sub.size(), 2) |
| |
|
| | def test_all_reduce(self): |
| | world = mx.distributed.init() |
| | dtypes = [ |
| | (mx.int8, 0), |
| | (mx.uint8, 0), |
| | (mx.int16, 0), |
| | (mx.uint16, 0), |
| | (mx.int32, 0), |
| | (mx.uint32, 0), |
| | (mx.float32, 1e-6), |
| | (mx.float16, 5e-3), |
| | (mx.bfloat16, 1e-1), |
| | (mx.complex64, 1e-6), |
| | ] |
| | sizes = [ |
| | (7,), |
| | (10,), |
| | (1024,), |
| | (1024, 1024), |
| | ] |
| | key = mx.random.key(0) |
| | group = world.split(world.rank() % 2) |
| |
|
| | for dt, rtol in dtypes: |
| | for sh in sizes: |
| | for g in [world, group]: |
| | x = ( |
| | mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10 |
| | ).astype(dt) |
| |
|
| | |
| | y = mx.distributed.all_sum(x[g.rank()], group=g) |
| | z = x.sum(0) |
| | maxrelerror = (y - z).abs() |
| | if rtol > 0: |
| | maxrelerror /= z.abs() |
| | maxrelerror = maxrelerror.max() |
| | self.assertLessEqual(maxrelerror, rtol) |
| |
|
| | |
| | y = mx.distributed.all_max(x[g.rank()], group=g) |
| | z = x.max(0) |
| | self.assertTrue(mx.all(y == z)) |
| |
|
| | |
| | y = mx.distributed.all_min(x[g.rank()], group=g) |
| | z = x.min(0) |
| | self.assertTrue(mx.all(y == z)) |
| |
|
| | def test_all_gather(self): |
| | world = mx.distributed.init() |
| | dtypes = [ |
| | mx.int8, |
| | mx.uint8, |
| | mx.int16, |
| | mx.uint16, |
| | mx.int32, |
| | mx.uint32, |
| | mx.float32, |
| | mx.complex64, |
| | ] |
| | for dt in dtypes: |
| | x = mx.ones((2, 2, 4), dtype=dt) |
| | y = mx.distributed.all_gather(x) |
| | self.assertEqual(y.shape, (world.size() * 2, 2, 4)) |
| | self.assertTrue(mx.all(y == 1)) |
| |
|
| | sub = world.split(world.rank() % 2) |
| | for dt in dtypes: |
| | x = mx.ones((2, 2, 4), dtype=dt) |
| | y = mx.distributed.all_gather(x, group=sub) |
| | self.assertEqual(y.shape, (sub.size() * 2, 2, 4)) |
| | self.assertTrue(mx.all(y == 1)) |
| |
|
| | def test_mixed(self): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | world = mx.distributed.init() |
| | sub_1 = world.split(world.rank() // 2) |
| | sub_2 = world.split(world.rank() % 2) |
| |
|
| | x = mx.ones((1, 8)) * world.rank() |
| | y = mx.distributed.all_sum(x, group=sub_1) |
| | z = mx.distributed.all_gather(y, group=sub_2) |
| | z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True) |
| |
|
| | self.assertTrue(mx.all(z == z_target)) |
| |
|
| | def test_send_recv(self): |
| | world = mx.distributed.init() |
| | pairs = world.split(world.rank() // 2) |
| | neighbor = (pairs.rank() + 1) % 2 |
| | send = pairs.rank() == 0 |
| |
|
| | x = mx.ones(10) |
| | for i in range(10): |
| | if send: |
| | mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs)) |
| | else: |
| | x = mx.distributed.recv_like(x, neighbor, group=pairs) |
| | mx.eval(x) |
| | send = not send |
| |
|
| | self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512))) |
| |
|
| | |
| | y = mx.ones((5, 5)) + mx.array(2.0) |
| | if send: |
| | x = mx.distributed.send(2 * x, neighbor, group=pairs) |
| | else: |
| | x = mx.distributed.recv_like(x, neighbor, group=pairs) |
| | mx.eval(y, x) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|