compvis / test /augmentation /test_augmentation.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
raw
history blame
117 kB
from typing import Any, Dict, Optional, Tuple, Type
import pytest
import torch
import torch.nn as nn
from torch.autograd import gradcheck
import kornia
import kornia.testing as utils # test utils
from kornia.augmentation import (
CenterCrop,
ColorJitter,
Denormalize,
Normalize,
PadTo,
RandomBoxBlur,
RandomChannelShuffle,
RandomCrop,
RandomElasticTransform,
RandomEqualize,
RandomErasing,
RandomFisheye,
RandomGaussianBlur,
RandomGaussianNoise,
RandomGrayscale,
RandomHorizontalFlip,
RandomInvert,
RandomResizedCrop,
RandomRotation,
RandomThinPlateSpline,
RandomVerticalFlip,
)
from kornia.augmentation.base import AugmentationBase2D
from kornia.constants import pi, Resample
from kornia.geometry import transform_points
from kornia.testing import assert_close, BaseTester, default_with_one_parameter_changed
from kornia.utils import create_meshgrid
from kornia.utils.helpers import _torch_inverse_cast
# TODO same_on_batch tests?
@pytest.mark.usefixtures("device", "dtype")
class CommonTests(BaseTester):
fixture_names = ("device", "dtype")
############################################################################################################
# Attribute variables to set
############################################################################################################
_augmentation_cls: Optional[Type[AugmentationBase2D]] = None
_default_param_set: Dict["str", Any] = {}
############################################################################################################
# Fixtures
############################################################################################################
@pytest.fixture(autouse=True)
def auto_injector_fixture(self, request):
for fixture_name in self.fixture_names:
setattr(self, fixture_name, request.getfixturevalue(fixture_name))
@pytest.fixture(scope="class")
def param_set(self, request):
raise NotImplementedError("param_set must be overridden in subclasses")
############################################################################################################
# Test cases
############################################################################################################
def test_smoke(self, param_set):
self._test_smoke_implementation(params=param_set)
self._test_smoke_call_implementation(params=param_set)
self._test_smoke_return_transform_implementation(params=param_set)
@pytest.mark.parametrize(
"input_shape,expected_output_shape",
[((4, 5), (1, 1, 4, 5)), ((3, 4, 5), (1, 3, 4, 5)), ((2, 3, 4, 5), (2, 3, 4, 5))],
)
def test_cardinality(self, input_shape, expected_output_shape):
self._test_cardinality_implementation(
input_shape=input_shape, expected_output_shape=expected_output_shape, params=self._default_param_set
)
def test_random_p_0(self):
self._test_random_p_0_implementation(params=self._default_param_set)
def test_random_p_0_return_transform(self):
self._test_random_p_0_return_transform_implementation(params=self._default_param_set)
def test_random_p_1(self):
raise NotImplementedError("Implement a stupid routine.")
def test_random_p_1_return_transform(self):
raise NotImplementedError("Implement a stupid routine.")
def test_inverse_coordinate_check(self):
self._test_inverse_coordinate_check_implementation(params=self._default_param_set)
def test_exception(self):
raise NotImplementedError("Implement a stupid routine.")
def test_batch(self):
raise NotImplementedError("Implement a stupid routine.")
@pytest.mark.skip(reason="turn off all jit for a while")
def test_jit(self):
raise NotImplementedError("Implement a stupid routine.")
def test_module(self):
self._test_module_implementation(params=self._default_param_set)
def test_gradcheck(self):
self._test_gradcheck_implementation(params=self._default_param_set)
# TODO Implement
# test_batch
# test_batch_return_transform
# test_coordinate check
# test_jit
# test_gradcheck
def _create_augmentation_from_params(self, **params):
return self._augmentation_cls(**params)
############################################################################################################
# Test case implementations
############################################################################################################
def _test_smoke_implementation(self, params):
assert issubclass(
self._augmentation_cls, AugmentationBase2D
), f"{self._augmentation_cls} is not a subclass of AugmentationBase2D"
# Can be instatiated
augmentation = self._create_augmentation_from_params(**params, return_transform=False)
assert issubclass(
type(augmentation), AugmentationBase2D
), f"{type(augmentation)} is not a subclass of AugmentationBase2D"
# generate_parameters can be called and returns the correct amount of parameters
batch_shape = (4, 3, 5, 6)
generated_params = augmentation.generate_parameters(batch_shape)
assert isinstance(generated_params, dict)
# compute_transformation can be called and returns the correct shaped transformation matrix
expected_transformation_shape = torch.Size((batch_shape[0], 3, 3))
test_input = torch.ones(batch_shape, device=self.device, dtype=self.dtype)
transformation = augmentation.compute_transformation(test_input, generated_params)
assert transformation.shape == expected_transformation_shape
# apply_transform can be called and returns the correct batch sized output
output = augmentation.apply_transform(test_input, generated_params, transformation)
assert output.shape[0] == batch_shape[0]
def _test_smoke_call_implementation(self, params):
batch_shape = (4, 3, 5, 6)
expected_transformation_shape = torch.Size((batch_shape[0], 3, 3))
test_input = torch.ones(batch_shape, device=self.device, dtype=self.dtype)
augmentation = self._create_augmentation_from_params(**params, return_transform=False)
generated_params = augmentation.generate_parameters(batch_shape)
test_transform = torch.rand(expected_transformation_shape, device=self.device, dtype=self.dtype)
output = augmentation(test_input, params=generated_params)
assert output.shape[0] == batch_shape[0]
output, transformation = augmentation(test_input, params=generated_params, return_transform=True)
assert output.shape[0] == batch_shape[0]
assert transformation.shape == expected_transformation_shape
output, final_transformation = augmentation(
(test_input, test_transform), params=generated_params, return_transform=True
)
assert output.shape[0] == batch_shape[0]
assert final_transformation.shape == expected_transformation_shape
assert_close(final_transformation, transformation @ test_transform, atol=1e-4, rtol=1e-4)
output, transformation = augmentation((test_input, test_transform), params=generated_params)
assert output.shape[0] == batch_shape[0]
assert transformation.shape == expected_transformation_shape
assert (transformation == test_transform).all()
def _test_smoke_return_transform_implementation(self, params):
batch_shape = (4, 3, 5, 6)
expected_transformation_shape = torch.Size((batch_shape[0], 3, 3))
test_input = torch.ones(batch_shape, device=self.device, dtype=self.dtype)
augmentation = self._create_augmentation_from_params(**params, return_transform=True)
generated_params = augmentation.generate_parameters(batch_shape)
test_transform = torch.rand(expected_transformation_shape, device=self.device, dtype=self.dtype)
output, transformation = augmentation(test_input, params=generated_params)
assert output.shape[0] == batch_shape[0]
assert transformation.shape == expected_transformation_shape
output, final_transformation = augmentation((test_input, test_transform), params=generated_params)
assert output.shape[0] == batch_shape[0]
assert final_transformation.shape == expected_transformation_shape
assert_close(final_transformation, transformation @ test_transform, atol=1e-4, rtol=1e-4)
output, final_transformation = augmentation(
(test_input, test_transform), params=generated_params, return_transform=True
)
assert output.shape[0] == batch_shape[0]
assert final_transformation.shape == expected_transformation_shape
assert_close(final_transformation, transformation @ test_transform, atol=1e-4, rtol=1e-4)
def _test_cardinality_implementation(self, input_shape, expected_output_shape, params):
# p==0.0
augmentation = self._create_augmentation_from_params(**params, p=0.0)
test_input = torch.rand(input_shape, device=self.device, dtype=self.dtype)
output = augmentation(test_input)
assert len(output.shape) == 4
assert output.shape == torch.Size((1,) * (4 - len(input_shape)) + tuple(input_shape))
# p==1.0
augmentation = self._create_augmentation_from_params(**params, p=1.0)
test_input = torch.rand(input_shape, device=self.device, dtype=self.dtype)
output = augmentation(test_input)
assert len(output.shape) == 4
assert output.shape == expected_output_shape
def _test_random_p_0_implementation(self, params):
augmentation = self._create_augmentation_from_params(**params, p=0.0, return_transform=False)
test_input = torch.rand((2, 3, 4, 5), device=self.device, dtype=self.dtype)
output = augmentation(test_input)
assert (output == test_input).all()
def _test_random_p_0_return_transform_implementation(self, params):
augmentation = self._create_augmentation_from_params(**params, p=0.0, return_transform=True)
expected_transformation_shape = torch.Size((2, 3, 3))
test_input = torch.rand((2, 3, 4, 5), device=self.device, dtype=self.dtype)
output, transformation = augmentation(test_input)
assert (output == test_input).all()
assert transformation.shape == expected_transformation_shape
assert (transformation == kornia.eye_like(3, transformation)).all()
def _test_random_p_1_implementation(self, input_tensor, expected_output, params):
augmentation = self._create_augmentation_from_params(**params, p=1.0, return_transform=False)
output = augmentation(input_tensor.to(self.device).to(self.dtype))
# Output should match
assert output.shape == expected_output.shape
assert_close(output, expected_output.to(self.device).to(self.dtype), atol=1e-4, rtol=1e-4)
def _test_random_p_1_return_transform_implementation(
self, input_tensor, expected_output, expected_transformation, params
):
augmentation = self._create_augmentation_from_params(**params, p=1.0, return_transform=True)
output, transformation = augmentation(input_tensor.to(self.device).to(self.dtype))
# Output should match
assert output.shape == expected_output.shape
assert_close(output, expected_output.to(self.device).to(self.dtype), atol=1e-4, rtol=1e-4)
# Transformation should match
assert transformation.shape == expected_transformation.shape
assert_close(transformation, expected_transformation.to(self.device).to(self.dtype), atol=1e-4, rtol=1e-4)
def _test_module_implementation(self, params):
augmentation = self._create_augmentation_from_params(**params, p=0.5, return_transform=True)
augmentation_sequence = nn.Sequential(augmentation, augmentation)
input_tensor = torch.rand(3, 5, 5, device=self.device, dtype=self.dtype) # 3 x 5 x 5
torch.manual_seed(42)
out1, transform1 = augmentation(input_tensor)
out2, transform2 = augmentation(out1)
transform = transform2 @ transform1
torch.manual_seed(42)
out_sequence, transform_sequence = augmentation_sequence(input_tensor)
assert out2.shape == out_sequence.shape
assert transform.shape == transform_sequence.shape
assert_close(out2, out_sequence, atol=1e-4, rtol=1e-4)
assert_close(transform, transform_sequence, atol=1e-4, rtol=1e-4)
def _test_inverse_coordinate_check_implementation(self, params):
torch.manual_seed(42)
input_tensor = torch.zeros((1, 3, 50, 100), device=self.device, dtype=self.dtype)
input_tensor[:, :, 20:30, 40:60] = 1.0
augmentation = self._create_augmentation_from_params(**params, p=1.0, return_transform=True)
output, transform = augmentation(input_tensor)
if (transform == kornia.eye_like(3, transform)).all():
pytest.skip("Test not relevant for intensity augmentations.")
indices = create_meshgrid(
height=output.shape[-2], width=output.shape[-1], normalized_coordinates=False, device=self.device
)
output_indices = indices.reshape((1, -1, 2)).to(dtype=self.dtype)
input_indices = transform_points(_torch_inverse_cast(transform.to(self.dtype)), output_indices)
output_indices = output_indices.round().long().squeeze(0)
input_indices = input_indices.round().long().squeeze(0)
output_values = output[0, 0, output_indices[:, 1], output_indices[:, 0]]
value_mask = output_values > 0.9999
output_values = output[0, :, output_indices[:, 1][value_mask], output_indices[:, 0][value_mask]]
input_values = input_tensor[0, :, input_indices[:, 1][value_mask], input_indices[:, 0][value_mask]]
assert_close(output_values, input_values, atol=1e-4, rtol=1e-4)
def _test_gradcheck_implementation(self, params):
input_tensor = torch.rand((3, 5, 5), device=self.device, dtype=self.dtype) # 3 x 3
input_tensor = utils.tensor_to_gradcheck_var(input_tensor) # to var
assert gradcheck(
self._create_augmentation_from_params(**params, p=1.0, return_transform=False),
(input_tensor,),
raise_exception=True,
)
class TestRandomEqualizeAlternative(CommonTests):
possible_params: Dict["str", Tuple] = {}
_augmentation_cls = RandomEqualize
_default_param_set: Dict["str", Any] = {}
@pytest.fixture(params=[_default_param_set], scope="class")
def param_set(self, request):
return request.param
def test_random_p_1(self):
input_tensor = torch.arange(20.0, device=self.device, dtype=self.dtype) / 20
input_tensor = input_tensor.repeat(1, 2, 20, 1)
expected_output = torch.tensor(
[
0.0000,
0.07843,
0.15686,
0.2353,
0.3137,
0.3922,
0.4706,
0.5490,
0.6275,
0.7059,
0.7843,
0.8627,
0.9412,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
],
device=self.device,
dtype=self.dtype,
)
expected_output = expected_output.repeat(1, 2, 20, 1)
parameters = {}
self._test_random_p_1_implementation(
input_tensor=input_tensor, expected_output=expected_output, params=parameters
)
def test_random_p_1_return_transform(self):
torch.manual_seed(42)
input_tensor = torch.rand(1, 1, 3, 4, device=self.device, dtype=self.dtype)
# Note: For small inputs it should return the input image
expected_output = input_tensor
expected_transformation = kornia.eye_like(3, input_tensor)
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
def test_batch(self):
input_tensor = torch.arange(20.0, device=self.device, dtype=self.dtype) / 20
input_tensor = input_tensor.repeat(2, 3, 20, 1)
expected_output = torch.tensor(
[
0.0000,
0.07843,
0.15686,
0.2353,
0.3137,
0.3922,
0.4706,
0.5490,
0.6275,
0.7059,
0.7843,
0.8627,
0.9412,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
],
device=self.device,
dtype=self.dtype,
)
expected_output = expected_output.repeat(2, 3, 20, 1)
expected_transformation = kornia.eye_like(3, input_tensor)
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
def test_exception(self):
with pytest.raises(ValueError):
self._create_augmentation_from_params(p=1.0)(
torch.ones((1, 3, 4, 5) * 200, device=self.device, dtype=self.dtype)
)
class TestCenterCropAlternative(CommonTests):
possible_params: Dict["str", Tuple] = {
"size": (2, (2, 2)),
"resample": (0, Resample.BILINEAR.name, Resample.BILINEAR),
"align_corners": (False, True),
}
_augmentation_cls = CenterCrop
_default_param_set: Dict["str", Any] = {"size": (2, 2), "align_corners": True}
@pytest.fixture(
params=default_with_one_parameter_changed(default=_default_param_set, **possible_params), scope="class"
)
def param_set(self, request):
return request.param
@pytest.mark.parametrize(
"input_shape,expected_output_shape",
[((4, 5), (1, 1, 2, 3)), ((3, 4, 5), (1, 3, 2, 3)), ((2, 3, 4, 5), (2, 3, 2, 3))],
)
def test_cardinality(self, input_shape, expected_output_shape):
self._test_cardinality_implementation(
input_shape=input_shape,
expected_output_shape=expected_output_shape,
params={"size": (2, 3), "align_corners": True},
)
@pytest.mark.xfail(reason="size=(1,2) results in RuntimeError: solve_cpu: For batch 0: U(3,3) is zero, singular U.")
def test_random_p_1(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor([[[[0.6, 0.7]]]], device=self.device, dtype=self.dtype)
parameters = {"size": (1, 2), "align_corners": True, "resample": 0}
self._test_random_p_1_implementation(
input_tensor=input_tensor, expected_output=expected_output, params=parameters
)
def test_random_p_1_return_transform(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor([[[[0.2, 0.3], [0.6, 0.7], [0.0, 0.1]]]], device=self.device, dtype=self.dtype)
expected_transformation = torch.tensor(
[[[1.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype
)
parameters = {"size": (3, 2), "align_corners": True, "resample": 0}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
def test_batch(self):
torch.manual_seed(42)
input_tensor = torch.rand((2, 3, 4, 4), device=self.device, dtype=self.dtype)
expected_output = input_tensor[:, :, 1:3, 1:3]
expected_transformation = torch.tensor(
[[[1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype
).repeat(2, 1, 1)
parameters = {"size": (2, 2), "align_corners": True, "resample": 0}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
@pytest.mark.xfail(reason="No input validation is implemented yet.")
def test_exception(self):
# Wrong type
with pytest.raises(TypeError):
self._create_augmentation_from_params(size=0.0)
with pytest.raises(TypeError):
self._create_augmentation_from_params(size=2, align_corners=0)
with pytest.raises(TypeError):
self._create_augmentation_from_params(size=2, resample=True)
# Bound check
with pytest.raises(ValueError):
self._create_augmentation_from_params(size=-1)
with pytest.raises(ValueError):
self._create_augmentation_from_params(size=(-1, 2))
with pytest.raises(ValueError):
self._create_augmentation_from_params(size=(2, -1))
class TestRandomHorizontalFlipAlternative(CommonTests):
possible_params: Dict["str", Tuple] = {}
_augmentation_cls = RandomHorizontalFlip
_default_param_set: Dict["str", Any] = {}
@pytest.fixture(params=[_default_param_set], scope="class")
def param_set(self, request):
return request.param
def test_random_p_1(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor(
[[[[0.3, 0.2, 0.1], [0.6, 0.5, 0.4], [0.9, 0.8, 0.7]]]], device=self.device, dtype=self.dtype
)
parameters = {}
self._test_random_p_1_implementation(
input_tensor=input_tensor, expected_output=expected_output, params=parameters
)
def test_random_p_1_return_transform(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor(
[[[[0.3, 0.2, 0.1], [0.6, 0.5, 0.4], [0.9, 0.8, 0.7]]]], device=self.device, dtype=self.dtype
)
expected_transformation = torch.tensor(
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype
)
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
def test_batch(self):
torch.manual_seed(12)
input_tensor = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]], device=self.device, dtype=self.dtype
).repeat((2, 1, 1, 1))
expected_output = torch.tensor(
[[[[0.3, 0.2, 0.1], [0.6, 0.5, 0.4], [0.9, 0.8, 0.7]]]], device=self.device, dtype=self.dtype
).repeat((2, 1, 1, 1))
expected_transformation = torch.tensor(
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype
).repeat((2, 1, 1))
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
@pytest.mark.skip(reason="No special parameters to validate.")
def test_exception(self):
pass
class TestRandomVerticalFlipAlternative(CommonTests):
possible_params: Dict["str", Tuple] = {}
_augmentation_cls = RandomVerticalFlip
_default_param_set: Dict["str", Any] = {}
@pytest.fixture(params=[_default_param_set], scope="class")
def param_set(self, request):
return request.param
def test_random_p_1(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor(
[[[[0.7, 0.8, 0.9], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3]]]], device=self.device, dtype=self.dtype
)
parameters = {}
self._test_random_p_1_implementation(
input_tensor=input_tensor, expected_output=expected_output, params=parameters
)
def test_random_p_1_return_transform(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor(
[[[[0.7, 0.8, 0.9], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3]]]], device=self.device, dtype=self.dtype
)
expected_transformation = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype
)
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
def test_batch(self):
torch.manual_seed(12)
input_tensor = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]], device=self.device, dtype=self.dtype
).repeat((2, 1, 1, 1))
expected_output = torch.tensor(
[[[[0.7, 0.8, 0.9], [0.4, 0.5, 0.6], [0.1, 0.2, 0.3]]]], device=self.device, dtype=self.dtype
).repeat((2, 1, 1, 1))
expected_transformation = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype
).repeat((2, 1, 1))
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
@pytest.mark.skip(reason="No special parameters to validate.")
def test_exception(self):
pass
class TestRandomRotationAlternative(CommonTests):
possible_params: Dict["str", Tuple] = {
"degrees": (0.0, (-360.0, 360.0), [0.0, 0.0], torch.tensor((-180.0, 180))),
"resample": (0, Resample.BILINEAR.name, Resample.BILINEAR),
"align_corners": (False, True),
}
_augmentation_cls = RandomRotation
_default_param_set: Dict["str", Any] = {"degrees": (30.0, 30.0), "align_corners": True}
@pytest.fixture(
params=default_with_one_parameter_changed(default=_default_param_set, **possible_params), scope="class"
)
def param_set(self, request):
return request.param
def test_random_p_1(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor(
[[[[0.3, 0.6, 0.9], [0.2, 0.5, 0.8], [0.1, 0.4, 0.7]]]], device=self.device, dtype=self.dtype
)
parameters = {"degrees": (90.0, 90.0), "align_corners": True}
self._test_random_p_1_implementation(
input_tensor=input_tensor, expected_output=expected_output, params=parameters
)
def test_random_p_1_return_transform(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], device=self.device, dtype=self.dtype
)
expected_output = torch.tensor(
[[[[0.7, 0.4, 0.1], [0.8, 0.5, 0.2], [0.9, 0.6, 0.3]]]], device=self.device, dtype=self.dtype
)
expected_transformation = torch.tensor(
[[[0.0, -1.0, 2.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]], device=self.device, dtype=self.dtype
)
parameters = {"degrees": (-90.0, -90.0), "align_corners": True}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
def test_batch(self):
torch.manual_seed(12)
input_tensor = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]], device=self.device, dtype=self.dtype
).repeat((2, 1, 1, 1))
expected_output = input_tensor
expected_transformation = kornia.eye_like(3, input_tensor)
parameters = {"degrees": (-360.0, -360.0), "align_corners": True}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
@pytest.mark.xfail(reason="No input validation is implemented yet.")
def test_exception(self):
# Wrong type
with pytest.raises(TypeError):
self._create_augmentation_from_params(degrees="")
with pytest.raises(TypeError):
self._create_augmentation_from_params(degrees=(3, 3), align_corners=0)
with pytest.raises(TypeError):
self._create_augmentation_from_params(degrees=(3, 3), resample=True)
# Bound check
with pytest.raises(ValueError):
self._create_augmentation_from_params(degrees=-361.0)
with pytest.raises(ValueError):
self._create_augmentation_from_params(degrees=(-361.0, 360.0))
with pytest.raises(ValueError):
self._create_augmentation_from_params(degrees=(-360.0, 361.0))
with pytest.raises(ValueError):
self._create_augmentation_from_params(degrees=(360.0, -360.0))
class TestRandomGrayscaleAlternative(CommonTests):
possible_params: Dict["str", Tuple] = {}
_augmentation_cls = RandomGrayscale
_default_param_set: Dict["str", Any] = {}
@pytest.fixture(params=[_default_param_set], scope="class")
def param_set(self, request):
return request.param
@pytest.mark.parametrize(
"input_shape,expected_output_shape", [((3, 4, 5), (1, 3, 4, 5)), ((2, 3, 4, 5), (2, 3, 4, 5))]
)
def test_cardinality(self, input_shape, expected_output_shape):
self._test_cardinality_implementation(
input_shape=input_shape, expected_output_shape=expected_output_shape, params=self._default_param_set
)
def test_random_p_1(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]], device=self.device, dtype=self.dtype
).repeat(1, 3, 1, 1)
expected_output = (
(input_tensor * torch.tensor([0.299, 0.587, 0.114], device=self.device, dtype=self.dtype).view(1, 3, 1, 1))
.sum(dim=1, keepdim=True)
.repeat(1, 3, 1, 1)
)
parameters = {}
self._test_random_p_1_implementation(
input_tensor=input_tensor, expected_output=expected_output, params=parameters
)
def test_random_p_1_return_transform(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]], device=self.device, dtype=self.dtype
).repeat(1, 3, 1, 1)
expected_output = (
(input_tensor * torch.tensor([0.299, 0.587, 0.114], device=self.device, dtype=self.dtype).view(1, 3, 1, 1))
.sum(dim=1, keepdim=True)
.repeat(1, 3, 1, 1)
)
expected_transformation = kornia.eye_like(3, input_tensor)
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
def test_batch(self):
torch.manual_seed(42)
input_tensor = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 0.0, 0.1, 0.2]], device=self.device, dtype=self.dtype
).repeat(2, 3, 1, 1)
expected_output = (
(input_tensor * torch.tensor([0.299, 0.587, 0.114], device=self.device, dtype=self.dtype).view(1, 3, 1, 1))
.sum(dim=1, keepdim=True)
.repeat(1, 3, 1, 1)
)
expected_transformation = kornia.eye_like(3, input_tensor)
parameters = {}
self._test_random_p_1_return_transform_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)
@pytest.mark.xfail(reason="No input validation is implemented yet when p=0.")
def test_exception(self):
torch.manual_seed(42)
with pytest.raises(ValueError):
self._create_augmentation_from_params(p=0.0)(torch.rand((1, 1, 4, 5), device=self.device, dtype=self.dtype))
with pytest.raises(ValueError):
self._create_augmentation_from_params(p=1.0)(torch.rand((1, 4, 4, 5), device=self.device, dtype=self.dtype))
class TestRandomHorizontalFlip:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = RandomHorizontalFlip(p=0.5)
repr = "RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)"
assert str(f) == repr
def test_random_hflip(self, device, dtype):
f = RandomHorizontalFlip(p=1.0, return_transform=True)
f1 = RandomHorizontalFlip(p=0.0, return_transform=True)
f2 = RandomHorizontalFlip(p=1.0)
f3 = RandomHorizontalFlip(p=0.0)
input = torch.tensor(
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 2.0]], device=device, dtype=dtype
) # 3 x 4
expected = torch.tensor(
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [2.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype
) # 3 x 4
expected = expected.to(device)
expected_transform = torch.tensor(
[[-1.0, 0.0, 3.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device, dtype=dtype
) # 3 x 3
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device, dtype=dtype
) # 3 x 3
assert (f(input)[0] == expected).all()
assert (f(input)[1] == expected_transform).all()
assert (f1(input)[0] == input).all()
assert (f1(input)[1] == identity).all()
assert (f2(input) == expected).all()
assert (f3(input) == input).all()
assert (f.inverse(expected) == input).all()
assert (f1.inverse(expected) == expected).all()
assert (f2.inverse(expected) == input).all()
assert (f3.inverse(expected) == expected).all()
def test_batch_random_hflip(self, device, dtype):
f = RandomHorizontalFlip(p=1.0, return_transform=True)
f1 = RandomHorizontalFlip(p=0.0, return_transform=True)
input = torch.tensor(
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected = torch.tensor(
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 1.0, 0.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected_transform = torch.tensor(
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 1 x 3 x 3
identity = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 1 x 3 x 3
input = input.repeat(5, 3, 1, 1) # 5 x 3 x 3 x 3
expected = expected.repeat(5, 3, 1, 1) # 5 x 3 x 3 x 3
expected_transform = expected_transform.repeat(5, 1, 1) # 5 x 3 x 3
identity = identity.repeat(5, 1, 1) # 5 x 3 x 3
assert (f(input)[0] == expected).all()
assert (f(input)[1] == expected_transform).all()
assert (f1(input)[0] == input).all()
assert (f1(input)[1] == identity).all()
assert (f.inverse(expected) == input).all()
assert (f1.inverse(expected) == expected).all()
def test_same_on_batch(self, device, dtype):
f = RandomHorizontalFlip(p=0.5, same_on_batch=True)
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
assert (f.inverse(res) == input).all()
def test_sequential(self, device, dtype):
f = nn.Sequential(
RandomHorizontalFlip(p=1.0, return_transform=True), RandomHorizontalFlip(p=1.0, return_transform=True)
)
f1 = nn.Sequential(RandomHorizontalFlip(p=1.0, return_transform=True), RandomHorizontalFlip(p=1.0))
input = torch.tensor(
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected_transform = torch.tensor(
[[[-1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 1 x 3 x 3
expected_transform_1 = expected_transform @ expected_transform
assert (f(input)[0] == input).all()
assert (f(input)[1] == expected_transform_1).all()
assert (f1(input)[0] == input).all()
assert (f1(input)[1] == expected_transform).all()
# TODO: Introduce Kornia.Sequential to do the inverse.
def test_random_hflip_coord_check(self, device, dtype):
f = RandomHorizontalFlip(p=1.0, return_transform=True)
input = torch.tensor(
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 4
input_coordinates = torch.tensor(
[
[
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], # x coord
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], # y coord
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
],
device=device,
dtype=dtype,
) # 1 x 3 x 3
expected_output = torch.tensor(
[[[[4.0, 3.0, 2.0, 1.0], [8.0, 7.0, 6.0, 5.0], [12.0, 11.0, 10.0, 9.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 4
output, transform = f(input)
result_coordinates = transform @ input_coordinates
# NOTE: without rounding it might produce unexpected results
input_coordinates = input_coordinates.round().long()
result_coordinates = result_coordinates.round().long()
# Tensors must have the same shapes and values
assert output.shape == expected_output.shape
assert (output == expected_output).all()
# Transformed indices must not be out of bound
assert (
torch.torch.logical_and(result_coordinates[0, 0, :] >= 0, result_coordinates[0, 0, :] < input.shape[-1])
).all()
assert (
torch.torch.logical_and(result_coordinates[0, 1, :] >= 0, result_coordinates[0, 1, :] < input.shape[-2])
).all()
# Values in the output tensor at the places of transformed indices must
# have the same value as the input tensor has at the corresponding
# positions
assert (
output[..., result_coordinates[0, 1, :], result_coordinates[0, 0, :]]
== input[..., input_coordinates[0, 1, :], input_coordinates[0, 0, :]]
).all()
def test_gradcheck(self, device, dtype):
input = torch.rand((3, 3), device=device, dtype=dtype) # 3 x 3
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(RandomHorizontalFlip(p=1.0), (input,), raise_exception=True)
class TestRandomVerticalFlip:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = RandomVerticalFlip(p=0.5)
repr = "RandomVerticalFlip(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)"
assert str(f) == repr
def test_random_vflip(self, device, dtype):
f = RandomVerticalFlip(p=1.0, return_transform=True)
f1 = RandomVerticalFlip(p=0.0, return_transform=True)
f2 = RandomVerticalFlip(p=1.0)
f3 = RandomVerticalFlip(p=0.0)
input = torch.tensor(
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected = torch.tensor(
[[[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected_transform = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 3 x 3
identity = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 3 x 3
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-4)
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4)
assert_close(f1(input)[0], input, atol=1e-4, rtol=1e-4)
assert_close(f1(input)[1], identity, atol=1e-4, rtol=1e-4)
assert_close(f2(input), expected, atol=1e-4, rtol=1e-4)
assert_close(f3(input), input, atol=1e-4, rtol=1e-4)
assert_close(f.inverse(expected), input, atol=1e-4, rtol=1e-4)
assert_close(f1.inverse(input), input, atol=1e-4, rtol=1e-4)
assert_close(f2.inverse(expected), input, atol=1e-4, rtol=1e-4)
assert_close(f3.inverse(input), input, atol=1e-4, rtol=1e-4)
def test_batch_random_vflip(self, device, dtype):
f = RandomVerticalFlip(p=1.0, return_transform=True)
f1 = RandomVerticalFlip(p=0.0, return_transform=True)
input = torch.tensor(
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected = torch.tensor(
[[[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected_transform = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 1 x 3 x 3
identity = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 1 x 3 x 3
input = input.repeat(5, 3, 1, 1) # 5 x 3 x 3 x 3
expected = expected.repeat(5, 3, 1, 1) # 5 x 3 x 3 x 3
expected_transform = expected_transform.repeat(5, 1, 1) # 5 x 3 x 3
identity = identity.repeat(5, 1, 1) # 5 x 3 x 3
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-4)
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4)
assert_close(f1(input)[0], input, atol=1e-4, rtol=1e-4)
assert_close(f1(input)[1], identity, atol=1e-4, rtol=1e-4)
assert_close(f.inverse(expected), input, atol=1e-4, rtol=1e-4)
assert_close(f1.inverse(input), input, atol=1e-4, rtol=1e-4)
def test_same_on_batch(self, device, dtype):
f = RandomVerticalFlip(p=0.5, same_on_batch=True)
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
assert (f.inverse(res) == input).all()
def test_sequential(self, device, dtype):
f = nn.Sequential(
RandomVerticalFlip(p=1.0, return_transform=True), RandomVerticalFlip(p=1.0, return_transform=True)
)
f1 = nn.Sequential(RandomVerticalFlip(p=1.0, return_transform=True), RandomVerticalFlip(p=1.0))
input = torch.tensor(
[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
expected_transform = torch.tensor(
[[[1.0, 0.0, 0.0], [0.0, -1.0, 2.0], [0.0, 0.0, 1.0]]], device=device, dtype=dtype
) # 1 x 3 x 3
expected_transform_1 = expected_transform @ expected_transform
assert_close(f(input)[0], input, atol=1e-4, rtol=1e-4)
assert_close(f(input)[1], expected_transform_1, atol=1e-4, rtol=1e-4)
assert_close(f1(input)[0], input, atol=1e-4, rtol=1e-4)
assert_close(f1(input)[1], expected_transform, atol=1e-4, rtol=1e-4)
def test_random_vflip_coord_check(self, device, dtype):
f = RandomVerticalFlip(p=1.0, return_transform=True)
input = torch.tensor(
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 4
input_coordinates = torch.tensor(
[
[
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], # x coord
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], # y coord
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
],
device=device,
dtype=dtype,
) # 1 x 3 x 3
expected_output = torch.tensor(
[[[[9.0, 10.0, 11.0, 12.0], [5.0, 6.0, 7.0, 8.0], [1.0, 2.0, 3.0, 4.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 4
output, transform = f(input)
result_coordinates = transform @ input_coordinates
# NOTE: without rounding it might produce unexpected results
input_coordinates = input_coordinates.round().long()
result_coordinates = result_coordinates.round().long()
# Tensors must have the same shapes and values
assert output.shape == expected_output.shape
assert (output == expected_output).all()
# Transformed indices must not be out of bound
assert (
torch.torch.logical_and(result_coordinates[0, 0, :] >= 0, result_coordinates[0, 0, :] < input.shape[-1])
).all()
assert (
torch.torch.logical_and(result_coordinates[0, 1, :] >= 0, result_coordinates[0, 1, :] < input.shape[-2])
).all()
# Values in the output tensor at the places of transformed indices must
# have the same value as the input tensor has at the corresponding
# positions
assert (
output[..., result_coordinates[0, 1, :], result_coordinates[0, 0, :]]
== input[..., input_coordinates[0, 1, :], input_coordinates[0, 0, :]]
).all()
class TestColorJitter:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = ColorJitter(brightness=0.5, contrast=0.3, saturation=[0.2, 1.2], hue=0.1)
repr = (
"ColorJitter(brightness=tensor([0.5000, 1.5000]), contrast=tensor([0.7000, 1.3000]), "
"saturation=tensor([0.2000, 1.2000]), hue=tensor([-0.1000, 0.1000]), "
"p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)"
)
assert str(f) == repr
def test_color_jitter(self, device, dtype):
f = ColorJitter()
f1 = ColorJitter(return_transform=True)
input = torch.rand(3, 5, 5, device=device, dtype=dtype).unsqueeze(0) # 3 x 5 x 5
expected = input
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) # 3 x 3
assert_close(f(input), expected, atol=1e-4, rtol=1e-5)
assert_close(f1(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_close(f1(input)[1], expected_transform, atol=1e-4, rtol=1e-4)
def test_color_jitter_batch(self, device, dtype):
f = ColorJitter()
f1 = ColorJitter(return_transform=True)
input = torch.rand(2, 3, 5, 5, device=device, dtype=dtype) # 2 x 3 x 5 x 5
expected = input
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand((2, 3, 3)) # 2 x 3 x 3
assert_close(f(input), expected, atol=1e-4, rtol=1e-5)
assert_close(f1(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_close(f1(input)[1], expected_transform, atol=1e-4, rtol=1e-4)
def test_same_on_batch(self, device, dtype):
f = ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, same_on_batch=True)
input = torch.eye(3).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
def _get_expected_brightness(self, device, dtype):
return torch.tensor(
[
[
[[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]],
[[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]],
[[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]],
],
[
[[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]],
[[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]],
[[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]],
],
],
device=device,
dtype=dtype,
)
def test_random_brightness(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(brightness=0.2)
input = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
input = input.repeat(2, 3, 1, 1) # 2 x 3 x 3
expected = self._get_expected_brightness(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def test_random_brightness_tuple(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(brightness=(0.8, 1.2))
input = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
input = input.repeat(2, 3, 1, 1) # 2 x 3 x 3
expected = self._get_expected_brightness(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def _get_expected_contrast(self, device, dtype):
return torch.tensor(
[
[
[[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]],
[[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]],
[[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]],
],
[
[[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]],
[[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]],
[[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]],
],
],
device=device,
dtype=dtype,
)
def test_random_contrast(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(contrast=0.2)
input = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
input = input.repeat(2, 3, 1, 1) # 2 x 3 x 3
expected = self._get_expected_contrast(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-5)
def test_random_contrast_list(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(contrast=[0.8, 1.2])
input = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]]]], device=device, dtype=dtype
) # 1 x 1 x 3 x 3
input = input.repeat(2, 3, 1, 1) # 2 x 3 x 3
expected = self._get_expected_contrast(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-5)
def _get_expected_saturation(self, device, dtype):
return torch.tensor(
[
[
[[0.1876, 0.2584, 0.3389], [0.6292, 0.5000, 0.4000], [0.7097, 0.8000, 1.0000]],
[[1.0000, 0.5292, 0.6097], [0.6292, 0.3195, 0.2195], [0.8000, 0.1682, 0.2779]],
[[0.6389, 0.8000, 0.7000], [0.9000, 0.3195, 0.2195], [0.8000, 0.4389, 0.5487]],
],
[
[[0.0000, 0.1295, 0.2530], [0.5648, 0.5000, 0.4000], [0.6883, 0.8000, 1.0000]],
[[1.0000, 0.4648, 0.5883], [0.5648, 0.2765, 0.1765], [0.8000, 0.0178, 0.1060]],
[[0.5556, 0.8000, 0.7000], [0.9000, 0.2765, 0.1765], [0.8000, 0.3530, 0.4413]],
],
],
device=device,
dtype=dtype,
)
def test_random_saturation(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(saturation=0.2)
input = torch.tensor(
[
[
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]],
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]],
]
],
device=device,
dtype=dtype,
) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = self._get_expected_saturation(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def test_random_saturation_tensor(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(saturation=torch.tensor(0.2))
input = torch.tensor(
[
[
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]],
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]],
]
],
device=device,
dtype=dtype,
) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = self._get_expected_saturation(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def test_random_saturation_tuple(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(saturation=(0.8, 1.2))
input = torch.tensor(
[
[
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]],
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]],
]
],
device=device,
dtype=dtype,
) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = self._get_expected_saturation(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def _get_expected_hue(self, device, dtype):
return torch.tensor(
[
[
[[0.1000, 0.2000, 0.3000], [0.6000, 0.5000, 0.4000], [0.7000, 0.8000, 1.0000]],
[[1.0000, 0.5251, 0.6167], [0.6126, 0.3000, 0.2000], [0.8000, 0.1000, 0.2000]],
[[0.5623, 0.8000, 0.7000], [0.9000, 0.3084, 0.2084], [0.7958, 0.4293, 0.5335]],
],
[
[[0.1000, 0.2000, 0.3000], [0.6116, 0.5000, 0.4000], [0.7000, 0.8000, 1.0000]],
[[1.0000, 0.4769, 0.5846], [0.6000, 0.3077, 0.2077], [0.7961, 0.1000, 0.2000]],
[[0.6347, 0.8000, 0.7000], [0.9000, 0.3000, 0.2000], [0.8000, 0.3730, 0.4692]],
],
],
device=device,
dtype=dtype,
)
def test_random_hue(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(hue=0.1 / pi.item())
input = torch.tensor(
[
[
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]],
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]],
]
],
device=device,
dtype=dtype,
) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = self._get_expected_hue(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def test_random_hue_list(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(hue=[-0.1 / pi, 0.1 / pi])
input = torch.tensor(
[
[
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]],
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]],
]
],
device=device,
dtype=dtype,
) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = self._get_expected_hue(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def test_random_hue_list_batch(self, device, dtype):
torch.manual_seed(42)
f = ColorJitter(hue=[-0.1 / pi.item(), 0.1 / pi.item()])
input = torch.tensor(
[
[
[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.0]],
[[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, 0.5]],
]
],
device=device,
dtype=dtype,
) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = self._get_expected_hue(device, dtype)
assert_close(f(input), expected, atol=1e-4, rtol=1e-4)
def test_sequential(self, device, dtype):
f = nn.Sequential(ColorJitter(return_transform=True), ColorJitter(return_transform=True))
input = torch.rand(3, 5, 5, device=device, dtype=dtype).unsqueeze(0) # 1 x 3 x 5 x 5
expected = input
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) # 3 x 3
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-5)
def test_color_jitter_batch_sequential(self, device, dtype):
f = nn.Sequential(ColorJitter(return_transform=True), ColorJitter(return_transform=True))
input = torch.rand(2, 3, 5, 5, device=device, dtype=dtype) # 2 x 3 x 5 x 5
expected = input
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand((2, 3, 3)) # 2 x 3 x 3
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-5)
def test_gradcheck(self, device, dtype):
input = torch.rand((3, 5, 5), device=device, dtype=dtype).unsqueeze(0) # 3 x 3
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(ColorJitter(p=1.0), (input,), raise_exception=True)
class TestRectangleRandomErasing:
@pytest.mark.parametrize("erase_scale_range", [(0.001, 0.001), (1.0, 1.0)])
@pytest.mark.parametrize("aspect_ratio_range", [(0.1, 0.1), (10.0, 10.0)])
@pytest.mark.parametrize("batch_shape", [(1, 4, 8, 15), (2, 3, 11, 7)])
def test_random_rectangle_erasing_shape(self, batch_shape, erase_scale_range, aspect_ratio_range):
input = torch.rand(batch_shape)
rand_rec = RandomErasing(erase_scale_range, aspect_ratio_range, p=1.0)
assert rand_rec(input).shape == batch_shape
@pytest.mark.parametrize("erase_scale_range", [(0.001, 0.001), (1.0, 1.0)])
@pytest.mark.parametrize("aspect_ratio_range", [(0.1, 0.1), (10.0, 10.0)])
@pytest.mark.parametrize("batch_shape", [(1, 4, 8, 15), (2, 3, 11, 7)])
def test_no_rectangle_erasing_shape(self, batch_shape, erase_scale_range, aspect_ratio_range):
input = torch.rand(batch_shape)
rand_rec = RandomErasing(erase_scale_range, aspect_ratio_range, p=0.0)
assert rand_rec(input).equal(input)
@pytest.mark.parametrize("erase_scale_range", [(0.001, 0.001), (1.0, 1.0)])
@pytest.mark.parametrize("aspect_ratio_range", [(0.1, 0.1), (10.0, 10.0)])
@pytest.mark.parametrize("shape", [(3, 11, 7)])
def test_same_on_batch(self, shape, erase_scale_range, aspect_ratio_range):
f = RandomErasing(erase_scale_range, aspect_ratio_range, same_on_batch=True, p=0.5)
input = torch.rand(shape).unsqueeze(dim=0).repeat(2, 1, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
def test_gradcheck(self, device, dtype):
# test parameters
batch_shape = (2, 3, 11, 7)
erase_scale_range = (0.2, 0.4)
aspect_ratio_range = (0.3, 0.5)
rand_rec = RandomErasing(erase_scale_range, aspect_ratio_range, p=1.0)
rect_params = rand_rec.forward_parameters(batch_shape)
# evaluate function gradient
input = torch.rand(batch_shape, device=device, dtype=dtype)
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(rand_rec, (input, rect_params), raise_exception=True)
class TestRandomGrayscale:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = RandomGrayscale()
repr = "RandomGrayscale(p=0.1, p_batch=1.0, same_on_batch=False, return_transform=False)"
assert str(f) == repr
def test_random_grayscale(self, device, dtype):
f = RandomGrayscale(return_transform=True)
input = torch.rand(3, 5, 5, device=device, dtype=dtype) # 3 x 5 x 5
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) # 3 x 3
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4)
def test_same_on_batch(self, device, dtype):
f = RandomGrayscale(p=0.5, same_on_batch=True)
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
def test_opencv_true(self, device, dtype):
data = torch.tensor(
[
[
[
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912],
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008],
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481],
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344],
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593],
],
[
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274],
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100],
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997],
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165],
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426],
],
[
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074],
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513],
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770],
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012],
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532],
],
]
],
device=device,
dtype=dtype,
)
# Output data generated with OpenCV 4.1.1: cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
expected = torch.tensor(
[
[
[
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016],
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204],
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113],
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529],
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805],
],
[
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016],
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204],
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113],
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529],
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805],
],
[
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016],
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204],
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113],
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529],
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805],
],
]
],
device=device,
dtype=dtype,
)
img_gray = RandomGrayscale(p=1.0)(data)
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4)
def test_opencv_false(self, device, dtype):
data = torch.tensor(
[
[
[
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912],
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008],
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481],
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344],
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593],
],
[
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274],
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100],
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997],
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165],
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426],
],
[
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074],
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513],
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770],
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012],
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532],
],
]
],
device=device,
dtype=dtype,
)
expected = data
img_gray = RandomGrayscale(p=0.0)(data)
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4)
def test_opencv_true_batch(self, device, dtype):
data = torch.tensor(
[
[
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912],
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008],
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481],
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344],
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593],
],
[
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274],
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100],
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997],
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165],
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426],
],
[
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074],
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513],
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770],
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012],
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532],
],
],
device=device,
dtype=dtype,
)
data = data.unsqueeze(0).repeat(4, 1, 1, 1)
# Output data generated with OpenCV 4.1.1: cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
expected = torch.tensor(
[
[
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016],
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204],
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113],
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529],
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805],
],
[
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016],
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204],
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113],
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529],
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805],
],
[
[0.4684734, 0.8954562, 0.6064363, 0.5236061, 0.6106016],
[0.1709944, 0.5133104, 0.7915002, 0.5745703, 0.1680204],
[0.5279005, 0.6092287, 0.3034387, 0.5333768, 0.6064113],
[0.3503858, 0.5720159, 0.7052018, 0.4558409, 0.3261529],
[0.6988886, 0.5897652, 0.6532392, 0.7234108, 0.7218805],
],
],
device=device,
dtype=dtype,
)
expected = expected.unsqueeze(0).repeat(4, 1, 1, 1)
img_gray = RandomGrayscale(p=1.0)(data)
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4)
def test_opencv_false_batch(self, device, dtype):
data = torch.tensor(
[
[
[0.3944633, 0.8597369, 0.1670904, 0.2825457, 0.0953912],
[0.1251704, 0.8020709, 0.8933256, 0.9170977, 0.1497008],
[0.2711633, 0.1111478, 0.0783281, 0.2771807, 0.5487481],
[0.0086008, 0.8288748, 0.9647092, 0.8922020, 0.7614344],
[0.2898048, 0.1282895, 0.7621747, 0.5657831, 0.9918593],
],
[
[0.5414237, 0.9962701, 0.8947155, 0.5900949, 0.9483274],
[0.0468036, 0.3933847, 0.8046577, 0.3640994, 0.0632100],
[0.6171775, 0.8624780, 0.4126036, 0.7600935, 0.7279997],
[0.4237089, 0.5365476, 0.5591233, 0.1523191, 0.1382165],
[0.8932794, 0.8517839, 0.7152701, 0.8983801, 0.5905426],
],
[
[0.2869580, 0.4700376, 0.2743714, 0.8135023, 0.2229074],
[0.9306560, 0.3734594, 0.4566821, 0.7599275, 0.7557513],
[0.7415742, 0.6115875, 0.3317572, 0.0379378, 0.1315770],
[0.8692724, 0.0809556, 0.7767404, 0.8742208, 0.1522012],
[0.7708948, 0.4509611, 0.0481175, 0.2358997, 0.6900532],
],
],
device=device,
dtype=dtype,
)
data = data.unsqueeze(0).repeat(4, 1, 1, 1)
expected = data
img_gray = RandomGrayscale(p=0.0)(data)
assert_close(img_gray, expected, atol=1e-4, rtol=1e-4)
def test_random_grayscale_sequential_batch(self, device, dtype):
f = nn.Sequential(RandomGrayscale(p=0.0, return_transform=True), RandomGrayscale(p=0.0, return_transform=True))
input = torch.rand(2, 3, 5, 5, device=device, dtype=dtype) # 2 x 3 x 5 x 5
expected = input
expected_transform = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand((2, 3, 3)) # 2 x 3 x 3
expected_transform = expected_transform.to(device)
assert_close(f(input)[0], expected, atol=1e-4, rtol=1e-4)
assert_close(f(input)[1], expected_transform, atol=1e-4, rtol=1e-4)
def test_gradcheck(self, device, dtype):
input = torch.rand((3, 5, 5), device=device, dtype=dtype) # 3 x 3
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(RandomGrayscale(p=1.0), (input,), raise_exception=True)
assert gradcheck(RandomGrayscale(p=0.0), (input,), raise_exception=True)
class TestCenterCrop:
def test_no_transform(self, device, dtype):
inp = torch.rand(1, 2, 4, 4, device=device, dtype=dtype)
out = CenterCrop(2)(inp)
assert out.shape == (1, 2, 2, 2)
aug = CenterCrop(2, cropping_mode="resample")
out = aug(inp)
assert out.shape == (1, 2, 2, 2)
assert aug.inverse(out).shape == (1, 2, 4, 4)
def test_transform(self, device, dtype):
inp = torch.rand(1, 2, 5, 4, device=device, dtype=dtype)
out = CenterCrop(2, return_transform=True)(inp)
assert len(out) == 2
assert out[0].shape == (1, 2, 2, 2)
assert out[1].shape == (1, 3, 3)
aug = CenterCrop(2, cropping_mode="resample", return_transform=True)
out = aug(inp)
assert out[0].shape == (1, 2, 2, 2)
assert out[1].shape == (1, 3, 3)
assert aug.inverse(out).shape == (1, 2, 5, 4)
def test_no_transform_tuple(self, device, dtype):
inp = torch.rand(1, 2, 5, 4, device=device, dtype=dtype)
out = CenterCrop((3, 4))(inp)
assert out.shape == (1, 2, 3, 4)
aug = CenterCrop((3, 4), cropping_mode="resample")
out = aug(inp)
assert out.shape == (1, 2, 3, 4)
assert aug.inverse(out).shape == (1, 2, 5, 4)
def test_crop_modes(self, device, dtype):
torch.manual_seed(0)
img = torch.rand(1, 3, 5, 5, device=device, dtype=dtype)
op1 = CenterCrop(size=(2, 2), cropping_mode='resample')
out = op1(img)
op2 = CenterCrop(size=(2, 2), cropping_mode='slice')
assert_close(out, op2(img, op1._params))
def test_gradcheck(self, device, dtype):
input = torch.rand(1, 2, 3, 4, device=device, dtype=dtype)
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(CenterCrop(3), (input,), raise_exception=True)
class TestRandomRotation:
torch.manual_seed(0) # for random reproductibility
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = RandomRotation(degrees=45.5)
repr = (
"RandomRotation(degrees=tensor([-45.5000, 45.5000]), interpolation=BILINEAR, p=0.5, "
"p_batch=1.0, same_on_batch=False, return_transform=False)"
)
assert str(f) == repr
def test_random_rotation(self, device, dtype):
# This is included in doctest
torch.manual_seed(0) # for random reproductibility
f = RandomRotation(degrees=45.0, return_transform=True, p=1.0)
f1 = RandomRotation(degrees=45.0, p=1.0)
input = torch.tensor(
[[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]],
device=device,
dtype=dtype,
) # 4 x 4
expected = torch.tensor(
[
[
[
[0.9824, 0.0088, 0.0000, 1.9649],
[0.0000, 0.0029, 0.0000, 0.0176],
[0.0029, 1.0000, 1.9883, 0.0000],
[0.0000, 0.0088, 1.0117, 1.9649],
]
]
],
device=device,
dtype=dtype,
) # 1 x 4 x 4
expected_transform = torch.tensor(
[[[1.0000, -0.0059, 0.0088], [0.0059, 1.0000, -0.0088], [0.0000, 0.0000, 1.0000]]],
device=device,
dtype=dtype,
) # 1 x 3 x 3
expected_2 = torch.tensor(
[
[
[
[0.1322, 0.0000, 0.7570, 0.2644],
[0.3785, 0.0000, 0.4166, 0.0000],
[0.0000, 0.6309, 1.5910, 1.2371],
[0.0000, 0.1444, 0.3177, 0.6499],
]
]
],
device=device,
dtype=dtype,
) # 1 x 4 x 4
out, mat = f(input)
assert_close(out, expected, rtol=1e-6, atol=1e-4)
assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4)
assert_close(f1(input), expected_2, rtol=1e-6, atol=1e-4)
def test_batch_random_rotation(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
f = RandomRotation(degrees=45.0, return_transform=True, p=1.0)
input = torch.tensor(
[[[[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]]]],
device=device,
dtype=dtype,
) # 1 x 1 x 4 x 4
expected = torch.tensor(
[
[
[
[0.9824, 0.0088, 0.0000, 1.9649],
[0.0000, 0.0029, 0.0000, 0.0176],
[0.0029, 1.0000, 1.9883, 0.0000],
[0.0000, 0.0088, 1.0117, 1.9649],
]
],
[
[
[0.1322, 0.0000, 0.7570, 0.2644],
[0.3785, 0.0000, 0.4166, 0.0000],
[0.0000, 0.6309, 1.5910, 1.2371],
[0.0000, 0.1444, 0.3177, 0.6499],
]
],
],
device=device,
dtype=dtype,
) # 2 x 1 x 4 x 4
expected_transform = torch.tensor(
[
[[1.0000, -0.0059, 0.0088], [0.0059, 1.0000, -0.0088], [0.0000, 0.0000, 1.0000]],
[[0.9125, 0.4090, -0.4823], [-0.4090, 0.9125, 0.7446], [0.0000, 0.0000, 1.0000]],
],
device=device,
dtype=dtype,
) # 2 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 5 x 3 x 3 x 3
out, mat = f(input)
assert_close(out, expected, rtol=1e-4, atol=1e-4)
assert_close(mat, expected_transform, rtol=1e-4, atol=1e-4)
def test_same_on_batch(self, device, dtype):
f = RandomRotation(degrees=40, same_on_batch=True)
input = torch.eye(6, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
def test_sequential(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
f = nn.Sequential(
RandomRotation(torch.tensor([-45.0, 90]), return_transform=True, p=1.0),
RandomRotation(10.4, return_transform=True, p=1.0),
)
f1 = nn.Sequential(
RandomRotation(torch.tensor([-45.0, 90]), return_transform=True, p=1.0), RandomRotation(10.4, p=1.0)
)
input = torch.tensor(
[[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]],
device=device,
dtype=dtype,
) # 4 x 4
expected = torch.tensor(
[
[
[
[0.1314, 0.1050, 0.6649, 0.2628],
[0.3234, 0.0202, 0.4256, 0.1671],
[0.0525, 0.5976, 1.5199, 1.1306],
[0.0000, 0.1453, 0.3224, 0.5796],
]
]
],
device=device,
dtype=dtype,
) # 1 x 4 x 4
expected_transform = torch.tensor(
[[[0.8864, 0.4629, -0.5240], [-0.4629, 0.8864, 0.8647], [0.0000, 0.0000, 1.0000]]],
device=device,
dtype=dtype,
) # 1 x 3 x 3
expected_transform_2 = torch.tensor(
[[[0.8381, -0.5455, 1.0610], [0.5455, 0.8381, -0.5754], [0.0000, 0.0000, 1.0000]]],
device=device,
dtype=dtype,
) # 1 x 3 x 3
out, mat = f(input)
_, mat_2 = f1(input)
assert_close(out, expected, rtol=1e-4, atol=1e-4)
assert_close(mat, expected_transform, rtol=1e-4, atol=1e-4)
assert_close(mat_2, expected_transform_2, rtol=1e-4, atol=1e-4)
def test_gradcheck(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
input = torch.rand((3, 3), device=device, dtype=dtype) # 3 x 3
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(RandomRotation(degrees=(15.0, 15.0), p=1.0), (input,), raise_exception=True)
class TestRandomCrop:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = RandomCrop(size=(2, 3), padding=(0, 1), fill=10, pad_if_needed=False, p=1.0)
repr = (
"RandomCrop(crop_size=(2, 3), padding=(0, 1), fill=10, pad_if_needed=False, padding_mode=constant, "
"resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)"
)
assert str(f) == repr
def test_no_padding(self, device, dtype):
torch.manual_seed(0)
inp = torch.tensor([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype)
expected = torch.tensor([[[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype)
rc = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0)
out = rc(inp)
torch.manual_seed(0)
out2 = rc(inp.squeeze())
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(out2, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(0)
inversed = torch.tensor([[[[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype)
aug = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0, cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_no_padding_batch(self, device, dtype):
torch.manual_seed(42)
batch_size = 2
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat(
batch_size, 1, 1, 1
)
expected = torch.tensor(
[[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]], [[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype
)
rc = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0)
out = rc(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(42)
inversed = torch.tensor(
[
[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]],
],
device=device,
dtype=dtype,
)
aug = RandomCrop(size=(2, 3), padding=None, align_corners=True, p=1.0, cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_same_on_batch(self, device, dtype):
f = RandomCrop(size=(2, 3), padding=1, same_on_batch=True, align_corners=True, p=1.0)
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
def test_padding(self, device, dtype):
torch.manual_seed(42)
inp = torch.tensor([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype)
expected = torch.tensor([[[[7.0, 8.0, 7.0], [4.0, 5.0, 4.0]]]], device=device, dtype=dtype)
rc = RandomCrop(size=(2, 3), padding=1, padding_mode='reflect', align_corners=True, p=1.0)
out = rc(inp)
torch.manual_seed(42)
out2 = rc(inp.squeeze())
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(out2, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(42)
inversed = torch.tensor([[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 7.0, 8.0]]]], device=device, dtype=dtype)
aug = RandomCrop(
size=(2, 3), padding=1, padding_mode='reflect', align_corners=True, p=1.0, cropping_mode="resample"
)
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_padding_batch_1(self, device, dtype):
torch.manual_seed(42)
batch_size = 2
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat(
batch_size, 1, 1, 1
)
expected = torch.tensor(
[[[[1.0, 2.0, 0.0], [4.0, 5.0, 0.0]]], [[[7.0, 8.0, 0.0], [0.0, 0.0, 0.0]]]], device=device, dtype=dtype
)
rc = RandomCrop(size=(2, 3), padding=1, align_corners=True, p=1.0)
out = rc(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(42)
inversed = torch.tensor(
[
[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 7.0, 8.0]]],
],
device=device,
dtype=dtype,
)
aug = RandomCrop(size=(2, 3), padding=1, align_corners=True, p=1.0, cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_padding_batch_2(self, device, dtype):
torch.manual_seed(42)
batch_size = 2
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat(
batch_size, 1, 1, 1
)
expected = torch.tensor(
[[[[1.0, 2.0, 10.0], [4.0, 5.0, 10.0]]], [[[4.0, 5.0, 10.0], [7.0, 8.0, 10.0]]]], device=device, dtype=dtype
)
rc = RandomCrop(size=(2, 3), padding=(0, 1), fill=10, align_corners=True, p=1.0)
out = rc(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(42)
inversed = torch.tensor(
[
[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 0.0, 0.0]]],
[[[0.0, 0.0, 0.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]],
],
device=device,
dtype=dtype,
)
aug = RandomCrop(size=(2, 3), padding=(0, 1), fill=10, align_corners=True, p=1.0, cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_padding_batch_3(self, device, dtype):
torch.manual_seed(0)
batch_size = 2
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat(
batch_size, 1, 1, 1
)
expected = torch.tensor(
[[[[8.0, 8.0, 8.0], [8.0, 0.0, 1.0]]], [[[8.0, 8.0, 8.0], [1.0, 2.0, 8.0]]]], device=device, dtype=dtype
)
rc = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=8, align_corners=True, p=1.0)
out = rc(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(0)
inversed = torch.tensor(
[
[[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
[[[0.0, 1.0, 2.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
],
device=device,
dtype=dtype,
)
aug = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=8, align_corners=True, p=1.0, cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_padding_no_forward(self, device, dtype):
torch.manual_seed(0)
inp = torch.tensor([[[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype)
trans = torch.tensor([[[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype)
# Not return transform
rc = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=9, align_corners=True, p=0.0)
out = rc(inp)
assert_close(out, inp, atol=1e-4, rtol=1e-4)
out = rc((inp, trans))
assert_close(out[0], inp, atol=1e-4, rtol=1e-4)
assert_close(out[1], trans, atol=1e-4, rtol=1e-4)
# with return transform
rc = RandomCrop(size=(2, 3), padding=(0, 1, 2, 3), fill=9, align_corners=True, p=0.0, return_transform=True)
out = rc(inp)
assert_close(out[0], inp, atol=1e-4, rtol=1e-4)
out = rc((inp, trans))
assert_close(out[0], inp, atol=1e-4, rtol=1e-4)
assert_close(out[1], trans, atol=1e-4, rtol=1e-4)
def test_pad_if_needed(self, device, dtype):
torch.manual_seed(0)
batch_size = 2
inp = torch.tensor([[[0.0, 1.0, 2.0]]], device=device, dtype=dtype).repeat(batch_size, 1, 1, 1)
expected = torch.tensor(
[[[[9.0, 9.0, 9.0], [0.0, 1.0, 2.0]]], [[[9.0, 9.0, 9.0], [0.0, 1.0, 2.0]]]], device=device, dtype=dtype
)
rc = RandomCrop(size=(2, 3), pad_if_needed=True, fill=9, align_corners=True, p=1.0)
out = rc(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(0)
inversed = torch.tensor([[[[0.0, 1.0, 2.0]]], [[[0.0, 1.0, 2.0]]]], device=device, dtype=dtype)
aug = RandomCrop(size=(2, 3), pad_if_needed=True, fill=9, align_corners=True, p=1.0, cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_crop_modes(self, device, dtype):
torch.manual_seed(0)
img = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype)
op1 = RandomCrop(size=(2, 2), cropping_mode='resample')
out = op1(img)
op2 = RandomCrop(size=(2, 2), cropping_mode='slice')
assert_close(out, op2(img, op1._params))
def test_gradcheck(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
inp = torch.rand((3, 3, 3), device=device, dtype=dtype) # 3 x 3
inp = utils.tensor_to_gradcheck_var(inp) # to var
assert gradcheck(RandomCrop(size=(3, 3), p=1.0), (inp,), raise_exception=True)
@pytest.mark.skip("Need to fix Union type")
def test_jit(self, device, dtype):
# Define script
op = RandomCrop(size=(3, 3), p=1.0).forward
op_script = torch.jit.script(op)
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype)
actual = op_script(img)
expected = kornia.geometry.transform.center_crop3d(img)
assert_close(actual, expected, atol=1e-4, rtol=1e-4)
@pytest.mark.skip("Need to fix Union type")
def test_jit_trace(self, device, dtype):
# Define script
op = RandomCrop(size=(3, 3), p=1.0).forward
op_script = torch.jit.script(op)
# 1. Trace op
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype)
op_trace = torch.jit.trace(op_script, (img,))
# 2. Generate new input
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype)
# 3. Evaluate
actual = op_trace(img)
expected = op(img)
assert_close(actual, expected, atol=1e-4, rtol=1e-4)
class TestRandomResizedCrop:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0))
repr = (
"RandomResizedCrop(size=(2, 3), scale=tensor([1., 1.]), ratio=tensor([1., 1.]), "
"interpolation=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)"
)
assert str(f) == repr
def test_no_resize(self, device, dtype):
torch.manual_seed(0)
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype)
expected = torch.tensor([[[[0.0000, 1.0000, 2.0000], [6.0000, 7.0000, 8.0000]]]], device=device, dtype=dtype)
rrc = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0))
# It will crop a size of (2, 3) from the aspect ratio implementation of torch
out = rrc(inp)
assert_close(out, expected, rtol=1e-4, atol=1e-4)
torch.manual_seed(0)
aug = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0), cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inp[None], atol=1e-4, rtol=1e-4)
def test_same_on_batch(self, device, dtype):
f = RandomResizedCrop(size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0), same_on_batch=True)
input = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat(
2, 1, 1, 1
)
res = f(input)
assert (res[0] == res[1]).all()
torch.manual_seed(0)
aug = RandomResizedCrop(
size=(2, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0), same_on_batch=True, cropping_mode="resample"
)
out = aug(input)
inversed = aug.inverse(out)
assert (inversed[0] == inversed[1]).all()
def test_crop_scale_ratio(self, device, dtype):
# This is included in doctest
torch.manual_seed(0)
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype)
expected = torch.tensor(
[[[[1.0000, 1.5000, 2.0000], [4.0000, 4.5000, 5.0000], [7.0000, 7.5000, 8.0000]]]],
device=device,
dtype=dtype,
)
rrc = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0))
# It will crop a size of (3, 3) from the aspect ratio implementation of torch
out = rrc(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
torch.manual_seed(0)
inversed = torch.tensor([[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]]], device=device, dtype=dtype)
aug = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0), cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_crop_size_greater_than_input(self, device, dtype):
# This is included in doctest
torch.manual_seed(0)
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype)
exp = torch.tensor(
[
[
[
[1.0000, 1.3333, 1.6667, 2.0000],
[3.0000, 3.3333, 3.6667, 4.0000],
[5.0000, 5.3333, 5.6667, 6.0000],
[7.0000, 7.3333, 7.6667, 8.0000],
]
]
],
device=device,
dtype=dtype,
)
rrc = RandomResizedCrop(size=(4, 4), scale=(3.0, 3.0), ratio=(2.0, 2.0))
# It will crop a size of (3, 3) from the aspect ratio implementation of torch
out = rrc(inp)
assert out.shape == torch.Size([1, 1, 4, 4])
assert_close(out, exp, atol=1e-4, rtol=1e-4)
torch.manual_seed(0)
inversed = torch.tensor([[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]]], device=device, dtype=dtype)
aug = RandomResizedCrop(size=(4, 4), scale=(3.0, 3.0), ratio=(2.0, 2.0), cropping_mode="resample")
out = aug(inp)
assert_close(out, exp, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_crop_scale_ratio_batch(self, device, dtype):
torch.manual_seed(0)
batch_size = 2
inp = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype).repeat(
batch_size, 1, 1, 1
)
expected = torch.tensor(
[
[[[1.0000, 1.5000, 2.0000], [4.0000, 4.5000, 5.0000], [7.0000, 7.5000, 8.0000]]],
[[[0.0000, 0.5000, 1.0000], [3.0000, 3.5000, 4.0000], [6.0000, 6.5000, 7.0000]]],
],
device=device,
dtype=dtype,
)
rrc = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0))
# It will crop a size of (2, 2) from the aspect ratio implementation of torch
out = rrc(inp)
assert_close(out, expected, rtol=1e-4, atol=1e-4)
torch.manual_seed(0)
inversed = torch.tensor(
[
[[[0.0, 1.0, 2.0], [0.0, 4.0, 5.0], [0.0, 7.0, 8.0]]],
[[[0.0, 1.0, 0.0], [3.0, 4.0, 0.0], [6.0, 7.0, 0.0]]],
],
device=device,
dtype=dtype,
)
aug = RandomResizedCrop(size=(3, 3), scale=(3.0, 3.0), ratio=(2.0, 2.0), cropping_mode="resample")
out = aug(inp)
assert_close(out, expected, atol=1e-4, rtol=1e-4)
assert_close(aug.inverse(out), inversed, atol=1e-4, rtol=1e-4)
def test_crop_modes(self, device, dtype):
torch.manual_seed(0)
img = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]], device=device, dtype=dtype)
op1 = RandomResizedCrop(size=(4, 4), cropping_mode='resample')
out = op1(img)
op2 = RandomResizedCrop(size=(4, 4), cropping_mode='slice')
assert_close(out, op2(img, op1._params))
def test_gradcheck(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
inp = torch.rand((1, 3, 3), device=device, dtype=dtype) # 3 x 3
inp = utils.tensor_to_gradcheck_var(inp) # to var
assert gradcheck(
RandomResizedCrop(size=(3, 3), scale=(1.0, 1.0), ratio=(1.0, 1.0)), (inp,), raise_exception=True
)
class TestRandomEqualize:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing precision.")
def test_smoke(self, device, dtype):
f = RandomEqualize(p=0.5)
repr = "RandomEqualize(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)"
assert str(f) == repr
def test_random_equalize(self, device, dtype):
f = RandomEqualize(p=1.0, return_transform=True)
f1 = RandomEqualize(p=0.0, return_transform=True)
f2 = RandomEqualize(p=1.0)
f3 = RandomEqualize(p=0.0)
bs, channels, height, width = 1, 3, 20, 20
inputs = self.build_input(channels, height, width, bs, device=device, dtype=dtype)
row_expected = torch.tensor(
[
0.0000,
0.07843,
0.15686,
0.2353,
0.3137,
0.3922,
0.4706,
0.5490,
0.6275,
0.7059,
0.7843,
0.8627,
0.9412,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
]
)
expected = self.build_input(channels, height, width, bs=1, row=row_expected, device=device, dtype=dtype)
identity = kornia.eye_like(3, expected) # 3 x 3
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4)
def test_batch_random_equalize(self, device, dtype):
f = RandomEqualize(p=1.0, return_transform=True)
f1 = RandomEqualize(p=0.0, return_transform=True)
f2 = RandomEqualize(p=1.0)
f3 = RandomEqualize(p=0.0)
bs, channels, height, width = 2, 3, 20, 20
inputs = self.build_input(channels, height, width, bs, device=device, dtype=dtype)
row_expected = torch.tensor(
[
0.0000,
0.07843,
0.15686,
0.2353,
0.3137,
0.3922,
0.4706,
0.5490,
0.6275,
0.7059,
0.7843,
0.8627,
0.9412,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
1.0000,
]
)
expected = self.build_input(channels, height, width, bs, row=row_expected, device=device, dtype=dtype)
identity = kornia.eye_like(3, expected) # 2 x 3 x 3
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4)
def test_same_on_batch(self, device, dtype):
f = RandomEqualize(p=0.5, same_on_batch=True)
input = torch.eye(4, device=device, dtype=dtype)
input = input.unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
def test_gradcheck(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
input = torch.rand((3, 3, 3), device=device, dtype=dtype) # 3 x 3 x 3
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(RandomEqualize(p=0.5), (input,), raise_exception=True)
@staticmethod
def build_input(channels, height, width, bs=1, row=None, device='cpu', dtype=torch.float32):
if row is None:
row = torch.arange(width, device=device, dtype=dtype) / float(width)
channel = torch.stack([row] * height)
image = torch.stack([channel] * channels)
batch = torch.stack([image] * bs)
return batch.to(device, dtype)
class TestGaussianBlur:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self):
f = RandomGaussianBlur((3, 3), (0.1, 2.0), p=1.0)
repr = "RandomGaussianBlur(p=1.0, p_batch=1.0, same_on_batch=False, return_transform=False)"
assert str(f) == repr
class TestRandomInvert:
def test_smoke(self, device, dtype):
img = torch.ones(1, 3, 4, 5, device=device, dtype=dtype)
assert_close(RandomInvert(p=1.0)(img), torch.zeros_like(img))
class TestRandomChannelShuffle:
def test_smoke(self, device, dtype):
torch.manual_seed(0)
img = torch.arange(1 * 3 * 2 * 2, device=device, dtype=dtype).view(1, 3, 2, 2)
out_expected = torch.tensor(
[[[[8.0, 9.0], [10.0, 11.0]], [[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]],
device=device,
dtype=dtype,
)
aug = RandomChannelShuffle(p=1.0)
out = aug(img)
assert_close(out, out_expected)
class TestRandomGaussianNoise:
def test_smoke(self, device, dtype):
torch.manual_seed(0)
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype)
aug = RandomGaussianNoise(p=1.0)
assert img.shape == aug(img).shape
class TestNormalize:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self, device, dtype):
f = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0]))
repr = (
"Normalize(mean=torch.tensor([1.]), std=torch.tensor([1.]), p=1., p_batch=1.0, "
"same_on_batch=False, return_transform=False)"
)
assert str(f) == repr
@staticmethod
@pytest.mark.parametrize(
"mean, std", [((1.0, 1.0, 1.0), (0.5, 0.5, 0.5)), (1.0, 0.5), (torch.tensor([1.0]), torch.tensor([0.5]))]
)
def test_random_normalize_different_parameter_types(mean, std):
f = Normalize(mean=mean, std=std, p=1)
data = torch.ones(2, 3, 256, 313)
if isinstance(mean, float):
expected = (data - torch.as_tensor(mean)) / torch.as_tensor(std)
else:
expected = (data - torch.as_tensor(mean[0])) / torch.as_tensor(std[0])
assert_close(f(data), expected)
@staticmethod
@pytest.mark.parametrize("mean, std", [((1.0, 1.0, 1.0, 1.0), (0.5, 0.5, 0.5, 0.5)), ((1.0, 1.0), (0.5, 0.5))])
def test_random_normalize_invalid_parameter_shape(mean, std):
f = Normalize(mean=mean, std=std, p=1.0, return_transform=True)
inputs = torch.arange(0.0, 16.0, step=1).reshape(1, 4, 4).unsqueeze(0)
with pytest.raises(ValueError):
f(inputs)
def test_random_normalize(self, device, dtype):
f = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True)
f1 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True)
f2 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0)
f3 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0)
inputs = torch.arange(0.0, 16.0, step=1, device=device, dtype=dtype).reshape(1, 4, 4).unsqueeze(0)
expected = (inputs - 1) * 2
identity = kornia.eye_like(3, expected)
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4)
def test_batch_random_normalize(self, device, dtype):
f = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True)
f1 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True)
f2 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0)
f3 = Normalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0)
inputs = torch.arange(0.0, 16.0 * 2, step=1, device=device, dtype=dtype).reshape(2, 1, 4, 4)
expected = (inputs - 1) * 2
identity = kornia.eye_like(3, expected)
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4)
def test_gradcheck(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
input = torch.rand((3, 3, 3), device=device, dtype=dtype) # 3 x 3 x 3
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(
Normalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0]), p=1.0), (input,), raise_exception=True
)
class TestDenormalize:
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
# return values such a torch.Tensor variable.
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
def test_smoke(self, device, dtype):
f = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0]))
repr = (
"Denormalize(mean=torch.tensor([1.]), std=torch.tensor([1.]), p=1., p_batch=1.0, "
"same_on_batch=False, return_transform=False)"
)
assert str(f) == repr
def test_random_denormalize(self, device, dtype):
f = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True)
f1 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True)
f2 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0)
f3 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0)
inputs = torch.arange(0.0, 16.0, step=1, device=device, dtype=dtype).reshape(1, 4, 4).unsqueeze(0)
expected = inputs / 2 + 1
identity = kornia.eye_like(3, expected)
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4)
def test_batch_random_denormalize(self, device, dtype):
f = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0, return_transform=True)
f1 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0, return_transform=True)
f2 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=1.0)
f3 = Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([0.5]), p=0.0)
inputs = torch.arange(0.0, 16.0 * 2, step=1, device=device, dtype=dtype).reshape(2, 1, 4, 4)
expected = inputs / 2 + 1
identity = kornia.eye_like(3, expected)
assert_close(f(inputs)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[0], inputs, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs), inputs, rtol=1e-4, atol=1e-4)
def test_gradcheck(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
input = torch.rand((3, 3, 3), device=device, dtype=dtype) # 3 x 3 x 3
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(
Denormalize(mean=torch.tensor([1.0]), std=torch.tensor([1.0]), p=1.0), (input,), raise_exception=True
)
class TestRandomFisheye:
def test_smoke(self, device, dtype):
torch.manual_seed(0)
center_x = torch.tensor([-0.3, 0.3])
center_y = torch.tensor([-0.3, 0.3])
gamma = torch.tensor([-1.0, 1.0])
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype)
aug = RandomFisheye(center_x, center_y, gamma, p=1.0)
assert img.shape == aug(img).shape
@pytest.mark.skip(reason="RuntimeError: Jacobian mismatch for output 0 with respect to input 0")
def test_gradcheck(self, device, dtype):
img = torch.rand(1, 1, 3, 3, device=device, dtype=dtype)
center_x = torch.tensor([-0.3, 0.3], device=device, dtype=dtype)
center_y = torch.tensor([-0.3, 0.3], device=device, dtype=dtype)
gamma = torch.tensor([-1.0, 1.0], device=device, dtype=dtype)
img = utils.tensor_to_gradcheck_var(img) # to var
center_x = utils.tensor_to_gradcheck_var(center_x) # to var
center_y = utils.tensor_to_gradcheck_var(center_y) # to var
gamma = utils.tensor_to_gradcheck_var(gamma) # to var
assert gradcheck(RandomFisheye(center_x, center_y, gamma), (img,), raise_exception=True)
class TestRandomElasticTransform:
def test_smoke(self, device, dtype):
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype)
aug = RandomElasticTransform(p=1.0)
assert img.shape == aug(img).shape
def test_same_on_batch(self, device, dtype):
f = RandomElasticTransform(p=1.0, same_on_batch=True)
input = torch.eye(3, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1)
res = f(input)
assert (res[0] == res[1]).all()
class TestRandomThinPlateSpline:
def test_smoke(self, device, dtype):
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype)
aug = RandomThinPlateSpline(p=1.0)
assert img.shape == aug(img).shape
class TestRandomBoxBlur:
def test_smoke(self, device, dtype):
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype)
aug = RandomBoxBlur(p=1.0)
assert img.shape == aug(img).shape
class TestPadTo:
def test_smoke(self, device, dtype):
img = torch.rand(1, 1, 2, 2, device=device, dtype=dtype)
aug = PadTo(size=(4, 5))
out = aug(img)
assert out.shape == (1, 1, 4, 5)
assert (aug.inverse(out) == img).all()