# ------------------------------------------------------------------------------------------ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, List class HRALinear(nn.Linear): def __init__( self, in_features: int, out_features: int, config: dict, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) config = config.hra self.r = config.r self.apply_GS = config.apply_GS half_u = torch.zeros(self.in_features, self.r // 2) nn.init.kaiming_uniform_(half_u, a=math.sqrt(5)) self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True) self.weight.requires_grad = False self.register_buffer( "eye", torch.eye(self.in_features) ) self.alpha = getattr(config, "alpha", 16.0) self.scale = self.alpha / self.r nn.Linear.reset_parameters(self) def train(self, mode: bool = True): nn.Linear.train(self, mode) # def forward(self, x): # orig_weight = self.weight # if self.apply_GS: # weight = [(self.hra_u[:, 0] / self.hra_u[:, 0].norm()).view(-1, 1)] # for i in range(1, self.r): # ui = self.hra_u[:, i].view(-1, 1) # for j in range(i): # ui = ui - (weight[j].t() @ ui) * weight[j] # weight.append((ui / ui.norm()).view(-1, 1)) # weight = torch.cat(weight, dim=1) # new_weight = torch.mm(orig_weight, torch.eye(self.in_features, device=x.device, dtype=x.dtype) - 2 * weight @ weight.t()) # else: # new_weight = orig_weight # hra_u_norm = self.hra_u / self.hra_u.norm(dim=0) # for i in range(self.r): # ui = hra_u_norm[:, i].view(-1, 1) # new_weight = torch.mm(new_weight, torch.eye(self.in_features, device=x.device, dtype=x.dtype) - 2 * ui @ ui.t()) # out = F.linear(input=x, weight=new_weight, bias=self.bias) # return out def forward(self, x): # KHÔNG dùng .data W = self.weight # frozen weight, requires_grad=False # ===== build orthogonal Q ===== if self.apply_GS: U = [] for i in range(self.r): ui = self.hra_u[:, i] for uj in U: ui = ui - torch.dot(uj, ui) * uj ui = ui / (ui.norm() + 1e-6) U.append(ui) U = torch.stack(U, dim=1) # [in_features, r] Q = self.eye - 2.0 * (U @ U.t()) else: hra_u_norm = self.hra_u / (self.hra_u.norm(dim=0, keepdim=True) + 1e-6) Q = self.eye for i in range(self.r): ui = hra_u_norm[:, i:i+1] Q = Q @ (self.eye - 2.0 * ui @ ui.t()) # ===== HRA residual (CRITICAL) ===== deltaW = self.scale * (W @ (Q - self.eye)) W_eff = W + deltaW return F.linear(x, W_eff, self.bias) def project(R, eps): I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device) diff = R - I norm_diff = torch.norm(diff) if norm_diff <= eps: return R else: return I + eps * (diff / norm_diff) def project_batch(R, eps=1e-5): # scaling factor for each of the smaller block matrix eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0])) I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R) diff = R - I norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True) mask = (norm_diff <= eps).bool() out = torch.where(mask, R, I + eps * (diff / norm_diff)) return out class OFTLinear(nn.Linear): # LoRA implemented in a dense layer def __init__( self, in_features: int, out_features: int, config: dict, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) # 不是fan_in_fan_out的问题,因为没有一个module设它为true **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) config = config.oft self.block_size = config.block_size self.r = in_features // self.block_size self.is_coft = config.is_coft self.block_share = config.block_share self.eps = config.eps # Actual trainable parameters if self.block_share: # Initialized as an identity matrix R_shape = [self.block_size, self.block_size] self.oft_R = nn.Parameter(self.weight.new_zeros(R_shape[0], R_shape[0])) self.eps = self.eps * R_shape[0] * R_shape[0] else: R_shape = [self.r, self.block_size, self.block_size] self.oft_R = self.weight.new_zeros(R_shape[1], R_shape[1]) self.oft_R = torch.stack([self.oft_R] * self.r) self.oft_R = nn.Parameter(self.oft_R) self.eps = self.eps * R_shape[1] * R_shape[1] self.weight.requires_grad = False # self.reset_parameters() def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'R'): nn.init.kaiming_uniform_(self.oft_R, a=math.sqrt(5)) def forward(self, x): if self.block_share: if self.is_coft: with torch.no_grad(): self.oft_R.copy_(project(self.oft_R, eps=self.eps)) orth_rotate = self.cayley(self.oft_R) else: if self.is_coft: with torch.no_grad(): self.oft_R.copy_(project_batch(self.oft_R, eps=self.eps)) orth_rotate = self.cayley_batch(self.oft_R) # Block-diagonal parametrization block_diagonal_matrix = self.block_diagonal(orth_rotate) out = F.linear(input=x, weight=self.weight @ block_diagonal_matrix.to(x.dtype).t(), bias=self.bias) return out def cayley(self, data): r, c = list(data.shape) # Ensure the input matrix is skew-symmetric skew = 0.5 * (data - data.t()) I = torch.eye(r, device=data.device) # Perform the Cayley parametrization Q = torch.mm(I + skew, torch.inverse(I - skew)) return Q def cayley_batch(self, data): b, r, c = data.shape # Ensure the input matrix is skew-symmetric skew = 0.5 * (data - data.transpose(1, 2)) I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # Perform the Cayley parametrization Q = torch.bmm(I - skew, torch.inverse(I + skew)) return Q def block_diagonal(self, R): if self.block_share: # Create a list of R repeated block_count times blocks = [R] * self.r else: # Create a list of R slices along the third dimension blocks = [R[i, ...] for i in range(self.r)] # Use torch.block_diag to create the block diagonal matrix A = torch.block_diag(*blocks) return A class LoRALayer(): def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights class Embedding(nn.Embedding, LoRALayer): # LoRA implemented in a dense layer def __init__( self, num_embeddings: int, embedding_dim: int, r: int = 0, lora_alpha: int = 1, merge_weights: bool = True, **kwargs ): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.zeros_(self.lora_A) nn.init.normal_(self.lora_B) def train(self, mode: bool = True): nn.Embedding.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling self.merged = True def forward(self, x: torch.Tensor): if self.r > 0 and not self.merged: result = nn.Embedding.forward(self, x) after_A = F.embedding( x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse ) result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return nn.Embedding.forward(self, x) class LoRALinear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__( self, in_features: int, out_features: int, config: dict, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) config = config.lora LoRALayer.__init__(self, r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, merge_weights=config.merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if self.r > 0: self.lora_A = nn.Parameter(self.weight.new_zeros((self.r, in_features))) self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, self.r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize B the same way as the default for nn.Linear and A to zero # this is different than what is described in the paper but should not affect performance nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w nn.Linear.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.merged = True def forward(self, x: torch.Tensor): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w if self.r > 0 and not self.merged: result = F.linear(x, T(self.weight), bias=self.bias) result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return F.linear(x, T(self.weight), bias=self.bias) class MergedLinear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., enable_lora: List[bool] = [False], fan_in_fan_out: bool = False, merge_weights: bool = True, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) assert out_features % len(enable_lora) == 0, \ 'The length of enable_lora must divide out_features' self.enable_lora = enable_lora self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0 and any(enable_lora): self.lora_A = nn.Parameter( self.weight.new_zeros((r * sum(enable_lora), in_features))) self.lora_B = nn.Parameter( self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) ) # weights for Conv1D with groups=sum(enable_lora) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices self.lora_ind = self.weight.new_zeros( (out_features, ), dtype=torch.bool ).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def zero_pad(self, x): result = x.new_zeros((len(self.lora_ind), *x.shape[1:])) result[self.lora_ind] = x return result def merge_AB(self): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w delta_w = F.conv1d( self.lora_A.unsqueeze(0), self.lora_B.unsqueeze(-1), groups=sum(self.enable_lora) ).squeeze(0) return T(self.zero_pad(delta_w)) def train(self, mode: bool = True): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w nn.Linear.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0 and any(self.enable_lora): self.weight.data -= self.merge_AB() * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0 and any(self.enable_lora): self.weight.data += self.merge_AB() * self.scaling self.merged = True def forward(self, x: torch.Tensor): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w if self.merged: return F.linear(x, T(self.weight), bias=self.bias) else: result = F.linear(x, T(self.weight), bias=self.bias) if self.r > 0: result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling return result class ConvLoRA(nn.Module, LoRALayer): def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): super(ConvLoRA, self).__init__() self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) assert isinstance(kernel_size, int) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) ) self.lora_B = nn.Parameter( self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)) ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.conv.weight.requires_grad = False self.reset_parameters() self.merged = False def reset_parameters(self): self.conv.reset_parameters() if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode=True): super(ConvLoRA, self).train(mode) if mode: if self.merge_weights and self.merged: if self.r > 0: # Make sure that the weights are not merged self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: if self.r > 0: # Merge the weights and mark it self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling self.merged = True def forward(self, x): if self.r > 0 and not self.merged: return self.conv._conv_forward( x, self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, self.conv.bias ) return self.conv(x) class Conv2d(ConvLoRA): def __init__(self, *args, **kwargs): super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs) class Conv1d(ConvLoRA): def __init__(self, *args, **kwargs): super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs) # Can Extend to other ones like this class Conv3d(ConvLoRA): def __init__(self, *args, **kwargs): super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs) adapter_dict = { 'lora': LoRALinear, 'oft': OFTLinear, 'hra': HRALinear, }