| |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| from torch import Tensor, nn |
| from torch.nn.init import normal_ |
|
|
| from mmdet.registry import MODELS |
| from mmdet.structures import OptSampleList |
| from mmdet.utils import OptConfigType |
| from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder, |
| DinoTransformerDecoder, SinePositionalEncoding) |
| from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention |
|
|
|
|
| @MODELS.register_module() |
| class DINO(DeformableDETR): |
| r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes |
| for End-to-End Object Detection <https://arxiv.org/abs/2203.03605>`_ |
| |
| Code is modified from the `official github repo |
| <https://github.com/IDEA-Research/DINO>`_. |
| |
| Args: |
| dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising |
| query generator. Defaults to `None`. |
| """ |
|
|
| def __init__(self, *args, dn_cfg: OptConfigType = None, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| assert self.as_two_stage, 'as_two_stage must be True for DINO' |
| assert self.with_box_refine, 'with_box_refine must be True for DINO' |
|
|
| if dn_cfg is not None: |
| assert 'num_classes' not in dn_cfg and \ |
| 'num_queries' not in dn_cfg and \ |
| 'hidden_dim' not in dn_cfg, \ |
| 'The three keyword args `num_classes`, `embed_dims`, and ' \ |
| '`num_matching_queries` are set in `detector.__init__()`, ' \ |
| 'users should not set them in `dn_cfg` config.' |
| dn_cfg['num_classes'] = self.bbox_head.num_classes |
| dn_cfg['embed_dims'] = self.embed_dims |
| dn_cfg['num_matching_queries'] = self.num_queries |
| self.dn_query_generator = CdnQueryGenerator(**dn_cfg) |
|
|
| def _init_layers(self) -> None: |
| """Initialize layers except for backbone, neck and bbox_head.""" |
| self.positional_encoding = SinePositionalEncoding( |
| **self.positional_encoding) |
| self.encoder = DeformableDetrTransformerEncoder(**self.encoder) |
| self.decoder = DinoTransformerDecoder(**self.decoder) |
| self.embed_dims = self.encoder.embed_dims |
| self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) |
| |
| |
| |
| |
|
|
| num_feats = self.positional_encoding.num_feats |
| assert num_feats * 2 == self.embed_dims, \ |
| f'embed_dims should be exactly 2 times of num_feats. ' \ |
| f'Found {self.embed_dims} and {num_feats}.' |
|
|
| self.level_embed = nn.Parameter( |
| torch.Tensor(self.num_feature_levels, self.embed_dims)) |
| self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) |
| self.memory_trans_norm = nn.LayerNorm(self.embed_dims) |
|
|
| def init_weights(self) -> None: |
| """Initialize weights for Transformer and other components.""" |
| super(DeformableDETR, self).init_weights() |
| for coder in self.encoder, self.decoder: |
| for p in coder.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
| for m in self.modules(): |
| if isinstance(m, MultiScaleDeformableAttention): |
| m.init_weights() |
| nn.init.xavier_uniform_(self.memory_trans_fc.weight) |
| nn.init.xavier_uniform_(self.query_embedding.weight) |
| normal_(self.level_embed) |
|
|
| def forward_transformer( |
| self, |
| img_feats: Tuple[Tensor], |
| batch_data_samples: OptSampleList = None, |
| ) -> Dict: |
| """Forward process of Transformer. |
| |
| The forward procedure of the transformer is defined as: |
| 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' |
| More details can be found at `TransformerDetector.forward_transformer` |
| in `mmdet/detector/base_detr.py`. |
| The difference is that the ground truth in `batch_data_samples` is |
| required for the `pre_decoder` to prepare the query of DINO. |
| Additionally, DINO inherits the `pre_transformer` method and the |
| `forward_encoder` method of DeformableDETR. More details about the |
| two methods can be found in `mmdet/detector/deformable_detr.py`. |
| |
| Args: |
| img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each |
| feature map has shape (bs, dim, H, W). |
| batch_data_samples (list[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
| Defaults to None. |
| |
| Returns: |
| dict: The dictionary of bbox_head function inputs, which always |
| includes the `hidden_states` of the decoder output and may contain |
| `references` including the initial and intermediate references. |
| """ |
| encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( |
| img_feats, batch_data_samples) |
|
|
| encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict) |
|
|
| tmp_dec_in, head_inputs_dict = self.pre_decoder( |
| **encoder_outputs_dict, batch_data_samples=batch_data_samples) |
| decoder_inputs_dict.update(tmp_dec_in) |
|
|
| decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) |
| head_inputs_dict.update(decoder_outputs_dict) |
| return head_inputs_dict |
|
|
| def pre_decoder( |
| self, |
| memory: Tensor, |
| memory_mask: Tensor, |
| spatial_shapes: Tensor, |
| batch_data_samples: OptSampleList = None, |
| ) -> Tuple[Dict]: |
| """Prepare intermediate variables before entering Transformer decoder, |
| such as `query`, `query_pos`, and `reference_points`. |
| |
| Args: |
| memory (Tensor): The output embeddings of the Transformer encoder, |
| has shape (bs, num_feat_points, dim). |
| memory_mask (Tensor): ByteTensor, the padding mask of the memory, |
| has shape (bs, num_feat_points). Will only be used when |
| `as_two_stage` is `True`. |
| spatial_shapes (Tensor): Spatial shapes of features in all levels. |
| With shape (num_levels, 2), last dimension represents (h, w). |
| Will only be used when `as_two_stage` is `True`. |
| batch_data_samples (list[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
| Defaults to None. |
| |
| Returns: |
| tuple[dict]: The decoder_inputs_dict and head_inputs_dict. |
| |
| - decoder_inputs_dict (dict): The keyword dictionary args of |
| `self.forward_decoder()`, which includes 'query', 'memory', |
| `reference_points`, and `dn_mask`. The reference points of |
| decoder input here are 4D boxes, although it has `points` |
| in its name. |
| - head_inputs_dict (dict): The keyword dictionary args of the |
| bbox_head functions, which includes `topk_score`, `topk_coords`, |
| and `dn_meta` when `self.training` is `True`, else is empty. |
| """ |
| bs, _, c = memory.shape |
| cls_out_features = self.bbox_head.cls_branches[ |
| self.decoder.num_layers].out_features |
|
|
| output_memory, output_proposals = self.gen_encoder_output_proposals( |
| memory, memory_mask, spatial_shapes) |
| enc_outputs_class = self.bbox_head.cls_branches[ |
| self.decoder.num_layers]( |
| output_memory) |
| enc_outputs_coord_unact = self.bbox_head.reg_branches[ |
| self.decoder.num_layers](output_memory) + output_proposals |
|
|
| |
| |
| |
| |
| topk_indices = torch.topk( |
| enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1] |
| topk_score = torch.gather( |
| enc_outputs_class, 1, |
| topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) |
| topk_coords_unact = torch.gather( |
| enc_outputs_coord_unact, 1, |
| topk_indices.unsqueeze(-1).repeat(1, 1, 4)) |
| topk_coords = topk_coords_unact.sigmoid() |
| topk_coords_unact = topk_coords_unact.detach() |
|
|
| query = self.query_embedding.weight[:, None, :] |
| query = query.repeat(1, bs, 1).transpose(0, 1) |
| if self.training: |
| dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ |
| self.dn_query_generator(batch_data_samples) |
| query = torch.cat([dn_label_query, query], dim=1) |
| reference_points = torch.cat([dn_bbox_query, topk_coords_unact], |
| dim=1) |
| else: |
| reference_points = topk_coords_unact |
| dn_mask, dn_meta = None, None |
| reference_points = reference_points.sigmoid() |
|
|
| decoder_inputs_dict = dict( |
| query=query, |
| memory=memory, |
| reference_points=reference_points, |
| dn_mask=dn_mask) |
| |
| |
| |
| head_inputs_dict = dict( |
| enc_outputs_class=topk_score, |
| enc_outputs_coord=topk_coords, |
| dn_meta=dn_meta) if self.training else dict() |
| return decoder_inputs_dict, head_inputs_dict |
|
|
| def forward_decoder(self, |
| query: Tensor, |
| memory: Tensor, |
| memory_mask: Tensor, |
| reference_points: Tensor, |
| spatial_shapes: Tensor, |
| level_start_index: Tensor, |
| valid_ratios: Tensor, |
| dn_mask: Optional[Tensor] = None, |
| **kwargs) -> Dict: |
| """Forward with Transformer decoder. |
| |
| The forward procedure of the transformer is defined as: |
| 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' |
| More details can be found at `TransformerDetector.forward_transformer` |
| in `mmdet/detector/base_detr.py`. |
| |
| Args: |
| query (Tensor): The queries of decoder inputs, has shape |
| (bs, num_queries_total, dim), where `num_queries_total` is the |
| sum of `num_denoising_queries` and `num_matching_queries` when |
| `self.training` is `True`, else `num_matching_queries`. |
| memory (Tensor): The output embeddings of the Transformer encoder, |
| has shape (bs, num_feat_points, dim). |
| memory_mask (Tensor): ByteTensor, the padding mask of the memory, |
| has shape (bs, num_feat_points). |
| reference_points (Tensor): The initial reference, has shape |
| (bs, num_queries_total, 4) with the last dimension arranged as |
| (cx, cy, w, h). |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, |
| has shape (num_levels, 2), last dimension represents (h, w). |
| level_start_index (Tensor): The start index of each level. |
| A tensor has shape (num_levels, ) and can be represented |
| as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
| valid_ratios (Tensor): The ratios of the valid width and the valid |
| height relative to the width and the height of features in all |
| levels, has shape (bs, num_levels, 2). |
| dn_mask (Tensor, optional): The attention mask to prevent |
| information leakage from different denoising groups and |
| matching parts, will be used as `self_attn_mask` of the |
| `self.decoder`, has shape (num_queries_total, |
| num_queries_total). |
| It is `None` when `self.training` is `False`. |
| |
| Returns: |
| dict: The dictionary of decoder outputs, which includes the |
| `hidden_states` of the decoder output and `references` including |
| the initial and intermediate reference_points. |
| """ |
| inter_states, references = self.decoder( |
| query=query, |
| value=memory, |
| key_padding_mask=memory_mask, |
| self_attn_mask=dn_mask, |
| reference_points=reference_points, |
| spatial_shapes=spatial_shapes, |
| level_start_index=level_start_index, |
| valid_ratios=valid_ratios, |
| reg_branches=self.bbox_head.reg_branches, |
| **kwargs) |
|
|
| if len(query) == self.num_queries: |
| |
| |
| |
| |
| inter_states[0] += \ |
| self.dn_query_generator.label_embedding.weight[0, 0] * 0.0 |
|
|
| decoder_outputs_dict = dict( |
| hidden_states=inter_states, references=list(references)) |
| return decoder_outputs_dict |
|
|