| """ |
| Simple LoRA implementation for custom PyTorch transformer modules. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Iterable, List |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| @dataclass |
| class LoRAConfig: |
| r: int = 8 |
| alpha: int = 16 |
| dropout: float = 0.05 |
| target_keywords: List[str] = None |
|
|
| def __post_init__(self) -> None: |
| if self.target_keywords is None: |
| self.target_keywords = ["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"] |
|
|
|
|
| class LoRALinear(nn.Module): |
| def __init__(self, base: nn.Linear, r: int, alpha: int, dropout: float) -> None: |
| super().__init__() |
| if base.bias is not None: |
| |
| raise ValueError("LoRALinear expects base Linear with bias=None in this project.") |
|
|
| self.base = base |
| self.base.weight.requires_grad = False |
|
|
| self.in_features = base.in_features |
| self.out_features = base.out_features |
| self.r = r |
| self.scaling = alpha / max(1, r) |
|
|
| self.lora_A = nn.Parameter(torch.zeros(r, self.in_features)) |
| self.lora_B = nn.Parameter(torch.zeros(self.out_features, r)) |
| self.dropout = nn.Dropout(dropout) |
|
|
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| base_out = self.base(x) |
| lora_out = self.dropout(x) @ self.lora_A.t() @ self.lora_B.t() |
| return base_out + (self.scaling * lora_out) |
|
|
|
|
| def _replace_module(root: nn.Module, dotted_name: str, new_module: nn.Module) -> None: |
| parts = dotted_name.split(".") |
| parent = root |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| setattr(parent, parts[-1], new_module) |
|
|
|
|
| def apply_lora(model: nn.Module, cfg: LoRAConfig) -> List[str]: |
| replaced: List[str] = [] |
| for name, module in list(model.named_modules()): |
| if not isinstance(module, nn.Linear): |
| continue |
| if not any(k in name for k in cfg.target_keywords): |
| continue |
| lora_mod = LoRALinear(base=module, r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout) |
| _replace_module(model, name, lora_mod) |
| replaced.append(name) |
|
|
| |
| for p in model.parameters(): |
| p.requires_grad = False |
| for n, p in model.named_parameters(): |
| if "lora_A" in n or "lora_B" in n: |
| p.requires_grad = True |
|
|
| return replaced |
|
|
|
|
| def lora_state_dict(model: nn.Module) -> dict: |
| return {k: v.detach().cpu() for k, v in model.state_dict().items() if ("lora_A" in k or "lora_B" in k)} |
|
|
|
|
| def load_lora_state_dict(model: nn.Module, state: dict) -> None: |
| own = model.state_dict() |
| for k, v in state.items(): |
| if k in own: |
| own[k].copy_(v) |
|
|