import pytest import torch from torch.autograd import gradcheck import kornia from kornia.testing import assert_close from kornia.utils._compat import torch_version_geq def _sample_points(batch_size, device, dtype=torch.float32): src = torch.tensor([[[0.0, 0.0], [0.0, 10.0], [10.0, 0.0], [10.0, 10.0], [5.0, 5.0]]], device=device, dtype=dtype) src = src.repeat(batch_size, 1, 1) dst = src + torch.rand_like(src) * 2.5 return src, dst class TestTransformParameters: @pytest.mark.parametrize('batch_size', [1, 3]) def test_smoke(self, batch_size, device, dtype): src = torch.rand(batch_size, 4, 2, device=device) out = kornia.geometry.transform.get_tps_transform(src, src) assert len(out) == 2 @pytest.mark.parametrize('batch_size', [1, 3]) def test_no_warp(self, batch_size, device, dtype): src = torch.rand(batch_size, 5, 2, device=device) kernel, affine = kornia.geometry.transform.get_tps_transform(src, src) target_kernel = torch.zeros(batch_size, 5, 2, device=device) target_affine = torch.zeros(batch_size, 3, 2, device=device) target_affine[:, [1, 2], [0, 1]] = 1.0 assert_close(kernel, target_kernel, atol=1e-4, rtol=1e-4) assert_close(affine, target_affine, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize('batch_size', [1, 3]) def test_affine_only(self, batch_size, device, dtype): src = torch.tensor([[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.5, 0.5]]], device=device).repeat( batch_size, 1, 1 ) dst = src.clone() * 2.0 kernel, _ = kornia.geometry.transform.get_tps_transform(src, dst) assert_close(kernel, torch.zeros_like(kernel), atol=1e-4, rtol=1e-4) @pytest.mark.parametrize('batch_size', [1, 3]) def test_exception(self, batch_size, device, dtype): with pytest.raises(TypeError): src = torch.rand(batch_size, 5, 2).numpy() assert kornia.geometry.transform.get_tps_transform(src, src) with pytest.raises(ValueError): src = torch.rand(batch_size, 5) assert kornia.geometry.transform.get_tps_transform(src, src) @pytest.mark.grad @pytest.mark.parametrize('batch_size', [1, 3]) @pytest.mark.parametrize('requires_grad', [True, False]) def test_gradcheck(self, batch_size, device, dtype, requires_grad): opts = dict(device=device, dtype=torch.float64) src, dst = _sample_points(batch_size, **opts) src.requires_grad_(requires_grad) dst.requires_grad_(not requires_grad) assert gradcheck(kornia.geometry.transform.get_tps_transform, (src, dst), raise_exception=True) @pytest.mark.jit @pytest.mark.parametrize('batch_size', [1, 3]) def test_jit(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) op = kornia.geometry.transform.get_tps_transform op_jit = torch.jit.script(op) op_output = op(src, dst) jit_output = op_jit(src, dst) assert_close(op_output[0], jit_output[0]) assert_close(op_output[1], jit_output[1]) class TestWarpPoints: @pytest.mark.parametrize('batch_size', [1, 3]) def test_smoke(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) kernel, affine = kornia.geometry.transform.get_tps_transform(src, dst) warp = kornia.geometry.transform.warp_points_tps(src, dst, kernel, affine) assert warp.shape == src.shape @pytest.mark.parametrize('batch_size', [1, 3]) def test_warp(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) kernel, affine = kornia.geometry.transform.get_tps_transform(src, dst) warp = kornia.geometry.transform.warp_points_tps(src, dst, kernel, affine) assert_close(warp, dst, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize('batch_size', [1, 3]) def test_exception(self, batch_size, device, dtype): src = torch.rand(batch_size, 5, 2) kernel = torch.zeros_like(src) affine = torch.zeros(batch_size, 3, 2) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_points_tps(src.numpy(), src, kernel, affine) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_points_tps(src, src.numpy(), kernel, affine) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_points_tps(src, src, kernel.numpy(), affine) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_points_tps(src, src, kernel, affine.numpy()) with pytest.raises(ValueError): src_bad = torch.rand(batch_size, 5) assert kornia.geometry.transform.warp_points_tps(src_bad, src, kernel, affine) with pytest.raises(ValueError): src_bad = torch.rand(batch_size, 5) assert kornia.geometry.transform.warp_points_tps(src, src_bad, kernel, affine) with pytest.raises(ValueError): kernel_bad = torch.rand(batch_size, 5) assert kornia.geometry.transform.warp_points_tps(src, src, kernel_bad, affine) with pytest.raises(ValueError): affine_bad = torch.rand(batch_size, 3) assert kornia.geometry.transform.warp_points_tps(src, src, kernel, affine_bad) @pytest.mark.grad @pytest.mark.parametrize('batch_size', [1, 3]) @pytest.mark.parametrize('requires_grad', [True, False]) def test_gradcheck(self, batch_size, device, dtype, requires_grad): opts = dict(device=device, dtype=torch.float64) src, dst = _sample_points(batch_size, **opts) kernel, affine = kornia.geometry.transform.get_tps_transform(src, dst) kernel.requires_grad_(requires_grad) affine.requires_grad_(not requires_grad) assert gradcheck(kornia.geometry.transform.warp_points_tps, (src, dst, kernel, affine), raise_exception=True) @pytest.mark.jit @pytest.mark.parametrize('batch_size', [1, 3]) def test_jit(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) kernel, affine = kornia.geometry.transform.get_tps_transform(src, dst) op = kornia.geometry.transform.warp_points_tps op_jit = torch.jit.script(op) assert_close(op(src, dst, kernel, affine), op_jit(src, dst, kernel, affine)) class TestWarpImage: @pytest.mark.parametrize('batch_size', [1, 3]) def test_smoke(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) tensor = torch.rand(batch_size, 3, 32, 32, device=device) kernel, affine = kornia.geometry.transform.get_tps_transform(src, dst) warp = kornia.geometry.transform.warp_image_tps(tensor, dst, kernel, affine) assert warp.shape == tensor.shape @pytest.mark.skipif( torch_version_geq(1, 10), reason="for some reason the solver detects singular matrices in pytorch >=1.10.") @pytest.mark.parametrize('batch_size', [1, 3]) def test_warp(self, batch_size, device, dtype): src = torch.tensor([[[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0], [0.0, 0.0]]], device=device).repeat( batch_size, 1, 1 ) # zoom in by a factor of 2 dst = src.clone() * 2.0 tensor = torch.zeros(batch_size, 3, 8, 8, device=device) tensor[:, :, 2:6, 2:6] = 1.0 expected = torch.ones_like(tensor) # nn.grid_sample interpolates the at the edges it seems, so the boundaries have values < 1 expected[:, :, [0, -1], :] *= 0.5 expected[:, :, :, [0, -1]] *= 0.5 kernel, affine = kornia.geometry.transform.get_tps_transform(dst, src) warp = kornia.geometry.transform.warp_image_tps(tensor, src, kernel, affine) assert_close(warp, expected, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize('batch_size', [1, 3]) def test_exception(self, batch_size, device, dtype): image = torch.rand(batch_size, 3, 32, 32) dst = torch.rand(batch_size, 5, 2) kernel = torch.zeros_like(dst) affine = torch.zeros(batch_size, 3, 2) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_image_tps(image.numpy(), dst, kernel, affine) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_image_tps(image, dst.numpy(), kernel, affine) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_image_tps(image, dst, kernel.numpy(), affine) with pytest.raises(TypeError): assert kornia.geometry.transform.warp_image_tps(image, dst, kernel, affine.numpy()) with pytest.raises(ValueError): image_bad = torch.rand(batch_size, 32, 32) assert kornia.geometry.transform.warp_image_tps(image_bad, dst, kernel, affine) with pytest.raises(ValueError): dst_bad = torch.rand(batch_size, 5) assert kornia.geometry.transform.warp_image_tps(image, dst_bad, kernel, affine) with pytest.raises(ValueError): kernel_bad = torch.rand(batch_size, 5) assert kornia.geometry.transform.warp_image_tps(image, dst, kernel_bad, affine) with pytest.raises(ValueError): affine_bad = torch.rand(batch_size, 3) assert kornia.geometry.transform.warp_image_tps(image, dst, kernel, affine_bad) @pytest.mark.grad @pytest.mark.parametrize('batch_size', [1, 3]) def test_gradcheck(self, batch_size, device, dtype): opts = dict(device=device, dtype=torch.float64) src, dst = _sample_points(batch_size, **opts) kernel, affine = kornia.geometry.transform.get_tps_transform(src, dst) image = torch.rand(batch_size, 3, 8, 8, requires_grad=True, **opts) assert gradcheck( kornia.geometry.transform.warp_image_tps, (image, dst, kernel, affine), raise_exception=True, atol=1e-4, rtol=1e-4, ) @pytest.mark.jit @pytest.mark.parametrize('batch_size', [1, 3]) def test_jit(self, batch_size, device, dtype): src, dst = _sample_points(batch_size, device) kernel, affine = kornia.geometry.transform.get_tps_transform(src, dst) image = torch.rand(batch_size, 3, 32, 32, device=device) op = kornia.geometry.transform.warp_image_tps op_jit = torch.jit.script(op) assert_close(op(image, dst, kernel, affine), op_jit(image, dst, kernel, affine), rtol=1e-4, atol=1e-4)