nvan13's picture
Upload folder using huggingface_hub
ecadbd9 verified
import torch
import torch.nn as nn
from typing import Optional, Set
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
import einops
class SingleSharedBlockDiag(nn.Module):
def __init__(self, num_unique_blocks, share_factor, block_size_r: int = None, block_size_c: int = None, init_std=0.0):
"""
Initializes the layer with shared diagonal weights.
self: Description
num_unique_blocks (int): Number of unique weight blocks (groups).
share_factor (int): Number of times each unique block is repeated/shared.
block_size_r (int, optional): Output size of each block (Row dimension).
block_size_c (int, optional): Input size of each block (Column dimension).
"""
super().__init__()
if (block_size_r, block_size_c) == (None, None):
raise ValueError(f"Block size r,c are not valid")
elif block_size_r == None:
block_size_r = block_size_c
elif block_size_c == None:
block_size_c = block_size_r
self.num_diag_blocks = num_unique_blocks * share_factor
self.block_size_r = block_size_r
self.block_size_c = block_size_c
self.share_factor = share_factor
self.num_unique_blocks = num_unique_blocks
# only the store the diagonal
# (num_unique_blocks, block_size r c)
self.weights = nn.Parameter(torch.empty(self.num_unique_blocks, block_size_r, block_size_c))
self.init_std = init_std
self.reset_parameters()
def reset_parameters(self):
with torch.no_grad():
if self.init_std > 0:
# nn.init.normal_(self.weights, mean=0, std=0.1)
nn.init.normal_(self.weights, mean=0, std=self.init_std)
else:
nn.init.constant_(self.weights, 0)
def forward(self, x):
# x_dtype = x.dtype
# x = x.to(self.weights.dtype)
# broadcasting, from normal: B,K @ (K,K).T -> B/r,r,K @ (K,K).T
# x = einops.rearrange(x, '... (uni factor size_c) -> ... uni factor size_c',
# uni=self.num_unique_blocks, factor=self.share_factor, size_c=self.block_size_c)
x = x.view(*x.shape[:-1], self.num_unique_blocks, self.share_factor, self.block_size_c) # faster
# row vector ...x -> x @ W.t
output = torch.einsum('...ufc, urc -> ...ufr', x, self.weights) # (group (unique), share_factor, block_size_r)
# output = x @ self.weights.transpose(-2,-1)
# output = einops.rearrange(output, '... group factor size_r -> ... (group factor size_r)')
output = output.reshape(*output.shape[:-3], -1)
return output #.to(x_dtype)
class SharedMonarch(nn.Module):
def __init__(self, share_factor_L, share_factor_R, block_size_rR, block_size_cR,
block_size_rL, block_size_cL):
"""
Input Dimension N = n1 * n2 or cL * cR
Output Dimension M = m1 * m2 or rL * rR
R (Right): BlockDiag with 'n1/cL' blocks of size R
P (Permute): Transpose (m2, n1) -> (n1, m2)
L (Left): BlockDiag with 'n2' blocks of size 'n1'
block_size_cR/rL: in/out size of R/L
Note:
Layer R has 'block_size_cL' blocks.
Layer L has 'block_size_rR' blocks.
"""
super().__init__()
if block_size_rL == None:
block_size_rL = block_size_rR
if block_size_cL == None:
block_size_cL = block_size_cR
self.block_size_rL = block_size_rL
self.block_size_cL = block_size_cL
self.block_size_rR = block_size_rR
self.block_size_cR = block_size_cR
if block_size_cL % share_factor_R != 0:
raise ValueError(f"block_size_cL ({block_size_cL}) must be divisible by share_factor_R ({share_factor_R})")
num_unique_blocksR = block_size_cL // share_factor_R
if block_size_rR % share_factor_L != 0:
raise ValueError(f"block_size_rR ({block_size_rR}) must be divisible by share_factor_L ({share_factor_L})")
num_unique_blocksL = block_size_rR // share_factor_L
self.share_factor_L = share_factor_L
self.share_factor_R = share_factor_R
self.sama_L = SingleSharedBlockDiag(num_unique_blocks=num_unique_blocksL, share_factor=share_factor_L,
block_size_r=block_size_rL, block_size_c=block_size_cL, init_std=0.0)
self.sama_R = SingleSharedBlockDiag(num_unique_blocks=num_unique_blocksR, share_factor=share_factor_R,
block_size_r=block_size_rR, block_size_c=block_size_cR, init_std=1e-3) # 1e-4
def forward(self, x):
# x_dtype = x.dtype
# x = x.to(self.sama_L.weights.dtype)
# right matrix
out_r = self.sama_R(x)
# permutation
### Be careful with permutation matrices
out_P = einops.rearrange(out_r, '... (h w) -> ... (w h)', h=self.block_size_cL, w=self.block_size_rR)
# Left block
out_l = self.sama_L(out_P)
out_Pt = einops.rearrange(out_l, '... (w h) -> ... (h w)', h=self.block_size_rL, w=self.block_size_rR)
return out_Pt #.to(x_dtype)
def get_delta_weight2(self):
"""
Compute the delta weight matrix induced by the SaMA layer.
Returns:
Delta weight matrix of shape (dout, din)
"""
## Right R
sama_R = self.sama_R.weights
share_factor_R = self.share_factor_R
share_factor_L = self.share_factor_L
device, dtype = sama_R.device, sama_R.dtype
# expand: (Unique, r, c) -> (Unique, Factor, r, c) -> (Total_Blocks, r, c)
blocks_R = sama_R.unsqueeze(-3).expand(-1, share_factor_R, -1, -1).reshape(-1, self.block_size_rR, self.block_size_cR)
R_dense = torch.block_diag(*[b for b in blocks_R])
# Logic: Permutationo of (cL, rR) -> (rR, cL)
intermidiate_dim = self.block_size_cL * self.block_size_rR
idx = torch.arange(intermidiate_dim, device=device) # indices
# # View (cL, rR) -> Transpose -> Flatten (h w) -> (w h)
perm_indices = idx.view(self.block_size_cL, self.block_size_rR).t().reshape(-1)
eye_mid = torch.eye(intermidiate_dim, device=device, dtype=dtype)
P_mid = eye_mid[perm_indices]
### Left L
sama_L = self.sama_L.weights
blocks_L = sama_L.unsqueeze(1).expand(-1, share_factor_L, -1, -1).reshape(-1, self.block_size_rL, self.block_size_cL)
L_dense = torch.block_diag(*[b for b in blocks_L])
# (rR, rL) -> (rL, rR)
dim_out = self.block_size_rL * self.block_size_rR
idx_out = torch.arange(dim_out, device=device)
perm_indices_out = idx_out.view(self.block_size_rR, self.block_size_rL).t().reshape(-1)
eye_out = torch.eye(dim_out, device=device, dtype=dtype)
P_out = eye_out[perm_indices_out]
## Final
W_final = P_out @ (L_dense @ (P_mid @ R_dense))
return W_final
def get_delta_weight(self):
"""
Compute the delta weight matrix efficiently without creating large
intermediate identity or permutation matrices.
"""
device, dtype = self.sama_R.weights.device, self.sama_R.weights.dtype
# 1. Dense R (M_mid x N_in)
# weights: (unique, r, c) -> expand to (unique, factor, r, c) -> reshape
blocks_R = self.sama_R.weights.unsqueeze(1).expand(-1, self.share_factor_R, -1, -1)
blocks_R = blocks_R.reshape(-1, self.block_size_rR, self.block_size_cR)
R_dense = torch.block_diag(*[b for b in blocks_R])
# 2. Dense L (M_out x M_mid)
blocks_L = self.sama_L.weights.unsqueeze(1).expand(-1, self.share_factor_L, -1, -1)
blocks_L = blocks_L.reshape(-1, self.block_size_rL, self.block_size_cL)
L_dense = torch.block_diag(*[b for b in blocks_L])
# 3. Not P_mid @ R_dense, permute rows of R_dense by indexing
# Logic: Permute (cL, rR) -> (rR, cL)
dim_mid = self.block_size_cL * self.block_size_rR
idx_mid = torch.arange(dim_mid, device=device)
perm_idx_mid = idx_mid.view(self.block_size_cL, self.block_size_rR).t().reshape(-1)
# P_mid @ R_dense
R_permuted = R_dense[perm_idx_mid, :]
W_mid = L_dense @ R_permuted
# 5. Pt: (rR, rL) -> (rL, rR)
dim_out = self.block_size_rL * self.block_size_rR
idx_out = torch.arange(dim_out, device=device)
perm_idx_out = idx_out.view(self.block_size_rR, self.block_size_rL).t().reshape(-1)
# Pt @ W_mid: rearrange rows of W_mid
W_final = W_mid[perm_idx_out, :]
return W_final
class SamaLayer(BaseTunerLayer):
"""
Adapter-like wrapper that attaches Rotation modules to a base linear layer.
"""
adapter_layer_names: tuple[str, ...] = ("_sama_layer",)
other_param_names: tuple[str, ...] = ("share_factor_L", "share_factor_R", "scaling",
"col_L",
"row_R",
"drop_out",)
def __init__(self, base_layer: nn.Module, **kwargs):
# Let BaseTunerLayer do its init (it usually subclasses nn.Module)
super().__init__()
# store base layer and adapter containers
self.base_layer = base_layer
self._sama_layer = nn.ModuleDict() # mapping adapter_name -> Rotation module
self._sama_dropout = nn.ModuleDict()
self.scaling={} # default scaling per adapter
self._adapter_config = {} # store r, T, num_rotations per adapter
# flags (exposed in a simple way)
self._disable_adapters = False
self.merged_adapters: list[str] = []
self._cast_input_dtype_enabled = True
self.kwargs = kwargs
if isinstance(base_layer, nn.Linear):
self.in_features = base_layer.in_features
self.out_features = base_layer.out_features
else:
raise NotImplementedError("SamaLayer only supports nn.Linear base layers for now.")
@property
def _available_adapters(self) -> set[str]:
return set(self._sama_layer.keys())
@property
def disable_adapters(self) -> bool:
return self._disable_adapters
@property
def merged(self) -> bool:
return bool(self.merged_adapters)
@property
def active_adapters(self) -> list[str]:
# If some external mechanism sets active adapters, prefer it; else use all added adapters.
return getattr(self, "_active_adapters", list(self._sama_layer.keys()))
def get_base_layer(self) -> nn.Module:
return self.base_layer
def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
if not self._cast_input_dtype_enabled:
return x
return x.to(dtype)
def update_layer(
self,
adapter_name: str,
share_factor_L: int,
share_factor_R: int,
scaling: float,
col_L: int,
row_R: int,
drop_out: float,
**kwargs,
):
"""
Add / update a rotation adapter for this layer.
"""
# if r <= 0:
# raise ValueError(f"r must be positive, got {r}")
# if num_rotations <= 0:
# raise ValueError(f"num_rotations must be positive, got {num_rotations}")
col_R = self.in_features // col_L
if self.in_features % col_L != 0:
raise ValueError(f'Input mismatches, col_L = {col_L} * col_R = {col_R} vs input = {self.in_features}')
row_L = self.out_features // row_R
if self.out_features % row_R != 0:
raise ValueError(f'Output mismatches, row_L = {row_L} * row_R = {row_R} vs input = {self.out_features}')
sama_adapter = SharedMonarch(share_factor_L=share_factor_L, share_factor_R=share_factor_R,
block_size_rR=row_R, block_size_cR=col_R,
block_size_cL=col_L, block_size_rL=row_L)
self._sama_layer[adapter_name] = sama_adapter
self.scaling[adapter_name] = scaling ## No /r
self._adapter_config[adapter_name] = {"scaling": scaling,
"share_factor_L": share_factor_L, "share_factor_R": share_factor_R,
"row_R": row_R, "col_L": col_L,
"drop_out": drop_out}
if drop_out > 0.0:
sama_dropout_layer = nn.Dropout(p=drop_out)
else:
sama_dropout_layer = nn.Identity()
self._sama_dropout.update(nn.ModuleDict({adapter_name: sama_dropout_layer}))
# (optional) helper to set currently active adapters externally
def set_active_adapters(self, adapters: Optional[list[str]]):
if adapters is None:
if hasattr(self, "_active_adapters"):
delattr(self, "_active_adapters")
else:
self._active_adapters = adapters
class Linear(nn.Module, SamaLayer):
"""
A linear layer with SaMA layer for parameter-efficient fine-tuning.
"""
def __init__(self,
base_layer: nn.Linear,
adapter_name: str,
share_factor_L: int,
share_factor_R: int,
scaling: float,
col_L: int,
row_R: int,
drop_out: float,
**kwargs):
super().__init__()
SamaLayer.__init__(self, base_layer=base_layer, **kwargs)
self._active_adapter = adapter_name
self.update_layer(
adapter_name=adapter_name,
share_factor_L=share_factor_L,
share_factor_R=share_factor_R,
scaling=scaling,
col_L=col_L,
row_R=row_R,
drop_out=drop_out,
**kwargs,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None):
"""
Merge the adapter effect into the base layer weights:
W_merged = W + Pt L P R
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
return
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
# base_layer.weight shape: (out_features, in_features)
W = base_layer.weight.data # (out, in)
for active_adapter in adapter_names:
if active_adapter not in self._available_adapters:
continue
### Method 1: Identity matrix
if False:
sama_layer = self._sama_layer[active_adapter].sama_L.weights
scaling = self.scaling[active_adapter]
identity = torch.eye(self.in_features, device=sama_layer.sama_L.weights.device, dtype=sama_layer.sama_L.weights.dtype)
output_monarch = sama_layer(identity)
W_final = (base_layer.weight.data + scaling * output_monarch.T).contiguous().to(orig_dtype)
base_layer.weight.data.copy_(W_final)
else:
### Method 2: Manually doing calculation
sama_layer_weights = self._sama_layer[active_adapter].get_delta_weight()
scaling = self.scaling[active_adapter]
W_final = (base_layer.weight.data + scaling * sama_layer_weights).contiguous().to(orig_dtype)
base_layer.weight.data.copy_(W_final)
# mark merged (so unmerge can restore by inverse)
self.merged_adapters.append(active_adapter)
def unmerge(self):
"""
Reverse merges in LIFO order (pop merged adapters and invert R).
"""
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
while self.merged_adapters:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self._available_adapters:
continue
sama_layer_weights = self._sama_layer[active_adapter].get_delta_weight()
scaling = self.scaling[active_adapter]
W_final = (base_layer.weight.data - scaling * sama_layer_weights).contiguous().to(orig_dtype)
base_layer.weight.data.copy_(W_final)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x_dtype = x.dtype
base_layer = self.get_base_layer()
if self.disable_adapters:
# if merged, unmerge to ensure base_layer produces original behavior
if self.merged:
self.unmerge()
return base_layer(x, *args, **kwargs).to(x_dtype)
if self.merged:
# if merged into base layer, just forward
return base_layer(x, *args, **kwargs).to(x_dtype)
# otherwise apply active adapters (transform inputs) then call base layer
output = base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self._sama_layer:
continue
sama_layer = self._sama_layer[active_adapter]
scaling = self.scaling[active_adapter]
x = self._cast_input_dtype(x, sama_layer.sama_L.weights.dtype)
output = output + scaling * sama_layer(x).to(output.dtype)
return output.to(x_dtype)
def __repr__(self):
return f"sama.{super().__repr__()}"