|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import math |
|
|
from typing import Optional, List, Type, Set, Literal |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from diffusers import UNet2DConditionModel |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
|
|
|
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ |
|
|
|
|
|
"Attention" |
|
|
] |
|
|
UNET_TARGET_REPLACE_MODULE_CONV = [ |
|
|
"ResnetBlock2D", |
|
|
"Downsample2D", |
|
|
"Upsample2D", |
|
|
"DownBlock2D", |
|
|
"UpBlock2D", |
|
|
|
|
|
] |
|
|
|
|
|
LORA_PREFIX_UNET = "lora_unet" |
|
|
|
|
|
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER |
|
|
|
|
|
TRAINING_METHODS = Literal[ |
|
|
"noxattn", |
|
|
"innoxattn", |
|
|
"selfattn", |
|
|
"xattn", |
|
|
"full", |
|
|
"xattn-strict", |
|
|
"noxattn-hspace", |
|
|
"noxattn-hspace-last", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
class LoRAModule(nn.Module): |
|
|
""" |
|
|
replaces forward method of the original Linear, instead of replacing the original Linear module. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lora_name, |
|
|
proj, |
|
|
v, |
|
|
mean, |
|
|
std, |
|
|
org_module: nn.Module, |
|
|
multiplier=1.0, |
|
|
lora_dim=4, |
|
|
alpha=1, |
|
|
): |
|
|
"""if alpha == 0 or None, alpha is rank (no scaling).""" |
|
|
super().__init__() |
|
|
self.lora_name = lora_name |
|
|
self.lora_dim = lora_dim |
|
|
self.in_dim = org_module.in_features |
|
|
self.out_dim = org_module.out_features |
|
|
|
|
|
self.proj = proj.bfloat16() |
|
|
self.mean1 = mean[0:self.in_dim].bfloat16() |
|
|
self.mean2 = mean[self.in_dim:].bfloat16() |
|
|
self.std1 = std[0:self.in_dim].bfloat16() |
|
|
self.std2 = std[self.in_dim:].bfloat16() |
|
|
self.v1 = v[0:self.in_dim].bfloat16() |
|
|
self.v2 = v[self.in_dim: ].bfloat16() |
|
|
|
|
|
if type(alpha) == torch.Tensor: |
|
|
alpha = alpha.detach().numpy() |
|
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha |
|
|
self.scale = alpha / self.lora_dim |
|
|
|
|
|
|
|
|
|
|
|
self.multiplier = multiplier |
|
|
self.org_module = org_module |
|
|
|
|
|
def apply_to(self): |
|
|
self.org_forward = self.org_module.forward |
|
|
self.org_module.forward = self.forward |
|
|
del self.org_module |
|
|
|
|
|
def forward(self, x): |
|
|
return self.org_forward(x) +\ |
|
|
(x@((self.proj@self.v1.T)*self.std1+self.mean1).T)@(((self.proj@self.v2.T)*self.std2+self.mean2))*self.multiplier*self.scale |
|
|
|
|
|
|
|
|
|
|
|
class LoRAw2w(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
proj, |
|
|
mean, |
|
|
std, |
|
|
v, |
|
|
unet: UNet2DConditionModel, |
|
|
rank: int = 4, |
|
|
multiplier: torch.bfloat16= 1.0, |
|
|
alpha: torch.bfloat16 = 1.0, |
|
|
train_method: TRAINING_METHODS = "full" |
|
|
|
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.lora_scale = 1 |
|
|
self.multiplier = multiplier |
|
|
self.lora_dim = rank |
|
|
self.alpha = alpha |
|
|
self.proj = torch.nn.Parameter(proj) |
|
|
self.register_buffer("mean", torch.tensor(mean)) |
|
|
self.register_buffer("std", torch.tensor(std)) |
|
|
self.register_buffer("v", torch.tensor(v)) |
|
|
|
|
|
self.module = LoRAModule |
|
|
|
|
|
self.unet_loras = self.create_modules( |
|
|
LORA_PREFIX_UNET, |
|
|
unet, |
|
|
DEFAULT_TARGET_REPLACE, |
|
|
self.lora_dim, |
|
|
self.multiplier, |
|
|
train_method=train_method, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.lora_names = set() |
|
|
for lora in self.unet_loras: |
|
|
assert ( |
|
|
lora.lora_name not in self.lora_names |
|
|
), f"duplicated lora name: {lora.lora_name}. {self.lora_names}" |
|
|
self.lora_names.add(lora.lora_name) |
|
|
|
|
|
|
|
|
for lora in self.unet_loras: |
|
|
lora.apply_to() |
|
|
self.add_module( |
|
|
lora.lora_name, |
|
|
lora, |
|
|
) |
|
|
|
|
|
del unet |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
def reset(self): |
|
|
for lora in self.unet_loras: |
|
|
lora.proj = torch.nn.Parameter(self.proj.bfloat16()) |
|
|
def create_modules( |
|
|
self, |
|
|
prefix: str, |
|
|
root_module: nn.Module, |
|
|
target_replace_modules: List[str], |
|
|
rank: int, |
|
|
multiplier: float, |
|
|
train_method: TRAINING_METHODS, |
|
|
) -> list: |
|
|
|
|
|
counter = 0 |
|
|
|
|
|
|
|
|
mm = [] |
|
|
nn = [] |
|
|
for name, module in root_module.named_modules(): |
|
|
nn.append(name) |
|
|
mm.append(module) |
|
|
|
|
|
|
|
|
midstart = 0 |
|
|
upstart = 0 |
|
|
for i in range(len(nn)): |
|
|
if "mid_block" in nn[i]: |
|
|
midstart = i |
|
|
break |
|
|
|
|
|
for i in range(len(nn)): |
|
|
if "up_block" in nn[i]: |
|
|
upstart = i |
|
|
break |
|
|
|
|
|
mm = mm[:upstart]+mm[midstart:]+mm[upstart:midstart] |
|
|
nn = nn[:upstart]+nn[midstart:]+nn[upstart:midstart] |
|
|
|
|
|
|
|
|
|
|
|
loras = [] |
|
|
names = [] |
|
|
|
|
|
for i in range(len(mm)): |
|
|
name = nn[i] |
|
|
module = mm[i] |
|
|
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": |
|
|
if "attn2" in name or "time_embed" in name: |
|
|
continue |
|
|
elif train_method == "innoxattn": |
|
|
if "attn2" in name: |
|
|
continue |
|
|
elif train_method == "selfattn": |
|
|
if "attn1" not in name: |
|
|
continue |
|
|
elif train_method == "xattn" or train_method == "xattn-strict": |
|
|
if "to_k" in name: |
|
|
continue |
|
|
|
|
|
elif train_method == "full": |
|
|
pass |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"train_method: {train_method} is not implemented." |
|
|
) |
|
|
if module.__class__.__name__ in target_replace_modules: |
|
|
for child_name, child_module in module.named_modules(): |
|
|
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: |
|
|
if train_method == 'xattn-strict': |
|
|
if 'out' in child_name: |
|
|
continue |
|
|
if "to_k" in child_name: |
|
|
continue |
|
|
if train_method == 'noxattn-hspace': |
|
|
if 'mid_block' not in name: |
|
|
continue |
|
|
if train_method == 'noxattn-hspace-last': |
|
|
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: |
|
|
continue |
|
|
lora_name = prefix + "." + name + "." + child_name |
|
|
lora_name = lora_name.replace(".", "_") |
|
|
|
|
|
|
|
|
in_dim = child_module.in_features |
|
|
out_dim = child_module.out_features |
|
|
combined_dim = in_dim+out_dim |
|
|
|
|
|
lora = self.module( |
|
|
lora_name, self.proj, self.v[counter:counter+combined_dim], self.mean[counter:counter+combined_dim],\ |
|
|
self.std[counter:counter+combined_dim], child_module, multiplier, rank, self.alpha) |
|
|
counter+=combined_dim |
|
|
if lora_name not in names: |
|
|
loras.append(lora) |
|
|
names.append(lora_name) |
|
|
|
|
|
|
|
|
return loras |
|
|
|
|
|
|
|
|
|
|
|
def prepare_optimizer_params(self): |
|
|
all_params = [] |
|
|
|
|
|
if self.unet_loras: |
|
|
params = [] |
|
|
[params.extend(lora.parameters()) for lora in self.unet_loras] |
|
|
param_data = {"params": params} |
|
|
all_params.append(param_data) |
|
|
|
|
|
return all_params |
|
|
|
|
|
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): |
|
|
state_dict = self.state_dict() |
|
|
|
|
|
if dtype is not None: |
|
|
for key in list(state_dict.keys()): |
|
|
v = state_dict[key] |
|
|
v = v.detach().clone().to("cpu").to(dtype) |
|
|
state_dict[key] = v |
|
|
|
|
|
if os.path.splitext(file)[1] == ".safetensors": |
|
|
save_file(state_dict, file, metadata) |
|
|
else: |
|
|
torch.save(state_dict, file) |
|
|
def set_lora_slider(self, scale): |
|
|
self.lora_scale = scale |
|
|
|
|
|
def __enter__(self): |
|
|
for lora in self.unet_loras: |
|
|
lora.multiplier = 1.0 * self.lora_scale |
|
|
|
|
|
def __exit__(self, exc_type, exc_value, tb): |
|
|
for lora in self.unet_loras: |
|
|
lora.multiplier = 0 |
|
|
|