from typing import Dict import torch import torch.nn as nn from detect_tools.upn import DECODERS, build_decoder from detect_tools.upn.models.module import MLP from detect_tools.upn.models.utils import (gen_sineembed_for_position, get_activation_fn, get_clones, inverse_sigmoid) from detect_tools.upn.ops.modules import MSDeformAttn @DECODERS.register_module() class DeformableTransformerDecoderLayer(nn.Module): """Deformable Transformer Decoder Layer. This is a modified version in Grounding DINO. After the query is attented to the image feature, it is further attented to the text feature. The execute order is: self_attn -> cross_attn to text -> cross_attn to image -> ffn Args: d_model (int): The dimension of keys/values/queries in :class:`MultiheadAttention`. d_ffn (int): The dimension of the feedforward network model. dropout (float): Probability of an element to be zeroed. activation (str): Activation function in the feedforward network. 'relu' and 'gelu' are supported. n_levels (int): The number of levels in Multi-Scale Deformable Attention. n_heads (int): Parallel attention heads. n_points (int): Number of sampling points in Multi-Scale Deformable Attention. ffn_extra_layernorm (bool): If True, add an extra layernorm after ffn. """ def __init__( self, d_model: int = 256, d_ffn: int = 1024, dropout: float = 0.1, activation: str = "relu", n_levels: int = 4, n_heads: int = 8, n_points: int = 4, ffn_extra_layernorm: bool = False, ) -> None: super().__init__() # cross attention for visual features self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm1 = nn.LayerNorm(d_model) # self attention for query self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = get_activation_fn(activation, d_model=d_ffn, batch_dim=1) self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.linear2 = nn.Linear(d_ffn, d_model) self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm3 = nn.LayerNorm(d_model) if ffn_extra_layernorm: raise NotImplementedError("ffn_extra_layernorm not implemented") self.norm_ext = nn.LayerNorm(d_ffn) else: self.norm_ext = None self.key_aware_proj = None def rm_self_attn_modules(self): self.self_attn = None self.dropout2 = None self.norm2 = None @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt def forward( self, tgt: torch.Tensor, tgt_query_pos: torch.Tensor = None, tgt_reference_points: torch.Tensor = None, memory: torch.Tensor = None, memory_key_padding_mask: torch.Tensor = None, memory_level_start_index: torch.Tensor = None, memory_spatial_shapes: torch.Tensor = None, self_attn_mask: torch.Tensor = None, cross_attn_mask: torch.Tensor = None, ) -> torch.Tensor: """Forward function Args: tgt (torch.Tensor): Input target in shape (B, T, C) tgt_query_pos (torch.Tensor): Positional encoding of the query. tgt_query_sine_embed (torch.Tensor): Sine positional encoding of the query. Unused. tgt_key_padding_mask (torch.Tensor): Mask for target feature in shape (B, T). tgt_reference_points (torch.Tensor): Reference points for the query in shape (B, T, 4). memory_text (torch.Tensor): Input text embeddings in shape (B, num_token, C). text_attention_mask (torch.Tensor): Attention mask for text embeddings in shape (B, num_token). memory (torch.Tensor): Input image feature in shape (B, HW, C) memory_key_padding_mask (torch.Tensor): Mask for image feature in shape (B, HW) memory_level_start_index (torch.Tensor): Starting index of each level in memory. memory_spatial_shapes (torch.Tensor): Spatial shape of each level in memory. memory_pos (torch.Tensor): Positional encoding of memory. Unused. self_attn_mask (torch.Tensor): Mask used for self-attention. cross_attn_mask (torch.Tensor): Mask used for cross-attention. Returns: torch.Tensor: Output tensor in shape (B, T, C) """ assert cross_attn_mask is None # self attention if self.self_attn is not None: q = k = self.with_pos_embed(tgt, tgt_query_pos) tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # attend to image features tgt2 = self.cross_attn( self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), tgt_reference_points.transpose(0, 1).contiguous(), memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask, ).transpose(0, 1) tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) # ffn tgt = self.forward_ffn(tgt) return tgt @DECODERS.register_module() class UPNDecoder(nn.Module): """Decoder used in UPN. Each layer is a DeformableTransformerDecoderLayer. The query will be abled to attend the image feature and text feature. The execute order is: self_attn -> cross_attn to image -> ffn Args: decoder_layer_cfg (Dict): Config for the DeformableTransformerDecoderLayer. num_layers (int): number of layers norm (nn.Module, optional): normalization layer. Defaults to None. return_intermediate (bool, optional): whether return intermediate results. Defaults to False. d_model (int, optional): dimension of the model. Defaults to 256. query_dim (int, optional): dimension of the query. Defaults to 4. modulate_hw_attn (bool, optional): whether modulate the attention weights by the height and width of the image feature. Defaults to False. num_feature_levels (int, optional): number of feature levels. Defaults to 1. deformable_decoder (bool, optional): whether use deformable decoder. Defaults to False. decoder_query_perturber ([type], optional): [description]. Defaults to None. dec_layer_number ([type], optional): [description]. Defaults to None. rm_dec_query_scale (bool, optional): [description]. Defaults to False. dec_layer_share (bool, optional): [description]. Defaults to False. dec_layer_dropout_prob ([type], optional): [description]. Defaults to None. """ def __init__( self, decoder_layer_cfg: Dict, num_layers: int, norm: str = "layernorm", return_intermediate: bool = True, d_model: int = 256, query_dim: int = 4, modulate_hw_attn: bool = False, num_feature_levels: int = 1, deformable_decoder: bool = True, decoder_query_perturber=None, dec_layer_number=None, rm_dec_query_scale: bool = True, dec_layer_share: bool = False, dec_layer_dropout_prob=None, use_detached_boxes_dec_out: bool = False, ): super().__init__() decoder_layer = build_decoder(decoder_layer_cfg) if num_layers > 0: self.layers = get_clones( decoder_layer, num_layers, layer_share=dec_layer_share ) else: self.layers = [] self.num_layers = num_layers if norm == "layernorm": self.norm = nn.LayerNorm(d_model) self.return_intermediate = return_intermediate self.query_dim = query_dim assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) self.num_feature_levels = num_feature_levels self.use_detached_boxes_dec_out = use_detached_boxes_dec_out self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) self.ref_point_head_point = MLP( d_model, d_model, d_model, 2 ) # for point reference only if not deformable_decoder: self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2) else: self.query_pos_sine_scale = None if rm_dec_query_scale: self.query_scale = None else: raise NotImplementedError self.query_scale = MLP(d_model, d_model, d_model, 2) self.bbox_embed = None self.class_embed = None self.d_model = d_model self.modulate_hw_attn = modulate_hw_attn self.deformable_decoder = deformable_decoder if not deformable_decoder and modulate_hw_attn: self.ref_anchor_head = MLP(d_model, d_model, 2, 2) else: self.ref_anchor_head = None self.decoder_query_perturber = decoder_query_perturber self.box_pred_damping = None self.dec_layer_number = dec_layer_number if dec_layer_number is not None: assert isinstance(dec_layer_number, list) assert len(dec_layer_number) == num_layers self.dec_layer_dropout_prob = dec_layer_dropout_prob if dec_layer_dropout_prob is not None: assert isinstance(dec_layer_dropout_prob, list) assert len(dec_layer_dropout_prob) == num_layers for i in dec_layer_dropout_prob: assert 0.0 <= i <= 1.0 self.rm_detach = None def forward( self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor = None, memory_mask: torch.Tensor = None, tgt_key_padding_mask: torch.Tensor = None, memory_key_padding_mask: torch.Tensor = None, pos: torch.Tensor = None, refpoints_unsigmoid: torch.Tensor = None, level_start_index: torch.Tensor = None, spatial_shapes: torch.Tensor = None, valid_ratios: torch.Tensor = None, memory_ref_image: torch.Tensor = None, refImg_padding_mask: torch.Tensor = None, memory_visual_prompt: torch.Tensor = None, ): """Forward function. Args: tgt (torch.Tensor): target feature, [bs, num_queries, d_model] memory (torch.Tensor): Image feature, [bs, hw, d_model] tgt_mask (torch.Tensor, optional): target mask for attention. Defaults to None. memory_mask (torch.Tensor, optional): image mask for attention. Defaults to None. tgt_key_padding_mask (torch.Tensor, optional): target mask for padding. Defaults to None. memory_key_padding_mask (torch.Tensor, optional): image mask for padding. Defaults to None. pos (torch.Tensor, optional): query position embedding refpoints_unsigmoid (torch.Tensor, optional): reference points. Defaults to None. level_start_index (torch.Tensor, optional): start index of each level. Defaults to None. spatial_shapes (torch.Tensor, optional): spatial shape of each level. Defaults to None. valid_ratios (torch.Tensor, optional): valid ratio of each level. Defaults to None. memory_ref_image (torch.Tensor, optional): reference image feature, [bs, num_ref, d_model]. Defaults to None. refImg_padding_mask (torch.Tensor, optional): padding mask for attention. Defaults to None. """ output = tgt intermediate = [] reference_points = refpoints_unsigmoid.sigmoid() ref_points = [reference_points] for layer_id, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = ( reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :] ) # nq, bs, nlevel, 4 else: assert reference_points.shape[-1] == 2 reference_points_input = ( reference_points[:, :, None] * valid_ratios[None, :] ) query_sine_embed = gen_sineembed_for_position( reference_points_input[:, :, 0, :] ) # nq, bs, 256*2 # conditional query if query_sine_embed.shape[-1] == 512: raw_query_pos = ( self.ref_point_head(query_sine_embed) + self.ref_point_head_point( torch.zeros_like(query_sine_embed)[:, :, :256] ) * 0.0 ) else: raw_query_pos = ( self.ref_point_head_point(query_sine_embed) + self.ref_point_head( torch.zeros( query_sine_embed.shape[0], query_sine_embed.shape[1], 512, device=query_sine_embed.device, ) ) * 0.0 ) pos_scale = self.query_scale(output) if self.query_scale is not None else 1 query_pos = pos_scale * raw_query_pos # main process output = layer( tgt=output, tgt_query_pos=query_pos, tgt_reference_points=reference_points_input, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask, ) if output.isnan().any() | output.isinf().any(): print(f"output layer_id {layer_id} is nan") try: num_nan = output.isnan().sum().item() num_inf = output.isinf().sum().item() print(f"num_nan {num_nan}, num_inf {num_inf}") except Exception as e: print(e) # iter update if self.bbox_embed is not None: reference_before_sigmoid = inverse_sigmoid(reference_points) delta_unsig = self.bbox_embed[layer_id](output) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() if self.rm_detach and "dec" in self.rm_detach: reference_points = new_reference_points else: reference_points = new_reference_points.detach() if self.use_detached_boxes_dec_out: ref_points.append(reference_points) else: ref_points.append(new_reference_points) if self.return_intermediate: intermediate.append(self.norm(output)) if self.return_intermediate: return [ [itm_out.transpose(0, 1) for itm_out in intermediate], [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], ] else: return self.norm(output).transpose(0, 1)