"""Multi-head linear probe for SAE-feature classification. Architecture: LayerNorm(d_in) -> Dropout(p) -> N x Linear(d_in, 1) Each head is a binary classifier producing a single logit. """ from __future__ import annotations import torch import torch.nn as nn class MultiHeadProbe(nn.Module): """Multi-head binary probe on SAE features. Args: d_in: SAE feature dimension (e.g. 65536 for 65k width). head_names: List of head names (determines number of heads). dropout: Dropout probability applied after LayerNorm. """ def __init__(self, d_in: int, head_names: list[str], dropout: float = 0.2): super().__init__() self.d_in = d_in self.head_names = list(head_names) self.n_heads = len(head_names) self.norm = nn.LayerNorm(d_in) self.drop = nn.Dropout(dropout) self.heads = nn.ModuleList([nn.Linear(d_in, 1) for _ in range(self.n_heads)]) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x: [batch, d_in] float tensor of SAE features. Returns: [batch, n_heads] logits (pre-sigmoid). """ x = self.norm(x) x = self.drop(x) return torch.cat([h(x) for h in self.heads], dim=-1) def predict_proba(self, x: torch.Tensor) -> torch.Tensor: """Return per-head probabilities (sigmoid of logits).""" return torch.sigmoid(self.forward(x)) @classmethod def from_checkpoint(cls, path: str, device: str = "cpu") -> "MultiHeadProbe": """Load a probe from a checkpoint file. Args: path: Path to the .pt checkpoint. device: Device to load onto. Returns: Loaded MultiHeadProbe in eval mode. """ ckpt = torch.load(path, map_location=device, weights_only=False) probe = cls( d_in=ckpt["d_in"], head_names=ckpt["head_names"], dropout=ckpt.get("dropout", 0.2), ) probe.load_state_dict(ckpt["model_state_dict"]) probe.to(device).eval() return probe def save_checkpoint(self, path: str, **extra_meta): """Save probe checkpoint. Args: path: Destination .pt file. extra_meta: Additional metadata to store in the checkpoint. """ payload = { "model_state_dict": self.state_dict(), "d_in": self.d_in, "head_names": self.head_names, "n_heads": self.n_heads, "dropout": self.drop.p, } payload.update(extra_meta) torch.save(payload, path)