File size: 6,568 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# Copyright © 2023 Apple Inc.
import math
import unittest
from itertools import permutations
import mlx.core as mx
import mlx_tests
import numpy as np
try:
import torch
has_torch = True
except ImportError as e:
has_torch = False
class TestBF16(mlx_tests.MLXTestCase):
def __test_ops(
self,
ref_op, # Function that outputs array_like
mlx_op, # Function that outputs array_like
np_args, # Numpy arguments
ref_transform=lambda x: x,
mlx_transform=lambda x: mx.array(x),
atol=1e-5,
):
ref_args = map(ref_transform, np_args)
mlx_args = map(mlx_transform, np_args)
r_ref = ref_op(*ref_args)
r_mlx = mlx_op(*mlx_args)
self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol))
def __default_test(
self,
op,
np_args,
simple_transform=lambda x: x,
atol_np=1e-3,
atol_torch=1e-5,
np_kwargs=dict(),
mlx_kwargs=dict(),
torch_kwargs=dict(),
torch_op=None,
):
with self.subTest(reference="numpy"):
def np_transform(x):
x_mx_bf16 = mx.array(x).astype(mx.bfloat16)
x_mx_fp32 = x_mx_bf16.astype(mx.float32)
return np.asarray(x_mx_fp32)
def mlx_fn(*args):
out_bf16 = getattr(mx, op)(*args, **mlx_kwargs)
return np.asarray(out_bf16.astype(mx.float32))
def np_fn(*args):
out_fp32 = getattr(np, op)(*args, **np_kwargs)
return np_transform(out_fp32)
ref_op = np_fn
mlx_op = mlx_fn
ref_transform = lambda x: simple_transform(np_transform(x))
mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16))
self.__test_ops(
ref_op,
mlx_op,
np_args,
ref_transform=ref_transform,
mlx_transform=mlx_transform,
atol=atol_np,
)
if has_torch:
with self.subTest(reference="torch"):
torch_op = op if torch_op is None else torch_op
def torch_fn(*args):
out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs)
return out_bf16.to(torch.float32).numpy()
ref_op = torch_fn
ref_transform = lambda x: simple_transform(
torch.from_numpy(x).to(torch.bfloat16)
)
self.__test_ops(
ref_op,
mlx_op,
np_args,
ref_transform=ref_transform,
mlx_transform=mlx_transform,
atol=atol_torch,
)
def test_unary_ops(self):
x = np.random.rand(18, 28, 38)
for op in ["abs", "exp", "log", "square", "sqrt"]:
with self.subTest(op=op):
np_args = (x.astype(np.float32),)
self.__default_test(op, np_args)
def test_binary_ops(self):
x = np.random.rand(18, 28, 38)
y = np.random.rand(18, 28, 38)
for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
with self.subTest(op=op):
np_args = (
x.astype(np.float32),
y.astype(np.float32),
)
self.__default_test(op, np_args, simple_transform=lambda x: x)
self.__default_test(op, np_args, simple_transform=lambda x: x[:1])
self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1])
def test_reduction_ops(self):
x = np.random.rand(18, 28, 38).astype(np.float32)
for op in ("min", "max"):
with self.subTest(op=op):
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
with self.subTest(axes=axes):
np_args = (x.astype(np.float32),)
self.__default_test(
op,
np_args,
np_kwargs={"axis": axes},
mlx_kwargs={"axis": axes},
torch_kwargs={"dim": axes},
torch_op="a" + op,
)
def test_arg_reduction_ops(self):
data = np.random.rand(10, 12, 13).astype(np.float32)
x = mx.array(data).astype(mx.bfloat16)
data = np.asarray(x.astype(mx.float32))
for op in ["argmin", "argmax"]:
for axis in range(3):
for kd in [True, False]:
a = getattr(mx, op)(x, axis, kd)
b = getattr(np, op)(data, axis, keepdims=kd)
a = a.astype(mx.float32)
self.assertEqual(a.tolist(), b.tolist())
for op in ["argmin", "argmax"]:
a = getattr(mx, op)(x, keepdims=True)
b = getattr(np, op)(data, keepdims=True)
a = a.astype(mx.float32)
self.assertEqual(a.tolist(), b.tolist())
a = getattr(mx, op)(x)
b = getattr(np, op)(data)
a = a.astype(mx.float32)
self.assertEqual(a.item(), b)
def test_blas_ops(self):
if mx.default_device() != mx.gpu:
return
def test_blas(shape_x, shape_y):
np.random.seed(42)
with self.subTest(shape_x=shape_x, shape_y=shape_y):
x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x)
y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y)
np_args = (
x.astype(np.float32),
y.astype(np.float32),
)
op = "matmul"
self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3)
for shape_x, shape_y in [
[(32, 32), (32, 32)],
[(23, 57), (57, 1)],
[(1, 3), (3, 128)],
[(8, 128, 768), (768, 16)],
]:
test_blas(shape_x, shape_y)
@unittest.skipIf(not has_torch, "requires PyTorch")
def test_conversion(self):
a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16)
a_mx = mx.array(a_torch)
expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16)
self.assertEqual(a_mx.dtype, mx.bfloat16)
self.assertTrue(mx.array_equal(a_mx, expected))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()
|