razlapid's picture
Upload SAE guard probes
f377030 verified
"""Multi-head linear probe for SAE-feature classification.
Architecture: LayerNorm(d_in) -> Dropout(p) -> N x Linear(d_in, 1)
Each head is a binary classifier producing a single logit.
"""
from __future__ import annotations
import torch
import torch.nn as nn
class MultiHeadProbe(nn.Module):
"""Multi-head binary probe on SAE features.
Args:
d_in: SAE feature dimension (e.g. 65536 for 65k width).
head_names: List of head names (determines number of heads).
dropout: Dropout probability applied after LayerNorm.
"""
def __init__(self, d_in: int, head_names: list[str], dropout: float = 0.2):
super().__init__()
self.d_in = d_in
self.head_names = list(head_names)
self.n_heads = len(head_names)
self.norm = nn.LayerNorm(d_in)
self.drop = nn.Dropout(dropout)
self.heads = nn.ModuleList([nn.Linear(d_in, 1) for _ in range(self.n_heads)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: [batch, d_in] float tensor of SAE features.
Returns:
[batch, n_heads] logits (pre-sigmoid).
"""
x = self.norm(x)
x = self.drop(x)
return torch.cat([h(x) for h in self.heads], dim=-1)
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
"""Return per-head probabilities (sigmoid of logits)."""
return torch.sigmoid(self.forward(x))
@classmethod
def from_checkpoint(cls, path: str, device: str = "cpu") -> "MultiHeadProbe":
"""Load a probe from a checkpoint file.
Args:
path: Path to the .pt checkpoint.
device: Device to load onto.
Returns:
Loaded MultiHeadProbe in eval mode.
"""
ckpt = torch.load(path, map_location=device, weights_only=False)
probe = cls(
d_in=ckpt["d_in"],
head_names=ckpt["head_names"],
dropout=ckpt.get("dropout", 0.2),
)
probe.load_state_dict(ckpt["model_state_dict"])
probe.to(device).eval()
return probe
def save_checkpoint(self, path: str, **extra_meta):
"""Save probe checkpoint.
Args:
path: Destination .pt file.
extra_meta: Additional metadata to store in the checkpoint.
"""
payload = {
"model_state_dict": self.state_dict(),
"d_in": self.d_in,
"head_names": self.head_names,
"n_heads": self.n_heads,
"dropout": self.drop.p,
}
payload.update(extra_meta)
torch.save(payload, path)