tonyshark's picture
Upload 119 files
0bb1a82 verified
import math
from functools import cache
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import LycorisBaseModule
from ..functional.general import rebuild_tucker
from ..logging import logger
@cache
def log_wd():
return logger.warning(
"Using weight_decompose=True with LoRA (DoRA) will ignore network_dropout."
"Only rank dropout and module dropout will be applied"
)
class LoConModule(LycorisBaseModule):
name = "locon"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = [
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
"alpha",
"dora_scale",
]
weight_list_det = ["lora_up.weight"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
wd_on_out=False,
bypass_mode=None,
rs_lora=False,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in LoRA/LoCon algo.")
self.lora_dim = lora_dim
self.tucker = False
self.rs_lora = rs_lora
if self.module_type.startswith("conv"):
self.isconv = True
# For general LoCon
in_dim = org_module.in_channels
k_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
out_dim = org_module.out_channels
use_tucker = use_tucker and any(i != 1 for i in k_size)
self.down_op = self.op
self.up_op = self.op
if use_tucker and any(i != 1 for i in k_size):
self.lora_down = self.module(in_dim, lora_dim, 1, bias=False)
self.lora_mid = self.module(
lora_dim, lora_dim, k_size, stride, padding, bias=False
)
self.tucker = True
else:
self.lora_down = self.module(
in_dim, lora_dim, k_size, stride, padding, bias=False
)
self.lora_up = self.module(lora_dim, out_dim, 1, bias=False)
elif isinstance(org_module, nn.Linear):
self.isconv = False
self.down_op = F.linear
self.up_op = F.linear
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
else:
raise NotImplementedError
self.wd = weight_decompose
self.wd_on_out = wd_on_out
if self.wd:
org_weight = org_module.weight.cpu().clone().float()
self.dora_norm_dims = org_weight.dim() - 1
if self.wd_on_out:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.reshape(org_weight.shape[0], -1),
dim=1,
keepdim=True,
).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
).float()
else:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()
if dropout:
self.dropout = nn.Dropout(dropout)
if self.wd:
log_wd()
else:
self.dropout = nn.Identity()
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
r_factor = lora_dim
if self.rs_lora:
r_factor = math.sqrt(r_factor)
self.scale = alpha / r_factor
self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor)))
if use_scalar:
self.scalar = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("scalar", torch.tensor(1.0), persistent=False)
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
if use_scalar:
torch.nn.init.kaiming_uniform_(self.lora_up.weight, a=math.sqrt(5))
else:
torch.nn.init.constant_(self.lora_up.weight, 0)
if self.tucker:
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
@classmethod
def make_module_from_state_dict(
cls, lora_name, orig_module, up, down, mid, alpha, dora_scale
):
module = cls(
lora_name,
orig_module,
1,
down.size(0),
float(alpha),
use_tucker=mid is not None,
weight_decompose=dora_scale is not None,
)
module.lora_up.weight.data.copy_(up)
module.lora_down.weight.data.copy_(down)
if mid is not None:
module.lora_mid.weight.data.copy_(mid)
if dora_scale is not None:
module.dora_scale.copy_(dora_scale)
return module
def load_weight_hook(self, module: nn.Module, incompatible_keys):
missing_keys = incompatible_keys.missing_keys
for key in missing_keys:
if "scalar" in key:
del missing_keys[missing_keys.index(key)]
if isinstance(self.scalar, nn.Parameter):
self.scalar.data.copy_(torch.ones_like(self.scalar))
elif getattr(self, "scalar", None) is not None:
self.scalar.copy_(torch.ones_like(self.scalar))
else:
self.register_buffer(
"scalar", torch.ones_like(self.scalar), persistent=False
)
def make_weight(self, device=None):
wa = self.lora_up.weight.to(device)
wb = self.lora_down.weight.to(device)
if self.tucker:
t = self.lora_mid.weight
wa = wa.view(wa.size(0), -1).transpose(0, 1)
wb = wb.view(wb.size(0), -1)
weight = rebuild_tucker(t, wa, wb)
else:
weight = wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1)
weight = weight.view(self.shape)
if self.training and self.rank_dropout:
drop = (torch.rand(weight.size(0), device=device) > self.rank_dropout).to(
weight.dtype
)
drop = drop.view(-1, *[1] * len(weight.shape[1:]))
if self.rank_dropout_scale:
drop /= drop.mean()
weight *= drop
return weight * self.scalar.to(device)
def get_diff_weight(self, multiplier=1, shape=None, device=None):
scale = self.scale * multiplier
diff = self.make_weight(device=device) * scale
if shape is not None:
diff = diff.view(shape)
if device is not None:
diff = diff.to(device)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0]
weight = self.org_weight
if self.wd:
merged = self.apply_weight_decompose(weight + diff, multiplier)
else:
merged = weight + diff * multiplier
return merged, None
def apply_weight_decompose(self, weight, multiplier=1):
weight = weight.to(self.dora_scale.dtype)
if self.wd_on_out:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1)
.reshape(weight.shape[0], *[1] * self.dora_norm_dims)
) + torch.finfo(weight.dtype).eps
else:
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
) + torch.finfo(weight.dtype).eps
scale = self.dora_scale.to(weight.device) / weight_norm
if multiplier != 1:
scale = multiplier * (scale - 1) + 1
return weight * scale
def custom_state_dict(self):
destination = {}
if self.wd:
destination["dora_scale"] = self.dora_scale
destination["alpha"] = self.alpha
destination["lora_up.weight"] = self.lora_up.weight * self.scalar
destination["lora_down.weight"] = self.lora_down.weight
if self.tucker:
destination["lora_mid.weight"] = self.lora_mid.weight
return destination
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.make_weight(device).norm() * self.scale
norm = torch.clamp(orig_norm, max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu() / norm.cpu()
scaled = norm != desired
if scaled:
self.scalar *= ratio
return scaled, orig_norm * ratio
def bypass_forward_diff(self, x, scale=1):
if self.tucker:
mid = self.lora_mid(self.lora_down(x))
else:
mid = self.lora_down(x)
if self.rank_dropout and self.training:
drop = (
torch.rand(self.lora_dim, device=mid.device) > self.rank_dropout
).to(mid.dtype)
if self.rank_dropout_scale:
drop /= drop.mean()
if (dims := len(x.shape)) == 4:
drop = drop.view(1, -1, 1, 1)
else:
drop = drop.view(*[1] * (dims - 1), -1)
mid = mid * drop
return self.dropout(self.lora_up(mid) * self.scalar * self.scale * scale)
def bypass_forward(self, x, scale=1):
return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale)
def forward(self, x):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
scale = self.scale
dtype = self.dtype
if not self.bypass_mode:
diff_weight = self.make_weight(x.device).to(dtype) * scale
weight = self.org_module[0].weight.data.to(dtype)
if self.wd:
weight = self.apply_weight_decompose(
weight + diff_weight, self.multiplier
)
else:
weight = weight + diff_weight * self.multiplier
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)
else:
return self.bypass_forward(x, scale=self.multiplier)