""" Simple LoRA implementation for custom PyTorch transformer modules. """ from __future__ import annotations import math from dataclasses import dataclass from typing import Iterable, List import torch import torch.nn as nn @dataclass class LoRAConfig: r: int = 8 alpha: int = 16 dropout: float = 0.05 target_keywords: List[str] = None # type: ignore[assignment] def __post_init__(self) -> None: if self.target_keywords is None: self.target_keywords = ["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"] class LoRALinear(nn.Module): def __init__(self, base: nn.Linear, r: int, alpha: int, dropout: float) -> None: super().__init__() if base.bias is not None: # Keep implementation simple and stable for current model (bias=False modules). raise ValueError("LoRALinear expects base Linear with bias=None in this project.") self.base = base self.base.weight.requires_grad = False self.in_features = base.in_features self.out_features = base.out_features self.r = r self.scaling = alpha / max(1, r) self.lora_A = nn.Parameter(torch.zeros(r, self.in_features)) self.lora_B = nn.Parameter(torch.zeros(self.out_features, r)) self.dropout = nn.Dropout(dropout) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def forward(self, x: torch.Tensor) -> torch.Tensor: base_out = self.base(x) lora_out = self.dropout(x) @ self.lora_A.t() @ self.lora_B.t() return base_out + (self.scaling * lora_out) def _replace_module(root: nn.Module, dotted_name: str, new_module: nn.Module) -> None: parts = dotted_name.split(".") parent = root for p in parts[:-1]: parent = getattr(parent, p) setattr(parent, parts[-1], new_module) def apply_lora(model: nn.Module, cfg: LoRAConfig) -> List[str]: replaced: List[str] = [] for name, module in list(model.named_modules()): if not isinstance(module, nn.Linear): continue if not any(k in name for k in cfg.target_keywords): continue lora_mod = LoRALinear(base=module, r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout) _replace_module(model, name, lora_mod) replaced.append(name) # Freeze everything except LoRA params. for p in model.parameters(): p.requires_grad = False for n, p in model.named_parameters(): if "lora_A" in n or "lora_B" in n: p.requires_grad = True return replaced def lora_state_dict(model: nn.Module) -> dict: return {k: v.detach().cpu() for k, v in model.state_dict().items() if ("lora_A" in k or "lora_B" in k)} def load_lora_state_dict(model: nn.Module, state: dict) -> None: own = model.state_dict() for k, v in state.items(): if k in own: own[k].copy_(v)