| """Multi-head linear probe for SAE-feature classification. |
| |
| Architecture: LayerNorm(d_in) -> Dropout(p) -> N x Linear(d_in, 1) |
| |
| Each head is a binary classifier producing a single logit. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class MultiHeadProbe(nn.Module): |
| """Multi-head binary probe on SAE features. |
| |
| Args: |
| d_in: SAE feature dimension (e.g. 65536 for 65k width). |
| head_names: List of head names (determines number of heads). |
| dropout: Dropout probability applied after LayerNorm. |
| """ |
|
|
| def __init__(self, d_in: int, head_names: list[str], dropout: float = 0.2): |
| super().__init__() |
| self.d_in = d_in |
| self.head_names = list(head_names) |
| self.n_heads = len(head_names) |
| self.norm = nn.LayerNorm(d_in) |
| self.drop = nn.Dropout(dropout) |
| self.heads = nn.ModuleList([nn.Linear(d_in, 1) for _ in range(self.n_heads)]) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Forward pass. |
| |
| Args: |
| x: [batch, d_in] float tensor of SAE features. |
| |
| Returns: |
| [batch, n_heads] logits (pre-sigmoid). |
| """ |
| x = self.norm(x) |
| x = self.drop(x) |
| return torch.cat([h(x) for h in self.heads], dim=-1) |
|
|
| def predict_proba(self, x: torch.Tensor) -> torch.Tensor: |
| """Return per-head probabilities (sigmoid of logits).""" |
| return torch.sigmoid(self.forward(x)) |
|
|
| @classmethod |
| def from_checkpoint(cls, path: str, device: str = "cpu") -> "MultiHeadProbe": |
| """Load a probe from a checkpoint file. |
| |
| Args: |
| path: Path to the .pt checkpoint. |
| device: Device to load onto. |
| |
| Returns: |
| Loaded MultiHeadProbe in eval mode. |
| """ |
| ckpt = torch.load(path, map_location=device, weights_only=False) |
| probe = cls( |
| d_in=ckpt["d_in"], |
| head_names=ckpt["head_names"], |
| dropout=ckpt.get("dropout", 0.2), |
| ) |
| probe.load_state_dict(ckpt["model_state_dict"]) |
| probe.to(device).eval() |
| return probe |
|
|
| def save_checkpoint(self, path: str, **extra_meta): |
| """Save probe checkpoint. |
| |
| Args: |
| path: Destination .pt file. |
| extra_meta: Additional metadata to store in the checkpoint. |
| """ |
| payload = { |
| "model_state_dict": self.state_dict(), |
| "d_in": self.d_in, |
| "head_names": self.head_names, |
| "n_heads": self.n_heads, |
| "dropout": self.drop.p, |
| } |
| payload.update(extra_meta) |
| torch.save(payload, path) |
|
|