# Copyright (c) OpenMMLab. All rights reserved. import copy import warnings import torch import torch.nn as nn import torch.utils.checkpoint as cp import torch.nn.functional as F from mmcv import ConfigDict, deprecated_api_warning from mmcv.cnn.bricks.wrappers import Linear from mmcv.cnn.bricks.activation import build_activation_layer from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.drop import build_dropout # from mmcv.models.bricks import Linear, build_activation_layer, build_norm_layer from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.utils import build_from_cfg from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) from mmdet3d_plugin.uniad.custom_modules.peft import (LoRALinear, ZeroAdapter, LoRACLAdapter, LoRAMoECLAdapter, MOELoRALinear, finetuning_detach, frozen_grad, peft_wrapper_forward, lora_wrapper) from mmdet3d_plugin.uniad.custom_modules.custom_mha_function import custom_multi_head_attention_forward def build_positional_encoding(cfg, default_args=None): """Builder for Position Encoding.""" return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args) def build_attention(cfg, default_args=None): """Builder for attention.""" return build_from_cfg(cfg, ATTENTION, default_args) def build_feedforward_network(cfg, default_args=None): """Builder for feed-forward network (FFN).""" return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args) def build_transformer_layer(cfg, default_args=None): """Builder for transformer layer.""" return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args) def build_transformer_layer_sequence(cfg, default_args=None): """Builder for transformer encoder and transformer decoder.""" return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args) @ATTENTION.register_module(force=True) class MultiheadAttention(BaseModule): """A wrapper for ``torch.nn.MultiheadAttention``. This module implements MultiheadAttention with identity connection, and positional encoding is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): When it is True, Key, Query and Value are shape of (batch, n, embed_dim), otherwise (n, batch, embed_dim). Default to False. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=dict(type='Dropout', drop_prob=0.), init_cfg=None, batch_first=False, with_cp=False, use_lora=False, multi_lora_task=False, lora_rank=16, num_task=6, moe_lora=False, **kwargs): super(MultiheadAttention, self).__init__(init_cfg) if 'dropout' in kwargs: warnings.warn('The arguments `dropout` in MultiheadAttention ' 'has been deprecated, now you can separately ' 'set `attn_drop`(float), proj_drop(float), ' 'and `dropout_layer`(dict) ') attn_drop = kwargs['dropout'] dropout_layer['drop_prob'] = kwargs.pop('dropout') self.embed_dims = embed_dims self.num_heads = num_heads self.batch_first = batch_first self.moe_lora = moe_lora self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) self.use_lora = use_lora self.multi_lora_task = multi_lora_task self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else nn.Identity() self.with_cp = with_cp if self.use_lora: # freeze param in attn for param in self.attn.parameters(): param.requires_grad = False if self.multi_lora_task: self.q_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank) self.k_lora = LoRACLAdapter(embed_dims, embed_dims, num_task=num_task, r=lora_rank) self.v_lora = LoRACLAdapter(embed_dims, embed_dims, num_task=num_task, r=lora_rank) elif moe_lora: self.q_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank) self.k_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank) self.v_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank) else: self.q_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank) self.k_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank) self.v_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank) if self.moe_lora: self.out_lora = MOELoRALinear(embed_dims, embed_dims, num_task=num_task, r=lora_rank) else: self.out_lora = LoRALinear(embed_dims, embed_dims, r=lora_rank) self.q_proj_weight = self.attn.in_proj_weight[:embed_dims, :] self.k_proj_weight = self.attn.in_proj_weight[embed_dims:2*embed_dims, :] self.v_proj_weight = self.attn.in_proj_weight[2*embed_dims:, :] self.q_proj_bias = self.attn.in_proj_bias[:embed_dims] self.k_proj_bias = self.attn.in_proj_bias[embed_dims:2*embed_dims] self.v_proj_bias = self.attn.in_proj_bias[2*embed_dims:] finetuning_detach(self) def lora_attn_forward( self, query, key, value, attn_mask, key_padding_mask, task_mask=None, task_idx=None, ): if self.attn.batch_first: query, key, value = [x.transpose(1, 0) for x in (query, key, value)] lora_query = self.q_lora(query, task_mask, task_idx=task_idx) lora_key = self.k_lora(key, task_mask, task_idx=task_idx) lora_value = self.v_lora(value, task_mask, task_idx=task_idx) # print(query.device, self.q_proj_weight.device, self.q_proj_bias.device) self.q_proj_weight = self.q_proj_weight.cuda() self.q_proj_bias = self.q_proj_bias.cuda() self.k_proj_weight = self.k_proj_weight.cuda() self.k_proj_bias = self.k_proj_bias.cuda() self.v_proj_weight = self.v_proj_weight.cuda() self.v_proj_bias = self.v_proj_bias.cuda() new_query = F.linear(query, self.q_proj_weight, self.q_proj_bias) new_key = F.linear(key, self.k_proj_weight, self.k_proj_bias) new_value = F.linear(value, self.v_proj_weight, self.v_proj_bias) query = new_query.detach() + lora_query key = new_key.detach() + lora_key value = new_value.detach() + lora_value attn_output, attn_output_weights = custom_multi_head_attention_forward( query, key, value, self.attn.embed_dim, self.attn.num_heads, in_proj_weight=None, # Projections are already done manually in_proj_bias=None, # Bias is not used bias_k=None, bias_v=None, add_zero_attn=self.attn.add_zero_attn, dropout_p=self.attn.dropout, out_proj_weight=self.attn.out_proj.weight, out_proj_bias=self.attn.out_proj.bias, training=self.attn.training, key_padding_mask=key_padding_mask, need_weights=True, use_separate_proj_weight=True, use_direct_input=True, attn_mask=attn_mask, ) if self.attn.batch_first: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights @deprecated_api_warning({'residual': 'identity'}, cls_name='MultiheadAttention') def forward(self, query, key=None, value=None, identity=None, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, task_mask=None, forward_origin=False, task_idx=None, **kwargs): """Forward function for `MultiheadAttention`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. """ if key is None: key = query if value is None: value = key if identity is None: identity = query if key_pos is None: if query_pos is not None: # use query_pos if key_pos is not available if query_pos.shape == key.shape: key_pos = query_pos else: warnings.warn(f'position encoding of key is' f'missing in {self.__class__.__name__}.') if query_pos is not None: query = query + query_pos if key_pos is not None: key = key + key_pos # Because the dataflow('key', 'query', 'value') of # ``torch.nn.MultiheadAttention`` is (num_query, batch, # embed_dims), We should adjust the shape of dataflow from # batch_first (batch, num_query, embed_dims) to num_query_first # (num_query ,batch, embed_dims), and recover ``attn_output`` # from num_query_first to batch_first. if self.batch_first: query = query.transpose(0, 1) key = key.transpose(0, 1) value = value.transpose(0, 1) if self.with_cp: out = cp.checkpoint(self.attn, use_reentrant=False, query=query, key=key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] else: if self.use_lora and forward_origin==False: out = self.lora_attn_forward( query=query, key=key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask, task_mask=task_mask, task_idx=task_idx)[0] else: out = self.attn( query=query, key=key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] if self.batch_first: out = out.transpose(0, 1) if self.use_lora: out = out + self.out_lora(out, i=task_idx) return identity + self.dropout_layer(self.proj_drop(out)) #forceful register @FEEDFORWARD_NETWORK.register_module(force=True) class FFN(BaseModule): """Implements feed-forward networks (FFNs) with identity connection. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. Defaults: 256. feedforward_channels (int): The hidden dimension of FFNs. Defaults: 1024. num_fcs (int, optional): The number of fully-connected layers in FFNs. Default: 2. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='ReLU') ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Default 0.0. add_identity (bool, optional): Whether to add the identity connection. Default: `True`. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ @deprecated_api_warning( { 'dropout': 'ffn_drop', 'add_residual': 'add_identity' }, cls_name='FFN') def __init__(self, embed_dims=256, feedforward_channels=1024, num_fcs=2, act_cfg=dict(type='ReLU', inplace=True), ffn_drop=0., dropout_layer=None, add_identity=True, init_cfg=None, use_lora=False, lora_rank=16, use_adapter=False, adatper_num=6, lora_moe=False, num_task=6, **kwargs): super(FFN, self).__init__(init_cfg) assert num_fcs >= 2, 'num_fcs should be no less ' \ f'than 2. got {num_fcs}.' self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.num_fcs = num_fcs self.act_cfg = act_cfg self.activate = build_activation_layer(act_cfg) self.lora_moe = lora_moe layers = [] in_channels = embed_dims for _ in range(num_fcs - 1): layers.append( Sequential( Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop))) in_channels = feedforward_channels layers.append(Linear(feedforward_channels, embed_dims)) layers.append(nn.Dropout(ffn_drop)) self.use_lora = use_lora self.use_adapter = use_adapter if use_adapter: assert use_lora==True self.layers = Sequential(*layers) if self.use_lora: if self.use_adapter: self.lora_layers = LoRAMoECLAdapter(embed_dims, feedforward_channels, embed_dims, num_task=adatper_num, r=lora_rank, dropout=ffn_drop) else: lora_layer = MOELoRALinear if self.lora_moe else LoRALinear self.lora_layers = lora_wrapper(self.layers, LoraLayer=lora_layer, rank=lora_rank, dropout=0.0,num_task=num_task) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() self.add_identity = add_identity if self.use_lora: self.layers = frozen_grad(self.layers) finetuning_detach(self) @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') def forward(self, x, identity=None, forward_origin=False, task_idx=None,**kwargs): """Forward function for `FFN`. The function would add x to the output tensor if residue is None. """ org_x = x.clone() if (not self.use_lora) or forward_origin: out = self.layers(x) else: if self.use_adapter: out = self.layers(x).detach() + self.lora_layers(x) else: out = peft_wrapper_forward(x, self.layers, self.lora_layers, task_idx=task_idx) if not self.add_identity: return self.dropout_layer(out) if identity is None: identity = org_x return identity + self.dropout_layer(out) @TRANSFORMER_LAYER.register_module(force=True) class BaseTransformerLayer(BaseModule): """Base `TransformerLayer` for vision transformer. It can be built from `mmcv.ConfigDict` and support more flexible customization, for example, using any number of `FFN or LN ` and use different kinds of `attention` by specifying a list of `ConfigDict` named `attn_cfgs`. It is worth mentioning that it supports `prenorm` when you specifying `norm` as the first element of `operation_order`. More details about the `prenorm`: `On Layer Normalization in the Transformer Architecture `_ . Args: attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for `self_attention` or `cross_attention` modules, The order of the configs in the list should be consistent with corresponding attentions in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. Default: None. ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for FFN, The order of the configs in the list should be consistent with corresponding ffn in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Support `prenorm` when you specifying first element as `norm`. Default:None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False. """ def __init__(self, attn_cfgs=None, with_cp=False, ffn_cfgs=dict( type='FFN', embed_dims=256, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ), operation_order=None, norm_cfg=dict(type='LN'), init_cfg=None, batch_first=False, use_lora=False, lora_rank=16, ffn_use_lora=False, ffn_lora_rank=16, ffn_use_adapter=False, ffn_adapter_num=6, moe_lora=False, num_task=6, **kwargs): deprecated_args = dict( feedforward_channels='feedforward_channels', ffn_dropout='ffn_drop', ffn_num_fcs='num_fcs') for ori_name, new_name in deprecated_args.items(): if ori_name in kwargs: warnings.warn( f'The arguments `{ori_name}` in BaseTransformerLayer ' f'has been deprecated, now you should set `{new_name}` ' f'and other FFN related arguments ' f'to a dict named `ffn_cfgs`. ') ffn_cfgs[new_name] = kwargs[ori_name] super(BaseTransformerLayer, self).__init__(init_cfg) self.batch_first = batch_first assert set(operation_order) & set( ['self_attn', 'norm', 'ffn', 'cross_attn']) == \ set(operation_order), f'The operation_order of' \ f' {self.__class__.__name__} should ' \ f'contains all four operation type ' \ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" num_attn = operation_order.count('self_attn') + operation_order.count( 'cross_attn') attn_use_lora = False if isinstance(attn_cfgs, dict): if 'use_lora' in attn_cfgs: attn_use_lora = attn_cfgs['use_lora'] else: attn_use_lora = False if isinstance(attn_cfgs, dict): attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] else: assert num_attn == len(attn_cfgs), f'The length ' \ f'of attn_cfg {num_attn} is ' \ f'not consistent with the number of attention' \ f'in operation_order {operation_order}.' self.num_attn = num_attn self.operation_order = operation_order self.norm_cfg = norm_cfg self.pre_norm = operation_order[0] == 'norm' self.attentions = ModuleList() self.use_lora = use_lora index = 0 have_use_lora = False for attn_cfg in attn_cfgs: if 'use_lora' in attn_cfg: if attn_cfg['use_lora'] == True: have_use_lora = True for operation_name in operation_order: if operation_name in ['self_attn', 'cross_attn']: if 'batch_first' in attn_cfgs[index]: assert self.batch_first == attn_cfgs[index]['batch_first'] else: attn_cfgs[index]['batch_first'] = self.batch_first attention = build_attention(attn_cfgs[index]) if have_use_lora: if 'use_lora' not in attn_cfgs[index].keys(): for param in attention.parameters(): param.requires_grad = False else: if attn_cfgs[index]['use_lora']==False: for param in attention.parameters(): param.requires_grad = False # Some custom attentions used as `self_attn` # or `cross_attn` can have different behavior. attention.operation_name = operation_name self.attentions.append(attention) index += 1 self.embed_dims = self.attentions[0].embed_dims self.ffns = ModuleList() ffn_cfgs['use_lora'] = ffn_use_lora ffn_cfgs['lora_rank'] = ffn_lora_rank ffn_cfgs['use_adapter'] = ffn_use_adapter ffn_cfgs['adapter_num'] = ffn_adapter_num ffn_cfgs['moe_lora'] = moe_lora ffn_cfgs['num_task'] = num_task self.freeze_ffn = False if ffn_use_lora==False: self.freeze_ffn = True num_ffns = operation_order.count('ffn') if isinstance(ffn_cfgs, dict): ffn_cfgs = ConfigDict(ffn_cfgs) if isinstance(ffn_cfgs, dict): ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] assert len(ffn_cfgs) == num_ffns for ffn_index in range(num_ffns): if 'embed_dims' not in ffn_cfgs[ffn_index]: ffn_cfgs['embed_dims'] = self.embed_dims else: assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims self.ffns.append( build_feedforward_network(ffn_cfgs[ffn_index], dict(type='FFN'))) self.norms = ModuleList() num_norms = operation_order.count('norm') for _ in range(num_norms): self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) self.with_cp = with_cp if self.freeze_ffn: for param in self.ffns.parameters(): param.requires_grad = False def forward(self, query, key=None, value=None, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, task_mask=None, forward_origin=False, task_idx=None, **kwargs): """Forward function for `TransformerDecoderLayer`. **kwargs contains some specific arguments of attentions. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . value (Tensor): The value tensor with same shape as `key`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor] | None): 2D Tensor used in calculation of corresponding attention. The length of it should equal to the number of `attention` in `operation_order`. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in `self_attn` layer. Defaults to None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ norm_index = 0 attn_index = 0 ffn_index = 0 identity = query if attn_masks is None: attn_masks = [None for _ in range(self.num_attn)] elif isinstance(attn_masks, torch.Tensor): attn_masks = [ copy.deepcopy(attn_masks) for _ in range(self.num_attn) ] warnings.warn(f'Use same attn_mask in all attentions in ' f'{self.__class__.__name__} ') else: assert len(attn_masks) == self.num_attn, f'The length of ' \ f'attn_masks {len(attn_masks)} must be equal ' \ f'to the number of attention in ' \ f'operation_order {self.num_attn}' for layer in self.operation_order: if layer == 'self_attn': temp_key = temp_value = query query = self.attentions[attn_index]( query, temp_key, temp_value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=query_pos, attn_mask=attn_masks[attn_index], key_padding_mask=query_key_padding_mask, forward_origin=forward_origin, task_idx=task_idx, **kwargs) attn_index += 1 identity = query elif layer == 'norm': query = self.norms[norm_index](query) norm_index += 1 elif layer == 'cross_attn': query = self.attentions[attn_index]( query, key, value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_masks[attn_index], key_padding_mask=key_padding_mask, task_mask=task_mask, forward_origin=forward_origin, task_idx=task_idx, **kwargs) attn_index += 1 identity = query elif layer == 'ffn': if self.with_cp: query = cp.checkpoint(self.ffns[ffn_index], query) else: query = self.ffns[ffn_index]( query, identity if self.pre_norm else None, forward_origin=forward_origin, task_idx=task_idx,) ffn_index += 1 return query @TRANSFORMER_LAYER_SEQUENCE.register_module(force=True) class TransformerLayerSequence(BaseModule): """Base class for TransformerEncoder and TransformerDecoder in vision transformer. As base-class of Encoder and Decoder in vision transformer. Support customization such as specifying different kind of `transformer_layer` in `transformer_coder`. Args: transformerlayer (list[obj:`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`): Config of transformerlayer in TransformerCoder. If it is obj:`mmcv.ConfigDict`, it would be repeated `num_layer` times to a list[`mmcv.ConfigDict`]. Default: None. num_layers (int): The number of `TransformerLayer`. Default: None. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): super(TransformerLayerSequence, self).__init__(init_cfg) if isinstance(transformerlayers, dict): transformerlayers = [ copy.deepcopy(transformerlayers) for _ in range(num_layers) ] else: assert isinstance(transformerlayers, list) and \ len(transformerlayers) == num_layers self.num_layers = num_layers self.layers = ModuleList() for i in range(num_layers): self.layers.append(build_transformer_layer(transformerlayers[i])) self.embed_dims = self.layers[0].embed_dims self.pre_norm = self.layers[0].pre_norm def forward(self, query, key, value, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, task_mask=None, forward_origin=False, task_idx=None, **kwargs): """Forward function for `TransformerCoder`. Args: query (Tensor): Input query with shape `(num_queries, bs, embed_dims)`. key (Tensor): The key tensor with shape `(num_keys, bs, embed_dims)`. value (Tensor): The value tensor with shape `(num_keys, bs, embed_dims)`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor], optional): Each element is 2D Tensor which is used in calculation of corresponding attention in operation_order. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in self-attention Default: None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: results with shape [num_queries, bs, embed_dims]. """ for layer in self.layers: query = layer( query, key, value, query_pos=query_pos, key_pos=key_pos, attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask, key_padding_mask=key_padding_mask, task_mask=task_mask, forward_origin=forward_origin, task_idx=task_idx, **kwargs) return query def transformer_test(): model = MultiheadAttention( 64, 4, use_lora=True ) model = FFN(use_lora=True) # finetuning_detach(model) print("Model structure after attaching LoRA layers:\n", model) for name, param in model.named_parameters(): print(name, param.shape, param.requires_grad) if __name__=='__main__': transformer_test()