| import math |
| import random |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .base import LycorisBaseModule |
| from ..utils import product |
|
|
|
|
| class DyLoraModule(LycorisBaseModule): |
| support_module = { |
| "linear", |
| "conv1d", |
| "conv2d", |
| "conv3d", |
| } |
|
|
| def __init__( |
| self, |
| lora_name, |
| org_module: nn.Module, |
| multiplier=1.0, |
| lora_dim=4, |
| alpha=1, |
| dropout=0.0, |
| rank_dropout=0.0, |
| module_dropout=0.0, |
| use_tucker=False, |
| block_size=4, |
| use_scalar=False, |
| rank_dropout_scale=False, |
| weight_decompose=False, |
| bypass_mode=None, |
| rs_lora=False, |
| train_on_input=False, |
| **kwargs, |
| ): |
| """if alpha == 0 or None, alpha is rank (no scaling).""" |
| super().__init__( |
| lora_name, |
| org_module, |
| multiplier, |
| dropout, |
| rank_dropout, |
| module_dropout, |
| rank_dropout_scale, |
| bypass_mode, |
| ) |
| if self.module_type not in self.support_module: |
| raise ValueError(f"{self.module_type} is not supported in IA^3 algo.") |
| assert lora_dim % block_size == 0, "lora_dim must be a multiple of block_size" |
| self.block_count = lora_dim // block_size |
| self.block_size = block_size |
|
|
| shape = ( |
| self.shape[0], |
| product(self.shape[1:]), |
| ) |
|
|
| self.lora_dim = lora_dim |
| self.up_list = nn.ParameterList( |
| [torch.empty(shape[0], self.block_size) for i in range(self.block_count)] |
| ) |
| self.down_list = nn.ParameterList( |
| [torch.empty(self.block_size, shape[1]) for i in range(self.block_count)] |
| ) |
|
|
| if type(alpha) == torch.Tensor: |
| alpha = alpha.detach().float().numpy() |
| alpha = lora_dim if alpha is None or alpha == 0 else alpha |
| self.scale = alpha / self.lora_dim |
| self.register_buffer("alpha", torch.tensor(alpha)) |
|
|
| |
| for v in self.down_list: |
| torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5)) |
| for v in self.up_list: |
| torch.nn.init.zeros_(v) |
|
|
| def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): |
| return |
|
|
| def custom_state_dict(self): |
| destination = {} |
| destination["alpha"] = self.alpha |
| destination["lora_up.weight"] = nn.Parameter( |
| torch.concat(list(self.up_list), dim=1) |
| ) |
| destination["lora_down.weight"] = nn.Parameter( |
| torch.concat(list(self.down_list)).reshape( |
| self.lora_dim, -1, *self.shape[2:] |
| ) |
| ) |
| return destination |
|
|
| def get_weight(self, rank): |
| b = math.ceil(rank / self.block_size) |
| down = torch.concat( |
| list(i.data for i in self.down_list[:b]) + list(self.down_list[b : (b + 1)]) |
| ) |
| up = torch.concat( |
| list(i.data for i in self.up_list[:b]) + list(self.up_list[b : (b + 1)]), |
| dim=1, |
| ) |
| return down, up, self.alpha / (b + 1) |
|
|
| def get_random_rank_weight(self): |
| b = random.randint(0, self.block_count - 1) |
| return self.get_weight(b * self.block_size) |
|
|
| def get_diff_weight(self, multiplier=1, shape=None, device=None, rank=None): |
| if rank is None: |
| down, up, scale = self.get_random_rank_weight() |
| else: |
| down, up, scale = self.get_weight(rank) |
| w = up @ (down * (scale * multiplier)) |
| if device is not None: |
| w = w.to(device) |
| if shape is not None: |
| w = w.view(shape) |
| else: |
| w = w.view(self.shape) |
| return w, None |
|
|
| def get_merged_weight(self, multiplier=1, shape=None, device=None, rank=None): |
| diff, _ = self.get_diff_weight(multiplier, shape, device, rank) |
| return diff + self.org_weight, None |
|
|
| def bypass_forward_diff(self, x, scale=1, rank=None): |
| if rank is None: |
| down, up, gamma = self.get_random_rank_weight() |
| else: |
| down, up, scale = self.get_weight(rank) |
| down = down.view(self.lora_dim, -1, *self.shape[2:]) |
| up = up.view(-1, self.lora_dim, *(1 for _ in self.shape[2:])) |
| scale = scale * gamma |
| return self.op(self.op(x, down, **self.kw_dict), up) |
|
|
| def bypass_forward(self, x, scale=1, rank=None): |
| return self.org_forward(x) + self.bypass_forward_diff(x, scale, rank) |
|
|
| def forward(self, x, *args, **kwargs): |
| if self.module_dropout and self.training: |
| if torch.rand(1) < self.module_dropout: |
| return self.org_forward(x) |
| if self.bypass_mode: |
| return self.bypass_forward(x, self.multiplier) |
| else: |
| weight = self.get_merged_weight(multiplier=self.multiplier)[0] |
| bias = ( |
| None |
| if self.org_module[0].bias is None |
| else self.org_module[0].bias.data |
| ) |
| return self.op(x, weight, bias, **self.kw_dict) |
|
|