File size: 3,622 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright © 2023 Apple Inc.

import os

# Use regular fp32 precision for tests
os.environ["MLX_ENABLE_TF32"] = "0"

# Do not abort on cache thrashing
os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0"

import platform
import unittest
from typing import Any, Callable, List, Tuple, Union

import mlx.core as mx
import numpy as np


class MLXTestRunner(unittest.TestProgram):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def createTests(self, *args, **kwargs):
        super().createTests(*args, **kwargs)

        # Asume CUDA backend in this case
        device = os.getenv("DEVICE", None)
        if device is not None:
            device = getattr(mx, device)
        else:
            device = mx.default_device()

        if not (device == mx.gpu and not mx.metal.is_available()):
            return

        from cuda_skip import cuda_skip

        filtered_suite = unittest.TestSuite()

        def filter_and_add(t):
            if isinstance(t, unittest.TestSuite):
                for sub_t in t:
                    filter_and_add(sub_t)
            else:
                t_id = ".".join(t.id().split(".")[-2:])
                if t_id in cuda_skip:
                    print(f"Skipping {t_id}")
                else:
                    filtered_suite.addTest(t)

        filter_and_add(self.test)
        self.test = filtered_suite


class MLXTestCase(unittest.TestCase):
    @property
    def is_apple_silicon(self):
        return platform.machine() == "arm64" and platform.system() == "Darwin"

    def setUp(self):
        self.default = mx.default_device()
        device = os.getenv("DEVICE", None)
        if device is not None:
            device = getattr(mx, device)
            mx.set_default_device(device)

    def tearDown(self):
        mx.set_default_device(self.default)

    # Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple
    def assertCmpNumpy(
        self,
        args: List[Union[Tuple[int], Any]],
        mx_fn: Callable[..., mx.array],
        np_fn: Callable[..., np.array],
        atol=1e-2,
        rtol=1e-2,
        dtype=mx.float32,
        **kwargs,
    ):
        assert dtype != mx.bfloat16, "numpy does not support bfloat16"
        args = [
            mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
            for s in args
        ]
        mx_res = mx_fn(*args, **kwargs)
        np_res = np_fn(
            *[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
        )
        return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)

    def assertEqualArray(
        self,
        mx_res: mx.array,
        expected: mx.array,
        atol=1e-2,
        rtol=1e-2,
    ):
        self.assertEqual(
            tuple(mx_res.shape),
            tuple(expected.shape),
            msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
        )
        self.assertEqual(
            mx_res.dtype,
            expected.dtype,
            msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
        )
        if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
            np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
            return
        elif not isinstance(mx_res, mx.array):
            mx_res = mx.array(mx_res)
        elif not isinstance(expected, mx.array):
            expected = mx.array(expected)
        self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))