File size: 3,347 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright © 2023-2024 Apple Inc.

import unittest

import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np

try:
    import torch
    import torch.nn.functional as F

    has_torch = True
except ImportError as e:
    has_torch = False


class TestUpsample(mlx_tests.MLXTestCase):
    @unittest.skipIf(not has_torch, "requires Torch")
    def test_torch_upsample(self):
        def run_upsample(
            N,
            C,
            idim,
            scale_factor,
            mode,
            align_corner,
            dtype="float32",
            atol=1e-5,
        ):
            with self.subTest(
                N=N,
                C=C,
                idim=idim,
                scale_factor=scale_factor,
                mode=mode,
                align_corner=align_corner,
            ):
                np_dtype = getattr(np, dtype)
                np.random.seed(0)
                iH, iW = idim
                in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype)

                in_mx = mx.array(in_np)
                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")

                out_mx = nn.Upsample(
                    scale_factor=scale_factor,
                    mode=mode,
                    align_corners=align_corner,
                )(in_mx)
                mode_pt = {
                    "nearest": "nearest",
                    "linear": "bilinear",
                    "cubic": "bicubic",
                }[mode]
                out_pt = F.interpolate(
                    in_pt,
                    scale_factor=scale_factor,
                    mode=mode_pt,
                    align_corners=align_corner if mode != "nearest" else None,
                )
                out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
                self.assertEqual(out_pt.shape, out_mx.shape)
                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))

        for dtype in ("float32",):
            for N, C in ((1, 1), (2, 3)):
                # only test cases in which target sizes are intergers
                # if not, there will be numerical difference between mlx
                # and torch due to different indices selection.
                for idim, scale_factor in (
                    ((2, 2), (1.0, 1.0)),
                    ((2, 2), (1.5, 1.5)),
                    ((2, 2), (2.0, 2.0)),
                    ((4, 4), (0.5, 0.5)),
                    ((7, 7), (2.0, 2.0)),
                    ((10, 10), (0.2, 0.2)),
                    ((10, 10), (0.3, 0.3)),
                    ((11, 21), (3.0, 3.0)),
                    ((11, 21), (3.0, 2.0)),
                ):
                    for mode in ("cubic", "linear", "nearest"):
                        for align_corner in (False, True):
                            if mode == "nearest" and align_corner:
                                continue
                            run_upsample(
                                N,
                                C,
                                idim,
                                scale_factor,
                                mode,
                                align_corner,
                                dtype=dtype,
                            )


if __name__ == "__main__":
    mlx_tests.MLXTestRunner()