|
|
from functools import cache |
|
|
from math import log2 |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from einops import rearrange |
|
|
|
|
|
from .base import LycorisBaseModule |
|
|
from ..functional import power2factorization |
|
|
from ..logging import logger |
|
|
|
|
|
|
|
|
@cache |
|
|
def log_butterfly_factorize(dim, factor, result): |
|
|
logger.info( |
|
|
f"Use BOFT({int(log2(result[1]))}, {result[0]//2})" |
|
|
f" (equivalent to factor={result[0]}) " |
|
|
f"for {dim=} and {factor=}" |
|
|
) |
|
|
|
|
|
|
|
|
def butterfly_factor(dimension: int, factor: int = -1) -> tuple[int, int]: |
|
|
m, n = power2factorization(dimension, factor) |
|
|
|
|
|
if n == 0: |
|
|
raise ValueError( |
|
|
f"It is impossible to decompose {dimension} with factor {factor} under BOFT constraints." |
|
|
) |
|
|
|
|
|
log_butterfly_factorize(dimension, factor, (m, n)) |
|
|
return m, n |
|
|
|
|
|
|
|
|
class ButterflyOFTModule(LycorisBaseModule): |
|
|
name = "boft" |
|
|
support_module = { |
|
|
"linear", |
|
|
"conv1d", |
|
|
"conv2d", |
|
|
"conv3d", |
|
|
} |
|
|
weight_list = [ |
|
|
"oft_blocks", |
|
|
"rescale", |
|
|
"alpha", |
|
|
] |
|
|
weight_list_det = ["oft_blocks"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lora_name, |
|
|
org_module: nn.Module, |
|
|
multiplier=1.0, |
|
|
lora_dim=4, |
|
|
alpha=1, |
|
|
dropout=0.0, |
|
|
rank_dropout=0.0, |
|
|
module_dropout=0.0, |
|
|
use_tucker=False, |
|
|
use_scalar=False, |
|
|
rank_dropout_scale=False, |
|
|
constraint=0, |
|
|
rescaled=False, |
|
|
bypass_mode=None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__( |
|
|
lora_name, |
|
|
org_module, |
|
|
multiplier, |
|
|
dropout, |
|
|
rank_dropout, |
|
|
module_dropout, |
|
|
rank_dropout_scale, |
|
|
bypass_mode, |
|
|
) |
|
|
if self.module_type not in self.support_module: |
|
|
raise ValueError(f"{self.module_type} is not supported in BOFT algo.") |
|
|
|
|
|
out_dim = self.dim |
|
|
b, m_exp = butterfly_factor(out_dim, lora_dim) |
|
|
self.block_size = b |
|
|
self.block_num = m_exp |
|
|
|
|
|
self.boft_b = b |
|
|
self.boft_m = sum(int(i) for i in f"{m_exp-1:b}") + 1 |
|
|
|
|
|
self.rescaled = rescaled |
|
|
self.constraint = constraint * out_dim |
|
|
self.register_buffer("alpha", torch.tensor(constraint)) |
|
|
self.oft_blocks = nn.Parameter( |
|
|
torch.zeros(self.boft_m, self.block_num, self.block_size, self.block_size) |
|
|
) |
|
|
if rescaled: |
|
|
self.rescale = nn.Parameter( |
|
|
torch.ones(out_dim, *(1 for _ in range(org_module.weight.dim() - 1))) |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def algo_check(cls, state_dict, lora_name): |
|
|
if f"{lora_name}.oft_blocks" in state_dict: |
|
|
oft_blocks = state_dict[f"{lora_name}.oft_blocks"] |
|
|
if oft_blocks.ndim == 4: |
|
|
return True |
|
|
return False |
|
|
|
|
|
@classmethod |
|
|
def make_module_from_state_dict( |
|
|
cls, lora_name, orig_module, oft_blocks, rescale, alpha |
|
|
): |
|
|
m, n, s, _ = oft_blocks.shape |
|
|
module = cls( |
|
|
lora_name, |
|
|
orig_module, |
|
|
1, |
|
|
lora_dim=s, |
|
|
constraint=float(alpha), |
|
|
rescaled=rescale is not None, |
|
|
) |
|
|
module.oft_blocks.copy_(oft_blocks) |
|
|
if rescale is not None: |
|
|
module.rescale.copy_(rescale) |
|
|
return module |
|
|
|
|
|
@property |
|
|
def I(self): |
|
|
return torch.eye(self.block_size, device=self.device) |
|
|
|
|
|
def get_r(self): |
|
|
I = self.I |
|
|
|
|
|
q = self.oft_blocks - self.oft_blocks.transpose(-1, -2) |
|
|
normed_q = q |
|
|
|
|
|
if self.constraint > 0: |
|
|
q_norm = torch.norm(q) + 1e-8 |
|
|
if q_norm > self.constraint: |
|
|
normed_q = q * self.constraint / q_norm |
|
|
|
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse() |
|
|
return r |
|
|
|
|
|
def make_weight(self, scale=1, device=None, diff=False): |
|
|
m = self.boft_m |
|
|
b = self.boft_b |
|
|
r_b = b // 2 |
|
|
r = self.get_r() |
|
|
inp = org = self.org_weight.to(device, dtype=r.dtype) |
|
|
|
|
|
for i in range(m): |
|
|
bi = r[i] |
|
|
g = 2 |
|
|
k = 2**i * r_b |
|
|
if scale != 1: |
|
|
bi = bi * scale + (1 - scale) * self.I |
|
|
inp = ( |
|
|
inp.unflatten(-1, (-1, g, k)) |
|
|
.transpose(-2, -1) |
|
|
.flatten(-3) |
|
|
.unflatten(-1, (-1, b)) |
|
|
) |
|
|
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) |
|
|
inp = ( |
|
|
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) |
|
|
) |
|
|
|
|
|
if self.rescaled: |
|
|
inp = inp * self.rescale |
|
|
|
|
|
if diff: |
|
|
inp = inp - org |
|
|
|
|
|
return inp.to(self.oft_blocks.dtype) |
|
|
|
|
|
def get_diff_weight(self, multiplier=1, shape=None, device=None): |
|
|
diff = self.make_weight(scale=multiplier, device=device, diff=True) |
|
|
if shape is not None: |
|
|
diff = diff.view(shape) |
|
|
return diff, None |
|
|
|
|
|
def get_merged_weight(self, multiplier=1, shape=None, device=None): |
|
|
diff = self.make_weight(scale=multiplier, device=device) |
|
|
if shape is not None: |
|
|
diff = diff.view(shape) |
|
|
return diff, None |
|
|
|
|
|
@torch.no_grad() |
|
|
def apply_max_norm(self, max_norm, device=None): |
|
|
orig_norm = self.oft_blocks.to(device).norm() |
|
|
norm = torch.clamp(orig_norm, max_norm / 2) |
|
|
desired = torch.clamp(norm, max=max_norm) |
|
|
ratio = desired / norm |
|
|
|
|
|
scaled = norm != desired |
|
|
if scaled: |
|
|
self.oft_blocks *= ratio |
|
|
|
|
|
return scaled, orig_norm * ratio |
|
|
|
|
|
def _bypass_forward(self, x, scale=1, diff=False): |
|
|
m = self.boft_m |
|
|
b = self.boft_b |
|
|
r_b = b // 2 |
|
|
r = self.get_r() |
|
|
inp = org = self.org_forward(x) |
|
|
if self.op in {F.conv2d, F.conv1d, F.conv3d}: |
|
|
inp = inp.transpose(1, -1) |
|
|
|
|
|
for i in range(m): |
|
|
bi = r[i] |
|
|
g = 2 |
|
|
k = 2**i * r_b |
|
|
if scale != 1: |
|
|
bi = bi * scale + (1 - scale) * self.I |
|
|
inp = ( |
|
|
inp.unflatten(-1, (-1, g, k)) |
|
|
.transpose(-2, -1) |
|
|
.flatten(-3) |
|
|
.unflatten(-1, (-1, b)) |
|
|
) |
|
|
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) |
|
|
inp = ( |
|
|
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) |
|
|
) |
|
|
|
|
|
if self.rescaled: |
|
|
inp = inp * self.rescale.transpose(0, -1) |
|
|
|
|
|
if self.op in {F.conv2d, F.conv1d, F.conv3d}: |
|
|
inp = inp.transpose(1, -1) |
|
|
|
|
|
if diff: |
|
|
inp = inp - org |
|
|
return inp |
|
|
|
|
|
def bypass_forward_diff(self, x, scale=1): |
|
|
return self._bypass_forward(x, scale, diff=True) |
|
|
|
|
|
def bypass_forward(self, x, scale=1): |
|
|
return self._bypass_forward(x, scale, diff=False) |
|
|
|
|
|
def forward(self, x, *args, **kwargs): |
|
|
if self.module_dropout and self.training: |
|
|
if torch.rand(1) < self.module_dropout: |
|
|
return self.org_forward(x) |
|
|
scale = self.multiplier |
|
|
|
|
|
if self.bypass_mode: |
|
|
return self.bypass_forward(x, scale) |
|
|
else: |
|
|
w = self.make_weight(scale, x.device) |
|
|
kw_dict = self.kw_dict | {"weight": w, "bias": self.org_module[0].bias} |
|
|
return self.op(x, **kw_dict) |
|
|
|