compvis / test /augmentation /test_base.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
from unittest.mock import patch
import pytest
import torch
from torch.autograd import gradcheck
import kornia.testing as utils # test utils
from kornia.augmentation.base import _BasicAugmentationBase, AugmentationBase2D
from kornia.testing import assert_close
class TestBasicAugmentationBase:
def test_smoke(self, device, dtype):
base = _BasicAugmentationBase(p=0.5, p_batch=1.0, same_on_batch=True)
__repr__ = "p=0.5, p_batch=1.0, same_on_batch=True"
assert str(base) == __repr__
def test_infer_input(self, device, dtype):
input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype)
augmentation = _BasicAugmentationBase(p=1.0, p_batch=1)
with patch.object(augmentation, "transform_tensor", autospec=True) as transform_tensor:
transform_tensor.side_effect = lambda x: x.unsqueeze(dim=2)
output = augmentation.transform_tensor(input)
assert output.shape == torch.Size([2, 3, 1, 4, 5])
assert_close(input, output[:, :, 0, :, :])
@pytest.mark.parametrize(
"p,p_batch,same_on_batch,num,seed",
[
(1.0, 1.0, False, 12, 1),
(1.0, 0.0, False, 0, 1),
(0.0, 1.0, False, 0, 1),
(0.0, 0.0, False, 0, 1),
(0.5, 0.1, False, 7, 3),
(0.5, 0.1, True, 12, 3),
(0.3, 1.0, False, 2, 1),
(0.3, 1.0, True, 0, 1),
],
)
def test_forward_params(self, p, p_batch, same_on_batch, num, seed, device, dtype):
input_shape = (12,)
torch.manual_seed(seed)
augmentation = _BasicAugmentationBase(p, p_batch, same_on_batch)
with patch.object(augmentation, "generate_parameters", autospec=True) as generate_parameters:
generate_parameters.side_effect = lambda shape: {
'degrees': torch.arange(0, shape[0], device=device, dtype=dtype)
}
output = augmentation.forward_parameters(input_shape)
assert "batch_prob" in output
assert len(output['degrees']) == output['batch_prob'].sum().item() == num
@pytest.mark.parametrize('keepdim', (True, False))
def test_forward(self, device, dtype, keepdim):
torch.manual_seed(42)
input = torch.rand((12, 3, 4, 5), device=device, dtype=dtype)
expected_output = input[..., :2, :2] if keepdim else input.unsqueeze(dim=0)[..., :2, :2]
augmentation = _BasicAugmentationBase(p=0.3, p_batch=1.0, keepdim=keepdim)
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object(
augmentation, "generate_parameters", autospec=True
) as generate_parameters, patch.object(
augmentation, "transform_tensor", autospec=True
) as transform_tensor, patch.object(
augmentation, "__check_batching__", autospec=True
) as check_batching:
generate_parameters.side_effect = lambda shape: {
'degrees': torch.arange(0, shape[0], device=device, dtype=dtype)
}
transform_tensor.side_effect = lambda x: x.unsqueeze(dim=0)
apply_transform.side_effect = lambda input, params: input[..., :2, :2]
check_batching.side_effect = lambda input: None
output = augmentation(input)
assert output.shape == expected_output.shape
assert_close(output, expected_output)
class TestAugmentationBase2D:
@pytest.mark.parametrize(
'input_shape, in_trans_shape',
[
((2, 3, 4, 5), (2, 3, 3)),
((3, 4, 5), (3, 3)),
((4, 5), (3, 3)),
pytest.param((1, 2, 3, 4, 5), (2, 3, 3), marks=pytest.mark.xfail),
pytest.param((2, 3, 4, 5), (1, 3, 3), marks=pytest.mark.xfail),
pytest.param((2, 3, 4, 5), (3, 3), marks=pytest.mark.xfail),
],
)
def test_check_batching(self, device, dtype, input_shape, in_trans_shape):
input = torch.rand(input_shape, device=device, dtype=dtype)
in_trans = torch.rand(in_trans_shape, device=device, dtype=dtype)
augmentation = AugmentationBase2D(p=1.0, p_batch=1)
augmentation.__check_batching__(input)
augmentation.__check_batching__((input, in_trans))
def test_forward(self, device, dtype):
torch.manual_seed(42)
input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype)
input_transform = torch.rand((2, 3, 3), device=device, dtype=dtype)
expected_output = torch.rand((2, 3, 4, 5), device=device, dtype=dtype)
expected_transform = torch.rand((2, 3, 3), device=device, dtype=dtype)
augmentation = AugmentationBase2D(return_transform=False, p=1.0)
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object(
augmentation, "generate_parameters", autospec=True
) as generate_parameters, patch.object(
augmentation, "compute_transformation", autospec=True
) as compute_transformation:
# Calling the augmentation with a single tensor shall return the expected tensor using the generated params.
params = {'params': {}, 'flags': {'foo': 0}}
generate_parameters.return_value = params
apply_transform.return_value = expected_output
compute_transformation.return_value = expected_transform
output = augmentation(input)
# RuntimeError: Boolean value of Tensor with more than one value is ambiguous
# Not an easy fix, happens on verifying torch.tensor([True, True])
# _params = {'batch_prob': torch.tensor([True, True]), 'params': {}, 'flags': {'foo': 0}}
# apply_transform.assert_called_once_with(input, _params)
assert output is expected_output
# Calling the augmentation with a tensor and set return_transform shall
# return the expected tensor and transformation.
output, transformation = augmentation(input, return_transform=True)
assert output is expected_output
assert_close(transformation, expected_transform)
# Calling the augmentation with a tensor and params shall return the expected tensor using the given params.
params = {'params': {}, 'flags': {'bar': 1}}
apply_transform.reset_mock()
generate_parameters.return_value = None
output = augmentation(input, params=params)
# RuntimeError: Boolean value of Tensor with more than one value is ambiguous
# Not an easy fix, happens on verifying torch.tensor([True, True])
# _params = {'batch_prob': torch.tensor([True, True]), 'params': {}, 'flags': {'foo': 0}}
# apply_transform.assert_called_once_with(input, _params)
assert output is expected_output
# Calling the augmentation with a tensor,a transformation and set
# return_transform shall return the expected tensor and the proper
# transformation matrix.
expected_final_transformation = expected_transform @ input_transform
output, transformation = augmentation((input, input_transform), return_transform=True)
assert output is expected_output
assert torch.allclose(expected_final_transformation, transformation)
assert transformation.shape[0] == input.shape[0]
def test_gradcheck(self, device, dtype):
torch.manual_seed(42)
input = torch.rand((1, 1, 3, 3), device=device, dtype=dtype)
output = torch.rand((1, 1, 3, 3), device=device, dtype=dtype)
input_transform = torch.rand((1, 3, 3), device=device, dtype=dtype)
other_transform = torch.rand((1, 3, 3), device=device, dtype=dtype)
input = utils.tensor_to_gradcheck_var(input) # to var
input_transform = utils.tensor_to_gradcheck_var(input_transform) # to var
output = utils.tensor_to_gradcheck_var(output) # to var
other_transform = utils.tensor_to_gradcheck_var(other_transform) # to var
input_param = {'batch_prob': torch.tensor([True]), 'params': {'x': input_transform}, 'flags': {}}
augmentation = AugmentationBase2D(return_transform=True, p=1.0)
with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, patch.object(
augmentation, "compute_transformation", autospec=True
) as compute_transformation:
apply_transform.return_value = output
compute_transformation.return_value = other_transform
assert gradcheck(augmentation, ((input, input_param)), raise_exception=True)