| """ |
| LoRA (Low-Rank Adaptation) implementation for MLP layers. |
| Replaces qkv projections in attention and the FFN MLP layers. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| USE_LORA :bool = 1 |
| if USE_LORA: |
| LORA_dropout :float = 0.0 |
| LORA_apply_to_conv :bool = 1 |
| LORA_freeze_base :bool = False |
| LORA_DEBUG :bool = 0 |
| FORCE_SAME_RANK_ACROSS_TASKS :bool = 0 |
| DONT_lora_if_dim_lt :int = 90 |
| DONT_lora_if_rankFrac_gt :float = 0.3 |
|
|
| class LoRALinear(nn.Module): |
| """ |
| LoRA layer that wraps a frozen Linear layer with low-rank adaptation. |
| |
| Args: |
| original_linear: original nn.Linear layer that will be frozen |
| rank: LoRA rank (r) |
| dropout: dropout probability |
| """ |
| def __init__( |
| self, |
| original_linear: nn.Linear, |
| rank: int = 4, |
| dropout: float = 0.0, |
| freeze_base: bool = True, |
| ): |
| super().__init__() |
| |
| self.in_features = original_linear.in_features |
| self.out_features = original_linear.out_features |
| self.rank = rank |
| self.scaling = 2.0 |
| |
| |
| self.original_linear = original_linear |
| if freeze_base: |
| for param in self.original_linear.parameters(): |
| param.requires_grad = False |
| |
| |
| self.lora_A = nn.Parameter(torch.zeros(rank, self.in_features)) |
| self.lora_B = nn.Parameter(torch.zeros(self.out_features, rank)) |
| |
| |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B) |
| |
| |
| self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| result = self.original_linear(x) |
| |
| |
| |
| |
| |
| lora_out = self.dropout(x) @ self.lora_A.T @ self.lora_B.T |
| |
| return result + lora_out * self.scaling |
| |
| def __repr__(self): |
| return f"LoRALinear(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank}, scaling={self.scaling})" |
|
|
|
|
| def replace_linear_with_lora( |
| module: nn.Module, |
| rank: int = 4, |
| dropout: float = 0.0, |
| target_modules: list = None, |
| verbose: bool = True, |
| ): |
| """ |
| Recursively replace nn.Linear layers within a module with LoRALinear wrappers. |
| |
| Args: |
| module: module whose linear layers should be replaced |
| rank: LoRA rank |
| dropout: dropout probability |
| target_modules: specific module names to replace; None means all linears |
| e.g.: ['to_q', 'to_k', 'to_v', 'to_out'] for attention |
| ['net.0', 'net.2'] for FeedForward |
| verbose: whether to log replacements |
| |
| Returns: |
| the module with replacements applied |
| """ |
| replaced_count = 0 |
| |
| for name, child in module.named_children(): |
| |
| if target_modules is not None and name not in target_modules: |
| |
| replace_linear_with_lora(child, rank, dropout, target_modules, verbose) |
| continue |
| |
| if isinstance(child, nn.Linear): |
| |
| lora_layer = LoRALinear(child, rank=rank, dropout=dropout, freeze_base=LORA_freeze_base) |
| setattr(module, name, lora_layer) |
| replaced_count += 1 |
| if verbose: |
| print(f"[LoRA] Replaced {name}: {child.in_features} -> {child.out_features} with rank={rank}") |
| elif isinstance(child, nn.Sequential): |
| |
| new_sequential = nn.Sequential() |
| for idx, submodule in enumerate(child): |
| if isinstance(submodule, nn.Linear): |
| lora_layer = LoRALinear(submodule, rank=rank, dropout=dropout, freeze_base=LORA_freeze_base) |
| new_sequential.add_module(str(idx), lora_layer) |
| replaced_count += 1 |
| if verbose: |
| print(f"[LoRA] Replaced {name}.{idx}: {submodule.in_features} -> {submodule.out_features} with rank={rank}") |
| else: |
| new_sequential.add_module(str(idx), submodule) |
| setattr(module, name, new_sequential) |
| else: |
| |
| replace_linear_with_lora(child, rank, dropout, target_modules, verbose) |
| |
| return module |
|
|
|
|
| def count_lora_parameters(module: nn.Module): |
| """ |
| Count LoRA parameters within a module. |
| |
| Returns: |
| dict: {'trainable': trainable params, 'frozen': frozen params, 'total': total params} |
| """ |
| trainable_params = 0 |
| frozen_params = 0 |
| |
| for name, param in module.named_parameters(): |
| num_params = param.numel() |
| if param.requires_grad: |
| trainable_params += num_params |
| else: |
| frozen_params += num_params |
| |
| total_params = trainable_params + frozen_params |
| |
| return { |
| 'trainable': trainable_params, |
| 'frozen': frozen_params, |
| 'total': total_params, |
| 'trainable_ratio': trainable_params / total_params if total_params > 0 else 0, |
| } |
|
|
|
|
| def print_lora_parameters(module: nn.Module, name: str = "Model"): |
| """Print LoRA parameter statistics.""" |
| stats = count_lora_parameters(module) |
| print(f"\n{'='*60}") |
| print(f"{name} Parameter Statistics:") |
| print(f"{'='*60}") |
| print(f"Trainable params: {stats['trainable']:,} ({stats['trainable_ratio']*100:.2f}%)") |
| print(f"Frozen params: {stats['frozen']:,} ({(1-stats['trainable_ratio'])*100:.2f}%)") |
| print(f"Total params: {stats['total']:,}") |
| print(f"{'='*60}\n") |
|
|
|
|
| class LoRAConv2d(nn.Module): |
| """ |
| LoRA layer for Conv2d. |
| |
| Treat Conv2d as a matrix multiplication: |
| - flatten kernel: (out_channels, in_channels, k, k) -> (out_channels, in_channels*k*k) |
| - apply low-rank decomposition: W = W_0 + B @ A |
| |
| Args: |
| original_conv: original nn.Conv2d layer that will be frozen |
| rank: LoRA rank (r) |
| dropout: dropout probability |
| """ |
| def __init__( |
| self, |
| original_conv: nn.Conv2d, |
| rank: int = 4, |
| dropout: float = 0.0, |
| freeze_base: bool = True, |
| ): |
| super().__init__() |
| |
| self.out_channels = original_conv.out_channels |
| self.in_channels = original_conv.in_channels |
| self.kernel_size = original_conv.kernel_size |
| self.stride = original_conv.stride |
| self.padding = original_conv.padding |
| self.dilation = original_conv.dilation |
| self.groups = original_conv.groups |
| |
| self.rank = rank |
| self.scaling = 2.0 |
| |
| |
| self.original_conv = original_conv |
| if freeze_base: |
| for param in self.original_conv.parameters(): |
| param.requires_grad = False |
| |
| |
| |
| |
| self.lora_A = nn.Parameter(torch.zeros( |
| rank, |
| self.in_channels // self.groups, |
| self.kernel_size[0], |
| self.kernel_size[1] |
| )) |
| self.lora_B = nn.Parameter(torch.zeros(self.out_channels, rank, 1, 1)) |
| |
| |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B) |
| |
| |
| self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() |
| |
| print(f"param orig:lora (M) = {self.original_conv.weight.numel()/1024/1024}:{self.lora_A.numel()+self.lora_B.numel()/1024/1024}") |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| result = self.original_conv(x) |
| |
| |
| |
| x_dropped = self.dropout(x) |
| lora_out = F.conv2d( |
| x_dropped, |
| self.lora_A, |
| stride=self.stride, |
| padding=self.padding, |
| dilation=self.dilation, |
| groups=self.groups |
| ) |
| lora_out = F.conv2d(lora_out, self.lora_B) |
| |
| return result + lora_out * self.scaling |
| |
| def __repr__(self): |
| return (f"LoRAConv2d(in_channels={self.in_channels}, out_channels={self.out_channels}, " |
| f"kernel_size={self.kernel_size}, rank={self.rank}, scaling={self.scaling})") |
|
|
|
|
|
|
|
|
|
|
| def _auto_lora_rank(in_features: int, out_features: int) -> int: |
| m = min(in_features, out_features) |
| r = max(LORA_rank_min, int(round(m / max(1.0, LORA_rank_ratio)))) |
| if (LORA_rank_max is not None) and (r > LORA_rank_max): |
| r = LORA_rank_max |
| return max(1, r) |
|
|
| def _svd_low_rank(M: torch.Tensor, rank: int): |
| |
| orig_device = M.device |
| orig_dtype = M.dtype |
| if 1: |
| M = M.to(device=torch.device('cuda'), dtype=torch.float32) |
| U, S, Vh = torch.linalg.svd(M, full_matrices=False) |
| r = min(rank, U.shape[1], Vh.shape[0]) |
| U_r = U[:, :r] |
| S_r = S[:r] |
| Vh_r = Vh[:r, :] |
| S_root = torch.sqrt(torch.clamp(S_r, min=0)) |
| B = U_r @ torch.diag(S_root) |
| A = torch.diag(S_root) @ Vh_r |
| |
| B = B.to(device=orig_device, dtype=orig_dtype) |
| A = A.to(device=orig_device, dtype=orig_dtype) |
| S = S.to(device=orig_device, dtype=orig_dtype) |
| return B, A, S |
|
|
|
|
| def _svdvals_squared(M: torch.Tensor) -> torch.Tensor: |
| |
| orig_device = M.device |
| orig_dtype = M.dtype |
| if 1: |
| M = M.to(device=torch.device('cuda'), dtype=torch.float32) |
| S = torch.linalg.svdvals(M) |
| S2 = (S.float() ** 2) |
| return S2.to(device=orig_device, dtype=torch.float32) |
|
|
|
|
| def _compute_adaptive_rank_from_S2_list( |
| list_S2: list, |
| avg_threshold: float = None, |
| min_threshold: float = None, |
| max_rank: int = None, |
| ) -> int: |
| |
| assert len(list_S2) > 0 |
| if avg_threshold is None: |
| avg_threshold = ADAPTIVE_RANK_AVG_ENERGY_THRESH |
| if min_threshold is None: |
| min_threshold = ADAPTIVE_RANK_MIN_ENERGY_THRESH |
|
|
| totals = [] |
| lengths = [] |
| for s2 in list_S2: |
| assert s2.numel() > 0 |
| total = s2.sum() |
| |
| assert float(total.item()) > 0.0, "Zero energy in weight_diff; cannot determine adaptive rank" |
| totals.append(total) |
| lengths.append(int(s2.shape[0])) |
|
|
| R_cap = min(lengths) |
| if LORA_rank_max is not None: |
| R_cap = min(R_cap, int(LORA_rank_max)) |
| if max_rank is not None: |
| R_cap = min(R_cap, int(max_rank)) |
| R_cap = max(1, R_cap) |
|
|
| |
| for r in range(1, R_cap + 1): |
| ratios = [] |
| for s2, total in zip(list_S2, totals): |
| captured = s2[:r].sum() |
| ratios.append(float((captured / total).item())) |
| avg_ratio = sum(ratios) / len(ratios) |
| min_ratio = min(ratios) |
| if (avg_ratio >= avg_threshold) and (min_ratio >= min_threshold): |
| ret = min(int(R_cap), max(int(LORA_rank_min), int(r))) |
| return ret |
| |
| raise AssertionError(f"No rank satisfies avg>={avg_threshold} and min>={min_threshold} up to R_cap={R_cap}") |
|
|
| def _compute_per_task_ranks_from_S2_list( |
| list_S2: list, |
| min_threshold: float = None, |
| max_rank: int = None, |
| ) -> list: |
| |
| assert len(list_S2) > 0 |
| if min_threshold is None: |
| min_threshold = ADAPTIVE_RANK_MIN_ENERGY_THRESH |
| ret = [] |
| for i, s2 in enumerate(list_S2): |
| assert s2.numel() > 0 |
| total = s2.sum() |
| assert float(total.item()) > 0.0, "Zero energy in weight_diff; cannot determine adaptive rank" |
| R_cap = int(s2.shape[0]) |
| if LORA_rank_max is not None: |
| R_cap = min(R_cap, int(LORA_rank_max)) |
| if max_rank is not None: |
| R_cap = min(R_cap, int(max_rank)) |
| R_cap = max(1, R_cap) |
| found = R_cap |
| |
| thres_this = TASK_2_adaptive_rank_min_energy_thresh[i] if (not FORCE_SAME_RANK_ACROSS_TASKS) else min_threshold |
| for r in range(1, R_cap + 1): |
| ratio = s2[:r].sum() / total |
| if float(ratio.item()) >= float(thres_this): |
| found = r |
| break |
| ret.append(int(max(int(LORA_rank_min), int(found)))) |
| return ret |
|
|
| def compute_adaptive_rank_for_linear_diffs( |
| weight_diffs: list, |
| avg_threshold: float = None, |
| min_threshold: float = None, |
| max_rank: int = None, |
| per_task: bool = None, |
| ): |
| |
| assert isinstance(weight_diffs, (list, tuple)) and len(weight_diffs) > 0 |
| if per_task is None: |
| per_task = not FORCE_SAME_RANK_ACROSS_TASKS |
| list_S2 = [_svdvals_squared(M) for M in weight_diffs] |
| out0, in0 = weight_diffs[0].shape |
| if per_task: |
| ranks = _compute_per_task_ranks_from_S2_list(list_S2, min_threshold, max_rank) |
| print(f"[AdaptiveRank-Linear per-task] in={in0} out={out0} ranks={ranks}") |
| return ranks |
| else: |
| ret = _compute_adaptive_rank_from_S2_list(list_S2, None, min_threshold, max_rank) |
| print(f"[AdaptiveRank-Linear] in={in0} out={out0} rank={ret}") |
| return ret |
|
|
|
|
| def compute_adaptive_rank_for_conv_diffs( |
| weight_diffs: list, |
| avg_threshold: float = None, |
| min_threshold: float = None, |
| max_rank: int = None, |
| per_task: bool = None, |
| ): |
| |
| assert isinstance(weight_diffs, (list, tuple)) and len(weight_diffs) > 0 |
| if per_task is None: |
| per_task = not FORCE_SAME_RANK_ACROSS_TASKS |
| list_S2 = [] |
| for W in weight_diffs: |
| out_c, in_c, kH, kW = W.shape |
| M = W.reshape(out_c, in_c * kH * kW) |
| list_S2.append(_svdvals_squared(M)) |
| out0, in0, kH0, kW0 = weight_diffs[0].shape |
| if per_task: |
| ranks = _compute_per_task_ranks_from_S2_list(list_S2, min_threshold, max_rank) |
| print(f"[AdaptiveRank-Conv per-task] in_ch={in0} out_ch={out0} kernel=({kH0},{kW0}) ranks={ranks}") |
| return ranks |
| else: |
| ret = _compute_adaptive_rank_from_S2_list(list_S2, None, min_threshold, max_rank) |
| print(f"[AdaptiveRank-Conv] in_ch={in0} out_ch={out0} kernel=({kH0},{kW0}) rank={ret}") |
| return ret |
|
|
|
|
| class LoRAAdapterLinearOnly(nn.Module): |
| """ |
| Incremental LoRA (no base Linear) that returns x @ A^T @ B^T + bias_delta. |
| """ |
| def __init__(self, in_features: int, out_features: int, rank: int = None, dropout: float = 0.0, scaling: float = 1.0, use_bias_delta: bool = True): |
| super().__init__() |
| if rank is None: |
| rank = _auto_lora_rank(in_features, out_features) |
| self.in_features = in_features |
| self.out_features = out_features |
| self.rank = rank |
| self.scaling = scaling |
| self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() |
| self.lora_A = nn.Parameter(torch.zeros(rank, in_features)) |
| self.lora_B = nn.Parameter(torch.zeros(out_features, rank)) |
| self.use_bias_delta = use_bias_delta |
| if use_bias_delta: |
| self.lora_bias = nn.Parameter(torch.zeros(out_features)) |
| else: |
| self.register_parameter('lora_bias', None) |
| |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B) |
|
|
| @torch.no_grad() |
| def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): |
| |
| B, A, S = _svd_low_rank(weight_diff.float(), self.rank) |
| self.lora_A.copy_(A.to(self.lora_A.dtype).to(self.lora_A.device)) |
| self.lora_B.copy_(B.to(self.lora_B.dtype).to(self.lora_B.device)) |
| if self.use_bias_delta and (bias_diff is not None): |
| self.lora_bias.copy_(bias_diff) |
| if LORA_DEBUG: |
| energy_total = (S.float() ** 2).sum().item() |
| energy_top = (S[: self.rank].float() ** 2).sum().item() |
| energy_ratio = energy_top / max(1e-12, energy_total) |
| approx = (B @ A).to(weight_diff.device).to(weight_diff.dtype) |
| err = torch.linalg.norm((approx - weight_diff).float()).item() |
| base = torch.linalg.norm(weight_diff.float()).item() |
| rel_err = err / max(1e-12, base) |
| bias_norm = 0.0 if (bias_diff is None) else float(torch.linalg.norm(bias_diff.float()).item()) |
| print(f"[LoRA-Linear init] shape={tuple(weight_diff.shape)} rank={self.rank} energy={energy_ratio:.4f} rel_err={rel_err:.6f} bias_norm={bias_norm:.6f}") |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| update = self.dropout(x) @ self.lora_A.T @ self.lora_B.T |
| if self.lora_bias is not None: |
| update = update + self.lora_bias |
| return update * self.scaling |
|
|
|
|
| class LoRAAdapterConv2dOnly(nn.Module): |
| """ |
| Incremental LoRA for Conv2d: convolve with A then 1x1 B, return the delta. |
| """ |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, groups: int = 1, rank: int = None, dropout: float = 0.0, scaling: float = 1.0, use_bias_delta: bool = True): |
| super().__init__() |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size) |
| if isinstance(stride, int): |
| stride = (stride, stride) |
| if isinstance(padding, int): |
| padding = (padding, padding) |
| if isinstance(dilation, int): |
| dilation = (dilation, dilation) |
| kH, kW = kernel_size |
| if rank is None: |
| |
| rank = _auto_lora_rank(in_channels * kH * kW, out_channels) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.dilation = dilation |
| self.groups = groups |
| self.rank = rank |
| self.scaling = scaling |
| self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity() |
| |
| self.lora_A = nn.Parameter(torch.zeros(rank, in_channels // groups, kH, kW)) |
| |
| self.lora_B = nn.Parameter(torch.zeros(out_channels, rank, 1, 1)) |
| self.use_bias_delta = use_bias_delta |
| if use_bias_delta: |
| self.lora_bias = nn.Parameter(torch.zeros(out_channels)) |
| else: |
| self.register_parameter('lora_bias', None) |
| |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B) |
|
|
| @torch.no_grad() |
| def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None): |
| |
| out_c, in_c, kH, kW = weight_diff.shape |
| M = weight_diff.reshape(out_c, in_c * kH * kW) |
| B, A, S = _svd_low_rank(M.float(), self.rank) |
| A_reshaped = A.view(self.rank, in_c, kH, kW) |
| self.lora_A.copy_(A_reshaped) |
| self.lora_B.copy_(B.view(out_c, self.rank, 1, 1)) |
| if self.lora_bias is not None and (bias_diff is not None): |
| self.lora_bias.copy_(bias_diff) |
| if LORA_DEBUG: |
| energy_total = (S.float() ** 2).sum().item() |
| energy_top = (S[: self.rank].float() ** 2).sum().item() |
| energy_ratio = energy_top / max(1e-12, energy_total) |
| approx = (B @ A).to(M.device).to(M.dtype) |
| err = torch.linalg.norm((approx - M).float()).item() |
| base = torch.linalg.norm(M.float()).item() |
| rel_err = err / max(1e-12, base) |
| bias_norm = 0.0 if (bias_diff is None) else float(torch.linalg.norm(bias_diff.float()).item()) |
| print(f"[LoRA-Conv init] out_in_k=({out_c},{in_c},{kH}x{kW}) rank={self.rank} energy={energy_ratio:.4f} rel_err={rel_err:.6f} bias_norm={bias_norm:.6f}") |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_d = self.dropout(x) |
| u = F.conv2d(x_d, self.lora_A, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) |
| u = F.conv2d(u, self.lora_B) |
| if self.lora_bias is not None: |
| u = u + self.lora_bias.view(1, -1, 1, 1) |
| return u * self.scaling |
|
|