Spaces:
Running
Running
File size: 9,483 Bytes
0940df6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import torch
import torch.nn as nn
from timm.models.layers import DropPath
from models.convformer import LayerNormWithoutBias
from models.common import ConvGLU
from models.mat_pytorch_impl import compute_bilinear_weights, compute_match_attention, compute_bilinear_softmax, attention_aggregate
from models.match_former_ops import MF_FusedForwardOps
from utils.utils import bilinear_sample_by_offset, init_coords
class MatchAttention(torch.nn.Module):
r"""MatchAttention: Matching the relative positions
"""
def __init__(self, args, dim, win_r=[1, 1], num_head=8, head_dim=None, qkv_bias=False,
attn_drop=0., proj_drop=0., proj_bias=False, cross=False, noc_embed=False, **kargs):
super().__init__()
self.num_head = num_head
self.cross = cross
self.noc_embed = noc_embed if not cross else False # only for self attention
self.head_dim = dim // num_head if head_dim is None else head_dim
self.scale = self.head_dim ** -0.5
self.attention_dim = self.num_head * self.head_dim
self.win_r = win_r
self.attn_num = (2*win_r[0]+2)*(2*win_r[1]+2)
embed_dim = dim + 1 if noc_embed else dim # '1' for noc_mask
self.q = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
self.k = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
self.v = nn.Linear(embed_dim, self.attention_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
if self.cross:
self.g = nn.Sequential(nn.Linear(embed_dim, self.attention_dim,bias=qkv_bias), nn.SiLU())
self.proj = nn.Linear(self.attention_dim + self.num_head*self.attn_num, dim, bias=proj_bias)
else:
self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.use_pytorch = (args.mat_impl == 'pytorch')
self.mf_fused = MF_FusedForwardOps()
def clamp_max_offset(self, max_offset, H, W):
max_offset_x, max_offset_y = max_offset.chunk(2, dim=-1) # to avoid inplace operation
# for ONNX support
min_x = torch.tensor(self.win_r[0], dtype=max_offset.dtype, device=max_offset.device)
max_x = torch.tensor(W - 1 - self.win_r[0] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)
min_y = torch.tensor(self.win_r[1], dtype=max_offset.dtype, device=max_offset.device)
max_y = torch.tensor(H - 1 - self.win_r[1] - 1e-3, dtype=max_offset.dtype, device=max_offset.device)
max_offset_x = torch.clamp(max_offset_x, min=min_x, max=max_x)
max_offset_y = torch.clamp(max_offset_y, min=min_y, max=max_y)
## max_offset_x = max_offset_x.clamp(min=self.win_r[0], max=W-1-self.win_r[0]-1e-3)
## max_offset_y = max_offset_y.clamp(min=self.win_r[1], max=H-1-self.win_r[1]-1e-3)
return torch.cat((max_offset_x, max_offset_y), dim=-1).contiguous()
def forward(self, x, max_offset, noc_mask=None): # offset: [B, N, h, 2]
B, H, W, _ = x.shape
N = H*W
assert (2*self.win_r[1] + 2 <= H) and (2*self.win_r[0] + 2 <= W)
x = x.view(B, N, -1).contiguous()
if self.cross:
ref_, tgt_ = x.chunk(2, dim=0) # split along batch dimension
ref = torch.cat((ref_, tgt_), dim=0) # order
tgt = torch.cat((tgt_, ref_), dim=0) # reverse order
g = self.g(ref)
else: # self-attn
if self.noc_embed:
x = torch.cat((x, noc_mask.view(B, N, -1)), dim=-1).contiguous()
ref, tgt = x, x
q, k, v = self.q(ref), self.k(tgt), self.v(tgt)
## non-parameter modules
max_offset = self.clamp_max_offset(max_offset, H, W)
if self.use_pytorch:
m_id = torch.floor(max_offset).to(torch.int32) # [B, N, h, 2]
bilinear_weight = compute_bilinear_weights(max_offset)
attn, indices_gather = compute_match_attention(q.view(B, N, self.num_head, -1), k.view(B, N, self.num_head, -1), m_id, self.win_r, H, W)
attn = attn * self.scale
attn = compute_bilinear_softmax(attn, bilinear_weight, self.win_r)
attn = self.attn_drop(attn)
x = attention_aggregate(v.view(B, N, self.num_head, -1), attn, indices_gather, self.win_r)
else:
x, attn = self.mf_fused(max_offset, q, k, v, H, W, self.win_r, self.attn_num, attn_type='l1_norm', scale=self.scale)
if self.cross:
x = g * x # gate
attn = attn.view(B, N, -1).contiguous()
x = torch.cat((x, attn), dim=-1).contiguous()
x = self.proj(x)
x = self.proj_drop(x)
return x.view(B, H, W, -1).contiguous()
class MatchAttentionLayer(nn.Module):
r"""MatchAttention layer with interleaved self-MatchAttention, cross-MatchAttention, and ConvGLU
"""
def __init__(self, args, dim, win_r,
num_head=8, head_dim=32, mlp=ConvGLU, mlp_ratio=2, field_dim=2,
norm_layer=nn.LayerNorm, drop=0., drop_path=0.):
super().__init__()
self.num_head = num_head
self.field_dim = field_dim
self.match_attention_self = MatchAttention(args, dim + self.field_dim + self.num_head*2, [win_r, win_r], num_head=num_head, head_dim=head_dim, noc_embed=True)
self.norm0 = norm_layer(dim + self.field_dim + self.num_head*2)
self.match_attention_cross = MatchAttention(args, dim + self.field_dim, [win_r, win_r], num_head=num_head, head_dim=head_dim, cross=True)
self.norm1 = norm_layer(dim + self.field_dim)
self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop)
self.norm2 = norm_layer(dim)
self.field_scale = nn.Parameter(0.1*torch.ones(1, 1, 1, 2))
self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def consistency_mask(self, field, A=2):
offset = field + init_coords(field) # [B, H, W, 2]
field_ref_, field_tgt_ = field.chunk(2, dim=0)
field_ref = torch.cat((field_ref_, field_tgt_), dim=0) # order
field_tgt = torch.cat((field_tgt_, field_ref_), dim=0) # reverse order
field_tgt_to_ref = bilinear_sample_by_offset(field_tgt.permute(0, 3, 1, 2).contiguous(), offset).permute(0, 2, 3, 1).contiguous()
field_diff = torch.abs(field_ref + field_tgt_to_ref).sum(dim=-1, keepdim=True) # ref and tgt flow has different sign
noc_mask = (field_diff < A).to(field_diff.dtype)
return noc_mask
def forward(self, x, self_rpos, field, stereo=True): # self_rpos [B, H, W, h*2], field [B, H, W, 2]
field_out = {}
B, H, W, C = x.shape
noc_mask = self.consistency_mask(field.detach())
x = torch.cat((x, field*self.field_scale.to(field.dtype), self_rpos), dim=-1).contiguous()
coords_0 = init_coords(field).repeat(1, 1, 1, self.num_head)
self_offset = self_rpos + coords_0
self_offset = self_offset.view(B, H*W, self.num_head, 2).contiguous()
x = x + self.drop_path0(self.match_attention_self(self.norm0(x), self_offset, noc_mask))
self_rpos = x[..., -(self.num_head*2):].contiguous() # [B, H, W, h*2]
x = x[..., :-(self.num_head*2)].contiguous()
if stereo: x[..., -1] = 0
field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
field_out['self'] = field.clone()
offset = field.repeat(1, 1, 1, self.num_head).contiguous() + coords_0 # [B, H, W, h*2]
offset = offset.view(B, H*W, self.num_head, 2).contiguous()
x = x + self.drop_path1(self.match_attention_cross(self.norm1(x), offset))
if stereo: x[..., -1] = 0
field = x[..., -self.field_dim:].contiguous() / self.field_scale.to(field.dtype)
field_out['cross'] = field.clone()
x = x[..., :-self.field_dim].contiguous() # No field feature in MLP
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x, self_rpos, field, field_out
class MatchAttentionBlock(nn.Module):
r"""MatchAttention block with multiple match-attention layers
"""
def __init__(self, args, dim, win_r=2,
num_layer=6, num_head=8, head_dim=32,
mlp=ConvGLU, mlp_ratio=2, field_dim=2,
norm_layer=LayerNormWithoutBias,
drop=0., dp_rates=[0.]):
super().__init__()
self.num_head = num_head
self.layers = nn.ModuleList()
for i in range(num_layer):
layer = MatchAttentionLayer(args, dim, win_r=win_r, num_head=num_head, head_dim=head_dim,
mlp=mlp, mlp_ratio=mlp_ratio, field_dim=field_dim,
norm_layer=norm_layer, drop=drop, drop_path=dp_rates[i])
self.layers.append(layer)
def forward(self, x, self_rpos, field, stereo=True):
fields = []
B, H, W, C = x.shape
self_rpos = self_rpos.repeat(1, 1, 1, self.num_head) # [B, H, W, 2] -> [B, H, W, h*2]
for layer in self.layers:
x, self_rpos, field, field_out = layer(x, self_rpos, field, stereo)
fields.append(field_out)
self_rpos = self_rpos.view(B, H, W, self.num_head, 2).mean(dim=-2, keepdim=False)
return x, self_rpos, field, fields |