| | |
| |
|
| | import os |
| |
|
| | |
| | os.environ["MLX_ENABLE_TF32"] = "0" |
| |
|
| | |
| | os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0" |
| |
|
| | import platform |
| | import unittest |
| | from typing import Any, Callable, List, Tuple, Union |
| |
|
| | import mlx.core as mx |
| | import numpy as np |
| |
|
| |
|
| | class MLXTestRunner(unittest.TestProgram): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def createTests(self, *args, **kwargs): |
| | super().createTests(*args, **kwargs) |
| |
|
| | |
| | device = os.getenv("DEVICE", None) |
| | if device is not None: |
| | device = getattr(mx, device) |
| | else: |
| | device = mx.default_device() |
| |
|
| | if not (device == mx.gpu and not mx.metal.is_available()): |
| | return |
| |
|
| | from cuda_skip import cuda_skip |
| |
|
| | filtered_suite = unittest.TestSuite() |
| |
|
| | def filter_and_add(t): |
| | if isinstance(t, unittest.TestSuite): |
| | for sub_t in t: |
| | filter_and_add(sub_t) |
| | else: |
| | t_id = ".".join(t.id().split(".")[-2:]) |
| | if t_id in cuda_skip: |
| | print(f"Skipping {t_id}") |
| | else: |
| | filtered_suite.addTest(t) |
| |
|
| | filter_and_add(self.test) |
| | self.test = filtered_suite |
| |
|
| |
|
| | class MLXTestCase(unittest.TestCase): |
| | @property |
| | def is_apple_silicon(self): |
| | return platform.machine() == "arm64" and platform.system() == "Darwin" |
| |
|
| | def setUp(self): |
| | self.default = mx.default_device() |
| | device = os.getenv("DEVICE", None) |
| | if device is not None: |
| | device = getattr(mx, device) |
| | mx.set_default_device(device) |
| |
|
| | def tearDown(self): |
| | mx.set_default_device(self.default) |
| |
|
| | |
| | def assertCmpNumpy( |
| | self, |
| | args: List[Union[Tuple[int], Any]], |
| | mx_fn: Callable[..., mx.array], |
| | np_fn: Callable[..., np.array], |
| | atol=1e-2, |
| | rtol=1e-2, |
| | dtype=mx.float32, |
| | **kwargs, |
| | ): |
| | assert dtype != mx.bfloat16, "numpy does not support bfloat16" |
| | args = [ |
| | mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s |
| | for s in args |
| | ] |
| | mx_res = mx_fn(*args, **kwargs) |
| | np_res = np_fn( |
| | *[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs |
| | ) |
| | return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol) |
| |
|
| | def assertEqualArray( |
| | self, |
| | mx_res: mx.array, |
| | expected: mx.array, |
| | atol=1e-2, |
| | rtol=1e-2, |
| | ): |
| | self.assertEqual( |
| | tuple(mx_res.shape), |
| | tuple(expected.shape), |
| | msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}", |
| | ) |
| | self.assertEqual( |
| | mx_res.dtype, |
| | expected.dtype, |
| | msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}", |
| | ) |
| | if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array): |
| | np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) |
| | return |
| | elif not isinstance(mx_res, mx.array): |
| | mx_res = mx.array(mx_res) |
| | elif not isinstance(expected, mx.array): |
| | expected = mx.array(expected) |
| | self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) |
| |
|