| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .base import LycorisBaseModule |
| | from ..functional import tucker_weight_from_conv |
| |
|
| |
|
| | class GLoRAModule(LycorisBaseModule): |
| | name = "glora" |
| | support_module = { |
| | "linear", |
| | "conv1d", |
| | "conv2d", |
| | "conv3d", |
| | } |
| | weight_list = [ |
| | "a1.weight", |
| | "a2.weight", |
| | "b1.weight", |
| | "b2.weight", |
| | "bm.weight", |
| | "alpha", |
| | ] |
| | weight_list_det = ["a1.weight"] |
| |
|
| | def __init__( |
| | self, |
| | lora_name, |
| | org_module: nn.Module, |
| | multiplier=1.0, |
| | lora_dim=4, |
| | alpha=1, |
| | dropout=0.0, |
| | rank_dropout=0.0, |
| | module_dropout=0.0, |
| | use_tucker=False, |
| | use_scalar=False, |
| | rank_dropout_scale=False, |
| | weight_decompose=False, |
| | bypass_mode=None, |
| | rs_lora=False, |
| | **kwargs, |
| | ): |
| | """ |
| | f(x) = WX + WAX + BX, where A and B are low-rank matrices |
| | bypass_forward(x) = W(X+A(X)) + B(X) |
| | bypass_forward_diff(x) = W(A(X)) + B(X) |
| | get_merged_weight() = W + WA + B |
| | get_diff_weight() = WA + B |
| | """ |
| | super().__init__( |
| | lora_name, |
| | org_module, |
| | multiplier, |
| | dropout, |
| | rank_dropout, |
| | module_dropout, |
| | rank_dropout_scale, |
| | bypass_mode, |
| | ) |
| | if self.module_type not in self.support_module: |
| | raise ValueError(f"{self.module_type} is not supported in GLoRA algo.") |
| | self.lora_dim = lora_dim |
| | self.tucker = False |
| | self.rs_lora = rs_lora |
| |
|
| | if self.module_type.startswith("conv"): |
| | self.isconv = True |
| | |
| | in_dim = org_module.in_channels |
| | k_size = org_module.kernel_size |
| | stride = org_module.stride |
| | padding = org_module.padding |
| | out_dim = org_module.out_channels |
| | use_tucker = use_tucker and all(i == 1 for i in k_size) |
| | self.down_op = self.op |
| | self.up_op = self.op |
| |
|
| | |
| | self.a2 = self.module(in_dim, lora_dim, 1, bias=False) |
| | self.a1 = self.module(lora_dim, in_dim, 1, bias=False) |
| |
|
| | |
| | if use_tucker and any(i != 1 for i in k_size): |
| | self.b2 = self.module(in_dim, lora_dim, 1, bias=False) |
| | self.bm = self.module( |
| | lora_dim, lora_dim, k_size, stride, padding, bias=False |
| | ) |
| | self.tucker = True |
| | else: |
| | self.b2 = self.module( |
| | in_dim, lora_dim, k_size, stride, padding, bias=False |
| | ) |
| | self.b1 = self.module(lora_dim, out_dim, 1, bias=False) |
| | else: |
| | self.isconv = False |
| | self.down_op = F.linear |
| | self.up_op = F.linear |
| | in_dim = org_module.in_features |
| | out_dim = org_module.out_features |
| | self.a2 = nn.Linear(in_dim, lora_dim, bias=False) |
| | self.a1 = nn.Linear(lora_dim, in_dim, bias=False) |
| | self.b2 = nn.Linear(in_dim, lora_dim, bias=False) |
| | self.b1 = nn.Linear(lora_dim, out_dim, bias=False) |
| |
|
| | if type(alpha) == torch.Tensor: |
| | alpha = alpha.detach().float().numpy() |
| | alpha = lora_dim if alpha is None or alpha == 0 else alpha |
| |
|
| | r_factor = lora_dim |
| | if self.rs_lora: |
| | r_factor = math.sqrt(r_factor) |
| |
|
| | self.scale = alpha / r_factor |
| |
|
| | self.register_buffer("alpha", torch.tensor(alpha)) |
| |
|
| | if use_scalar: |
| | self.scalar = nn.Parameter(torch.tensor(0.0)) |
| | else: |
| | self.register_buffer("scalar", torch.tensor(1.0), persistent=False) |
| |
|
| | |
| | torch.nn.init.kaiming_uniform_(self.a1.weight, a=math.sqrt(5)) |
| | torch.nn.init.kaiming_uniform_(self.b1.weight, a=math.sqrt(5)) |
| | if use_scalar: |
| | torch.nn.init.kaiming_uniform_(self.a2.weight, a=math.sqrt(5)) |
| | torch.nn.init.kaiming_uniform_(self.b2.weight, a=math.sqrt(5)) |
| | else: |
| | torch.nn.init.zeros_(self.a2.weight) |
| | torch.nn.init.zeros_(self.b2.weight) |
| |
|
| | @classmethod |
| | def make_module_from_state_dict( |
| | cls, lora_name, orig_module, a1, a2, b1, b2, bm, alpha |
| | ): |
| | module = cls( |
| | lora_name, |
| | orig_module, |
| | 1, |
| | a2.size(0), |
| | float(alpha), |
| | use_tucker=bm is not None, |
| | ) |
| | module.a1.weight.data.copy_(a1) |
| | module.a2.weight.data.copy_(a2) |
| | module.b1.weight.data.copy_(b1) |
| | module.b2.weight.data.copy_(b2) |
| | if bm is not None: |
| | module.bm.weight.data.copy_(bm) |
| | return module |
| |
|
| | def custom_state_dict(self): |
| | destination = {} |
| | destination["alpha"] = self.alpha |
| | destination["a1.weight"] = self.a1.weight |
| | destination["a2.weight"] = self.a2.weight * self.scalar |
| | destination["b1.weight"] = self.b1.weight |
| | destination["b2.weight"] = self.b2.weight * self.scalar |
| | if self.tucker: |
| | destination["bm.weight"] = self.bm.weight |
| | return destination |
| |
|
| | def load_weight_hook(self, module: nn.Module, incompatible_keys): |
| | missing_keys = incompatible_keys.missing_keys |
| | for key in missing_keys: |
| | if "scalar" in key: |
| | del missing_keys[missing_keys.index(key)] |
| | if isinstance(self.scalar, nn.Parameter): |
| | self.scalar.data.copy_(torch.ones_like(self.scalar)) |
| | elif getattr(self, "scalar", None) is not None: |
| | self.scalar.copy_(torch.ones_like(self.scalar)) |
| | else: |
| | self.register_buffer( |
| | "scalar", torch.ones_like(self.scalar), persistent=False |
| | ) |
| |
|
| | def make_weight(self, device=None): |
| | wa1 = self.a1.weight.view(self.a1.weight.size(0), -1) |
| | wa2 = self.a2.weight.view(self.a2.weight.size(0), -1) |
| | orig = self.org_weight |
| |
|
| | if self.tucker: |
| | wb = tucker_weight_from_conv(self.b1.weight, self.b2.weight, self.bm.weight) |
| | else: |
| | wb1 = self.b1.weight.view(self.b1.weight.size(0), -1) |
| | wb2 = self.b2.weight.view(self.b2.weight.size(0), -1) |
| | wb = wb1 @ wb2 |
| | wb = wb.view(*orig.shape) |
| | if orig.dim() > 2: |
| | w_wa1 = torch.einsum("o i ..., i j -> o j ...", orig, wa1) |
| | w_wa2 = torch.einsum("o i ..., i j -> o j ...", w_wa1, wa2) |
| | else: |
| | w_wa2 = (orig @ wa1) @ wa2 |
| | return (wb + w_wa2) * self.scale * self.scalar |
| |
|
| | def get_diff_weight(self, multiplier=1.0, shape=None, device=None): |
| | weight = self.make_weight(device) * multiplier |
| | if shape is not None: |
| | weight = weight.view(shape) |
| | return weight, None |
| |
|
| | def get_merged_weight(self, multiplier=1, shape=None, device=None): |
| | diff_w, _ = self.get_diff_weight(multiplier, shape, device) |
| | return self.org_weight + diff_w, None |
| |
|
| | def _bypass_forward(self, x, scale=1, diff=False): |
| | scale = self.scale * scale |
| | ax_mid = self.a2(x) * scale |
| | bx_mid = self.b2(x) * scale |
| |
|
| | if self.rank_dropout and self.training: |
| | drop_a = ( |
| | torch.rand(self.lora_dim, device=ax_mid.device) < self.rank_dropout |
| | ).to(ax_mid.dtype) |
| | drop_b = ( |
| | torch.rand(self.lora_dim, device=bx_mid.device) < self.rank_dropout |
| | ).to(bx_mid.dtype) |
| | if self.rank_dropout_scale: |
| | drop_a /= drop_a.mean() |
| | drop_b /= drop_b.mean() |
| | if (dims := len(x.shape)) == 4: |
| | drop_a = drop_a.view(1, -1, 1, 1) |
| | drop_b = drop_b.view(1, -1, 1, 1) |
| | else: |
| | drop_a = drop_a.view(*[1] * (dims - 1), -1) |
| | drop_b = drop_b.view(*[1] * (dims - 1), -1) |
| | ax_mid = ax_mid * drop_a |
| | bx_mid = bx_mid * drop_b |
| | return ( |
| | self.org_forward( |
| | (0 if diff else x) + self.drop(self.a1(ax_mid)) * self.scale |
| | ) |
| | + self.drop(self.b1(bx_mid)) * self.scale |
| | ) |
| |
|
| | def bypass_forward_diff(self, x, scale=1): |
| | return self._bypass_forward(x, scale=scale, diff=True) |
| |
|
| | def bypass_forward(self, x, scale=1): |
| | return self._bypass_forward(x, scale=scale, diff=False) |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | if self.module_dropout and self.training: |
| | if torch.rand(1) < self.module_dropout: |
| | return self.org_forward(x) |
| | if self.bypass_mode: |
| | return self.bypass_forward(x, self.multiplier) |
| | else: |
| | weight = ( |
| | self.org_module[0].weight.data.to(self.dtype) |
| | + self.get_diff_weight(multiplier=self.multiplier)[0] |
| | ) |
| | bias = ( |
| | None |
| | if self.org_module[0].bias is None |
| | else self.org_module[0].bias.data |
| | ) |
| | return self.op(x, weight, bias, **self.kw_dict) |
| |
|