| import torch |
| import torch.nn as nn |
| from typing import Optional, Set |
|
|
| from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
| import einops |
|
|
| class SingleSharedBlockDiag(nn.Module): |
| def __init__(self, num_unique_blocks, share_factor, block_size_r: int = None, block_size_c: int = None, init_std=0.0): |
| """ |
| Initializes the layer with shared diagonal weights. |
| |
| self: Description |
| num_unique_blocks (int): Number of unique weight blocks (groups). |
| share_factor (int): Number of times each unique block is repeated/shared. |
| block_size_r (int, optional): Output size of each block (Row dimension). |
| block_size_c (int, optional): Input size of each block (Column dimension). |
| |
| """ |
| super().__init__() |
|
|
| if (block_size_r, block_size_c) == (None, None): |
| raise ValueError(f"Block size r,c are not valid") |
| elif block_size_r == None: |
| block_size_r = block_size_c |
| elif block_size_c == None: |
| block_size_c = block_size_r |
| |
| self.num_diag_blocks = num_unique_blocks * share_factor |
| self.block_size_r = block_size_r |
| self.block_size_c = block_size_c |
| self.share_factor = share_factor |
|
|
| self.num_unique_blocks = num_unique_blocks |
|
|
| |
| |
| self.weights = nn.Parameter(torch.empty(self.num_unique_blocks, block_size_r, block_size_c)) |
| self.init_std = init_std |
|
|
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| with torch.no_grad(): |
| if self.init_std > 0: |
| |
| nn.init.normal_(self.weights, mean=0, std=self.init_std) |
| else: |
| nn.init.constant_(self.weights, 0) |
|
|
| def forward(self, x): |
| |
|
|
| |
|
|
| |
| |
| |
| x = x.view(*x.shape[:-1], self.num_unique_blocks, self.share_factor, self.block_size_c) |
| |
| |
| output = torch.einsum('...ufc, urc -> ...ufr', x, self.weights) |
| |
| |
| |
| output = output.reshape(*output.shape[:-3], -1) |
| |
| return output |
|
|
|
|
| class SharedMonarch(nn.Module): |
| def __init__(self, share_factor_L, share_factor_R, block_size_rR, block_size_cR, |
| block_size_rL, block_size_cL): |
| """ |
| Input Dimension N = n1 * n2 or cL * cR |
| Output Dimension M = m1 * m2 or rL * rR |
| |
| R (Right): BlockDiag with 'n1/cL' blocks of size R |
| P (Permute): Transpose (m2, n1) -> (n1, m2) |
| L (Left): BlockDiag with 'n2' blocks of size 'n1' |
| |
| block_size_cR/rL: in/out size of R/L |
| Note: |
| Layer R has 'block_size_cL' blocks. |
| Layer L has 'block_size_rR' blocks. |
| """ |
| super().__init__() |
|
|
| if block_size_rL == None: |
| block_size_rL = block_size_rR |
| if block_size_cL == None: |
| block_size_cL = block_size_cR |
|
|
| self.block_size_rL = block_size_rL |
| self.block_size_cL = block_size_cL |
| self.block_size_rR = block_size_rR |
| self.block_size_cR = block_size_cR |
|
|
| if block_size_cL % share_factor_R != 0: |
| raise ValueError(f"block_size_cL ({block_size_cL}) must be divisible by share_factor_R ({share_factor_R})") |
| num_unique_blocksR = block_size_cL // share_factor_R |
|
|
| if block_size_rR % share_factor_L != 0: |
| raise ValueError(f"block_size_rR ({block_size_rR}) must be divisible by share_factor_L ({share_factor_L})") |
| num_unique_blocksL = block_size_rR // share_factor_L |
| self.share_factor_L = share_factor_L |
| self.share_factor_R = share_factor_R |
|
|
| self.sama_L = SingleSharedBlockDiag(num_unique_blocks=num_unique_blocksL, share_factor=share_factor_L, |
| block_size_r=block_size_rL, block_size_c=block_size_cL, init_std=0.0) |
| |
| self.sama_R = SingleSharedBlockDiag(num_unique_blocks=num_unique_blocksR, share_factor=share_factor_R, |
| block_size_r=block_size_rR, block_size_c=block_size_cR, init_std=1e-3) |
| |
| def forward(self, x): |
| |
|
|
| |
| |
| |
| out_r = self.sama_R(x) |
| |
| |
| out_P = einops.rearrange(out_r, '... (h w) -> ... (w h)', h=self.block_size_cL, w=self.block_size_rR) |
|
|
| |
| out_l = self.sama_L(out_P) |
|
|
| out_Pt = einops.rearrange(out_l, '... (w h) -> ... (h w)', h=self.block_size_rL, w=self.block_size_rR) |
|
|
| return out_Pt |
|
|
| def get_delta_weight2(self): |
| """ |
| Compute the delta weight matrix induced by the SaMA layer. |
| Returns: |
| Delta weight matrix of shape (dout, din) |
| """ |
|
|
| |
| sama_R = self.sama_R.weights |
| share_factor_R = self.share_factor_R |
| share_factor_L = self.share_factor_L |
| device, dtype = sama_R.device, sama_R.dtype |
|
|
| |
| blocks_R = sama_R.unsqueeze(-3).expand(-1, share_factor_R, -1, -1).reshape(-1, self.block_size_rR, self.block_size_cR) |
|
|
| R_dense = torch.block_diag(*[b for b in blocks_R]) |
|
|
| |
| intermidiate_dim = self.block_size_cL * self.block_size_rR |
| idx = torch.arange(intermidiate_dim, device=device) |
|
|
| |
| perm_indices = idx.view(self.block_size_cL, self.block_size_rR).t().reshape(-1) |
| eye_mid = torch.eye(intermidiate_dim, device=device, dtype=dtype) |
| P_mid = eye_mid[perm_indices] |
| |
| |
| sama_L = self.sama_L.weights |
| blocks_L = sama_L.unsqueeze(1).expand(-1, share_factor_L, -1, -1).reshape(-1, self.block_size_rL, self.block_size_cL) |
| L_dense = torch.block_diag(*[b for b in blocks_L]) |
|
|
| |
| dim_out = self.block_size_rL * self.block_size_rR |
| idx_out = torch.arange(dim_out, device=device) |
| perm_indices_out = idx_out.view(self.block_size_rR, self.block_size_rL).t().reshape(-1) |
| eye_out = torch.eye(dim_out, device=device, dtype=dtype) |
| P_out = eye_out[perm_indices_out] |
|
|
| |
| W_final = P_out @ (L_dense @ (P_mid @ R_dense)) |
|
|
| return W_final |
| def get_delta_weight(self): |
| """ |
| Compute the delta weight matrix efficiently without creating large |
| intermediate identity or permutation matrices. |
| """ |
| device, dtype = self.sama_R.weights.device, self.sama_R.weights.dtype |
| |
| |
| |
| blocks_R = self.sama_R.weights.unsqueeze(1).expand(-1, self.share_factor_R, -1, -1) |
| blocks_R = blocks_R.reshape(-1, self.block_size_rR, self.block_size_cR) |
| R_dense = torch.block_diag(*[b for b in blocks_R]) |
|
|
| |
| blocks_L = self.sama_L.weights.unsqueeze(1).expand(-1, self.share_factor_L, -1, -1) |
| blocks_L = blocks_L.reshape(-1, self.block_size_rL, self.block_size_cL) |
| L_dense = torch.block_diag(*[b for b in blocks_L]) |
|
|
| |
| |
| dim_mid = self.block_size_cL * self.block_size_rR |
| idx_mid = torch.arange(dim_mid, device=device) |
| perm_idx_mid = idx_mid.view(self.block_size_cL, self.block_size_rR).t().reshape(-1) |
| |
| |
| R_permuted = R_dense[perm_idx_mid, :] |
|
|
| W_mid = L_dense @ R_permuted |
|
|
| |
| dim_out = self.block_size_rL * self.block_size_rR |
| idx_out = torch.arange(dim_out, device=device) |
| perm_idx_out = idx_out.view(self.block_size_rR, self.block_size_rL).t().reshape(-1) |
| |
| |
| W_final = W_mid[perm_idx_out, :] |
|
|
| return W_final |
|
|
|
|
|
|
| class SamaLayer(BaseTunerLayer): |
| """ |
| Adapter-like wrapper that attaches Rotation modules to a base linear layer. |
| """ |
|
|
| adapter_layer_names: tuple[str, ...] = ("_sama_layer",) |
| other_param_names: tuple[str, ...] = ("share_factor_L", "share_factor_R", "scaling", |
| "col_L", |
| "row_R", |
| "drop_out",) |
|
|
| def __init__(self, base_layer: nn.Module, **kwargs): |
| |
| super().__init__() |
| |
| self.base_layer = base_layer |
| self._sama_layer = nn.ModuleDict() |
| self._sama_dropout = nn.ModuleDict() |
| self.scaling={} |
| self._adapter_config = {} |
|
|
| |
| self._disable_adapters = False |
| self.merged_adapters: list[str] = [] |
| self._cast_input_dtype_enabled = True |
| self.kwargs = kwargs |
|
|
| if isinstance(base_layer, nn.Linear): |
| self.in_features = base_layer.in_features |
| self.out_features = base_layer.out_features |
| else: |
| raise NotImplementedError("SamaLayer only supports nn.Linear base layers for now.") |
|
|
| @property |
| def _available_adapters(self) -> set[str]: |
| return set(self._sama_layer.keys()) |
|
|
| @property |
| def disable_adapters(self) -> bool: |
| return self._disable_adapters |
|
|
| @property |
| def merged(self) -> bool: |
| return bool(self.merged_adapters) |
|
|
| @property |
| def active_adapters(self) -> list[str]: |
| |
| return getattr(self, "_active_adapters", list(self._sama_layer.keys())) |
|
|
| def get_base_layer(self) -> nn.Module: |
| return self.base_layer |
|
|
| def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
| if not self._cast_input_dtype_enabled: |
| return x |
| return x.to(dtype) |
|
|
| def update_layer( |
| self, |
| adapter_name: str, |
| share_factor_L: int, |
| share_factor_R: int, |
| scaling: float, |
| col_L: int, |
| row_R: int, |
| drop_out: float, |
| **kwargs, |
| ): |
| """ |
| Add / update a rotation adapter for this layer. |
| """ |
| |
| |
| |
| |
| |
|
|
| col_R = self.in_features // col_L |
| if self.in_features % col_L != 0: |
| raise ValueError(f'Input mismatches, col_L = {col_L} * col_R = {col_R} vs input = {self.in_features}') |
| row_L = self.out_features // row_R |
| if self.out_features % row_R != 0: |
| raise ValueError(f'Output mismatches, row_L = {row_L} * row_R = {row_R} vs input = {self.out_features}') |
|
|
| sama_adapter = SharedMonarch(share_factor_L=share_factor_L, share_factor_R=share_factor_R, |
| block_size_rR=row_R, block_size_cR=col_R, |
| block_size_cL=col_L, block_size_rL=row_L) |
| self._sama_layer[adapter_name] = sama_adapter |
| self.scaling[adapter_name] = scaling |
| self._adapter_config[adapter_name] = {"scaling": scaling, |
| "share_factor_L": share_factor_L, "share_factor_R": share_factor_R, |
| "row_R": row_R, "col_L": col_L, |
| "drop_out": drop_out} |
| if drop_out > 0.0: |
| sama_dropout_layer = nn.Dropout(p=drop_out) |
| else: |
| sama_dropout_layer = nn.Identity() |
|
|
| self._sama_dropout.update(nn.ModuleDict({adapter_name: sama_dropout_layer})) |
|
|
| |
| def set_active_adapters(self, adapters: Optional[list[str]]): |
| if adapters is None: |
| if hasattr(self, "_active_adapters"): |
| delattr(self, "_active_adapters") |
| else: |
| self._active_adapters = adapters |
| |
| |
| class Linear(nn.Module, SamaLayer): |
| """ |
| A linear layer with SaMA layer for parameter-efficient fine-tuning. |
| """ |
| |
| def __init__(self, |
| base_layer: nn.Linear, |
| adapter_name: str, |
| share_factor_L: int, |
| share_factor_R: int, |
| scaling: float, |
| col_L: int, |
| row_R: int, |
| drop_out: float, |
| **kwargs): |
| |
| super().__init__() |
| SamaLayer.__init__(self, base_layer=base_layer, **kwargs) |
| self._active_adapter = adapter_name |
| |
| self.update_layer( |
| adapter_name=adapter_name, |
| share_factor_L=share_factor_L, |
| share_factor_R=share_factor_R, |
| scaling=scaling, |
| col_L=col_L, |
| row_R=row_R, |
| drop_out=drop_out, |
| **kwargs, |
| ) |
| |
| def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None): |
| """ |
| Merge the adapter effect into the base layer weights: |
| W_merged = W + Pt L P R |
| """ |
| adapter_names = check_adapters_to_merge(self, adapter_names) |
|
|
| if not adapter_names: |
| return |
| |
| base_layer = self.get_base_layer() |
| orig_dtype = base_layer.weight.dtype |
| |
| W = base_layer.weight.data |
|
|
| for active_adapter in adapter_names: |
|
|
| if active_adapter not in self._available_adapters: |
| continue |
|
|
|
|
| |
| if False: |
| sama_layer = self._sama_layer[active_adapter].sama_L.weights |
| scaling = self.scaling[active_adapter] |
| identity = torch.eye(self.in_features, device=sama_layer.sama_L.weights.device, dtype=sama_layer.sama_L.weights.dtype) |
| output_monarch = sama_layer(identity) |
|
|
| W_final = (base_layer.weight.data + scaling * output_monarch.T).contiguous().to(orig_dtype) |
| base_layer.weight.data.copy_(W_final) |
| else: |
| |
| sama_layer_weights = self._sama_layer[active_adapter].get_delta_weight() |
| scaling = self.scaling[active_adapter] |
| W_final = (base_layer.weight.data + scaling * sama_layer_weights).contiguous().to(orig_dtype) |
| base_layer.weight.data.copy_(W_final) |
|
|
| |
| self.merged_adapters.append(active_adapter) |
| |
| |
| def unmerge(self): |
| """ |
| Reverse merges in LIFO order (pop merged adapters and invert R). |
| """ |
| base_layer = self.get_base_layer() |
| orig_dtype = base_layer.weight.dtype |
|
|
| while self.merged_adapters: |
| active_adapter = self.merged_adapters.pop() |
| if active_adapter not in self._available_adapters: |
| continue |
| |
| sama_layer_weights = self._sama_layer[active_adapter].get_delta_weight() |
| scaling = self.scaling[active_adapter] |
| W_final = (base_layer.weight.data - scaling * sama_layer_weights).contiguous().to(orig_dtype) |
| base_layer.weight.data.copy_(W_final) |
| |
| |
| def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
| x_dtype = x.dtype |
| base_layer = self.get_base_layer() |
|
|
| if self.disable_adapters: |
| |
| if self.merged: |
| self.unmerge() |
| return base_layer(x, *args, **kwargs).to(x_dtype) |
|
|
| if self.merged: |
| |
| return base_layer(x, *args, **kwargs).to(x_dtype) |
|
|
| |
| output = base_layer(x, *args, **kwargs) |
| for active_adapter in self.active_adapters: |
| if active_adapter not in self._sama_layer: |
| continue |
| sama_layer = self._sama_layer[active_adapter] |
| scaling = self.scaling[active_adapter] |
| x = self._cast_input_dtype(x, sama_layer.sama_L.weights.dtype) |
|
|
| output = output + scaling * sama_layer(x).to(output.dtype) |
|
|
| return output.to(x_dtype) |
|
|
| def __repr__(self): |
| return f"sama.{super().__repr__()}" |
|
|