| import math |
| from functools import cache |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .base import LycorisBaseModule |
| from ..functional import factorization, rebuild_tucker |
| from ..functional.lokr import make_kron |
| from ..logging import logger |
|
|
|
|
| @cache |
| def logging_force_full_matrix(lora_dim, dim, factor): |
| logger.warning( |
| f"lora_dim {lora_dim} is too large for" |
| f" dim={dim} and {factor=}" |
| ", using full matrix mode." |
| ) |
|
|
|
|
| class LokrModule(LycorisBaseModule): |
| name = "kron" |
| support_module = { |
| "linear", |
| "conv1d", |
| "conv2d", |
| "conv3d", |
| } |
| weight_list = [ |
| "lokr_w1", |
| "lokr_w1_a", |
| "lokr_w1_b", |
| "lokr_w2", |
| "lokr_w2_a", |
| "lokr_w2_b", |
| "lokr_t1", |
| "lokr_t2", |
| "alpha", |
| "dora_scale", |
| ] |
| weight_list_det = ["lokr_w1", "lokr_w1_a"] |
|
|
| def __init__( |
| self, |
| lora_name, |
| org_module: nn.Module, |
| multiplier=1.0, |
| lora_dim=4, |
| alpha=1, |
| dropout=0.0, |
| rank_dropout=0.0, |
| module_dropout=0.0, |
| use_tucker=False, |
| use_scalar=False, |
| decompose_both=False, |
| factor: int = -1, |
| rank_dropout_scale=False, |
| weight_decompose=False, |
| wd_on_out=False, |
| full_matrix=False, |
| bypass_mode=None, |
| rs_lora=False, |
| unbalanced_factorization=False, |
| **kwargs, |
| ): |
| super().__init__( |
| lora_name, |
| org_module, |
| multiplier, |
| dropout, |
| rank_dropout, |
| module_dropout, |
| rank_dropout_scale, |
| bypass_mode, |
| ) |
| if self.module_type not in self.support_module: |
| raise ValueError(f"{self.module_type} is not supported in LoKr algo.") |
|
|
| factor = int(factor) |
| self.lora_dim = lora_dim |
| self.tucker = False |
| self.use_w1 = False |
| self.use_w2 = False |
| self.full_matrix = full_matrix |
| self.rs_lora = rs_lora |
|
|
| if self.module_type.startswith("conv"): |
| in_dim = org_module.in_channels |
| k_size = org_module.kernel_size |
| out_dim = org_module.out_channels |
| self.shape = (out_dim, in_dim, *k_size) |
|
|
| in_m, in_n = factorization(in_dim, factor) |
| out_l, out_k = factorization(out_dim, factor) |
| if unbalanced_factorization: |
| out_l, out_k = out_k, out_l |
| shape = ((out_l, out_k), (in_m, in_n), *k_size) |
| self.tucker = use_tucker and any(i != 1 for i in k_size) |
| if ( |
| decompose_both |
| and lora_dim < max(shape[0][0], shape[1][0]) / 2 |
| and not self.full_matrix |
| ): |
| self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) |
| self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) |
| else: |
| self.use_w1 = True |
| self.lokr_w1 = nn.Parameter( |
| torch.empty(shape[0][0], shape[1][0]) |
| ) |
|
|
| if lora_dim >= max(shape[0][1], shape[1][1]) / 2 or self.full_matrix: |
| if not self.full_matrix: |
| logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor) |
| self.use_w2 = True |
| self.lokr_w2 = nn.Parameter( |
| torch.empty(shape[0][1], shape[1][1], *k_size) |
| ) |
| elif self.tucker: |
| self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *shape[2:])) |
| self.lokr_w2_a = nn.Parameter( |
| torch.empty(lora_dim, shape[0][1]) |
| ) |
| self.lokr_w2_b = nn.Parameter( |
| torch.empty(lora_dim, shape[1][1]) |
| ) |
| else: |
| |
| self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) |
| self.lokr_w2_b = nn.Parameter( |
| torch.empty( |
| lora_dim, shape[1][1] * torch.tensor(shape[2:]).prod().item() |
| ) |
| ) |
| |
| else: |
| in_dim = org_module.in_features |
| out_dim = org_module.out_features |
| self.shape = (out_dim, in_dim) |
|
|
| in_m, in_n = factorization(in_dim, factor) |
| out_l, out_k = factorization(out_dim, factor) |
| if unbalanced_factorization: |
| out_l, out_k = out_k, out_l |
| shape = ( |
| (out_l, out_k), |
| (in_m, in_n), |
| ) |
| |
| if ( |
| decompose_both |
| and lora_dim < max(shape[0][0], shape[1][0]) / 2 |
| and not self.full_matrix |
| ): |
| self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) |
| self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) |
| else: |
| self.use_w1 = True |
| self.lokr_w1 = nn.Parameter( |
| torch.empty(shape[0][0], shape[1][0]) |
| ) |
| if lora_dim < max(shape[0][1], shape[1][1]) / 2 and not self.full_matrix: |
| |
| self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) |
| self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) |
| |
| else: |
| if not self.full_matrix: |
| logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor) |
| self.use_w2 = True |
| self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) |
|
|
| self.wd = weight_decompose |
| self.wd_on_out = wd_on_out |
| if self.wd: |
| org_weight = org_module.weight.cpu().clone().float() |
| self.dora_norm_dims = org_weight.dim() - 1 |
| if self.wd_on_out: |
| self.dora_scale = nn.Parameter( |
| torch.norm( |
| org_weight.reshape(org_weight.shape[0], -1), |
| dim=1, |
| keepdim=True, |
| ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) |
| ).float() |
| else: |
| self.dora_scale = nn.Parameter( |
| torch.norm( |
| org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), |
| dim=1, |
| keepdim=True, |
| ) |
| .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) |
| .transpose(1, 0) |
| ).float() |
|
|
| self.dropout = dropout |
| if dropout: |
| print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.") |
| self.rank_dropout = rank_dropout |
| self.rank_dropout_scale = rank_dropout_scale |
| self.module_dropout = module_dropout |
|
|
| if isinstance(alpha, torch.Tensor): |
| alpha = alpha.detach().float().numpy() |
| alpha = lora_dim if alpha is None or alpha == 0 else alpha |
| if self.use_w2 and self.use_w1: |
| |
| alpha = lora_dim |
|
|
| r_factor = lora_dim |
| if self.rs_lora: |
| r_factor = math.sqrt(r_factor) |
|
|
| self.scale = alpha / r_factor |
|
|
| self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) |
|
|
| if use_scalar: |
| self.scalar = nn.Parameter(torch.tensor(0.0)) |
| else: |
| self.register_buffer("scalar", torch.tensor(1.0), persistent=False) |
|
|
| if self.use_w2: |
| if use_scalar: |
| torch.nn.init.kaiming_uniform_(self.lokr_w2, a=math.sqrt(5)) |
| else: |
| torch.nn.init.constant_(self.lokr_w2, 0) |
| else: |
| if self.tucker: |
| torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) |
| torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) |
| if use_scalar: |
| torch.nn.init.kaiming_uniform_(self.lokr_w2_b, a=math.sqrt(5)) |
| else: |
| torch.nn.init.constant_(self.lokr_w2_b, 0) |
|
|
| if self.use_w1: |
| torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) |
| else: |
| torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5)) |
| torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5)) |
|
|
| @classmethod |
| def make_module_from_state_dict( |
| cls, |
| lora_name, |
| orig_module, |
| w1, |
| w1a, |
| w1b, |
| w2, |
| w2a, |
| w2b, |
| _, |
| t2, |
| alpha, |
| dora_scale, |
| ): |
| full_matrix = False |
| if w1a is not None: |
| lora_dim = w1a.size(1) |
| elif w2a is not None: |
| lora_dim = w2a.size(1) |
| else: |
| full_matrix = True |
| lora_dim = 1 |
|
|
| if w1 is None: |
| out_dim = w1a.size(0) |
| in_dim = w1b.size(1) |
| else: |
| out_dim, in_dim = w1.shape |
|
|
| shape_s = [out_dim, in_dim] |
|
|
| if w2 is None: |
| out_dim *= w2a.size(0) |
| in_dim *= w2b.size(1) |
| else: |
| out_dim *= w2.size(0) |
| in_dim *= w2.size(1) |
|
|
| if ( |
| shape_s[0] == factorization(out_dim, -1)[0] |
| and shape_s[1] == factorization(in_dim, -1)[0] |
| ): |
| factor = -1 |
| else: |
| w1_shape = w1.shape if w1 is not None else (w1a.size(0), w1b.size(1)) |
| w2_shape = w2.shape if w2 is not None else (w2a.size(0), w2b.size(1)) |
| shape_group_1 = (w1_shape[0], w2_shape[0]) |
| shape_group_2 = (w1_shape[1], w2_shape[1]) |
| w_shape = (w1_shape[0] * w2_shape[0], w1_shape[1] * w2_shape[1]) |
| factor1 = max(w1.shape) if w1 is not None else max(w1a.size(0), w1b.size(1)) |
| factor2 = max(w2.shape) if w2 is not None else max(w2a.size(0), w2b.size(1)) |
| if ( |
| w_shape[0] % factor1 == 0 |
| and w_shape[1] % factor1 == 0 |
| and factor1 in shape_group_1 |
| and factor1 in shape_group_2 |
| ): |
| factor = factor1 |
| elif ( |
| w_shape[0] % factor2 == 0 |
| and w_shape[1] % factor2 == 0 |
| and factor2 in shape_group_1 |
| and factor2 in shape_group_2 |
| ): |
| factor = factor2 |
| else: |
| factor = min(factor1, factor2) |
|
|
| module = cls( |
| lora_name, |
| orig_module, |
| 1, |
| lora_dim, |
| float(alpha), |
| use_tucker=t2 is not None, |
| decompose_both=w1 is None and w2 is None, |
| factor=factor, |
| weight_decompose=dora_scale is not None, |
| full_matrix=full_matrix, |
| ) |
| if w1 is not None: |
| module.lokr_w1.copy_(w1) |
| else: |
| module.lokr_w1_a.copy_(w1a) |
| module.lokr_w1_b.copy_(w1b) |
| if w2 is not None: |
| module.lokr_w2.copy_(w2) |
| else: |
| module.lokr_w2_a.copy_(w2a) |
| module.lokr_w2_b.copy_(w2b) |
| if t2 is not None: |
| module.lokr_t2.copy_(t2) |
| if dora_scale is not None: |
| module.dora_scale.copy_(dora_scale) |
| return module |
|
|
| def load_weight_hook(self, module: nn.Module, incompatible_keys): |
| missing_keys = incompatible_keys.missing_keys |
| for key in missing_keys: |
| if "scalar" in key: |
| del missing_keys[missing_keys.index(key)] |
| if isinstance(self.scalar, nn.Parameter): |
| self.scalar.data.copy_(torch.ones_like(self.scalar)) |
| elif getattr(self, "scalar", None) is not None: |
| self.scalar.copy_(torch.ones_like(self.scalar)) |
| else: |
| self.register_buffer( |
| "scalar", torch.ones_like(self.scalar), persistent=False |
| ) |
|
|
| def get_weight(self, shape): |
| weight = make_kron( |
| self.lokr_w1 if self.use_w1 else self.lokr_w1_a @ self.lokr_w1_b, |
| ( |
| self.lokr_w2 |
| if self.use_w2 |
| else ( |
| rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) |
| if self.tucker |
| else self.lokr_w2_a @ self.lokr_w2_b |
| ) |
| ), |
| self.scale, |
| ) |
| dtype = weight.dtype |
| if shape is not None: |
| weight = weight.view(shape) |
| if self.training and self.rank_dropout: |
| drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(dtype) |
| drop = drop.view(-1, *[1] * len(weight.shape[1:])) |
| if self.rank_dropout_scale: |
| drop /= drop.mean() |
| weight *= drop |
| return weight |
|
|
| def get_diff_weight(self, multiplier=1, shape=None, device=None): |
| scale = self.scale * multiplier |
| diff = self.get_weight(shape) * scale |
| if device is not None: |
| diff = diff.to(device) |
| return diff, None |
|
|
| def get_merged_weight(self, multiplier=1, shape=None, device=None): |
| diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] |
| weight = self.org_weight |
| if self.wd: |
| merged = self.apply_weight_decompose(weight + diff, multiplier) |
| else: |
| merged = weight + diff * multiplier |
| return merged, None |
|
|
| def apply_weight_decompose(self, weight, multiplier=1): |
| weight = weight.to(self.dora_scale.dtype) |
| if self.wd_on_out: |
| weight_norm = ( |
| weight.reshape(weight.shape[0], -1) |
| .norm(dim=1) |
| .reshape(weight.shape[0], *[1] * self.dora_norm_dims) |
| ) + torch.finfo(weight.dtype).eps |
| else: |
| weight_norm = ( |
| weight.transpose(0, 1) |
| .reshape(weight.shape[1], -1) |
| .norm(dim=1, keepdim=True) |
| .reshape(weight.shape[1], *[1] * self.dora_norm_dims) |
| .transpose(0, 1) |
| ) + torch.finfo(weight.dtype).eps |
|
|
| scale = self.dora_scale.to(weight.device) / weight_norm |
| if multiplier != 1: |
| scale = multiplier * (scale - 1) + 1 |
|
|
| return weight * scale |
|
|
| def custom_state_dict(self): |
| destination = {} |
| destination["alpha"] = self.alpha |
| if self.wd: |
| destination["dora_scale"] = self.dora_scale |
| if self.use_w1: |
| destination["lokr_w1"] = self.lokr_w1 * self.scalar |
| else: |
| destination["lokr_w1_a"] = self.lokr_w1_a * self.scalar |
| destination["lokr_w1_b"] = self.lokr_w1_b |
|
|
| if self.use_w2: |
| destination["lokr_w2"] = self.lokr_w2 |
| else: |
| destination["lokr_w2_a"] = self.lokr_w2_a |
| destination["lokr_w2_b"] = self.lokr_w2_b |
| if self.tucker: |
| destination["lokr_t2"] = self.lokr_t2 |
| return destination |
|
|
| @torch.no_grad() |
| def apply_max_norm(self, max_norm, device=None): |
| orig_norm = self.get_weight(self.shape).norm() |
| norm = torch.clamp(orig_norm, max_norm / 2) |
| desired = torch.clamp(norm, max=max_norm) |
| ratio = desired.cpu() / norm.cpu() |
|
|
| scaled = norm != desired |
| if scaled: |
| modules = 4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.tucker) |
| if self.use_w1: |
| self.lokr_w1 *= ratio ** (1 / modules) |
| else: |
| self.lokr_w1_a *= ratio ** (1 / modules) |
| self.lokr_w1_b *= ratio ** (1 / modules) |
|
|
| if self.use_w2: |
| self.lokr_w2 *= ratio ** (1 / modules) |
| else: |
| if self.tucker: |
| self.lokr_t2 *= ratio ** (1 / modules) |
| self.lokr_w2_a *= ratio ** (1 / modules) |
| self.lokr_w2_b *= ratio ** (1 / modules) |
|
|
| return scaled, orig_norm * ratio |
|
|
| def bypass_forward_diff(self, h, scale=1): |
| is_conv = self.module_type.startswith("conv") |
| if self.use_w2: |
| ba = self.lokr_w2 |
| else: |
| a = self.lokr_w2_b |
| b = self.lokr_w2_a |
|
|
| if self.tucker: |
| t = self.lokr_t2 |
| a = a.view(*a.shape, *[1] * (len(t.shape) - 2)) |
| b = b.view(*b.shape, *[1] * (len(t.shape) - 2)) |
| elif is_conv: |
| a = a.view(*a.shape, *self.shape[2:]) |
| b = b.view(*b.shape, *[1] * (len(self.shape) - 2)) |
|
|
| if self.use_w1: |
| c = self.lokr_w1 |
| else: |
| c = self.lokr_w1_a @ self.lokr_w1_b |
| uq = c.size(1) |
|
|
| if is_conv: |
| |
| b, _, *rest = h.shape |
| h_in_group = h.reshape(b * uq, -1, *rest) |
| else: |
| |
| h_in_group = h.reshape(*h.shape[:-1], uq, -1) |
|
|
| if self.use_w2: |
| hb = self.op(h_in_group, ba, **self.kw_dict) |
| else: |
| if is_conv: |
| if self.tucker: |
| ha = self.op(h_in_group, a) |
| ht = self.op(ha, t, **self.kw_dict) |
| hb = self.op(ht, b) |
| else: |
| ha = self.op(h_in_group, a, **self.kw_dict) |
| hb = self.op(ha, b) |
| else: |
| ha = self.op(h_in_group, a, **self.kw_dict) |
| hb = self.op(ha, b) |
|
|
| if is_conv: |
| |
| |
| |
| hb = hb.view(b, -1, *hb.shape[1:]) |
| h_cross_group = hb.transpose(1, -1) |
| else: |
| |
| |
| h_cross_group = hb.transpose(-1, -2) |
|
|
| hc = F.linear(h_cross_group, c) |
| if is_conv: |
| |
| |
| |
| hc = hc.transpose(1, -1) |
| h = hc.reshape(b, -1, *hc.shape[3:]) |
| else: |
| |
| |
| |
| hc = hc.transpose(-1, -2) |
| h = hc.reshape(*hc.shape[:-2], -1) |
|
|
| return self.drop(h * scale * self.scalar) |
|
|
| def bypass_forward(self, x, scale=1): |
| return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) |
|
|
| def forward(self, x: torch.Tensor, *args, **kwargs): |
| if self.module_dropout and self.training: |
| if torch.rand(1) < self.module_dropout: |
| return self.org_forward(x) |
| if self.bypass_mode: |
| return self.bypass_forward(x, self.multiplier) |
| else: |
| diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar |
| weight = self.org_module[0].weight.data.to(self.dtype) |
| if self.wd: |
| weight = self.apply_weight_decompose( |
| weight + diff_weight, self.multiplier |
| ) |
| elif self.multiplier == 1: |
| weight = weight + diff_weight |
| else: |
| weight = weight + diff_weight * self.multiplier |
| bias = ( |
| None |
| if self.org_module[0].bias is None |
| else self.org_module[0].bias.data |
| ) |
| return self.op(x, weight, bias, **self.kw_dict) |
|
|
|
|
| if __name__ == "__main__": |
| base = nn.Conv2d(128, 128, 3, 1, 1) |
| net = LokrModule( |
| "", |
| base, |
| multiplier=1, |
| lora_dim=4, |
| alpha=1, |
| weight_decompose=False, |
| use_tucker=False, |
| use_scalar=False, |
| decompose_both=True, |
| ) |
| net.apply_to() |
| sd = net.state_dict() |
| for key in sd: |
| if key != "alpha": |
| sd[key] = torch.randn_like(sd[key]) |
| net.load_state_dict(sd) |
|
|
| test_input = torch.randn(1, 128, 16, 16) |
| test_output = net(test_input) |
| print(test_output.shape) |
|
|
| net2 = LokrModule( |
| "", |
| base, |
| multiplier=1, |
| lora_dim=4, |
| alpha=1, |
| weight_decompose=False, |
| use_tucker=False, |
| use_scalar=False, |
| bypass_mode=True, |
| decompose_both=True, |
| ) |
| net2.apply_to() |
| net2.load_state_dict(sd) |
| print(net2) |
|
|
| test_output2 = net(test_input) |
| print(F.mse_loss(test_output, test_output2)) |
|
|