tonyshark's picture
Upload 119 files
0bb1a82 verified
from functools import cache
from math import log2
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .base import LycorisBaseModule
from ..functional import power2factorization
from ..logging import logger
@cache
def log_butterfly_factorize(dim, factor, result):
logger.info(
f"Use BOFT({int(log2(result[1]))}, {result[0]//2})"
f" (equivalent to factor={result[0]}) "
f"for {dim=} and {factor=}"
)
def butterfly_factor(dimension: int, factor: int = -1) -> tuple[int, int]:
m, n = power2factorization(dimension, factor)
if n == 0:
raise ValueError(
f"It is impossible to decompose {dimension} with factor {factor} under BOFT constraints."
)
log_butterfly_factorize(dimension, factor, (m, n))
return m, n
class ButterflyOFTModule(LycorisBaseModule):
name = "boft"
support_module = {
"linear",
"conv1d",
"conv2d",
"conv3d",
}
weight_list = [
"oft_blocks",
"rescale",
"alpha",
]
weight_list_det = ["oft_blocks"]
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.0,
rank_dropout=0.0,
module_dropout=0.0,
use_tucker=False,
use_scalar=False,
rank_dropout_scale=False,
constraint=0,
rescaled=False,
bypass_mode=None,
**kwargs,
):
super().__init__(
lora_name,
org_module,
multiplier,
dropout,
rank_dropout,
module_dropout,
rank_dropout_scale,
bypass_mode,
)
if self.module_type not in self.support_module:
raise ValueError(f"{self.module_type} is not supported in BOFT algo.")
out_dim = self.dim
b, m_exp = butterfly_factor(out_dim, lora_dim)
self.block_size = b
self.block_num = m_exp
# BOFT(m, b)
self.boft_b = b
self.boft_m = sum(int(i) for i in f"{m_exp-1:b}") + 1
# block_num > block_size
self.rescaled = rescaled
self.constraint = constraint * out_dim
self.register_buffer("alpha", torch.tensor(constraint))
self.oft_blocks = nn.Parameter(
torch.zeros(self.boft_m, self.block_num, self.block_size, self.block_size)
)
if rescaled:
self.rescale = nn.Parameter(
torch.ones(out_dim, *(1 for _ in range(org_module.weight.dim() - 1)))
)
@classmethod
def algo_check(cls, state_dict, lora_name):
if f"{lora_name}.oft_blocks" in state_dict:
oft_blocks = state_dict[f"{lora_name}.oft_blocks"]
if oft_blocks.ndim == 4:
return True
return False
@classmethod
def make_module_from_state_dict(
cls, lora_name, orig_module, oft_blocks, rescale, alpha
):
m, n, s, _ = oft_blocks.shape
module = cls(
lora_name,
orig_module,
1,
lora_dim=s,
constraint=float(alpha),
rescaled=rescale is not None,
)
module.oft_blocks.copy_(oft_blocks)
if rescale is not None:
module.rescale.copy_(rescale)
return module
@property
def I(self):
return torch.eye(self.block_size, device=self.device)
def get_r(self):
I = self.I
# for Q = -Q^T
q = self.oft_blocks - self.oft_blocks.transpose(-1, -2)
normed_q = q
# Diag OFT style constrain
if self.constraint > 0:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# use float() to prevent unsupported type
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r
def make_weight(self, scale=1, device=None, diff=False):
m = self.boft_m
b = self.boft_b
r_b = b // 2
r = self.get_r()
inp = org = self.org_weight.to(device, dtype=r.dtype)
for i in range(m):
bi = r[i] # b_num, b_size, b_size
g = 2
k = 2**i * r_b
if scale != 1:
bi = bi * scale + (1 - scale) * self.I
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, b))
)
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
inp = (
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
)
if self.rescaled:
inp = inp * self.rescale
if diff:
inp = inp - org
return inp.to(self.oft_blocks.dtype)
def get_diff_weight(self, multiplier=1, shape=None, device=None):
diff = self.make_weight(scale=multiplier, device=device, diff=True)
if shape is not None:
diff = diff.view(shape)
return diff, None
def get_merged_weight(self, multiplier=1, shape=None, device=None):
diff = self.make_weight(scale=multiplier, device=device)
if shape is not None:
diff = diff.view(shape)
return diff, None
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.oft_blocks.to(device).norm()
norm = torch.clamp(orig_norm, max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired / norm
scaled = norm != desired
if scaled:
self.oft_blocks *= ratio
return scaled, orig_norm * ratio
def _bypass_forward(self, x, scale=1, diff=False):
m = self.boft_m
b = self.boft_b
r_b = b // 2
r = self.get_r()
inp = org = self.org_forward(x)
if self.op in {F.conv2d, F.conv1d, F.conv3d}:
inp = inp.transpose(1, -1)
for i in range(m):
bi = r[i] # b_num, b_size, b_size
g = 2
k = 2**i * r_b
if scale != 1:
bi = bi * scale + (1 - scale) * self.I
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, b))
)
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
inp = (
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
)
if self.rescaled:
inp = inp * self.rescale.transpose(0, -1)
if self.op in {F.conv2d, F.conv1d, F.conv3d}:
inp = inp.transpose(1, -1)
if diff:
inp = inp - org
return inp
def bypass_forward_diff(self, x, scale=1):
return self._bypass_forward(x, scale, diff=True)
def bypass_forward(self, x, scale=1):
return self._bypass_forward(x, scale, diff=False)
def forward(self, x, *args, **kwargs):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.org_forward(x)
scale = self.multiplier
if self.bypass_mode:
return self.bypass_forward(x, scale)
else:
w = self.make_weight(scale, x.device)
kw_dict = self.kw_dict | {"weight": w, "bias": self.org_module[0].bias}
return self.op(x, **kw_dict)