| | |
| |
|
| | import math |
| | import unittest |
| | from itertools import permutations |
| |
|
| | import mlx.core as mx |
| | import mlx_tests |
| | import numpy as np |
| |
|
| | try: |
| | import torch |
| |
|
| | has_torch = True |
| | except ImportError as e: |
| | has_torch = False |
| |
|
| |
|
| | class TestBF16(mlx_tests.MLXTestCase): |
| | def __test_ops( |
| | self, |
| | ref_op, |
| | mlx_op, |
| | np_args, |
| | ref_transform=lambda x: x, |
| | mlx_transform=lambda x: mx.array(x), |
| | atol=1e-5, |
| | ): |
| | ref_args = map(ref_transform, np_args) |
| | mlx_args = map(mlx_transform, np_args) |
| |
|
| | r_ref = ref_op(*ref_args) |
| | r_mlx = mlx_op(*mlx_args) |
| |
|
| | self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol)) |
| |
|
| | def __default_test( |
| | self, |
| | op, |
| | np_args, |
| | simple_transform=lambda x: x, |
| | atol_np=1e-3, |
| | atol_torch=1e-5, |
| | np_kwargs=dict(), |
| | mlx_kwargs=dict(), |
| | torch_kwargs=dict(), |
| | torch_op=None, |
| | ): |
| | with self.subTest(reference="numpy"): |
| |
|
| | def np_transform(x): |
| | x_mx_bf16 = mx.array(x).astype(mx.bfloat16) |
| | x_mx_fp32 = x_mx_bf16.astype(mx.float32) |
| | return np.asarray(x_mx_fp32) |
| |
|
| | def mlx_fn(*args): |
| | out_bf16 = getattr(mx, op)(*args, **mlx_kwargs) |
| | return np.asarray(out_bf16.astype(mx.float32)) |
| |
|
| | def np_fn(*args): |
| | out_fp32 = getattr(np, op)(*args, **np_kwargs) |
| | return np_transform(out_fp32) |
| |
|
| | ref_op = np_fn |
| | mlx_op = mlx_fn |
| |
|
| | ref_transform = lambda x: simple_transform(np_transform(x)) |
| | mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16)) |
| |
|
| | self.__test_ops( |
| | ref_op, |
| | mlx_op, |
| | np_args, |
| | ref_transform=ref_transform, |
| | mlx_transform=mlx_transform, |
| | atol=atol_np, |
| | ) |
| |
|
| | if has_torch: |
| | with self.subTest(reference="torch"): |
| | torch_op = op if torch_op is None else torch_op |
| |
|
| | def torch_fn(*args): |
| | out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs) |
| | return out_bf16.to(torch.float32).numpy() |
| |
|
| | ref_op = torch_fn |
| | ref_transform = lambda x: simple_transform( |
| | torch.from_numpy(x).to(torch.bfloat16) |
| | ) |
| | self.__test_ops( |
| | ref_op, |
| | mlx_op, |
| | np_args, |
| | ref_transform=ref_transform, |
| | mlx_transform=mlx_transform, |
| | atol=atol_torch, |
| | ) |
| |
|
| | def test_unary_ops(self): |
| | x = np.random.rand(18, 28, 38) |
| | for op in ["abs", "exp", "log", "square", "sqrt"]: |
| | with self.subTest(op=op): |
| | np_args = (x.astype(np.float32),) |
| | self.__default_test(op, np_args) |
| |
|
| | def test_binary_ops(self): |
| | x = np.random.rand(18, 28, 38) |
| | y = np.random.rand(18, 28, 38) |
| | for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]: |
| | with self.subTest(op=op): |
| | np_args = ( |
| | x.astype(np.float32), |
| | y.astype(np.float32), |
| | ) |
| | self.__default_test(op, np_args, simple_transform=lambda x: x) |
| | self.__default_test(op, np_args, simple_transform=lambda x: x[:1]) |
| | self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1]) |
| |
|
| | def test_reduction_ops(self): |
| | x = np.random.rand(18, 28, 38).astype(np.float32) |
| |
|
| | for op in ("min", "max"): |
| | with self.subTest(op=op): |
| | for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): |
| | with self.subTest(axes=axes): |
| | np_args = (x.astype(np.float32),) |
| | self.__default_test( |
| | op, |
| | np_args, |
| | np_kwargs={"axis": axes}, |
| | mlx_kwargs={"axis": axes}, |
| | torch_kwargs={"dim": axes}, |
| | torch_op="a" + op, |
| | ) |
| |
|
| | def test_arg_reduction_ops(self): |
| | data = np.random.rand(10, 12, 13).astype(np.float32) |
| | x = mx.array(data).astype(mx.bfloat16) |
| | data = np.asarray(x.astype(mx.float32)) |
| |
|
| | for op in ["argmin", "argmax"]: |
| | for axis in range(3): |
| | for kd in [True, False]: |
| | a = getattr(mx, op)(x, axis, kd) |
| | b = getattr(np, op)(data, axis, keepdims=kd) |
| | a = a.astype(mx.float32) |
| | self.assertEqual(a.tolist(), b.tolist()) |
| |
|
| | for op in ["argmin", "argmax"]: |
| | a = getattr(mx, op)(x, keepdims=True) |
| | b = getattr(np, op)(data, keepdims=True) |
| | a = a.astype(mx.float32) |
| | self.assertEqual(a.tolist(), b.tolist()) |
| | a = getattr(mx, op)(x) |
| | b = getattr(np, op)(data) |
| | a = a.astype(mx.float32) |
| | self.assertEqual(a.item(), b) |
| |
|
| | def test_blas_ops(self): |
| | if mx.default_device() != mx.gpu: |
| | return |
| |
|
| | def test_blas(shape_x, shape_y): |
| | np.random.seed(42) |
| | with self.subTest(shape_x=shape_x, shape_y=shape_y): |
| | x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x) |
| | y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y) |
| |
|
| | np_args = ( |
| | x.astype(np.float32), |
| | y.astype(np.float32), |
| | ) |
| | op = "matmul" |
| |
|
| | self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3) |
| |
|
| | for shape_x, shape_y in [ |
| | [(32, 32), (32, 32)], |
| | [(23, 57), (57, 1)], |
| | [(1, 3), (3, 128)], |
| | [(8, 128, 768), (768, 16)], |
| | ]: |
| | test_blas(shape_x, shape_y) |
| |
|
| | @unittest.skipIf(not has_torch, "requires PyTorch") |
| | def test_conversion(self): |
| | a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16) |
| | a_mx = mx.array(a_torch) |
| | expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16) |
| | self.assertEqual(a_mx.dtype, mx.bfloat16) |
| | self.assertTrue(mx.array_equal(a_mx, expected)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|