|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import functools |
|
|
from collections import defaultdict |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class LoRAUPParallel(torch.nn.Module): |
|
|
def __init__(self, blocks): |
|
|
super().__init__() |
|
|
self.blocks = torch.nn.ModuleList(blocks) |
|
|
|
|
|
def forward(self, x): |
|
|
assert x.shape[-1] % len(self.blocks) == 0 |
|
|
xs = torch.chunk(x, len(self.blocks), dim=-1) |
|
|
out = torch.cat([self.blocks[i](xs[i]) for i in range(len(self.blocks))], dim=-1) |
|
|
return out |
|
|
|
|
|
|
|
|
class LoRAModule(torch.nn.Module): |
|
|
""" |
|
|
replaces forward method of the original Linear, instead of replacing the original Linear module. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lora_name, |
|
|
org_module: torch.nn.Module, |
|
|
multiplier=1.0, |
|
|
lora_dim=4, |
|
|
alpha=1, |
|
|
n_seperate=1 |
|
|
): |
|
|
super().__init__() |
|
|
self.lora_name = lora_name |
|
|
|
|
|
assert org_module.__class__.__name__ == "Linear" |
|
|
in_dim = org_module.in_features |
|
|
out_dim = org_module.out_features |
|
|
|
|
|
if n_seperate > 1: |
|
|
assert out_dim % n_seperate == 0 |
|
|
|
|
|
self.lora_dim = lora_dim |
|
|
if n_seperate > 1: |
|
|
self.lora_down = torch.nn.Linear(in_dim, n_seperate * self.lora_dim, bias=False) |
|
|
self.lora_up = LoRAUPParallel([torch.nn.Linear(self.lora_dim, out_dim // n_seperate, bias=False) for _ in range(n_seperate)]) |
|
|
else: |
|
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) |
|
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) |
|
|
|
|
|
if type(alpha) == torch.Tensor: |
|
|
alpha = alpha.detach().float().numpy() |
|
|
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha |
|
|
alpha_scale = alpha / self.lora_dim |
|
|
self.register_buffer("alpha_scale", torch.tensor(alpha_scale)) |
|
|
|
|
|
|
|
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) |
|
|
if n_seperate > 1: |
|
|
for block in self.lora_up.blocks: |
|
|
torch.nn.init.zeros_(block.weight) |
|
|
else: |
|
|
torch.nn.init.zeros_(self.lora_up.weight) |
|
|
|
|
|
self.multiplier = multiplier |
|
|
self.use_lora = True |
|
|
|
|
|
def set_use_lora(self, use_lora): |
|
|
self.use_lora = use_lora |
|
|
|
|
|
|
|
|
class LoRANetwork(torch.nn.Module): |
|
|
|
|
|
LORA_PREFIX = "lora" |
|
|
LORA_HYPHEN = "___lorahyphen___" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
lora_network_state_dict_loaded, |
|
|
multiplier: float = 1.0, |
|
|
lora_dim: int = 128, |
|
|
alpha: float = 64, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.multiplier = multiplier |
|
|
self.use_lora = True |
|
|
self.lora_dim = lora_dim |
|
|
self.alpha = alpha |
|
|
|
|
|
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") |
|
|
|
|
|
lora_module_names = set() |
|
|
for key in lora_network_state_dict_loaded.keys(): |
|
|
if key.endswith("lora_down.weight"): |
|
|
lora_name = key.split(".lora_down.weight")[0] |
|
|
lora_module_names.add(lora_name) |
|
|
|
|
|
loras = [] |
|
|
for lora_name in lora_module_names: |
|
|
|
|
|
module_name = lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".") |
|
|
|
|
|
try: |
|
|
module = model |
|
|
for part in module_name.split('.'): |
|
|
module = getattr(module, part) |
|
|
except Exception as e: |
|
|
print(f"Cannot find module: {module_name}, error: {e}") |
|
|
continue |
|
|
if module.__class__.__name__ != "Linear": |
|
|
continue |
|
|
|
|
|
|
|
|
n_seperate = 1 |
|
|
prefix = lora_name + ".lora_up.blocks" |
|
|
n_blocks = sum(1 for k in lora_network_state_dict_loaded if k.startswith(prefix)) |
|
|
if n_blocks > 0: |
|
|
n_seperate = n_blocks |
|
|
|
|
|
dim = self.lora_dim |
|
|
alpha = self.alpha |
|
|
|
|
|
lora = LoRAModule( |
|
|
lora_name, |
|
|
module, |
|
|
self.multiplier, |
|
|
dim, |
|
|
alpha, |
|
|
n_seperate=n_seperate |
|
|
) |
|
|
loras.append(lora) |
|
|
|
|
|
self.loras = loras |
|
|
for lora in self.loras: |
|
|
self.add_module(lora.lora_name, lora) |
|
|
print(f"create LoRA for model: {len(self.loras)} modules.") |
|
|
|
|
|
|
|
|
names = set() |
|
|
for lora in self.loras: |
|
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" |
|
|
names.add(lora.lora_name) |
|
|
|
|
|
def disapply_to(self): |
|
|
for lora in self.loras: |
|
|
lora.disapply_to() |
|
|
|
|
|
def set_multiplier(self, multiplier): |
|
|
self.multiplier = multiplier |
|
|
for lora in self.loras: |
|
|
lora.multiplier = self.multiplier |
|
|
|
|
|
def set_use_lora(self, use_lora): |
|
|
self.use_lora = use_lora |
|
|
for lora in self.loras: |
|
|
lora.set_use_lora(use_lora) |
|
|
|
|
|
def prepare_optimizer_params(self, lr): |
|
|
self.requires_grad_(True) |
|
|
all_params = [] |
|
|
|
|
|
params = [] |
|
|
for lora in self.loras: |
|
|
params.extend(lora.parameters()) |
|
|
|
|
|
param_data = {"params": params} |
|
|
param_data["lr"] = lr |
|
|
all_params.append(param_data) |
|
|
|
|
|
return all_params |
|
|
|
|
|
|
|
|
def create_lora_network( |
|
|
transformer, |
|
|
lora_network_state_dict_loaded, |
|
|
multiplier: float, |
|
|
network_dim: Optional[int], |
|
|
network_alpha: Optional[float], |
|
|
): |
|
|
network = LoRANetwork( |
|
|
transformer, |
|
|
lora_network_state_dict_loaded, |
|
|
multiplier=multiplier, |
|
|
lora_dim=network_dim, |
|
|
alpha=network_alpha, |
|
|
) |
|
|
return network |
|
|
|
|
|
|