|
|
from typing import Any, Dict, Optional, Tuple, Type |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd import gradcheck |
|
|
|
|
|
import kornia |
|
|
import kornia.testing as utils |
|
|
from kornia.augmentation import ( |
|
|
CenterCrop, |
|
|
ColorJitter, |
|
|
Denormalize, |
|
|
Normalize, |
|
|
PadTo, |
|
|
RandomBoxBlur, |
|
|
RandomChannelShuffle, |
|
|
RandomCrop, |
|
|
RandomElasticTransform, |
|
|
RandomEqualize, |
|
|
RandomErasing, |
|
|
RandomFisheye, |
|
|
RandomGaussianBlur, |
|
|
RandomGaussianNoise, |
|
|
RandomGrayscale, |
|
|
RandomHorizontalFlip, |
|
|
RandomInvert, |
|
|
RandomResizedCrop, |
|
|
RandomRotation, |
|
|
RandomThinPlateSpline, |
|
|
RandomVerticalFlip, |
|
|
) |
|
|
from kornia.augmentation.base import AugmentationBase2D |
|
|
from kornia.constants import pi, Resample |
|
|
from kornia.geometry import transform_points |
|
|
from kornia.testing import assert_close, BaseTester, default_with_one_parameter_changed |
|
|
from kornia.utils import create_meshgrid |
|
|
from kornia.utils.helpers import _torch_inverse_cast |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("device", "dtype") |
|
|
class CommonTests(BaseTester): |
|
|
fixture_names = ("device", "dtype") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_augmentation_cls: Optional[Type[AugmentationBase2D]] = None |
|
|
_default_param_set: Dict["str", Any] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
|
def auto_injector_fixture(self, request): |
|
|
for fixture_name in self.fixture_names: |
|
|
setattr(self, fixture_name, request.getfixturevalue(fixture_name)) |
|
|
|
|
|
@pytest.fixture(scope="class") |
|
|
def param_set(self, request): |
|
|
raise NotImplementedError("param_set must be overridden in subclasses") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_smoke(self, param_set): |
|
|
self._test_smoke_implementation(params=param_set) |
|
|
self._test_smoke_call_implementation(params=param_set) |
|
|
self._test_smoke_return_transform_implementation(params=param_set) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"input_shape,expected_output_shape", |
|
|
[((4, 5), (1, 1, 4, 5)), ((3, 4, 5), (1, 3, 4, 5)), ((2, 3, 4, 5), (2, 3, 4, 5))], |
|
|
) |
|
|
def test_cardinality(self, input_shape, expected_output_shape): |
|
|
self._test_cardinality_implementation( |
|
|
input_shape=input_shape, expected_output_shape=expected_output_shape, params=self._default_param_set |
|
|
) |
|
|
|
|
|
def test_random_p_0(self): |
|
|
self._test_random_p_0_implementation(params=self._default_param_set) |
|
|
|
|
|
def test_random_p_0_return_transform(self): |
|
|
self._test_random_p_0_return_transform_implementation(params=self._default_param_set) |
|
|
|
|
|
def test_random_p_1(self): |
|
|
raise NotImplementedError("Implement a stupid routine.") |
|
|
|
|
|
def test_random_p_1_return_transform(self): |
|
|
raise NotImplementedError("Implement a stupid routine.") |
|
|
|
|
|
def test_inverse_coordinate_check(self): |
|
|
self._test_inverse_coordinate_check_implementation(params=self._default_param_set) |
|
|
|
|
|
def test_exception(self): |
|
|
raise NotImplementedError("Implement a stupid routine.") |
|
|
|
|
|
def test_batch(self): |
|
|
raise NotImplementedError("Implement a stupid routine.") |
|
|
|
|
|
@pytest.mark.skip(reason="turn off all jit for a while") |
|
|
def test_jit(self): |
|
|
raise NotImplementedError("Implement a stupid routine.") |
|
|
|
|
|
def test_module(self): |
|
|
self._test_module_implementation(params=self._default_param_set) |
|
|
|
|
|
def test_gradcheck(self): |
|
|
self._test_gradcheck_implementation(params=self._default_param_set) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_augmentation_from_params(self, **params): |
|
|
return self._augmentation_cls(**params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_smoke_implementation(self, params): |
|
|
assert issubclass( |
|
|
self._augmentation_cls, AugmentationBase2D |
|
|
), f"{self._augmentation_cls} is not a subclass of AugmentationBase2D" |
|
|
|
|
|
|
|
|
augmentation = self._create_augmentation_from_params(**params, return_transform=False) |
|
|
assert issubclass( |
|
|
type(augmentation), AugmentationBase2D |
|
|
), f"{type(augmentation)} is not a subclass of AugmentationBase2D" |
|
|
|
|
|
|
|
|
batch_shape = (4, 3, 5, 6) |
|
|
generated_params = augmentation.generate_parameters(batch_shape) |
|
|
assert isinstance(generated_params, dict) |
|
|
|
|
|
|
|
|
expected_transformation_shape = torch.Size((batch_shape[0], 3, 3)) |
|
|
test_input = torch.ones(batch_shape, device=self.device, dtype=self.dtype) |
|
|
transformation = augmentation.compute_transformation(test_input, generated_params) |
|
|
assert transformation.shape == expected_transformation_shape |
|
|
|
|
|
|
|
|
output = augmentation.apply_transform(test_input, generated_params, transformation) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
|
|
|
def _test_smoke_call_implementation(self, params): |
|
|
batch_shape = (4, 3, 5, 6) |
|
|
expected_transformation_shape = torch.Size((batch_shape[0], 3, 3)) |
|
|
test_input = torch.ones(batch_shape, device=self.device, dtype=self.dtype) |
|
|
augmentation = self._create_augmentation_from_params(**params, return_transform=False) |
|
|
generated_params = augmentation.generate_parameters(batch_shape) |
|
|
test_transform = torch.rand(expected_transformation_shape, device=self.device, dtype=self.dtype) |
|
|
|
|
|
output = augmentation(test_input, params=generated_params) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
|
|
|
output, transformation = augmentation(test_input, params=generated_params, return_transform=True) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
assert transformation.shape == expected_transformation_shape |
|
|
|
|
|
output, final_transformation = augmentation( |
|
|
(test_input, test_transform), params=generated_params, return_transform=True |
|
|
) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
assert final_transformation.shape == expected_transformation_shape |
|
|
assert_close(final_transformation, transformation @ test_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
output, transformation = augmentation((test_input, test_transform), params=generated_params) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
assert transformation.shape == expected_transformation_shape |
|
|
assert (transformation == test_transform).all() |
|
|
|
|
|
def _test_smoke_return_transform_implementation(self, params): |
|
|
batch_shape = (4, 3, 5, 6) |
|
|
expected_transformation_shape = torch.Size((batch_shape[0], 3, 3)) |
|
|
test_input = torch.ones(batch_shape, device=self.device, dtype=self.dtype) |
|
|
augmentation = self._create_augmentation_from_params(**params, return_transform=True) |
|
|
generated_params = augmentation.generate_parameters(batch_shape) |
|
|
test_transform = torch.rand(expected_transformation_shape, device=self.device, dtype=self.dtype) |
|
|
|
|
|
output, transformation = augmentation(test_input, params=generated_params) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
assert transformation.shape == expected_transformation_shape |
|
|
|
|
|
output, final_transformation = augmentation((test_input, test_transform), params=generated_params) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
assert final_transformation.shape == expected_transformation_shape |
|
|
assert_close(final_transformation, transformation @ test_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
output, final_transformation = augmentation( |
|
|
(test_input, test_transform), params=generated_params, return_transform=True |
|
|
) |
|
|
assert output.shape[0] == batch_shape[0] |
|
|
assert final_transformation.shape == expected_transformation_shape |
|
|
assert_close(final_transformation, transformation @ test_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def _test_cardinality_implementation(self, input_shape, expected_output_shape, params): |
|
|
|
|
|
|
|
|
augmentation = self._create_augmentation_from_params(**params, p=0.0) |
|
|
test_input = torch.rand(input_shape, device=self.device, dtype=self.dtype) |
|
|
output = augmentation(test_input) |
|
|
assert len(output.shape) == 4 |
|
|
assert output.shape == torch.Size((1,) * (4 - len(input_shape)) + tuple(input_shape)) |
|
|
|
|
|
|
|
|
augmentation = self._create_augmentation_from_params(**params, p=1.0) |
|
|
test_input = torch.rand(input_shape, device=self.device, dtype=self.dtype) |
|
|
output = augmentation(test_input) |
|
|
assert len(output.shape) == 4 |
|
|
assert output.shape == expected_output_shape |
|
|
|
|
|
def _test_random_p_0_implementation(self, params): |
|
|
augmentation = self._create_augmentation_from_params(**params, p=0.0, return_transform=False) |
|
|
test_input = torch.rand((2, 3, 4, 5), device=self.device, dtype=self.dtype) |
|
|
output = augmentation(test_input) |
|
|
assert (output == test_input).all() |
|
|
|
|
|
def _test_random_p_0_return_transform_implementation(self, params): |
|
|
augmentation = self._create_augmentation_from_params(**params, p=0.0, return_transform=True) |
|
|
expected_transformation_shape = torch.Size((2, 3, 3)) |
|
|
test_input = torch.rand((2, 3, 4, 5), device=self.device, dtype=self.dtype) |
|
|
output, transformation = augmentation(test_input) |
|
|
|
|
|
assert (output == test_input).all() |
|
|
assert transformation.shape == expected_transformation_shape |
|
|
assert (transformation == kornia.eye_like(3, transformation)).all() |
|
|
|
|
|
def _test_random_p_1_implementation(self, input_tensor, expected_output, params): |
|
|
augmentation = self._create_augmentation_from_params(**params, p=1.0, return_transform=False) |
|
|
output = augmentation(input_tensor.to(self.device).to(self.dtype)) |
|
|
|
|
|
|
|
|
assert output.shape == expected_output.shape |
|
|
assert_close(output, expected_output.to(self.device).to(self.dtype), atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def _test_random_p_1_return_transform_implementation( |
|
|
self, input_tensor, expected_output, expected_transformation, params |
|
|
): |
|
|
augmentation = self._create_augmentation_from_params(**params, p=1.0, return_transform=True) |
|
|
output, transformation = augmentation(input_tensor.to(self.device).to(self.dtype)) |
|
|
|
|
|
assert output.shape == expected_output.shape |
|
|
assert_close(output, expected_output.to(self.device).to(self.dtype), atol=1e-4, rtol=1e-4) |
|
|
|
|
|
assert transformation.shape == expected_transformation.shape |
|
|
assert_close(transformation, expected_transformation.to(self.device).to(self.dtype), atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def _test_module_implementation(self, params): |
|
|
augmentation = self._create_augmentation_from_params(**params, p=0.5, return_transform=True) |
|
|
|
|
|
augmentation_sequence = nn.Sequential(augmentation, augmentation) |
|
|
|
|
|
input_tensor = torch.rand(3, 5, 5, device=self.device, dtype=self.dtype) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
out1, transform1 = augmentation(input_tensor) |
|
|
out2, transform2 = augmentation(out1) |
|
|
transform = transform2 @ transform1 |
|
|
|
|
|
torch.manual_seed(42) |
|
|
out_sequence, transform_sequence = augmentation_sequence(input_tensor) |
|
|
|
|
|
assert out2.shape == out_sequence.shape |
|
|
assert transform.shape == transform_sequence.shape |
|
|
assert_close(out2, out_sequence, atol=1e-4, rtol=1e-4) |
|
|
assert_close(transform, transform_sequence, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def _test_inverse_coordinate_check_implementation(self, params): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.zeros((1, 3, 50, 100), device=self.device, dtype=self.dtype) |
|
|
input_tensor[:, :, 20:30, 40:60] = 1.0 |
|
|
|
|
|
augmentation = self._create_augmentation_from_params(**params, p=1.0, return_transform=True) |
|
|
output, transform = augmentation(input_tensor) |
|
|
|
|
|
if (transform == kornia.eye_like(3, transform)).all(): |
|
|
pytest.skip("Test not relevant for intensity augmentations.") |
|
|
|
|
|
indices = create_meshgrid( |
|
|
height=output.shape[-2], width=output.shape[-1], normalized_coordinates=False, device=self.device |
|
|
) |
|
|
output_indices = indices.reshape((1, -1, 2)).to(dtype=self.dtype) |
|
|
input_indices = transform_points(_torch_inverse_cast(transform.to(self.dtype)), output_indices) |
|
|
|
|
|
output_indices = output_indices.round().long().squeeze(0) |
|
|
input_indices = input_indices.round().long().squeeze(0) |
|
|
output_values = output[0, 0, output_indices[:, 1], output_indices[:, 0]] |
|
|
value_mask = output_values > 0.9999 |
|
|
|
|
|
output_values = output[0, :, output_indices[:, 1][value_mask], output_indices[:, 0][value_mask]] |
|
|
input_values = input_tensor[0, :, input_indices[:, 1][value_mask], input_indices[:, 0][value_mask]] |
|
|
|
|
|
assert_close(output_values, input_values, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def _test_gradcheck_implementation(self, params): |
|
|
input_tensor = torch.rand((3, 5, 5), device=self.device, dtype=self.dtype) |
|
|
input_tensor = utils.tensor_to_gradcheck_var(input_tensor) |
|
|
assert gradcheck( |
|
|
self._create_augmentation_from_params(**params, p=1.0, return_transform=False), |
|
|
(input_tensor,), |
|
|
raise_exception=True, |
|
|
) |
|
|
|
|
|
|
|
|
class TestRandomEqualizeAlternative(CommonTests): |
|
|
|
|
|
possible_params: Dict["str", Tuple] = {} |
|
|
|
|
|
_augmentation_cls = RandomEqualize |
|
|
_default_param_set: Dict["str", Any] = {} |
|
|
|
|
|
@pytest.fixture(params=[_default_param_set], scope="class") |
|
|
def param_set(self, request): |
|
|
return request.param |
|
|
|
|
|
def test_random_p_1(self): |
|
|
input_tensor = torch.arange(20.0, device=self.device, dtype=self.dtype) / 20 |
|
|
input_tensor = input_tensor.repeat(1, 2, 20, 1) |
|
|
|
|
|
expected_output = torch.tensor( |
|
|
[ |
|
|
0.0000, |
|
|
0.07843, |
|
|
0.15686, |
|
|
0.2353, |
|
|
0.3137, |
|
|
0.3922, |
|
|
0.4706, |
|
|
0.5490, |
|
|
0.6275, |
|
|
0.7059, |
|
|
0.7843, |
|
|
0.8627, |
|
|
0.9412, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
], |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
expected_output = expected_output.repeat(1, 2, 20, 1) |
|
|
|
|
|
parameters = {} |
|
|
self._test_random_p_1_implementation( |
|
|
input_tensor=input_tensor, expected_output=expected_output, params=parameters |
|
|
) |
|
|
|
|
|
def test_random_p_1_return_transform(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.rand(1, 1, 3, 4, device=self.device, dtype=self.dtype) |
|
|
|
|
|
|
|
|
expected_output = input_tensor |
|
|
|
|
|
expected_transformation = kornia.eye_like(3, input_tensor) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
def test_batch(self): |
|
|
input_tensor = torch.arange(20.0, device=self.device, dtype=self.dtype) / 20 |
|
|
input_tensor = input_tensor.repeat(2, 3, 20, 1) |
|
|
|
|
|
expected_output = torch.tensor( |
|
|
[ |
|
|
0.0000, |
|
|
0.07843, |
|
|
0.15686, |
|
|
0.2353, |
|
|
0.3137, |
|
|
0.3922, |
|
|
0.4706, |
|
|
0.5490, |
|
|
0.6275, |
|
|
0.7059, |
|
|
0.7843, |
|
|
0.8627, |
|
|
0.9412, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
], |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
expected_output = expected_output.repeat(2, 3, 20, 1) |
|
|
|
|
|
expected_transformation = kornia.eye_like(3, input_tensor) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
def test_exception(self): |
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(p=1.0)( |
|
|
torch.ones((1, 3, 4, 5) * 200, device=self.device, dtype=self.dtype) |
|
|
) |
|
|
|
|
|
|
|
|
class TestCenterCropAlternative(CommonTests): |
|
|
possible_params: Dict["str", Tuple] = { |
|
|
"size": (2, (2, 2)), |
|
|
"resample": (0, Resample.BILINEAR.name, Resample.BILINEAR), |
|
|
"align_corners": (False, True), |
|
|
} |
|
|
_augmentation_cls = CenterCrop |
|
|
_default_param_set: Dict["str", Any] = {"size": (2, 2), "align_corners": True} |
|
|
|
|
|
@pytest.fixture( |
|
|
params=default_with_one_parameter_changed(default=_default_param_set, **possible_params), scope="class" |
|
|
) |
|
|
def param_set(self, request): |
|
|
return request.param |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"input_shape,expected_output_shape", |
|
|
[((4, 5), (1, 1, 2, 3)), ((3, 4, 5), (1, 3, 2, 3)), ((2, 3, 4, 5), (2, 3, 2, 3))], |
|
|
) |
|
|
def test_cardinality(self, input_shape, expected_output_shape): |
|
|
self._test_cardinality_implementation( |
|
|
input_shape=input_shape, |
|
|
expected_output_shape=expected_output_shape, |
|
|
params={"size": (2, 3), "align_corners": True}, |
|
|
) |
|
|
|
|
|
@pytest.mark.xfail(reason="size=(1,2) results in RuntimeError: solve_cpu: For batch 0: U(3,3) is zero, singular U.") |
|
|
def test_random_p_1(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor([[[[0.6, 0.7]]]], device=self.device, dtype=self.dtype) |
|
|
|
|
|
parameters = {"size": (1, 2), "align_corners": True, "resample": 0} |
|
|
self._test_random_p_1_implementation( |
|
|
input_tensor=input_tensor, expected_output=expected_output, params=parameters |
|
|
) |
|
|
|
|
|
def test_random_p_1_return_transform(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor([[[[0.2, 0.3], [0.6, 0.7], [0.0, 0.1]]]], device=self.device, dtype=self.dtype) |
|
|
expected_transformation = torch.tensor( |
|
|
[[[1.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
parameters = {"size": (3, 2), "align_corners": True, "resample": 0} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
def test_batch(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.rand((2, 3, 4, 4), device=self.device, dtype=self.dtype) |
|
|
expected_output = input_tensor[:, :, 1:3, 1:3] |
|
|
expected_transformation = torch.tensor( |
|
|
[[[1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype |
|
|
).repeat(2, 1, 1) |
|
|
parameters = {"size": (2, 2), "align_corners": True, "resample": 0} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
@pytest.mark.xfail(reason="No input validation is implemented yet.") |
|
|
def test_exception(self): |
|
|
|
|
|
with pytest.raises(TypeError): |
|
|
self._create_augmentation_from_params(size=0.0) |
|
|
with pytest.raises(TypeError): |
|
|
self._create_augmentation_from_params(size=2, align_corners=0) |
|
|
with pytest.raises(TypeError): |
|
|
self._create_augmentation_from_params(size=2, resample=True) |
|
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(size=-1) |
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(size=(-1, 2)) |
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(size=(2, -1)) |
|
|
|
|
|
|
|
|
class TestRandomHorizontalFlipAlternative(CommonTests): |
|
|
possible_params: Dict["str", Tuple] = {} |
|
|
_augmentation_cls = RandomHorizontalFlip |
|
|
_default_param_set: Dict["str", Any] = {} |
|
|
|
|
|
@pytest.fixture(params=[_default_param_set], scope="class") |
|
|
def param_set(self, request): |
|
|
return request.param |
|
|
|
|
|
def test_random_p_1(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.3, 0.2, 0.1], [0.6, 0.5, 0.4], [0.9, 0.8, 0.7]]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
|
|
|
parameters = {} |
|
|
self._test_random_p_1_implementation( |
|
|
input_tensor=input_tensor, expected_output=expected_output, params=parameters |
|
|
) |
|
|
|
|
|
def test_random_p_1_return_transform(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.3, 0.2, 0.1], [0.6, 0.5, 0.4], [0.9, 0.8, 0.7]]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_transformation = torch.tensor( |
|
|
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
def test_batch(self): |
|
|
torch.manual_seed(12) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]], device=self.device, dtype=self.dtype |
|
|
).repeat((2, 1, 1, 1)) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.3, 0.2, 0.1], [0.6, 0.5, 0.4], [0.9, 0.8, 0.7]]]], device=self.device, dtype=self.dtype |
|
|
).repeat((2, 1, 1, 1)) |
|
|
expected_transformation = torch.tensor( |
|
|
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype |
|
|
).repeat((2, 1, 1)) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
@pytest.mark.skip(reason="No special parameters to validate.") |
|
|
def test_exception(self): |
|
|
pass |
|
|
|
|
|
|
|
|
class TestRandomVerticalFlipAlternative(CommonTests): |
|
|
possible_params: Dict["str", Tuple] = {} |
|
|
_augmentation_cls = RandomVerticalFlip |
|
|
_default_param_set: Dict["str", Any] = {} |
|
|
|
|
|
@pytest.fixture(params=[_default_param_set], scope="class") |
|
|
def param_set(self, request): |
|
|
return request.param |
|
|
|
|
|
def test_random_p_1(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.7, 0.8, 0.9], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3]]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
|
|
|
parameters = {} |
|
|
self._test_random_p_1_implementation( |
|
|
input_tensor=input_tensor, expected_output=expected_output, params=parameters |
|
|
) |
|
|
|
|
|
def test_random_p_1_return_transform(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.7, 0.8, 0.9], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3]]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_transformation = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
def test_batch(self): |
|
|
torch.manual_seed(12) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]], device=self.device, dtype=self.dtype |
|
|
).repeat((2, 1, 1, 1)) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.7, 0.8, 0.9], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3]]]], device=self.device, dtype=self.dtype |
|
|
).repeat((2, 1, 1, 1)) |
|
|
expected_transformation = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype |
|
|
).repeat((2, 1, 1)) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
@pytest.mark.skip(reason="No special parameters to validate.") |
|
|
def test_exception(self): |
|
|
pass |
|
|
|
|
|
|
|
|
class TestRandomRotationAlternative(CommonTests): |
|
|
possible_params: Dict["str", Tuple] = { |
|
|
"degrees": (0.0, (-360.0, 360.0), [0.0, 0.0], torch.tensor((-180.0, 180))), |
|
|
"resample": (0, Resample.BILINEAR.name, Resample.BILINEAR), |
|
|
"align_corners": (False, True), |
|
|
} |
|
|
_augmentation_cls = RandomRotation |
|
|
_default_param_set: Dict["str", Any] = {"degrees": (30.0, 30.0), "align_corners": True} |
|
|
|
|
|
@pytest.fixture( |
|
|
params=default_with_one_parameter_changed(default=_default_param_set, **possible_params), scope="class" |
|
|
) |
|
|
def param_set(self, request): |
|
|
return request.param |
|
|
|
|
|
def test_random_p_1(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.3, 0.6, 0.9], [0.2, 0.5, 0.8], [0.1, 0.4, 0.7]]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
|
|
|
parameters = {"degrees": (90.0, 90.0), "align_corners": True} |
|
|
self._test_random_p_1_implementation( |
|
|
input_tensor=input_tensor, expected_output=expected_output, params=parameters |
|
|
) |
|
|
|
|
|
def test_random_p_1_return_transform(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_output = torch.tensor( |
|
|
[[[[0.7, 0.4, 0.1], [0.8, 0.5, 0.2], [0.9, 0.6, 0.3]]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
expected_transformation = torch.tensor( |
|
|
[[[0.0, -1.0, 2.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype |
|
|
) |
|
|
parameters = {"degrees": (-90.0, -90.0), "align_corners": True} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
def test_batch(self): |
|
|
torch.manual_seed(12) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]], device=self.device, dtype=self.dtype |
|
|
).repeat((2, 1, 1, 1)) |
|
|
expected_output = input_tensor |
|
|
expected_transformation = kornia.eye_like(3, input_tensor) |
|
|
parameters = {"degrees": (-360.0, -360.0), "align_corners": True} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
@pytest.mark.xfail(reason="No input validation is implemented yet.") |
|
|
def test_exception(self): |
|
|
|
|
|
with pytest.raises(TypeError): |
|
|
self._create_augmentation_from_params(degrees="") |
|
|
with pytest.raises(TypeError): |
|
|
self._create_augmentation_from_params(degrees=(3, 3), align_corners=0) |
|
|
with pytest.raises(TypeError): |
|
|
self._create_augmentation_from_params(degrees=(3, 3), resample=True) |
|
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(degrees=-361.0) |
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(degrees=(-361.0, 360.0)) |
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(degrees=(-360.0, 361.0)) |
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(degrees=(360.0, -360.0)) |
|
|
|
|
|
|
|
|
class TestRandomGrayscaleAlternative(CommonTests): |
|
|
|
|
|
possible_params: Dict["str", Tuple] = {} |
|
|
|
|
|
_augmentation_cls = RandomGrayscale |
|
|
_default_param_set: Dict["str", Any] = {} |
|
|
|
|
|
@pytest.fixture(params=[_default_param_set], scope="class") |
|
|
def param_set(self, request): |
|
|
return request.param |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"input_shape,expected_output_shape", [((3, 4, 5), (1, 3, 4, 5)), ((2, 3, 4, 5), (2, 3, 4, 5))] |
|
|
) |
|
|
def test_cardinality(self, input_shape, expected_output_shape): |
|
|
self._test_cardinality_implementation( |
|
|
input_shape=input_shape, expected_output_shape=expected_output_shape, params=self._default_param_set |
|
|
) |
|
|
|
|
|
def test_random_p_1(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]], device=self.device, dtype=self.dtype |
|
|
).repeat(1, 3, 1, 1) |
|
|
expected_output = ( |
|
|
(input_tensor * torch.tensor([0.299, 0.587, 0.114], device=self.device, dtype=self.dtype).view(1, 3, 1, 1)) |
|
|
.sum(dim=1, keepdim=True) |
|
|
.repeat(1, 3, 1, 1) |
|
|
) |
|
|
|
|
|
parameters = {} |
|
|
self._test_random_p_1_implementation( |
|
|
input_tensor=input_tensor, expected_output=expected_output, params=parameters |
|
|
) |
|
|
|
|
|
def test_random_p_1_return_transform(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]], device=self.device, dtype=self.dtype |
|
|
).repeat(1, 3, 1, 1) |
|
|
expected_output = ( |
|
|
(input_tensor * torch.tensor([0.299, 0.587, 0.114], device=self.device, dtype=self.dtype).view(1, 3, 1, 1)) |
|
|
.sum(dim=1, keepdim=True) |
|
|
.repeat(1, 3, 1, 1) |
|
|
) |
|
|
|
|
|
expected_transformation = kornia.eye_like(3, input_tensor) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
def test_batch(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input_tensor = torch.tensor( |
|
|
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]], device=self.device, dtype=self.dtype |
|
|
).repeat(2, 3, 1, 1) |
|
|
expected_output = ( |
|
|
(input_tensor * torch.tensor([0.299, 0.587, 0.114], device=self.device, dtype=self.dtype).view(1, 3, 1, 1)) |
|
|
.sum(dim=1, keepdim=True) |
|
|
.repeat(1, 3, 1, 1) |
|
|
) |
|
|
|
|
|
expected_transformation = kornia.eye_like(3, input_tensor) |
|
|
parameters = {} |
|
|
self._test_random_p_1_return_transform_implementation( |
|
|
input_tensor=input_tensor, |
|
|
expected_output=expected_output, |
|
|
expected_transformation=expected_transformation, |
|
|
params=parameters, |
|
|
) |
|
|
|
|
|
@pytest.mark.xfail(reason="No input validation is implemented yet when p=0.") |
|
|
def test_exception(self): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(p=0.0)(torch.rand((1, 1, 4, 5), device=self.device, dtype=self.dtype)) |
|
|
with pytest.raises(ValueError): |
|
|
self._create_augmentation_from_params(p=1.0)(torch.rand((1, 4, 4, 5), device=self.device, dtype=self.dtype)) |
|
|
|
|
|
|
|
|
class TestRandomHorizontalFlip: |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = RandomHorizontalFlip(p=0.5) |
|
|
repr = "RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_hflip(self, device, dtype): |
|
|
|
|
|
f = RandomHorizontalFlip(p=1.0, return_transform=True) |
|
|
f1 = RandomHorizontalFlip(p=0.0, return_transform=True) |
|
|
f2 = RandomHorizontalFlip(p=1.0) |
|
|
f3 = RandomHorizontalFlip(p=0.0) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 2.0]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [2.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected = expected.to(device) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[-1.0, 0.0, 3.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
identity = torch.tensor( |
|
|
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
assert (f(input)[0] == expected).all() |
|
|
assert (f(input)[1] == expected_transform).all() |
|
|
assert (f1(input)[0] == input).all() |
|
|
assert (f1(input)[1] == identity).all() |
|
|
assert (f2(input) == expected).all() |
|
|
assert (f3(input) == input).all() |
|
|
assert (f.inverse(expected) == input).all() |
|
|
assert (f1.inverse(expected) == expected).all() |
|
|
assert (f2.inverse(expected) == input).all() |
|
|
assert (f3.inverse(expected) == expected).all() |
|
|
|
|
|
def test_batch_random_hflip(self, device, dtype): |
|
|
|
|
|
f = RandomHorizontalFlip(p=1.0, return_transform=True) |
|
|
f1 = RandomHorizontalFlip(p=0.0, return_transform=True) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 1.0, 0.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
identity = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
input = input.repeat(5, 3, 1, 1) |
|
|
expected = expected.repeat(5, 3, 1, 1) |
|
|
expected_transform = expected_transform.repeat(5, 1, 1) |
|
|
identity = identity.repeat(5, 1, 1) |
|
|
|
|
|
assert (f(input)[0] == expected).all() |
|
|
assert (f(input)[1] == expected_transform).all() |
|
|
assert (f1(input)[0] == input).all() |
|
|
assert (f1(input)[1] == identity).all() |
|
|
assert (f.inverse(expected) == input).all() |
|
|
assert (f1.inverse(expected) == expected).all() |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomHorizontalFlip(p=0.5, same_on_batch=True) |
|
|
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
assert (f.inverse(res) == input).all() |
|
|
|
|
|
def test_sequential(self, device, dtype): |
|
|
|
|
|
f = nn.Sequential( |
|
|
RandomHorizontalFlip(p=1.0, return_transform=True), RandomHorizontalFlip(p=1.0, return_transform=True) |
|
|
) |
|
|
f1 = nn.Sequential(RandomHorizontalFlip(p=1.0, return_transform=True), RandomHorizontalFlip(p=1.0)) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected_transform_1 = expected_transform @ expected_transform |
|
|
|
|
|
assert (f(input)[0] == input).all() |
|
|
assert (f(input)[1] == expected_transform_1).all() |
|
|
assert (f1(input)[0] == input).all() |
|
|
assert (f1(input)[1] == expected_transform).all() |
|
|
|
|
|
|
|
|
def test_random_hflip_coord_check(self, device, dtype): |
|
|
|
|
|
f = RandomHorizontalFlip(p=1.0, return_transform=True) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
input_coordinates = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], |
|
|
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], |
|
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected_output = torch.tensor( |
|
|
[[[[4.0, 3.0, 2.0, 1.0], [8.0, 7.0, 6.0, 5.0], [12.0, 11.0, 10.0, 9.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
output, transform = f(input) |
|
|
result_coordinates = transform @ input_coordinates |
|
|
|
|
|
input_coordinates = input_coordinates.round().long() |
|
|
result_coordinates = result_coordinates.round().long() |
|
|
|
|
|
|
|
|
assert output.shape == expected_output.shape |
|
|
assert (output == expected_output).all() |
|
|
|
|
|
assert ( |
|
|
torch.torch.logical_and(result_coordinates[0, 0, :] >= 0, result_coordinates[0, 0, :] < input.shape[-1]) |
|
|
).all() |
|
|
assert ( |
|
|
torch.torch.logical_and(result_coordinates[0, 1, :] >= 0, result_coordinates[0, 1, :] < input.shape[-2]) |
|
|
).all() |
|
|
|
|
|
|
|
|
|
|
|
assert ( |
|
|
output[..., result_coordinates[0, 1, :], result_coordinates[0, 0, :]] |
|
|
== input[..., input_coordinates[0, 1, :], input_coordinates[0, 0, :]] |
|
|
).all() |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
input = torch.rand((3, 3), device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck(RandomHorizontalFlip(p=1.0), (input,), raise_exception=True) |
|
|
|
|
|
|
|
|
class TestRandomVerticalFlip: |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = RandomVerticalFlip(p=0.5) |
|
|
repr = "RandomVerticalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_vflip(self, device, dtype): |
|
|
|
|
|
f = RandomVerticalFlip(p=1.0, return_transform=True) |
|
|
f1 = RandomVerticalFlip(p=0.0, return_transform=True) |
|
|
f2 = RandomVerticalFlip(p=1.0) |
|
|
f3 = RandomVerticalFlip(p=0.0) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[[[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
identity = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1(input)[0], input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1(input)[1], identity, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f2(input), expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f3(input), input, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
assert_close(f.inverse(expected), input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1.inverse(input), input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f2.inverse(expected), input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f3.inverse(input), input, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_batch_random_vflip(self, device, dtype): |
|
|
|
|
|
f = RandomVerticalFlip(p=1.0, return_transform=True) |
|
|
f1 = RandomVerticalFlip(p=0.0, return_transform=True) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[[[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
identity = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
input = input.repeat(5, 3, 1, 1) |
|
|
expected = expected.repeat(5, 3, 1, 1) |
|
|
expected_transform = expected_transform.repeat(5, 1, 1) |
|
|
identity = identity.repeat(5, 1, 1) |
|
|
|
|
|
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1(input)[0], input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1(input)[1], identity, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f.inverse(expected), input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1.inverse(input), input, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomVerticalFlip(p=0.5, same_on_batch=True) |
|
|
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
assert (f.inverse(res) == input).all() |
|
|
|
|
|
def test_sequential(self, device, dtype): |
|
|
|
|
|
f = nn.Sequential( |
|
|
RandomVerticalFlip(p=1.0, return_transform=True), RandomVerticalFlip(p=1.0, return_transform=True) |
|
|
) |
|
|
f1 = nn.Sequential(RandomVerticalFlip(p=1.0, return_transform=True), RandomVerticalFlip(p=1.0)) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
expected_transform_1 = expected_transform @ expected_transform |
|
|
|
|
|
assert_close(f(input)[0], input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f(input)[1], expected_transform_1, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1(input)[0], input, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f1(input)[1], expected_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_random_vflip_coord_check(self, device, dtype): |
|
|
|
|
|
f = RandomVerticalFlip(p=1.0, return_transform=True) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
input_coordinates = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], |
|
|
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], |
|
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected_output = torch.tensor( |
|
|
[[[[9.0, 10.0, 11.0, 12.0], [5.0, 6.0, 7.0, 8.0], [1.0, 2.0, 3.0, 4.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
output, transform = f(input) |
|
|
result_coordinates = transform @ input_coordinates |
|
|
|
|
|
input_coordinates = input_coordinates.round().long() |
|
|
result_coordinates = result_coordinates.round().long() |
|
|
|
|
|
|
|
|
assert output.shape == expected_output.shape |
|
|
assert (output == expected_output).all() |
|
|
|
|
|
assert ( |
|
|
torch.torch.logical_and(result_coordinates[0, 0, :] >= 0, result_coordinates[0, 0, :] < input.shape[-1]) |
|
|
).all() |
|
|
assert ( |
|
|
torch.torch.logical_and(result_coordinates[0, 1, :] >= 0, result_coordinates[0, 1, :] < input.shape[-2]) |
|
|
).all() |
|
|
|
|
|
|
|
|
|
|
|
assert ( |
|
|
output[..., result_coordinates[0, 1, :], result_coordinates[0, 0, :]] |
|
|
== input[..., input_coordinates[0, 1, :], input_coordinates[0, 0, :]] |
|
|
).all() |
|
|
|
|
|
|
|
|
class TestColorJitter: |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = ColorJitter(brightness=0.5, contrast=0.3, saturation=[0.2, 1.2], hue=0.1) |
|
|
repr = ( |
|
|
"ColorJitter(brightness=tensor([0.5000, 1.5000]), contrast=tensor([0.7000, 1.3000]), " |
|
|
"saturation=tensor([0.2000, 1.2000]), hue=tensor([-0.1000, 0.1000]), " |
|
|
"p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
) |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_color_jitter(self, device, dtype): |
|
|
|
|
|
f = ColorJitter() |
|
|
f1 = ColorJitter(return_transform=True) |
|
|
|
|
|
input = torch.rand(3, 5, 5, device=device, dtype=dtype).unsqueeze(0) |
|
|
expected = input |
|
|
|
|
|
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-5) |
|
|
assert_close(f1(input)[0], expected, atol=1e-4, rtol=1e-5) |
|
|
assert_close(f1(input)[1], expected_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_color_jitter_batch(self, device, dtype): |
|
|
f = ColorJitter() |
|
|
f1 = ColorJitter(return_transform=True) |
|
|
|
|
|
input = torch.rand(2, 3, 5, 5, device=device, dtype=dtype) |
|
|
expected = input |
|
|
|
|
|
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand((2, 3, 3)) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-5) |
|
|
assert_close(f1(input)[0], expected, atol=1e-4, rtol=1e-5) |
|
|
assert_close(f1(input)[1], expected_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, same_on_batch=True) |
|
|
input = torch.eye(3).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
def _get_expected_brightness(self, device, dtype): |
|
|
return torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]], |
|
|
[[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]], |
|
|
[[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]], |
|
|
], |
|
|
[ |
|
|
[[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]], |
|
|
[[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]], |
|
|
[[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
def test_random_brightness(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(brightness=0.2) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
input = input.repeat(2, 3, 1, 1) |
|
|
|
|
|
expected = self._get_expected_brightness(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_random_brightness_tuple(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(brightness=(0.8, 1.2)) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
input = input.repeat(2, 3, 1, 1) |
|
|
|
|
|
expected = self._get_expected_brightness(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def _get_expected_contrast(self, device, dtype): |
|
|
return torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]], |
|
|
[[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]], |
|
|
[[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]], |
|
|
], |
|
|
[ |
|
|
[[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]], |
|
|
[[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]], |
|
|
[[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
def test_random_contrast(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(contrast=0.2) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
input = input.repeat(2, 3, 1, 1) |
|
|
|
|
|
expected = self._get_expected_contrast(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-5) |
|
|
|
|
|
def test_random_contrast_list(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(contrast=[0.8, 1.2]) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
input = input.repeat(2, 3, 1, 1) |
|
|
|
|
|
expected = self._get_expected_contrast(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-5) |
|
|
|
|
|
def _get_expected_saturation(self, device, dtype): |
|
|
return torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1876, 0.2584, 0.3389], [0.6292, 0.5000, 0.4000], [0.7097, 0.8000, 1.0000]], |
|
|
[[1.0000, 0.5292, 0.6097], [0.6292, 0.3195, 0.2195], [0.8000, 0.1682, 0.2779]], |
|
|
[[0.6389, 0.8000, 0.7000], [0.9000, 0.3195, 0.2195], [0.8000, 0.4389, 0.5487]], |
|
|
], |
|
|
[ |
|
|
[[0.0000, 0.1295, 0.2530], [0.5648, 0.5000, 0.4000], [0.6883, 0.8000, 1.0000]], |
|
|
[[1.0000, 0.4648, 0.5883], [0.5648, 0.2765, 0.1765], [0.8000, 0.0178, 0.1060]], |
|
|
[[0.5556, 0.8000, 0.7000], [0.9000, 0.2765, 0.1765], [0.8000, 0.3530, 0.4413]], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
def test_random_saturation(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(saturation=0.2) |
|
|
|
|
|
input = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]], |
|
|
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], |
|
|
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
input = input.repeat(2, 1, 1, 1) |
|
|
|
|
|
expected = self._get_expected_saturation(device, dtype) |
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_random_saturation_tensor(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(saturation=torch.tensor(0.2)) |
|
|
|
|
|
input = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]], |
|
|
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], |
|
|
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
input = input.repeat(2, 1, 1, 1) |
|
|
|
|
|
expected = self._get_expected_saturation(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_random_saturation_tuple(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(saturation=(0.8, 1.2)) |
|
|
|
|
|
input = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]], |
|
|
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], |
|
|
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
input = input.repeat(2, 1, 1, 1) |
|
|
|
|
|
expected = self._get_expected_saturation(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def _get_expected_hue(self, device, dtype): |
|
|
return torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1000, 0.2000, 0.3000], [0.6000, 0.5000, 0.4000], [0.7000, 0.8000, 1.0000]], |
|
|
[[1.0000, 0.5251, 0.6167], [0.6126, 0.3000, 0.2000], [0.8000, 0.1000, 0.2000]], |
|
|
[[0.5623, 0.8000, 0.7000], [0.9000, 0.3084, 0.2084], [0.7958, 0.4293, 0.5335]], |
|
|
], |
|
|
[ |
|
|
[[0.1000, 0.2000, 0.3000], [0.6116, 0.5000, 0.4000], [0.7000, 0.8000, 1.0000]], |
|
|
[[1.0000, 0.4769, 0.5846], [0.6000, 0.3077, 0.2077], [0.7961, 0.1000, 0.2000]], |
|
|
[[0.6347, 0.8000, 0.7000], [0.9000, 0.3000, 0.2000], [0.8000, 0.3730, 0.4692]], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
def test_random_hue(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(hue=0.1 / pi.item()) |
|
|
|
|
|
input = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]], |
|
|
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], |
|
|
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
input = input.repeat(2, 1, 1, 1) |
|
|
|
|
|
expected = self._get_expected_hue(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_random_hue_list(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(hue=[-0.1 / pi, 0.1 / pi]) |
|
|
|
|
|
input = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]], |
|
|
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], |
|
|
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
input = input.repeat(2, 1, 1, 1) |
|
|
|
|
|
expected = self._get_expected_hue(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_random_hue_list_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
f = ColorJitter(hue=[-0.1 / pi.item(), 0.1 / pi.item()]) |
|
|
|
|
|
input = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]], |
|
|
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], |
|
|
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
input = input.repeat(2, 1, 1, 1) |
|
|
|
|
|
expected = self._get_expected_hue(device, dtype) |
|
|
|
|
|
assert_close(f(input), expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_sequential(self, device, dtype): |
|
|
|
|
|
f = nn.Sequential(ColorJitter(return_transform=True), ColorJitter(return_transform=True)) |
|
|
|
|
|
input = torch.rand(3, 5, 5, device=device, dtype=dtype).unsqueeze(0) |
|
|
|
|
|
expected = input |
|
|
|
|
|
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) |
|
|
|
|
|
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-5) |
|
|
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-5) |
|
|
|
|
|
def test_color_jitter_batch_sequential(self, device, dtype): |
|
|
f = nn.Sequential(ColorJitter(return_transform=True), ColorJitter(return_transform=True)) |
|
|
|
|
|
input = torch.rand(2, 3, 5, 5, device=device, dtype=dtype) |
|
|
expected = input |
|
|
|
|
|
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand((2, 3, 3)) |
|
|
|
|
|
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-5) |
|
|
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-5) |
|
|
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-5) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
input = torch.rand((3, 5, 5), device=device, dtype=dtype).unsqueeze(0) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck(ColorJitter(p=1.0), (input,), raise_exception=True) |
|
|
|
|
|
|
|
|
class TestRectangleRandomErasing: |
|
|
@pytest.mark.parametrize("erase_scale_range", [(0.001, 0.001), (1.0, 1.0)]) |
|
|
@pytest.mark.parametrize("aspect_ratio_range", [(0.1, 0.1), (10.0, 10.0)]) |
|
|
@pytest.mark.parametrize("batch_shape", [(1, 4, 8, 15), (2, 3, 11, 7)]) |
|
|
def test_random_rectangle_erasing_shape(self, batch_shape, erase_scale_range, aspect_ratio_range): |
|
|
input = torch.rand(batch_shape) |
|
|
rand_rec = RandomErasing(erase_scale_range, aspect_ratio_range, p=1.0) |
|
|
assert rand_rec(input).shape == batch_shape |
|
|
|
|
|
@pytest.mark.parametrize("erase_scale_range", [(0.001, 0.001), (1.0, 1.0)]) |
|
|
@pytest.mark.parametrize("aspect_ratio_range", [(0.1, 0.1), (10.0, 10.0)]) |
|
|
@pytest.mark.parametrize("batch_shape", [(1, 4, 8, 15), (2, 3, 11, 7)]) |
|
|
def test_no_rectangle_erasing_shape(self, batch_shape, erase_scale_range, aspect_ratio_range): |
|
|
input = torch.rand(batch_shape) |
|
|
rand_rec = RandomErasing(erase_scale_range, aspect_ratio_range, p=0.0) |
|
|
assert rand_rec(input).equal(input) |
|
|
|
|
|
@pytest.mark.parametrize("erase_scale_range", [(0.001, 0.001), (1.0, 1.0)]) |
|
|
@pytest.mark.parametrize("aspect_ratio_range", [(0.1, 0.1), (10.0, 10.0)]) |
|
|
@pytest.mark.parametrize("shape", [(3, 11, 7)]) |
|
|
def test_same_on_batch(self, shape, erase_scale_range, aspect_ratio_range): |
|
|
f = RandomErasing(erase_scale_range, aspect_ratio_range, same_on_batch=True, p=0.5) |
|
|
input = torch.rand(shape).unsqueeze(dim=0).repeat(2, 1, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
|
|
|
batch_shape = (2, 3, 11, 7) |
|
|
erase_scale_range = (0.2, 0.4) |
|
|
aspect_ratio_range = (0.3, 0.5) |
|
|
|
|
|
rand_rec = RandomErasing(erase_scale_range, aspect_ratio_range, p=1.0) |
|
|
rect_params = rand_rec.forward_parameters(batch_shape) |
|
|
|
|
|
|
|
|
input = torch.rand(batch_shape, device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck(rand_rec, (input, rect_params), raise_exception=True) |
|
|
|
|
|
|
|
|
class TestRandomGrayscale: |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = RandomGrayscale() |
|
|
repr = "RandomGrayscale(p=0.1, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_grayscale(self, device, dtype): |
|
|
|
|
|
f = RandomGrayscale(return_transform=True) |
|
|
|
|
|
input = torch.rand(3, 5, 5, device=device, dtype=dtype) |
|
|
|
|
|
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) |
|
|
|
|
|
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomGrayscale(p=0.5, same_on_batch=True) |
|
|
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
def test_opencv_true(self, device, dtype): |
|
|
data = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912], |
|
|
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008], |
|
|
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481], |
|
|
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344], |
|
|
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593], |
|
|
], |
|
|
[ |
|
|
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274], |
|
|
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100], |
|
|
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997], |
|
|
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165], |
|
|
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426], |
|
|
], |
|
|
[ |
|
|
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074], |
|
|
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513], |
|
|
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770], |
|
|
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012], |
|
|
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532], |
|
|
], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016], |
|
|
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204], |
|
|
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113], |
|
|
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529], |
|
|
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805], |
|
|
], |
|
|
[ |
|
|
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016], |
|
|
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204], |
|
|
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113], |
|
|
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529], |
|
|
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805], |
|
|
], |
|
|
[ |
|
|
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016], |
|
|
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204], |
|
|
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113], |
|
|
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529], |
|
|
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805], |
|
|
], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
img_gray = RandomGrayscale(p=1.0)(data) |
|
|
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_opencv_false(self, device, dtype): |
|
|
data = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912], |
|
|
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008], |
|
|
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481], |
|
|
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344], |
|
|
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593], |
|
|
], |
|
|
[ |
|
|
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274], |
|
|
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100], |
|
|
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997], |
|
|
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165], |
|
|
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426], |
|
|
], |
|
|
[ |
|
|
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074], |
|
|
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513], |
|
|
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770], |
|
|
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012], |
|
|
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532], |
|
|
], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected = data |
|
|
|
|
|
img_gray = RandomGrayscale(p=0.0)(data) |
|
|
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_opencv_true_batch(self, device, dtype): |
|
|
data = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912], |
|
|
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008], |
|
|
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481], |
|
|
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344], |
|
|
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593], |
|
|
], |
|
|
[ |
|
|
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274], |
|
|
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100], |
|
|
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997], |
|
|
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165], |
|
|
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426], |
|
|
], |
|
|
[ |
|
|
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074], |
|
|
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513], |
|
|
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770], |
|
|
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012], |
|
|
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
data = data.unsqueeze(0).repeat(4, 1, 1, 1) |
|
|
|
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016], |
|
|
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204], |
|
|
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113], |
|
|
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529], |
|
|
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805], |
|
|
], |
|
|
[ |
|
|
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016], |
|
|
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204], |
|
|
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113], |
|
|
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529], |
|
|
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805], |
|
|
], |
|
|
[ |
|
|
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016], |
|
|
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204], |
|
|
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113], |
|
|
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529], |
|
|
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
expected = expected.unsqueeze(0).repeat(4, 1, 1, 1) |
|
|
|
|
|
img_gray = RandomGrayscale(p=1.0)(data) |
|
|
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_opencv_false_batch(self, device, dtype): |
|
|
data = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912], |
|
|
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008], |
|
|
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481], |
|
|
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344], |
|
|
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593], |
|
|
], |
|
|
[ |
|
|
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274], |
|
|
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100], |
|
|
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997], |
|
|
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165], |
|
|
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426], |
|
|
], |
|
|
[ |
|
|
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074], |
|
|
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513], |
|
|
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770], |
|
|
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012], |
|
|
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
data = data.unsqueeze(0).repeat(4, 1, 1, 1) |
|
|
|
|
|
expected = data |
|
|
|
|
|
img_gray = RandomGrayscale(p=0.0)(data) |
|
|
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_random_grayscale_sequential_batch(self, device, dtype): |
|
|
f = nn.Sequential(RandomGrayscale(p=0.0, return_transform=True), RandomGrayscale(p=0.0, return_transform=True)) |
|
|
|
|
|
input = torch.rand(2, 3, 5, 5, device=device, dtype=dtype) |
|
|
expected = input |
|
|
|
|
|
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand((2, 3, 3)) |
|
|
expected_transform = expected_transform.to(device) |
|
|
|
|
|
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
input = torch.rand((3, 5, 5), device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck(RandomGrayscale(p=1.0), (input,), raise_exception=True) |
|
|
assert gradcheck(RandomGrayscale(p=0.0), (input,), raise_exception=True) |
|
|
|
|
|
|
|
|
class TestCenterCrop: |
|
|
def test_no_transform(self, device, dtype): |
|
|
inp = torch.rand(1, 2, 4, 4, device=device, dtype=dtype) |
|
|
out = CenterCrop(2)(inp) |
|
|
assert out.shape == (1, 2, 2, 2) |
|
|
aug = CenterCrop(2, cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert out.shape == (1, 2, 2, 2) |
|
|
assert aug.inverse(out).shape == (1, 2, 4, 4) |
|
|
|
|
|
def test_transform(self, device, dtype): |
|
|
inp = torch.rand(1, 2, 5, 4, device=device, dtype=dtype) |
|
|
out = CenterCrop(2, return_transform=True)(inp) |
|
|
assert len(out) == 2 |
|
|
assert out[0].shape == (1, 2, 2, 2) |
|
|
assert out[1].shape == (1, 3, 3) |
|
|
aug = CenterCrop(2, cropping_mode="resample", return_transform=True) |
|
|
out = aug(inp) |
|
|
assert out[0].shape == (1, 2, 2, 2) |
|
|
assert out[1].shape == (1, 3, 3) |
|
|
assert aug.inverse(out).shape == (1, 2, 5, 4) |
|
|
|
|
|
def test_no_transform_tuple(self, device, dtype): |
|
|
inp = torch.rand(1, 2, 5, 4, device=device, dtype=dtype) |
|
|
out = CenterCrop((3, 4))(inp) |
|
|
assert out.shape == (1, 2, 3, 4) |
|
|
aug = CenterCrop((3, 4), cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert out.shape == (1, 2, 3, 4) |
|
|
assert aug.inverse(out).shape == (1, 2, 5, 4) |
|
|
|
|
|
def test_crop_modes(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
img = torch.rand(1, 3, 5, 5, device=device, dtype=dtype) |
|
|
|
|
|
op1 = CenterCrop(size=(2, 2), cropping_mode='resample') |
|
|
out = op1(img) |
|
|
|
|
|
op2 = CenterCrop(size=(2, 2), cropping_mode='slice') |
|
|
|
|
|
assert_close(out, op2(img, op1._params)) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
input = torch.rand(1, 2, 3, 4, device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck(CenterCrop(3), (input,), raise_exception=True) |
|
|
|
|
|
|
|
|
class TestRandomRotation: |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = RandomRotation(degrees=45.5) |
|
|
repr = ( |
|
|
"RandomRotation(degrees=tensor([-45.5000, 45.5000]), interpolation=BILINEAR, p=0.5, " |
|
|
"p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
) |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_rotation(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
f = RandomRotation(degrees=45.0, return_transform=True, p=1.0) |
|
|
f1 = RandomRotation(degrees=45.0, p=1.0) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[0.9824, 0.0088, 0.0000, 1.9649], |
|
|
[0.0000, 0.0029, 0.0000, 0.0176], |
|
|
[0.0029, 1.0000, 1.9883, 0.0000], |
|
|
[0.0000, 0.0088, 1.0117, 1.9649], |
|
|
] |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[[1.0000, -0.0059, 0.0088], [0.0059, 1.0000, -0.0088], [0.0000, 0.0000, 1.0000]]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected_2 = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[0.1322, 0.0000, 0.7570, 0.2644], |
|
|
[0.3785, 0.0000, 0.4166, 0.0000], |
|
|
[0.0000, 0.6309, 1.5910, 1.2371], |
|
|
[0.0000, 0.1444, 0.3177, 0.6499], |
|
|
] |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
out, mat = f(input) |
|
|
assert_close(out, expected, rtol=1e-6, atol=1e-4) |
|
|
assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4) |
|
|
assert_close(f1(input), expected_2, rtol=1e-6, atol=1e-4) |
|
|
|
|
|
def test_batch_random_rotation(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
f = RandomRotation(degrees=45.0, return_transform=True, p=1.0) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[[[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]]]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[0.9824, 0.0088, 0.0000, 1.9649], |
|
|
[0.0000, 0.0029, 0.0000, 0.0176], |
|
|
[0.0029, 1.0000, 1.9883, 0.0000], |
|
|
[0.0000, 0.0088, 1.0117, 1.9649], |
|
|
] |
|
|
], |
|
|
[ |
|
|
[ |
|
|
[0.1322, 0.0000, 0.7570, 0.2644], |
|
|
[0.3785, 0.0000, 0.4166, 0.0000], |
|
|
[0.0000, 0.6309, 1.5910, 1.2371], |
|
|
[0.0000, 0.1444, 0.3177, 0.6499], |
|
|
] |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[ |
|
|
[[1.0000, -0.0059, 0.0088], [0.0059, 1.0000, -0.0088], [0.0000, 0.0000, 1.0000]], |
|
|
[[0.9125, 0.4090, -0.4823], [-0.4090, 0.9125, 0.7446], [0.0000, 0.0000, 1.0000]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
input = input.repeat(2, 1, 1, 1) |
|
|
|
|
|
out, mat = f(input) |
|
|
assert_close(out, expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(mat, expected_transform, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomRotation(degrees=40, same_on_batch=True) |
|
|
input = torch.eye(6, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
def test_sequential(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
f = nn.Sequential( |
|
|
RandomRotation(torch.tensor([-45.0, 90]), return_transform=True, p=1.0), |
|
|
RandomRotation(10.4, return_transform=True, p=1.0), |
|
|
) |
|
|
f1 = nn.Sequential( |
|
|
RandomRotation(torch.tensor([-45.0, 90]), return_transform=True, p=1.0), RandomRotation(10.4, p=1.0) |
|
|
) |
|
|
|
|
|
input = torch.tensor( |
|
|
[[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[0.1314, 0.1050, 0.6649, 0.2628], |
|
|
[0.3234, 0.0202, 0.4256, 0.1671], |
|
|
[0.0525, 0.5976, 1.5199, 1.1306], |
|
|
[0.0000, 0.1453, 0.3224, 0.5796], |
|
|
] |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected_transform = torch.tensor( |
|
|
[[[0.8864, 0.4629, -0.5240], [-0.4629, 0.8864, 0.8647], [0.0000, 0.0000, 1.0000]]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
expected_transform_2 = torch.tensor( |
|
|
[[[0.8381, -0.5455, 1.0610], [0.5455, 0.8381, -0.5754], [0.0000, 0.0000, 1.0000]]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
out, mat = f(input) |
|
|
_, mat_2 = f1(input) |
|
|
assert_close(out, expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(mat, expected_transform, rtol=1e-4, atol=1e-4) |
|
|
assert_close(mat_2, expected_transform_2, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
input = torch.rand((3, 3), device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck(RandomRotation(degrees=(15.0, 15.0), p=1.0), (input,), raise_exception=True) |
|
|
|
|
|
|
|
|
class TestRandomCrop: |
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = RandomCrop(size=(2, 3), padding=(0, 1), fill=10, pad_if_needed=False, p=1.0) |
|
|
repr = ( |
|
|
"RandomCrop(crop_size=(2, 3), padding=(0, 1), fill=10, pad_if_needed=False, padding_mode=constant, " |
|
|
"resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
) |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_no_padding(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
inp = torch.tensor([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
expected = torch.tensor([[[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
rc = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0) |
|
|
out = rc(inp) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
out2 = rc(inp.squeeze()) |
|
|
|
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(out2, expected, atol=1e-4, rtol=1e-4) |
|
|
torch.manual_seed(0) |
|
|
inversed = torch.tensor([[[[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
aug = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0, cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_no_padding_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
batch_size = 2 |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat( |
|
|
batch_size, 1, 1, 1 |
|
|
) |
|
|
expected = torch.tensor( |
|
|
[[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]], [[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
rc = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0) |
|
|
out = rc(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
inversed = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [0.0, 0.0, 0.0]]], |
|
|
[[[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
aug = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0, cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomCrop(size=(2, 3), padding=1, same_on_batch=True, align_corners=True, p=1.0) |
|
|
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
def test_padding(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
inp = torch.tensor([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
expected = torch.tensor([[[[7.0, 8.0, 7.0], [4.0, 5.0, 4.0]]]], device=device, dtype=dtype) |
|
|
rc = RandomCrop(size=(2, 3), padding=1, padding_mode='reflect', align_corners=True, p=1.0) |
|
|
out = rc(inp) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
out2 = rc(inp.squeeze()) |
|
|
|
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(out2, expected, atol=1e-4, rtol=1e-4) |
|
|
torch.manual_seed(42) |
|
|
inversed = torch.tensor([[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
aug = RandomCrop( |
|
|
size=(2, 3), padding=1, padding_mode='reflect', align_corners=True, p=1.0, cropping_mode="resample" |
|
|
) |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_padding_batch_1(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
batch_size = 2 |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat( |
|
|
batch_size, 1, 1, 1 |
|
|
) |
|
|
expected = torch.tensor( |
|
|
[[[[1.0, 2.0, 0.0], [4.0, 5.0, 0.0]]], [[[7.0, 8.0, 0.0], [0.0, 0.0, 0.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
rc = RandomCrop(size=(2, 3), padding=1, align_corners=True, p=1.0) |
|
|
out = rc(inp) |
|
|
|
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
inversed = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 0.0, 0.0]]], |
|
|
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 7.0, 8.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
aug = RandomCrop(size=(2, 3), padding=1, align_corners=True, p=1.0, cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_padding_batch_2(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
batch_size = 2 |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat( |
|
|
batch_size, 1, 1, 1 |
|
|
) |
|
|
expected = torch.tensor( |
|
|
[[[[1.0, 2.0, 10.0], [4.0, 5.0, 10.0]]], [[[4.0, 5.0, 10.0], [7.0, 8.0, 10.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
rc = RandomCrop(size=(2, 3), padding=(0, 1), fill=10, align_corners=True, p=1.0) |
|
|
out = rc(inp) |
|
|
|
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
torch.manual_seed(42) |
|
|
inversed = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 0.0, 0.0]]], |
|
|
[[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
aug = RandomCrop(size=(2, 3), padding=(0, 1), fill=10, align_corners=True, p=1.0, cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_padding_batch_3(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
batch_size = 2 |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat( |
|
|
batch_size, 1, 1, 1 |
|
|
) |
|
|
expected = torch.tensor( |
|
|
[[[[8.0, 8.0, 8.0], [8.0, 0.0, 1.0]]], [[[8.0, 8.0, 8.0], [1.0, 2.0, 8.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
rc = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=8, align_corners=True, p=1.0) |
|
|
out = rc(inp) |
|
|
|
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
inversed = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], |
|
|
[[[0.0, 1.0, 2.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
aug = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=8, align_corners=True, p=1.0, cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_padding_no_forward(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
inp = torch.tensor([[[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
trans = torch.tensor([[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype) |
|
|
|
|
|
rc = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=9, align_corners=True, p=0.0) |
|
|
|
|
|
out = rc(inp) |
|
|
assert_close(out, inp, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
out = rc((inp, trans)) |
|
|
assert_close(out[0], inp, atol=1e-4, rtol=1e-4) |
|
|
assert_close(out[1], trans, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
|
|
|
rc = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=9, align_corners=True, p=0.0, return_transform=True) |
|
|
out = rc(inp) |
|
|
assert_close(out[0], inp, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
out = rc((inp, trans)) |
|
|
assert_close(out[0], inp, atol=1e-4, rtol=1e-4) |
|
|
assert_close(out[1], trans, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_pad_if_needed(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
batch_size = 2 |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0]]], device=device, dtype=dtype).repeat(batch_size, 1, 1, 1) |
|
|
expected = torch.tensor( |
|
|
[[[[9.0, 9.0, 9.0], [0.0, 1.0, 2.0]]], [[[9.0, 9.0, 9.0], [0.0, 1.0, 2.0]]]], device=device, dtype=dtype |
|
|
) |
|
|
rc = RandomCrop(size=(2, 3), pad_if_needed=True, fill=9, align_corners=True, p=1.0) |
|
|
out = rc(inp) |
|
|
|
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
inversed = torch.tensor([[[[0.0, 1.0, 2.0]]], [[[0.0, 1.0, 2.0]]]], device=device, dtype=dtype) |
|
|
aug = RandomCrop(size=(2, 3), pad_if_needed=True, fill=9, align_corners=True, p=1.0, cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_crop_modes(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
img = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype) |
|
|
|
|
|
op1 = RandomCrop(size=(2, 2), cropping_mode='resample') |
|
|
out = op1(img) |
|
|
|
|
|
op2 = RandomCrop(size=(2, 2), cropping_mode='slice') |
|
|
|
|
|
assert_close(out, op2(img, op1._params)) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
inp = torch.rand((3, 3, 3), device=device, dtype=dtype) |
|
|
inp = utils.tensor_to_gradcheck_var(inp) |
|
|
assert gradcheck(RandomCrop(size=(3, 3), p=1.0), (inp,), raise_exception=True) |
|
|
|
|
|
@pytest.mark.skip("Need to fix Union type") |
|
|
def test_jit(self, device, dtype): |
|
|
|
|
|
op = RandomCrop(size=(3, 3), p=1.0).forward |
|
|
op_script = torch.jit.script(op) |
|
|
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype) |
|
|
|
|
|
actual = op_script(img) |
|
|
expected = kornia.geometry.transform.center_crop3d(img) |
|
|
assert_close(actual, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
@pytest.mark.skip("Need to fix Union type") |
|
|
def test_jit_trace(self, device, dtype): |
|
|
|
|
|
op = RandomCrop(size=(3, 3), p=1.0).forward |
|
|
op_script = torch.jit.script(op) |
|
|
|
|
|
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype) |
|
|
|
|
|
op_trace = torch.jit.trace(op_script, (img,)) |
|
|
|
|
|
|
|
|
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
actual = op_trace(img) |
|
|
expected = op(img) |
|
|
assert_close(actual, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
|
|
|
class TestRandomResizedCrop: |
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0)) |
|
|
repr = ( |
|
|
"RandomResizedCrop(size=(2, 3), scale=tensor([1., 1.]), ratio=tensor([1., 1.]), " |
|
|
"interpolation=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
) |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_no_resize(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype) |
|
|
|
|
|
expected = torch.tensor([[[[0.0000, 1.0000, 2.0000], [6.0000, 7.0000, 8.0000]]]], device=device, dtype=dtype) |
|
|
|
|
|
rrc = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0)) |
|
|
|
|
|
out = rrc(inp) |
|
|
assert_close(out, expected, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
aug = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0), cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inp[None], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0), same_on_batch=True) |
|
|
input = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat( |
|
|
2, 1, 1, 1 |
|
|
) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
torch.manual_seed(0) |
|
|
aug = RandomResizedCrop( |
|
|
size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0), same_on_batch=True, cropping_mode="resample" |
|
|
) |
|
|
out = aug(input) |
|
|
inversed = aug.inverse(out) |
|
|
assert (inversed[0] == inversed[1]).all() |
|
|
|
|
|
def test_crop_scale_ratio(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[[[[1.0000, 1.5000, 2.0000], [4.0000, 4.5000, 5.0000], [7.0000, 7.5000, 8.0000]]]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
rrc = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0)) |
|
|
|
|
|
out = rrc(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
inversed = torch.tensor([[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
aug = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0), cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_crop_size_greater_than_input(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype) |
|
|
|
|
|
exp = torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[ |
|
|
[1.0000, 1.3333, 1.6667, 2.0000], |
|
|
[3.0000, 3.3333, 3.6667, 4.0000], |
|
|
[5.0000, 5.3333, 5.6667, 6.0000], |
|
|
[7.0000, 7.3333, 7.6667, 8.0000], |
|
|
] |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
rrc = RandomResizedCrop(size=(4, 4), scale=(3.0, 3.0), ratio=(2.0, 2.0)) |
|
|
|
|
|
out = rrc(inp) |
|
|
assert out.shape == torch.Size([1, 1, 4, 4]) |
|
|
assert_close(out, exp, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
inversed = torch.tensor([[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]]], device=device, dtype=dtype) |
|
|
aug = RandomResizedCrop(size=(4, 4), scale=(3.0, 3.0), ratio=(2.0, 2.0), cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, exp, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_crop_scale_ratio_batch(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
batch_size = 2 |
|
|
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat( |
|
|
batch_size, 1, 1, 1 |
|
|
) |
|
|
|
|
|
expected = torch.tensor( |
|
|
[ |
|
|
[[[1.0000, 1.5000, 2.0000], [4.0000, 4.5000, 5.0000], [7.0000, 7.5000, 8.0000]]], |
|
|
[[[0.0000, 0.5000, 1.0000], [3.0000, 3.5000, 4.0000], [6.0000, 6.5000, 7.0000]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
rrc = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0)) |
|
|
|
|
|
out = rrc(inp) |
|
|
assert_close(out, expected, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
torch.manual_seed(0) |
|
|
inversed = torch.tensor( |
|
|
[ |
|
|
[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]], |
|
|
[[[0.0, 1.0, 0.0], [3.0, 4.0, 0.0], [6.0, 7.0, 0.0]]], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
aug = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0), cropping_mode="resample") |
|
|
out = aug(inp) |
|
|
assert_close(out, expected, atol=1e-4, rtol=1e-4) |
|
|
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_crop_modes(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
img = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype) |
|
|
|
|
|
op1 = RandomResizedCrop(size=(4, 4), cropping_mode='resample') |
|
|
out = op1(img) |
|
|
|
|
|
op2 = RandomResizedCrop(size=(4, 4), cropping_mode='slice') |
|
|
|
|
|
assert_close(out, op2(img, op1._params)) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
inp = torch.rand((1, 3, 3), device=device, dtype=dtype) |
|
|
inp = utils.tensor_to_gradcheck_var(inp) |
|
|
assert gradcheck( |
|
|
RandomResizedCrop(size=(3, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0)), (inp,), raise_exception=True |
|
|
) |
|
|
|
|
|
|
|
|
class TestRandomEqualize: |
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing precision.") |
|
|
def test_smoke(self, device, dtype): |
|
|
f = RandomEqualize(p=0.5) |
|
|
repr = "RandomEqualize(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_equalize(self, device, dtype): |
|
|
f = RandomEqualize(p=1.0, return_transform=True) |
|
|
f1 = RandomEqualize(p=0.0, return_transform=True) |
|
|
f2 = RandomEqualize(p=1.0) |
|
|
f3 = RandomEqualize(p=0.0) |
|
|
|
|
|
bs, channels, height, width = 1, 3, 20, 20 |
|
|
|
|
|
inputs = self.build_input(channels, height, width, bs, device=device, dtype=dtype) |
|
|
|
|
|
row_expected = torch.tensor( |
|
|
[ |
|
|
0.0000, |
|
|
0.07843, |
|
|
0.15686, |
|
|
0.2353, |
|
|
0.3137, |
|
|
0.3922, |
|
|
0.4706, |
|
|
0.5490, |
|
|
0.6275, |
|
|
0.7059, |
|
|
0.7843, |
|
|
0.8627, |
|
|
0.9412, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
] |
|
|
) |
|
|
expected = self.build_input(channels, height, width, bs=1, row=row_expected, device=device, dtype=dtype) |
|
|
identity = kornia.eye_like(3, expected) |
|
|
|
|
|
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_batch_random_equalize(self, device, dtype): |
|
|
f = RandomEqualize(p=1.0, return_transform=True) |
|
|
f1 = RandomEqualize(p=0.0, return_transform=True) |
|
|
f2 = RandomEqualize(p=1.0) |
|
|
f3 = RandomEqualize(p=0.0) |
|
|
|
|
|
bs, channels, height, width = 2, 3, 20, 20 |
|
|
|
|
|
inputs = self.build_input(channels, height, width, bs, device=device, dtype=dtype) |
|
|
|
|
|
row_expected = torch.tensor( |
|
|
[ |
|
|
0.0000, |
|
|
0.07843, |
|
|
0.15686, |
|
|
0.2353, |
|
|
0.3137, |
|
|
0.3922, |
|
|
0.4706, |
|
|
0.5490, |
|
|
0.6275, |
|
|
0.7059, |
|
|
0.7843, |
|
|
0.8627, |
|
|
0.9412, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
1.0000, |
|
|
] |
|
|
) |
|
|
expected = self.build_input(channels, height, width, bs, row=row_expected, device=device, dtype=dtype) |
|
|
|
|
|
identity = kornia.eye_like(3, expected) |
|
|
|
|
|
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomEqualize(p=0.5, same_on_batch=True) |
|
|
input = torch.eye(4, device=device, dtype=dtype) |
|
|
input = input.unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
input = torch.rand((3, 3, 3), device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck(RandomEqualize(p=0.5), (input,), raise_exception=True) |
|
|
|
|
|
@staticmethod |
|
|
def build_input(channels, height, width, bs=1, row=None, device='cpu', dtype=torch.float32): |
|
|
if row is None: |
|
|
row = torch.arange(width, device=device, dtype=dtype) / float(width) |
|
|
|
|
|
channel = torch.stack([row] * height) |
|
|
image = torch.stack([channel] * channels) |
|
|
batch = torch.stack([image] * bs) |
|
|
|
|
|
return batch.to(device, dtype) |
|
|
|
|
|
|
|
|
class TestGaussianBlur: |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self): |
|
|
f = RandomGaussianBlur((3, 3), (0.1, 2.0), p=1.0) |
|
|
repr = "RandomGaussianBlur(p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)" |
|
|
assert str(f) == repr |
|
|
|
|
|
|
|
|
class TestRandomInvert: |
|
|
def test_smoke(self, device, dtype): |
|
|
img = torch.ones(1, 3, 4, 5, device=device, dtype=dtype) |
|
|
assert_close(RandomInvert(p=1.0)(img), torch.zeros_like(img)) |
|
|
|
|
|
|
|
|
class TestRandomChannelShuffle: |
|
|
def test_smoke(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
img = torch.arange(1 * 3 * 2 * 2, device=device, dtype=dtype).view(1, 3, 2, 2) |
|
|
|
|
|
out_expected = torch.tensor( |
|
|
[[[[8.0, 9.0], [10.0, 11.0]], [[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
aug = RandomChannelShuffle(p=1.0) |
|
|
out = aug(img) |
|
|
assert_close(out, out_expected) |
|
|
|
|
|
|
|
|
class TestRandomGaussianNoise: |
|
|
def test_smoke(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype) |
|
|
aug = RandomGaussianNoise(p=1.0) |
|
|
assert img.shape == aug(img).shape |
|
|
|
|
|
|
|
|
class TestNormalize: |
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self, device, dtype): |
|
|
f = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0])) |
|
|
repr = ( |
|
|
"Normalize(mean=torch.tensor([1.]), std=torch.tensor([1.]), p=1., p_batch=1.0, " |
|
|
"same_on_batch=False, return_transform=False)" |
|
|
) |
|
|
assert str(f) == repr |
|
|
|
|
|
@staticmethod |
|
|
@pytest.mark.parametrize( |
|
|
"mean, std", [((1.0, 1.0, 1.0), (0.5, 0.5, 0.5)), (1.0, 0.5), (torch.tensor([1.0]), torch.tensor([0.5]))] |
|
|
) |
|
|
def test_random_normalize_different_parameter_types(mean, std): |
|
|
f = Normalize(mean=mean, std=std, p=1) |
|
|
data = torch.ones(2, 3, 256, 313) |
|
|
if isinstance(mean, float): |
|
|
expected = (data - torch.as_tensor(mean)) / torch.as_tensor(std) |
|
|
else: |
|
|
expected = (data - torch.as_tensor(mean[0])) / torch.as_tensor(std[0]) |
|
|
assert_close(f(data), expected) |
|
|
|
|
|
@staticmethod |
|
|
@pytest.mark.parametrize("mean, std", [((1.0, 1.0, 1.0, 1.0), (0.5, 0.5, 0.5, 0.5)), ((1.0, 1.0), (0.5, 0.5))]) |
|
|
def test_random_normalize_invalid_parameter_shape(mean, std): |
|
|
f = Normalize(mean=mean, std=std, p=1.0, return_transform=True) |
|
|
inputs = torch.arange(0.0, 16.0, step=1).reshape(1, 4, 4).unsqueeze(0) |
|
|
with pytest.raises(ValueError): |
|
|
f(inputs) |
|
|
|
|
|
def test_random_normalize(self, device, dtype): |
|
|
f = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True) |
|
|
f1 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True) |
|
|
f2 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0) |
|
|
f3 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0) |
|
|
|
|
|
inputs = torch.arange(0.0, 16.0, step=1, device=device, dtype=dtype).reshape(1, 4, 4).unsqueeze(0) |
|
|
|
|
|
expected = (inputs - 1) * 2 |
|
|
|
|
|
identity = kornia.eye_like(3, expected) |
|
|
|
|
|
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_batch_random_normalize(self, device, dtype): |
|
|
f = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True) |
|
|
f1 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True) |
|
|
f2 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0) |
|
|
f3 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0) |
|
|
|
|
|
inputs = torch.arange(0.0, 16.0 * 2, step=1, device=device, dtype=dtype).reshape(2, 1, 4, 4) |
|
|
|
|
|
expected = (inputs - 1) * 2 |
|
|
|
|
|
identity = kornia.eye_like(3, expected) |
|
|
|
|
|
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
input = torch.rand((3, 3, 3), device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck( |
|
|
Normalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0]), p=1.0), (input,), raise_exception=True |
|
|
) |
|
|
|
|
|
|
|
|
class TestDenormalize: |
|
|
|
|
|
|
|
|
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.") |
|
|
def test_smoke(self, device, dtype): |
|
|
f = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0])) |
|
|
repr = ( |
|
|
"Denormalize(mean=torch.tensor([1.]), std=torch.tensor([1.]), p=1., p_batch=1.0, " |
|
|
"same_on_batch=False, return_transform=False)" |
|
|
) |
|
|
assert str(f) == repr |
|
|
|
|
|
def test_random_denormalize(self, device, dtype): |
|
|
f = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True) |
|
|
f1 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True) |
|
|
f2 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0) |
|
|
f3 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0) |
|
|
|
|
|
inputs = torch.arange(0.0, 16.0, step=1, device=device, dtype=dtype).reshape(1, 4, 4).unsqueeze(0) |
|
|
|
|
|
expected = inputs / 2 + 1 |
|
|
|
|
|
identity = kornia.eye_like(3, expected) |
|
|
|
|
|
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_batch_random_denormalize(self, device, dtype): |
|
|
f = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True) |
|
|
f1 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True) |
|
|
f2 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0) |
|
|
f3 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0) |
|
|
|
|
|
inputs = torch.arange(0.0, 16.0 * 2, step=1, device=device, dtype=dtype).reshape(2, 1, 4, 4) |
|
|
|
|
|
expected = inputs / 2 + 1 |
|
|
|
|
|
identity = kornia.eye_like(3, expected) |
|
|
|
|
|
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4) |
|
|
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
input = torch.rand((3, 3, 3), device=device, dtype=dtype) |
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
assert gradcheck( |
|
|
Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0]), p=1.0), (input,), raise_exception=True |
|
|
) |
|
|
|
|
|
|
|
|
class TestRandomFisheye: |
|
|
def test_smoke(self, device, dtype): |
|
|
torch.manual_seed(0) |
|
|
center_x = torch.tensor([-0.3, 0.3]) |
|
|
center_y = torch.tensor([-0.3, 0.3]) |
|
|
gamma = torch.tensor([-1.0, 1.0]) |
|
|
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype) |
|
|
aug = RandomFisheye(center_x, center_y, gamma, p=1.0) |
|
|
assert img.shape == aug(img).shape |
|
|
|
|
|
@pytest.mark.skip(reason="RuntimeError: Jacobian mismatch for output 0 with respect to input 0") |
|
|
def test_gradcheck(self, device, dtype): |
|
|
img = torch.rand(1, 1, 3, 3, device=device, dtype=dtype) |
|
|
center_x = torch.tensor([-0.3, 0.3], device=device, dtype=dtype) |
|
|
center_y = torch.tensor([-0.3, 0.3], device=device, dtype=dtype) |
|
|
gamma = torch.tensor([-1.0, 1.0], device=device, dtype=dtype) |
|
|
img = utils.tensor_to_gradcheck_var(img) |
|
|
center_x = utils.tensor_to_gradcheck_var(center_x) |
|
|
center_y = utils.tensor_to_gradcheck_var(center_y) |
|
|
gamma = utils.tensor_to_gradcheck_var(gamma) |
|
|
assert gradcheck(RandomFisheye(center_x, center_y, gamma), (img,), raise_exception=True) |
|
|
|
|
|
|
|
|
class TestRandomElasticTransform: |
|
|
def test_smoke(self, device, dtype): |
|
|
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype) |
|
|
aug = RandomElasticTransform(p=1.0) |
|
|
assert img.shape == aug(img).shape |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
f = RandomElasticTransform(p=1.0, same_on_batch=True) |
|
|
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1) |
|
|
res = f(input) |
|
|
assert (res[0] == res[1]).all() |
|
|
|
|
|
|
|
|
class TestRandomThinPlateSpline: |
|
|
def test_smoke(self, device, dtype): |
|
|
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype) |
|
|
aug = RandomThinPlateSpline(p=1.0) |
|
|
assert img.shape == aug(img).shape |
|
|
|
|
|
|
|
|
class TestRandomBoxBlur: |
|
|
def test_smoke(self, device, dtype): |
|
|
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype) |
|
|
aug = RandomBoxBlur(p=1.0) |
|
|
assert img.shape == aug(img).shape |
|
|
|
|
|
|
|
|
class TestPadTo: |
|
|
def test_smoke(self, device, dtype): |
|
|
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype) |
|
|
aug = PadTo(size=(4, 5)) |
|
|
out = aug(img) |
|
|
assert out.shape == (1, 1, 4, 5) |
|
|
assert (aug.inverse(out) == img).all() |
|
|
|