| |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| from mmcv.cnn import build_norm_layer |
| from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention |
| from mmcv.ops import MultiScaleDeformableAttention |
| from mmengine.model import ModuleList |
| from torch import Tensor, nn |
|
|
| from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, |
| DetrTransformerEncoder, DetrTransformerEncoderLayer) |
| from .utils import inverse_sigmoid |
|
|
|
|
| class DeformableDetrTransformerEncoder(DetrTransformerEncoder): |
| """Transformer encoder of Deformable DETR.""" |
|
|
| def _init_layers(self) -> None: |
| """Initialize encoder layers.""" |
| self.layers = ModuleList([ |
| DeformableDetrTransformerEncoderLayer(**self.layer_cfg) |
| for _ in range(self.num_layers) |
| ]) |
| self.embed_dims = self.layers[0].embed_dims |
|
|
| def forward(self, query: Tensor, query_pos: Tensor, |
| key_padding_mask: Tensor, spatial_shapes: Tensor, |
| level_start_index: Tensor, valid_ratios: Tensor, |
| **kwargs) -> Tensor: |
| """Forward function of Transformer encoder. |
| |
| Args: |
| query (Tensor): The input query, has shape (bs, num_queries, dim). |
| query_pos (Tensor): The positional encoding for query, has shape |
| (bs, num_queries, dim). |
| key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
| input. ByteTensor, has shape (bs, num_queries). |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, |
| has shape (num_levels, 2), last dimension represents (h, w). |
| level_start_index (Tensor): The start index of each level. |
| A tensor has shape (num_levels, ) and can be represented |
| as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
| valid_ratios (Tensor): The ratios of the valid width and the valid |
| height relative to the width and the height of features in all |
| levels, has shape (bs, num_levels, 2). |
| |
| Returns: |
| Tensor: Output queries of Transformer encoder, which is also |
| called 'encoder output embeddings' or 'memory', has shape |
| (bs, num_queries, dim) |
| """ |
| reference_points = self.get_encoder_reference_points( |
| spatial_shapes, valid_ratios, device=query.device) |
| for layer in self.layers: |
| query = layer( |
| query=query, |
| query_pos=query_pos, |
| key_padding_mask=key_padding_mask, |
| spatial_shapes=spatial_shapes, |
| level_start_index=level_start_index, |
| valid_ratios=valid_ratios, |
| reference_points=reference_points, |
| **kwargs) |
| return query |
|
|
| @staticmethod |
| def get_encoder_reference_points( |
| spatial_shapes: Tensor, valid_ratios: Tensor, |
| device: Union[torch.device, str]) -> Tensor: |
| """Get the reference points used in encoder. |
| |
| Args: |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, |
| has shape (num_levels, 2), last dimension represents (h, w). |
| valid_ratios (Tensor): The ratios of the valid width and the valid |
| height relative to the width and the height of features in all |
| levels, has shape (bs, num_levels, 2). |
| device (obj:`device` or str): The device acquired by the |
| `reference_points`. |
| |
| Returns: |
| Tensor: Reference points used in decoder, has shape (bs, length, |
| num_levels, 2). |
| """ |
|
|
| reference_points_list = [] |
| for lvl, (H, W) in enumerate(spatial_shapes): |
| ref_y, ref_x = torch.meshgrid( |
| torch.linspace( |
| 0.5, H - 0.5, H, dtype=torch.float32, device=device), |
| torch.linspace( |
| 0.5, W - 0.5, W, dtype=torch.float32, device=device)) |
| ref_y = ref_y.reshape(-1)[None] / ( |
| valid_ratios[:, None, lvl, 1] * H) |
| ref_x = ref_x.reshape(-1)[None] / ( |
| valid_ratios[:, None, lvl, 0] * W) |
| ref = torch.stack((ref_x, ref_y), -1) |
| reference_points_list.append(ref) |
| reference_points = torch.cat(reference_points_list, 1) |
| |
| reference_points = reference_points[:, :, None] * valid_ratios[:, None] |
| return reference_points |
|
|
|
|
| class DeformableDetrTransformerDecoder(DetrTransformerDecoder): |
| """Transformer Decoder of Deformable DETR.""" |
|
|
| def _init_layers(self) -> None: |
| """Initialize decoder layers.""" |
| self.layers = ModuleList([ |
| DeformableDetrTransformerDecoderLayer(**self.layer_cfg) |
| for _ in range(self.num_layers) |
| ]) |
| self.embed_dims = self.layers[0].embed_dims |
| if self.post_norm_cfg is not None: |
| raise ValueError('There is not post_norm in ' |
| f'{self._get_name()}') |
|
|
| def forward(self, |
| query: Tensor, |
| query_pos: Tensor, |
| value: Tensor, |
| key_padding_mask: Tensor, |
| reference_points: Tensor, |
| spatial_shapes: Tensor, |
| level_start_index: Tensor, |
| valid_ratios: Tensor, |
| reg_branches: Optional[nn.Module] = None, |
| **kwargs) -> Tuple[Tensor]: |
| """Forward function of Transformer decoder. |
| |
| Args: |
| query (Tensor): The input queries, has shape (bs, num_queries, |
| dim). |
| query_pos (Tensor): The input positional query, has shape |
| (bs, num_queries, dim). It will be added to `query` before |
| forward function. |
| value (Tensor): The input values, has shape (bs, num_value, dim). |
| key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` |
| input. ByteTensor, has shape (bs, num_value). |
| reference_points (Tensor): The initial reference, has shape |
| (bs, num_queries, 4) with the last dimension arranged as |
| (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has |
| shape (bs, num_queries, 2) with the last dimension arranged |
| as (cx, cy). |
| spatial_shapes (Tensor): Spatial shapes of features in all levels, |
| has shape (num_levels, 2), last dimension represents (h, w). |
| level_start_index (Tensor): The start index of each level. |
| A tensor has shape (num_levels, ) and can be represented |
| as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
| valid_ratios (Tensor): The ratios of the valid width and the valid |
| height relative to the width and the height of features in all |
| levels, has shape (bs, num_levels, 2). |
| reg_branches: (obj:`nn.ModuleList`, optional): Used for refining |
| the regression results. Only would be passed when |
| `with_box_refine` is `True`, otherwise would be `None`. |
| |
| Returns: |
| tuple[Tensor]: Outputs of Deformable Transformer Decoder. |
| |
| - output (Tensor): Output embeddings of the last decoder, has |
| shape (num_queries, bs, embed_dims) when `return_intermediate` |
| is `False`. Otherwise, Intermediate output embeddings of all |
| decoder layers, has shape (num_decoder_layers, num_queries, bs, |
| embed_dims). |
| - reference_points (Tensor): The reference of the last decoder |
| layer, has shape (bs, num_queries, 4) when `return_intermediate` |
| is `False`. Otherwise, Intermediate references of all decoder |
| layers, has shape (num_decoder_layers, bs, num_queries, 4). The |
| coordinates are arranged as (cx, cy, w, h) |
| """ |
| output = query |
| intermediate = [] |
| intermediate_reference_points = [] |
| for layer_id, layer in enumerate(self.layers): |
| if reference_points.shape[-1] == 4: |
| reference_points_input = \ |
| reference_points[:, :, None] * \ |
| torch.cat([valid_ratios, valid_ratios], -1)[:, None] |
| else: |
| assert reference_points.shape[-1] == 2 |
| reference_points_input = \ |
| reference_points[:, :, None] * \ |
| valid_ratios[:, None] |
| output = layer( |
| output, |
| query_pos=query_pos, |
| value=value, |
| key_padding_mask=key_padding_mask, |
| spatial_shapes=spatial_shapes, |
| level_start_index=level_start_index, |
| valid_ratios=valid_ratios, |
| reference_points=reference_points_input, |
| **kwargs) |
|
|
| if reg_branches is not None: |
| tmp_reg_preds = reg_branches[layer_id](output) |
| if reference_points.shape[-1] == 4: |
| new_reference_points = tmp_reg_preds + inverse_sigmoid( |
| reference_points) |
| new_reference_points = new_reference_points.sigmoid() |
| else: |
| assert reference_points.shape[-1] == 2 |
| new_reference_points = tmp_reg_preds |
| new_reference_points[..., :2] = tmp_reg_preds[ |
| ..., :2] + inverse_sigmoid(reference_points) |
| new_reference_points = new_reference_points.sigmoid() |
| reference_points = new_reference_points.detach() |
|
|
| if self.return_intermediate: |
| intermediate.append(output) |
| intermediate_reference_points.append(reference_points) |
|
|
| if self.return_intermediate: |
| return torch.stack(intermediate), torch.stack( |
| intermediate_reference_points) |
|
|
| return output, reference_points |
|
|
|
|
| class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): |
| """Encoder layer of Deformable DETR.""" |
|
|
| def _init_layers(self) -> None: |
| """Initialize self_attn, ffn, and norms.""" |
| self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) |
| self.embed_dims = self.self_attn.embed_dims |
| self.ffn = FFN(**self.ffn_cfg) |
| norms_list = [ |
| build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
| for _ in range(2) |
| ] |
| self.norms = ModuleList(norms_list) |
|
|
|
|
| class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): |
| """Decoder layer of Deformable DETR.""" |
|
|
| def _init_layers(self) -> None: |
| """Initialize self_attn, cross-attn, ffn, and norms.""" |
| self.self_attn = MultiheadAttention(**self.self_attn_cfg) |
| self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) |
| self.embed_dims = self.self_attn.embed_dims |
| self.ffn = FFN(**self.ffn_cfg) |
| norms_list = [ |
| build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
| for _ in range(3) |
| ] |
| self.norms = ModuleList(norms_list) |
|
|