"""Sparse adaptation for PAWN. Perturbs a random subset of frozen weight elements. Each selected weight gets an additive trainable delta (zero-initialized, so the model starts identical to the frozen backbone). Related to sparse fine-tuning ideas from the lottery ticket literature (`Frankle & Carbin, 2018 `_, ICLR 2019). output = linear(x) where W = W_frozen + delta * mask The binary mask is fixed at init. Effective trainable parameters equal the number of True entries in the mask, controlled by the density parameter. """ import torch import torch.nn as nn import torch.nn.functional as F from pawn.config import CLMConfig from pawn.model import PAWNCLM, Attention class SparseLinear(nn.Module): """Frozen linear with a sparse additive delta. output = F.linear(x, W_frozen + delta * mask, bias) """ mask: torch.Tensor def __init__(self, frozen_linear: nn.Linear, mask: torch.Tensor): super().__init__() self.frozen = frozen_linear self.delta = nn.Parameter(torch.zeros_like(frozen_linear.weight)) self.register_buffer("mask", mask) @property def n_active(self) -> int: return int(self.mask.sum().item()) def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.frozen.weight + self.delta * self.mask return F.linear(x, w, self.frozen.bias) def _random_mask(shape: tuple[int, ...], density: float, generator: torch.Generator | None = None) -> torch.Tensor: """Create a random binary mask with the given density (fraction of True).""" return torch.rand(shape, generator=generator) < density _ATTN_TARGETS = ("wq", "wk", "wv", "wo") _FFN_TARGETS = ("w_gate", "w_up", "w_down") class SparseCLM(nn.Module): """Frozen PAWN backbone with sparse weight perturbation. A random subset of weight elements in attention (and optionally FFN) projections are made trainable. Only the masked delta values are effectively learned. """ def __init__( self, backbone: PAWNCLM, density: float = 0.01, attn_targets: tuple[str, ...] = _ATTN_TARGETS, adapt_ffn: bool = False, layers: tuple[int, ...] | None = None, seed: int = 42, ): super().__init__() self.backbone = backbone self.density = density self.adapt_ffn = adapt_ffn self.attn_targets = tuple(attn_targets) n_layers = len(backbone.layers) self.adapted_layers = set(layers if layers is not None else range(n_layers)) # Freeze the entire backbone for p in backbone.parameters(): p.requires_grad = False gen = torch.Generator().manual_seed(seed) # Inject sparse adapters for layer_idx in range(len(backbone.layers)): if layer_idx not in self.adapted_layers: continue block = backbone.get_block(layer_idx) attn: Attention = block.attn for proj_name in self.attn_targets: original = getattr(attn, proj_name) mask = _random_mask(original.weight.shape, density, gen) setattr(attn, proj_name, SparseLinear(original, mask)) if adapt_ffn: ffn = block.ffn for proj_name in _FFN_TARGETS: original = getattr(ffn, proj_name) mask = _random_mask(original.weight.shape, density, gen) setattr(ffn, proj_name, SparseLinear(original, mask)) @property def cfg(self) -> CLMConfig: return self.backbone.cfg def forward_hidden(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: """Run backbone layers (with sparse deltas), return normed hidden states.""" bb = self.backbone x = bb.embed(input_ids) T = input_ids.shape[1] if attention_mask is not None: causal = bb.causal_mask[:T, :T] padding = attention_mask.unsqueeze(1).unsqueeze(2) mask = causal.unsqueeze(0) & padding else: mask = None rope_cos = bb.rope_cos[:, :, :T, :] rope_sin = bb.rope_sin[:, :, :T, :] for layer in bb.layers: x = layer(x, rope_cos, rope_sin, mask) return bb.final_norm(x) def project_head(self, x: torch.Tensor) -> torch.Tensor: """Project hidden states through lm_head.""" return self.backbone.lm_head(x) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Full forward pass. Returns logits (B, T, V).""" bb = self.backbone x = bb.embed(input_ids) T = input_ids.shape[1] if attention_mask is not None: causal = bb.causal_mask[:T, :T] padding = attention_mask.unsqueeze(1).unsqueeze(2) mask = causal.unsqueeze(0) & padding else: mask = None rope_cos = bb.rope_cos[:, :, :T, :] rope_sin = bb.rope_sin[:, :, :T, :] for layer in bb.layers: x = layer(x, rope_cos, rope_sin, mask) x = bb.final_norm(x) return self.project_head(x) def forward_generate( self, input_ids: torch.Tensor, kv_cache: list[tuple[torch.Tensor, torch.Tensor]] | None = None, ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: """Forward with KV-cache for autoregressive generation.""" bb = self.backbone x = bb.embed(input_ids) T_new = input_ids.shape[1] if kv_cache is not None: T_cached = kv_cache[0][0].shape[2] rope_cos = bb.rope_cos[:, :, T_cached:T_cached + T_new, :] rope_sin = bb.rope_sin[:, :, T_cached:T_cached + T_new, :] else: rope_cos = bb.rope_cos[:, :, :T_new, :] rope_sin = bb.rope_sin[:, :, :T_new, :] new_kv_cache = [] for i in range(len(bb.layers)): layer_cache = kv_cache[i] if kv_cache is not None else None x, new_cache = bb.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache) new_kv_cache.append(new_cache) x = bb.final_norm(x[:, -1:, :]) logits = bb.lm_head(x) return logits, new_kv_cache # --- Parameter management --- def sparse_parameters(self) -> list[nn.Parameter]: """Return only trainable sparse delta parameters.""" return [p for p in self.parameters() if p.requires_grad] def n_active_params(self) -> int: """Count of actually active (masked-in) parameters.""" total = 0 for layer_idx in range(len(self.backbone.layers)): block = self.backbone.get_block(layer_idx) for proj_name in self.attn_targets: module = getattr(block.attn, proj_name) if isinstance(module, SparseLinear): total += module.n_active if self.adapt_ffn: for proj_name in _FFN_TARGETS: module = getattr(block.ffn, proj_name) if isinstance(module, SparseLinear): total += module.n_active return total def sparse_state_dict(self) -> dict[str, torch.Tensor]: """Extract sparse delta weights for saving.""" return { name: param.data.clone() for name, param in self.named_parameters() if param.requires_grad } def load_sparse_state_dict(self, state: dict[str, torch.Tensor]): """Load sparse delta weights.""" own = self.state_dict() for k, v in state.items(): if k in own: own[k] = v self.load_state_dict(own, strict=False) def sparse_weight_report(self) -> dict[str, float]: """Per-layer sparse delta norms for monitoring.""" report = {} for layer_idx in range(len(self.backbone.layers)): block = self.backbone.get_block(layer_idx) for proj_name in self.attn_targets: module = getattr(block.attn, proj_name) if isinstance(module, SparseLinear): masked_delta = module.delta.data * module.mask report[f"sparse/layer{layer_idx}.{proj_name}.delta"] = masked_delta.norm().item() if self.adapt_ffn: for proj_name in _FFN_TARGETS: module = getattr(block.ffn, proj_name) if isinstance(module, SparseLinear): masked_delta = module.delta.data * module.mask report[f"sparse/layer{layer_idx}.{proj_name}.delta"] = masked_delta.norm().item() return report