import torch import torch.nn as nn from timm.models.layers import DropPath from models.convformer import LayerNormWithoutBias from models.common import ConvGLU from models.mat_pytorch_impl import compute_bilinear_weights, compute_match_attention, compute_bilinear_softmax, attention_aggregate from models.match_former_ops import MF_FusedForwardOps from utils.utils import bilinear_sample_by_offset, init_coords class MatchAttention(torch.nn.Module): r"""MatchAttention: Matching the relative positions """ def __init__(self, args, dim, win_r=[1, 1], num_head=8, head_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0., proj_bias=False, cross=False, noc_embed=False, **kargs): super().__init__() self.num_head = num_head self.cross = cross self.noc_embed = noc_embed if not cross else False # only for self attention self.head_dim = dim // num_head if head_dim is None else head_dim self.scale = self.head_dim ** -0.5 self.attention_dim = self.num_head * self.head_dim self.win_r = win_r self.attn_num = (2*win_r[0]+2)*(2*win_r[1]+2) embed_dim = dim + 1 if noc_embed else dim # '1' for noc_mask self.q = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias) self.k = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias) self.v = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) if self.cross: self.g = nn.Sequential(nn.Linear(embed_dim, self.attention_dim,bias=qkv_bias), nn.SiLU()) self.proj = nn.Linear(self.attention_dim + self.num_head*self.attn_num, dim, bias=proj_bias) else: self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) self.use_pytorch = (args.mat_impl == 'pytorch') self.mf_fused = MF_FusedForwardOps() def clamp_max_offset(self, max_offset, H, W): max_offset_x, max_offset_y = max_offset.chunk(2, dim=-1) # to avoid inplace operation # for ONNX support min_x = torch.tensor(self.win_r[0], dtype=max_offset.dtype, device=max_offset.device) max_x = torch.tensor(W - 1 - self.win_r[0] - 1e-3, dtype=max_offset.dtype, device=max_offset.device) min_y = torch.tensor(self.win_r[1], dtype=max_offset.dtype, device=max_offset.device) max_y = torch.tensor(H - 1 - self.win_r[1] - 1e-3, dtype=max_offset.dtype, device=max_offset.device) max_offset_x = torch.clamp(max_offset_x, min=min_x, max=max_x) max_offset_y = torch.clamp(max_offset_y, min=min_y, max=max_y) ## max_offset_x = max_offset_x.clamp(min=self.win_r[0], max=W-1-self.win_r[0]-1e-3) ## max_offset_y = max_offset_y.clamp(min=self.win_r[1], max=H-1-self.win_r[1]-1e-3) return torch.cat((max_offset_x, max_offset_y), dim=-1).contiguous() def forward(self, x, max_offset, noc_mask=None): # offset: [B, N, h, 2] B, H, W, _ = x.shape N = H*W assert (2*self.win_r[1] + 2 <= H) and (2*self.win_r[0] + 2 <= W) x = x.view(B, N, -1).contiguous() if self.cross: ref_, tgt_ = x.chunk(2, dim=0) # split along batch dimension ref = torch.cat((ref_, tgt_), dim=0) # order tgt = torch.cat((tgt_, ref_), dim=0) # reverse order g = self.g(ref) else: # self-attn if self.noc_embed: x = torch.cat((x, noc_mask.view(B, N, -1)), dim=-1).contiguous() ref, tgt = x, x q, k, v = self.q(ref), self.k(tgt), self.v(tgt) ## non-parameter modules max_offset = self.clamp_max_offset(max_offset, H, W) if self.use_pytorch: m_id = torch.floor(max_offset).to(torch.int32) # [B, N, h, 2] bilinear_weight = compute_bilinear_weights(max_offset) attn, indices_gather = compute_match_attention(q.view(B, N, self.num_head, -1), k.view(B, N, self.num_head, -1), m_id, self.win_r, H, W) attn = attn * self.scale attn = compute_bilinear_softmax(attn, bilinear_weight, self.win_r) attn = self.attn_drop(attn) x = attention_aggregate(v.view(B, N, self.num_head, -1), attn, indices_gather, self.win_r) else: x, attn = self.mf_fused(max_offset, q, k, v, H, W, self.win_r, self.attn_num, attn_type='l1_norm', scale=self.scale) if self.cross: x = g * x # gate attn = attn.view(B, N, -1).contiguous() x = torch.cat((x, attn), dim=-1).contiguous() x = self.proj(x) x = self.proj_drop(x) return x.view(B, H, W, -1).contiguous() class MatchAttentionLayer(nn.Module): r"""MatchAttention layer with interleaved self-MatchAttention, cross-MatchAttention, and ConvGLU """ def __init__(self, args, dim, win_r, num_head=8, head_dim=32, mlp=ConvGLU, mlp_ratio=2, field_dim=2, norm_layer=nn.LayerNorm, drop=0., drop_path=0.): super().__init__() self.num_head = num_head self.field_dim = field_dim self.match_attention_self = MatchAttention(args, dim + self.field_dim + self.num_head*2, [win_r, win_r], num_head=num_head, head_dim=head_dim, noc_embed=True) self.norm0 = norm_layer(dim + self.field_dim + self.num_head*2) self.match_attention_cross = MatchAttention(args, dim + self.field_dim, [win_r, win_r], num_head=num_head, head_dim=head_dim, cross=True) self.norm1 = norm_layer(dim + self.field_dim) self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop) self.norm2 = norm_layer(dim) self.field_scale = nn.Parameter(0.1*torch.ones(1, 1, 1, 2)) self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def consistency_mask(self, field, A=2): offset = field + init_coords(field) # [B, H, W, 2] field_ref_, field_tgt_ = field.chunk(2, dim=0) field_ref = torch.cat((field_ref_, field_tgt_), dim=0) # order field_tgt = torch.cat((field_tgt_, field_ref_), dim=0) # reverse order field_tgt_to_ref = bilinear_sample_by_offset(field_tgt.permute(0, 3, 1, 2).contiguous(), offset).permute(0, 2, 3, 1).contiguous() field_diff = torch.abs(field_ref + field_tgt_to_ref).sum(dim=-1, keepdim=True) # ref and tgt flow has different sign noc_mask = (field_diff < A).to(field_diff.dtype) return noc_mask def forward(self, x, self_rpos, field, stereo=True): # self_rpos [B, H, W, h*2], field [B, H, W, 2] field_out = {} B, H, W, C = x.shape noc_mask = self.consistency_mask(field.detach()) x = torch.cat((x, field*self.field_scale.to(field.dtype), self_rpos), dim=-1).contiguous() coords_0 = init_coords(field).repeat(1, 1, 1, self.num_head) self_offset = self_rpos + coords_0 self_offset = self_offset.view(B, H*W, self.num_head, 2).contiguous() x = x + self.drop_path0(self.match_attention_self(self.norm0(x), self_offset, noc_mask)) self_rpos = x[..., -(self.num_head*2):].contiguous() # [B, H, W, h*2] x = x[..., :-(self.num_head*2)].contiguous() if stereo: x[..., -1] = 0 field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype) field_out['self'] = field.clone() offset = field.repeat(1, 1, 1, self.num_head).contiguous() + coords_0 # [B, H, W, h*2] offset = offset.view(B, H*W, self.num_head, 2).contiguous() x = x + self.drop_path1(self.match_attention_cross(self.norm1(x), offset)) if stereo: x[..., -1] = 0 field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype) field_out['cross'] = field.clone() x = x[..., :-self.field_dim].contiguous() # No field feature in MLP x = x + self.drop_path2(self.mlp(self.norm2(x))) return x, self_rpos, field, field_out class MatchAttentionBlock(nn.Module): r"""MatchAttention block with multiple match-attention layers """ def __init__(self, args, dim, win_r=2, num_layer=6, num_head=8, head_dim=32, mlp=ConvGLU, mlp_ratio=2, field_dim=2, norm_layer=LayerNormWithoutBias, drop=0., dp_rates=[0.]): super().__init__() self.num_head = num_head self.layers = nn.ModuleList() for i in range(num_layer): layer = MatchAttentionLayer(args, dim, win_r=win_r, num_head=num_head, head_dim=head_dim, mlp=mlp, mlp_ratio=mlp_ratio, field_dim=field_dim, norm_layer=norm_layer, drop=drop, drop_path=dp_rates[i]) self.layers.append(layer) def forward(self, x, self_rpos, field, stereo=True): fields = [] B, H, W, C = x.shape self_rpos = self_rpos.repeat(1, 1, 1, self.num_head) # [B, H, W, 2] -> [B, H, W, h*2] for layer in self.layers: x, self_rpos, field, field_out = layer(x, self_rpos, field, stereo) fields.append(field_out) self_rpos = self_rpos.view(B, H, W, self.num_head, 2).mean(dim=-2, keepdim=False) return x, self_rpos, field, fields