File size: 5,150 Bytes
7e31006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torch.nn as nn
import torch.nn.functional as F


INF = 1e9  # -1e4 for fp16matmul


class CoarseMatching(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.window_size = config["local_resolution"]
        self.thr = config["coarse"]["mconf_thr"]
        self.temperature = config["coarse"]["dsmax_temperature"]
        self.ds_opt = config["coarse"]["ds_opt"]
        self.pad_num = config["coarse"]["train_pad_num"]
        self.topk = config["coarse"]["topk"]
        self.deploy = config["deploy"]

    def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
        """
        Args:
            feat0 (torch.Tensor): [N, L, C]
            feat1 (torch.Tensor): [N, S, C]
            data (dict)
            mask_c0 (torch.Tensor): [N, L] (optional)
            mask_c1 (torch.Tensor): [N, S] (optional)
        Update:
            data (dict): {
                'b_ids' (torch.Tensor): [M'],
                'i_ids' (torch.Tensor): [M'],
                'j_ids' (torch.Tensor): [M'],
                'm_bids' (torch.Tensor): [M],
                'mkpts0_c' (torch.Tensor): [M, 2],
                'mkpts1_c' (torch.Tensor): [M, 2],
                'mconf' (torch.Tensor): [M]}
            NOTE: M' != M during training.
        """
        # normalize
        feat_c0, feat_c1 = map(
            lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1]
        )

        with torch.autocast(enabled=False, device_type="cuda"):
            sim_matrix = (
                torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) /
                self.temperature
            )
            del feat_c0, feat_c1
            if mask_c0 is not None:
                sim_matrix = sim_matrix.float().masked_fill(
                    ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
                    -INF,
                )

        if not self.training and self.ds_opt:
            # Alternative implementation of daul-softmax operator for efficient inference
            sim_matrix = torch.exp(sim_matrix)
            conf_matrix = F.normalize(sim_matrix, p=1, dim=1) * F.normalize(
                sim_matrix, p=1, dim=2
            )
        else:
            # Native daul-softmax operator
            conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
        
        data.update(
            {
                "conf_matrix": conf_matrix,
            }
        )

        if not self.deploy:
            # predict coarse matches from conf_matrix
            self.coarse_matching_selection(data)

        return conf_matrix # Returning the sim_matrix can be faster, but it may reduce the accuracy. 

    # Static tensor shape for mini-batch inference.
    @torch.no_grad()
    def coarse_matching_selection(self, data):
        """
        Args:
            data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
        Returns:
            coarse_matches (dict): {
                'b_ids' (torch.Tensor): [M'],
                'i_ids' (torch.Tensor): [M'],
                'j_ids' (torch.Tensor): [M'],
                'm_bids' (torch.Tensor): [M],
                'mkpts0_c' (torch.Tensor): [M, 2],
                'mkpts1_c' (torch.Tensor): [M, 2],
                'mconf' (torch.Tensor): [M]}
        """
        conf_matrix = data["conf_matrix"]

        # mutual nearest
        # mask = (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
        #     * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
        # conf_matrix[~mask]=0

        k = self.topk
        row_max_val, row_max_idx = torch.max(conf_matrix, dim=2)

        # prevent out of range
        if k == -1 or k > row_max_val.shape[-1]:
            k = row_max_val.shape[-1]

        topk_val, topk_idx = torch.topk(row_max_val, k)
        b_ids = (
            torch.arange(conf_matrix.shape[0], device=conf_matrix.device)
            .unsqueeze(1)
            .repeat(1, k)
            .flatten()
        )
        i_ids = topk_idx.flatten()
        j_ids = row_max_idx[b_ids, i_ids].flatten()
        mconf = conf_matrix[b_ids, i_ids, j_ids]

        scale = data["hw0_i"][0] / data["hw0_c"][0]
        scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
        scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
        mkpts0_c = (
            torch.stack(
                [
                    i_ids % data["hw0_c"][1],
                    torch.div(i_ids, data["hw0_c"][1], rounding_mode="floor"),
                ],
                dim=1,
            )
            * scale0
        )
        mkpts1_c = (
            torch.stack(
                [
                    j_ids % data["hw1_c"][1],
                    torch.div(j_ids, data["hw1_c"][1], rounding_mode="floor"),
                ],
                dim=1,
            )
            * scale1
        )

        data.update(
            {
                "mconf": mconf,
                "mkpts0_c": mkpts0_c,
                "mkpts1_c": mkpts1_c,
                "b_ids": b_ids,
                "i_ids": i_ids,
                "j_ids": j_ids,
            }
        )