| | |
| | |
| | |
| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | import math |
| | from typing import Optional, List |
| |
|
| | class HRALinear(nn.Linear): |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | config: dict, |
| | **kwargs |
| | ): |
| | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| | config = config.hra |
| | self.r = config.r |
| | self.apply_GS = config.apply_GS |
| | |
| | half_u = torch.zeros(self.in_features, self.r // 2) |
| | nn.init.kaiming_uniform_(half_u, a=math.sqrt(5)) |
| | self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True) |
| | |
| | self.weight.requires_grad = False |
| |
|
| | self.register_buffer( |
| | "eye", |
| | torch.eye(self.in_features) |
| | ) |
| | self.alpha = getattr(config, "alpha", 16.0) |
| | self.scale = self.alpha / self.r |
| |
|
| | nn.Linear.reset_parameters(self) |
| | |
| | def train(self, mode: bool = True): |
| | nn.Linear.train(self, mode) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | def forward(self, x): |
| | |
| | W = self.weight |
| |
|
| | |
| | if self.apply_GS: |
| | U = [] |
| | for i in range(self.r): |
| | ui = self.hra_u[:, i] |
| | for uj in U: |
| | ui = ui - torch.dot(uj, ui) * uj |
| | ui = ui / (ui.norm() + 1e-6) |
| | U.append(ui) |
| | U = torch.stack(U, dim=1) |
| | Q = self.eye - 2.0 * (U @ U.t()) |
| | else: |
| | hra_u_norm = self.hra_u / (self.hra_u.norm(dim=0, keepdim=True) + 1e-6) |
| | Q = self.eye |
| | for i in range(self.r): |
| | ui = hra_u_norm[:, i:i+1] |
| | Q = Q @ (self.eye - 2.0 * ui @ ui.t()) |
| |
|
| | |
| | deltaW = self.scale * (W @ (Q - self.eye)) |
| | W_eff = W + deltaW |
| |
|
| | return F.linear(x, W_eff, self.bias) |
| |
|
| |
|
| | def project(R, eps): |
| | I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device) |
| | diff = R - I |
| | norm_diff = torch.norm(diff) |
| | if norm_diff <= eps: |
| | return R |
| | else: |
| | return I + eps * (diff / norm_diff) |
| |
|
| | def project_batch(R, eps=1e-5): |
| | |
| | eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0])) |
| | I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R) |
| | diff = R - I |
| | norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True) |
| | mask = (norm_diff <= eps).bool() |
| | out = torch.where(mask, R, I + eps * (diff / norm_diff)) |
| | return out |
| |
|
| | class OFTLinear(nn.Linear): |
| | |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | config: dict, |
| | fan_in_fan_out: bool = False, |
| | |
| | **kwargs |
| | ): |
| | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| | config = config.oft |
| | self.block_size = config.block_size |
| | self.r = in_features // self.block_size |
| | self.is_coft = config.is_coft |
| | self.block_share = config.block_share |
| | self.eps = config.eps |
| | |
| | |
| | if self.block_share: |
| | |
| | R_shape = [self.block_size, self.block_size] |
| | self.oft_R = nn.Parameter(self.weight.new_zeros(R_shape[0], R_shape[0])) |
| | self.eps = self.eps * R_shape[0] * R_shape[0] |
| | else: |
| | R_shape = [self.r, self.block_size, self.block_size] |
| | self.oft_R = self.weight.new_zeros(R_shape[1], R_shape[1]) |
| | self.oft_R = torch.stack([self.oft_R] * self.r) |
| | self.oft_R = nn.Parameter(self.oft_R) |
| | self.eps = self.eps * R_shape[1] * R_shape[1] |
| | |
| | self.weight.requires_grad = False |
| | |
| | |
| | def reset_parameters(self): |
| | nn.Linear.reset_parameters(self) |
| | if hasattr(self, 'R'): |
| | nn.init.kaiming_uniform_(self.oft_R, a=math.sqrt(5)) |
| | |
| | def forward(self, x): |
| | if self.block_share: |
| | if self.is_coft: |
| | with torch.no_grad(): |
| | self.oft_R.copy_(project(self.oft_R, eps=self.eps)) |
| | orth_rotate = self.cayley(self.oft_R) |
| | else: |
| | if self.is_coft: |
| | with torch.no_grad(): |
| | self.oft_R.copy_(project_batch(self.oft_R, eps=self.eps)) |
| | orth_rotate = self.cayley_batch(self.oft_R) |
| |
|
| | |
| | block_diagonal_matrix = self.block_diagonal(orth_rotate) |
| | out = F.linear(input=x, weight=self.weight @ block_diagonal_matrix.to(x.dtype).t(), bias=self.bias) |
| |
|
| | return out |
| |
|
| | def cayley(self, data): |
| | r, c = list(data.shape) |
| | |
| | skew = 0.5 * (data - data.t()) |
| | I = torch.eye(r, device=data.device) |
| | |
| | |
| | Q = torch.mm(I + skew, torch.inverse(I - skew)) |
| | return Q |
| | |
| | def cayley_batch(self, data): |
| | b, r, c = data.shape |
| | |
| | skew = 0.5 * (data - data.transpose(1, 2)) |
| | I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) |
| |
|
| | |
| | Q = torch.bmm(I - skew, torch.inverse(I + skew)) |
| |
|
| | return Q |
| |
|
| | def block_diagonal(self, R): |
| | if self.block_share: |
| | |
| | blocks = [R] * self.r |
| | else: |
| | |
| | blocks = [R[i, ...] for i in range(self.r)] |
| |
|
| | |
| | A = torch.block_diag(*blocks) |
| |
|
| | return A |
| |
|
| | class LoRALayer(): |
| | def __init__( |
| | self, |
| | r: int, |
| | lora_alpha: int, |
| | lora_dropout: float, |
| | merge_weights: bool, |
| | ): |
| | self.r = r |
| | self.lora_alpha = lora_alpha |
| | |
| | if lora_dropout > 0.: |
| | self.lora_dropout = nn.Dropout(p=lora_dropout) |
| | else: |
| | self.lora_dropout = lambda x: x |
| | |
| | self.merged = False |
| | self.merge_weights = merge_weights |
| |
|
| |
|
| | class Embedding(nn.Embedding, LoRALayer): |
| | |
| | def __init__( |
| | self, |
| | num_embeddings: int, |
| | embedding_dim: int, |
| | r: int = 0, |
| | lora_alpha: int = 1, |
| | merge_weights: bool = True, |
| | **kwargs |
| | ): |
| | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
| | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, |
| | merge_weights=merge_weights) |
| | |
| | if r > 0: |
| | self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) |
| | self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) |
| | self.scaling = self.lora_alpha / self.r |
| | |
| | self.weight.requires_grad = False |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | nn.Embedding.reset_parameters(self) |
| | if hasattr(self, 'lora_A'): |
| | |
| | nn.init.zeros_(self.lora_A) |
| | nn.init.normal_(self.lora_B) |
| |
|
| | def train(self, mode: bool = True): |
| | nn.Embedding.train(self, mode) |
| | if mode: |
| | if self.merge_weights and self.merged: |
| | |
| | if self.r > 0: |
| | self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling |
| | self.merged = False |
| | else: |
| | if self.merge_weights and not self.merged: |
| | |
| | if self.r > 0: |
| | self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling |
| | self.merged = True |
| | |
| | def forward(self, x: torch.Tensor): |
| | if self.r > 0 and not self.merged: |
| | result = nn.Embedding.forward(self, x) |
| | after_A = F.embedding( |
| | x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, |
| | self.norm_type, self.scale_grad_by_freq, self.sparse |
| | ) |
| | result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling |
| | return result |
| | else: |
| | return nn.Embedding.forward(self, x) |
| | |
| |
|
| | class LoRALinear(nn.Linear, LoRALayer): |
| | |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | config: dict, |
| | fan_in_fan_out: bool = False, |
| | **kwargs |
| | ): |
| | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| | config = config.lora |
| | LoRALayer.__init__(self, r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, |
| | merge_weights=config.merge_weights) |
| |
|
| | self.fan_in_fan_out = fan_in_fan_out |
| | |
| | if self.r > 0: |
| | self.lora_A = nn.Parameter(self.weight.new_zeros((self.r, in_features))) |
| | self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, self.r))) |
| | self.scaling = self.lora_alpha / self.r |
| | |
| | self.weight.requires_grad = False |
| | self.reset_parameters() |
| | if fan_in_fan_out: |
| | self.weight.data = self.weight.data.transpose(0, 1) |
| |
|
| | def reset_parameters(self): |
| | nn.Linear.reset_parameters(self) |
| | if hasattr(self, 'lora_A'): |
| | |
| | |
| | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B) |
| |
|
| | def train(self, mode: bool = True): |
| | def T(w): |
| | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | nn.Linear.train(self, mode) |
| | if mode: |
| | if self.merge_weights and self.merged: |
| | |
| | if self.r > 0: |
| | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling |
| | self.merged = False |
| | else: |
| | if self.merge_weights and not self.merged: |
| | |
| | if self.r > 0: |
| | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling |
| | self.merged = True |
| |
|
| | def forward(self, x: torch.Tensor): |
| | def T(w): |
| | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | if self.r > 0 and not self.merged: |
| | result = F.linear(x, T(self.weight), bias=self.bias) |
| | result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling |
| | return result |
| | else: |
| | return F.linear(x, T(self.weight), bias=self.bias) |
| |
|
| |
|
| | class MergedLinear(nn.Linear, LoRALayer): |
| | |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | r: int = 0, |
| | lora_alpha: int = 1, |
| | lora_dropout: float = 0., |
| | enable_lora: List[bool] = [False], |
| | fan_in_fan_out: bool = False, |
| | merge_weights: bool = True, |
| | **kwargs |
| | ): |
| | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
| | merge_weights=merge_weights) |
| | assert out_features % len(enable_lora) == 0, \ |
| | 'The length of enable_lora must divide out_features' |
| | self.enable_lora = enable_lora |
| | self.fan_in_fan_out = fan_in_fan_out |
| | |
| | if r > 0 and any(enable_lora): |
| | self.lora_A = nn.Parameter( |
| | self.weight.new_zeros((r * sum(enable_lora), in_features))) |
| | self.lora_B = nn.Parameter( |
| | self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) |
| | ) |
| | self.scaling = self.lora_alpha / self.r |
| | |
| | self.weight.requires_grad = False |
| | |
| | self.lora_ind = self.weight.new_zeros( |
| | (out_features, ), dtype=torch.bool |
| | ).view(len(enable_lora), -1) |
| | self.lora_ind[enable_lora, :] = True |
| | self.lora_ind = self.lora_ind.view(-1) |
| | self.reset_parameters() |
| | if fan_in_fan_out: |
| | self.weight.data = self.weight.data.transpose(0, 1) |
| |
|
| | def reset_parameters(self): |
| | nn.Linear.reset_parameters(self) |
| | if hasattr(self, 'lora_A'): |
| | |
| | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B) |
| |
|
| | def zero_pad(self, x): |
| | result = x.new_zeros((len(self.lora_ind), *x.shape[1:])) |
| | result[self.lora_ind] = x |
| | return result |
| |
|
| | def merge_AB(self): |
| | def T(w): |
| | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | delta_w = F.conv1d( |
| | self.lora_A.unsqueeze(0), |
| | self.lora_B.unsqueeze(-1), |
| | groups=sum(self.enable_lora) |
| | ).squeeze(0) |
| | return T(self.zero_pad(delta_w)) |
| |
|
| | def train(self, mode: bool = True): |
| | def T(w): |
| | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | nn.Linear.train(self, mode) |
| | if mode: |
| | if self.merge_weights and self.merged: |
| | |
| | if self.r > 0 and any(self.enable_lora): |
| | self.weight.data -= self.merge_AB() * self.scaling |
| | self.merged = False |
| | else: |
| | if self.merge_weights and not self.merged: |
| | |
| | if self.r > 0 and any(self.enable_lora): |
| | self.weight.data += self.merge_AB() * self.scaling |
| | self.merged = True |
| |
|
| | def forward(self, x: torch.Tensor): |
| | def T(w): |
| | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | if self.merged: |
| | return F.linear(x, T(self.weight), bias=self.bias) |
| | else: |
| | result = F.linear(x, T(self.weight), bias=self.bias) |
| | if self.r > 0: |
| | result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling |
| | return result |
| |
|
| | class ConvLoRA(nn.Module, LoRALayer): |
| | def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): |
| | super(ConvLoRA, self).__init__() |
| | self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) |
| | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) |
| | assert isinstance(kernel_size, int) |
| | |
| | if r > 0: |
| | self.lora_A = nn.Parameter( |
| | self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) |
| | ) |
| | self.lora_B = nn.Parameter( |
| | self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)) |
| | ) |
| | self.scaling = self.lora_alpha / self.r |
| | |
| | self.conv.weight.requires_grad = False |
| | self.reset_parameters() |
| | self.merged = False |
| |
|
| | def reset_parameters(self): |
| | self.conv.reset_parameters() |
| | if hasattr(self, 'lora_A'): |
| | |
| | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| | nn.init.zeros_(self.lora_B) |
| |
|
| | def train(self, mode=True): |
| | super(ConvLoRA, self).train(mode) |
| | if mode: |
| | if self.merge_weights and self.merged: |
| | if self.r > 0: |
| | |
| | self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling |
| | self.merged = False |
| | else: |
| | if self.merge_weights and not self.merged: |
| | if self.r > 0: |
| | |
| | self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling |
| | self.merged = True |
| |
|
| | def forward(self, x): |
| | if self.r > 0 and not self.merged: |
| | return self.conv._conv_forward( |
| | x, |
| | self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, |
| | self.conv.bias |
| | ) |
| | return self.conv(x) |
| |
|
| | class Conv2d(ConvLoRA): |
| | def __init__(self, *args, **kwargs): |
| | super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs) |
| |
|
| | class Conv1d(ConvLoRA): |
| | def __init__(self, *args, **kwargs): |
| | super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs) |
| |
|
| | |
| |
|
| | class Conv3d(ConvLoRA): |
| | def __init__(self, *args, **kwargs): |
| | super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs) |
| |
|
| |
|
| | adapter_dict = { |
| | 'lora': LoRALinear, |
| | 'oft': OFTLinear, |
| | 'hra': HRALinear, |
| | } |
| |
|