compvis / test /feature /test_matching.py
Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import pytest
import torch
from torch.autograd import gradcheck
import kornia.testing as utils # test utils
from kornia.feature.matching import DescriptorMatcher, match_mnn, match_nn, match_smnn, match_snn
from kornia.testing import assert_close
class TestMatchNN:
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(1, 4, 4), (2, 5, 128), (6, 2, 32)])
def test_shape(self, num_desc1, num_desc2, dim, device):
desc1 = torch.rand(num_desc1, dim, device=device)
desc2 = torch.rand(num_desc2, dim, device=device)
dists, idxs = match_nn(desc1, desc2)
assert idxs.shape == (num_desc1, 2)
assert dists.shape == (num_desc1, 1)
def test_matching(self, device):
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device)
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device)
dists, idxs = match_nn(desc1, desc2)
expected_dists = torch.tensor([0, 0, 0.5, 0, 0], device=device).view(-1, 1)
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device)
assert_close(dists, expected_dists)
assert_close(idxs, expected_idx)
dists1, idxs1 = match_nn(desc1, desc2)
assert_close(dists1, expected_dists)
assert_close(idxs1, expected_idx)
def test_gradcheck(self, device):
desc1 = torch.rand(5, 8, device=device)
desc2 = torch.rand(7, 8, device=device)
desc1 = utils.tensor_to_gradcheck_var(desc1) # to var
desc2 = utils.tensor_to_gradcheck_var(desc2) # to var
assert gradcheck(match_mnn, (desc1, desc2), raise_exception=True, nondet_tol=1e-4)
class TestMatchMNN:
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(1, 4, 4), (2, 5, 128), (6, 2, 32)])
def test_shape(self, num_desc1, num_desc2, dim, device):
desc1 = torch.rand(num_desc1, dim, device=device)
desc2 = torch.rand(num_desc2, dim, device=device)
dists, idxs = match_mnn(desc1, desc2)
assert idxs.shape[1] == 2
assert dists.shape[1] == 1
assert idxs.shape[0] == dists.shape[0]
assert dists.shape[0] <= num_desc1
def test_matching(self, device):
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device)
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device)
dists, idxs = match_mnn(desc1, desc2)
expected_dists = torch.tensor([0, 0, 0.5, 0, 0], device=device).view(-1, 1)
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device)
assert_close(dists, expected_dists)
assert_close(idxs, expected_idx)
matcher = DescriptorMatcher('mnn').to(device)
dists1, idxs1 = matcher(desc1, desc2)
assert_close(dists1, expected_dists)
assert_close(idxs1, expected_idx)
def test_gradcheck(self, device):
desc1 = torch.rand(5, 8, device=device)
desc2 = torch.rand(7, 8, device=device)
desc1 = utils.tensor_to_gradcheck_var(desc1) # to var
desc2 = utils.tensor_to_gradcheck_var(desc2) # to var
assert gradcheck(match_mnn, (desc1, desc2), raise_exception=True, nondet_tol=1e-4)
class TestMatchSNN:
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(2, 4, 4), (2, 5, 128), (6, 2, 32)])
def test_shape(self, num_desc1, num_desc2, dim, device):
desc1 = torch.rand(num_desc1, dim, device=device)
desc2 = torch.rand(num_desc2, dim, device=device)
dists, idxs = match_snn(desc1, desc2)
assert idxs.shape[1] == 2
assert dists.shape[1] == 1
assert idxs.shape[0] == dists.shape[0]
assert dists.shape[0] <= num_desc1
def test_matching1(self, device):
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device)
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device)
dists, idxs = match_snn(desc1, desc2, 0.8)
expected_dists = torch.tensor([0, 0, 0.35355339059327373, 0, 0], device=device).view(-1, 1)
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device)
assert_close(dists, expected_dists)
assert_close(idxs, expected_idx)
matcher = DescriptorMatcher('snn', 0.8).to(device)
dists1, idxs1 = matcher(desc1, desc2)
assert_close(dists1, expected_dists)
assert_close(idxs1, expected_idx)
def test_matching2(self, device):
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device)
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device)
dists, idxs = match_snn(desc1, desc2, 0.1)
expected_dists = torch.tensor([0.0, 0, 0, 0], device=device).view(-1, 1)
expected_idx = torch.tensor([[0, 4], [1, 3], [3, 1], [4, 0]], device=device)
assert_close(dists, expected_dists)
assert_close(idxs, expected_idx)
matcher = DescriptorMatcher('snn', 0.1).to(device)
dists1, idxs1 = matcher(desc1, desc2)
assert_close(dists1, expected_dists)
assert_close(idxs1, expected_idx)
def test_gradcheck(self, device):
desc1 = torch.rand(5, 8, device=device)
desc2 = torch.rand(7, 8, device=device)
desc1 = utils.tensor_to_gradcheck_var(desc1) # to var
desc2 = utils.tensor_to_gradcheck_var(desc2) # to var
assert gradcheck(match_snn, (desc1, desc2, 0.8), raise_exception=True, nondet_tol=1e-4)
class TestMatchSMNN:
@pytest.mark.parametrize("num_desc1, num_desc2, dim", [(2, 4, 4), (2, 5, 128), (6, 2, 32)])
def test_shape(self, num_desc1, num_desc2, dim, device):
desc1 = torch.rand(num_desc1, dim, device=device)
desc2 = torch.rand(num_desc2, dim, device=device)
dists, idxs = match_smnn(desc1, desc2, 0.8)
assert idxs.shape[1] == 2
assert dists.shape[1] == 1
assert idxs.shape[0] == dists.shape[0]
assert dists.shape[0] <= num_desc1
assert dists.shape[0] <= num_desc2
def test_matching1(self, device):
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device)
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device)
dists, idxs = match_smnn(desc1, desc2, 0.8)
expected_dists = torch.tensor([0, 0, 0.5423, 0, 0], device=device).view(-1, 1)
expected_idx = torch.tensor([[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]], device=device)
assert_close(dists, expected_dists)
assert_close(idxs, expected_idx)
matcher = DescriptorMatcher('smnn', 0.8).to(device)
dists1, idxs1 = matcher(desc1, desc2)
assert_close(dists1, expected_dists)
assert_close(idxs1, expected_idx)
def test_matching2(self, device):
desc1 = torch.tensor([[0, 0.0], [1, 1], [2, 2], [3, 3.0], [5, 5.0]], device=device)
desc2 = torch.tensor([[5, 5.0], [3, 3.0], [2.3, 2.4], [1, 1], [0, 0.0]], device=device)
dists, idxs = match_smnn(desc1, desc2, 0.1)
expected_dists = torch.tensor([0.0, 0, 0, 0], device=device).view(-1, 1)
expected_idx = torch.tensor([[0, 4], [1, 3], [3, 1], [4, 0]], device=device)
assert_close(dists, expected_dists)
assert_close(idxs, expected_idx)
matcher = DescriptorMatcher('smnn', 0.1).to(device)
dists1, idxs1 = matcher(desc1, desc2)
assert_close(dists1, expected_dists)
assert_close(idxs1, expected_idx)
def test_gradcheck(self, device):
desc1 = torch.rand(5, 8, device=device)
desc2 = torch.rand(7, 8, device=device)
desc1 = utils.tensor_to_gradcheck_var(desc1) # to var
desc2 = utils.tensor_to_gradcheck_var(desc2) # to var
matcher = DescriptorMatcher('smnn', 0.8).to(device)
assert gradcheck(match_smnn, (desc1, desc2, 0.8), raise_exception=True, nondet_tol=1e-4)
assert gradcheck(matcher, (desc1, desc2), raise_exception=True, nondet_tol=1e-4)
@pytest.mark.jit
@pytest.mark.parametrize("match_type", ["nn", "snn", "mnn", "smnn"])
def test_jit(self, match_type, device, dtype):
desc1 = torch.rand(5, 8, device=device, dtype=dtype)
desc2 = torch.rand(7, 8, device=device, dtype=dtype)
matcher = DescriptorMatcher(match_type, 0.8).to(device)
matcher_jit = torch.jit.script(DescriptorMatcher(match_type, 0.8).to(device))
assert_close(matcher(desc1, desc2)[0], matcher_jit(desc1, desc2)[0])
assert_close(matcher(desc1, desc2)[1], matcher_jit(desc1, desc2)[1])