File size: 2,933 Bytes
53f0cc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """
Simple LoRA implementation for custom PyTorch transformer modules.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Iterable, List
import torch
import torch.nn as nn
@dataclass
class LoRAConfig:
r: int = 8
alpha: int = 16
dropout: float = 0.05
target_keywords: List[str] = None # type: ignore[assignment]
def __post_init__(self) -> None:
if self.target_keywords is None:
self.target_keywords = ["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"]
class LoRALinear(nn.Module):
def __init__(self, base: nn.Linear, r: int, alpha: int, dropout: float) -> None:
super().__init__()
if base.bias is not None:
# Keep implementation simple and stable for current model (bias=False modules).
raise ValueError("LoRALinear expects base Linear with bias=None in this project.")
self.base = base
self.base.weight.requires_grad = False
self.in_features = base.in_features
self.out_features = base.out_features
self.r = r
self.scaling = alpha / max(1, r)
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
self.dropout = nn.Dropout(dropout)
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def forward(self, x: torch.Tensor) -> torch.Tensor:
base_out = self.base(x)
lora_out = self.dropout(x) @ self.lora_A.t() @ self.lora_B.t()
return base_out + (self.scaling * lora_out)
def _replace_module(root: nn.Module, dotted_name: str, new_module: nn.Module) -> None:
parts = dotted_name.split(".")
parent = root
for p in parts[:-1]:
parent = getattr(parent, p)
setattr(parent, parts[-1], new_module)
def apply_lora(model: nn.Module, cfg: LoRAConfig) -> List[str]:
replaced: List[str] = []
for name, module in list(model.named_modules()):
if not isinstance(module, nn.Linear):
continue
if not any(k in name for k in cfg.target_keywords):
continue
lora_mod = LoRALinear(base=module, r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
_replace_module(model, name, lora_mod)
replaced.append(name)
# Freeze everything except LoRA params.
for p in model.parameters():
p.requires_grad = False
for n, p in model.named_parameters():
if "lora_A" in n or "lora_B" in n:
p.requires_grad = True
return replaced
def lora_state_dict(model: nn.Module) -> dict:
return {k: v.detach().cpu() for k, v in model.state_dict().items() if ("lora_A" in k or "lora_B" in k)}
def load_lora_state_dict(model: nn.Module, state: dict) -> None:
own = model.state_dict()
for k, v in state.items():
if k in own:
own[k].copy_(v)
|