|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .base import LycorisBaseModule |
|
|
|
|
|
|
|
|
class IA3Module(LycorisBaseModule): |
|
|
name = "ia3" |
|
|
support_module = { |
|
|
"linear", |
|
|
"conv1d", |
|
|
"conv2d", |
|
|
"conv3d", |
|
|
} |
|
|
weight_list = ["weight", "on_input"] |
|
|
weight_list_det = ["on_input"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lora_name, |
|
|
org_module: nn.Module, |
|
|
multiplier=1.0, |
|
|
lora_dim=4, |
|
|
alpha=1, |
|
|
dropout=0.0, |
|
|
rank_dropout=0.0, |
|
|
module_dropout=0.0, |
|
|
use_tucker=False, |
|
|
use_scalar=False, |
|
|
rank_dropout_scale=False, |
|
|
weight_decompose=False, |
|
|
bypass_mode=None, |
|
|
rs_lora=False, |
|
|
train_on_input=False, |
|
|
**kwargs, |
|
|
): |
|
|
"""if alpha == 0 or None, alpha is rank (no scaling).""" |
|
|
super().__init__( |
|
|
lora_name, |
|
|
org_module, |
|
|
multiplier, |
|
|
dropout, |
|
|
rank_dropout, |
|
|
module_dropout, |
|
|
rank_dropout_scale, |
|
|
bypass_mode, |
|
|
) |
|
|
if self.module_type not in self.support_module: |
|
|
raise ValueError(f"{self.module_type} is not supported in IA^3 algo.") |
|
|
|
|
|
if self.module_type.startswith("conv"): |
|
|
self.isconv = True |
|
|
in_dim = org_module.in_channels |
|
|
out_dim = org_module.out_channels |
|
|
if train_on_input: |
|
|
train_dim = in_dim |
|
|
else: |
|
|
train_dim = out_dim |
|
|
self.weight = nn.Parameter( |
|
|
torch.empty(1, train_dim, *(1 for _ in self.shape[2:])) |
|
|
) |
|
|
else: |
|
|
in_dim = org_module.in_features |
|
|
out_dim = org_module.out_features |
|
|
if train_on_input: |
|
|
train_dim = in_dim |
|
|
else: |
|
|
train_dim = out_dim |
|
|
|
|
|
self.weight = nn.Parameter(torch.empty(train_dim)) |
|
|
|
|
|
|
|
|
torch.nn.init.constant_(self.weight, 0) |
|
|
self.train_input = train_on_input |
|
|
self.register_buffer("on_input", torch.tensor(int(train_on_input))) |
|
|
|
|
|
@classmethod |
|
|
def make_module_from_state_dict(cls, lora_name, orig_module, weight): |
|
|
module = cls( |
|
|
lora_name, |
|
|
orig_module, |
|
|
1, |
|
|
) |
|
|
module.weight.data.copy_(weight) |
|
|
return module |
|
|
|
|
|
def apply_to(self): |
|
|
self.org_forward = self.org_module[0].forward |
|
|
self.org_module[0].forward = self.forward |
|
|
|
|
|
def make_weight(self, multiplier=1, shape=None, device=None, diff=False): |
|
|
weight = self.weight * multiplier + int(not diff) |
|
|
if self.train_input: |
|
|
diff = self.org_weight * weight |
|
|
else: |
|
|
diff = self.org_weight.transpose(0, 1) * weight |
|
|
diff = diff.transpose(0, 1) |
|
|
if shape is not None: |
|
|
diff = diff.view(shape) |
|
|
if device is not None: |
|
|
diff = diff.to(device) |
|
|
return diff |
|
|
|
|
|
def get_diff_weight(self, multiplier=1, shape=None, device=None): |
|
|
diff = self.make_weight( |
|
|
multiplier=multiplier, shape=shape, device=device, diff=True |
|
|
) |
|
|
return diff, None |
|
|
|
|
|
def get_merged_weight(self, multiplier=1, shape=None, device=None): |
|
|
diff = self.make_weight(multiplier=multiplier, shape=shape, device=device) |
|
|
return diff, None |
|
|
|
|
|
def _bypass_forward(self, x, scale=1, diff=False): |
|
|
weight = self.weight * scale + int(not diff) |
|
|
if self.train_input: |
|
|
x = x * weight |
|
|
out = self.org_forward(x) |
|
|
if not self.train_input: |
|
|
out = out * weight |
|
|
return out |
|
|
|
|
|
def bypass_forward_diff(self, x, scale=1): |
|
|
return self._bypass_forward(x, scale, diff=True) |
|
|
|
|
|
def bypass_forward(self, x, scale=1): |
|
|
return self._bypass_forward(x, scale, diff=False) |
|
|
|
|
|
def forward(self, x, *args, **kwargs): |
|
|
if self.module_dropout and self.training: |
|
|
if torch.rand(1) < self.module_dropout: |
|
|
return self.org_forward(x) |
|
|
if self.bypass_mode: |
|
|
return self.bypass_forward(x, self.multiplier) |
|
|
else: |
|
|
weight = self.get_merged_weight(multiplier=self.multiplier)[0] |
|
|
bias = ( |
|
|
None |
|
|
if self.org_module[0].bias is None |
|
|
else self.org_module[0].bias.data |
|
|
) |
|
|
return self.op(x, weight, bias, **self.kw_dict) |
|
|
|