File size: 2,933 Bytes
53f0cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Simple LoRA implementation for custom PyTorch transformer modules.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Iterable, List

import torch
import torch.nn as nn


@dataclass
class LoRAConfig:
    r: int = 8
    alpha: int = 16
    dropout: float = 0.05
    target_keywords: List[str] = None  # type: ignore[assignment]

    def __post_init__(self) -> None:
        if self.target_keywords is None:
            self.target_keywords = ["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"]


class LoRALinear(nn.Module):
    def __init__(self, base: nn.Linear, r: int, alpha: int, dropout: float) -> None:
        super().__init__()
        if base.bias is not None:
            # Keep implementation simple and stable for current model (bias=False modules).
            raise ValueError("LoRALinear expects base Linear with bias=None in this project.")

        self.base = base
        self.base.weight.requires_grad = False

        self.in_features = base.in_features
        self.out_features = base.out_features
        self.r = r
        self.scaling = alpha / max(1, r)

        self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
        self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
        self.dropout = nn.Dropout(dropout)

        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_out = self.base(x)
        lora_out = self.dropout(x) @ self.lora_A.t() @ self.lora_B.t()
        return base_out + (self.scaling * lora_out)


def _replace_module(root: nn.Module, dotted_name: str, new_module: nn.Module) -> None:
    parts = dotted_name.split(".")
    parent = root
    for p in parts[:-1]:
        parent = getattr(parent, p)
    setattr(parent, parts[-1], new_module)


def apply_lora(model: nn.Module, cfg: LoRAConfig) -> List[str]:
    replaced: List[str] = []
    for name, module in list(model.named_modules()):
        if not isinstance(module, nn.Linear):
            continue
        if not any(k in name for k in cfg.target_keywords):
            continue
        lora_mod = LoRALinear(base=module, r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
        _replace_module(model, name, lora_mod)
        replaced.append(name)

    # Freeze everything except LoRA params.
    for p in model.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if "lora_A" in n or "lora_B" in n:
            p.requires_grad = True

    return replaced


def lora_state_dict(model: nn.Module) -> dict:
    return {k: v.detach().cpu() for k, v in model.state_dict().items() if ("lora_A" in k or "lora_B" in k)}


def load_lora_state_dict(model: nn.Module, state: dict) -> None:
    own = model.state_dict()
    for k, v in state.items():
        if k in own:
            own[k].copy_(v)