siege / interp_arena /model /hooks.py
BART-ender's picture
Upload folder using huggingface_hub
433f30e verified
"""Hook builders for TransformerLens β€” Red attack hooks and Blue defense hooks.
TransformerLens hook naming convention:
blocks.{L}.hook_resid_post β†’ residual stream after layer L
blocks.{L}.hook_resid_pre β†’ residual stream before layer L
blocks.{L}.attn.hook_attn_scores β†’ attention logits (pre-softmax)
blocks.{L}.hook_mlp_out β†’ MLP output
hook_logits β†’ final logit tensor
All hook functions follow the TransformerLens signature:
hook_fn(value: Tensor, hook: HookPoint) -> Tensor
"""
from __future__ import annotations
import re
from typing import Callable, Optional
import torch
# ── Type alias ────────────────────────────────────────────────────────────────
HookFn = Callable[[torch.Tensor, object], torch.Tensor]
HookPair = tuple[str, HookFn] # (hook_name, hook_fn)
# ── Hook name helpers ─────────────────────────────────────────────────────────
def resid_post(layer: int) -> str:
return f"blocks.{layer}.hook_resid_post"
def resid_pre(layer: int) -> str:
return f"blocks.{layer}.hook_resid_pre"
def attn_scores(layer: int) -> str:
return f"blocks.{layer}.attn.hook_attn_scores"
def mlp_out(layer: int) -> str:
return f"blocks.{layer}.hook_mlp_out"
LOGITS_HOOK = "hook_logits"
# ── Red: attack hooks ─────────────────────────────────────────────────────────
def make_steer_hook(direction: torch.Tensor, strength: float) -> HookFn:
"""Add `strength * direction` to every token's residual stream at this layer."""
direction = _unit(direction)
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
# value: (batch, seq, d_model)
d = direction.to(value.device, value.dtype)
return value + strength * d
return hook
def make_amplify_attn_hook(head: int, scale: float) -> HookFn:
"""Scale the attention scores for a specific head (pre-softmax).
scale > 1 β†’ head attends more sharply
scale < 1 β†’ head attends more uniformly
0 β†’ head fully suppressed
"""
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
# value: (batch, heads, seq_q, seq_k)
value = value.clone()
value[:, head, :, :] *= scale
return value
return hook
def make_patch_hook(position: int, patch_value: torch.Tensor) -> HookFn:
"""Replace residual stream at *position* with *patch_value*."""
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
# value: (batch, seq, d_model)
p = patch_value.to(value.device, value.dtype)
value = value.clone()
value[:, position, :] = p
return value
return hook
def make_logit_bias_hook(token_ids: list[int], bias: float) -> HookFn:
"""Add *bias* to the final logits for *token_ids* (steers generation)."""
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
# value: (batch, seq, vocab) or (batch, vocab)
value = value.clone()
value[..., token_ids] += bias
return value
return hook
# ── Blue: defense hooks ───────────────────────────────────────────────────────
def make_ablate_hook(direction: torch.Tensor) -> HookFn:
"""Project *direction* out of the residual stream (orthogonal ablation)."""
direction = _unit(direction)
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
d = direction.to(value.device, value.dtype)
# v - (vΒ·d) * d for each token
# value: (batch, seq, d_model)
proj = (value @ d).unsqueeze(-1) * d # (batch, seq, d_model)
return value - proj
return hook
def make_suppress_head_hook(head: int) -> HookFn:
"""Zero out a specific attention head's contribution."""
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
value = value.clone()
value[:, head, :, :] = 0.0
return value
return hook
def make_clamp_hook(clamp_min: float, clamp_max: float) -> HookFn:
"""Clamp residual stream values to [clamp_min, clamp_max]."""
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
return value.clamp(clamp_min, clamp_max)
return hook
def make_restore_hook(
position: int,
baseline_activation: torch.Tensor,
) -> HookFn:
"""Replace activation at *position* with the clean baseline."""
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
b = baseline_activation.to(value.device, value.dtype)
value = value.clone()
value[:, position, :] = b
return value
return hook
def make_logit_filter_hook(token_ids: list[int]) -> HookFn:
"""Set logits for prohibited token ids to -inf."""
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
value = value.clone()
value[..., token_ids] = float("-inf")
return value
return hook
# ── Detection utility ─────────────────────────────────────────────────────────
def make_detection_hook(
ref_direction: torch.Tensor,
threshold: float,
detected_layers: list[int],
layer_id: int,
) -> HookFn:
"""Blue monitoring hook: records *layer_id* if steering is detected.
Detects steering by measuring cosine similarity of the per-token residual
shift relative to `ref_direction`. If mean cosine sim > threshold, the
layer is flagged.
"""
ref = _unit(ref_direction)
def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001
r = ref.to(value.device, value.dtype)
# value: (batch, seq, d_model)
norms = value.norm(dim=-1, keepdim=True).clamp(min=1e-8)
cos_sim = (value / norms) @ r # (batch, seq)
mean_cos = cos_sim.abs().mean().item()
if mean_cos > threshold:
detected_layers.append(layer_id)
return value # no modification β€” pure observation
return hook
# ── Internals ─────────────────────────────────────────────────────────────────
def _unit(v: torch.Tensor) -> torch.Tensor:
norm = v.norm()
return v / norm if norm > 1e-8 else v