|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx_distributed_tests |
|
|
import mlx_tests |
|
|
|
|
|
|
|
|
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): |
|
|
@classmethod |
|
|
def setUpClass(cls): |
|
|
world = mx.distributed.init(strict=True, backend="ring") |
|
|
|
|
|
def test_groups(self): |
|
|
world = mx.distributed.init() |
|
|
self.assertEqual(world.size(), 8) |
|
|
self.assertTrue(0 <= world.rank() < 8) |
|
|
|
|
|
world2 = mx.distributed.init() |
|
|
self.assertEqual(world.size(), world2.size()) |
|
|
self.assertEqual(world.rank(), world2.rank()) |
|
|
|
|
|
with self.assertRaises(RuntimeError): |
|
|
sub = world.split(world.rank() % 2) |
|
|
|
|
|
def test_all_reduce(self): |
|
|
world = mx.distributed.init() |
|
|
dtypes = [ |
|
|
(mx.int8, 0), |
|
|
(mx.uint8, 0), |
|
|
(mx.int16, 0), |
|
|
(mx.uint16, 0), |
|
|
(mx.int32, 0), |
|
|
(mx.uint32, 0), |
|
|
(mx.float32, 1e-6), |
|
|
(mx.float16, 5e-3), |
|
|
(mx.bfloat16, 1e-1), |
|
|
(mx.complex64, 1e-6), |
|
|
] |
|
|
sizes = [ |
|
|
(7,), |
|
|
(10,), |
|
|
(1024,), |
|
|
(1024, 1024), |
|
|
] |
|
|
key = mx.random.key(0) |
|
|
reductions = ["min", "max", "sum"] |
|
|
|
|
|
for dt, rtol in dtypes: |
|
|
for sh in sizes: |
|
|
x = ( |
|
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 |
|
|
).astype(dt) |
|
|
|
|
|
|
|
|
y = mx.distributed.all_sum(x[world.rank()]) |
|
|
z = x.sum(0) |
|
|
maxrelerror = (y - z).abs() |
|
|
if rtol > 0: |
|
|
maxrelerror /= z.abs() |
|
|
maxrelerror = maxrelerror.max() |
|
|
self.assertLessEqual(maxrelerror, rtol) |
|
|
|
|
|
|
|
|
y = mx.distributed.all_max(x[world.rank()]) |
|
|
z = x.max(0) |
|
|
self.assertTrue(mx.all(y == z)) |
|
|
|
|
|
|
|
|
y = mx.distributed.all_min(x[world.rank()]) |
|
|
z = x.min(0) |
|
|
self.assertTrue(mx.all(y == z)) |
|
|
|
|
|
def test_all_gather(self): |
|
|
world = mx.distributed.init() |
|
|
dtypes = [ |
|
|
mx.int8, |
|
|
mx.uint8, |
|
|
mx.int16, |
|
|
mx.uint16, |
|
|
mx.int32, |
|
|
mx.uint32, |
|
|
mx.float32, |
|
|
mx.complex64, |
|
|
] |
|
|
for dt in dtypes: |
|
|
x = mx.ones((2, 2, 4), dtype=dt) |
|
|
y = mx.distributed.all_gather(x) |
|
|
self.assertEqual(y.shape, (world.size() * 2, 2, 4)) |
|
|
self.assertTrue(mx.all(y == 1)) |
|
|
|
|
|
def test_send_recv(self): |
|
|
world = mx.distributed.init() |
|
|
dtypes = [ |
|
|
mx.int8, |
|
|
mx.uint8, |
|
|
mx.int16, |
|
|
mx.uint16, |
|
|
mx.int32, |
|
|
mx.uint32, |
|
|
mx.float32, |
|
|
mx.float16, |
|
|
mx.bfloat16, |
|
|
mx.complex64, |
|
|
] |
|
|
sizes = [ |
|
|
(7,), |
|
|
(10,), |
|
|
(1024,), |
|
|
(1024, 1024), |
|
|
] |
|
|
key = mx.random.key(0) |
|
|
right = (world.rank() + 1) % world.size() |
|
|
left = (world.rank() + world.size() - 1) % world.size() |
|
|
for dt in dtypes: |
|
|
for sh in sizes: |
|
|
x = ( |
|
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 |
|
|
).astype(dt) |
|
|
if world.rank() % 2 == 0: |
|
|
y = mx.distributed.send(x[world.rank()], right) |
|
|
z = mx.distributed.recv_like(y, left) |
|
|
mx.eval(y, z) |
|
|
else: |
|
|
z = mx.distributed.recv_like(x[world.rank()], left) |
|
|
y = mx.distributed.send(x[world.rank()], right) |
|
|
mx.eval(z, y) |
|
|
self.assertTrue(mx.all(y == x[world.rank()])) |
|
|
self.assertTrue(mx.all(z == x[left])) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
mlx_tests.MLXTestRunner() |
|
|
|