|
|
from unittest.mock import patch |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
from torch.autograd import gradcheck |
|
|
|
|
|
import kornia.testing as utils |
|
|
from kornia.augmentation.base import _BasicAugmentationBase, AugmentationBase2D |
|
|
from kornia.testing import assert_close |
|
|
|
|
|
|
|
|
class TestBasicAugmentationBase: |
|
|
def test_smoke(self, device, dtype): |
|
|
base = _BasicAugmentationBase(p=0.5, p_batch=1.0, same_on_batch=True) |
|
|
__repr__ = "p=0.5, p_batch=1.0, same_on_batch=True" |
|
|
assert str(base) == __repr__ |
|
|
|
|
|
def test_infer_input(self, device, dtype): |
|
|
input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype) |
|
|
augmentation = _BasicAugmentationBase(p=1.0, p_batch=1) |
|
|
with patch.object(augmentation, "transform_tensor", autospec=True) as transform_tensor: |
|
|
transform_tensor.side_effect = lambda x: x.unsqueeze(dim=2) |
|
|
output = augmentation.transform_tensor(input) |
|
|
assert output.shape == torch.Size([2, 3, 1, 4, 5]) |
|
|
assert_close(input, output[:, :, 0, :, :]) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"p,p_batch,same_on_batch,num,seed", |
|
|
[ |
|
|
(1.0, 1.0, False, 12, 1), |
|
|
(1.0, 0.0, False, 0, 1), |
|
|
(0.0, 1.0, False, 0, 1), |
|
|
(0.0, 0.0, False, 0, 1), |
|
|
(0.5, 0.1, False, 7, 3), |
|
|
(0.5, 0.1, True, 12, 3), |
|
|
(0.3, 1.0, False, 2, 1), |
|
|
(0.3, 1.0, True, 0, 1), |
|
|
], |
|
|
) |
|
|
def test_forward_params(self, p, p_batch, same_on_batch, num, seed, device, dtype): |
|
|
input_shape = (12,) |
|
|
torch.manual_seed(seed) |
|
|
augmentation = _BasicAugmentationBase(p, p_batch, same_on_batch) |
|
|
with patch.object(augmentation, "generate_parameters", autospec=True) as generate_parameters: |
|
|
generate_parameters.side_effect = lambda shape: { |
|
|
'degrees': torch.arange(0, shape[0], device=device, dtype=dtype) |
|
|
} |
|
|
output = augmentation.forward_parameters(input_shape) |
|
|
assert "batch_prob" in output |
|
|
assert len(output['degrees']) == output['batch_prob'].sum().item() == num |
|
|
|
|
|
@pytest.mark.parametrize('keepdim', (True, False)) |
|
|
def test_forward(self, device, dtype, keepdim): |
|
|
torch.manual_seed(42) |
|
|
input = torch.rand((12, 3, 4, 5), device=device, dtype=dtype) |
|
|
expected_output = input[..., :2, :2] if keepdim else input.unsqueeze(dim=0)[..., :2, :2] |
|
|
augmentation = _BasicAugmentationBase(p=0.3, p_batch=1.0, keepdim=keepdim) |
|
|
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object( |
|
|
augmentation, "generate_parameters", autospec=True |
|
|
) as generate_parameters, patch.object( |
|
|
augmentation, "transform_tensor", autospec=True |
|
|
) as transform_tensor, patch.object( |
|
|
augmentation, "__check_batching__", autospec=True |
|
|
) as check_batching: |
|
|
|
|
|
generate_parameters.side_effect = lambda shape: { |
|
|
'degrees': torch.arange(0, shape[0], device=device, dtype=dtype) |
|
|
} |
|
|
transform_tensor.side_effect = lambda x: x.unsqueeze(dim=0) |
|
|
apply_transform.side_effect = lambda input, params: input[..., :2, :2] |
|
|
check_batching.side_effect = lambda input: None |
|
|
output = augmentation(input) |
|
|
assert output.shape == expected_output.shape |
|
|
assert_close(output, expected_output) |
|
|
|
|
|
|
|
|
class TestAugmentationBase2D: |
|
|
@pytest.mark.parametrize( |
|
|
'input_shape, in_trans_shape', |
|
|
[ |
|
|
((2, 3, 4, 5), (2, 3, 3)), |
|
|
((3, 4, 5), (3, 3)), |
|
|
((4, 5), (3, 3)), |
|
|
pytest.param((1, 2, 3, 4, 5), (2, 3, 3), marks=pytest.mark.xfail), |
|
|
pytest.param((2, 3, 4, 5), (1, 3, 3), marks=pytest.mark.xfail), |
|
|
pytest.param((2, 3, 4, 5), (3, 3), marks=pytest.mark.xfail), |
|
|
], |
|
|
) |
|
|
def test_check_batching(self, device, dtype, input_shape, in_trans_shape): |
|
|
input = torch.rand(input_shape, device=device, dtype=dtype) |
|
|
in_trans = torch.rand(in_trans_shape, device=device, dtype=dtype) |
|
|
augmentation = AugmentationBase2D(p=1.0, p_batch=1) |
|
|
augmentation.__check_batching__(input) |
|
|
augmentation.__check_batching__((input, in_trans)) |
|
|
|
|
|
def test_forward(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype) |
|
|
input_transform = torch.rand((2, 3, 3), device=device, dtype=dtype) |
|
|
expected_output = torch.rand((2, 3, 4, 5), device=device, dtype=dtype) |
|
|
expected_transform = torch.rand((2, 3, 3), device=device, dtype=dtype) |
|
|
augmentation = AugmentationBase2D(return_transform=False, p=1.0) |
|
|
|
|
|
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object( |
|
|
augmentation, "generate_parameters", autospec=True |
|
|
) as generate_parameters, patch.object( |
|
|
augmentation, "compute_transformation", autospec=True |
|
|
) as compute_transformation: |
|
|
|
|
|
|
|
|
params = {'params': {}, 'flags': {'foo': 0}} |
|
|
generate_parameters.return_value = params |
|
|
apply_transform.return_value = expected_output |
|
|
compute_transformation.return_value = expected_transform |
|
|
output = augmentation(input) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert output is expected_output |
|
|
|
|
|
|
|
|
|
|
|
output, transformation = augmentation(input, return_transform=True) |
|
|
assert output is expected_output |
|
|
assert_close(transformation, expected_transform) |
|
|
|
|
|
|
|
|
params = {'params': {}, 'flags': {'bar': 1}} |
|
|
apply_transform.reset_mock() |
|
|
generate_parameters.return_value = None |
|
|
output = augmentation(input, params=params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert output is expected_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expected_final_transformation = expected_transform @ input_transform |
|
|
output, transformation = augmentation((input, input_transform), return_transform=True) |
|
|
assert output is expected_output |
|
|
assert torch.allclose(expected_final_transformation, transformation) |
|
|
assert transformation.shape[0] == input.shape[0] |
|
|
|
|
|
def test_gradcheck(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
|
|
|
input = torch.rand((1, 1, 3, 3), device=device, dtype=dtype) |
|
|
output = torch.rand((1, 1, 3, 3), device=device, dtype=dtype) |
|
|
input_transform = torch.rand((1, 3, 3), device=device, dtype=dtype) |
|
|
other_transform = torch.rand((1, 3, 3), device=device, dtype=dtype) |
|
|
|
|
|
input = utils.tensor_to_gradcheck_var(input) |
|
|
input_transform = utils.tensor_to_gradcheck_var(input_transform) |
|
|
output = utils.tensor_to_gradcheck_var(output) |
|
|
other_transform = utils.tensor_to_gradcheck_var(other_transform) |
|
|
|
|
|
input_param = {'batch_prob': torch.tensor([True]), 'params': {'x': input_transform}, 'flags': {}} |
|
|
|
|
|
augmentation = AugmentationBase2D(return_transform=True, p=1.0) |
|
|
|
|
|
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object( |
|
|
augmentation, "compute_transformation", autospec=True |
|
|
) as compute_transformation: |
|
|
|
|
|
apply_transform.return_value = output |
|
|
compute_transformation.return_value = other_transform |
|
|
assert gradcheck(augmentation, ((input, input_param)), raise_exception=True) |
|
|
|