HRA / nlu /adapterlib /layers.py
nvan13's picture
Add files using upload-large-folder tool
ab0f6ec verified
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, List
class HRALinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
config: dict,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
config = config.hra
self.r = config.r
self.apply_GS = config.apply_GS
half_u = torch.zeros(self.in_features, self.r // 2)
nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True)
self.weight.requires_grad = False
self.register_buffer(
"eye",
torch.eye(self.in_features)
)
self.alpha = getattr(config, "alpha", 16.0)
self.scale = self.alpha / self.r
nn.Linear.reset_parameters(self)
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
# def forward(self, x):
# orig_weight = self.weight
# if self.apply_GS:
# weight = [(self.hra_u[:, 0] / self.hra_u[:, 0].norm()).view(-1, 1)]
# for i in range(1, self.r):
# ui = self.hra_u[:, i].view(-1, 1)
# for j in range(i):
# ui = ui - (weight[j].t() @ ui) * weight[j]
# weight.append((ui / ui.norm()).view(-1, 1))
# weight = torch.cat(weight, dim=1)
# new_weight = torch.mm(orig_weight, torch.eye(self.in_features, device=x.device, dtype=x.dtype) - 2 * weight @ weight.t())
# else:
# new_weight = orig_weight
# hra_u_norm = self.hra_u / self.hra_u.norm(dim=0)
# for i in range(self.r):
# ui = hra_u_norm[:, i].view(-1, 1)
# new_weight = torch.mm(new_weight, torch.eye(self.in_features, device=x.device, dtype=x.dtype) - 2 * ui @ ui.t())
# out = F.linear(input=x, weight=new_weight, bias=self.bias)
# return out
def forward(self, x):
# KHÔNG dùng .data
W = self.weight # frozen weight, requires_grad=False
# ===== build orthogonal Q =====
if self.apply_GS:
U = []
for i in range(self.r):
ui = self.hra_u[:, i]
for uj in U:
ui = ui - torch.dot(uj, ui) * uj
ui = ui / (ui.norm() + 1e-6)
U.append(ui)
U = torch.stack(U, dim=1) # [in_features, r]
Q = self.eye - 2.0 * (U @ U.t())
else:
hra_u_norm = self.hra_u / (self.hra_u.norm(dim=0, keepdim=True) + 1e-6)
Q = self.eye
for i in range(self.r):
ui = hra_u_norm[:, i:i+1]
Q = Q @ (self.eye - 2.0 * ui @ ui.t())
# ===== HRA residual (CRITICAL) =====
deltaW = self.scale * (W @ (Q - self.eye))
W_eff = W + deltaW
return F.linear(x, W_eff, self.bias)
def project(R, eps):
I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
diff = R - I
norm_diff = torch.norm(diff)
if norm_diff <= eps:
return R
else:
return I + eps * (diff / norm_diff)
def project_batch(R, eps=1e-5):
# scaling factor for each of the smaller block matrix
eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
diff = R - I
norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
mask = (norm_diff <= eps).bool()
out = torch.where(mask, R, I + eps * (diff / norm_diff))
return out
class OFTLinear(nn.Linear):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
config: dict,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
# 不是fan_in_fan_out的问题,因为没有一个module设它为true
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
config = config.oft
self.block_size = config.block_size
self.r = in_features // self.block_size
self.is_coft = config.is_coft
self.block_share = config.block_share
self.eps = config.eps
# Actual trainable parameters
if self.block_share:
# Initialized as an identity matrix
R_shape = [self.block_size, self.block_size]
self.oft_R = nn.Parameter(self.weight.new_zeros(R_shape[0], R_shape[0]))
self.eps = self.eps * R_shape[0] * R_shape[0]
else:
R_shape = [self.r, self.block_size, self.block_size]
self.oft_R = self.weight.new_zeros(R_shape[1], R_shape[1])
self.oft_R = torch.stack([self.oft_R] * self.r)
self.oft_R = nn.Parameter(self.oft_R)
self.eps = self.eps * R_shape[1] * R_shape[1]
self.weight.requires_grad = False
# self.reset_parameters()
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'R'):
nn.init.kaiming_uniform_(self.oft_R, a=math.sqrt(5))
def forward(self, x):
if self.block_share:
if self.is_coft:
with torch.no_grad():
self.oft_R.copy_(project(self.oft_R, eps=self.eps))
orth_rotate = self.cayley(self.oft_R)
else:
if self.is_coft:
with torch.no_grad():
self.oft_R.copy_(project_batch(self.oft_R, eps=self.eps))
orth_rotate = self.cayley_batch(self.oft_R)
# Block-diagonal parametrization
block_diagonal_matrix = self.block_diagonal(orth_rotate)
out = F.linear(input=x, weight=self.weight @ block_diagonal_matrix.to(x.dtype).t(), bias=self.bias)
return out
def cayley(self, data):
r, c = list(data.shape)
# Ensure the input matrix is skew-symmetric
skew = 0.5 * (data - data.t())
I = torch.eye(r, device=data.device)
# Perform the Cayley parametrization
Q = torch.mm(I + skew, torch.inverse(I - skew))
return Q
def cayley_batch(self, data):
b, r, c = data.shape
# Ensure the input matrix is skew-symmetric
skew = 0.5 * (data - data.transpose(1, 2))
I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c)
# Perform the Cayley parametrization
Q = torch.bmm(I - skew, torch.inverse(I + skew))
return Q
def block_diagonal(self, R):
if self.block_share:
# Create a list of R repeated block_count times
blocks = [R] * self.r
else:
# Create a list of R slices along the third dimension
blocks = [R[i, ...] for i in range(self.r)]
# Use torch.block_diag to create the block diagonal matrix
A = torch.block_diag(*blocks)
return A
class LoRALayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class Embedding(nn.Embedding, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
merge_weights: bool = True,
**kwargs
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
merge_weights=merge_weights)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
after_A = F.embedding(
x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)
class LoRALinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
config: dict,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
config = config.lora
LoRALayer.__init__(self, r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout,
merge_weights=config.merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if self.r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((self.r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, self.r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize B the same way as the default for nn.Linear and A to zero
# this is different than what is described in the paper but should not affect performance
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
assert out_features % len(enable_lora) == 0, \
'The length of enable_lora must divide out_features'
self.enable_lora = enable_lora
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(
self.weight.new_zeros((r * sum(enable_lora), in_features)))
self.lora_B = nn.Parameter(
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
) # weights for Conv1D with groups=sum(enable_lora)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros(
(out_features, ), dtype=torch.bool
).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def zero_pad(self, x):
result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
result[self.lora_ind] = x
return result
def merge_AB(self):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
delta_w = F.conv1d(
self.lora_A.unsqueeze(0),
self.lora_B.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
return T(self.zero_pad(delta_w))
def train(self, mode: bool = True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
self.weight.data -= self.merge_AB() * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
self.weight.data += self.merge_AB() * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.merged:
return F.linear(x, T(self.weight), bias=self.bias)
else:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling
return result
class ConvLoRA(nn.Module, LoRALayer):
def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
super(ConvLoRA, self).__init__()
self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
assert isinstance(kernel_size, int)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(
self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
)
self.lora_B = nn.Parameter(
self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.conv.weight.requires_grad = False
self.reset_parameters()
self.merged = False
def reset_parameters(self):
self.conv.reset_parameters()
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode=True):
super(ConvLoRA, self).train(mode)
if mode:
if self.merge_weights and self.merged:
if self.r > 0:
# Make sure that the weights are not merged
self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
if self.r > 0:
# Merge the weights and mark it
self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
self.merged = True
def forward(self, x):
if self.r > 0 and not self.merged:
return self.conv._conv_forward(
x,
self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
self.conv.bias
)
return self.conv(x)
class Conv2d(ConvLoRA):
def __init__(self, *args, **kwargs):
super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
class Conv1d(ConvLoRA):
def __init__(self, *args, **kwargs):
super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
# Can Extend to other ones like this
class Conv3d(ConvLoRA):
def __init__(self, *args, **kwargs):
super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
adapter_dict = {
'lora': LoRALinear,
'oft': OFTLinear,
'hra': HRALinear,
}