tonyshark's picture
Upload 119 files
0bb1a82 verified
import math
import torch
import torch.nn as nn
from .base import LycorisBaseModule
from ..functional.loha import diff_weight as loha_diff_weight
class LohaModule(LycorisBaseModule):
name = "loha"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = [
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
"alpha",
"dora_scale",
]
weight_list_det = ["hada_w1_a"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
wd_on_out=False,
bypass_mode=None,
rs_lora=False,
**kwargs,
):
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in LoHa algo.")
self.lora_name = lora_name
self.lora_dim = lora_dim
self.tucker = False
self.rs_lora = rs_lora
w_shape = self.shape
if self.module_type.startswith("conv"):
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
self.shape = (out_dim, in_dim, *k_size)
self.tucker = use_tucker and any(i != 1 for i in k_size)
if self.tucker:
w_shape = (out_dim, in_dim, *k_size)
else:
w_shape = (out_dim, in_dim * torch.tensor(k_size).prod().item())
if self.tucker:
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:]))
self.hada_w1_a = nn.Parameter(
torch.empty(lora_dim, w_shape[0])
) # out_dim, 1-mode
self.hada_w1_b = nn.Parameter(
torch.empty(lora_dim, w_shape[1])
) # in_dim , 2-mode
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:]))
self.hada_w2_a = nn.Parameter(
torch.empty(lora_dim, w_shape[0])
) # out_dim, 1-mode
self.hada_w2_b = nn.Parameter(
torch.empty(lora_dim, w_shape[1])
) # in_dim , 2-mode
else:
self.hada_w1_a = nn.Parameter(torch.empty(w_shape[0], lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, w_shape[1]))
self.hada_w2_a = nn.Parameter(torch.empty(w_shape[0], lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, w_shape[1]))
self.wd = weight_decompose
self.wd_on_out = wd_on_out
if self.wd:
org_weight = org_module.weight.cpu().clone().float()
self.dora_norm_dims = org_weight.dim() - 1
if self.wd_on_out:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.reshape(org_weight.shape[0], -1),
dim=1,
keepdim=True,
).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims)
).float()
else:
self.dora_scale = nn.Parameter(
torch.norm(
org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1),
dim=1,
keepdim=True,
)
.reshape(org_weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(1, 0)
).float()
if self.dropout:
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
r_factor = lora_dim
if self.rs_lora:
r_factor = math.sqrt(r_factor)
self.scale = alpha / r_factor
self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor)))
if use_scalar:
self.scalar = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("scalar", torch.tensor(1.0), persistent=False)
# Need more experiments on init method
if self.tucker:
torch.nn.init.normal_(self.hada_t1, std=0.1)
torch.nn.init.normal_(self.hada_t2, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1)
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w2_b, std=1)
if use_scalar:
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
else:
torch.nn.init.constant_(self.hada_w2_a, 0)
@classmethod
def make_module_from_state_dict(
cls, lora_name, orig_module, w1a, w1b, w2a, w2b, t1, t2, alpha, dora_scale
):
module = cls(
lora_name,
orig_module,
1,
w1b.size(0),
float(alpha),
use_tucker=t1 is not None,
weight_decompose=dora_scale is not None,
)
module.hada_w1_a.copy_(w1a)
module.hada_w1_b.copy_(w1b)
module.hada_w2_a.copy_(w2a)
module.hada_w2_b.copy_(w2b)
if t1 is not None:
module.hada_t1.copy_(t1)
module.hada_t2.copy_(t2)
if dora_scale is not None:
module.dora_scale.copy_(dora_scale)
return module
def load_weight_hook(self, module: nn.Module, incompatible_keys):
missing_keys = incompatible_keys.missing_keys
for key in missing_keys:
if "scalar" in key:
del missing_keys[missing_keys.index(key)]
if isinstance(self.scalar, nn.Parameter):
self.scalar.data.copy_(torch.ones_like(self.scalar))
elif getattr(self, "scalar", None) is not None:
self.scalar.copy_(torch.ones_like(self.scalar))
else:
self.register_buffer(
"scalar", torch.ones_like(self.scalar), persistent=False
)
def get_weight(self, shape):
scale = torch.tensor(
self.scale, dtype=self.hada_w1_b.dtype, device=self.hada_w1_b.device
)
if self.tucker:
weight = loha_diff_weight(
self.hada_w1_b,
self.hada_w1_a,
self.hada_w2_b,
self.hada_w2_a,
self.hada_t1,
self.hada_t2,
gamma=scale,
)
else:
weight = loha_diff_weight(
self.hada_w1_b,
self.hada_w1_a,
self.hada_w2_b,
self.hada_w2_a,
None,
None,
gamma=scale,
)
if shape is not None:
weight = weight.reshape(shape)
if self.training and self.rank_dropout:
drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(weight.dtype)
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
if self.rank_dropout_scale:
drop /= drop.mean()
weight *= drop
return weight
def get_diff_weight(self, multiplier=1, shape=None, device=None):
scale = self.scale * multiplier
diff = self.get_weight(shape) * scale
if device is not None:
diff = diff.to(device)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0]
weight = self.org_weight
if self.wd:
merged = self.apply_weight_decompose(weight + diff, multiplier)
else:
merged = weight + diff * multiplier
return merged, None
def apply_weight_decompose(self, weight, multiplier=1):
weight = weight.to(self.dora_scale.dtype)
if self.wd_on_out:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1)
.reshape(weight.shape[0], *[1] * self.dora_norm_dims)
) + torch.finfo(weight.dtype).eps
else:
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * self.dora_norm_dims)
.transpose(0, 1)
) + torch.finfo(weight.dtype).eps
scale = self.dora_scale.to(weight.device) / weight_norm
if multiplier != 1:
scale = multiplier * (scale - 1) + 1
return weight * scale
def custom_state_dict(self):
destination = {}
destination["alpha"] = self.alpha
if self.wd:
destination["dora_scale"] = self.dora_scale
destination["hada_w1_a"] = self.hada_w1_a * self.scalar
destination["hada_w1_b"] = self.hada_w1_b
destination["hada_w2_a"] = self.hada_w2_a
destination["hada_w2_b"] = self.hada_w2_b
if self.tucker:
destination["hada_t1"] = self.hada_t1
destination["hada_t2"] = self.hada_t2
return destination
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = (self.get_weight(self.shape) * self.scalar).norm()
norm = torch.clamp(orig_norm, max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu() / norm.cpu()
scaled = norm != desired
if scaled:
self.scalar *= ratio
return scaled, orig_norm * ratio
def bypass_forward_diff(self, x, scale=1):
diff_weight = self.get_weight(self.shape) * self.scalar * scale
return self.drop(self.op(x, diff_weight, **self.kw_dict))
def bypass_forward(self, x, scale=1):
return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale)
def forward(self, x: torch.Tensor, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.op(
x,
self.org_module[0].weight.data,
(
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
),
)
if self.bypass_mode:
return self.bypass_forward(x, scale=self.multiplier)
else:
diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar
weight = self.org_module[0].weight.data.to(self.dtype)
if self.wd:
weight = self.apply_weight_decompose(
weight + diff_weight, self.multiplier
)
else:
weight = weight + diff_weight * self.multiplier
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)