"""Minimal LoRA wrappers + injector for fine-tuning a frozen base model. LoRALinear / LoRAConv2d: forward = frozen_base(x) + scaling * B(A(x)) where A: (in -> r), B: (r -> out). A is Kaiming init, B is zero init, so the wrapped module starts as an exact identity to the base layer. inject_lora(model, ...) walks ``model.named_modules()`` and replaces target Linear/Conv2d layers in-place. The original base weights remain on the module (just .requires_grad_(False)); only the LoRA A/B matrices train. This is intentionally tiny — no scaling schedules, no rank-stabilization, no merging. If you need PEFT's full feature set, install peft. For our single-checkpoint fine-tune use case this is enough. """ from __future__ import annotations from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn class LoRALinear(nn.Module): def __init__(self, base: nn.Linear, rank: int, alpha: Optional[float] = None): super().__init__() if not isinstance(base, nn.Linear): raise TypeError(f"LoRALinear expects nn.Linear, got {type(base).__name__}") self.base = base for p in self.base.parameters(): p.requires_grad_(False) self.rank = int(rank) self.alpha = float(alpha) if alpha is not None else float(rank) self.scaling = self.alpha / self.rank self.lora_A = nn.Linear(base.in_features, self.rank, bias=False) self.lora_B = nn.Linear(self.rank, base.out_features, bias=False) nn.init.kaiming_uniform_(self.lora_A.weight, a=5 ** 0.5) nn.init.zeros_(self.lora_B.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.base(x) + self.lora_B(self.lora_A(x)) * self.scaling class LoRAConv2d(nn.Module): """Rank-r low-rank decomposition for a Conv2d: A is 1x1 (in->r), B is the original kernel size (r->out). Adds to the base conv output.""" def __init__(self, base: nn.Conv2d, rank: int, alpha: Optional[float] = None): super().__init__() if not isinstance(base, nn.Conv2d): raise TypeError(f"LoRAConv2d expects nn.Conv2d, got {type(base).__name__}") self.base = base for p in self.base.parameters(): p.requires_grad_(False) self.rank = int(rank) self.alpha = float(alpha) if alpha is not None else float(rank) self.scaling = self.alpha / self.rank self.lora_A = nn.Conv2d( base.in_channels, self.rank, kernel_size=1, stride=1, padding=0, bias=False, ) self.lora_B = nn.Conv2d( self.rank, base.out_channels, kernel_size=base.kernel_size, stride=base.stride, padding=base.padding, dilation=base.dilation, groups=1, bias=False, ) nn.init.kaiming_uniform_(self.lora_A.weight, a=5 ** 0.5) nn.init.zeros_(self.lora_B.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.base(x) + self.lora_B(self.lora_A(x)) * self.scaling def _module_matches(name: str, patterns: Iterable[str]) -> bool: return any(p in name for p in patterns) def inject_lora( root: nn.Module, target_substrings: Iterable[str], rank: int = 16, alpha: Optional[float] = None, include_linear: bool = True, include_conv2d: bool = True, skip_substrings: Iterable[str] = (), ) -> Tuple[int, List[str]]: """Replace target Linear / Conv2d layers under ``root`` with LoRA wrappers. Returns (count, names_replaced). The walk does a snapshot of ``named_modules()`` first so we can mutate parents during iteration. Skips ``root.text_model`` and any module whose qualified name contains one of ``skip_substrings``. """ if not target_substrings: return 0, [] skip_substrings = list(skip_substrings) + ["text_model"] targets = list(target_substrings) snapshot = list(root.named_modules()) replaced: List[str] = [] count = 0 for qname, module in snapshot: if not qname: continue if _module_matches(qname, skip_substrings): continue if not _module_matches(qname, targets): continue if include_linear and isinstance(module, nn.Linear): new_mod = LoRALinear(module, rank=rank, alpha=alpha) elif include_conv2d and isinstance(module, nn.Conv2d): new_mod = LoRAConv2d(module, rank=rank, alpha=alpha) else: continue # Set on parent parent_path, _, leaf = qname.rpartition(".") parent = root.get_submodule(parent_path) if parent_path else root setattr(parent, leaf, new_mod) replaced.append(qname) count += 1 return count, replaced def lora_parameter_count(root: nn.Module) -> int: n = 0 for m in root.modules(): if isinstance(m, (LoRALinear, LoRAConv2d)): n += sum(p.numel() for p in m.lora_A.parameters()) n += sum(p.numel() for p in m.lora_B.parameters()) return n