| | |
| |
|
| | import math |
| | import unittest |
| | from itertools import permutations |
| |
|
| | import mlx.core as mx |
| | import mlx_tests |
| | import numpy as np |
| |
|
| | try: |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | has_torch = True |
| | except ImportError as e: |
| | has_torch = False |
| |
|
| |
|
| | class TestConvTranspose(mlx_tests.MLXTestCase): |
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_1D(self): |
| | def run_conv_transpose_1D( |
| | N, |
| | C, |
| | O, |
| | iH, |
| | kH, |
| | stride, |
| | padding, |
| | output_padding=0, |
| | dilation=1, |
| | groups=1, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | iH=iH, |
| | kH=kH, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) |
| | wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype( |
| | np_dtype |
| | ) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) |
| | wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) |
| |
|
| | out_mx = mx.conv_transpose1d( |
| | in_mx, |
| | wt_mx, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| | out_pt = torch.conv_transpose1d( |
| | in_pt, |
| | wt_pt, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| | out_pt = torch.transpose(out_pt, 2, 1) |
| |
|
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ( |
| | (1, 1, 1), |
| | (1, 6, 1), |
| | (1, 1, 6), |
| | (4, 32, 64), |
| | ): |
| | for iH, kH, stride, padding in ( |
| | (1, 1, 1, 0), |
| | (3, 3, 1, 0), |
| | (31, 5, 5, 2), |
| | ): |
| | run_conv_transpose_1D(N, C, O, iH, kH, stride, padding, dtype=dtype) |
| |
|
| | |
| | N, C, O = (4, 32, 64) |
| | for iH, kH, stride, padding in ( |
| | (1, 1, 1, 0), |
| | (3, 3, 1, 0), |
| | (31, 5, 5, 2), |
| | ): |
| | for group in (1,): |
| | run_conv_transpose_1D( |
| | N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype |
| | ) |
| |
|
| | |
| | for tpose_in, tpose_wt in ( |
| | ((0, 2, 1), (0, 1, 2)), |
| | ((0, 2, 1), (0, 2, 1)), |
| | ): |
| | with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt): |
| | in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32) |
| | wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_mx_t = mx.transpose(in_mx, tpose_in) |
| | wt_mx_t = mx.transpose(wt_mx, tpose_wt) |
| | out_mx = mx.conv_transpose1d(in_mx_t, wt_mx_t) |
| |
|
| | in_pt = torch.from_numpy(in_np.transpose(tpose_in).transpose(0, 2, 1)) |
| | wt_pt = torch.from_numpy(wt_np.transpose(tpose_wt).transpose(2, 0, 1)) |
| |
|
| | out_pt = torch.conv_transpose1d(in_pt, wt_pt) |
| | out_pt = torch.transpose(out_pt, 2, 1) |
| |
|
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5)) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_1D_grad(self): |
| | def run_conv_transpose1D_grad( |
| | N, |
| | C, |
| | O, |
| | iH, |
| | kH, |
| | stride, |
| | padding, |
| | dilation=1, |
| | groups=1, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | iH=iH, |
| | kH=kH, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | |
| |
|
| | in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) |
| | wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)).requires_grad_(True) |
| | wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)).requires_grad_(True) |
| |
|
| | out_pt = F.conv_transpose1d( |
| | in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation |
| | ) |
| |
|
| | |
| | out_pt.retain_grad() |
| | out_pt.sum().backward() |
| |
|
| | pt_grad_in = in_pt.grad.permute(0, 2, 1).numpy() |
| | pt_grad_wt = wt_pt.grad.permute(1, 2, 0).numpy() |
| |
|
| | ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 1)) |
| |
|
| | def f(a, b): |
| | return mx.conv_transpose1d( |
| | a, |
| | b, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| |
|
| | _, outs_mx = mx.vjp( |
| | f, |
| | [ |
| | in_mx, |
| | wt_mx, |
| | ], |
| | [ |
| | ct_mx, |
| | ], |
| | ) |
| |
|
| | mx_grad_in, mx_grad_wt = outs_mx |
| |
|
| | self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) |
| | self.assertEqual(in_mx.shape, mx_grad_in.shape) |
| | self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) |
| |
|
| | self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) |
| | self.assertEqual(wt_mx.shape, mx_grad_wt.shape) |
| | self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ( |
| | (1, 1, 1), |
| | (1, 6, 1), |
| | (1, 1, 6), |
| | (4, 32, 64), |
| | ): |
| | for iH, kH, stride, padding in ( |
| | (1, 1, 1, 0), |
| | (3, 3, 1, 0), |
| | (31, 5, 5, 2), |
| | ): |
| | run_conv_transpose1D_grad( |
| | N, C, O, iH, kH, stride, padding, dtype=dtype |
| | ) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_2D(self): |
| | def run_conv_transpose2D( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | dilation=(1, 1), |
| | groups=1, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | idim=idim, |
| | kdim=kdim, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | iH, iW = idim |
| | kH, kW = kdim |
| | scale = 1.0 / math.sqrt(kH * kW * C) |
| | in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) |
| | wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype( |
| | np_dtype |
| | ) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu") |
| | wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).to("cpu") |
| |
|
| | out_mx = mx.conv_transpose2d( |
| | in_mx, |
| | wt_mx, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| | out_pt = torch.conv_transpose2d( |
| | in_pt, |
| | wt_pt, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| | out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) |
| |
|
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ( |
| | (1, 1, 1), |
| | (1, 6, 1), |
| | (1, 1, 6), |
| | (4, 32, 64), |
| | ): |
| | for idim, kdim, stride, padding in ( |
| | ((1, 1), (1, 1), (1, 1), (0, 0)), |
| | ((3, 3), (3, 1), (1, 1), (0, 0)), |
| | ((31, 31), (5, 5), (5, 5), (2, 2)), |
| | ): |
| | run_conv_transpose2D( |
| | N, C, O, idim, kdim, stride, padding, dtype=dtype |
| | ) |
| |
|
| | |
| | N, C, O = (4, 32, 64) |
| | for idim, kdim, stride, padding in ( |
| | ((1, 1), (1, 1), (1, 1), (0, 0)), |
| | ((3, 3), (3, 1), (1, 1), (0, 0)), |
| | ((31, 31), (5, 5), (5, 5), (2, 2)), |
| | ): |
| | for group in (1,): |
| | run_conv_transpose2D( |
| | N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype |
| | ) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_2D_grad(self): |
| | def run_conv_transpose2D_grad( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | dilation=(1, 1), |
| | groups=1, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | idim=idim, |
| | kdim=kdim, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | iH, iW = idim |
| | kH, kW = kdim |
| | scale = 1.0 / math.sqrt(kH * kW * C * O) |
| |
|
| | in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) |
| | wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).requires_grad_( |
| | True |
| | ) |
| | wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)).requires_grad_( |
| | True |
| | ) |
| |
|
| | out_pt = F.conv_transpose2d( |
| | in_pt, wt_pt, stride=stride, padding=padding, dilation=dilation |
| | ) |
| |
|
| | |
| | out_pt.retain_grad() |
| | out_pt.sum().backward() |
| |
|
| | pt_grad_in = in_pt.grad.permute(0, 2, 3, 1).numpy() |
| | pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 0).numpy() |
| |
|
| | ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 1)) |
| |
|
| | def f(a, b): |
| | return mx.conv_transpose2d( |
| | a, |
| | b, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| |
|
| | _, outs_mx = mx.vjp( |
| | f, |
| | [in_mx, wt_mx], |
| | [ct_mx], |
| | ) |
| |
|
| | mx_grad_in, mx_grad_wt = outs_mx |
| |
|
| | self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) |
| | self.assertEqual(in_mx.shape, mx_grad_in.shape) |
| | self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) |
| |
|
| | self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) |
| | self.assertEqual(wt_mx.shape, mx_grad_wt.shape) |
| | self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)): |
| | for idim, kdim, stride, padding, dilation in ( |
| | ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)), |
| | ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)), |
| | ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)), |
| | ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)), |
| | ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)), |
| | ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)), |
| | ): |
| | run_conv_transpose2D_grad( |
| | N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype |
| | ) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_3D(self): |
| | def run_conv_transpose3D( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | dilation=(1, 1, 1), |
| | groups=1, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | idim=idim, |
| | kdim=kdim, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | iD, iH, iW = idim |
| | kD, kH, kW = kdim |
| | scale = 1.0 / math.sqrt(kD * kH * kW * C * O) |
| | in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype( |
| | np_dtype |
| | ) |
| | wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) |
| | wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) |
| |
|
| | out_mx = mx.conv_transpose3d( |
| | in_mx, |
| | wt_mx, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| | out_pt = torch.conv_transpose3d( |
| | in_pt, |
| | wt_pt, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| | out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) |
| |
|
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ( |
| | (1, 1, 1), |
| | (1, 6, 1), |
| | (1, 1, 6), |
| | (2, 8, 16), |
| | ): |
| | for idim, kdim, stride, padding in ( |
| | ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)), |
| | ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)), |
| | ((15, 15, 15), (3, 3, 3), (3, 3, 3), (2, 2, 2)), |
| | ): |
| | run_conv_transpose3D( |
| | N, C, O, idim, kdim, stride, padding, dtype=dtype |
| | ) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_3D_grad(self): |
| | def run_conv_transpose3D_grad( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | dilation=(1, 1, 1), |
| | groups=1, |
| | dtype="float32", |
| | atol=1e-4, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | idim=idim, |
| | kdim=kdim, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | iD, iH, iW = idim |
| | kD, kH, kW = kdim |
| | scale = 1.0 / math.sqrt(kD * kH * kW * C * O) |
| |
|
| | in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype( |
| | np_dtype |
| | ) |
| | wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype( |
| | np_dtype |
| | ) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)).requires_grad_( |
| | True |
| | ) |
| | wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)).requires_grad_( |
| | True |
| | ) |
| |
|
| | out_pt = F.conv_transpose3d( |
| | in_pt, |
| | wt_pt, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| |
|
| | |
| | out_pt.retain_grad() |
| | out_pt.sum().backward() |
| |
|
| | pt_grad_in = in_pt.grad.permute(0, 2, 3, 4, 1).numpy() |
| | pt_grad_wt = wt_pt.grad.permute(1, 2, 3, 4, 0).numpy() |
| |
|
| | ct_mx = mx.array(out_pt.grad.numpy().transpose(0, 2, 3, 4, 1)) |
| |
|
| | def f(a, b): |
| | return mx.conv_transpose3d( |
| | a, |
| | b, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | ) |
| |
|
| | _, outs_mx = mx.vjp( |
| | f, |
| | [in_mx, wt_mx], |
| | [ct_mx], |
| | ) |
| |
|
| | mx_grad_in, mx_grad_wt = outs_mx |
| |
|
| | self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) |
| | self.assertEqual(in_mx.shape, mx_grad_in.shape) |
| | self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) |
| |
|
| | self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) |
| | self.assertEqual(wt_mx.shape, mx_grad_wt.shape) |
| | self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (2, 4, 8), (2, 8, 16)): |
| | for idim, kdim, stride, padding, dilation in ( |
| | ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), |
| | ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)), |
| | ((7, 7, 7), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)), |
| | ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)), |
| | ((7, 7, 7), (5, 5, 5), (3, 3, 3), (2, 2, 2), (3, 2, 2)), |
| | ((8, 8, 8), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)), |
| | ): |
| | run_conv_transpose3D_grad( |
| | N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype |
| | ) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_tranpose_1d_output_padding(self): |
| | def run_conv_transpose_1d_output_padding( |
| | N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5 |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | iH=iH, |
| | kH=kH, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) |
| | wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) |
| | wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) |
| |
|
| | out_mx = mx.conv_transpose1d( |
| | in_mx, |
| | wt_mx, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ) |
| |
|
| | out_pt = torch.conv_transpose1d( |
| | in_pt, |
| | wt_pt, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ) |
| | out_pt = torch.transpose(out_pt, 2, 1) |
| |
|
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): |
| | for iH, kH, stride, padding, output_padding in ( |
| | (3, 2, 2, 0, 1), |
| | (5, 3, 2, 1, 0), |
| | (7, 4, 3, 1, 2), |
| | ): |
| | run_conv_transpose_1d_output_padding( |
| | N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype |
| | ) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_2d_output_padding(self): |
| | def run_conv_transpose_2d_output_padding( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | output_padding, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | idim=idim, |
| | kdim=kdim, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | iH, iW = idim |
| | kH, kW = kdim |
| | in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype) |
| | wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)) |
| | wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)) |
| |
|
| | out_mx = mx.conv_transpose2d( |
| | in_mx, |
| | wt_mx, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ) |
| |
|
| | out_pt = torch.conv_transpose2d( |
| | in_pt, |
| | wt_pt, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ) |
| | out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) |
| |
|
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): |
| | for idim, kdim, stride, padding, output_padding in ( |
| | ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)), |
| | ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)), |
| | ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)), |
| | ): |
| | run_conv_transpose_2d_output_padding( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | output_padding, |
| | dtype=dtype, |
| | ) |
| |
|
| | @unittest.skipIf(not has_torch, "requires Torch") |
| | def test_torch_conv_transpose_3d_output_padding(self): |
| | def run_conv_transpose_3d_output_padding( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | output_padding, |
| | dtype="float32", |
| | atol=1e-5, |
| | ): |
| | with self.subTest( |
| | dtype=dtype, |
| | N=N, |
| | C=C, |
| | O=O, |
| | idim=idim, |
| | kdim=kdim, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ): |
| | np_dtype = getattr(np, dtype) |
| | np.random.seed(0) |
| | iD, iH, iW = idim |
| | kD, kH, kW = kdim |
| | in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype( |
| | np_dtype |
| | ) |
| | wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype( |
| | np_dtype |
| | ) |
| |
|
| | in_mx, wt_mx = map(mx.array, (in_np, wt_np)) |
| | in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) |
| | wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) |
| |
|
| | out_mx = mx.conv_transpose3d( |
| | in_mx, |
| | wt_mx, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ) |
| | out_pt = torch.conv_transpose3d( |
| | in_pt, |
| | wt_pt, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | ) |
| | out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) |
| |
|
| | self.assertEqual(out_pt.shape, out_mx.shape) |
| | self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) |
| |
|
| | for dtype in ("float32",): |
| | for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): |
| | for idim, kdim, stride, padding, output_padding in ( |
| | ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)), |
| | ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)), |
| | ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)), |
| | ): |
| | run_conv_transpose_3d_output_padding( |
| | N, |
| | C, |
| | O, |
| | idim, |
| | kdim, |
| | stride, |
| | padding, |
| | output_padding, |
| | dtype=dtype, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|