| | |
| | import copy |
| | import warnings |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from annotator.mmpkg.mmcv import ConfigDict, deprecated_api_warning |
| | from annotator.mmpkg.mmcv.cnn import Linear, build_activation_layer, build_norm_layer |
| | from annotator.mmpkg.mmcv.runner.base_module import BaseModule, ModuleList, Sequential |
| | from annotator.mmpkg.mmcv.utils import build_from_cfg |
| | from .drop import build_dropout |
| | from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, |
| | TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) |
| |
|
| | |
| | try: |
| | from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention |
| | warnings.warn( |
| | ImportWarning( |
| | '``MultiScaleDeformableAttention`` has been moved to ' |
| | '``mmcv.ops.multi_scale_deform_attn``, please change original path ' |
| | '``from annotator.mmpkg.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' |
| | 'to ``from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' |
| | )) |
| |
|
| | except ImportError: |
| | warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' |
| | '``mmcv.ops.multi_scale_deform_attn``, ' |
| | 'You should install ``mmcv-full`` if you need this module. ') |
| |
|
| |
|
| | def build_positional_encoding(cfg, default_args=None): |
| | """Builder for Position Encoding.""" |
| | return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args) |
| |
|
| |
|
| | def build_attention(cfg, default_args=None): |
| | """Builder for attention.""" |
| | return build_from_cfg(cfg, ATTENTION, default_args) |
| |
|
| |
|
| | def build_feedforward_network(cfg, default_args=None): |
| | """Builder for feed-forward network (FFN).""" |
| | return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args) |
| |
|
| |
|
| | def build_transformer_layer(cfg, default_args=None): |
| | """Builder for transformer layer.""" |
| | return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args) |
| |
|
| |
|
| | def build_transformer_layer_sequence(cfg, default_args=None): |
| | """Builder for transformer encoder and transformer decoder.""" |
| | return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args) |
| |
|
| |
|
| | @ATTENTION.register_module() |
| | class MultiheadAttention(BaseModule): |
| | """A wrapper for ``torch.nn.MultiheadAttention``. |
| | |
| | This module implements MultiheadAttention with identity connection, |
| | and positional encoding is also passed as input. |
| | |
| | Args: |
| | embed_dims (int): The embedding dimension. |
| | num_heads (int): Parallel attention heads. |
| | attn_drop (float): A Dropout layer on attn_output_weights. |
| | Default: 0.0. |
| | proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. |
| | Default: 0.0. |
| | dropout_layer (obj:`ConfigDict`): The dropout_layer used |
| | when adding the shortcut. |
| | init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
| | Default: None. |
| | batch_first (bool): When it is True, Key, Query and Value are shape of |
| | (batch, n, embed_dim), otherwise (n, batch, embed_dim). |
| | Default to False. |
| | """ |
| |
|
| | def __init__(self, |
| | embed_dims, |
| | num_heads, |
| | attn_drop=0., |
| | proj_drop=0., |
| | dropout_layer=dict(type='Dropout', drop_prob=0.), |
| | init_cfg=None, |
| | batch_first=False, |
| | **kwargs): |
| | super(MultiheadAttention, self).__init__(init_cfg) |
| | if 'dropout' in kwargs: |
| | warnings.warn('The arguments `dropout` in MultiheadAttention ' |
| | 'has been deprecated, now you can separately ' |
| | 'set `attn_drop`(float), proj_drop(float), ' |
| | 'and `dropout_layer`(dict) ') |
| | attn_drop = kwargs['dropout'] |
| | dropout_layer['drop_prob'] = kwargs.pop('dropout') |
| |
|
| | self.embed_dims = embed_dims |
| | self.num_heads = num_heads |
| | self.batch_first = batch_first |
| |
|
| | self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, |
| | **kwargs) |
| |
|
| | self.proj_drop = nn.Dropout(proj_drop) |
| | self.dropout_layer = build_dropout( |
| | dropout_layer) if dropout_layer else nn.Identity() |
| |
|
| | @deprecated_api_warning({'residual': 'identity'}, |
| | cls_name='MultiheadAttention') |
| | def forward(self, |
| | query, |
| | key=None, |
| | value=None, |
| | identity=None, |
| | query_pos=None, |
| | key_pos=None, |
| | attn_mask=None, |
| | key_padding_mask=None, |
| | **kwargs): |
| | """Forward function for `MultiheadAttention`. |
| | |
| | **kwargs allow passing a more general data flow when combining |
| | with other operations in `transformerlayer`. |
| | |
| | Args: |
| | query (Tensor): The input query with shape [num_queries, bs, |
| | embed_dims] if self.batch_first is False, else |
| | [bs, num_queries embed_dims]. |
| | key (Tensor): The key tensor with shape [num_keys, bs, |
| | embed_dims] if self.batch_first is False, else |
| | [bs, num_keys, embed_dims] . |
| | If None, the ``query`` will be used. Defaults to None. |
| | value (Tensor): The value tensor with same shape as `key`. |
| | Same in `nn.MultiheadAttention.forward`. Defaults to None. |
| | If None, the `key` will be used. |
| | identity (Tensor): This tensor, with the same shape as x, |
| | will be used for the identity link. |
| | If None, `x` will be used. Defaults to None. |
| | query_pos (Tensor): The positional encoding for query, with |
| | the same shape as `x`. If not None, it will |
| | be added to `x` before forward function. Defaults to None. |
| | key_pos (Tensor): The positional encoding for `key`, with the |
| | same shape as `key`. Defaults to None. If not None, it will |
| | be added to `key` before forward function. If None, and |
| | `query_pos` has the same shape as `key`, then `query_pos` |
| | will be used for `key_pos`. Defaults to None. |
| | attn_mask (Tensor): ByteTensor mask with shape [num_queries, |
| | num_keys]. Same in `nn.MultiheadAttention.forward`. |
| | Defaults to None. |
| | key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. |
| | Defaults to None. |
| | |
| | Returns: |
| | Tensor: forwarded results with shape |
| | [num_queries, bs, embed_dims] |
| | if self.batch_first is False, else |
| | [bs, num_queries embed_dims]. |
| | """ |
| |
|
| | if key is None: |
| | key = query |
| | if value is None: |
| | value = key |
| | if identity is None: |
| | identity = query |
| | if key_pos is None: |
| | if query_pos is not None: |
| | |
| | if query_pos.shape == key.shape: |
| | key_pos = query_pos |
| | else: |
| | warnings.warn(f'position encoding of key is' |
| | f'missing in {self.__class__.__name__}.') |
| | if query_pos is not None: |
| | query = query + query_pos |
| | if key_pos is not None: |
| | key = key + key_pos |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.batch_first: |
| | query = query.transpose(0, 1) |
| | key = key.transpose(0, 1) |
| | value = value.transpose(0, 1) |
| |
|
| | out = self.attn( |
| | query=query, |
| | key=key, |
| | value=value, |
| | attn_mask=attn_mask, |
| | key_padding_mask=key_padding_mask)[0] |
| |
|
| | if self.batch_first: |
| | out = out.transpose(0, 1) |
| |
|
| | return identity + self.dropout_layer(self.proj_drop(out)) |
| |
|
| |
|
| | @FEEDFORWARD_NETWORK.register_module() |
| | class FFN(BaseModule): |
| | """Implements feed-forward networks (FFNs) with identity connection. |
| | |
| | Args: |
| | embed_dims (int): The feature dimension. Same as |
| | `MultiheadAttention`. Defaults: 256. |
| | feedforward_channels (int): The hidden dimension of FFNs. |
| | Defaults: 1024. |
| | num_fcs (int, optional): The number of fully-connected layers in |
| | FFNs. Default: 2. |
| | act_cfg (dict, optional): The activation config for FFNs. |
| | Default: dict(type='ReLU') |
| | ffn_drop (float, optional): Probability of an element to be |
| | zeroed in FFN. Default 0.0. |
| | add_identity (bool, optional): Whether to add the |
| | identity connection. Default: `True`. |
| | dropout_layer (obj:`ConfigDict`): The dropout_layer used |
| | when adding the shortcut. |
| | init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
| | Default: None. |
| | """ |
| |
|
| | @deprecated_api_warning( |
| | { |
| | 'dropout': 'ffn_drop', |
| | 'add_residual': 'add_identity' |
| | }, |
| | cls_name='FFN') |
| | def __init__(self, |
| | embed_dims=256, |
| | feedforward_channels=1024, |
| | num_fcs=2, |
| | act_cfg=dict(type='ReLU', inplace=True), |
| | ffn_drop=0., |
| | dropout_layer=None, |
| | add_identity=True, |
| | init_cfg=None, |
| | **kwargs): |
| | super(FFN, self).__init__(init_cfg) |
| | assert num_fcs >= 2, 'num_fcs should be no less ' \ |
| | f'than 2. got {num_fcs}.' |
| | self.embed_dims = embed_dims |
| | self.feedforward_channels = feedforward_channels |
| | self.num_fcs = num_fcs |
| | self.act_cfg = act_cfg |
| | self.activate = build_activation_layer(act_cfg) |
| |
|
| | layers = [] |
| | in_channels = embed_dims |
| | for _ in range(num_fcs - 1): |
| | layers.append( |
| | Sequential( |
| | Linear(in_channels, feedforward_channels), self.activate, |
| | nn.Dropout(ffn_drop))) |
| | in_channels = feedforward_channels |
| | layers.append(Linear(feedforward_channels, embed_dims)) |
| | layers.append(nn.Dropout(ffn_drop)) |
| | self.layers = Sequential(*layers) |
| | self.dropout_layer = build_dropout( |
| | dropout_layer) if dropout_layer else torch.nn.Identity() |
| | self.add_identity = add_identity |
| |
|
| | @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') |
| | def forward(self, x, identity=None): |
| | """Forward function for `FFN`. |
| | |
| | The function would add x to the output tensor if residue is None. |
| | """ |
| | out = self.layers(x) |
| | if not self.add_identity: |
| | return self.dropout_layer(out) |
| | if identity is None: |
| | identity = x |
| | return identity + self.dropout_layer(out) |
| |
|
| |
|
| | @TRANSFORMER_LAYER.register_module() |
| | class BaseTransformerLayer(BaseModule): |
| | """Base `TransformerLayer` for vision transformer. |
| | |
| | It can be built from `mmcv.ConfigDict` and support more flexible |
| | customization, for example, using any number of `FFN or LN ` and |
| | use different kinds of `attention` by specifying a list of `ConfigDict` |
| | named `attn_cfgs`. It is worth mentioning that it supports `prenorm` |
| | when you specifying `norm` as the first element of `operation_order`. |
| | More details about the `prenorm`: `On Layer Normalization in the |
| | Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ . |
| | |
| | Args: |
| | attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): |
| | Configs for `self_attention` or `cross_attention` modules, |
| | The order of the configs in the list should be consistent with |
| | corresponding attentions in operation_order. |
| | If it is a dict, all of the attention modules in operation_order |
| | will be built with this config. Default: None. |
| | ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): |
| | Configs for FFN, The order of the configs in the list should be |
| | consistent with corresponding ffn in operation_order. |
| | If it is a dict, all of the attention modules in operation_order |
| | will be built with this config. |
| | operation_order (tuple[str]): The execution order of operation |
| | in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). |
| | Support `prenorm` when you specifying first element as `norm`. |
| | Default:None. |
| | norm_cfg (dict): Config dict for normalization layer. |
| | Default: dict(type='LN'). |
| | init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
| | Default: None. |
| | batch_first (bool): Key, Query and Value are shape |
| | of (batch, n, embed_dim) |
| | or (n, batch, embed_dim). Default to False. |
| | """ |
| |
|
| | def __init__(self, |
| | attn_cfgs=None, |
| | ffn_cfgs=dict( |
| | type='FFN', |
| | embed_dims=256, |
| | feedforward_channels=1024, |
| | num_fcs=2, |
| | ffn_drop=0., |
| | act_cfg=dict(type='ReLU', inplace=True), |
| | ), |
| | operation_order=None, |
| | norm_cfg=dict(type='LN'), |
| | init_cfg=None, |
| | batch_first=False, |
| | **kwargs): |
| |
|
| | deprecated_args = dict( |
| | feedforward_channels='feedforward_channels', |
| | ffn_dropout='ffn_drop', |
| | ffn_num_fcs='num_fcs') |
| | for ori_name, new_name in deprecated_args.items(): |
| | if ori_name in kwargs: |
| | warnings.warn( |
| | f'The arguments `{ori_name}` in BaseTransformerLayer ' |
| | f'has been deprecated, now you should set `{new_name}` ' |
| | f'and other FFN related arguments ' |
| | f'to a dict named `ffn_cfgs`. ') |
| | ffn_cfgs[new_name] = kwargs[ori_name] |
| |
|
| | super(BaseTransformerLayer, self).__init__(init_cfg) |
| |
|
| | self.batch_first = batch_first |
| |
|
| | assert set(operation_order) & set( |
| | ['self_attn', 'norm', 'ffn', 'cross_attn']) == \ |
| | set(operation_order), f'The operation_order of' \ |
| | f' {self.__class__.__name__} should ' \ |
| | f'contains all four operation type ' \ |
| | f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" |
| |
|
| | num_attn = operation_order.count('self_attn') + operation_order.count( |
| | 'cross_attn') |
| | if isinstance(attn_cfgs, dict): |
| | attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] |
| | else: |
| | assert num_attn == len(attn_cfgs), f'The length ' \ |
| | f'of attn_cfg {num_attn} is ' \ |
| | f'not consistent with the number of attention' \ |
| | f'in operation_order {operation_order}.' |
| |
|
| | self.num_attn = num_attn |
| | self.operation_order = operation_order |
| | self.norm_cfg = norm_cfg |
| | self.pre_norm = operation_order[0] == 'norm' |
| | self.attentions = ModuleList() |
| |
|
| | index = 0 |
| | for operation_name in operation_order: |
| | if operation_name in ['self_attn', 'cross_attn']: |
| | if 'batch_first' in attn_cfgs[index]: |
| | assert self.batch_first == attn_cfgs[index]['batch_first'] |
| | else: |
| | attn_cfgs[index]['batch_first'] = self.batch_first |
| | attention = build_attention(attn_cfgs[index]) |
| | |
| | |
| | attention.operation_name = operation_name |
| | self.attentions.append(attention) |
| | index += 1 |
| |
|
| | self.embed_dims = self.attentions[0].embed_dims |
| |
|
| | self.ffns = ModuleList() |
| | num_ffns = operation_order.count('ffn') |
| | if isinstance(ffn_cfgs, dict): |
| | ffn_cfgs = ConfigDict(ffn_cfgs) |
| | if isinstance(ffn_cfgs, dict): |
| | ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] |
| | assert len(ffn_cfgs) == num_ffns |
| | for ffn_index in range(num_ffns): |
| | if 'embed_dims' not in ffn_cfgs[ffn_index]: |
| | ffn_cfgs['embed_dims'] = self.embed_dims |
| | else: |
| | assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims |
| | self.ffns.append( |
| | build_feedforward_network(ffn_cfgs[ffn_index], |
| | dict(type='FFN'))) |
| |
|
| | self.norms = ModuleList() |
| | num_norms = operation_order.count('norm') |
| | for _ in range(num_norms): |
| | self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) |
| |
|
| | def forward(self, |
| | query, |
| | key=None, |
| | value=None, |
| | query_pos=None, |
| | key_pos=None, |
| | attn_masks=None, |
| | query_key_padding_mask=None, |
| | key_padding_mask=None, |
| | **kwargs): |
| | """Forward function for `TransformerDecoderLayer`. |
| | |
| | **kwargs contains some specific arguments of attentions. |
| | |
| | Args: |
| | query (Tensor): The input query with shape |
| | [num_queries, bs, embed_dims] if |
| | self.batch_first is False, else |
| | [bs, num_queries embed_dims]. |
| | key (Tensor): The key tensor with shape [num_keys, bs, |
| | embed_dims] if self.batch_first is False, else |
| | [bs, num_keys, embed_dims] . |
| | value (Tensor): The value tensor with same shape as `key`. |
| | query_pos (Tensor): The positional encoding for `query`. |
| | Default: None. |
| | key_pos (Tensor): The positional encoding for `key`. |
| | Default: None. |
| | attn_masks (List[Tensor] | None): 2D Tensor used in |
| | calculation of corresponding attention. The length of |
| | it should equal to the number of `attention` in |
| | `operation_order`. Default: None. |
| | query_key_padding_mask (Tensor): ByteTensor for `query`, with |
| | shape [bs, num_queries]. Only used in `self_attn` layer. |
| | Defaults to None. |
| | key_padding_mask (Tensor): ByteTensor for `query`, with |
| | shape [bs, num_keys]. Default: None. |
| | |
| | Returns: |
| | Tensor: forwarded results with shape [num_queries, bs, embed_dims]. |
| | """ |
| |
|
| | norm_index = 0 |
| | attn_index = 0 |
| | ffn_index = 0 |
| | identity = query |
| | if attn_masks is None: |
| | attn_masks = [None for _ in range(self.num_attn)] |
| | elif isinstance(attn_masks, torch.Tensor): |
| | attn_masks = [ |
| | copy.deepcopy(attn_masks) for _ in range(self.num_attn) |
| | ] |
| | warnings.warn(f'Use same attn_mask in all attentions in ' |
| | f'{self.__class__.__name__} ') |
| | else: |
| | assert len(attn_masks) == self.num_attn, f'The length of ' \ |
| | f'attn_masks {len(attn_masks)} must be equal ' \ |
| | f'to the number of attention in ' \ |
| | f'operation_order {self.num_attn}' |
| |
|
| | for layer in self.operation_order: |
| | if layer == 'self_attn': |
| | temp_key = temp_value = query |
| | query = self.attentions[attn_index]( |
| | query, |
| | temp_key, |
| | temp_value, |
| | identity if self.pre_norm else None, |
| | query_pos=query_pos, |
| | key_pos=query_pos, |
| | attn_mask=attn_masks[attn_index], |
| | key_padding_mask=query_key_padding_mask, |
| | **kwargs) |
| | attn_index += 1 |
| | identity = query |
| |
|
| | elif layer == 'norm': |
| | query = self.norms[norm_index](query) |
| | norm_index += 1 |
| |
|
| | elif layer == 'cross_attn': |
| | query = self.attentions[attn_index]( |
| | query, |
| | key, |
| | value, |
| | identity if self.pre_norm else None, |
| | query_pos=query_pos, |
| | key_pos=key_pos, |
| | attn_mask=attn_masks[attn_index], |
| | key_padding_mask=key_padding_mask, |
| | **kwargs) |
| | attn_index += 1 |
| | identity = query |
| |
|
| | elif layer == 'ffn': |
| | query = self.ffns[ffn_index]( |
| | query, identity if self.pre_norm else None) |
| | ffn_index += 1 |
| |
|
| | return query |
| |
|
| |
|
| | @TRANSFORMER_LAYER_SEQUENCE.register_module() |
| | class TransformerLayerSequence(BaseModule): |
| | """Base class for TransformerEncoder and TransformerDecoder in vision |
| | transformer. |
| | |
| | As base-class of Encoder and Decoder in vision transformer. |
| | Support customization such as specifying different kind |
| | of `transformer_layer` in `transformer_coder`. |
| | |
| | Args: |
| | transformerlayer (list[obj:`mmcv.ConfigDict`] | |
| | obj:`mmcv.ConfigDict`): Config of transformerlayer |
| | in TransformerCoder. If it is obj:`mmcv.ConfigDict`, |
| | it would be repeated `num_layer` times to a |
| | list[`mmcv.ConfigDict`]. Default: None. |
| | num_layers (int): The number of `TransformerLayer`. Default: None. |
| | init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
| | Default: None. |
| | """ |
| |
|
| | def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): |
| | super(TransformerLayerSequence, self).__init__(init_cfg) |
| | if isinstance(transformerlayers, dict): |
| | transformerlayers = [ |
| | copy.deepcopy(transformerlayers) for _ in range(num_layers) |
| | ] |
| | else: |
| | assert isinstance(transformerlayers, list) and \ |
| | len(transformerlayers) == num_layers |
| | self.num_layers = num_layers |
| | self.layers = ModuleList() |
| | for i in range(num_layers): |
| | self.layers.append(build_transformer_layer(transformerlayers[i])) |
| | self.embed_dims = self.layers[0].embed_dims |
| | self.pre_norm = self.layers[0].pre_norm |
| |
|
| | def forward(self, |
| | query, |
| | key, |
| | value, |
| | query_pos=None, |
| | key_pos=None, |
| | attn_masks=None, |
| | query_key_padding_mask=None, |
| | key_padding_mask=None, |
| | **kwargs): |
| | """Forward function for `TransformerCoder`. |
| | |
| | Args: |
| | query (Tensor): Input query with shape |
| | `(num_queries, bs, embed_dims)`. |
| | key (Tensor): The key tensor with shape |
| | `(num_keys, bs, embed_dims)`. |
| | value (Tensor): The value tensor with shape |
| | `(num_keys, bs, embed_dims)`. |
| | query_pos (Tensor): The positional encoding for `query`. |
| | Default: None. |
| | key_pos (Tensor): The positional encoding for `key`. |
| | Default: None. |
| | attn_masks (List[Tensor], optional): Each element is 2D Tensor |
| | which is used in calculation of corresponding attention in |
| | operation_order. Default: None. |
| | query_key_padding_mask (Tensor): ByteTensor for `query`, with |
| | shape [bs, num_queries]. Only used in self-attention |
| | Default: None. |
| | key_padding_mask (Tensor): ByteTensor for `query`, with |
| | shape [bs, num_keys]. Default: None. |
| | |
| | Returns: |
| | Tensor: results with shape [num_queries, bs, embed_dims]. |
| | """ |
| | for layer in self.layers: |
| | query = layer( |
| | query, |
| | key, |
| | value, |
| | query_pos=query_pos, |
| | key_pos=key_pos, |
| | attn_masks=attn_masks, |
| | query_key_padding_mask=query_key_padding_mask, |
| | key_padding_mask=key_padding_mask, |
| | **kwargs) |
| | return query |
| |
|