File size: 8,710 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from unittest.mock import patch
import pytest
import torch
from torch.autograd import gradcheck
import kornia.testing as utils # test utils
from kornia.augmentation.base import _BasicAugmentationBase, AugmentationBase2D
from kornia.testing import assert_close
class TestBasicAugmentationBase:
def test_smoke(self, device, dtype):
base = _BasicAugmentationBase(p=0.5, p_batch=1.0, same_on_batch=True)
__repr__ = "p=0.5, p_batch=1.0, same_on_batch=True"
assert str(base) == __repr__
def test_infer_input(self, device, dtype):
input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype)
augmentation = _BasicAugmentationBase(p=1.0, p_batch=1)
with patch.object(augmentation, "transform_tensor", autospec=True) as transform_tensor:
transform_tensor.side_effect = lambda x: x.unsqueeze(dim=2)
output = augmentation.transform_tensor(input)
assert output.shape == torch.Size([2, 3, 1, 4, 5])
assert_close(input, output[:, :, 0, :, :])
@pytest.mark.parametrize(
"p,p_batch,same_on_batch,num,seed",
[
(1.0, 1.0, False, 12, 1),
(1.0, 0.0, False, 0, 1),
(0.0, 1.0, False, 0, 1),
(0.0, 0.0, False, 0, 1),
(0.5, 0.1, False, 7, 3),
(0.5, 0.1, True, 12, 3),
(0.3, 1.0, False, 2, 1),
(0.3, 1.0, True, 0, 1),
],
)
def test_forward_params(self, p, p_batch, same_on_batch, num, seed, device, dtype):
input_shape = (12,)
torch.manual_seed(seed)
augmentation = _BasicAugmentationBase(p, p_batch, same_on_batch)
with patch.object(augmentation, "generate_parameters", autospec=True) as generate_parameters:
generate_parameters.side_effect = lambda shape: {
'degrees': torch.arange(0, shape[0], device=device, dtype=dtype)
}
output = augmentation.forward_parameters(input_shape)
assert "batch_prob" in output
assert len(output['degrees']) == output['batch_prob'].sum().item() == num
@pytest.mark.parametrize('keepdim', (True, False))
def test_forward(self, device, dtype, keepdim):
torch.manual_seed(42)
input = torch.rand((12, 3, 4, 5), device=device, dtype=dtype)
expected_output = input[..., :2, :2] if keepdim else input.unsqueeze(dim=0)[..., :2, :2]
augmentation = _BasicAugmentationBase(p=0.3, p_batch=1.0, keepdim=keepdim)
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object(
augmentation, "generate_parameters", autospec=True
) as generate_parameters, patch.object(
augmentation, "transform_tensor", autospec=True
) as transform_tensor, patch.object(
augmentation, "__check_batching__", autospec=True
) as check_batching:
generate_parameters.side_effect = lambda shape: {
'degrees': torch.arange(0, shape[0], device=device, dtype=dtype)
}
transform_tensor.side_effect = lambda x: x.unsqueeze(dim=0)
apply_transform.side_effect = lambda input, params: input[..., :2, :2]
check_batching.side_effect = lambda input: None
output = augmentation(input)
assert output.shape == expected_output.shape
assert_close(output, expected_output)
class TestAugmentationBase2D:
@pytest.mark.parametrize(
'input_shape, in_trans_shape',
[
((2, 3, 4, 5), (2, 3, 3)),
((3, 4, 5), (3, 3)),
((4, 5), (3, 3)),
pytest.param((1, 2, 3, 4, 5), (2, 3, 3), marks=pytest.mark.xfail),
pytest.param((2, 3, 4, 5), (1, 3, 3), marks=pytest.mark.xfail),
pytest.param((2, 3, 4, 5), (3, 3), marks=pytest.mark.xfail),
],
)
def test_check_batching(self, device, dtype, input_shape, in_trans_shape):
input = torch.rand(input_shape, device=device, dtype=dtype)
in_trans = torch.rand(in_trans_shape, device=device, dtype=dtype)
augmentation = AugmentationBase2D(p=1.0, p_batch=1)
augmentation.__check_batching__(input)
augmentation.__check_batching__((input, in_trans))
def test_forward(self, device, dtype):
torch.manual_seed(42)
input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype)
input_transform = torch.rand((2, 3, 3), device=device, dtype=dtype)
expected_output = torch.rand((2, 3, 4, 5), device=device, dtype=dtype)
expected_transform = torch.rand((2, 3, 3), device=device, dtype=dtype)
augmentation = AugmentationBase2D(return_transform=False, p=1.0)
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object(
augmentation, "generate_parameters", autospec=True
) as generate_parameters, patch.object(
augmentation, "compute_transformation", autospec=True
) as compute_transformation:
# Calling the augmentation with a single tensor shall return the expected tensor using the generated params.
params = {'params': {}, 'flags': {'foo': 0}}
generate_parameters.return_value = params
apply_transform.return_value = expected_output
compute_transformation.return_value = expected_transform
output = augmentation(input)
# RuntimeError: Boolean value of Tensor with more than one value is ambiguous
# Not an easy fix, happens on verifying torch.tensor([True, True])
# _params = {'batch_prob': torch.tensor([True, True]), 'params': {}, 'flags': {'foo': 0}}
# apply_transform.assert_called_once_with(input, _params)
assert output is expected_output
# Calling the augmentation with a tensor and set return_transform shall
# return the expected tensor and transformation.
output, transformation = augmentation(input, return_transform=True)
assert output is expected_output
assert_close(transformation, expected_transform)
# Calling the augmentation with a tensor and params shall return the expected tensor using the given params.
params = {'params': {}, 'flags': {'bar': 1}}
apply_transform.reset_mock()
generate_parameters.return_value = None
output = augmentation(input, params=params)
# RuntimeError: Boolean value of Tensor with more than one value is ambiguous
# Not an easy fix, happens on verifying torch.tensor([True, True])
# _params = {'batch_prob': torch.tensor([True, True]), 'params': {}, 'flags': {'foo': 0}}
# apply_transform.assert_called_once_with(input, _params)
assert output is expected_output
# Calling the augmentation with a tensor,a transformation and set
# return_transform shall return the expected tensor and the proper
# transformation matrix.
expected_final_transformation = expected_transform @ input_transform
output, transformation = augmentation((input, input_transform), return_transform=True)
assert output is expected_output
assert torch.allclose(expected_final_transformation, transformation)
assert transformation.shape[0] == input.shape[0]
def test_gradcheck(self, device, dtype):
torch.manual_seed(42)
input = torch.rand((1, 1, 3, 3), device=device, dtype=dtype)
output = torch.rand((1, 1, 3, 3), device=device, dtype=dtype)
input_transform = torch.rand((1, 3, 3), device=device, dtype=dtype)
other_transform = torch.rand((1, 3, 3), device=device, dtype=dtype)
input = utils.tensor_to_gradcheck_var(input) # to var
input_transform = utils.tensor_to_gradcheck_var(input_transform) # to var
output = utils.tensor_to_gradcheck_var(output) # to var
other_transform = utils.tensor_to_gradcheck_var(other_transform) # to var
input_param = {'batch_prob': torch.tensor([True]), 'params': {'x': input_transform}, 'flags': {}}
augmentation = AugmentationBase2D(return_transform=True, p=1.0)
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object(
augmentation, "compute_transformation", autospec=True
) as compute_transformation:
apply_transform.return_value = output
compute_transformation.return_value = other_transform
assert gradcheck(augmentation, ((input, input_param)), raise_exception=True)
|