|
|
|
|
|
|
|
|
|
|
| import math
|
| import functools
|
| from collections import defaultdict
|
|
|
| from typing import Optional
|
|
|
| import torch
|
|
|
|
|
| class LoRAUPParallel(torch.nn.Module):
|
| def __init__(self, blocks):
|
| super().__init__()
|
| self.blocks = torch.nn.ModuleList(blocks)
|
|
|
| def forward(self, x):
|
| assert x.shape[-1] % len(self.blocks) == 0
|
| xs = torch.chunk(x, len(self.blocks), dim=-1)
|
| out = torch.cat([self.blocks[i](xs[i]) for i in range(len(self.blocks))], dim=-1)
|
| return out
|
|
|
|
|
| class LoRAModule(torch.nn.Module):
|
| """
|
| replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| """
|
|
|
| def __init__(
|
| self,
|
| lora_name,
|
| org_module: torch.nn.Module,
|
| multiplier=1.0,
|
| lora_dim=4,
|
| alpha=1,
|
| n_seperate=1
|
| ):
|
| super().__init__()
|
| self.lora_name = lora_name
|
|
|
| assert org_module.__class__.__name__ == "Linear"
|
| in_dim = org_module.in_features
|
| out_dim = org_module.out_features
|
|
|
| if n_seperate > 1:
|
| assert out_dim % n_seperate == 0
|
|
|
| self.lora_dim = lora_dim
|
| if n_seperate > 1:
|
| self.lora_down = torch.nn.Linear(in_dim, n_seperate * self.lora_dim, bias=False)
|
| self.lora_up = LoRAUPParallel([torch.nn.Linear(self.lora_dim, out_dim // n_seperate, bias=False) for _ in range(n_seperate)])
|
| else:
|
| self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
|
|
| if type(alpha) == torch.Tensor:
|
| alpha = alpha.detach().float().numpy()
|
| alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| alpha_scale = alpha / self.lora_dim
|
| self.register_buffer("alpha_scale", torch.tensor(alpha_scale))
|
|
|
|
|
| torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| if n_seperate > 1:
|
| for block in self.lora_up.blocks:
|
| torch.nn.init.zeros_(block.weight)
|
| else:
|
| torch.nn.init.zeros_(self.lora_up.weight)
|
|
|
| self.multiplier = multiplier
|
| self.use_lora = True
|
|
|
| def set_use_lora(self, use_lora):
|
| self.use_lora = use_lora
|
|
|
|
|
| class LoRANetwork(torch.nn.Module):
|
|
|
| LORA_PREFIX = "lora"
|
| LORA_HYPHEN = "___lorahyphen___"
|
|
|
| def __init__(
|
| self,
|
| model,
|
| lora_network_state_dict_loaded,
|
| multiplier: float = 1.0,
|
| lora_dim: int = 128,
|
| alpha: float = 64,
|
| ) -> None:
|
| super().__init__()
|
| self.multiplier = multiplier
|
| self.use_lora = True
|
| self.lora_dim = lora_dim
|
| self.alpha = alpha
|
|
|
| print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
|
|
| lora_module_names = set()
|
| for key in lora_network_state_dict_loaded.keys():
|
| if key.endswith("lora_down.weight"):
|
| lora_name = key.split(".lora_down.weight")[0]
|
| lora_module_names.add(lora_name)
|
|
|
| loras = []
|
| for lora_name in lora_module_names:
|
|
|
| module_name = lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
|
|
|
| try:
|
| module = model
|
| for part in module_name.split('.'):
|
| module = getattr(module, part)
|
| except Exception as e:
|
| print(f"Cannot find module: {module_name}, error: {e}")
|
| continue
|
| if module.__class__.__name__ != "Linear":
|
| continue
|
|
|
|
|
| n_seperate = 1
|
| prefix = lora_name + ".lora_up.blocks"
|
| n_blocks = sum(1 for k in lora_network_state_dict_loaded if k.startswith(prefix))
|
| if n_blocks > 0:
|
| n_seperate = n_blocks
|
|
|
| dim = self.lora_dim
|
| alpha = self.alpha
|
|
|
| lora = LoRAModule(
|
| lora_name,
|
| module,
|
| self.multiplier,
|
| dim,
|
| alpha,
|
| n_seperate=n_seperate
|
| )
|
| loras.append(lora)
|
|
|
| self.loras = loras
|
| for lora in self.loras:
|
| self.add_module(lora.lora_name, lora)
|
| print(f"create LoRA for model: {len(self.loras)} modules.")
|
|
|
|
|
| names = set()
|
| for lora in self.loras:
|
| assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| names.add(lora.lora_name)
|
|
|
| def disapply_to(self):
|
| for lora in self.loras:
|
| lora.disapply_to()
|
|
|
| def set_multiplier(self, multiplier):
|
| self.multiplier = multiplier
|
| for lora in self.loras:
|
| lora.multiplier = self.multiplier
|
|
|
| def set_use_lora(self, use_lora):
|
| self.use_lora = use_lora
|
| for lora in self.loras:
|
| lora.set_use_lora(use_lora)
|
|
|
| def prepare_optimizer_params(self, lr):
|
| self.requires_grad_(True)
|
| all_params = []
|
|
|
| params = []
|
| for lora in self.loras:
|
| params.extend(lora.parameters())
|
|
|
| param_data = {"params": params}
|
| param_data["lr"] = lr
|
| all_params.append(param_data)
|
|
|
| return all_params
|
|
|
|
|
| def create_lora_network(
|
| transformer,
|
| lora_network_state_dict_loaded,
|
| multiplier: float,
|
| network_dim: Optional[int],
|
| network_alpha: Optional[float],
|
| ):
|
| network = LoRANetwork(
|
| transformer,
|
| lora_network_state_dict_loaded,
|
| multiplier=multiplier,
|
| lora_dim=network_dim,
|
| alpha=network_alpha,
|
| )
|
| return network
|
|
|
|
|