| | from unittest.mock import patch |
| |
|
| | import pytest |
| | import torch |
| | from torch.autograd import gradcheck |
| |
|
| | import kornia.testing as utils |
| | from kornia.augmentation.base import _BasicAugmentationBase, AugmentationBase2D |
| | from kornia.testing import assert_close |
| |
|
| |
|
| | class TestBasicAugmentationBase: |
| | def test_smoke(self, device, dtype): |
| | base = _BasicAugmentationBase(p=0.5, p_batch=1.0, same_on_batch=True) |
| | __repr__ = "p=0.5, p_batch=1.0, same_on_batch=True" |
| | assert str(base) == __repr__ |
| |
|
| | def test_infer_input(self, device, dtype): |
| | input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype) |
| | augmentation = _BasicAugmentationBase(p=1.0, p_batch=1) |
| | with patch.object(augmentation, "transform_tensor", autospec=True) as transform_tensor: |
| | transform_tensor.side_effect = lambda x: x.unsqueeze(dim=2) |
| | output = augmentation.transform_tensor(input) |
| | assert output.shape == torch.Size([2, 3, 1, 4, 5]) |
| | assert_close(input, output[:, :, 0, :, :]) |
| |
|
| | @pytest.mark.parametrize( |
| | "p,p_batch,same_on_batch,num,seed", |
| | [ |
| | (1.0, 1.0, False, 12, 1), |
| | (1.0, 0.0, False, 0, 1), |
| | (0.0, 1.0, False, 0, 1), |
| | (0.0, 0.0, False, 0, 1), |
| | (0.5, 0.1, False, 7, 3), |
| | (0.5, 0.1, True, 12, 3), |
| | (0.3, 1.0, False, 2, 1), |
| | (0.3, 1.0, True, 0, 1), |
| | ], |
| | ) |
| | def test_forward_params(self, p, p_batch, same_on_batch, num, seed, device, dtype): |
| | input_shape = (12,) |
| | torch.manual_seed(seed) |
| | augmentation = _BasicAugmentationBase(p, p_batch, same_on_batch) |
| | with patch.object(augmentation, "generate_parameters", autospec=True) as generate_parameters: |
| | generate_parameters.side_effect = lambda shape: { |
| | 'degrees': torch.arange(0, shape[0], device=device, dtype=dtype) |
| | } |
| | output = augmentation.forward_parameters(input_shape) |
| | assert "batch_prob" in output |
| | assert len(output['degrees']) == output['batch_prob'].sum().item() == num |
| |
|
| | @pytest.mark.parametrize('keepdim', (True, False)) |
| | def test_forward(self, device, dtype, keepdim): |
| | torch.manual_seed(42) |
| | input = torch.rand((12, 3, 4, 5), device=device, dtype=dtype) |
| | expected_output = input[..., :2, :2] if keepdim else input.unsqueeze(dim=0)[..., :2, :2] |
| | augmentation = _BasicAugmentationBase(p=0.3, p_batch=1.0, keepdim=keepdim) |
| | with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object( |
| | augmentation, "generate_parameters", autospec=True |
| | ) as generate_parameters, patch.object( |
| | augmentation, "transform_tensor", autospec=True |
| | ) as transform_tensor, patch.object( |
| | augmentation, "__check_batching__", autospec=True |
| | ) as check_batching: |
| |
|
| | generate_parameters.side_effect = lambda shape: { |
| | 'degrees': torch.arange(0, shape[0], device=device, dtype=dtype) |
| | } |
| | transform_tensor.side_effect = lambda x: x.unsqueeze(dim=0) |
| | apply_transform.side_effect = lambda input, params: input[..., :2, :2] |
| | check_batching.side_effect = lambda input: None |
| | output = augmentation(input) |
| | assert output.shape == expected_output.shape |
| | assert_close(output, expected_output) |
| |
|
| |
|
| | class TestAugmentationBase2D: |
| | @pytest.mark.parametrize( |
| | 'input_shape, in_trans_shape', |
| | [ |
| | ((2, 3, 4, 5), (2, 3, 3)), |
| | ((3, 4, 5), (3, 3)), |
| | ((4, 5), (3, 3)), |
| | pytest.param((1, 2, 3, 4, 5), (2, 3, 3), marks=pytest.mark.xfail), |
| | pytest.param((2, 3, 4, 5), (1, 3, 3), marks=pytest.mark.xfail), |
| | pytest.param((2, 3, 4, 5), (3, 3), marks=pytest.mark.xfail), |
| | ], |
| | ) |
| | def test_check_batching(self, device, dtype, input_shape, in_trans_shape): |
| | input = torch.rand(input_shape, device=device, dtype=dtype) |
| | in_trans = torch.rand(in_trans_shape, device=device, dtype=dtype) |
| | augmentation = AugmentationBase2D(p=1.0, p_batch=1) |
| | augmentation.__check_batching__(input) |
| | augmentation.__check_batching__((input, in_trans)) |
| |
|
| | def test_forward(self, device, dtype): |
| | torch.manual_seed(42) |
| | input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype) |
| | input_transform = torch.rand((2, 3, 3), device=device, dtype=dtype) |
| | expected_output = torch.rand((2, 3, 4, 5), device=device, dtype=dtype) |
| | expected_transform = torch.rand((2, 3, 3), device=device, dtype=dtype) |
| | augmentation = AugmentationBase2D(return_transform=False, p=1.0) |
| |
|
| | with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object( |
| | augmentation, "generate_parameters", autospec=True |
| | ) as generate_parameters, patch.object( |
| | augmentation, "compute_transformation", autospec=True |
| | ) as compute_transformation: |
| |
|
| | |
| | params = {'params': {}, 'flags': {'foo': 0}} |
| | generate_parameters.return_value = params |
| | apply_transform.return_value = expected_output |
| | compute_transformation.return_value = expected_transform |
| | output = augmentation(input) |
| | |
| | |
| | |
| | |
| | assert output is expected_output |
| |
|
| | |
| | |
| | output, transformation = augmentation(input, return_transform=True) |
| | assert output is expected_output |
| | assert_close(transformation, expected_transform) |
| |
|
| | |
| | params = {'params': {}, 'flags': {'bar': 1}} |
| | apply_transform.reset_mock() |
| | generate_parameters.return_value = None |
| | output = augmentation(input, params=params) |
| | |
| | |
| | |
| | |
| | assert output is expected_output |
| |
|
| | |
| | |
| | |
| | expected_final_transformation = expected_transform @ input_transform |
| | output, transformation = augmentation((input, input_transform), return_transform=True) |
| | assert output is expected_output |
| | assert torch.allclose(expected_final_transformation, transformation) |
| | assert transformation.shape[0] == input.shape[0] |
| |
|
| | def test_gradcheck(self, device, dtype): |
| | torch.manual_seed(42) |
| |
|
| | input = torch.rand((1, 1, 3, 3), device=device, dtype=dtype) |
| | output = torch.rand((1, 1, 3, 3), device=device, dtype=dtype) |
| | input_transform = torch.rand((1, 3, 3), device=device, dtype=dtype) |
| | other_transform = torch.rand((1, 3, 3), device=device, dtype=dtype) |
| |
|
| | input = utils.tensor_to_gradcheck_var(input) |
| | input_transform = utils.tensor_to_gradcheck_var(input_transform) |
| | output = utils.tensor_to_gradcheck_var(output) |
| | other_transform = utils.tensor_to_gradcheck_var(other_transform) |
| |
|
| | input_param = {'batch_prob': torch.tensor([True]), 'params': {'x': input_transform}, 'flags': {}} |
| |
|
| | augmentation = AugmentationBase2D(return_transform=True, p=1.0) |
| |
|
| | with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object( |
| | augmentation, "compute_transformation", autospec=True |
| | ) as compute_transformation: |
| |
|
| | apply_transform.return_value = output |
| | compute_transformation.return_value = other_transform |
| | assert gradcheck(augmentation, ((input, input_param)), raise_exception=True) |
| |
|