| |
| import warnings |
| import numpy as np |
| from typing import Optional, Tuple, Union,List |
| import torch |
| from mmcv.cnn import build_norm_layer |
| from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention |
| |
| from .specdetr_atten import MultiScaleDeformableAttention_1 as MultiScaleDeformableAttention |
| from mmengine.model import ModuleList |
| from torch import Tensor, nn |
| from mmengine.model import BaseModule |
|
|
| from mmdet.structures import SampleList |
| from mmdet.structures.bbox import bbox_xyxy_to_cxcywh,bbox_cxcywh_to_xyxy |
| from mmengine import ConfigDict |
| from mmdet.utils import ConfigType, OptConfigType |
| |
| |
| from .utils import MLP, coordinate_to_encoding, inverse_sigmoid |
| import random |
| import math |
|
|
|
|
| class SpecDetrTransformerEncoder(BaseModule): |
| """Transformer encoder of Deformable DETR.Encoder of DETR. |
| |
| Args: |
| num_layers (int): Number of encoder layers. |
| layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder |
| layer. All the layers will share the same config. |
| init_cfg (:obj:`ConfigDict` or dict, optional): the config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| num_layers: int, |
| layer_cfg: ConfigType, |
| init_cfg: OptConfigType = None) -> None: |
|
|
| super().__init__(init_cfg=init_cfg) |
| self.num_layers = num_layers |
| self.layer_cfg = layer_cfg |
| self._init_layers() |
| self.save_id = 0 |
|
|
|
|
| def _init_layers(self) -> None: |
| """Initialize encoder layers.""" |
| self.layers = ModuleList([ |
| SpecDetrTransformerEncoderLayer(**self.layer_cfg) |
| for _ in range(self.num_layers) |
| ]) |
| self.embed_dims = self.layers[0].embed_dims |
|
|
| def forward(self, query: Tensor, query_pos: Tensor, |
| key_padding_mask: Tensor, spatial_shapes: Tensor, |
| level_start_index: Tensor, valid_ratios: Tensor, |
| **kwargs) -> Tensor: |
| """Forward function of Transformer encoder. |
| |
| Args: |
| query (Tensor): The input query, has shape (bs, num_queries, dim). |
| query_pos (Tensor): The positional encoding for query, has shape |
| (bs, num_queries, dim). |
| key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
| input. ByteTensor, has shape (bs, num_queries). |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, |
| has shape (num_levels, 2), last dimension represents (h, w). |
| level_start_index (Tensor): The start index of each level. |
| A tensor has shape (num_levels, ) and can be represented |
| as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
| valid_ratios (Tensor): The ratios of the valid width and the valid |
| height relative to the width and the height of features in all |
| levels, has shape (bs, num_levels, 2). |
| |
| Returns: |
| Tensor: Output queries of Transformer encoder, which is also |
| called 'encoder output embeddings' or 'memory', has shape |
| (bs, num_queries, dim) |
| """ |
| reference_points = self.get_encoder_reference_points( |
| spatial_shapes, valid_ratios, device=query.device) |
|
|
| for i, layer in enumerate(self.layers): |
| if self.save_id in [21] and i == 5: |
| [] |
| query = layer( |
| query=query, |
| query_pos=query_pos, |
| key_padding_mask=key_padding_mask, |
| spatial_shapes=spatial_shapes, |
| level_start_index=level_start_index, |
| valid_ratios=valid_ratios, |
| reference_points=reference_points, |
| **kwargs) |
| |
| |
| |
| |
| |
| |
| self.save_id += 1 |
| return query |
|
|
| @staticmethod |
| def get_encoder_reference_points( |
| spatial_shapes: Tensor, valid_ratios: Tensor, |
| device: Union[torch.device, str]) -> Tensor: |
| """Get the reference points used in encoder. |
| |
| Args: |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, |
| has shape (num_levels, 2), last dimension represents (h, w). |
| valid_ratios (Tensor): The ratios of the valid width and the valid |
| height relative to the width and the height of features in all |
| levels, has shape (bs, num_levels, 2). |
| device (obj:`device` or str): The device acquired by the |
| `reference_points`. |
| |
| Returns: |
| Tensor: Reference points used in decoder, has shape (bs, length, |
| num_levels, 2). |
| """ |
|
|
| reference_points_list = [] |
| for lvl, (H, W) in enumerate(spatial_shapes): |
| ref_y, ref_x = torch.meshgrid( |
| torch.linspace( |
| 0.5, H - 0.5, H, dtype=torch.float32, device=device), |
| torch.linspace( |
| 0.5, W - 0.5, W, dtype=torch.float32, device=device)) |
| ref_y = ref_y.reshape(-1)[None] / ( |
| valid_ratios[:, None, lvl, 1] * H) |
| ref_x = ref_x.reshape(-1)[None] / ( |
| valid_ratios[:, None, lvl, 0] * W) |
| ref = torch.stack((ref_x, ref_y), -1) |
| reference_points_list.append(ref) |
| reference_points = torch.cat(reference_points_list, 1) |
| |
| reference_points = reference_points[:, :, None] * valid_ratios[:, None] |
| return reference_points |
|
|
|
|
| class SpecDetrTransformerDecoder(BaseModule): |
| """Transformer encoder of DINO.Decoder of DETR. |
| |
| Args: |
| num_layers (int): Number of decoder layers. |
| layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder |
| layer. All the layers will share the same config. |
| post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the |
| post normalization layer. Defaults to `LN`. |
| return_intermediate (bool, optional): Whether to return outputs of |
| intermediate layers. Defaults to `True`, |
| init_cfg (:obj:`ConfigDict` or dict, optional): the config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| num_layers: int, |
| layer_cfg: ConfigType, |
| post_norm_cfg: OptConfigType = dict(type='LN'), |
| return_intermediate: bool = True, |
| init_cfg: Union[dict, ConfigDict] = None) -> None: |
| super().__init__(init_cfg=init_cfg) |
| self.layer_cfg = layer_cfg |
| self.num_layers = num_layers |
| self.post_norm_cfg = post_norm_cfg |
| self.return_intermediate = return_intermediate |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize decoder layers.""" |
| self.layers = ModuleList([ |
| SpecDetrTransformerDecoderLayer(**self.layer_cfg) |
| for _ in range(self.num_layers) |
| ]) |
| self.embed_dims = self.layers[0].embed_dims |
| if self.post_norm_cfg is not None: |
| raise ValueError('There is not post_norm in ' |
| f'{self._get_name()}') |
| self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, |
| self.embed_dims, 2) |
| self.norm = nn.LayerNorm(self.embed_dims) |
|
|
| def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, |
| self_attn_mask: Tensor, reference_points: Tensor, |
| spatial_shapes: Tensor, level_start_index: Tensor, |
| valid_ratios: Tensor, reg_branches: nn.ModuleList, |
| **kwargs) -> Tensor: |
| """Forward function of Transformer encoder. |
| |
| Args: |
| query (Tensor): The input query, has shape (num_queries, bs, dim). |
| value (Tensor): The input values, has shape (num_value, bs, dim). |
| key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
| input. ByteTensor, has shape (num_queries, bs). |
| self_attn_mask (Tensor): The attention mask to prevent information |
| leakage from different denoising groups and matching parts, has |
| shape (num_queries_total, num_queries_total). It is `None` when |
| `self.training` is `False`. |
| reference_points (Tensor): The initial reference, has shape |
| (bs, num_queries, 4) with the last dimension arranged as |
| (cx, cy, w, h). |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, |
| has shape (num_levels, 2), last dimension represents (h, w). |
| level_start_index (Tensor): The start index of each level. |
| A tensor has shape (num_levels, ) and can be represented |
| as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
| valid_ratios (Tensor): The ratios of the valid width and the valid |
| height relative to the width and the height of features in all |
| levels, has shape (bs, num_levels, 2). |
| reg_branches: (obj:`nn.ModuleList`): Used for refining the |
| regression results. |
| |
| Returns: |
| Tensor: Output queries of Transformer encoder, which is also |
| called 'encoder output embeddings' or 'memory', has shape |
| (num_queries, bs, dim) |
| """ |
| intermediate = [] |
| intermediate_reference_points = [reference_points] |
| for lid, layer in enumerate(self.layers): |
| if reference_points.shape[-1] == 4: |
| reference_points_input = \ |
| reference_points[:, :, None] * torch.cat( |
| [valid_ratios, valid_ratios], -1)[:, None] |
| else: |
| assert reference_points.shape[-1] == 2 |
| reference_points_input = \ |
| reference_points[:, :, None] * valid_ratios[:, None] |
|
|
| query_sine_embed = coordinate_to_encoding( |
| reference_points_input[:, :, 0, :], self.embed_dims/2 ) |
| query_pos = self.ref_point_head(query_sine_embed) |
|
|
| query = layer( |
| query, |
| query_pos=query_pos, |
| value=value, |
| key_padding_mask=key_padding_mask, |
| self_attn_mask=self_attn_mask, |
| spatial_shapes=spatial_shapes, |
| level_start_index=level_start_index, |
| valid_ratios=valid_ratios, |
| reference_points=reference_points_input, |
| **kwargs) |
|
|
| if reg_branches is not None: |
| tmp = reg_branches[lid](query) |
| assert reference_points.shape[-1] == 4 |
| new_reference_points = tmp + inverse_sigmoid( |
| reference_points, eps=1e-3) |
| new_reference_points = new_reference_points.sigmoid() |
| reference_points = new_reference_points.detach() |
|
|
| if self.return_intermediate: |
| intermediate.append(self.norm(query)) |
| intermediate_reference_points.append(new_reference_points) |
| |
| |
|
|
| if self.return_intermediate: |
| return torch.stack(intermediate), torch.stack( |
| intermediate_reference_points) |
|
|
| return query, reference_points |
|
|
|
|
| class SpecDetrTransformerEncoderLayer(BaseModule): |
| """Encoder layer of Deformable DETR. |
| Implements encoder layer in DETR transformer. |
| |
| Args: |
| self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self |
| attention. |
| ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. |
| norm_cfg (:obj:`ConfigDict` or dict, optional): Config for |
| normalization layers. All the layers will share the same |
| config. Defaults to `LN`. |
| init_cfg (:obj:`ConfigDict` or dict, optional): Config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| self_attn_cfg: OptConfigType = dict( |
| embed_dims=256, num_heads=8, dropout=0.0), |
| ffn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| feedforward_channels=1024, |
| num_fcs=2, |
| ffn_drop=0., |
| act_cfg=dict(type='ReLU', inplace=True)), |
| norm_cfg: OptConfigType = dict(type='LN'), |
| init_cfg: OptConfigType = None) -> None: |
|
|
| super().__init__(init_cfg=init_cfg) |
|
|
| self.self_attn_cfg = self_attn_cfg |
| if 'batch_first' not in self.self_attn_cfg: |
| self.self_attn_cfg['batch_first'] = True |
| else: |
| assert self.self_attn_cfg['batch_first'] is True, 'First \ |
| dimension of all DETRs in mmdet is `batch`, \ |
| please set `batch_first` flag.' |
|
|
| self.ffn_cfg = ffn_cfg |
| self.norm_cfg = norm_cfg |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize self_attn, ffn, and norms.""" |
| self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) |
| self.embed_dims = self.self_attn.embed_dims |
| self.ffn = FFN(**self.ffn_cfg) |
| norms_list = [ |
| build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
| for _ in range(2) |
| ] |
| self.norms = ModuleList(norms_list) |
|
|
| def forward(self, query: Tensor, query_pos: Tensor, |
| key_padding_mask: Tensor, **kwargs) -> Tensor: |
| """Forward function of an encoder layer. |
| |
| Args: |
| query (Tensor): The input query, has shape (bs, num_queries, dim). |
| query_pos (Tensor): The positional encoding for query, with |
| the same shape as `query`. |
| key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
| input. ByteTensor. has shape (bs, num_queries). |
| Returns: |
| Tensor: forwarded results, has shape (bs, num_queries, dim). |
| """ |
| query = self.self_attn( |
| query=query, |
| key=query, |
| value=query, |
| query_pos=query_pos, |
| key_pos=query_pos, |
| key_padding_mask=key_padding_mask, |
| **kwargs) |
| query = self.norms[0](query) |
| query = self.ffn(query) |
| query = self.norms[1](query) |
| return query |
|
|
|
|
| class SpecDetrTransformerDecoderLayer(BaseModule): |
| """Decoder layer of Deformable DETR. |
| Implements decoder layer in DETR transformer. |
| Args: |
| self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self |
| attention. |
| cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross |
| attention. |
| ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. |
| norm_cfg (:obj:`ConfigDict` or dict, optional): Config for |
| normalization layers. All the layers will share the same |
| config. Defaults to `LN`. |
| init_cfg (:obj:`ConfigDict` or dict, optional): Config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| self_attn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| num_heads=8, |
| dropout=0.0, |
| batch_first=True), |
| cross_attn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| num_heads=8, |
| dropout=0.0, |
| batch_first=True), |
| ffn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| feedforward_channels=1024, |
| num_fcs=2, |
| ffn_drop=0., |
| act_cfg=dict(type='ReLU', inplace=True), |
| ), |
| norm_cfg: OptConfigType = dict(type='LN'), |
| init_cfg: OptConfigType = None) -> None: |
|
|
| super().__init__(init_cfg=init_cfg) |
|
|
| self.self_attn_cfg = self_attn_cfg |
| self.cross_attn_cfg = cross_attn_cfg |
| if 'batch_first' not in self.self_attn_cfg: |
| self.self_attn_cfg['batch_first'] = True |
| else: |
| assert self.self_attn_cfg['batch_first'] is True, 'First \ |
| dimension of all DETRs in mmdet is `batch`, \ |
| please set `batch_first` flag.' |
|
|
| if 'batch_first' not in self.cross_attn_cfg: |
| self.cross_attn_cfg['batch_first'] = True |
| else: |
| assert self.cross_attn_cfg['batch_first'] is True, 'First \ |
| dimension of all DETRs in mmdet is `batch`, \ |
| please set `batch_first` flag.' |
|
|
| self.ffn_cfg = ffn_cfg |
| self.norm_cfg = norm_cfg |
| self._init_layers() |
|
|
|
|
| def _init_layers(self) -> None: |
| """Initialize self_attn, cross-attn, ffn, and norms.""" |
| |
| self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) |
| self.embed_dims = self.cross_attn.embed_dims |
| self.ffn = FFN(**self.ffn_cfg) |
| norms_list = [ |
| build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
| for _ in range(2) |
| ] |
| self.norms = ModuleList(norms_list) |
|
|
| def forward(self, |
| query: Tensor, |
| key: Tensor = None, |
| value: Tensor = None, |
| query_pos: Tensor = None, |
| key_pos: Tensor = None, |
| self_attn_mask: Tensor = None, |
| cross_attn_mask: Tensor = None, |
| key_padding_mask: Tensor = None, |
| **kwargs) -> Tensor: |
| """ |
| Args: |
| query (Tensor): The input query, has shape (bs, num_queries, dim). |
| key (Tensor, optional): The input key, has shape (bs, num_keys, |
| dim). If `None`, the `query` will be used. Defaults to `None`. |
| value (Tensor, optional): The input value, has the same shape as |
| `key`, as in `nn.MultiheadAttention.forward`. If `None`, the |
| `key` will be used. Defaults to `None`. |
| query_pos (Tensor, optional): The positional encoding for `query`, |
| has the same shape as `query`. If not `None`, it will be added |
| to `query` before forward function. Defaults to `None`. |
| key_pos (Tensor, optional): The positional encoding for `key`, has |
| the same shape as `key`. If not `None`, it will be added to |
| `key` before forward function. If None, and `query_pos` has the |
| same shape as `key`, then `query_pos` will be used for |
| `key_pos`. Defaults to None. |
| self_attn_mask (Tensor, optional): ByteTensor mask, has shape |
| (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. |
| Defaults to None. |
| cross_attn_mask (Tensor, optional): ByteTensor mask, has shape |
| (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. |
| Defaults to None. |
| key_padding_mask (Tensor, optional): The `key_padding_mask` of |
| `self_attn` input. ByteTensor, has shape (bs, num_value). |
| Defaults to None. |
| |
| Returns: |
| Tensor: forwarded results, has shape (bs, num_queries, dim). |
| """ |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| query = self.cross_attn( |
| query=query, |
| key=key, |
| value=value, |
| query_pos=query_pos, |
| key_pos=key_pos, |
| attn_mask=cross_attn_mask, |
| key_padding_mask=key_padding_mask, |
| **kwargs) |
| query = self.norms[0](query) |
| query = self.ffn(query) |
| query = self.norms[1](query) |
| return query |
|
|
|
|
| class CdnQueryGenerator(BaseModule): |
| """Implement query generator of the Contrastive denoising (CDN) proposed in |
| `DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object |
| Detection <https://arxiv.org/abs/2203.03605>`_ |
| |
| Code is modified from the `official github repo |
| <https://github.com/IDEA-Research/DINO>`_. |
| |
| Args: |
| num_classes (int): Number of object classes. |
| embed_dims (int): The embedding dimensions of the generated queries. |
| num_matching_queries (int): The queries number of the matching part. |
| Used for generating dn_mask. |
| label_noise_scale (float): The scale of label noise, defaults to 0.5. |
| box_noise_scale (float): The scale of box noise, defaults to 1.0. |
| group_cfg (:obj:`ConfigDict` or dict, optional): The config of the |
| denoising queries grouping, includes `dynamic`, `num_dn_queries`, |
| and `num_groups`. Two grouping strategies, 'static dn groups' and |
| 'dynamic dn groups', are supported. When `dynamic` is `False`, |
| the `num_groups` should be set, and the number of denoising query |
| groups will always be `num_groups`. When `dynamic` is `True`, the |
| `num_dn_queries` should be set, and the group number will be |
| dynamic to ensure that the denoising queries number will not exceed |
| `num_dn_queries` to prevent large fluctuations of memory. Defaults |
| to `None`. |
| """ |
|
|
| def __init__(self, |
| num_classes: int, |
| embed_dims: int, |
| num_matching_queries: int, |
| label_noise_scale: float = 0.5, |
| box_noise_scale: float = 1.0, |
| query_initial: str = 'one', |
| group_cfg: OptConfigType = None) -> None: |
| super().__init__() |
| self.num_classes = num_classes |
| self.embed_dims = embed_dims |
| self.num_matching_queries = num_matching_queries |
| self.label_noise_scale = label_noise_scale |
| self.box_noise_scale = box_noise_scale |
|
|
| |
| group_cfg = {} if group_cfg is None else group_cfg |
| self.dynamic_dn_groups = group_cfg.get('dynamic', True) |
| if self.dynamic_dn_groups: |
| if 'num_dn_queries' not in group_cfg: |
| warnings.warn("'num_dn_queries' should be set when using " |
| 'dynamic dn groups, use 100 as default.') |
| self.num_dn_queries = group_cfg.get('num_dn_queries', 100) |
| assert isinstance(self.num_dn_queries, int), \ |
| f'Expected the num_dn_queries to have type int, but got ' \ |
| f'{self.num_dn_queries}({type(self.num_dn_queries)}). ' |
| else: |
| assert 'num_groups' in group_cfg, \ |
| 'num_groups should be set when using static dn groups' |
| self.num_groups = group_cfg['num_groups'] |
| assert isinstance(self.num_groups, int), \ |
| f'Expected the num_groups to have type int, but got ' \ |
| f'{self.num_groups}({type(self.num_groups)}). ' |
|
|
| |
| |
| |
| |
| |
| self.query_initial =query_initial |
| if self.query_initial == 'embed': |
| self.label_embedding = nn.Embedding(self.num_classes, self.embed_dims) |
|
|
| def __call__(self, batch_data_samples: SampleList) -> tuple: |
| """Generate contrastive denoising (cdn) queries with ground truth. |
| max_num_target 为一个batch内各个图像目标数量的最大值 |
| Descriptions of the Number Values in code and comments: |
| - num_target_total: the total target number of the input batch |
| samples. |
| - max_num_target: the max target number of the input batch samples. |
| - num_noisy_targets: the total targets number after adding noise, |
| i.e., num_target_total * num_groups * 2. |
| - num_denoising_queries: the length of the output batched queries, |
| i.e., max_num_target * num_groups * 2. |
| |
| NOTE The format of input bboxes in batch_data_samples is unnormalized |
| (x, y, x, y), and the output bbox queries are embedded by normalized |
| (cx, cy, w, h) format bboxes going through inverse_sigmoid. |
| |
| Args: |
| batch_data_samples (list[:obj:`DetDataSample`]): List of the batch |
| data samples, each includes `gt_instance` which has attributes |
| `bboxes` and `labels`. The `bboxes` has unnormalized coordinate |
| format (x, y, x, y). |
| |
| Returns: |
| tuple: The outputs of the dn query generator. |
| |
| - dn_label_query (Tensor): The output content queries for denoising |
| part, has shape (bs, num_denoising_queries, dim), where |
| `num_denoising_queries = max_num_target * num_groups * 2`. |
| - dn_bbox_query (Tensor): The output reference bboxes as positions |
| of queries for denoising part, which are embedded by normalized |
| (cx, cy, w, h) format bboxes going through inverse_sigmoid, has |
| shape (bs, num_denoising_queries, 4) with the last dimension |
| arranged as (cx, cy, w, h). |
| - attn_mask (Tensor): The attention mask to prevent information |
| leakage from different denoising groups and matching parts, |
| will be used as `self_attn_mask` of the `decoder`, has shape |
| (num_queries_total, num_queries_total), where `num_queries_total` |
| is the sum of `num_denoising_queries` and `num_matching_queries`. |
| - dn_meta (Dict[str, int]): The dictionary saves information about |
| group collation, including 'num_denoising_queries' and |
| 'num_denoising_groups'. It will be used for split outputs of |
| denoising and matching parts and loss calculation. |
| |
| """ |
| |
| gt_labels_list = [] |
| gt_bboxes_list = [] |
| for sample in batch_data_samples: |
| img_h, img_w = sample.img_shape |
| bboxes = sample.gt_instances.bboxes |
| factor = bboxes.new_tensor([img_w, img_h, img_w, |
| img_h]).unsqueeze(0) |
| bboxes_normalized = bboxes / factor |
| gt_bboxes_list.append(bboxes_normalized) |
| gt_labels_list.append(sample.gt_instances.labels) |
| gt_labels = torch.cat(gt_labels_list) |
| gt_bboxes = torch.cat(gt_bboxes_list) |
|
|
| num_target_list = [len(bboxes) for bboxes in gt_bboxes_list] |
| max_num_target = max(num_target_list) |
| num_groups = self.get_num_groups(max_num_target) |
|
|
| dn_label_query = self.generate_dn_label_query(gt_labels, num_groups) |
| dn_bbox_query = self.generate_dn_bbox_query(gt_bboxes, num_groups) |
| |
| |
| |
| |
| batch_idx = torch.cat([ |
| torch.full_like(t.long(), i) for i, t in enumerate(gt_labels_list) |
| ]) |
|
|
|
|
| |
| |
| dn_label_query, dn_bbox_query = self.collate_dn_queries( |
| dn_label_query, dn_bbox_query, batch_idx, len(batch_data_samples), |
| num_groups) |
|
|
| |
| attn_mask = self.generate_dn_mask( |
| max_num_target, num_groups, device=dn_label_query.device) |
|
|
| dn_meta = dict( |
| num_denoising_queries=int(max_num_target * 2 * num_groups), |
| num_denoising_groups=num_groups) |
|
|
| return dn_label_query, dn_bbox_query, attn_mask, dn_meta |
|
|
| def get_num_groups(self, max_num_target: int = None) -> int: |
| """Calculate denoising query groups number. |
| |
| Two grouping strategies, 'static dn groups' and 'dynamic dn groups', |
| are supported. When `self.dynamic_dn_groups` is `False`, the number |
| of denoising query groups will always be `self.num_groups`. When |
| `self.dynamic_dn_groups` is `True`, the group number will be dynamic, |
| ensuring the denoising queries number will not exceed |
| `self.num_dn_queries` to prevent large fluctuations of memory. |
| |
| NOTE The `num_group` is shared for different samples in a batch. When |
| the target numbers in the samples varies, the denoising queries of the |
| samples containing fewer targets are padded to the max length. |
| |
| Args: |
| max_num_target (int, optional): The max target number of the batch |
| samples. It will only be used when `self.dynamic_dn_groups` is |
| `True`. Defaults to `None`. |
| |
| Returns: |
| int: The denoising group number of the current batch. |
| """ |
| if self.dynamic_dn_groups: |
| assert max_num_target is not None, \ |
| 'group_queries should be provided when using ' \ |
| 'dynamic dn groups' |
| if max_num_target == 0: |
| num_groups = 1 |
| else: |
| num_groups = self.num_dn_queries // max_num_target |
| else: |
| num_groups = self.num_groups |
| if num_groups < 1: |
| num_groups = 1 |
| return int(num_groups) |
|
|
| def generate_dn_label_query(self, gt_labels: Tensor, |
| num_groups: int) -> Tensor: |
| """Generate noisy labels and their query embeddings. |
| |
| The strategy for generating noisy labels is: Randomly choose labels of |
| `self.label_noise_scale * 0.5` proportion and override each of them |
| with a random object category label. |
| |
| NOTE Not add noise to all labels. Besides, the `self.label_noise_scale |
| * 0.5` arg is the ratio of the chosen positions, which is higher than |
| the actual proportion of noisy labels, because the labels to override |
| may be correct. And the gap becomes larger as the number of target |
| categories decreases. The users should notice this and modify the scale |
| arg or the corresponding logic according to specific dataset. |
| |
| Args: |
| gt_labels (Tensor): The concatenated gt labels of all samples |
| in the batch, has shape (num_target_total, ) where |
| `num_target_total = sum(num_target_list)`. |
| num_groups (int): The number of denoising query groups. |
| |
| Returns: |
| Tensor: The query embeddings of noisy labels, has shape |
| (num_noisy_targets, embed_dims), where `num_noisy_targets = |
| num_target_total * num_groups * 2`. |
| """ |
| if self.query_initial == 'one': |
| dn_label_query = torch.ones((gt_labels.size(0)*num_groups*2, self.embed_dims), device=gt_labels.device) |
| elif self.query_initial == 'random': |
| dn_label_query = torch.rand((gt_labels.size(0)*num_groups*2, self.embed_dims), device=gt_labels.device) |
| elif self.query_initial == 'embed': |
| gt_labels_expand = gt_labels.repeat(2 * num_groups, |
| 1).view(-1) |
| dn_label_query = self.label_embedding(gt_labels_expand) |
| return dn_label_query |
|
|
|
|
|
|
| def generate_dn_bbox_query(self, gt_bboxes: Tensor, |
| num_groups: int) -> Tensor: |
| """Generate noisy bboxes and their query embeddings. |
| |
| The strategy for generating noisy bboxes is as follow: |
| |
| .. code:: text |
| |
| +--------------------+ |
| | negative | |
| | +----------+ | |
| | | positive | | |
| | | +-----|----+------------+ |
| | | | | | | |
| | +----+-----+ | | |
| | | | | |
| +---------+----------+ | |
| | | |
| | gt bbox | |
| | | |
| | +---------+----------+ |
| | | | | |
| | | +----+-----+ | |
| | | | | | | |
| +-------------|--- +----+ | | |
| | | positive | | |
| | +----------+ | |
| | negative | |
| +--------------------+ |
| |
| The random noise is added to the top-left and down-right point |
| positions, hence, normalized (x, y, x, y) format of bboxes are |
| required. The noisy bboxes of positive queries have the points |
| both within the inner square, while those of negative queries |
| have the points both between the inner and outer squares. |
| |
| Besides, the length of outer square is twice as long as that of |
| the inner square, i.e., self.box_noise_scale * w_or_h / 2. |
| NOTE The noise is added to all the bboxes. Moreover, there is still |
| unconsidered case when one point is within the positive square and |
| the others is between the inner and outer squares. |
| |
| Args: |
| gt_bboxes (Tensor): The concatenated gt bboxes of all samples |
| in the batch, has shape (num_target_total, 4) with the last |
| dimension arranged as (cx, cy, w, h) where |
| `num_target_total = sum(num_target_list)`. |
| num_groups (int): The number of denoising query groups. |
| |
| Returns: |
| Tensor: The output noisy bboxes, which are embedded by normalized |
| (cx, cy, w, h) format bboxes going through inverse_sigmoid, has |
| shape (num_noisy_targets, 4) with the last dimension arranged as |
| (cx, cy, w, h), where |
| `num_noisy_targets = num_target_total * num_groups * 2`. |
| """ |
| assert self.box_noise_scale > 0 |
| device = gt_bboxes.device |
|
|
| |
| gt_bboxes_expand = gt_bboxes.repeat(2 * num_groups, 1) |
|
|
| |
| positive_idx = torch.arange( |
| len(gt_bboxes), dtype=torch.long, device=device) |
| positive_idx = positive_idx.unsqueeze(0).repeat(num_groups, 1) |
| positive_idx += 2 * len(gt_bboxes) * torch.arange( |
| num_groups, dtype=torch.long, device=device)[:, None] |
| positive_idx = positive_idx.flatten() |
| negative_idx = positive_idx + len(gt_bboxes) |
|
|
|
|
| bboxes_cxcywh_expand = bbox_xyxy_to_cxcywh(gt_bboxes_expand) |
| bboxes_whwh = bbox_xyxy_to_cxcywh(gt_bboxes_expand)[:, 2:].repeat(1, 2) |
| rand_part = torch.rand_like(gt_bboxes_expand) * 2.0 - 1.0 |
| rand_part[:,:2] *= self.label_noise_scale |
| rand_part[:, 2:] *= self.box_noise_scale |
| noisy_bboxes_expand = bboxes_cxcywh_expand + torch.mul(rand_part, bboxes_whwh)/2 |
|
|
| rand_sign = torch.randint_like( |
| gt_bboxes_expand, low=0, high=2, |
| dtype=torch.float32) * 2.0 - 1.0 |
| |
| rand_part = torch.rand_like(gt_bboxes_expand) |
| |
| rand_part = self.label_noise_scale + rand_part * self.label_noise_scale |
| |
| rand_part *= rand_sign |
| noisy_bboxes_expand[negative_idx,:2] = bboxes_cxcywh_expand[negative_idx,:2]+torch.mul(rand_part[negative_idx,2:],bboxes_cxcywh_expand[negative_idx,2:]*0.5) |
| noisy_bboxes_expand = bbox_cxcywh_to_xyxy(noisy_bboxes_expand) |
| noisy_bboxes_expand = noisy_bboxes_expand.clamp(min=0.0, max=1.0) |
| noisy_bboxes_expand = bbox_xyxy_to_cxcywh(noisy_bboxes_expand) |
|
|
| dn_bbox_query = inverse_sigmoid(noisy_bboxes_expand, eps=1e-3) |
| return dn_bbox_query |
|
|
|
|
| def collate_dn_queries(self, input_label_query: Tensor, |
| input_bbox_query: Tensor, batch_idx: Tensor, |
| batch_size: int, num_groups: int) -> Tuple[Tensor]: |
| """Collate generated queries to obtain batched dn queries. |
| |
| The strategy for query collation is as follow: |
| |
| .. code:: text |
| |
| input_queries (num_target_total, query_dim) |
| P_A1 P_B1 P_B2 N_A1 N_B1 N_B2 P'A1 P'B1 P'B2 N'A1 N'B1 N'B2 |
| |________ group1 ________| |________ group2 ________| |
| | |
| V |
| P_A1 Pad0 N_A1 Pad0 P'A1 Pad0 N'A1 Pad0 |
| P_B1 P_B2 N_B1 N_B2 P'B1 P'B2 N'B1 N'B2 |
| |____ group1 ____| |____ group2 ____| |
| batched_queries (batch_size, max_num_target, query_dim) |
| |
| where query_dim is 4 for bbox and self.embed_dims for label. |
| Notation: _-group 1; '-group 2; |
| A-Sample1(has 1 target); B-sample2(has 2 targets) |
| |
| Args: |
| input_label_query (Tensor): The generated label queries of all |
| targets, has shape (num_target_total, embed_dims) where |
| `num_target_total = sum(num_target_list)`. |
| input_bbox_query (Tensor): The generated bbox queries of all |
| targets, has shape (num_target_total, 4) with the last |
| dimension arranged as (cx, cy, w, h). |
| batch_idx (Tensor): The batch index of the corresponding sample |
| for each target, has shape (num_target_total). |
| batch_size (int): The size of the input batch. |
| num_groups (int): The number of denoising query groups. |
| |
| Returns: |
| tuple[Tensor]: Output batched label and bbox queries. |
| - batched_label_query (Tensor): The output batched label queries, |
| has shape (batch_size, max_num_target, embed_dims). |
| - batched_bbox_query (Tensor): The output batched bbox queries, |
| has shape (batch_size, max_num_target, 4) with the last dimension |
| arranged as (cx, cy, w, h). |
| """ |
| device = input_label_query.device |
| num_target_list = [ |
| torch.sum(batch_idx == idx) for idx in range(batch_size) |
| ] |
| max_num_target = max(num_target_list) |
| num_denoising_queries = int(max_num_target * 2 * num_groups) |
|
|
| map_query_index = torch.cat([ |
| torch.arange(num_target, device=device) |
| for num_target in num_target_list |
| ]) |
| map_query_index = torch.cat([ |
| map_query_index + max_num_target * i for i in range(2 * num_groups) |
| ]).long() |
| batch_idx_expand = batch_idx.repeat(2 * num_groups, 1).view(-1) |
| mapper = (batch_idx_expand, map_query_index) |
|
|
| batched_label_query = torch.zeros( |
| batch_size, num_denoising_queries, self.embed_dims, device=device) |
| batched_bbox_query = torch.zeros( |
| batch_size, num_denoising_queries, 4, device=device) |
|
|
| batched_label_query[mapper] = input_label_query |
| batched_bbox_query[mapper] = input_bbox_query |
| return batched_label_query, batched_bbox_query |
|
|
| def generate_dn_mask(self, max_num_target: int, num_groups: int, |
| device: Union[torch.device, str]) -> Tensor: |
| """Generate attention mask to prevent information leakage from |
| different denoising groups and matching parts. |
| |
| .. code:: text |
| |
| 0 0 0 0 1 1 1 1 0 0 0 0 0 |
| 0 0 0 0 1 1 1 1 0 0 0 0 0 |
| 0 0 0 0 1 1 1 1 0 0 0 0 0 |
| 0 0 0 0 1 1 1 1 0 0 0 0 0 |
| 1 1 1 1 0 0 0 0 0 0 0 0 0 |
| 1 1 1 1 0 0 0 0 0 0 0 0 0 |
| 1 1 1 1 0 0 0 0 0 0 0 0 0 |
| 1 1 1 1 0 0 0 0 0 0 0 0 0 |
| 1 1 1 1 1 1 1 1 0 0 0 0 0 |
| 1 1 1 1 1 1 1 1 0 0 0 0 0 |
| 1 1 1 1 1 1 1 1 0 0 0 0 0 |
| 1 1 1 1 1 1 1 1 0 0 0 0 0 |
| 1 1 1 1 1 1 1 1 0 0 0 0 0 |
| max_num_target |_| |_________| num_matching_queries |
| |_____________| num_denoising_queries |
| |
| 1 -> True (Masked), means 'can not see'. |
| 0 -> False (UnMasked), means 'can see'. |
| |
| Args: |
| max_num_target (int): The max target number of the input batch |
| samples. |
| num_groups (int): The number of denoising query groups. |
| device (obj:`device` or str): The device of generated mask. |
| |
| Returns: |
| Tensor: The attention mask to prevent information leakage from |
| different denoising groups and matching parts, will be used as |
| `self_attn_mask` of the `decoder`, has shape (num_queries_total, |
| num_queries_total), where `num_queries_total` is the sum of |
| `num_denoising_queries` and `num_matching_queries`. |
| """ |
| num_denoising_queries = int(max_num_target * 2 * num_groups) |
| num_queries_total = num_denoising_queries + self.num_matching_queries |
| attn_mask = torch.zeros( |
| num_queries_total, |
| num_queries_total, |
| device=device, |
| dtype=torch.bool) |
| return attn_mask |
|
|
|
|