|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.utils.checkpoint as checkpoint
|
|
|
| from ..util.misc import _get_activation_fn, _get_clones, inverse_sigmoid
|
|
|
|
|
| class GlobalCrossAttention(nn.Module):
|
| def __init__(
|
| self,
|
| dim,
|
| num_heads,
|
| qkv_bias=True,
|
| qk_scale=None,
|
| attn_drop=0.0,
|
| proj_drop=0.0,
|
| ):
|
| super().__init__()
|
| self.dim = dim
|
| self.num_heads = num_heads
|
| head_dim = dim // num_heads
|
| self.scale = qk_scale or head_dim**-0.5
|
|
|
| self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
| self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
| self.attn_drop = nn.Dropout(attn_drop)
|
| self.proj = nn.Linear(dim, dim)
|
| self.proj_drop = nn.Dropout(proj_drop)
|
| self.softmax = nn.Softmax(dim=-1)
|
|
|
| def forward(
|
| self,
|
| query,
|
| k_input_flatten,
|
| v_input_flatten,
|
| input_padding_mask=None,
|
| ):
|
| B_, N, C = k_input_flatten.shape
|
| k = self.k(k_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| v = self.v(v_input_flatten).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| B_, N, C = query.shape
|
| q = self.q(query).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
|
| attn_mask = None
|
| if input_padding_mask is not None:
|
| attn_mask = input_padding_mask[:, None, None] * -100
|
| attn_mask = attn_mask.contiguous()
|
|
|
| x = torch.nn.functional.scaled_dot_product_attention(
|
| query=q,
|
| key=k,
|
| value=v,
|
| attn_mask=attn_mask,
|
| dropout_p=self.attn_drop.p if self.training else 0,
|
| scale=self.scale,
|
| )
|
| x = x.transpose(1, 2).reshape(B_, N, C)
|
| x = self.proj(x)
|
| x = self.proj_drop(x)
|
| return x
|
|
|
|
|
| class GlobalDecoderLayer(nn.Module):
|
| def __init__(
|
| self,
|
| d_model=256,
|
| d_ffn=1024,
|
| dropout=0.1,
|
| activation="relu",
|
| n_heads=8,
|
| norm_type="post_norm",
|
| ):
|
| super().__init__()
|
|
|
| self.norm_type = norm_type
|
|
|
|
|
| self.cross_attn = GlobalCrossAttention(d_model, n_heads)
|
| self.dropout1 = nn.Dropout(dropout)
|
| self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
|
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
| self.dropout2 = nn.Dropout(dropout)
|
| self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
|
| self.linear1 = nn.Linear(d_model, d_ffn)
|
| self.activation = _get_activation_fn(activation)
|
| self.dropout3 = nn.Dropout(dropout)
|
| self.linear2 = nn.Linear(d_ffn, d_model)
|
| self.dropout4 = nn.Dropout(dropout)
|
| self.norm3 = nn.LayerNorm(d_model)
|
|
|
| @staticmethod
|
| def with_pos_embed(tensor, pos):
|
| return tensor if pos is None else tensor + pos
|
|
|
| def forward_pre(
|
| self,
|
| tgt,
|
| query_pos,
|
| src,
|
| src_pos_embed,
|
| src_padding_mask=None,
|
| self_attn_mask=None,
|
| ):
|
|
|
| tgt2 = self.norm2(tgt)
|
| q = k = self.with_pos_embed(tgt2, query_pos)
|
| tgt2 = self.self_attn(
|
| q.transpose(0, 1), k.transpose(0, 1), tgt2.transpose(0, 1), attn_mask=self_attn_mask, need_weights=False
|
| )[0].transpose(0, 1)
|
| tgt = tgt + self.dropout2(tgt2)
|
|
|
|
|
| tgt2 = self.norm1(tgt)
|
| tgt2 = self.cross_attn(
|
| self.with_pos_embed(tgt2, query_pos),
|
| self.with_pos_embed(src, src_pos_embed),
|
| src,
|
| src_padding_mask,
|
| )
|
| tgt = tgt + self.dropout1(tgt2)
|
|
|
|
|
| tgt2 = self.norm3(tgt)
|
| tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt2))))
|
| tgt = tgt + self.dropout4(tgt2)
|
|
|
| return tgt
|
|
|
| def forward_post(
|
| self,
|
| tgt,
|
| query_pos,
|
| src,
|
| src_pos_embed,
|
| src_padding_mask=None,
|
| self_attn_mask=None,
|
| ):
|
|
|
| q = k = self.with_pos_embed(tgt, query_pos)
|
| tgt2 = self.self_attn(
|
| q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), attn_mask=self_attn_mask, need_weights=False
|
| )[0].transpose(0, 1)
|
| tgt = tgt + self.dropout2(tgt2)
|
| tgt = self.norm2(tgt)
|
|
|
|
|
| tgt2 = self.cross_attn(
|
| self.with_pos_embed(tgt, query_pos),
|
| self.with_pos_embed(src, src_pos_embed),
|
| src,
|
| src_padding_mask,
|
| )
|
| tgt = tgt + self.dropout1(tgt2)
|
| tgt = self.norm1(tgt)
|
|
|
|
|
| tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
| tgt = tgt + self.dropout4(tgt2)
|
| tgt = self.norm3(tgt)
|
|
|
| return tgt
|
|
|
| def forward(
|
| self,
|
| tgt,
|
| query_pos,
|
| src,
|
| src_pos_embed,
|
| src_padding_mask=None,
|
| self_attn_mask=None,
|
| ):
|
| if self.norm_type == "pre_norm":
|
| return self.forward_pre(tgt, query_pos, src, src_pos_embed, src_padding_mask, self_attn_mask)
|
| if self.norm_type == "post_norm":
|
| return self.forward_post(tgt, query_pos, src, src_pos_embed, src_padding_mask, self_attn_mask)
|
|
|
|
|
| class GlobalDecoder(nn.Module):
|
| def __init__(
|
| self,
|
| decoder_layer,
|
| num_layers,
|
| return_intermediate=False,
|
| look_forward_twice=False,
|
| use_checkpoint=False,
|
| d_model=256,
|
| norm_type="post_norm",
|
| ):
|
| super().__init__()
|
| self.layers = _get_clones(decoder_layer, num_layers)
|
| self.num_layers = num_layers
|
| self.return_intermediate = return_intermediate
|
| self.look_forward_twice = look_forward_twice
|
| self.use_checkpoint = use_checkpoint
|
|
|
| self.bbox_embed = None
|
| self.class_embed = None
|
|
|
| self.norm_type = norm_type
|
| if self.norm_type == "pre_norm":
|
| self.final_layer_norm = nn.LayerNorm(d_model)
|
| else:
|
| self.final_layer_norm = None
|
|
|
| def _reset_parameters(self):
|
|
|
| def _init_weights(m):
|
| if isinstance(m, nn.Linear):
|
| nn.init.trunc_normal_(m.weight, std=0.02)
|
| if isinstance(m, nn.Linear) and m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
| elif isinstance(m, nn.LayerNorm):
|
| nn.init.constant_(m.bias, 0)
|
| nn.init.constant_(m.weight, 1.0)
|
|
|
| self.apply(_init_weights)
|
|
|
| def forward(
|
| self,
|
| tgt,
|
| reference_points,
|
| src,
|
| src_pos_embed,
|
| src_spatial_shapes,
|
| src_level_start_index,
|
| src_valid_ratios,
|
| query_pos=None,
|
| src_padding_mask=None,
|
| self_attn_mask=None,
|
| max_shape=None,
|
| ):
|
| output = tgt
|
|
|
| intermediate = []
|
| intermediate_reference_points = []
|
| for lid, layer in enumerate(self.layers):
|
| if self.use_checkpoint:
|
| output = checkpoint.checkpoint(
|
| layer,
|
| output,
|
| query_pos,
|
| src,
|
| src_pos_embed,
|
| src_padding_mask,
|
| self_attn_mask,
|
| )
|
| else:
|
| output = layer(
|
| output,
|
| query_pos,
|
| src,
|
| src_pos_embed,
|
| src_padding_mask,
|
| self_attn_mask,
|
| )
|
|
|
| if self.final_layer_norm is not None:
|
| output_after_norm = self.final_layer_norm(output)
|
| else:
|
| output_after_norm = output
|
|
|
|
|
| if self.bbox_embed is not None:
|
| tmp = self.bbox_embed[lid](output_after_norm)
|
| if reference_points.shape[-1] == 4:
|
| new_reference_points = tmp + inverse_sigmoid(reference_points)
|
| new_reference_points = new_reference_points.sigmoid()
|
| else:
|
| assert reference_points.shape[-1] == 2
|
| new_reference_points = tmp
|
| new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
|
| new_reference_points = new_reference_points.sigmoid()
|
| reference_points = new_reference_points.detach()
|
|
|
| if self.return_intermediate:
|
| intermediate.append(output_after_norm)
|
| intermediate_reference_points.append(
|
| new_reference_points if self.look_forward_twice else reference_points
|
| )
|
|
|
| if self.return_intermediate:
|
| return torch.stack(intermediate), torch.stack(intermediate_reference_points)
|
|
|
| return output_after_norm, reference_points
|
|
|
|
|
| def build_global_ape_decoder(args):
|
| decoder_layer = GlobalDecoderLayer(
|
| d_model=args.hidden_dim,
|
| d_ffn=args.dim_feedforward,
|
| dropout=args.dropout,
|
| activation="relu",
|
| n_heads=args.nheads,
|
| norm_type=args.norm_type,
|
| )
|
| decoder = GlobalDecoder(
|
| decoder_layer,
|
| num_layers=args.dec_layers,
|
| return_intermediate=True,
|
| look_forward_twice=args.look_forward_twice,
|
| use_checkpoint=args.decoder_use_checkpoint,
|
| d_model=args.hidden_dim,
|
| norm_type=args.norm_type,
|
| )
|
| return decoder
|
|
|