ayjays132's picture
Upload 478 files
101858b verified
"""Minimal LoRA wrappers + injector for fine-tuning a frozen base model.
LoRALinear / LoRAConv2d: forward = frozen_base(x) + scaling * B(A(x))
where A: (in -> r), B: (r -> out). A is Kaiming init, B is zero init,
so the wrapped module starts as an exact identity to the base layer.
inject_lora(model, ...) walks ``model.named_modules()`` and replaces target
Linear/Conv2d layers in-place. The original base weights remain on the
module (just .requires_grad_(False)); only the LoRA A/B matrices train.
This is intentionally tiny — no scaling schedules, no rank-stabilization,
no merging. If you need PEFT's full feature set, install peft. For our
single-checkpoint fine-tune use case this is enough.
"""
from __future__ import annotations
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
def __init__(self, base: nn.Linear, rank: int, alpha: Optional[float] = None):
super().__init__()
if not isinstance(base, nn.Linear):
raise TypeError(f"LoRALinear expects nn.Linear, got {type(base).__name__}")
self.base = base
for p in self.base.parameters():
p.requires_grad_(False)
self.rank = int(rank)
self.alpha = float(alpha) if alpha is not None else float(rank)
self.scaling = self.alpha / self.rank
self.lora_A = nn.Linear(base.in_features, self.rank, bias=False)
self.lora_B = nn.Linear(self.rank, base.out_features, bias=False)
nn.init.kaiming_uniform_(self.lora_A.weight, a=5 ** 0.5)
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.base(x) + self.lora_B(self.lora_A(x)) * self.scaling
class LoRAConv2d(nn.Module):
"""Rank-r low-rank decomposition for a Conv2d: A is 1x1 (in->r), B is
the original kernel size (r->out). Adds to the base conv output."""
def __init__(self, base: nn.Conv2d, rank: int, alpha: Optional[float] = None):
super().__init__()
if not isinstance(base, nn.Conv2d):
raise TypeError(f"LoRAConv2d expects nn.Conv2d, got {type(base).__name__}")
self.base = base
for p in self.base.parameters():
p.requires_grad_(False)
self.rank = int(rank)
self.alpha = float(alpha) if alpha is not None else float(rank)
self.scaling = self.alpha / self.rank
self.lora_A = nn.Conv2d(
base.in_channels, self.rank,
kernel_size=1, stride=1, padding=0, bias=False,
)
self.lora_B = nn.Conv2d(
self.rank, base.out_channels,
kernel_size=base.kernel_size,
stride=base.stride,
padding=base.padding,
dilation=base.dilation,
groups=1,
bias=False,
)
nn.init.kaiming_uniform_(self.lora_A.weight, a=5 ** 0.5)
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.base(x) + self.lora_B(self.lora_A(x)) * self.scaling
def _module_matches(name: str, patterns: Iterable[str]) -> bool:
return any(p in name for p in patterns)
def inject_lora(
root: nn.Module,
target_substrings: Iterable[str],
rank: int = 16,
alpha: Optional[float] = None,
include_linear: bool = True,
include_conv2d: bool = True,
skip_substrings: Iterable[str] = (),
) -> Tuple[int, List[str]]:
"""Replace target Linear / Conv2d layers under ``root`` with LoRA wrappers.
Returns (count, names_replaced).
The walk does a snapshot of ``named_modules()`` first so we can mutate
parents during iteration. Skips ``root.text_model`` and any module whose
qualified name contains one of ``skip_substrings``.
"""
if not target_substrings:
return 0, []
skip_substrings = list(skip_substrings) + ["text_model"]
targets = list(target_substrings)
snapshot = list(root.named_modules())
replaced: List[str] = []
count = 0
for qname, module in snapshot:
if not qname:
continue
if _module_matches(qname, skip_substrings):
continue
if not _module_matches(qname, targets):
continue
if include_linear and isinstance(module, nn.Linear):
new_mod = LoRALinear(module, rank=rank, alpha=alpha)
elif include_conv2d and isinstance(module, nn.Conv2d):
new_mod = LoRAConv2d(module, rank=rank, alpha=alpha)
else:
continue
# Set on parent
parent_path, _, leaf = qname.rpartition(".")
parent = root.get_submodule(parent_path) if parent_path else root
setattr(parent, leaf, new_mod)
replaced.append(qname)
count += 1
return count, replaced
def lora_parameter_count(root: nn.Module) -> int:
n = 0
for m in root.modules():
if isinstance(m, (LoRALinear, LoRAConv2d)):
n += sum(p.numel() for p in m.lora_A.parameters())
n += sum(p.numel() for p in m.lora_B.parameters())
return n