| |
| from typing import Union |
|
|
| import torch |
| from mmcv.cnn import build_norm_layer |
| from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention |
| from mmengine import ConfigDict |
| from mmengine.model import BaseModule, ModuleList |
| from torch import Tensor |
|
|
| from mmdet.utils import ConfigType, OptConfigType |
|
|
|
|
| class DetrTransformerEncoder(BaseModule): |
| """Encoder of DETR. |
| |
| Args: |
| num_layers (int): Number of encoder layers. |
| layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder |
| layer. All the layers will share the same config. |
| init_cfg (:obj:`ConfigDict` or dict, optional): the config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| num_layers: int, |
| layer_cfg: ConfigType, |
| init_cfg: OptConfigType = None) -> None: |
|
|
| super().__init__(init_cfg=init_cfg) |
| self.num_layers = num_layers |
| self.layer_cfg = layer_cfg |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize encoder layers.""" |
| self.layers = ModuleList([ |
| DetrTransformerEncoderLayer(**self.layer_cfg) |
| for _ in range(self.num_layers) |
| ]) |
| self.embed_dims = self.layers[0].embed_dims |
|
|
| def forward(self, query: Tensor, query_pos: Tensor, |
| key_padding_mask: Tensor, **kwargs) -> Tensor: |
| """Forward function of encoder. |
| |
| Args: |
| query (Tensor): Input queries of encoder, has shape |
| (bs, num_queries, dim). |
| query_pos (Tensor): The positional embeddings of the queries, has |
| shape (bs, num_queries, dim). |
| key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
| input. ByteTensor, has shape (bs, num_queries). |
| |
| Returns: |
| Tensor: Has shape (bs, num_queries, dim) if `batch_first` is |
| `True`, otherwise (num_queries, bs, dim). |
| """ |
| for layer in self.layers: |
| query = layer(query, query_pos, key_padding_mask, **kwargs) |
| return query |
|
|
|
|
| class DetrTransformerDecoder(BaseModule): |
| """Decoder of DETR. |
| |
| Args: |
| num_layers (int): Number of decoder layers. |
| layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder |
| layer. All the layers will share the same config. |
| post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the |
| post normalization layer. Defaults to `LN`. |
| return_intermediate (bool, optional): Whether to return outputs of |
| intermediate layers. Defaults to `True`, |
| init_cfg (:obj:`ConfigDict` or dict, optional): the config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| num_layers: int, |
| layer_cfg: ConfigType, |
| post_norm_cfg: OptConfigType = dict(type='LN'), |
| return_intermediate: bool = True, |
| init_cfg: Union[dict, ConfigDict] = None) -> None: |
| super().__init__(init_cfg=init_cfg) |
| self.layer_cfg = layer_cfg |
| self.num_layers = num_layers |
| self.post_norm_cfg = post_norm_cfg |
| self.return_intermediate = return_intermediate |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize decoder layers.""" |
| self.layers = ModuleList([ |
| DetrTransformerDecoderLayer(**self.layer_cfg) |
| for _ in range(self.num_layers) |
| ]) |
| self.embed_dims = self.layers[0].embed_dims |
| self.post_norm = build_norm_layer(self.post_norm_cfg, |
| self.embed_dims)[1] |
|
|
| def forward(self, query: Tensor, key: Tensor, value: Tensor, |
| query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor, |
| **kwargs) -> Tensor: |
| """Forward function of decoder |
| Args: |
| query (Tensor): The input query, has shape (bs, num_queries, dim). |
| key (Tensor): The input key, has shape (bs, num_keys, dim). |
| value (Tensor): The input value with the same shape as `key`. |
| query_pos (Tensor): The positional encoding for `query`, with the |
| same shape as `query`. |
| key_pos (Tensor): The positional encoding for `key`, with the |
| same shape as `key`. |
| key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` |
| input. ByteTensor, has shape (bs, num_value). |
| |
| Returns: |
| Tensor: The forwarded results will have shape |
| (num_decoder_layers, bs, num_queries, dim) if |
| `return_intermediate` is `True` else (1, bs, num_queries, dim). |
| """ |
| intermediate = [] |
| for layer in self.layers: |
| query = layer( |
| query, |
| key=key, |
| value=value, |
| query_pos=query_pos, |
| key_pos=key_pos, |
| key_padding_mask=key_padding_mask, |
| **kwargs) |
| if self.return_intermediate: |
| intermediate.append(self.post_norm(query)) |
| query = self.post_norm(query) |
|
|
| if self.return_intermediate: |
| return torch.stack(intermediate) |
|
|
| return query.unsqueeze(0) |
|
|
|
|
| class DetrTransformerEncoderLayer(BaseModule): |
| """Implements encoder layer in DETR transformer. |
| |
| Args: |
| self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self |
| attention. |
| ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. |
| norm_cfg (:obj:`ConfigDict` or dict, optional): Config for |
| normalization layers. All the layers will share the same |
| config. Defaults to `LN`. |
| init_cfg (:obj:`ConfigDict` or dict, optional): Config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| self_attn_cfg: OptConfigType = dict( |
| embed_dims=256, num_heads=8, dropout=0.0), |
| ffn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| feedforward_channels=1024, |
| num_fcs=2, |
| ffn_drop=0., |
| act_cfg=dict(type='ReLU', inplace=True)), |
| norm_cfg: OptConfigType = dict(type='LN'), |
| init_cfg: OptConfigType = None) -> None: |
|
|
| super().__init__(init_cfg=init_cfg) |
|
|
| self.self_attn_cfg = self_attn_cfg |
| if 'batch_first' not in self.self_attn_cfg: |
| self.self_attn_cfg['batch_first'] = True |
| else: |
| assert self.self_attn_cfg['batch_first'] is True, 'First \ |
| dimension of all DETRs in mmdet is `batch`, \ |
| please set `batch_first` flag.' |
|
|
| self.ffn_cfg = ffn_cfg |
| self.norm_cfg = norm_cfg |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize self-attention, FFN, and normalization.""" |
| self.self_attn = MultiheadAttention(**self.self_attn_cfg) |
| self.embed_dims = self.self_attn.embed_dims |
| self.ffn = FFN(**self.ffn_cfg) |
| norms_list = [ |
| build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
| for _ in range(2) |
| ] |
| self.norms = ModuleList(norms_list) |
|
|
| def forward(self, query: Tensor, query_pos: Tensor, |
| key_padding_mask: Tensor, **kwargs) -> Tensor: |
| """Forward function of an encoder layer. |
| |
| Args: |
| query (Tensor): The input query, has shape (bs, num_queries, dim). |
| query_pos (Tensor): The positional encoding for query, with |
| the same shape as `query`. |
| key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
| input. ByteTensor. has shape (bs, num_queries). |
| Returns: |
| Tensor: forwarded results, has shape (bs, num_queries, dim). |
| """ |
| query = self.self_attn( |
| query=query, |
| key=query, |
| value=query, |
| query_pos=query_pos, |
| key_pos=query_pos, |
| key_padding_mask=key_padding_mask, |
| **kwargs) |
| query = self.norms[0](query) |
| query = self.ffn(query) |
| query = self.norms[1](query) |
|
|
| return query |
|
|
|
|
| class DetrTransformerDecoderLayer(BaseModule): |
| """Implements decoder layer in DETR transformer. |
| |
| Args: |
| self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self |
| attention. |
| cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross |
| attention. |
| ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. |
| norm_cfg (:obj:`ConfigDict` or dict, optional): Config for |
| normalization layers. All the layers will share the same |
| config. Defaults to `LN`. |
| init_cfg (:obj:`ConfigDict` or dict, optional): Config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| self_attn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| num_heads=8, |
| dropout=0.0, |
| batch_first=True), |
| cross_attn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| num_heads=8, |
| dropout=0.0, |
| batch_first=True), |
| ffn_cfg: OptConfigType = dict( |
| embed_dims=256, |
| feedforward_channels=1024, |
| num_fcs=2, |
| ffn_drop=0., |
| act_cfg=dict(type='ReLU', inplace=True), |
| ), |
| norm_cfg: OptConfigType = dict(type='LN'), |
| init_cfg: OptConfigType = None) -> None: |
|
|
| super().__init__(init_cfg=init_cfg) |
|
|
| self.self_attn_cfg = self_attn_cfg |
| self.cross_attn_cfg = cross_attn_cfg |
| if 'batch_first' not in self.self_attn_cfg: |
| self.self_attn_cfg['batch_first'] = True |
| else: |
| assert self.self_attn_cfg['batch_first'] is True, 'First \ |
| dimension of all DETRs in mmdet is `batch`, \ |
| please set `batch_first` flag.' |
|
|
| if 'batch_first' not in self.cross_attn_cfg: |
| self.cross_attn_cfg['batch_first'] = True |
| else: |
| assert self.cross_attn_cfg['batch_first'] is True, 'First \ |
| dimension of all DETRs in mmdet is `batch`, \ |
| please set `batch_first` flag.' |
|
|
| self.ffn_cfg = ffn_cfg |
| self.norm_cfg = norm_cfg |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize self-attention, FFN, and normalization.""" |
| self.self_attn = MultiheadAttention(**self.self_attn_cfg) |
| self.cross_attn = MultiheadAttention(**self.cross_attn_cfg) |
| self.embed_dims = self.self_attn.embed_dims |
| self.ffn = FFN(**self.ffn_cfg) |
| norms_list = [ |
| build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
| for _ in range(3) |
| ] |
| self.norms = ModuleList(norms_list) |
|
|
| def forward(self, |
| query: Tensor, |
| key: Tensor = None, |
| value: Tensor = None, |
| query_pos: Tensor = None, |
| key_pos: Tensor = None, |
| self_attn_mask: Tensor = None, |
| cross_attn_mask: Tensor = None, |
| key_padding_mask: Tensor = None, |
| **kwargs) -> Tensor: |
| """ |
| Args: |
| query (Tensor): The input query, has shape (bs, num_queries, dim). |
| key (Tensor, optional): The input key, has shape (bs, num_keys, |
| dim). If `None`, the `query` will be used. Defaults to `None`. |
| value (Tensor, optional): The input value, has the same shape as |
| `key`, as in `nn.MultiheadAttention.forward`. If `None`, the |
| `key` will be used. Defaults to `None`. |
| query_pos (Tensor, optional): The positional encoding for `query`, |
| has the same shape as `query`. If not `None`, it will be added |
| to `query` before forward function. Defaults to `None`. |
| key_pos (Tensor, optional): The positional encoding for `key`, has |
| the same shape as `key`. If not `None`, it will be added to |
| `key` before forward function. If None, and `query_pos` has the |
| same shape as `key`, then `query_pos` will be used for |
| `key_pos`. Defaults to None. |
| self_attn_mask (Tensor, optional): ByteTensor mask, has shape |
| (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. |
| Defaults to None. |
| cross_attn_mask (Tensor, optional): ByteTensor mask, has shape |
| (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. |
| Defaults to None. |
| key_padding_mask (Tensor, optional): The `key_padding_mask` of |
| `self_attn` input. ByteTensor, has shape (bs, num_value). |
| Defaults to None. |
| |
| Returns: |
| Tensor: forwarded results, has shape (bs, num_queries, dim). |
| """ |
|
|
| query = self.self_attn( |
| query=query, |
| key=query, |
| value=query, |
| query_pos=query_pos, |
| key_pos=query_pos, |
| attn_mask=self_attn_mask, |
| **kwargs) |
| query = self.norms[0](query) |
| query = self.cross_attn( |
| query=query, |
| key=key, |
| value=value, |
| query_pos=query_pos, |
| key_pos=key_pos, |
| attn_mask=cross_attn_mask, |
| key_padding_mask=key_padding_mask, |
| **kwargs) |
| query = self.norms[1](query) |
| query = self.ffn(query) |
| query = self.norms[2](query) |
|
|
| return query |
|
|