vidfom's picture
Upload folder using huggingface_hub
31112ad verified
# References:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# https://github.com/bmaltais/kohya_ss
import math
import functools
from collections import defaultdict
from typing import Optional
import torch
class LoRAUPParallel(torch.nn.Module):
def __init__(self, blocks):
super().__init__()
self.blocks = torch.nn.ModuleList(blocks)
def forward(self, x):
assert x.shape[-1] % len(self.blocks) == 0
xs = torch.chunk(x, len(self.blocks), dim=-1)
out = torch.cat([self.blocks[i](xs[i]) for i in range(len(self.blocks))], dim=-1)
return out
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
n_seperate=1
):
super().__init__()
self.lora_name = lora_name
assert org_module.__class__.__name__ == "Linear"
in_dim = org_module.in_features
out_dim = org_module.out_features
if n_seperate > 1:
assert out_dim % n_seperate == 0
self.lora_dim = lora_dim
if n_seperate > 1:
self.lora_down = torch.nn.Linear(in_dim, n_seperate * self.lora_dim, bias=False)
self.lora_up = LoRAUPParallel([torch.nn.Linear(self.lora_dim, out_dim // n_seperate, bias=False) for _ in range(n_seperate)])
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
alpha_scale = alpha / self.lora_dim
self.register_buffer("alpha_scale", torch.tensor(alpha_scale))
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
if n_seperate > 1:
for block in self.lora_up.blocks:
torch.nn.init.zeros_(block.weight)
else:
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.use_lora = True
def set_use_lora(self, use_lora):
self.use_lora = use_lora
class LoRANetwork(torch.nn.Module):
LORA_PREFIX = "lora"
LORA_HYPHEN = "___lorahyphen___"
def __init__(
self,
model,
lora_network_state_dict_loaded,
multiplier: float = 1.0,
lora_dim: int = 128,
alpha: float = 64,
) -> None:
super().__init__()
self.multiplier = multiplier
self.use_lora = True
self.lora_dim = lora_dim
self.alpha = alpha
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
lora_module_names = set()
for key in lora_network_state_dict_loaded.keys():
if key.endswith("lora_down.weight"):
lora_name = key.split(".lora_down.weight")[0]
lora_module_names.add(lora_name)
loras = []
for lora_name in lora_module_names:
# Restore the real module name in the model.
module_name = lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".")
# Find the module.
try:
module = model
for part in module_name.split('.'):
module = getattr(module, part)
except Exception as e:
print(f"Cannot find module: {module_name}, error: {e}")
continue
if module.__class__.__name__ != "Linear":
continue
# Infer n_seperate.
n_seperate = 1
prefix = lora_name + ".lora_up.blocks"
n_blocks = sum(1 for k in lora_network_state_dict_loaded if k.startswith(prefix))
if n_blocks > 0:
n_seperate = n_blocks
dim = self.lora_dim
alpha = self.alpha
lora = LoRAModule(
lora_name,
module,
self.multiplier,
dim,
alpha,
n_seperate=n_seperate
)
loras.append(lora)
self.loras = loras
for lora in self.loras:
self.add_module(lora.lora_name, lora)
print(f"create LoRA for model: {len(self.loras)} modules.")
# assertion
names = set()
for lora in self.loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def disapply_to(self):
for lora in self.loras:
lora.disapply_to()
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.loras:
lora.multiplier = self.multiplier
def set_use_lora(self, use_lora):
self.use_lora = use_lora
for lora in self.loras:
lora.set_use_lora(use_lora)
def prepare_optimizer_params(self, lr):
self.requires_grad_(True)
all_params = []
params = []
for lora in self.loras:
params.extend(lora.parameters())
param_data = {"params": params}
param_data["lr"] = lr
all_params.append(param_data)
return all_params
def create_lora_network(
transformer,
lora_network_state_dict_loaded,
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
):
network = LoRANetwork(
transformer,
lora_network_state_dict_loaded,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
)
return network