| from functools import cache |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .base import LycorisBaseModule |
| from ..functional import factorization |
| from ..logging import logger |
|
|
|
|
| @cache |
| def log_oft_factorize(dim, factor, num, bdim): |
| logger.info( |
| f"Use OFT(block num: {num}, block dim: {bdim})" |
| f" (equivalent to lora_dim={num}) " |
| f"for {dim=} and lora_dim={factor=}" |
| ) |
|
|
|
|
| class DiagOFTModule(LycorisBaseModule): |
| name = "diag-oft" |
| support_module = { |
| "linear", |
| "conv1d", |
| "conv2d", |
| "conv3d", |
| } |
| weight_list = [ |
| "oft_blocks", |
| "rescale", |
| "alpha", |
| ] |
| weight_list_det = ["oft_blocks"] |
|
|
| def __init__( |
| self, |
| lora_name, |
| org_module: nn.Module, |
| multiplier=1.0, |
| lora_dim=4, |
| alpha=1, |
| dropout=0.0, |
| rank_dropout=0.0, |
| module_dropout=0.0, |
| use_tucker=False, |
| use_scalar=False, |
| rank_dropout_scale=False, |
| constraint=0, |
| rescaled=False, |
| bypass_mode=None, |
| **kwargs, |
| ): |
| super().__init__( |
| lora_name, |
| org_module, |
| multiplier, |
| dropout, |
| rank_dropout, |
| module_dropout, |
| rank_dropout_scale, |
| bypass_mode, |
| ) |
| if self.module_type not in self.support_module: |
| raise ValueError(f"{self.module_type} is not supported in Diag-OFT algo.") |
|
|
| out_dim = self.dim |
| self.block_size, self.block_num = factorization(out_dim, lora_dim) |
| |
| self.rescaled = rescaled |
| self.constraint = constraint * out_dim |
| self.register_buffer("alpha", torch.tensor(constraint)) |
| self.oft_blocks = nn.Parameter( |
| torch.zeros(self.block_num, self.block_size, self.block_size) |
| ) |
| if rescaled: |
| self.rescale = nn.Parameter( |
| torch.ones(out_dim, *(1 for _ in range(org_module.weight.dim() - 1))) |
| ) |
|
|
| log_oft_factorize( |
| dim=out_dim, |
| factor=lora_dim, |
| num=self.block_num, |
| bdim=self.block_size, |
| ) |
|
|
| @classmethod |
| def algo_check(cls, state_dict, lora_name): |
| if f"{lora_name}.oft_blocks" in state_dict: |
| oft_blocks = state_dict[f"{lora_name}.oft_blocks"] |
| if oft_blocks.ndim == 3: |
| return True |
| return False |
|
|
| @classmethod |
| def make_module_from_state_dict( |
| cls, lora_name, orig_module, oft_blocks, rescale, alpha |
| ): |
| n, s, _ = oft_blocks.shape |
| module = cls( |
| lora_name, |
| orig_module, |
| 1, |
| lora_dim=s, |
| constraint=float(alpha), |
| rescaled=rescale is not None, |
| ) |
| module.oft_blocks.copy_(oft_blocks) |
| if rescale is not None: |
| module.rescale.copy_(rescale) |
| return module |
|
|
| @property |
| def I(self): |
| return torch.eye(self.block_size, device=self.device) |
|
|
| def get_r(self): |
| I = self.I |
| |
| q = self.oft_blocks - self.oft_blocks.transpose(1, 2) |
| normed_q = q |
| if self.constraint > 0: |
| q_norm = torch.norm(q) + 1e-8 |
| if q_norm > self.constraint: |
| normed_q = q * self.constraint / q_norm |
| |
| r = (I + normed_q) @ (I - normed_q).float().inverse() |
| return r |
|
|
| def make_weight(self, scale=1, device=None, diff=False): |
| r = self.get_r() |
| _, *shape = self.org_weight.shape |
| org_weight = self.org_weight.to(device, dtype=r.dtype) |
| org_weight = org_weight.view(self.block_num, self.block_size, *shape) |
| |
| weight = torch.einsum( |
| "k n m, k n ... -> k m ...", |
| self.rank_drop(r * scale) - scale * self.I + (0 if diff else self.I), |
| org_weight, |
| ).view(-1, *shape) |
| if self.rescaled: |
| weight = self.rescale * weight |
| if diff: |
| weight = weight + (self.rescale - 1) * org_weight |
| return weight.to(self.oft_blocks.dtype) |
|
|
| def get_diff_weight(self, multiplier=1, shape=None, device=None): |
| diff = self.make_weight(scale=multiplier, device=device, diff=True) |
| if shape is not None: |
| diff = diff.view(shape) |
| return diff, None |
|
|
| def get_merged_weight(self, multiplier=1, shape=None, device=None): |
| diff = self.make_weight(scale=multiplier, device=device) |
| if shape is not None: |
| diff = diff.view(shape) |
| return diff, None |
|
|
| @torch.no_grad() |
| def apply_max_norm(self, max_norm, device=None): |
| orig_norm = self.oft_blocks.to(device).norm() |
| norm = torch.clamp(orig_norm, max_norm / 2) |
| desired = torch.clamp(norm, max=max_norm) |
| ratio = desired / norm |
|
|
| scaled = norm != desired |
| if scaled: |
| self.oft_blocks *= ratio |
|
|
| return scaled, orig_norm * ratio |
|
|
| def _bypass_forward(self, x, scale=1, diff=False): |
| r = self.get_r() |
| org_out = self.org_forward(x) |
| if self.op in {F.conv2d, F.conv1d, F.conv3d}: |
| org_out = org_out.transpose(1, -1) |
| *shape, _ = org_out.shape |
| org_out = org_out.view(*shape, self.block_num, self.block_size) |
| mask = neg_mask = 1 |
| if self.dropout != 0 and self.training: |
| mask = torch.ones_like(org_out) |
| mask = self.drop(mask) |
| neg_mask = torch.max(mask) - mask |
| oft_out = torch.einsum( |
| "k n m, ... k n -> ... k m", |
| r * scale * mask + (1 - scale) * self.I * neg_mask, |
| org_out, |
| ) |
| if diff: |
| out = out - org_out |
| out = oft_out.view(*shape, -1) |
| if self.rescaled: |
| out = self.rescale.transpose(-1, 0) * out |
| out = out + (self.rescale.transpose(-1, 0) - 1) * org_out |
| if self.op in {F.conv2d, F.conv1d, F.conv3d}: |
| out = out.transpose(1, -1) |
| return out |
|
|
| def bypass_forward_diff(self, x, scale=1): |
| return self._bypass_forward(x, scale, diff=True) |
|
|
| def bypass_forward(self, x, scale=1): |
| return self._bypass_forward(x, scale, diff=False) |
|
|
| def forward(self, x: torch.Tensor, *args, **kwargs): |
| if self.module_dropout and self.training: |
| if torch.rand(1) < self.module_dropout: |
| return self.org_forward(x) |
| scale = self.multiplier |
|
|
| if self.bypass_mode: |
| return self.bypass_forward(x, scale) |
| else: |
| w = self.make_weight(scale, x.device) |
| kw_dict = self.kw_dict | {"weight": w, "bias": self.org_module[0].bias} |
| return self.op(x, **kw_dict) |
|
|