| """ |
| dfc.py β Dedicated Feature CrossCoder (DFC) model. |
| |
| Feature layout in dict_size |
| ββββββββββββββββββββββββββββ |
| βββββββββββββββββββββββ¬ββββββββββββββββββββββ¬βββββββββββββββββββββββββββ |
| β A-exclusive (n_a) β B-exclusive (n_b) β Shared (n_shared) β |
| βββββββββββββββββββββββ΄ββββββββββββββββββββββ΄βββββββββββββββββββββββββββ |
| idx: 0 βββββββ a_end ββββββββ b_end βββββββββββββββββββββ dict_size |
| |
| Constraints (enforced by gradient masking + _apply_masks every step) |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β’ Model A cannot encode/decode B-exclusive features |
| β’ Model B cannot encode/decode A-exclusive features |
| β’ Shared features are accessible to both |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| from bitsandbytes import features |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class DFCCrossCoder(nn.Module): |
|
|
| def __init__( |
| self, |
| activation_dim: int, |
| dict_size: int, |
| k: int, |
| model_a_exclusive_pct: float = 0.05, |
| model_b_exclusive_pct: float = 0.05, |
| ): |
| super().__init__() |
| self.activation_dim = activation_dim |
| self.dict_size = dict_size |
| self.k = k |
|
|
| self.n_a = int(dict_size * model_a_exclusive_pct) |
| self.n_b = int(dict_size * model_b_exclusive_pct) |
| self.n_shared = dict_size - self.n_a - self.n_b |
| self.a_end = self.n_a |
| self.b_end = self.n_a + self.n_b |
|
|
| print( |
| f"[DFC] dict={dict_size} k={k} | " |
| f"A-excl={self.n_a} B-excl={self.n_b} shared={self.n_shared}" |
| ) |
|
|
| |
| self.W_enc = nn.Parameter( |
| torch.randn(2, activation_dim, dict_size) / (activation_dim ** 0.5) |
| ) |
| self.b_enc = nn.Parameter(torch.zeros(dict_size)) |
|
|
| |
| self.W_dec = nn.Parameter( |
| torch.randn(dict_size, 2, activation_dim) / (dict_size ** 0.5) |
| ) |
| self.b_dec = nn.Parameter(torch.zeros(2, activation_dim)) |
|
|
| |
| |
| enc_mask = torch.ones(2, dict_size) |
| enc_mask[1, : self.a_end] = 0 |
| enc_mask[0, self.a_end : self.b_end] = 0 |
| self.register_buffer("enc_mask", enc_mask) |
|
|
| |
| dec_mask = torch.ones(dict_size, 2) |
| dec_mask[: self.a_end, 1] = 0 |
| dec_mask[self.a_end : self.b_end, 0] = 0 |
| self.register_buffer("dec_mask", dec_mask) |
|
|
| self._apply_masks() |
|
|
| |
|
|
| @torch.no_grad() |
| def _apply_masks(self): |
| """Zero forbidden weights. Call after every optimiser step.""" |
| for m in range(2): |
| self.W_enc.data[m] *= self.enc_mask[m].unsqueeze(0) |
| for m in range(2): |
| self.W_dec.data[:, m, :] *= self.dec_mask[:, m].unsqueeze(1) |
|
|
| |
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """x: (B, 2, d) β features: (B, dict_size) sparse top-k.""" |
| W = self.W_enc * self.enc_mask.unsqueeze(1) |
| pre = torch.einsum("bmd,mdf->bf", x, W) + self.b_enc |
| pre = F.relu(pre) |
| topk_vals, topk_idx = torch.topk(pre, self.k, dim=-1) |
| features = torch.zeros_like(pre) |
| features.scatter_(-1, topk_idx, topk_vals) |
| return features |
|
|
| def decode(self, features: torch.Tensor) -> torch.Tensor: |
| """features: (B, dict_size) β (B, 2, d).""" |
| W = self.W_dec * self.dec_mask.unsqueeze(-1) |
| return torch.einsum("bf,fmd->bmd", features, W) + self.b_dec |
|
|
| def forward(self, x: torch.Tensor): |
| """x: (B, 2, d) β (reconstruction, features).""" |
| features = self.encode(x) |
| recon = self.decode(features) |
| return recon, features |
|
|
| def loss( |
| self, |
| x: torch.Tensor, |
| sparsity_coef: float = 1e-3, |
| exclusive_sparsity_coef: float = 1e-3 |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """MSE + weighted L1 sparsity. Returns (total, mse, l1_shared, l1_exclusive).""" |
| recon, features = self.forward(x) |
| mse = F.mse_loss(recon, x) |
| |
| |
| |
| |
| fs = features[:, self.b_end:] |
|
|
| |
| fa = torch.cat([features[:, :self.a_end], features[:, self.b_end:]], dim=-1) |
| fb = torch.cat([features[:, self.a_end:self.b_end], features[:, self.b_end:]], dim=-1) |
| |
| |
| l1_shared = fs.abs().mean() |
| l1_exclusive = (fa.abs().mean() + fb.abs().mean()) / 2 |
| total = mse + exclusive_sparsity_coef * l1_exclusive + sparsity_coef * l1_shared |
| |
| return total, mse, l1_shared, l1_exclusive |
|
|
| |
|
|
| @torch.no_grad() |
| def verify_partition_integrity(self) -> dict[str, float]: |
| """Max absolute value in weights that should be zero.""" |
| enc_viol = (self.W_enc.abs() * (1 - self.enc_mask).unsqueeze(1)).max().item() |
| dec_viol_a = self.W_dec[: self.a_end, 1, :].abs().max().item() |
| dec_viol_b = self.W_dec[self.a_end : self.b_end, 0, :].abs().max().item() |
| return { |
| "enc_max_violation": enc_viol, |
| "dec_max_violation": max(dec_viol_a, dec_viol_b), |
| } |
|
|
| @torch.no_grad() |
| def feature_stats(self, features: torch.Tensor) -> dict[str, float]: |
| """Partition-level activation stats for a batch of features.""" |
| fa = features[:, : self.a_end] |
| fb = features[:, self.a_end : self.b_end] |
| fs = features[:, self.b_end :] |
| return { |
| "l0_total": (features > 0).float().sum(dim=-1).mean().item(), |
| "l0_a_excl": (fa > 0).float().sum(dim=-1).mean().item(), |
| "l0_b_excl": (fb > 0).float().sum(dim=-1).mean().item(), |
| "l0_shared": (fs > 0).float().sum(dim=-1).mean().item(), |
| "mean_a_excl": fa.mean().item(), |
| "mean_b_excl": fb.mean().item(), |
| "mean_shared": fs.mean().item(), |
| } |
|
|
| |
|
|
| def save(self, path: str) -> None: |
| Path(path).mkdir(parents=True, exist_ok=True) |
| torch.save(self.state_dict(), f"{path}/model.pt") |
| json.dump( |
| dict( |
| activation_dim=self.activation_dim, |
| dict_size=self.dict_size, |
| k=self.k, |
| n_a=self.n_a, |
| n_b=self.n_b, |
| ), |
| open(f"{path}/config.json", "w"), |
| indent=2, |
| ) |
| print(f"[DFC] Saved β {path}") |
|
|
| @classmethod |
| def load(cls, path: str, device: str = "cpu") -> "DFCCrossCoder": |
| cfg = json.load(open(f"{path}/config.json")) |
| model = cls( |
| activation_dim=cfg["activation_dim"], |
| dict_size=cfg["dict_size"], |
| k=cfg["k"], |
| model_a_exclusive_pct=cfg["n_a"] / cfg["dict_size"], |
| model_b_exclusive_pct=cfg["n_b"] / cfg["dict_size"], |
| ) |
| model.load_state_dict( |
| torch.load(f"{path}/model.pt", map_location=device, weights_only=True) |
| ) |
| return model.to(device) |
|
|