File size: 3,279 Bytes
146a630 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import types
import math
from torch import nn
import torch.nn.functional as F
QWEN2_TARGET_MODULES = [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.up_proj",
"mlp.gate_proj",
"mlp.down_proj",
]
class LoRALayer(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 1024,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features)
if r < 0:
self.forward = self.naive_forward
else:
# we elimate lora_alpha here bc we find it unnecessary in VoRA
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor):
intermediate = F.linear(x, self.weight, bias=self.bias)
result = intermediate + self.lora_B(self.lora_A(x))
return result
def naive_forward(self, x: torch.Tensor):
return F.linear(x, self.weight, bias=self.bias)
def _get_submodules(self, key):
parent = self.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = self.get_submodule(key)
return parent, target, target_name
def _find_and_replace(self, lora_params):
target_modules = lora_params["target_modules"]
for llm_module_name in target_modules:
parent, target, target_name = self._get_submodules(llm_module_name)
vora_layer = LoRALayer(
target.in_features,
target.out_features,
**lora_params
)
self._replace_module(parent, target_name, vora_layer, target)
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
new_module.weight = old_module.weight
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
def apply_lora(llm, lora_params={"layers": "all", "r": 1024, "target_modules": QWEN2_TARGET_MODULES}):
llm_num_layers = llm.config.num_hidden_layers
total_layers = lora_params.get("layers", "all")
# -------------------- validation check ---------------------
if isinstance(total_layers, str):
if total_layers.lower() == "all":
total_layers = list(range(llm_num_layers))
else:
assert isinstance(total_layers, int), "total_layers must be an integer or 'all'"
total_layers = list(range(total_layers))
# -------------------- validation check ---------------------
# -------------------- replace llm layers ---------------------
for i in total_layers:
llm_layer = llm.model.layers[i]
llm_layer._get_submodules = types.MethodType(_get_submodules, llm_layer)
llm_layer._find_and_replace = types.MethodType(_find_and_replace, llm_layer)
llm_layer._replace_module = types.MethodType(_replace_module, llm_layer)
llm_layer._find_and_replace(lora_params)
|