|
|
import pytest |
|
|
import torch |
|
|
from torch.autograd import gradcheck |
|
|
|
|
|
import kornia.testing as utils |
|
|
from kornia.feature.matching import DescriptorMatcher, match_mnn, match_nn, match_smnn, match_snn |
|
|
from kornia.testing import assert_close |
|
|
|
|
|
|
|
|
class TestMatchNN: |
|
|
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(1, 4, 4), (2, 5, 128), (6, 2, 32)]) |
|
|
def test_shape(self, num_desc1, num_desc2, dim, device): |
|
|
desc1 = torch.rand(num_desc1, dim, device=device) |
|
|
desc2 = torch.rand(num_desc2, dim, device=device) |
|
|
|
|
|
dists, idxs = match_nn(desc1, desc2) |
|
|
assert idxs.shape == (num_desc1, 2) |
|
|
assert dists.shape == (num_desc1, 1) |
|
|
|
|
|
def test_matching(self, device): |
|
|
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device) |
|
|
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device) |
|
|
|
|
|
dists, idxs = match_nn(desc1, desc2) |
|
|
expected_dists = torch.tensor([0, 0, 0.5, 0, 0], device=device).view(-1, 1) |
|
|
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device) |
|
|
assert_close(dists, expected_dists) |
|
|
assert_close(idxs, expected_idx) |
|
|
|
|
|
dists1, idxs1 = match_nn(desc1, desc2) |
|
|
assert_close(dists1, expected_dists) |
|
|
assert_close(idxs1, expected_idx) |
|
|
|
|
|
def test_gradcheck(self, device): |
|
|
desc1 = torch.rand(5, 8, device=device) |
|
|
desc2 = torch.rand(7, 8, device=device) |
|
|
desc1 = utils.tensor_to_gradcheck_var(desc1) |
|
|
desc2 = utils.tensor_to_gradcheck_var(desc2) |
|
|
assert gradcheck(match_mnn, (desc1, desc2), raise_exception=True, nondet_tol=1e-4) |
|
|
|
|
|
|
|
|
class TestMatchMNN: |
|
|
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(1, 4, 4), (2, 5, 128), (6, 2, 32)]) |
|
|
def test_shape(self, num_desc1, num_desc2, dim, device): |
|
|
desc1 = torch.rand(num_desc1, dim, device=device) |
|
|
desc2 = torch.rand(num_desc2, dim, device=device) |
|
|
|
|
|
dists, idxs = match_mnn(desc1, desc2) |
|
|
assert idxs.shape[1] == 2 |
|
|
assert dists.shape[1] == 1 |
|
|
assert idxs.shape[0] == dists.shape[0] |
|
|
assert dists.shape[0] <= num_desc1 |
|
|
|
|
|
def test_matching(self, device): |
|
|
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device) |
|
|
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device) |
|
|
|
|
|
dists, idxs = match_mnn(desc1, desc2) |
|
|
expected_dists = torch.tensor([0, 0, 0.5, 0, 0], device=device).view(-1, 1) |
|
|
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device) |
|
|
assert_close(dists, expected_dists) |
|
|
assert_close(idxs, expected_idx) |
|
|
matcher = DescriptorMatcher('mnn').to(device) |
|
|
dists1, idxs1 = matcher(desc1, desc2) |
|
|
assert_close(dists1, expected_dists) |
|
|
assert_close(idxs1, expected_idx) |
|
|
|
|
|
def test_gradcheck(self, device): |
|
|
desc1 = torch.rand(5, 8, device=device) |
|
|
desc2 = torch.rand(7, 8, device=device) |
|
|
desc1 = utils.tensor_to_gradcheck_var(desc1) |
|
|
desc2 = utils.tensor_to_gradcheck_var(desc2) |
|
|
assert gradcheck(match_mnn, (desc1, desc2), raise_exception=True, nondet_tol=1e-4) |
|
|
|
|
|
|
|
|
class TestMatchSNN: |
|
|
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(2, 4, 4), (2, 5, 128), (6, 2, 32)]) |
|
|
def test_shape(self, num_desc1, num_desc2, dim, device): |
|
|
desc1 = torch.rand(num_desc1, dim, device=device) |
|
|
desc2 = torch.rand(num_desc2, dim, device=device) |
|
|
|
|
|
dists, idxs = match_snn(desc1, desc2) |
|
|
assert idxs.shape[1] == 2 |
|
|
assert dists.shape[1] == 1 |
|
|
assert idxs.shape[0] == dists.shape[0] |
|
|
assert dists.shape[0] <= num_desc1 |
|
|
|
|
|
def test_matching1(self, device): |
|
|
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device) |
|
|
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device) |
|
|
|
|
|
dists, idxs = match_snn(desc1, desc2, 0.8) |
|
|
expected_dists = torch.tensor([0, 0, 0.35355339059327373, 0, 0], device=device).view(-1, 1) |
|
|
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device) |
|
|
assert_close(dists, expected_dists) |
|
|
assert_close(idxs, expected_idx) |
|
|
matcher = DescriptorMatcher('snn', 0.8).to(device) |
|
|
dists1, idxs1 = matcher(desc1, desc2) |
|
|
assert_close(dists1, expected_dists) |
|
|
assert_close(idxs1, expected_idx) |
|
|
|
|
|
def test_matching2(self, device): |
|
|
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device) |
|
|
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device) |
|
|
|
|
|
dists, idxs = match_snn(desc1, desc2, 0.1) |
|
|
expected_dists = torch.tensor([0.0, 0, 0, 0], device=device).view(-1, 1) |
|
|
expected_idx = torch.tensor([[0, 4], [1, 3], [3, 1], [4, 0]], device=device) |
|
|
assert_close(dists, expected_dists) |
|
|
assert_close(idxs, expected_idx) |
|
|
matcher = DescriptorMatcher('snn', 0.1).to(device) |
|
|
dists1, idxs1 = matcher(desc1, desc2) |
|
|
assert_close(dists1, expected_dists) |
|
|
assert_close(idxs1, expected_idx) |
|
|
|
|
|
def test_gradcheck(self, device): |
|
|
desc1 = torch.rand(5, 8, device=device) |
|
|
desc2 = torch.rand(7, 8, device=device) |
|
|
desc1 = utils.tensor_to_gradcheck_var(desc1) |
|
|
desc2 = utils.tensor_to_gradcheck_var(desc2) |
|
|
assert gradcheck(match_snn, (desc1, desc2, 0.8), raise_exception=True, nondet_tol=1e-4) |
|
|
|
|
|
|
|
|
class TestMatchSMNN: |
|
|
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(2, 4, 4), (2, 5, 128), (6, 2, 32)]) |
|
|
def test_shape(self, num_desc1, num_desc2, dim, device): |
|
|
desc1 = torch.rand(num_desc1, dim, device=device) |
|
|
desc2 = torch.rand(num_desc2, dim, device=device) |
|
|
|
|
|
dists, idxs = match_smnn(desc1, desc2, 0.8) |
|
|
assert idxs.shape[1] == 2 |
|
|
assert dists.shape[1] == 1 |
|
|
assert idxs.shape[0] == dists.shape[0] |
|
|
assert dists.shape[0] <= num_desc1 |
|
|
assert dists.shape[0] <= num_desc2 |
|
|
|
|
|
def test_matching1(self, device): |
|
|
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device) |
|
|
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device) |
|
|
|
|
|
dists, idxs = match_smnn(desc1, desc2, 0.8) |
|
|
expected_dists = torch.tensor([0, 0, 0.5423, 0, 0], device=device).view(-1, 1) |
|
|
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device) |
|
|
assert_close(dists, expected_dists) |
|
|
assert_close(idxs, expected_idx) |
|
|
matcher = DescriptorMatcher('smnn', 0.8).to(device) |
|
|
dists1, idxs1 = matcher(desc1, desc2) |
|
|
assert_close(dists1, expected_dists) |
|
|
assert_close(idxs1, expected_idx) |
|
|
|
|
|
def test_matching2(self, device): |
|
|
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device) |
|
|
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device) |
|
|
|
|
|
dists, idxs = match_smnn(desc1, desc2, 0.1) |
|
|
expected_dists = torch.tensor([0.0, 0, 0, 0], device=device).view(-1, 1) |
|
|
expected_idx = torch.tensor([[0, 4], [1, 3], [3, 1], [4, 0]], device=device) |
|
|
assert_close(dists, expected_dists) |
|
|
assert_close(idxs, expected_idx) |
|
|
matcher = DescriptorMatcher('smnn', 0.1).to(device) |
|
|
dists1, idxs1 = matcher(desc1, desc2) |
|
|
assert_close(dists1, expected_dists) |
|
|
assert_close(idxs1, expected_idx) |
|
|
|
|
|
def test_gradcheck(self, device): |
|
|
desc1 = torch.rand(5, 8, device=device) |
|
|
desc2 = torch.rand(7, 8, device=device) |
|
|
desc1 = utils.tensor_to_gradcheck_var(desc1) |
|
|
desc2 = utils.tensor_to_gradcheck_var(desc2) |
|
|
matcher = DescriptorMatcher('smnn', 0.8).to(device) |
|
|
assert gradcheck(match_smnn, (desc1, desc2, 0.8), raise_exception=True, nondet_tol=1e-4) |
|
|
assert gradcheck(matcher, (desc1, desc2), raise_exception=True, nondet_tol=1e-4) |
|
|
|
|
|
@pytest.mark.jit |
|
|
@pytest.mark.parametrize("match_type", ["nn", "snn", "mnn", "smnn"]) |
|
|
def test_jit(self, match_type, device, dtype): |
|
|
desc1 = torch.rand(5, 8, device=device, dtype=dtype) |
|
|
desc2 = torch.rand(7, 8, device=device, dtype=dtype) |
|
|
matcher = DescriptorMatcher(match_type, 0.8).to(device) |
|
|
matcher_jit = torch.jit.script(DescriptorMatcher(match_type, 0.8).to(device)) |
|
|
assert_close(matcher(desc1, desc2)[0], matcher_jit(desc1, desc2)[0]) |
|
|
assert_close(matcher(desc1, desc2)[1], matcher_jit(desc1, desc2)[1]) |
|
|
|