| | |
| |
|
| | import itertools |
| | import unittest |
| |
|
| | import mlx.core as mx |
| | import mlx_tests |
| | import numpy as np |
| |
|
| | try: |
| | import torch |
| |
|
| | has_torch = True |
| | except ImportError as e: |
| | has_torch = False |
| |
|
| |
|
| | class TestFFT(mlx_tests.MLXTestCase): |
| | def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs): |
| | out_np = op_np(a_np, **kwargs) |
| | a_mx = mx.array(a_np) |
| | out_mx = op_mx(a_mx, **kwargs) |
| | np.testing.assert_allclose(out_np, out_mx, atol=atol, rtol=rtol) |
| |
|
| | def test_fft(self): |
| | r = np.random.rand(100).astype(np.float32) |
| | i = np.random.rand(100).astype(np.float32) |
| | a_np = r + 1j * i |
| | self.check_mx_np(mx.fft.fft, np.fft.fft, a_np) |
| |
|
| | |
| | r = np.random.rand(100).astype(np.float32) |
| | i = np.random.rand(100).astype(np.float32) |
| | a_np = r + 1j * i |
| | self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) |
| | self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) |
| |
|
| | |
| | r = np.random.rand(100, 100).astype(np.float32) |
| | i = np.random.rand(100, 100).astype(np.float32) |
| | a_np = r + 1j * i |
| | self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) |
| | self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) |
| |
|
| | |
| | a_np = np.random.rand(100).astype(np.float32) |
| | self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) |
| | self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) |
| | self.check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) |
| |
|
| | |
| | r = np.random.rand(100, 100).astype(np.float32) |
| | i = np.random.rand(100, 100).astype(np.float32) |
| | a_np = r + 1j * i |
| | self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) |
| | self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) |
| | self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) |
| |
|
| | x = np.fft.rfft(np.real(a_np)) |
| | self.check_mx_np(mx.fft.irfft, np.fft.irfft, x) |
| |
|
| | def test_fftn(self): |
| | r = np.random.randn(8, 8, 8).astype(np.float32) |
| | i = np.random.randn(8, 8, 8).astype(np.float32) |
| | a = r + 1j * i |
| |
|
| | axes = [None, (1, 2), (2, 1), (0, 2)] |
| | shapes = [None, (10, 5), (5, 10)] |
| | ops = [ |
| | "fft2", |
| | "ifft2", |
| | "rfft2", |
| | "irfft2", |
| | "fftn", |
| | "ifftn", |
| | "rfftn", |
| | "irfftn", |
| | ] |
| |
|
| | for op, ax, s in itertools.product(ops, axes, shapes): |
| | if ax is None and s is not None: |
| | continue |
| | x = a |
| | if op in ["rfft2", "rfftn"]: |
| | x = r |
| | elif op == "irfft2": |
| | x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s)) |
| | elif op == "irfftn": |
| | x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s)) |
| | mx_op = getattr(mx.fft, op) |
| | np_op = getattr(np.fft, op) |
| | self.check_mx_np(mx_op, np_op, x, axes=ax, s=s) |
| |
|
| | def _run_ffts(self, shape, atol=1e-4, rtol=1e-4): |
| | np.random.seed(9) |
| |
|
| | r = np.random.rand(*shape).astype(np.float32) |
| | i = np.random.rand(*shape).astype(np.float32) |
| | a_np = r + 1j * i |
| | self.check_mx_np(mx.fft.fft, np.fft.fft, a_np, atol=atol, rtol=rtol) |
| | self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, atol=atol, rtol=rtol) |
| |
|
| | self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol) |
| |
|
| | ia_np = np.fft.rfft(r) |
| | self.check_mx_np( |
| | mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1] |
| | ) |
| | self.check_mx_np(mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol) |
| |
|
| | def test_fft_shared_mem(self): |
| | nums = np.concatenate( |
| | [ |
| | |
| | np.arange(2, 14), |
| | |
| | [2**k for k in range(4, 13)], |
| | |
| | [3 * 3 * 3, 3 * 11, 11 * 13 * 2, 7 * 4 * 13 * 11, 13 * 13 * 11], |
| | |
| | [17, 23, 29, 17 * 8 * 3, 23 * 2, 1153, 1982], |
| | |
| | [47, 83, 17 * 17], |
| | |
| | [3159, 3645, 3969, 4004], |
| | ] |
| | ) |
| | for batch_size in (1, 3, 32): |
| | for num in nums: |
| | atol = 1e-4 if num < 1025 else 1e-3 |
| | self._run_ffts((batch_size, num), atol=atol) |
| |
|
| | @unittest.skip("Too slow for CI but useful for local testing.") |
| | def test_fft_exhaustive(self): |
| | nums = range(2, 4097) |
| | for batch_size in (1, 3, 32): |
| | for num in nums: |
| | print(num) |
| | atol = 1e-4 if num < 1025 else 1e-3 |
| | self._run_ffts((batch_size, num), atol=atol) |
| |
|
| | def test_fft_big_powers_of_two(self): |
| | |
| | for k in range(12, 17): |
| | self._run_ffts((3, 2**k), atol=1e-3) |
| |
|
| | for k in range(17, 20): |
| | self._run_ffts((3, 2**k), atol=1e-2) |
| |
|
| | def test_fft_large_numbers(self): |
| | numbers = [ |
| | 1037, |
| | 18247, |
| | 1259 * 11, |
| | 7883, |
| | 3**8, |
| | 3109, |
| | 4006, |
| | ] |
| | for large_num in numbers: |
| | self._run_ffts((1, large_num), atol=1e-3) |
| |
|
| | def test_fft_contiguity(self): |
| | r = np.random.rand(4, 8).astype(np.float32) |
| | i = np.random.rand(4, 8).astype(np.float32) |
| | a_np = r + 1j * i |
| | a_mx = mx.array(a_np) |
| |
|
| | |
| | out_mx = mx.fft.fft(a_mx[:, ::2]) |
| | out_np = np.fft.fft(a_np[:, ::2]) |
| | np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5) |
| |
|
| | |
| | out_mx = mx.fft.fft(a_mx[::2]) |
| | out_np = np.fft.fft(a_np[::2]) |
| | np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5) |
| |
|
| | out_mx = mx.broadcast_to(mx.reshape(mx.transpose(a_mx), (4, 8, 1)), (4, 8, 16)) |
| | out_np = np.broadcast_to(np.reshape(np.transpose(a_np), (4, 8, 1)), (4, 8, 16)) |
| | np.testing.assert_allclose(out_np, out_mx, atol=1e-5, rtol=1e-5) |
| |
|
| | out2_mx = mx.fft.fft(mx.abs(out_mx) + 4) |
| | out2_np = np.fft.fft(np.abs(out_np) + 4) |
| | np.testing.assert_allclose(out2_mx, out2_np, atol=1e-5, rtol=1e-5) |
| |
|
| | b_np = np.array([[0, 1, 2, 3]]) |
| | out_mx = mx.abs(mx.fft.fft(mx.tile(mx.reshape(mx.array(b_np), (1, 4)), (4, 1)))) |
| | out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1)))) |
| | np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5) |
| |
|
| | def test_fft_into_ifft(self): |
| | n_fft = 8193 |
| | mx.random.seed(0) |
| |
|
| | segment = mx.random.normal(shape=[1, n_fft]) + 1j * mx.random.normal( |
| | shape=(1, n_fft) |
| | ) |
| | segment = mx.fft.fft(segment, n=n_fft) |
| | r = mx.fft.ifft(segment, n=n_fft) |
| | r_np = np.fft.ifft(segment, n=n_fft) |
| | self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) |
| |
|
| | def test_fft_throws(self): |
| | x = mx.array(3.0) |
| | with self.assertRaises(ValueError): |
| | mx.fft.irfftn(x) |
| |
|
| | def test_fftshift(self): |
| | |
| | r = np.random.rand(100).astype(np.float32) |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) |
| |
|
| | |
| | r = np.random.rand(4, 6).astype(np.float32) |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1]) |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1]) |
| |
|
| | |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1]) |
| |
|
| | |
| | r = np.random.rand(5, 7).astype(np.float32) |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) |
| |
|
| | |
| | r = np.random.rand(8, 8).astype(np.float32) |
| | i = np.random.rand(8, 8).astype(np.float32) |
| | c = r + 1j * i |
| | self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c) |
| |
|
| | def test_ifftshift(self): |
| | |
| | r = np.random.rand(100).astype(np.float32) |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) |
| |
|
| | |
| | r = np.random.rand(4, 6).astype(np.float32) |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) |
| |
|
| | |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) |
| |
|
| | |
| | r = np.random.rand(5, 7).astype(np.float32) |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) |
| |
|
| | |
| | r = np.random.rand(8, 8).astype(np.float32) |
| | i = np.random.rand(8, 8).astype(np.float32) |
| | c = r + 1j * i |
| | self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c) |
| |
|
| | def test_fftshift_errors(self): |
| | |
| | x = mx.array(np.random.rand(4, 4).astype(np.float32)) |
| | with self.assertRaises(ValueError): |
| | mx.fft.fftshift(x, axes=[2]) |
| | with self.assertRaises(ValueError): |
| | mx.fft.fftshift(x, axes=[-3]) |
| |
|
| | |
| | x = mx.array([]) |
| | self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) |
| |
|
| | @unittest.skipIf(not has_torch, "requires PyTorch") |
| | def test_fft_grads(self): |
| | real = [True, False] |
| | inverse = [True, False] |
| | axes = [ |
| | (-1,), |
| | (-2, -1), |
| | ] |
| | shapes = [ |
| | (4, 4), |
| | (2, 4), |
| | (2, 7), |
| | (7, 7), |
| | ] |
| |
|
| | mxffts = { |
| | (True, True): mx.fft.irfftn, |
| | (True, False): mx.fft.rfftn, |
| | (False, True): mx.fft.ifftn, |
| | (False, False): mx.fft.fftn, |
| | } |
| | tffts = { |
| | (True, True): torch.fft.irfftn, |
| | (True, False): torch.fft.rfftn, |
| | (False, True): torch.fft.ifftn, |
| | (False, False): torch.fft.fftn, |
| | } |
| |
|
| | for r, i, ax, sh in itertools.product(real, inverse, axes, shapes): |
| |
|
| | def f(x): |
| | y = mxffts[r, i](x) |
| | return (mx.abs(y) ** 2).sum() |
| |
|
| | def g(x): |
| | y = tffts[r, i](x) |
| | return (torch.abs(y) ** 2).sum() |
| |
|
| | if r and not i: |
| | x = mx.random.normal(sh) |
| | else: |
| | x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() |
| | fx = f(x) |
| | gx = g(torch.tensor(x)) |
| | self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) |
| |
|
| | dfdx = mx.grad(f)(x) |
| | dgdx = torch.func.grad(g)(torch.tensor(x)) |
| | self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|