|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx_tests |
|
|
|
|
|
|
|
|
|
|
|
class TestDefaultDevice(unittest.TestCase): |
|
|
def test_mlx_default_device(self): |
|
|
device = mx.default_device() |
|
|
if mx.is_available(mx.gpu): |
|
|
self.assertEqual(device, mx.Device(mx.gpu)) |
|
|
self.assertEqual(str(device), "Device(gpu, 0)") |
|
|
self.assertEqual(device, mx.gpu) |
|
|
self.assertEqual(mx.gpu, device) |
|
|
else: |
|
|
self.assertEqual(device.type, mx.Device(mx.cpu)) |
|
|
with self.assertRaises(ValueError): |
|
|
mx.set_default_device(mx.gpu) |
|
|
|
|
|
|
|
|
class TestDevice(mlx_tests.MLXTestCase): |
|
|
def test_device(self): |
|
|
device = mx.default_device() |
|
|
|
|
|
cpu = mx.Device(mx.cpu) |
|
|
mx.set_default_device(cpu) |
|
|
self.assertEqual(mx.default_device(), cpu) |
|
|
self.assertEqual(str(cpu), "Device(cpu, 0)") |
|
|
|
|
|
mx.set_default_device(mx.cpu) |
|
|
self.assertEqual(mx.default_device(), mx.cpu) |
|
|
self.assertEqual(cpu, mx.cpu) |
|
|
self.assertEqual(mx.cpu, cpu) |
|
|
|
|
|
|
|
|
mx.set_default_device(device) |
|
|
|
|
|
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") |
|
|
def test_device_context(self): |
|
|
default = mx.default_device() |
|
|
diff = mx.cpu if default == mx.gpu else mx.gpu |
|
|
self.assertNotEqual(default, diff) |
|
|
with mx.stream(diff): |
|
|
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2))) |
|
|
mx.eval(a) |
|
|
self.assertEqual(mx.default_device(), diff) |
|
|
self.assertEqual(mx.default_device(), default) |
|
|
|
|
|
def test_op_on_device(self): |
|
|
x = mx.array(1.0) |
|
|
y = mx.array(1.0) |
|
|
|
|
|
a = mx.add(x, y, stream=None) |
|
|
b = mx.add(x, y, stream=mx.default_device()) |
|
|
self.assertEqual(a.item(), b.item()) |
|
|
b = mx.add(x, y, stream=mx.cpu) |
|
|
self.assertEqual(a.item(), b.item()) |
|
|
|
|
|
if mx.metal.is_available(): |
|
|
b = mx.add(x, y, stream=mx.gpu) |
|
|
self.assertEqual(a.item(), b.item()) |
|
|
|
|
|
|
|
|
class TestStream(mlx_tests.MLXTestCase): |
|
|
def test_stream(self): |
|
|
s1 = mx.default_stream(mx.default_device()) |
|
|
self.assertEqual(s1.device, mx.default_device()) |
|
|
|
|
|
s2 = mx.new_stream(mx.default_device()) |
|
|
self.assertEqual(s2.device, mx.default_device()) |
|
|
self.assertNotEqual(s1, s2) |
|
|
|
|
|
if mx.is_available(mx.gpu): |
|
|
s_gpu = mx.default_stream(mx.gpu) |
|
|
self.assertEqual(s_gpu.device, mx.gpu) |
|
|
else: |
|
|
with self.assertRaises(ValueError): |
|
|
mx.default_stream(mx.gpu) |
|
|
|
|
|
s_cpu = mx.default_stream(mx.cpu) |
|
|
self.assertEqual(s_cpu.device, mx.cpu) |
|
|
|
|
|
s_cpu = mx.new_stream(mx.cpu) |
|
|
self.assertEqual(s_cpu.device, mx.cpu) |
|
|
|
|
|
if mx.is_available(mx.gpu): |
|
|
s_gpu = mx.new_stream(mx.gpu) |
|
|
self.assertEqual(s_gpu.device, mx.gpu) |
|
|
else: |
|
|
with self.assertRaises(ValueError): |
|
|
mx.new_stream(mx.gpu) |
|
|
|
|
|
def test_op_on_stream(self): |
|
|
x = mx.array(1.0) |
|
|
y = mx.array(1.0) |
|
|
|
|
|
a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) |
|
|
|
|
|
if mx.is_available(mx.gpu): |
|
|
b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) |
|
|
self.assertEqual(a.item(), b.item()) |
|
|
s_gpu = mx.new_stream(mx.gpu) |
|
|
b = mx.add(x, y, stream=s_gpu) |
|
|
self.assertEqual(a.item(), b.item()) |
|
|
|
|
|
b = mx.add(x, y, stream=mx.default_stream(mx.cpu)) |
|
|
self.assertEqual(a.item(), b.item()) |
|
|
s_cpu = mx.new_stream(mx.cpu) |
|
|
b = mx.add(x, y, stream=s_cpu) |
|
|
self.assertEqual(a.item(), b.item()) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
mlx_tests.MLXTestRunner() |
|
|
|