EDM / src /edm /head /coarse_matching.py
lixi042
update
7e31006
import torch
import torch.nn as nn
import torch.nn.functional as F
INF = 1e9 # -1e4 for fp16matmul
class CoarseMatching(nn.Module):
def __init__(self, config):
super().__init__()
self.window_size = config["local_resolution"]
self.thr = config["coarse"]["mconf_thr"]
self.temperature = config["coarse"]["dsmax_temperature"]
self.ds_opt = config["coarse"]["ds_opt"]
self.pad_num = config["coarse"]["train_pad_num"]
self.topk = config["coarse"]["topk"]
self.deploy = config["deploy"]
def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
feat1 (torch.Tensor): [N, S, C]
data (dict)
mask_c0 (torch.Tensor): [N, L] (optional)
mask_c1 (torch.Tensor): [N, S] (optional)
Update:
data (dict): {
'b_ids' (torch.Tensor): [M'],
'i_ids' (torch.Tensor): [M'],
'j_ids' (torch.Tensor): [M'],
'm_bids' (torch.Tensor): [M],
'mkpts0_c' (torch.Tensor): [M, 2],
'mkpts1_c' (torch.Tensor): [M, 2],
'mconf' (torch.Tensor): [M]}
NOTE: M' != M during training.
"""
# normalize
feat_c0, feat_c1 = map(
lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1]
)
with torch.autocast(enabled=False, device_type="cuda"):
sim_matrix = (
torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) /
self.temperature
)
del feat_c0, feat_c1
if mask_c0 is not None:
sim_matrix = sim_matrix.float().masked_fill(
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
-INF,
)
if not self.training and self.ds_opt:
# Alternative implementation of daul-softmax operator for efficient inference
sim_matrix = torch.exp(sim_matrix)
conf_matrix = F.normalize(sim_matrix, p=1, dim=1) * F.normalize(
sim_matrix, p=1, dim=2
)
else:
# Native daul-softmax operator
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
data.update(
{
"conf_matrix": conf_matrix,
}
)
if not self.deploy:
# predict coarse matches from conf_matrix
self.coarse_matching_selection(data)
return conf_matrix # Returning the sim_matrix can be faster, but it may reduce the accuracy.
# Static tensor shape for mini-batch inference.
@torch.no_grad()
def coarse_matching_selection(self, data):
"""
Args:
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
Returns:
coarse_matches (dict): {
'b_ids' (torch.Tensor): [M'],
'i_ids' (torch.Tensor): [M'],
'j_ids' (torch.Tensor): [M'],
'm_bids' (torch.Tensor): [M],
'mkpts0_c' (torch.Tensor): [M, 2],
'mkpts1_c' (torch.Tensor): [M, 2],
'mconf' (torch.Tensor): [M]}
"""
conf_matrix = data["conf_matrix"]
# mutual nearest
# mask = (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
# * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
# conf_matrix[~mask]=0
k = self.topk
row_max_val, row_max_idx = torch.max(conf_matrix, dim=2)
# prevent out of range
if k == -1 or k > row_max_val.shape[-1]:
k = row_max_val.shape[-1]
topk_val, topk_idx = torch.topk(row_max_val, k)
b_ids = (
torch.arange(conf_matrix.shape[0], device=conf_matrix.device)
.unsqueeze(1)
.repeat(1, k)
.flatten()
)
i_ids = topk_idx.flatten()
j_ids = row_max_idx[b_ids, i_ids].flatten()
mconf = conf_matrix[b_ids, i_ids, j_ids]
scale = data["hw0_i"][0] / data["hw0_c"][0]
scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
mkpts0_c = (
torch.stack(
[
i_ids % data["hw0_c"][1],
torch.div(i_ids, data["hw0_c"][1], rounding_mode="floor"),
],
dim=1,
)
* scale0
)
mkpts1_c = (
torch.stack(
[
j_ids % data["hw1_c"][1],
torch.div(j_ids, data["hw1_c"][1], rounding_mode="floor"),
],
dim=1,
)
* scale1
)
data.update(
{
"mconf": mconf,
"mkpts0_c": mkpts0_c,
"mkpts1_c": mkpts1_c,
"b_ids": b_ids,
"i_ids": i_ids,
"j_ids": j_ids,
}
)