dfc-crosscoder-qwen-ToolRL / dfc_crosscoder.py
antebe1's picture
Upload DFC CrossCoder model
ebdb5ae verified
"""
dfc.py β€” Dedicated Feature CrossCoder (DFC) model.
Feature layout in dict_size
────────────────────────────
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ A-exclusive (n_a) β”‚ B-exclusive (n_b) β”‚ Shared (n_shared) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
idx: 0 ─────── a_end ──────── b_end ───────────────────── dict_size
Constraints (enforced by gradient masking + _apply_masks every step)
──────────────────────────────────────────────────────────────────────
β€’ Model A cannot encode/decode B-exclusive features
β€’ Model B cannot encode/decode A-exclusive features
β€’ Shared features are accessible to both
"""
from __future__ import annotations
import json
from pathlib import Path
from bitsandbytes import features
import torch
import torch.nn as nn
import torch.nn.functional as F
class DFCCrossCoder(nn.Module):
def __init__(
self,
activation_dim: int,
dict_size: int,
k: int,
model_a_exclusive_pct: float = 0.05,
model_b_exclusive_pct: float = 0.05,
):
super().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size
self.k = k
self.n_a = int(dict_size * model_a_exclusive_pct)
self.n_b = int(dict_size * model_b_exclusive_pct)
self.n_shared = dict_size - self.n_a - self.n_b
self.a_end = self.n_a
self.b_end = self.n_a + self.n_b
print(
f"[DFC] dict={dict_size} k={k} | "
f"A-excl={self.n_a} B-excl={self.n_b} shared={self.n_shared}"
)
# Encoder: W_enc[model, d_in, dict_size]
self.W_enc = nn.Parameter(
torch.randn(2, activation_dim, dict_size) / (activation_dim ** 0.5)
)
self.b_enc = nn.Parameter(torch.zeros(dict_size))
# Decoder: W_dec[dict_size, model, d_in]
self.W_dec = nn.Parameter(
torch.randn(dict_size, 2, activation_dim) / (dict_size ** 0.5)
)
self.b_dec = nn.Parameter(torch.zeros(2, activation_dim))
# ── Partition masks (move with .to(device)) ───────────────────
# enc_mask[model, dict_size]
enc_mask = torch.ones(2, dict_size)
enc_mask[1, : self.a_end] = 0 # B cannot encode A-excl
enc_mask[0, self.a_end : self.b_end] = 0 # A cannot encode B-excl
self.register_buffer("enc_mask", enc_mask)
# dec_mask[dict_size, model]
dec_mask = torch.ones(dict_size, 2)
dec_mask[: self.a_end, 1] = 0 # A-excl: B decoder = 0
dec_mask[self.a_end : self.b_end, 0] = 0 # B-excl: A decoder = 0
self.register_buffer("dec_mask", dec_mask)
self._apply_masks()
# ── Weight enforcement ────────────────────────────────────────────
@torch.no_grad()
def _apply_masks(self):
"""Zero forbidden weights. Call after every optimiser step."""
for m in range(2):
self.W_enc.data[m] *= self.enc_mask[m].unsqueeze(0)
for m in range(2):
self.W_dec.data[:, m, :] *= self.dec_mask[:, m].unsqueeze(1)
# ── Forward ───────────────────────────────────────────────────────
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, 2, d) β†’ features: (B, dict_size) sparse top-k."""
W = self.W_enc * self.enc_mask.unsqueeze(1) # (2, d, dict)
pre = torch.einsum("bmd,mdf->bf", x, W) + self.b_enc
pre = F.relu(pre)
topk_vals, topk_idx = torch.topk(pre, self.k, dim=-1)
features = torch.zeros_like(pre)
features.scatter_(-1, topk_idx, topk_vals)
return features
def decode(self, features: torch.Tensor) -> torch.Tensor:
"""features: (B, dict_size) β†’ (B, 2, d)."""
W = self.W_dec * self.dec_mask.unsqueeze(-1) # (dict, 2, d)
return torch.einsum("bf,fmd->bmd", features, W) + self.b_dec
def forward(self, x: torch.Tensor):
"""x: (B, 2, d) β†’ (reconstruction, features)."""
features = self.encode(x)
recon = self.decode(features)
return recon, features
def loss(
self,
x: torch.Tensor,
sparsity_coef: float = 1e-3,
exclusive_sparsity_coef: float = 1e-3 # Lower penalty for exclusive features
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""MSE + weighted L1 sparsity. Returns (total, mse, l1_shared, l1_exclusive)."""
recon, features = self.forward(x)
mse = F.mse_loss(recon, x)
# Split features by partition
# fa = features[:, :self.a_end] # A-exclusive
# fb = features[:, self.a_end:self.b_end] # B-exclusive
fs = features[:, self.b_end:] # Shared
# A sees: A-exclusive + shared
fa = torch.cat([features[:, :self.a_end], features[:, self.b_end:]], dim=-1) # A-exclusive + shared
fb = torch.cat([features[:, self.a_end:self.b_end], features[:, self.b_end:]], dim=-1) # B-exclusive + shared
# Separate sparsity penalties
l1_shared = fs.abs().mean()
l1_exclusive = (fa.abs().mean() + fb.abs().mean()) / 2
total = mse + exclusive_sparsity_coef * l1_exclusive + sparsity_coef * l1_shared
return total, mse, l1_shared, l1_exclusive
# ── Diagnostics ───────────────────────────────────────────────────
@torch.no_grad()
def verify_partition_integrity(self) -> dict[str, float]:
"""Max absolute value in weights that should be zero."""
enc_viol = (self.W_enc.abs() * (1 - self.enc_mask).unsqueeze(1)).max().item()
dec_viol_a = self.W_dec[: self.a_end, 1, :].abs().max().item()
dec_viol_b = self.W_dec[self.a_end : self.b_end, 0, :].abs().max().item()
return {
"enc_max_violation": enc_viol,
"dec_max_violation": max(dec_viol_a, dec_viol_b),
}
@torch.no_grad()
def feature_stats(self, features: torch.Tensor) -> dict[str, float]:
"""Partition-level activation stats for a batch of features."""
fa = features[:, : self.a_end]
fb = features[:, self.a_end : self.b_end]
fs = features[:, self.b_end :]
return {
"l0_total": (features > 0).float().sum(dim=-1).mean().item(),
"l0_a_excl": (fa > 0).float().sum(dim=-1).mean().item(),
"l0_b_excl": (fb > 0).float().sum(dim=-1).mean().item(),
"l0_shared": (fs > 0).float().sum(dim=-1).mean().item(),
"mean_a_excl": fa.mean().item(),
"mean_b_excl": fb.mean().item(),
"mean_shared": fs.mean().item(),
}
# ── Save / Load ───────────────────────────────────────────────────
def save(self, path: str) -> None:
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(self.state_dict(), f"{path}/model.pt")
json.dump(
dict(
activation_dim=self.activation_dim,
dict_size=self.dict_size,
k=self.k,
n_a=self.n_a,
n_b=self.n_b,
),
open(f"{path}/config.json", "w"),
indent=2,
)
print(f"[DFC] Saved β†’ {path}")
@classmethod
def load(cls, path: str, device: str = "cpu") -> "DFCCrossCoder":
cfg = json.load(open(f"{path}/config.json"))
model = cls(
activation_dim=cfg["activation_dim"],
dict_size=cfg["dict_size"],
k=cfg["k"],
model_a_exclusive_pct=cfg["n_a"] / cfg["dict_size"],
model_b_exclusive_pct=cfg["n_b"] / cfg["dict_size"],
)
model.load_state_dict(
torch.load(f"{path}/model.pt", map_location=device, weights_only=True)
)
return model.to(device)