| | import pytest |
| | import torch |
| |
|
| | from kornia.testing import assert_close |
| | from kornia.utils import _extract_device_dtype |
| | from kornia.utils.helpers import ( |
| | _torch_histc_cast, |
| | _torch_inverse_cast, |
| | _torch_solve_cast, |
| | _torch_svd_cast, |
| | safe_inverse_with_mask, |
| | safe_solve_with_mask, |
| | ) |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "tensor_list,out_device,out_dtype,will_throw_error", |
| | [ |
| | ([], torch.device('cpu'), torch.get_default_dtype(), False), |
| | ([None, None], torch.device('cpu'), torch.get_default_dtype(), False), |
| | ([torch.tensor(0, device='cpu', dtype=torch.float16), None], torch.device('cpu'), torch.float16, False), |
| | ([torch.tensor(0, device='cpu', dtype=torch.float32), None], torch.device('cpu'), torch.float32, False), |
| | ([torch.tensor(0, device='cpu', dtype=torch.float64), None], torch.device('cpu'), torch.float64, False), |
| | ([torch.tensor(0, device='cpu', dtype=torch.float16)] * 2, torch.device('cpu'), torch.float16, False), |
| | ([torch.tensor(0, device='cpu', dtype=torch.float32)] * 2, torch.device('cpu'), torch.float32, False), |
| | ([torch.tensor(0, device='cpu', dtype=torch.float64)] * 2, torch.device('cpu'), torch.float64, False), |
| | ( |
| | [torch.tensor(0, device='cpu', dtype=torch.float16), torch.tensor(0, device='cpu', dtype=torch.float64)], |
| | None, |
| | None, |
| | True, |
| | ), |
| | ( |
| | [torch.tensor(0, device='cpu', dtype=torch.float32), torch.tensor(0, device='cpu', dtype=torch.float64)], |
| | None, |
| | None, |
| | True, |
| | ), |
| | ( |
| | [torch.tensor(0, device='cpu', dtype=torch.float16), torch.tensor(0, device='cpu', dtype=torch.float32)], |
| | None, |
| | None, |
| | True, |
| | ), |
| | ], |
| | ) |
| | def test_extract_device_dtype(tensor_list, out_device, out_dtype, will_throw_error): |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if will_throw_error: |
| | with pytest.raises(ValueError): |
| | _extract_device_dtype(tensor_list) |
| | else: |
| | device, dtype = _extract_device_dtype(tensor_list) |
| | assert device == out_device |
| | assert dtype == out_dtype |
| |
|
| |
|
| | class TestInverseCast: |
| | @pytest.mark.parametrize("input_shape", [(1, 3, 4, 4), (2, 4, 5, 5)]) |
| | def test_smoke(self, device, dtype, input_shape): |
| | x = torch.rand(input_shape, device=device, dtype=dtype) |
| | y = _torch_inverse_cast(x) |
| | assert y.shape == x.shape |
| |
|
| | def test_values(self, device, dtype): |
| | x = torch.tensor([[4.0, 7.0], [2.0, 6.0]], device=device, dtype=dtype) |
| |
|
| | y_expected = torch.tensor([[0.6, -0.7], [-0.2, 0.4]], device=device, dtype=dtype) |
| |
|
| | y = _torch_inverse_cast(x) |
| |
|
| | assert_close(y, y_expected) |
| |
|
| | def test_jit(self, device, dtype): |
| | x = torch.rand(1, 3, 4, 4, device=device, dtype=dtype) |
| | op = _torch_inverse_cast |
| | op_jit = torch.jit.script(op) |
| | assert_close(op(x), op_jit(x)) |
| |
|
| |
|
| | class TestHistcCast: |
| | def test_smoke(self, device, dtype): |
| | x = torch.tensor([1.0, 2.0, 1.0], device=device, dtype=dtype) |
| | y_expected = torch.tensor([0.0, 2.0, 1.0, 0.0], device=device, dtype=dtype) |
| |
|
| | y = _torch_histc_cast(x, bins=4, min=0, max=3) |
| |
|
| | assert_close(y, y_expected) |
| |
|
| |
|
| | class TestSvdCast: |
| | def test_smoke(self, device, dtype): |
| | a = torch.randn(5, 3, 3, device=device, dtype=dtype) |
| | u, s, v = _torch_svd_cast(a) |
| |
|
| | tol_val: float = 1e-1 if dtype == torch.float16 else 1e-3 |
| | assert_close(a, u @ torch.diag_embed(s) @ v.transpose(-2, -1), atol=tol_val, rtol=tol_val) |
| |
|
| |
|
| | class TestSolveCast: |
| | def test_smoke(self, device, dtype): |
| | A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype) |
| | B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype) |
| |
|
| | X, _ = _torch_solve_cast(B, A) |
| | error = torch.dist(B, A.matmul(X)) |
| |
|
| | tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4 |
| | assert_close(error, torch.zeros_like(error), atol=tol_val, rtol=tol_val) |
| |
|
| |
|
| | class TestSolveWithMask: |
| | def test_smoke(self, device, dtype): |
| | A = torch.randn(2, 3, 1, 4, 4, device=device, dtype=dtype) |
| | B = torch.randn(2, 3, 1, 4, 6, device=device, dtype=dtype) |
| |
|
| | X, _, mask = safe_solve_with_mask(B, A) |
| | X2, _ = _torch_solve_cast(B, A) |
| | tol_val: float = 1e-1 if dtype == torch.float16 else 1e-4 |
| | if mask.sum() > 0: |
| | assert_close(X[mask], X2[mask], atol=tol_val, rtol=tol_val) |
| |
|
| | @pytest.mark.skipif( |
| | (int(torch.__version__.split('.')[0]) == 1) and (int(torch.__version__.split('.')[1]) < 10), |
| | reason='<1.10.0 not supporting', |
| | ) |
| | def test_all_bad(self, device, dtype): |
| | A = torch.ones(10, 3, 3, device=device, dtype=dtype) |
| | B = torch.ones(3, 10, device=device, dtype=dtype) |
| |
|
| | X, _, mask = safe_solve_with_mask(B, A) |
| | assert torch.equal(mask, torch.zeros_like(mask)) |
| |
|
| |
|
| | class TestInverseWithMask: |
| | def test_smoke(self, device, dtype): |
| | x = torch.tensor([[4.0, 7.0], [2.0, 6.0]], device=device, dtype=dtype) |
| |
|
| | y_expected = torch.tensor([[0.6, -0.7], [-0.2, 0.4]], device=device, dtype=dtype) |
| |
|
| | y, mask = safe_inverse_with_mask(x) |
| |
|
| | assert_close(y, y_expected) |
| | assert torch.equal(mask, torch.ones_like(mask)) |
| |
|
| | @pytest.mark.skipif( |
| | (int(torch.__version__.split('.')[0]) == 1) and (int(torch.__version__.split('.')[1]) < 9), |
| | reason='<1.9.0 not supporting', |
| | ) |
| | def test_all_bad(self, device, dtype): |
| | A = torch.ones(10, 3, 3, device=device, dtype=dtype) |
| | X, mask = safe_inverse_with_mask(A) |
| | assert torch.equal(mask, torch.zeros_like(mask)) |
| |
|