unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
import torch.nn.functional as F
from mmcv import ConfigDict, deprecated_api_warning
from mmcv.cnn.bricks.wrappers import Linear
from mmcv.cnn.bricks.activation import build_activation_layer
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
# from mmcv.models.bricks import Linear, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, LoRAMoECLAdapter, MOELoRALinear,
finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper)
from mmdet3d_plugin.uniad.custom_modules.custom_mha_function import custom_multi_head_attention_forward
def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding."""
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
def build_attention(cfg, default_args=None):
"""Builder for attention."""
return build_from_cfg(cfg, ATTENTION, default_args)
def build_feedforward_network(cfg, default_args=None):
"""Builder for feed-forward network (FFN)."""
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
def build_transformer_layer_sequence(cfg, default_args=None):
"""Builder for transformer encoder and transformer decoder."""
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
@ATTENTION.register_module(force=True)
class MultiheadAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): When it is True, Key, Query and Value are shape of
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
Default to False.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='Dropout', drop_prob=0.),
init_cfg=None,
batch_first=False,
with_cp=False,
use_lora=False,
multi_lora_task=False,
lora_rank=16,
num_task=6,
moe_lora=False,
**kwargs):
super(MultiheadAttention, self).__init__(init_cfg)
if 'dropout' in kwargs:
warnings.warn('The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ')
attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout')
self.embed_dims = embed_dims
self.num_heads = num_heads
self.batch_first = batch_first
self.moe_lora = moe_lora
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
**kwargs)
self.use_lora = use_lora
self.multi_lora_task = multi_lora_task
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
self.with_cp = with_cp
if self.use_lora:
# freeze param in attn
for param in self.attn.parameters():
param.requires_grad = False
if self.multi_lora_task:
self.q_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank)
self.k_lora = LoRACLAdapter(embed_dims, embed_dims, num_task=num_task, r=lora_rank)
self.v_lora = LoRACLAdapter(embed_dims, embed_dims, num_task=num_task, r=lora_rank)
elif moe_lora:
self.q_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank)
self.k_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank)
self.v_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank)
else:
self.q_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank)
self.k_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank)
self.v_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank)
if self.moe_lora:
self.out_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank)
else:
self.out_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank)
self.q_proj_weight = self.attn.in_proj_weight[:embed_dims, :]
self.k_proj_weight = self.attn.in_proj_weight[embed_dims:2*embed_dims, :]
self.v_proj_weight = self.attn.in_proj_weight[2*embed_dims:, :]
self.q_proj_bias = self.attn.in_proj_bias[:embed_dims]
self.k_proj_bias = self.attn.in_proj_bias[embed_dims:2*embed_dims]
self.v_proj_bias = self.attn.in_proj_bias[2*embed_dims:]
finetuning_detach(self)
def lora_attn_forward(
self,
query,
key,
value,
attn_mask,
key_padding_mask,
task_mask=None,
task_idx=None,
):
if self.attn.batch_first:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
lora_query = self.q_lora(query, task_mask, task_idx=task_idx)
lora_key = self.k_lora(key, task_mask, task_idx=task_idx)
lora_value = self.v_lora(value, task_mask, task_idx=task_idx)
# print(query.device, self.q_proj_weight.device, self.q_proj_bias.device)
self.q_proj_weight = self.q_proj_weight.cuda()
self.q_proj_bias = self.q_proj_bias.cuda()
self.k_proj_weight = self.k_proj_weight.cuda()
self.k_proj_bias = self.k_proj_bias.cuda()
self.v_proj_weight = self.v_proj_weight.cuda()
self.v_proj_bias = self.v_proj_bias.cuda()
new_query = F.linear(query, self.q_proj_weight, self.q_proj_bias)
new_key = F.linear(key, self.k_proj_weight, self.k_proj_bias)
new_value = F.linear(value, self.v_proj_weight, self.v_proj_bias)
query = new_query.detach() + lora_query
key = new_key.detach() + lora_key
value = new_value.detach() + lora_value
attn_output, attn_output_weights = custom_multi_head_attention_forward(
query, key, value, self.attn.embed_dim, self.attn.num_heads,
in_proj_weight=None, # Projections are already done manually
in_proj_bias=None, # Bias is not used
bias_k=None, bias_v=None, add_zero_attn=self.attn.add_zero_attn,
dropout_p=self.attn.dropout,
out_proj_weight=self.attn.out_proj.weight, out_proj_bias=self.attn.out_proj.bias,
training=self.attn.training,
key_padding_mask=key_padding_mask, need_weights=True,
use_separate_proj_weight=True,
use_direct_input=True,
attn_mask=attn_mask,
)
if self.attn.batch_first:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiheadAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
task_mask=None,
forward_origin=False,
task_idx=None,
**kwargs):
"""Forward function for `MultiheadAttention`.
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
identity (Tensor): This tensor, with the same shape as x,
will be used for the identity link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
be added to `x` before forward function. Defaults to None.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Defaults to None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`. Defaults to None.
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Defaults to None.
Returns:
Tensor: forwarded results with shape
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(f'position encoding of key is'
f'missing in {self.__class__.__name__}.')
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
if self.with_cp:
out = cp.checkpoint(self.attn, use_reentrant=False, query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
else:
if self.use_lora and forward_origin==False:
out = self.lora_attn_forward(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
task_mask=task_mask,
task_idx=task_idx)[0]
else:
out = self.attn(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
if self.batch_first:
out = out.transpose(0, 1)
if self.use_lora:
out = out + self.out_lora(out, i=task_idx)
return identity + self.dropout_layer(self.proj_drop(out))
#forceful register
@FEEDFORWARD_NETWORK.register_module(force=True)
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
@deprecated_api_warning(
{
'dropout': 'ffn_drop',
'add_residual': 'add_identity'
},
cls_name='FFN')
def __init__(self,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.,
dropout_layer=None,
add_identity=True,
init_cfg=None,
use_lora=False,
lora_rank=16,
use_adapter=False,
adatper_num=6,
lora_moe=False,
num_task=6,
**kwargs):
super(FFN, self).__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
self.lora_moe = lora_moe
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
Sequential(
Linear(in_channels, feedforward_channels), self.activate,
nn.Dropout(ffn_drop)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.use_lora = use_lora
self.use_adapter = use_adapter
if use_adapter:
assert use_lora==True
self.layers = Sequential(*layers)
if self.use_lora:
if self.use_adapter:
self.lora_layers = LoRAMoECLAdapter(embed_dims, feedforward_channels, embed_dims,
num_task=adatper_num, r=lora_rank, dropout=ffn_drop)
else:
lora_layer = MOELoRALinear if self.lora_moe else LoRALinear
self.lora_layers = lora_wrapper(self.layers, LoraLayer=lora_layer, rank=lora_rank, dropout=0.0,num_task=num_task)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity
if self.use_lora:
self.layers = frozen_grad(self.layers)
finetuning_detach(self)
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None, forward_origin=False, task_idx=None,**kwargs):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
org_x = x.clone()
if (not self.use_lora) or forward_origin:
out = self.layers(x)
else:
if self.use_adapter:
out = self.layers(x).detach() + self.lora_layers(x)
else:
out = peft_wrapper_forward(x, self.layers, self.lora_layers, task_idx=task_idx)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
identity = org_x
return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module(force=True)
class BaseTransformerLayer(BaseModule):
"""Base `TransformerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def __init__(self,
attn_cfgs=None,
with_cp=False,
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=None,
norm_cfg=dict(type='LN'),
init_cfg=None,
batch_first=False,
use_lora=False,
lora_rank=16,
ffn_use_lora=False,
ffn_lora_rank=16,
ffn_use_adapter=False,
ffn_adapter_num=6,
moe_lora=False,
num_task=6,
**kwargs):
deprecated_args = dict(
feedforward_channels='feedforward_channels',
ffn_dropout='ffn_drop',
ffn_num_fcs='num_fcs')
for ori_name, new_name in deprecated_args.items():
if ori_name in kwargs:
warnings.warn(
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg)
self.batch_first = batch_first
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
attn_use_lora = False
if isinstance(attn_cfgs, dict):
if 'use_lora' in attn_cfgs:
attn_use_lora = attn_cfgs['use_lora']
else:
attn_use_lora = False
if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.num_attn = num_attn
self.operation_order = operation_order
self.norm_cfg = norm_cfg
self.pre_norm = operation_order[0] == 'norm'
self.attentions = ModuleList()
self.use_lora = use_lora
index = 0
have_use_lora = False
for attn_cfg in attn_cfgs:
if 'use_lora' in attn_cfg:
if attn_cfg['use_lora'] == True:
have_use_lora = True
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
if have_use_lora:
if 'use_lora' not in attn_cfgs[index].keys():
for param in attention.parameters():
param.requires_grad = False
else:
if attn_cfgs[index]['use_lora']==False:
for param in attention.parameters():
param.requires_grad = False
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = ModuleList()
ffn_cfgs['use_lora'] = ffn_use_lora
ffn_cfgs['lora_rank'] = ffn_lora_rank
ffn_cfgs['use_adapter'] = ffn_use_adapter
ffn_cfgs['adapter_num'] = ffn_adapter_num
ffn_cfgs['moe_lora'] = moe_lora
ffn_cfgs['num_task'] = num_task
self.freeze_ffn = False
if ffn_use_lora==False:
self.freeze_ffn = True
num_ffns = operation_order.count('ffn')
if isinstance(ffn_cfgs, dict):
ffn_cfgs = ConfigDict(ffn_cfgs)
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(
build_feedforward_network(ffn_cfgs[ffn_index],
dict(type='FFN')))
self.norms = ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
self.with_cp = with_cp
if self.freeze_ffn:
for param in self.ffns.parameters():
param.requires_grad = False
def forward(self,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
task_mask=None,
forward_origin=False,
task_idx=None,
**kwargs):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
for layer in self.operation_order:
if layer == 'self_attn':
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
forward_origin=forward_origin,
task_idx=task_idx,
**kwargs)
attn_index += 1
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
task_mask=task_mask,
forward_origin=forward_origin,
task_idx=task_idx,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn':
if self.with_cp:
query = cp.checkpoint(self.ffns[ffn_index], query)
else:
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None,
forward_origin=forward_origin,
task_idx=task_idx,)
ffn_index += 1
return query
@TRANSFORMER_LAYER_SEQUENCE.register_module(force=True)
class TransformerLayerSequence(BaseModule):
"""Base class for TransformerEncoder and TransformerDecoder in vision
transformer.
As base-class of Encoder and Decoder in vision transformer.
Support customization such as specifying different kind
of `transformer_layer` in `transformer_coder`.
Args:
transformerlayer (list[obj:`mmcv.ConfigDict`] |
obj:`mmcv.ConfigDict`): Config of transformerlayer
in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
it would be repeated `num_layer` times to a
list[`mmcv.ConfigDict`]. Default: None.
num_layers (int): The number of `TransformerLayer`. Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg)
if isinstance(transformerlayers, dict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)
]
else:
assert isinstance(transformerlayers, list) and \
len(transformerlayers) == num_layers
self.num_layers = num_layers
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(build_transformer_layer(transformerlayers[i]))
self.embed_dims = self.layers[0].embed_dims
self.pre_norm = self.layers[0].pre_norm
def forward(self,
query,
key,
value,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
task_mask=None,
forward_origin=False,
task_idx=None,
**kwargs):
"""Forward function for `TransformerCoder`.
Args:
query (Tensor): Input query with shape
`(num_queries, bs, embed_dims)`.
key (Tensor): The key tensor with shape
`(num_keys, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_keys, bs, embed_dims)`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor], optional): Each element is 2D Tensor
which is used in calculation of corresponding attention in
operation_order. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in self-attention
Default: None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: results with shape [num_queries, bs, embed_dims].
"""
for layer in self.layers:
query = layer(
query,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
task_mask=task_mask,
forward_origin=forward_origin,
task_idx=task_idx,
**kwargs)
return query
def transformer_test():
model = MultiheadAttention(
64, 4, use_lora=True
)
model = FFN(use_lora=True)
# finetuning_detach(model)
print("Model structure after attaching LoRA layers:\n", model)
for name, param in model.named_parameters():
print(name, param.shape, param.requires_grad)
if __name__=='__main__':
transformer_test()