File size: 3,279 Bytes
146a630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import types
import math
from torch import nn
import torch.nn.functional as F


QWEN2_TARGET_MODULES = [
    "self_attn.q_proj",
    "self_attn.k_proj",
    "self_attn.v_proj",
    "self_attn.o_proj",
    "mlp.up_proj",
    "mlp.gate_proj",
    "mlp.down_proj",
]


class LoRALayer(nn.Linear):
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        r: int = 1024,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features)
        if r < 0:
            self.forward = self.naive_forward
        else:
            # we elimate lora_alpha here bc we find it unnecessary in VoRA
            self.lora_A = nn.Linear(in_features, r, bias=False)
            self.lora_B = nn.Linear(r, out_features, bias=False)
            nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B.weight)
        
    def forward(self, x: torch.Tensor):
        intermediate = F.linear(x, self.weight, bias=self.bias)
        result = intermediate + self.lora_B(self.lora_A(x))
        return result

    def naive_forward(self, x: torch.Tensor):
        return F.linear(x, self.weight, bias=self.bias)

def _get_submodules(self, key):
    parent = self.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = self.get_submodule(key)
    return parent, target, target_name

def _find_and_replace(self, lora_params):
    target_modules = lora_params["target_modules"]

    for llm_module_name in target_modules:
        parent, target, target_name = self._get_submodules(llm_module_name)
        vora_layer = LoRALayer(
            target.in_features,
            target.out_features,
            **lora_params
        )
        self._replace_module(parent, target_name, vora_layer, target)

def _replace_module(self, parent_module, child_name, new_module, old_module):
    setattr(parent_module, child_name, new_module)
    new_module.weight = old_module.weight
    if old_module.bias is not None:
        new_module.bias = old_module.bias
    if getattr(old_module, "state", None) is not None:
        new_module.state = old_module.state
        new_module.to(old_module.weight.device)

def apply_lora(llm, lora_params={"layers": "all", "r": 1024, "target_modules": QWEN2_TARGET_MODULES}):
    llm_num_layers = llm.config.num_hidden_layers
    total_layers = lora_params.get("layers", "all")

    # -------------------- validation check ---------------------
    if isinstance(total_layers, str):
        if total_layers.lower() == "all":
            total_layers = list(range(llm_num_layers))
    else:
        assert isinstance(total_layers, int), "total_layers must be an integer or 'all'"
        total_layers = list(range(total_layers))
    # -------------------- validation check ---------------------

    # -------------------- replace llm layers ---------------------
    for i in total_layers:
        llm_layer = llm.model.layers[i]
        llm_layer._get_submodules = types.MethodType(_get_submodules, llm_layer)
        llm_layer._find_and_replace = types.MethodType(_find_and_replace, llm_layer)
        llm_layer._replace_module = types.MethodType(_replace_module, llm_layer)
        llm_layer._find_and_replace(lora_params)