import torch import torch.nn as nn from typing import Optional, Set from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge import einops class SingleSharedBlockDiag(nn.Module): def __init__(self, num_unique_blocks, share_factor, block_size_r: int = None, block_size_c: int = None, init_std=0.0): """ Initializes the layer with shared diagonal weights. self: Description num_unique_blocks (int): Number of unique weight blocks (groups). share_factor (int): Number of times each unique block is repeated/shared. block_size_r (int, optional): Output size of each block (Row dimension). block_size_c (int, optional): Input size of each block (Column dimension). """ super().__init__() if (block_size_r, block_size_c) == (None, None): raise ValueError(f"Block size r,c are not valid") elif block_size_r == None: block_size_r = block_size_c elif block_size_c == None: block_size_c = block_size_r self.num_diag_blocks = num_unique_blocks * share_factor self.block_size_r = block_size_r self.block_size_c = block_size_c self.share_factor = share_factor self.num_unique_blocks = num_unique_blocks # only the store the diagonal # (num_unique_blocks, block_size r c) self.weights = nn.Parameter(torch.empty(self.num_unique_blocks, block_size_r, block_size_c)) self.init_std = init_std self.reset_parameters() def reset_parameters(self): with torch.no_grad(): if self.init_std > 0: # nn.init.normal_(self.weights, mean=0, std=0.1) nn.init.normal_(self.weights, mean=0, std=self.init_std) else: nn.init.constant_(self.weights, 0) def forward(self, x): # x_dtype = x.dtype # x = x.to(self.weights.dtype) # broadcasting, from normal: B,K @ (K,K).T -> B/r,r,K @ (K,K).T # x = einops.rearrange(x, '... (uni factor size_c) -> ... uni factor size_c', # uni=self.num_unique_blocks, factor=self.share_factor, size_c=self.block_size_c) x = x.view(*x.shape[:-1], self.num_unique_blocks, self.share_factor, self.block_size_c) # faster # row vector ...x -> x @ W.t output = torch.einsum('...ufc, urc -> ...ufr', x, self.weights) # (group (unique), share_factor, block_size_r) # output = x @ self.weights.transpose(-2,-1) # output = einops.rearrange(output, '... group factor size_r -> ... (group factor size_r)') output = output.reshape(*output.shape[:-3], -1) return output #.to(x_dtype) class SharedMonarch(nn.Module): def __init__(self, share_factor_L, share_factor_R, block_size_rR, block_size_cR, block_size_rL, block_size_cL): """ Input Dimension N = n1 * n2 or cL * cR Output Dimension M = m1 * m2 or rL * rR R (Right): BlockDiag with 'n1/cL' blocks of size R P (Permute): Transpose (m2, n1) -> (n1, m2) L (Left): BlockDiag with 'n2' blocks of size 'n1' block_size_cR/rL: in/out size of R/L Note: Layer R has 'block_size_cL' blocks. Layer L has 'block_size_rR' blocks. """ super().__init__() if block_size_rL == None: block_size_rL = block_size_rR if block_size_cL == None: block_size_cL = block_size_cR self.block_size_rL = block_size_rL self.block_size_cL = block_size_cL self.block_size_rR = block_size_rR self.block_size_cR = block_size_cR if block_size_cL % share_factor_R != 0: raise ValueError(f"block_size_cL ({block_size_cL}) must be divisible by share_factor_R ({share_factor_R})") num_unique_blocksR = block_size_cL // share_factor_R if block_size_rR % share_factor_L != 0: raise ValueError(f"block_size_rR ({block_size_rR}) must be divisible by share_factor_L ({share_factor_L})") num_unique_blocksL = block_size_rR // share_factor_L self.share_factor_L = share_factor_L self.share_factor_R = share_factor_R self.sama_L = SingleSharedBlockDiag(num_unique_blocks=num_unique_blocksL, share_factor=share_factor_L, block_size_r=block_size_rL, block_size_c=block_size_cL, init_std=0.0) self.sama_R = SingleSharedBlockDiag(num_unique_blocks=num_unique_blocksR, share_factor=share_factor_R, block_size_r=block_size_rR, block_size_c=block_size_cR, init_std=1e-3) # 1e-4 def forward(self, x): # x_dtype = x.dtype # x = x.to(self.sama_L.weights.dtype) # right matrix out_r = self.sama_R(x) # permutation ### Be careful with permutation matrices out_P = einops.rearrange(out_r, '... (h w) -> ... (w h)', h=self.block_size_cL, w=self.block_size_rR) # Left block out_l = self.sama_L(out_P) out_Pt = einops.rearrange(out_l, '... (w h) -> ... (h w)', h=self.block_size_rL, w=self.block_size_rR) return out_Pt #.to(x_dtype) def get_delta_weight2(self): """ Compute the delta weight matrix induced by the SaMA layer. Returns: Delta weight matrix of shape (dout, din) """ ## Right R sama_R = self.sama_R.weights share_factor_R = self.share_factor_R share_factor_L = self.share_factor_L device, dtype = sama_R.device, sama_R.dtype # expand: (Unique, r, c) -> (Unique, Factor, r, c) -> (Total_Blocks, r, c) blocks_R = sama_R.unsqueeze(-3).expand(-1, share_factor_R, -1, -1).reshape(-1, self.block_size_rR, self.block_size_cR) R_dense = torch.block_diag(*[b for b in blocks_R]) # Logic: Permutationo of (cL, rR) -> (rR, cL) intermidiate_dim = self.block_size_cL * self.block_size_rR idx = torch.arange(intermidiate_dim, device=device) # indices # # View (cL, rR) -> Transpose -> Flatten (h w) -> (w h) perm_indices = idx.view(self.block_size_cL, self.block_size_rR).t().reshape(-1) eye_mid = torch.eye(intermidiate_dim, device=device, dtype=dtype) P_mid = eye_mid[perm_indices] ### Left L sama_L = self.sama_L.weights blocks_L = sama_L.unsqueeze(1).expand(-1, share_factor_L, -1, -1).reshape(-1, self.block_size_rL, self.block_size_cL) L_dense = torch.block_diag(*[b for b in blocks_L]) # (rR, rL) -> (rL, rR) dim_out = self.block_size_rL * self.block_size_rR idx_out = torch.arange(dim_out, device=device) perm_indices_out = idx_out.view(self.block_size_rR, self.block_size_rL).t().reshape(-1) eye_out = torch.eye(dim_out, device=device, dtype=dtype) P_out = eye_out[perm_indices_out] ## Final W_final = P_out @ (L_dense @ (P_mid @ R_dense)) return W_final def get_delta_weight(self): """ Compute the delta weight matrix efficiently without creating large intermediate identity or permutation matrices. """ device, dtype = self.sama_R.weights.device, self.sama_R.weights.dtype # 1. Dense R (M_mid x N_in) # weights: (unique, r, c) -> expand to (unique, factor, r, c) -> reshape blocks_R = self.sama_R.weights.unsqueeze(1).expand(-1, self.share_factor_R, -1, -1) blocks_R = blocks_R.reshape(-1, self.block_size_rR, self.block_size_cR) R_dense = torch.block_diag(*[b for b in blocks_R]) # 2. Dense L (M_out x M_mid) blocks_L = self.sama_L.weights.unsqueeze(1).expand(-1, self.share_factor_L, -1, -1) blocks_L = blocks_L.reshape(-1, self.block_size_rL, self.block_size_cL) L_dense = torch.block_diag(*[b for b in blocks_L]) # 3. Not P_mid @ R_dense, permute rows of R_dense by indexing # Logic: Permute (cL, rR) -> (rR, cL) dim_mid = self.block_size_cL * self.block_size_rR idx_mid = torch.arange(dim_mid, device=device) perm_idx_mid = idx_mid.view(self.block_size_cL, self.block_size_rR).t().reshape(-1) # P_mid @ R_dense R_permuted = R_dense[perm_idx_mid, :] W_mid = L_dense @ R_permuted # 5. Pt: (rR, rL) -> (rL, rR) dim_out = self.block_size_rL * self.block_size_rR idx_out = torch.arange(dim_out, device=device) perm_idx_out = idx_out.view(self.block_size_rR, self.block_size_rL).t().reshape(-1) # Pt @ W_mid: rearrange rows of W_mid W_final = W_mid[perm_idx_out, :] return W_final class SamaLayer(BaseTunerLayer): """ Adapter-like wrapper that attaches Rotation modules to a base linear layer. """ adapter_layer_names: tuple[str, ...] = ("_sama_layer",) other_param_names: tuple[str, ...] = ("share_factor_L", "share_factor_R", "scaling", "col_L", "row_R", "drop_out",) def __init__(self, base_layer: nn.Module, **kwargs): # Let BaseTunerLayer do its init (it usually subclasses nn.Module) super().__init__() # store base layer and adapter containers self.base_layer = base_layer self._sama_layer = nn.ModuleDict() # mapping adapter_name -> Rotation module self._sama_dropout = nn.ModuleDict() self.scaling={} # default scaling per adapter self._adapter_config = {} # store r, T, num_rotations per adapter # flags (exposed in a simple way) self._disable_adapters = False self.merged_adapters: list[str] = [] self._cast_input_dtype_enabled = True self.kwargs = kwargs if isinstance(base_layer, nn.Linear): self.in_features = base_layer.in_features self.out_features = base_layer.out_features else: raise NotImplementedError("SamaLayer only supports nn.Linear base layers for now.") @property def _available_adapters(self) -> set[str]: return set(self._sama_layer.keys()) @property def disable_adapters(self) -> bool: return self._disable_adapters @property def merged(self) -> bool: return bool(self.merged_adapters) @property def active_adapters(self) -> list[str]: # If some external mechanism sets active adapters, prefer it; else use all added adapters. return getattr(self, "_active_adapters", list(self._sama_layer.keys())) def get_base_layer(self) -> nn.Module: return self.base_layer def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: if not self._cast_input_dtype_enabled: return x return x.to(dtype) def update_layer( self, adapter_name: str, share_factor_L: int, share_factor_R: int, scaling: float, col_L: int, row_R: int, drop_out: float, **kwargs, ): """ Add / update a rotation adapter for this layer. """ # if r <= 0: # raise ValueError(f"r must be positive, got {r}") # if num_rotations <= 0: # raise ValueError(f"num_rotations must be positive, got {num_rotations}") col_R = self.in_features // col_L if self.in_features % col_L != 0: raise ValueError(f'Input mismatches, col_L = {col_L} * col_R = {col_R} vs input = {self.in_features}') row_L = self.out_features // row_R if self.out_features % row_R != 0: raise ValueError(f'Output mismatches, row_L = {row_L} * row_R = {row_R} vs input = {self.out_features}') sama_adapter = SharedMonarch(share_factor_L=share_factor_L, share_factor_R=share_factor_R, block_size_rR=row_R, block_size_cR=col_R, block_size_cL=col_L, block_size_rL=row_L) self._sama_layer[adapter_name] = sama_adapter self.scaling[adapter_name] = scaling ## No /r self._adapter_config[adapter_name] = {"scaling": scaling, "share_factor_L": share_factor_L, "share_factor_R": share_factor_R, "row_R": row_R, "col_L": col_L, "drop_out": drop_out} if drop_out > 0.0: sama_dropout_layer = nn.Dropout(p=drop_out) else: sama_dropout_layer = nn.Identity() self._sama_dropout.update(nn.ModuleDict({adapter_name: sama_dropout_layer})) # (optional) helper to set currently active adapters externally def set_active_adapters(self, adapters: Optional[list[str]]): if adapters is None: if hasattr(self, "_active_adapters"): delattr(self, "_active_adapters") else: self._active_adapters = adapters class Linear(nn.Module, SamaLayer): """ A linear layer with SaMA layer for parameter-efficient fine-tuning. """ def __init__(self, base_layer: nn.Linear, adapter_name: str, share_factor_L: int, share_factor_R: int, scaling: float, col_L: int, row_R: int, drop_out: float, **kwargs): super().__init__() SamaLayer.__init__(self, base_layer=base_layer, **kwargs) self._active_adapter = adapter_name self.update_layer( adapter_name=adapter_name, share_factor_L=share_factor_L, share_factor_R=share_factor_R, scaling=scaling, col_L=col_L, row_R=row_R, drop_out=drop_out, **kwargs, ) def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None): """ Merge the adapter effect into the base layer weights: W_merged = W + Pt L P R """ adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: return base_layer = self.get_base_layer() orig_dtype = base_layer.weight.dtype # base_layer.weight shape: (out_features, in_features) W = base_layer.weight.data # (out, in) for active_adapter in adapter_names: if active_adapter not in self._available_adapters: continue ### Method 1: Identity matrix if False: sama_layer = self._sama_layer[active_adapter].sama_L.weights scaling = self.scaling[active_adapter] identity = torch.eye(self.in_features, device=sama_layer.sama_L.weights.device, dtype=sama_layer.sama_L.weights.dtype) output_monarch = sama_layer(identity) W_final = (base_layer.weight.data + scaling * output_monarch.T).contiguous().to(orig_dtype) base_layer.weight.data.copy_(W_final) else: ### Method 2: Manually doing calculation sama_layer_weights = self._sama_layer[active_adapter].get_delta_weight() scaling = self.scaling[active_adapter] W_final = (base_layer.weight.data + scaling * sama_layer_weights).contiguous().to(orig_dtype) base_layer.weight.data.copy_(W_final) # mark merged (so unmerge can restore by inverse) self.merged_adapters.append(active_adapter) def unmerge(self): """ Reverse merges in LIFO order (pop merged adapters and invert R). """ base_layer = self.get_base_layer() orig_dtype = base_layer.weight.dtype while self.merged_adapters: active_adapter = self.merged_adapters.pop() if active_adapter not in self._available_adapters: continue sama_layer_weights = self._sama_layer[active_adapter].get_delta_weight() scaling = self.scaling[active_adapter] W_final = (base_layer.weight.data - scaling * sama_layer_weights).contiguous().to(orig_dtype) base_layer.weight.data.copy_(W_final) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: x_dtype = x.dtype base_layer = self.get_base_layer() if self.disable_adapters: # if merged, unmerge to ensure base_layer produces original behavior if self.merged: self.unmerge() return base_layer(x, *args, **kwargs).to(x_dtype) if self.merged: # if merged into base layer, just forward return base_layer(x, *args, **kwargs).to(x_dtype) # otherwise apply active adapters (transform inputs) then call base layer output = base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self._sama_layer: continue sama_layer = self._sama_layer[active_adapter] scaling = self.scaling[active_adapter] x = self._cast_input_dtype(x, sama_layer.sama_L.weights.dtype) output = output + scaling * sama_layer(x).to(output.dtype) return output.to(x_dtype) def __repr__(self): return f"sama.{super().__repr__()}"