File size: 3,788 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright © 2023 Apple Inc.

import unittest

import mlx.core as mx
import mlx_tests


# Don't inherit from MLXTestCase to avoid call to setUp
class TestDefaultDevice(unittest.TestCase):
    def test_mlx_default_device(self):
        device = mx.default_device()
        if mx.is_available(mx.gpu):
            self.assertEqual(device, mx.Device(mx.gpu))
            self.assertEqual(str(device), "Device(gpu, 0)")
            self.assertEqual(device, mx.gpu)
            self.assertEqual(mx.gpu, device)
        else:
            self.assertEqual(device.type, mx.Device(mx.cpu))
            with self.assertRaises(ValueError):
                mx.set_default_device(mx.gpu)


class TestDevice(mlx_tests.MLXTestCase):
    def test_device(self):
        device = mx.default_device()

        cpu = mx.Device(mx.cpu)
        mx.set_default_device(cpu)
        self.assertEqual(mx.default_device(), cpu)
        self.assertEqual(str(cpu), "Device(cpu, 0)")

        mx.set_default_device(mx.cpu)
        self.assertEqual(mx.default_device(), mx.cpu)
        self.assertEqual(cpu, mx.cpu)
        self.assertEqual(mx.cpu, cpu)

        # Restore device
        mx.set_default_device(device)

    @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
    def test_device_context(self):
        default = mx.default_device()
        diff = mx.cpu if default == mx.gpu else mx.gpu
        self.assertNotEqual(default, diff)
        with mx.stream(diff):
            a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
            mx.eval(a)
            self.assertEqual(mx.default_device(), diff)
        self.assertEqual(mx.default_device(), default)

    def test_op_on_device(self):
        x = mx.array(1.0)
        y = mx.array(1.0)

        a = mx.add(x, y, stream=None)
        b = mx.add(x, y, stream=mx.default_device())
        self.assertEqual(a.item(), b.item())
        b = mx.add(x, y, stream=mx.cpu)
        self.assertEqual(a.item(), b.item())

        if mx.metal.is_available():
            b = mx.add(x, y, stream=mx.gpu)
            self.assertEqual(a.item(), b.item())


class TestStream(mlx_tests.MLXTestCase):
    def test_stream(self):
        s1 = mx.default_stream(mx.default_device())
        self.assertEqual(s1.device, mx.default_device())

        s2 = mx.new_stream(mx.default_device())
        self.assertEqual(s2.device, mx.default_device())
        self.assertNotEqual(s1, s2)

        if mx.is_available(mx.gpu):
            s_gpu = mx.default_stream(mx.gpu)
            self.assertEqual(s_gpu.device, mx.gpu)
        else:
            with self.assertRaises(ValueError):
                mx.default_stream(mx.gpu)

        s_cpu = mx.default_stream(mx.cpu)
        self.assertEqual(s_cpu.device, mx.cpu)

        s_cpu = mx.new_stream(mx.cpu)
        self.assertEqual(s_cpu.device, mx.cpu)

        if mx.is_available(mx.gpu):
            s_gpu = mx.new_stream(mx.gpu)
            self.assertEqual(s_gpu.device, mx.gpu)
        else:
            with self.assertRaises(ValueError):
                mx.new_stream(mx.gpu)

    def test_op_on_stream(self):
        x = mx.array(1.0)
        y = mx.array(1.0)

        a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))

        if mx.is_available(mx.gpu):
            b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
            self.assertEqual(a.item(), b.item())
            s_gpu = mx.new_stream(mx.gpu)
            b = mx.add(x, y, stream=s_gpu)
            self.assertEqual(a.item(), b.item())

        b = mx.add(x, y, stream=mx.default_stream(mx.cpu))
        self.assertEqual(a.item(), b.item())
        s_cpu = mx.new_stream(mx.cpu)
        b = mx.add(x, y, stream=s_cpu)
        self.assertEqual(a.item(), b.item())


if __name__ == "__main__":
    mlx_tests.MLXTestRunner()