tonyshark's picture
Upload 119 files
0bb1a82 verified
import torch
import torch.nn as nn
from .base import LycorisBaseModule
class IA3Module(LycorisBaseModule):
name = "ia3"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = ["weight", "on_input"]
weight_list_det = ["on_input"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
weight_decompose=False,
bypass_mode=None,
rs_lora=False,
train_on_input=False,
**kwargs,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in IA^3 algo.")
if self.module_type.startswith("conv"):
self.isconv = True
in_dim = org_module.in_channels
out_dim = org_module.out_channels
if train_on_input:
train_dim = in_dim
else:
train_dim = out_dim
self.weight = nn.Parameter(
torch.empty(1, train_dim, *(1 for _ in self.shape[2:]))
)
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
if train_on_input:
train_dim = in_dim
else:
train_dim = out_dim
self.weight = nn.Parameter(torch.empty(train_dim))
# Need more experiences on init method
torch.nn.init.constant_(self.weight, 0)
self.train_input = train_on_input
self.register_buffer("on_input", torch.tensor(int(train_on_input)))
@classmethod
def make_module_from_state_dict(cls, lora_name, orig_module, weight):
module = cls(
lora_name,
orig_module,
1,
)
module.weight.data.copy_(weight)
return module
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def make_weight(self, multiplier=1, shape=None, device=None, diff=False):
weight = self.weight * multiplier + int(not diff)
if self.train_input:
diff = self.org_weight * weight
else:
diff = self.org_weight.transpose(0, 1) * weight
diff = diff.transpose(0, 1)
if shape is not None:
diff = diff.view(shape)
if device is not None:
diff = diff.to(device)
return diff
def get_diff_weight(self, multiplier=1, shape=None, device=None):
diff = self.make_weight(
multiplier=multiplier, shape=shape, device=device, diff=True
)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.make_weight(multiplier=multiplier, shape=shape, device=device)
return diff, None
def _bypass_forward(self, x, scale=1, diff=False):
weight = self.weight * scale + int(not diff)
if self.train_input:
x = x * weight
out = self.org_forward(x)
if not self.train_input:
out = out * weight
return out
def bypass_forward_diff(self, x, scale=1):
return self._bypass_forward(x, scale, diff=True)
def bypass_forward(self, x, scale=1):
return self._bypass_forward(x, scale, diff=False)
def forward(self, x, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
if self.bypass_mode:
return self.bypass_forward(x, self.multiplier)
else:
weight = self.get_merged_weight(multiplier=self.multiplier)[0]
bias = (
None
if self.org_module[0].bias is None
else self.org_module[0].bias.data
)
return self.op(x, weight, bias, **self.kw_dict)