compvis / kornia /testing /__init__.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
"""The testing package contains testing-specific utilities."""
import contextlib
import importlib
from abc import ABC, abstractmethod
from copy import deepcopy
from itertools import product
from typing import Any, Optional
import torch
__all__ = ['tensor_to_gradcheck_var', 'create_eye_batch', 'xla_is_available', 'assert_close']
def xla_is_available() -> bool:
"""Return whether `torch_xla` is available in the system."""
if importlib.util.find_spec("torch_xla") is not None:
return True
return False
# TODO: Isn't this function duplicated with eye_like?
def create_eye_batch(batch_size, eye_size, device=None, dtype=None):
"""Create a batch of identity matrices of shape Bx3x3."""
return torch.eye(eye_size, device=device, dtype=dtype).view(1, eye_size, eye_size).expand(batch_size, -1, -1)
def create_random_homography(batch_size, eye_size, std_val=1e-3):
"""Create a batch of random homographies of shape Bx3x3."""
std = torch.FloatTensor(batch_size, eye_size, eye_size)
eye = create_eye_batch(batch_size, eye_size)
return eye + std.uniform_(-std_val, std_val)
def tensor_to_gradcheck_var(tensor, dtype=torch.float64, requires_grad=True):
"""Convert the input tensor to a valid variable to check the gradient.
`gradcheck` needs 64-bit floating point and requires gradient.
"""
if not torch.is_tensor(tensor):
raise AssertionError(type(tensor))
return tensor.requires_grad_(requires_grad).type(dtype)
def dict_to(data: dict, device: torch.device, dtype: torch.dtype) -> dict:
out: dict = {}
for key, val in data.items():
out[key] = val.to(device, dtype) if isinstance(val, torch.Tensor) else val
return out
def compute_patch_error(x, y, h, w):
"""Compute the absolute error between patches."""
return torch.abs(x - y)[..., h // 4: -h // 4, w // 4: -w // 4].mean()
def check_is_tensor(obj):
"""Check whether the supplied object is a tensor."""
if not isinstance(obj, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(obj)}")
def create_rectified_fundamental_matrix(batch_size):
"""Create a batch of rectified fundamental matrices of shape Bx3x3."""
F_rect = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]).view(1, 3, 3)
F_repeat = F_rect.repeat(batch_size, 1, 1)
return F_repeat
def create_random_fundamental_matrix(batch_size, std_val=1e-3):
"""Create a batch of random fundamental matrices of shape Bx3x3."""
F_rect = create_rectified_fundamental_matrix(batch_size)
H_left = create_random_homography(batch_size, 3, std_val)
H_right = create_random_homography(batch_size, 3, std_val)
return H_left.permute(0, 2, 1) @ F_rect @ H_right
class BaseTester(ABC):
@abstractmethod
def test_smoke(self):
raise NotImplementedError("Implement a stupid routine.")
@abstractmethod
def test_exception(self):
raise NotImplementedError("Implement a stupid routine.")
@abstractmethod
def test_cardinality(self):
raise NotImplementedError("Implement a stupid routine.")
@abstractmethod
def test_jit(self):
raise NotImplementedError("Implement a stupid routine.")
@abstractmethod
def test_gradcheck(self):
raise NotImplementedError("Implement a stupid routine.")
@abstractmethod
def test_module(self):
raise NotImplementedError("Implement a stupid routine.")
def cartesian_product_of_parameters(**possible_parameters):
"""Create cartesian product of given parameters."""
parameter_names = possible_parameters.keys()
possible_values = [possible_parameters[parameter_name] for parameter_name in parameter_names]
for param_combination in product(*possible_values):
yield dict(zip(parameter_names, param_combination))
def default_with_one_parameter_changed(*, default={}, **possible_parameters):
if not isinstance(default, dict):
raise AssertionError(f"default should be a dict not a {type(default)}")
for parameter_name, possible_values in possible_parameters.items():
for v in possible_values:
param_set = deepcopy(default)
param_set[parameter_name] = v
yield param_set
def _get_precision(device: torch.device, dtype: torch.dtype) -> float:
if 'xla' in device.type:
return 1e-2
if dtype == torch.float16:
return 1e-3
return 1e-4
def _get_precision_by_name(
device: torch.device, device_target: str, tol_val: float, tol_val_default: float = 1e-4
) -> float:
if device_target not in ['cpu', 'cuda', 'xla']:
raise ValueError(f"Invalid device name: {device_target}.")
if device_target in device.type:
return tol_val
return tol_val_default
try:
# torch.testing.assert_close is only available for torch>=1.9
from torch.testing import assert_close as _assert_close # type: ignore
from torch.testing._core import _get_default_tolerance # type: ignore
def assert_close(
actual: torch.Tensor,
expected: torch.Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
**kwargs: Any,
) -> None:
if rtol is None and atol is None:
with contextlib.suppress(Exception):
rtol, atol = _get_default_tolerance(actual, expected)
return _assert_close(actual, expected, rtol=rtol, atol=atol, check_stride=False, equal_nan=True, **kwargs)
except ImportError:
# Partial backport of torch.testing.assert_close for torch<1.9
# TODO: remove this branch if kornia relies on torch>=1.9
from torch.testing import assert_allclose as _assert_allclose
class UsageError(Exception):
pass
def assert_close(
actual: torch.Tensor,
expected: torch.Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
**kwargs: Any,
) -> None:
try:
return _assert_allclose(actual, expected, rtol=rtol, atol=atol, **kwargs)
except ValueError as error:
raise UsageError(str(error)) from error