import math from typing import Dict, List import torch import torch.nn as nn from detect_tools.upn import ARCHITECTURES, build_decoder, build_encoder from detect_tools.upn.models.utils import (gen_encoder_output_proposals, inverse_sigmoid) from detect_tools.upn.ops.modules import MSDeformAttn @ARCHITECTURES.register_module() class DeformableTransformer(nn.Module): """Implementation of Deformable DETR. Args: encoder_cfg (Dict): Config for the TransformerEncoder. decoder_cfg (Dict): Config for the TransformerDecoder. num_queries (int): Number of queries. This is for matching part. Default: 900. d_model (int): Dimension of the model. Default: 256. num_feature_levels (int): Number of feature levels. Default: 1. binary_query_selection (bool): Whether to use binary query selection. Default: False. When using binary query selection, a linear with out channe =1 will be used to select topk proposals. Otherwise, we will use ContrastiveAssign to select topk proposals. learnable_tgt_init (bool): Whether to use learnable target init. Default: True. If False, we will use the topk encoder features as the target init. random_refpoints_xy (bool): Whether to use random refpoints xy. This is only used when two_stage_type is not 'no'. Default: False. If True, we will use random refpoints xy. two_stage_type (str): Type of two stage. Default: 'standard'. Options: 'no', 'standard' two_stage_learn_wh (bool): Whether to learn the width and height of anchor boxes. Default: False. two_stage_keep_all_tokens (bool): If False, the returned hs_enc, ref_enc, init_box_proposal will only be the topk proposals. Otherwise, we will return all the proposals from the encoder. Default: False. two_stage_bbox_embed_share (bool): Whether to share the bbox embedding between the two stages. Default: False. two_stage_class_embed_share (bool): Whether to share the class embedding between the two stages. rm_self_attn_layers (List[int]): The indices of the decoder layers to remove self-attention. Default: None. rm_detach (bool): Whether to detach the decoder output. Default: None. embed_init_tgt (bool): If true, the target embedding is learnable. Otherwise, we will use the topk encoder features as the target init. Default: True. """ def __init__( self, encoder_cfg: Dict, decoder_cfg: Dict, mask_decoder_cfg: Dict = None, num_queries: int = 900, d_model: int = 256, num_feature_levels: int = 4, binary_query_selection: bool = False, # init query (target) learnable_tgt_init=True, random_refpoints_xy=False, # for two stage two_stage_type: str = "standard", two_stage_learn_wh: bool = False, two_stage_keep_all_tokens: bool = False, two_stage_bbox_embed_share: bool = False, two_stage_class_embed_share: bool = False, # evo of #anchors rm_self_attn_layers: List[int] = None, # for detach rm_detach: bool = None, with_encoder_out: bool = True, ) -> None: super().__init__() self.binary_query_selection = binary_query_selection self.num_queries = num_queries self.num_feature_levels = num_feature_levels self.rm_self_attn_layers = rm_self_attn_layers self.d_model = d_model self.two_stage_bbox_embed_share = two_stage_bbox_embed_share self.two_stage_class_embed_share = two_stage_class_embed_share if self.binary_query_selection: self.binary_query_selection_layer = nn.Linear(d_model, 1) # build encoder self.encoder = build_encoder(encoder_cfg) # build decoder self.decoder = build_decoder(decoder_cfg) self.num_decoder_layers = self.decoder.num_layers # build sole mask decoder if mask_decoder_cfg is not None: self.mask_decoder = build_decoder(mask_decoder_cfg) else: self.mask_decoder = None # level embedding if num_feature_levels > 1: self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) # learnable target embedding self.learnable_tgt_init = learnable_tgt_init assert learnable_tgt_init, "why not learnable_tgt_init" self.tgt_embed = nn.Embedding(num_queries, d_model) nn.init.normal_(self.tgt_embed.weight.data) # for two stage # TODO: this part is really confusing self.two_stage_type = two_stage_type self.two_stage_learn_wh = two_stage_learn_wh self.two_stage_keep_all_tokens = two_stage_keep_all_tokens assert two_stage_type in [ "no", "standard", ], "unknown param {} of two_stage_type".format(two_stage_type) self.with_encoder_out = with_encoder_out if two_stage_type == "standard": # anchor selection at the output of encoder if with_encoder_out: self.enc_output = nn.Linear(d_model, d_model) self.enc_output_norm = nn.LayerNorm(d_model) if two_stage_learn_wh: # import ipdb; ipdb.set_trace() self.two_stage_wh_embedding = nn.Embedding(1, 2) else: self.two_stage_wh_embedding = None elif two_stage_type == "no": self.init_ref_points( num_queries, random_refpoints_xy ) # init self.refpoint_embed self.enc_out_class_embed = None # this will be initialized outside of the model self.enc_out_bbox_embed = None # this will be initialized outside of the model # remove some self_attn_layers or rm_detach self._reset_parameters() self.rm_self_attn_layers = rm_self_attn_layers if rm_self_attn_layers is not None: # assert len(rm_self_attn_layers) == num_decoder_layers print( "Removing the self-attn in {} decoder layers".format( rm_self_attn_layers ) ) for lid, dec_layer in enumerate(self.decoder.layers): if lid in rm_self_attn_layers: dec_layer.rm_self_attn_modules() self.rm_detach = rm_detach if self.rm_detach: assert isinstance(rm_detach, list) assert any([i in ["enc_ref", "enc_tgt", "dec"] for i in rm_detach]) self.decoder.rm_detach = rm_detach def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() if self.num_feature_levels > 1 and self.level_embed is not None: nn.init.normal_(self.level_embed) if self.two_stage_learn_wh: nn.init.constant_( self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05)) ) def init_ref_points(self, num_queries: int, random_refpoints_xy: bool = False): """Initialize learnable reference points for each query. Args: num_queries (int): number of queries random_refpoints_xy (bool, optional): whether to init the refpoints randomly. Defaults to False. """ self.refpoint_embed = nn.Embedding(num_queries, 4) if random_refpoints_xy: self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid( self.refpoint_embed.weight.data[:, :2] ) self.refpoint_embed.weight.data[:, :2].requires_grad = False def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward( self, src_flatten: torch.Tensor, lvl_pos_embed_flatten: torch.Tensor, level_start_index: List[int], spatial_shapes: List[torch.Tensor], valid_ratios: List[torch.Tensor], mask_flatten: torch.Tensor, prompt_type: str, ) -> List[torch.Tensor]: """Forward function.""" memory = self.encoder( src_flatten, pos=lvl_pos_embed_flatten, level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, key_padding_mask=mask_flatten, ) batch_size = src_flatten.shape[0] crop_region_features = torch.zeros(batch_size, 1, self.d_model).to( memory.device ) if prompt_type == "fine_grained_prompt": crop_region_features = ( self.fine_grained_prompt.weight[0] .unsqueeze(0) .unsqueeze(0) .repeat(batch_size, 1, 1) ) elif prompt_type == "coarse_grained_prompt": crop_region_features = ( self.coarse_grained_prompt.weight[0] .unsqueeze(0) .unsqueeze(0) .repeat(batch_size, 1, 1) ) pad_mask = torch.ones(batch_size, 1).to(crop_region_features.device).bool() self_attn_mask = torch.ones(batch_size, 1, 1).to(crop_region_features.device) ref_dict = dict( encoded_ref_feature=crop_region_features, pad_mask=pad_mask, self_attn_mask=self_attn_mask, prompt_type="universal_prompt", ) ( refpoint_embed, tgt, init_box_proposal, ) = self.get_two_stage_proposal(memory, mask_flatten, spatial_shapes, ref_dict) hs, references = self.decoder( tgt=tgt.transpose(0, 1), tgt_key_padding_mask=None, memory=memory.transpose(0, 1), memory_key_padding_mask=mask_flatten, pos=lvl_pos_embed_flatten.transpose(0, 1), refpoints_unsigmoid=refpoint_embed.transpose(0, 1), level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, tgt_mask=None, # we ~ the mask . False means use the token; True means pad the token ) hs_enc = ref_enc = None return ( hs, references, ref_dict, ) def get_two_stage_proposal( self, memory: torch.Tensor, mask_flatten: torch.Tensor, spatial_shapes: List[torch.Tensor], ref_dict: Dict, ) -> List[torch.Tensor]: """Two stage proposal generation for decoder Args: memory (torch.Tensor): Image encoded feature. [bs, n, 256] mask_flatten (torch.Tensor): Flattened mask. [bs, n] spatial_shapes (List[torch.Tensor]): Spatial shapes of each feature map. [bs, num_levels, 2] refpoint_embed_dn (torch.Tensor): Denosing refpoint embedding. [bs, num_dn_queries, 256] tgt_dn (torch.Tensor): Denosing target embedding. [bs, num_dn_queries, 256] ref_dict (Dict): A dict containing all kinds of reference image related features. """ bs = memory.shape[0] input_hw = None output_memory, output_proposals = gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes, input_hw ) output_memory = self.enc_output_norm(self.enc_output(output_memory)) if self.binary_query_selection: # Unused topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1) else: if ref_dict is not None: enc_outputs_class_unselected = self.enc_out_class_embed( output_memory, ref_dict ) # this is not a linear layer for prediction. But contrastive similaryity, shape [B, len_image, len_text] else: enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) topk_logits = enc_outputs_class_unselected.max(-1)[ 0 ] # shape [B, len_image] enc_outputs_coord_unselected = ( self.enc_out_bbox_embed(output_memory) + output_proposals ) # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries try: topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq except: raise ValueError(f"dadad {topk_logits.shape}") # gather boxes refpoint_embed_undetach = torch.gather( enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4), ) # unsigmoid refpoint_embed_ = refpoint_embed_undetach.detach() init_box_proposal = torch.gather( output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) ).sigmoid() # sigmoid # gather tgt tgt_undetach = torch.gather( output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) ) tgt_ = ( self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) ) # nq, bs, d_model refpoint_embed, tgt = refpoint_embed_, tgt_ return ( refpoint_embed, tgt, init_box_proposal, )