File size: 6,166 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""The testing package contains testing-specific utilities."""
import contextlib
import importlib
from abc import ABC, abstractmethod
from copy import deepcopy
from itertools import product
from typing import Any, Optional

import torch

__all__ = ['tensor_to_gradcheck_var', 'create_eye_batch', 'xla_is_available', 'assert_close']


def xla_is_available() -> bool:
    """Return whether `torch_xla` is available in the system."""
    if importlib.util.find_spec("torch_xla") is not None:
        return True
    return False


# TODO: Isn't this function duplicated with eye_like?
def create_eye_batch(batch_size, eye_size, device=None, dtype=None):
    """Create a batch of identity matrices of shape Bx3x3."""
    return torch.eye(eye_size, device=device, dtype=dtype).view(1, eye_size, eye_size).expand(batch_size, -1, -1)


def create_random_homography(batch_size, eye_size, std_val=1e-3):
    """Create a batch of random homographies of shape Bx3x3."""
    std = torch.FloatTensor(batch_size, eye_size, eye_size)
    eye = create_eye_batch(batch_size, eye_size)
    return eye + std.uniform_(-std_val, std_val)


def tensor_to_gradcheck_var(tensor, dtype=torch.float64, requires_grad=True):
    """Convert the input tensor to a valid variable to check the gradient.

    `gradcheck` needs 64-bit floating point and requires gradient.
    """
    if not torch.is_tensor(tensor):
        raise AssertionError(type(tensor))
    return tensor.requires_grad_(requires_grad).type(dtype)


def dict_to(data: dict, device: torch.device, dtype: torch.dtype) -> dict:
    out: dict = {}
    for key, val in data.items():
        out[key] = val.to(device, dtype) if isinstance(val, torch.Tensor) else val
    return out


def compute_patch_error(x, y, h, w):
    """Compute the absolute error between patches."""
    return torch.abs(x - y)[..., h // 4: -h // 4, w // 4: -w // 4].mean()


def check_is_tensor(obj):
    """Check whether the supplied object is a tensor."""
    if not isinstance(obj, torch.Tensor):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(obj)}")


def create_rectified_fundamental_matrix(batch_size):
    """Create a batch of rectified fundamental matrices of shape Bx3x3."""
    F_rect = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]).view(1, 3, 3)
    F_repeat = F_rect.repeat(batch_size, 1, 1)
    return F_repeat


def create_random_fundamental_matrix(batch_size, std_val=1e-3):
    """Create a batch of random fundamental matrices of shape Bx3x3."""
    F_rect = create_rectified_fundamental_matrix(batch_size)
    H_left = create_random_homography(batch_size, 3, std_val)
    H_right = create_random_homography(batch_size, 3, std_val)
    return H_left.permute(0, 2, 1) @ F_rect @ H_right


class BaseTester(ABC):
    @abstractmethod
    def test_smoke(self):
        raise NotImplementedError("Implement a stupid routine.")

    @abstractmethod
    def test_exception(self):
        raise NotImplementedError("Implement a stupid routine.")

    @abstractmethod
    def test_cardinality(self):
        raise NotImplementedError("Implement a stupid routine.")

    @abstractmethod
    def test_jit(self):
        raise NotImplementedError("Implement a stupid routine.")

    @abstractmethod
    def test_gradcheck(self):
        raise NotImplementedError("Implement a stupid routine.")

    @abstractmethod
    def test_module(self):
        raise NotImplementedError("Implement a stupid routine.")


def cartesian_product_of_parameters(**possible_parameters):
    """Create cartesian product of given parameters."""
    parameter_names = possible_parameters.keys()
    possible_values = [possible_parameters[parameter_name] for parameter_name in parameter_names]

    for param_combination in product(*possible_values):
        yield dict(zip(parameter_names, param_combination))


def default_with_one_parameter_changed(*, default={}, **possible_parameters):
    if not isinstance(default, dict):
        raise AssertionError(f"default should be a dict not a {type(default)}")

    for parameter_name, possible_values in possible_parameters.items():
        for v in possible_values:
            param_set = deepcopy(default)
            param_set[parameter_name] = v
            yield param_set


def _get_precision(device: torch.device, dtype: torch.dtype) -> float:
    if 'xla' in device.type:
        return 1e-2
    if dtype == torch.float16:
        return 1e-3
    return 1e-4


def _get_precision_by_name(
    device: torch.device, device_target: str, tol_val: float, tol_val_default: float = 1e-4
) -> float:
    if device_target not in ['cpu', 'cuda', 'xla']:
        raise ValueError(f"Invalid device name: {device_target}.")

    if device_target in device.type:
        return tol_val

    return tol_val_default


try:
    # torch.testing.assert_close is only available for torch>=1.9
    from torch.testing import assert_close as _assert_close  # type: ignore
    from torch.testing._core import _get_default_tolerance  # type: ignore

    def assert_close(
        actual: torch.Tensor,
        expected: torch.Tensor,
        *,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
        **kwargs: Any,
    ) -> None:
        if rtol is None and atol is None:
            with contextlib.suppress(Exception):
                rtol, atol = _get_default_tolerance(actual, expected)

        return _assert_close(actual, expected, rtol=rtol, atol=atol, check_stride=False, equal_nan=True, **kwargs)

except ImportError:
    # Partial backport of torch.testing.assert_close for torch<1.9
    # TODO: remove this branch if kornia relies on torch>=1.9
    from torch.testing import assert_allclose as _assert_allclose

    class UsageError(Exception):
        pass

    def assert_close(
        actual: torch.Tensor,
        expected: torch.Tensor,
        *,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
        **kwargs: Any,
    ) -> None:
        try:
            return _assert_allclose(actual, expected, rtol=rtol, atol=atol, **kwargs)
        except ValueError as error:
            raise UsageError(str(error)) from error