ACE-LoRA / loralib /layers.py
aydnarda's picture
upload supp files
05a82cf verified
# ------------------------------------------------------------------------------------------
# This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, List
from torch.jit import Final
from timm.layers import use_fused_attn
from timm.models.vision_transformer import Attention
from transformers.models.bert.modeling_bert import BertAttention
from typing import Optional, Tuple
def set_param(curr_mod, name, param=None, mode='update'):
r"""Refer to https://github.com/Baijiong-Lin/MOML/blob/main/MTL/utils.py"""
if '.' in name:
n = name.split('.')
module_name = n[0]
rest = '.'.join(n[1:])
for name, mod in curr_mod.named_children():
if module_name == name:
return set_param(mod, rest, param, mode=mode)
else:
if mode == 'update':
delattr(curr_mod, name)
setattr(curr_mod, name, param)
elif mode == 'get':
if hasattr(curr_mod, name):
p = getattr(curr_mod, name)
return p
class LoRALayer():
def __init__(
self,
r: int,
lora_alpha: int,
fan_in_fan_out: bool = False,
dropout_rate:float = 0,
):
self.r = r
self.lora_alpha = lora_alpha
self.dropout_rate = dropout_rate
if self.r > 0:
#self.scaling = self.lora_alpha / self.r
self.scaling = self.lora_alpha/math.sqrt(self.r) #
# Mark the weight as unmerged
self.merged = False
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
self.fan_in_fan_out = fan_in_fan_out
# define params that require LoRA {'param_name': 'lora_name'}
self.params_with_lora = {}
def register_lora_param(self):
r"""Register LoRA matrix"""
for param_name, lora_name in self.params_with_lora.items():
assert len(eval(f'self.{param_name}').size()) == 2
self.register_parameter(f'{lora_name}_lora_A',
nn.Parameter(eval(f'self.{param_name}').new_zeros((self.r, eval(f'self.{param_name}').size()[1])))
)
self.register_parameter(f'{lora_name}_lora_B',
nn.Parameter(eval(f'self.{param_name}').new_zeros((eval(f'self.{param_name}').size()[0], self.r)))
)
eval(f'self.{param_name}').requires_grad = False
def init_lora_param(self):
for param_name, lora_name in self.params_with_lora.items():
if hasattr(self, f'{lora_name}_lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5))
nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
def transpose(self, w: torch.Tensor):
return w.transpose(0, 1) if self.fan_in_fan_out else w
def merge_BA(self, param_name: str):
lora_name = self.params_with_lora[param_name]
return self.transpose((eval(f'self.{lora_name}_lora_B') @ eval(f'self.{lora_name}_lora_A')).view(eval(f'self.{param_name}').shape))
def merge_lora_param(self):
r"""p_new = p + scaling * B @ A and keep differentiable to A and B"""
for param_name, lora_name in self.params_with_lora.items():
p = set_param(self, param_name, mode='get')
# detach() is very important here
p_new = p.detach() + self.merge_BA(param_name) * self.scaling
set_param(self, param_name, param=p_new, mode='update')
def add_lora_data(self):
r"""NOT differentiable"""
for param_name, lora_name in self.params_with_lora.items():
eval(f'self.{param_name}').data += self.merge_BA(param_name) * self.scaling
def sub_lora_data(self):
r"""NOT differentiable"""
for param_name, lora_name in self.params_with_lora.items():
eval(f'self.{param_name}').data -= self.merge_BA(param_name) * self.scaling
def lora_train(self, mode: bool = True):
if mode:
if self.merged and self.r > 0:
# Make sure that the weights are not merged
self.sub_lora_data()
self.merged = False
else:
if not self.merged and self.r > 0:
# Merge the weights and mark it
self.add_lora_data()
self.merged = True
class Embedding(nn.Embedding, LoRALayer):
# LoRA implemented in a Embedding layer
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
self.params_with_lora = {'weight': 'w'}
if r > 0:
self.register_lora_param()
nn.Embedding.reset_parameters(self)
self.init_lora_param()
def init_lora_param(self):
if hasattr(self, 'w_lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.zeros_(self.w_lora_A)
nn.init.normal_(self.w_lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
self.lora_train(mode)
def forward(self, x: torch.Tensor, **kwargs):
if self.r > 0 and not self.merged:
self.merge_lora_param()
result = nn.Embedding.forward(self, x, **kwargs)
self.sub_lora_data()
return result
else:
return nn.Embedding.forward(self, x, **kwargs)
class LinearLoRA(nn.Linear, LoRALayer):
# LoRA implemented in a Linear layer
def __init__(
self,
existing_linear: nn.Linear,
r: int = 0,
lora_alpha: int = 1,
fan_in_fan_out: bool = False,
dropout_rate = 0.,
seed: int = 1,
**kwargs
):
super().__init__(
in_features=existing_linear.in_features,
out_features=existing_linear.out_features)
self.load_state_dict(existing_linear.state_dict())
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, fan_in_fan_out=fan_in_fan_out)
# Actual trainable parameters
self.params_with_lora = {'weight': 'w'}
if r > 0:
self.register_lora_param()
self.init_lora_param()
self.weight.data = self.transpose(self.weight.data)
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
else:
self.dropout = None
def train(self, mode: bool = True):
super().train(mode)
self.lora_train(mode)
def forward(self, x: torch.Tensor, **kwargs):
if self.dropout is None: # do as before
if self.r > 0 and not self.merged:
self.merge_lora_param()
result = nn.Linear.forward(self, x, **kwargs)
self.sub_lora_data()
return result
else:
return nn.Linear.forward(self, x, **kwargs)
# Compute the original linear transformation
original_output = nn.Linear.forward(self, x)
if self.training and self.dropout.p > 0:
x = self.dropout(x)
if self.r > 0 and not self.merged:
lora_adjustment = torch.matmul(x,self.merge_BA('weight').transpose(0, 1)) * self.scaling
result = original_output + lora_adjustment
else:
result = original_output
return result
class Conv1d(nn.Conv1d, LoRALayer):
# LoRA implemented in a Conv1d layer
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
r: int = 0,
lora_alpha: int = 1,
**kwargs
):
nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
assert type(kernel_size) is int
# Actual trainable parameters
self.params_with_lora = {'weight': 'w'}
if r > 0:
self.w_lora_A = nn.Parameter(
self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
)
self.w_lora_B = nn.Parameter(
self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
nn.Conv1d.reset_parameters(self)
self.init_lora_param()
def train(self, mode: bool = True):
nn.Conv1d.train(self, mode)
self.lora_train(mode)
def forward(self, x: torch.Tensor, **kwargs):
if self.r > 0 and not self.merged:
self.merge_lora_param()
result = nn.Conv1d.forward(self, x, **kwargs)
self.sub_lora_data()
return result
else:
return nn.Conv1d.forward(self, x, **kwargs)
class Conv2d(nn.Conv2d, LoRALayer):
# LoRA implemented in a Conv2d layer
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
r: int = 0,
lora_alpha: int = 1,
**kwargs
):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
assert type(kernel_size) is int
# Actual trainable parameters
self.params_with_lora = {'weight': 'w'}
if r > 0:
self.w_lora_A = nn.Parameter(
self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
)
self.w_lora_B = nn.Parameter(
self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
nn.Conv2d.reset_parameters(self)
self.init_lora_param()
def train(self, mode: bool = True):
nn.Conv2d.train(self, mode)
self.lora_train(mode)
def forward(self, x: torch.Tensor, **kwargs):
if self.r > 0 and not self.merged:
self.merge_lora_param()
result = nn.Conv2d.forward(self, x, **kwargs)
self.sub_lora_data()
return result
else:
return nn.Conv2d.forward(self, x, **kwargs)
class Conv3d(nn.Conv3d, LoRALayer):
# LoRA implemented in a Conv3d layer
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
r: int = 0,
lora_alpha: int = 1,
**kwargs
):
nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
assert type(kernel_size) is int
# Actual trainable parameters
self.params_with_lora = {'weight': 'w'}
if r > 0:
self.w_lora_A = nn.Parameter(
self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
)
self.w_lora_B = nn.Parameter(
self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
nn.Conv3d.reset_parameters(self)
self.init_lora_param()
def train(self, mode: bool = True):
nn.Conv3d.train(self, mode)
self.lora_train(mode)
def forward(self, x: torch.Tensor, **kwargs):
if self.r > 0 and not self.merged:
self.merge_lora_param()
result = nn.Conv3d.forward(self, x, **kwargs)
self.sub_lora_data()
return result
else:
return nn.Conv3d.forward(self, x, **kwargs)
class PlainMultiheadAttentionLoRA(nn.Module):
def __init__(
self,
existing_mha: nn.MultiheadAttention,
enable_lora: list = ['q', 'k', 'v', 'o'],
r: int = 0,
lora_alpha: int = 1,
dropout_rate:float = 0.,
**kwargs
):
super().__init__()
self.dropout = 0 # this module is not used to retrain the main block
self.embed_dim = existing_mha.embed_dim
self.kdim = existing_mha.kdim
self.vdim = existing_mha.vdim
self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim
self.num_heads = existing_mha.num_heads
self.batch_first = existing_mha.batch_first
self.head_dim = existing_mha.head_dim
#self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=existing_mha.in_proj_bias is not None)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None)
# Initialize parameters
with torch.no_grad():
# Extract the existing weights and biases
existing_weight = existing_mha.in_proj_weight.data
existing_bias = existing_mha.in_proj_bias.data if existing_mha.in_proj_bias is not None else None
# Initialize q_proj
self.q_proj.weight.data.copy_(existing_weight[:self.embed_dim, :])
if existing_bias is not None:
self.q_proj.bias.data.copy_(existing_bias[:self.embed_dim])
# Initialize k_proj
self.k_proj.weight.data.copy_(existing_weight[self.embed_dim:2*self.embed_dim, :])
if existing_bias is not None:
self.k_proj.bias.data.copy_(existing_bias[self.embed_dim:2*self.embed_dim])
# Initialize v_proj
self.v_proj.weight.data.copy_(existing_weight[2*self.embed_dim:, :])
if existing_bias is not None:
self.v_proj.bias.data.copy_(existing_bias[2*self.embed_dim:])
# Initialize proj
self.proj.weight.data.copy_(existing_mha.out_proj.weight.data)
if self.proj.bias is not None:
self.proj.bias.data.copy_(existing_mha.out_proj.bias.data)
self.scaled_dot_product_attention = F.scaled_dot_product_attention
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate)
# Init qkv as a new lora linear layer
for item in enable_lora:
if item == 'q':
self.q_proj = LinearLoRA(self.q_proj,
r=r,
lora_alpha=lora_alpha,
fan_in_fan_out=False,
dropout_rate = dropout_rate)
elif item == 'k':
self.k_proj = LinearLoRA(self.k_proj,
r=r,
lora_alpha=lora_alpha,
fan_in_fan_out=False,
dropout_rate = dropout_rate)
elif item == 'v':
self.v_proj = LinearLoRA(self.v_proj,
r=r,
lora_alpha=lora_alpha,
fan_in_fan_out=False,
dropout_rate = dropout_rate)
elif item == 'o':
self.proj = LinearLoRA(self.proj,
r=r,
lora_alpha=lora_alpha,
fan_in_fan_out=False,
dropout_rate = dropout_rate)
def forward_module(
self,
query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=True,
is_causal=False):
if attn_mask is not None and is_causal:
raise AssertionError("Only allow causal mask or attn_mask")
is_batched = query.dim() == 3
key_padding_mask = F._canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=F._none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
)
if self.batch_first and is_batched:
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
"""
E = query.size(-1)
qkv = self.qkv(query)
qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
q, k, v = qkv[0], qkv[1], qkv[2]
"""
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
attn_mask = F._canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=F._none_or_dtype(key_padding_mask),
other_name="key_padding_mask",
target_type=q.dtype,
check_other=False,
)
if attn_mask is not None:
# ensure attn_mask's dim is 3
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * self.num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)
dropout_p = self.dropout if self.training else 0.
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
src_len = k.size(1)
q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
k = k.view(bsz, self.num_heads, src_len, self.head_dim)
v = v.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = self.proj(attn_output)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), None
return attn_output, None
def train(self, mode: bool = True):
super().train(mode)
#self.lora_train(mode)
def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kwargs):
return self.forward_module(query, key, value, **kwargs)
class AttentionLoRA(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
existing_mha: Attention,
enable_lora: list = ['q', 'k', 'v', 'o'],
r: int = 0,
lora_alpha: int = 1,
dropout_rate: float = 0.,
seed: int = 1,
) -> None:
super().__init__()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
self.embed_dim = existing_mha.proj.in_features
self.num_heads = existing_mha.num_heads
self.head_dim = existing_mha.head_dim
assert self.embed_dim % self.num_heads == 0, 'dim should be divisible by num_heads'
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.dropout = 0
self.q_norm = existing_mha.q_norm
self.k_norm = existing_mha.k_norm
self.attn_drop = nn.Dropout(self.dropout)
self.proj_drop = nn.Dropout(self.dropout)
self.r = r
self.lora_alpha = lora_alpha
self.dropout_rate = dropout_rate
self.enable_lora = enable_lora
self.seed = seed
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.qkv.bias is not None)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.qkv.bias is not None)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.qkv.bias is not None)
self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.proj.bias is not None)
# Initialize parameters
with torch.no_grad():
existing_weight = existing_mha.qkv.weight.data
existing_bias = existing_mha.qkv.bias.data
self.q_proj.weight.data.copy_(existing_weight[:self.embed_dim, :])
if existing_bias is not None:
self.q_proj.bias.data.copy_(existing_bias[:self.embed_dim])
self.k_proj.weight.data.copy_(existing_weight[self.embed_dim:2*self.embed_dim, :])
if existing_bias is not None:
self.k_proj.bias.data.copy_(existing_bias[self.embed_dim:2*self.embed_dim])
self.v_proj.weight.data.copy_(existing_weight[2*self.embed_dim:, :])
if existing_bias is not None:
self.v_proj.bias.data.copy_(existing_bias[2*self.embed_dim:])
self.proj.weight.data.copy_(existing_mha.proj.weight.data)
if self.proj.bias is not None:
self.proj.bias.data.copy_(existing_mha.proj.bias.data)
self.q_proj, self.k_proj, self.v_proj, self.proj = self.inject_lora(self.q_proj, self.k_proj, self.v_proj, self.proj)
def inject_lora(self, q, k, v, proj):
for item in self.enable_lora:
if item == 'q':
q = LinearLoRA(q,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
elif item == 'k':
k = LinearLoRA(k,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
elif item == 'v':
v = LinearLoRA(v,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
elif item == 'o':
proj = LinearLoRA(proj,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
return q, k, v, proj
def forward(self, x: torch.Tensor, return_attn_scores=False) -> torch.Tensor:
B, N, C = x.shape
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
q, k = self.q_norm(q), self.k_norm(k)
if return_attn_scores:
q = q * self.scale
attn_scores = q @ k.transpose(-2, -1)
attn = attn_scores.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return (x, attn_scores)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class BertAttentionLoRA(nn.Module):
def __init__(self,
existing_mha: BertAttention,
enable_lora: list = ['q', 'k', 'v', 'o'],
r: int = 0,
lora_alpha: int = 1,
dropout_rate: float = 0.,
seed:int = 1,):
super().__init__()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
self.self_attn = existing_mha.self
self.output = existing_mha.output
self.num_attention_heads = self.self_attn.num_attention_heads
self.attention_head_size = self.self_attn.attention_head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = self.self_attn.query.in_features
self.q_proj = nn.Linear(self.hidden_size, self.all_head_size)
self.k_proj = nn.Linear(self.hidden_size, self.all_head_size)
self.v_proj = nn.Linear(self.hidden_size, self.all_head_size)
self.proj = nn.Linear(self.output.dense.in_features, self.output.dense.in_features)
self.LayerNorm = self.output.LayerNorm
self.dropout = nn.Dropout(0)
self.r = r
self.lora_alpha = lora_alpha
self.dropout_rate = dropout_rate
self.enable_lora = enable_lora
self.seed = seed
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate)
# Initialize parameters
with torch.no_grad():
self.q_proj.weight.data.copy_(self.self_attn.query.weight.data)
if self.self_attn.query.bias.data is not None:
self.q_proj.bias.data.copy_(self.self_attn.query.bias.data)
self.k_proj.weight.data.copy_(self.self_attn.key.weight.data)
if self.self_attn.key.bias.data is not None:
self.k_proj.bias.data.copy_(self.self_attn.key.bias.data)
self.v_proj.weight.data.copy_(self.self_attn.value.weight.data)
if self.self_attn.value.bias.data is not None:
self.v_proj.bias.data.copy_(self.self_attn.value.bias.data)
self.proj.weight.data.copy_(self.output.dense.weight.data)
if self.output.dense.bias.data is not None:
self.proj.bias.data.copy_(self.output.dense.bias.data)
self.q_proj, self.k_proj, self.v_proj, self.proj = self.inject_lora(self.q_proj, self.k_proj, self.v_proj, self.proj)
self.position_embedding_type = self.self_attn.position_embedding_type
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = self.self_attn.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * self.self_attn.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = self.self_attn.is_decoder
def inject_lora(self, q, k, v, proj):
for item in self.enable_lora:
if item == 'q':
q = LinearLoRA(q,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
elif item == 'k':
k = LinearLoRA(k,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
elif item == 'v':
v = LinearLoRA(v,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
elif item == 'o':
proj = LinearLoRA(proj,
r=self.r,
lora_alpha=self.lora_alpha,
fan_in_fan_out=False,
dropout_rate = self.dropout_rate,
seed=self.seed)
return q, k, v, proj
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.q_proj(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.k_proj(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.v_proj(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
self_attn_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
self_attn_outputs = self_attn_outputs + (past_key_value,)
# attention_output = self.output(self_outputs[0], hidden_states)
self_outputs = self.proj(self_attn_outputs[0])
attention_output = self.LayerNorm(self_outputs + hidden_states)
outputs = (attention_output,) + self_attn_outputs[1:] # add attentions if we output them
return outputs
class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
assert out_features % len(enable_lora) == 0, \
'The length of enable_lora must divide out_features'
self.enable_lora = enable_lora
# Actual trainable parameters
self.params_with_lora = {'weight': 'w'}
if r > 0 and any(enable_lora):
self.w_lora_A = nn.Parameter(
self.weight.new_zeros((r * sum(enable_lora), in_features)))
self.w_lora_B = nn.Parameter(
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
) # weights for Conv1D with groups=sum(enable_lora)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros(
(out_features, ), dtype=torch.bool
).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
nn.Linear.reset_parameters(self)
self.init_lora_param()
self.weight.data = self.transpose(self.weight.data)
def zero_pad(self, x):
result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
result[self.lora_ind] = x
return result
def merge_BA(self, param_name: str):
lora_name = self.params_with_lora[param_name]
delta_w = F.conv1d(
eval(f'self.{lora_name}_lora_A').unsqueeze(0),
eval(f'self.{lora_name}_lora_B').unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
return self.transpose(self.zero_pad(delta_w))
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.lora_train(mode)
def forward(self, x: torch.Tensor, **kwargs):
if self.r > 0 and not self.merged:
self.merge_lora_param()
result = nn.Linear.forward(self, x, **kwargs)
self.sub_lora_data()
return result
else:
return nn.Linear.forward(self, x, **kwargs)