File size: 9,483 Bytes
0940df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import torch
import torch.nn as nn

from timm.models.layers import DropPath

from models.convformer import LayerNormWithoutBias
from models.common import ConvGLU
from models.mat_pytorch_impl import compute_bilinear_weights, compute_match_attention, compute_bilinear_softmax, attention_aggregate
from models.match_former_ops import MF_FusedForwardOps
from utils.utils import bilinear_sample_by_offset, init_coords

class MatchAttention(torch.nn.Module):
    r"""MatchAttention: Matching the relative positions
    """
    def __init__(self, args, dim, win_r=[1, 1], num_head=8, head_dim=None, qkv_bias=False, 
                 attn_drop=0., proj_drop=0., proj_bias=False, cross=False, noc_embed=False, **kargs):
        super().__init__()

        self.num_head = num_head
        self.cross = cross
        self.noc_embed = noc_embed if not cross else False # only for self attention

        self.head_dim = dim // num_head if head_dim is None else head_dim
        self.scale = self.head_dim ** -0.5

        self.attention_dim = self.num_head * self.head_dim

        self.win_r = win_r
        self.attn_num = (2*win_r[0]+2)*(2*win_r[1]+2)

        embed_dim = dim + 1 if noc_embed else dim # '1' for noc_mask
        self.q = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
        self.k = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
        self.v = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        if self.cross:
            self.g = nn.Sequential(nn.Linear(embed_dim, self.attention_dim,bias=qkv_bias), nn.SiLU())
            self.proj = nn.Linear(self.attention_dim + self.num_head*self.attn_num, dim, bias=proj_bias)
        else:
            self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)
        self.use_pytorch = (args.mat_impl == 'pytorch')
        self.mf_fused = MF_FusedForwardOps()

    def clamp_max_offset(self, max_offset, H, W):
        max_offset_x, max_offset_y = max_offset.chunk(2, dim=-1) # to avoid inplace operation

        # for ONNX support
        min_x = torch.tensor(self.win_r[0], dtype=max_offset.dtype, device=max_offset.device)
        max_x = torch.tensor(W - 1 - self.win_r[0] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)
        min_y = torch.tensor(self.win_r[1], dtype=max_offset.dtype, device=max_offset.device)
        max_y = torch.tensor(H - 1 - self.win_r[1] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)

        max_offset_x = torch.clamp(max_offset_x, min=min_x, max=max_x)
        max_offset_y = torch.clamp(max_offset_y, min=min_y, max=max_y)

        ## max_offset_x = max_offset_x.clamp(min=self.win_r[0], max=W-1-self.win_r[0]-1e-3)
        ## max_offset_y = max_offset_y.clamp(min=self.win_r[1], max=H-1-self.win_r[1]-1e-3)
        return torch.cat((max_offset_x, max_offset_y), dim=-1).contiguous()

    def forward(self, x, max_offset, noc_mask=None): # offset: [B, N, h, 2]
        B, H, W, _ = x.shape
        N = H*W
        assert (2*self.win_r[1] + 2 <= H) and (2*self.win_r[0] + 2 <= W)
        x = x.view(B, N, -1).contiguous()

        if self.cross:
            ref_, tgt_ = x.chunk(2, dim=0) # split along batch dimension 
            ref = torch.cat((ref_, tgt_), dim=0) # order
            tgt = torch.cat((tgt_, ref_), dim=0) # reverse order
            g = self.g(ref)
        else: # self-attn
            if self.noc_embed:
                x = torch.cat((x, noc_mask.view(B, N, -1)), dim=-1).contiguous()
            ref, tgt = x, x
        q, k, v = self.q(ref), self.k(tgt), self.v(tgt)

        ## non-parameter modules
        max_offset = self.clamp_max_offset(max_offset, H, W)

        if self.use_pytorch:
            m_id = torch.floor(max_offset).to(torch.int32) # [B, N, h, 2]
            bilinear_weight = compute_bilinear_weights(max_offset)

            attn, indices_gather = compute_match_attention(q.view(B, N, self.num_head, -1), k.view(B, N, self.num_head, -1), m_id, self.win_r, H, W)
            attn = attn * self.scale

            attn = compute_bilinear_softmax(attn, bilinear_weight, self.win_r)
            attn = self.attn_drop(attn)

            x = attention_aggregate(v.view(B, N, self.num_head, -1), attn, indices_gather, self.win_r)
        else:
            x, attn = self.mf_fused(max_offset, q, k, v, H, W, self.win_r, self.attn_num, attn_type='l1_norm', scale=self.scale)

        if self.cross:
            x = g * x # gate
            attn = attn.view(B, N, -1).contiguous()
            x = torch.cat((x, attn), dim=-1).contiguous()
        x = self.proj(x)
        x = self.proj_drop(x)
        return x.view(B, H, W, -1).contiguous()


class MatchAttentionLayer(nn.Module):
    r"""MatchAttention layer with interleaved self-MatchAttention, cross-MatchAttention, and ConvGLU
    """

    def __init__(self, args, dim, win_r,
                 num_head=8, head_dim=32, mlp=ConvGLU, mlp_ratio=2, field_dim=2,
                 norm_layer=nn.LayerNorm, drop=0., drop_path=0.):
        super().__init__()
        self.num_head = num_head
        self.field_dim = field_dim

        self.match_attention_self = MatchAttention(args, dim + self.field_dim + self.num_head*2, [win_r, win_r], num_head=num_head, head_dim=head_dim, noc_embed=True)
        self.norm0 = norm_layer(dim + self.field_dim + self.num_head*2)

        self.match_attention_cross = MatchAttention(args, dim + self.field_dim, [win_r, win_r], num_head=num_head, head_dim=head_dim, cross=True)
        self.norm1 = norm_layer(dim + self.field_dim)

        self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop)
        self.norm2 = norm_layer(dim)

        self.field_scale = nn.Parameter(0.1*torch.ones(1, 1, 1, 2))

        self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def consistency_mask(self, field, A=2):
        offset = field + init_coords(field) # [B, H, W, 2]
        field_ref_, field_tgt_ = field.chunk(2, dim=0)
        field_ref = torch.cat((field_ref_, field_tgt_), dim=0) # order
        field_tgt = torch.cat((field_tgt_, field_ref_), dim=0) # reverse order
        field_tgt_to_ref = bilinear_sample_by_offset(field_tgt.permute(0, 3, 1, 2).contiguous(), offset).permute(0, 2, 3, 1).contiguous()
        field_diff = torch.abs(field_ref + field_tgt_to_ref).sum(dim=-1, keepdim=True) # ref and tgt flow has different sign
        noc_mask = (field_diff < A).to(field_diff.dtype)
        return noc_mask

    def forward(self, x, self_rpos, field, stereo=True): # self_rpos [B, H, W, h*2], field [B, H, W, 2]

        field_out = {}
        B, H, W, C = x.shape

        noc_mask = self.consistency_mask(field.detach())

        x = torch.cat((x, field*self.field_scale.to(field.dtype), self_rpos), dim=-1).contiguous()

        coords_0 = init_coords(field).repeat(1, 1, 1, self.num_head)
        self_offset = self_rpos + coords_0
        self_offset = self_offset.view(B, H*W, self.num_head, 2).contiguous()

        x = x + self.drop_path0(self.match_attention_self(self.norm0(x), self_offset, noc_mask))

        self_rpos = x[..., -(self.num_head*2):].contiguous() # [B, H, W, h*2]
        x = x[..., :-(self.num_head*2)].contiguous()

        if stereo: x[..., -1] = 0
        field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
        field_out['self'] = field.clone()

        offset = field.repeat(1, 1, 1, self.num_head).contiguous() + coords_0 # [B, H, W, h*2]
        offset = offset.view(B, H*W, self.num_head, 2).contiguous()

        x = x + self.drop_path1(self.match_attention_cross(self.norm1(x), offset))

        if stereo: x[..., -1] = 0
        field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
        field_out['cross'] = field.clone()

        x = x[..., :-self.field_dim].contiguous() # No field feature in MLP

        x = x + self.drop_path2(self.mlp(self.norm2(x)))

        return x, self_rpos, field, field_out


class MatchAttentionBlock(nn.Module):
    r"""MatchAttention block with multiple match-attention layers
    """

    def __init__(self, args, dim, win_r=2,
                 num_layer=6, num_head=8, head_dim=32,
                 mlp=ConvGLU, mlp_ratio=2, field_dim=2,
                 norm_layer=LayerNormWithoutBias,
                 drop=0., dp_rates=[0.]):

        super().__init__()
        self.num_head = num_head

        self.layers = nn.ModuleList()
        for i in range(num_layer):
            layer = MatchAttentionLayer(args, dim, win_r=win_r, num_head=num_head, head_dim=head_dim,
                                        mlp=mlp, mlp_ratio=mlp_ratio, field_dim=field_dim,
                                        norm_layer=norm_layer, drop=drop, drop_path=dp_rates[i])
            self.layers.append(layer)

    def forward(self, x, self_rpos, field, stereo=True):
        fields = []
        B, H, W, C = x.shape
        self_rpos = self_rpos.repeat(1, 1, 1, self.num_head) # [B, H, W, 2] -> [B, H, W, h*2]

        for layer in self.layers:

            x, self_rpos, field, field_out = layer(x, self_rpos, field, stereo)
            fields.append(field_out)

        self_rpos = self_rpos.view(B, H, W, self.num_head, 2).mean(dim=-2, keepdim=False)

        return x, self_rpos, field, fields