Spaces:
Sleeping
Sleeping
| """Hook builders for TransformerLens β Red attack hooks and Blue defense hooks. | |
| TransformerLens hook naming convention: | |
| blocks.{L}.hook_resid_post β residual stream after layer L | |
| blocks.{L}.hook_resid_pre β residual stream before layer L | |
| blocks.{L}.attn.hook_attn_scores β attention logits (pre-softmax) | |
| blocks.{L}.hook_mlp_out β MLP output | |
| hook_logits β final logit tensor | |
| All hook functions follow the TransformerLens signature: | |
| hook_fn(value: Tensor, hook: HookPoint) -> Tensor | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from typing import Callable, Optional | |
| import torch | |
| # ββ Type alias ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HookFn = Callable[[torch.Tensor, object], torch.Tensor] | |
| HookPair = tuple[str, HookFn] # (hook_name, hook_fn) | |
| # ββ Hook name helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def resid_post(layer: int) -> str: | |
| return f"blocks.{layer}.hook_resid_post" | |
| def resid_pre(layer: int) -> str: | |
| return f"blocks.{layer}.hook_resid_pre" | |
| def attn_scores(layer: int) -> str: | |
| return f"blocks.{layer}.attn.hook_attn_scores" | |
| def mlp_out(layer: int) -> str: | |
| return f"blocks.{layer}.hook_mlp_out" | |
| LOGITS_HOOK = "hook_logits" | |
| # ββ Red: attack hooks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_steer_hook(direction: torch.Tensor, strength: float) -> HookFn: | |
| """Add `strength * direction` to every token's residual stream at this layer.""" | |
| direction = _unit(direction) | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| # value: (batch, seq, d_model) | |
| d = direction.to(value.device, value.dtype) | |
| return value + strength * d | |
| return hook | |
| def make_amplify_attn_hook(head: int, scale: float) -> HookFn: | |
| """Scale the attention scores for a specific head (pre-softmax). | |
| scale > 1 β head attends more sharply | |
| scale < 1 β head attends more uniformly | |
| 0 β head fully suppressed | |
| """ | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| # value: (batch, heads, seq_q, seq_k) | |
| value = value.clone() | |
| value[:, head, :, :] *= scale | |
| return value | |
| return hook | |
| def make_patch_hook(position: int, patch_value: torch.Tensor) -> HookFn: | |
| """Replace residual stream at *position* with *patch_value*.""" | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| # value: (batch, seq, d_model) | |
| p = patch_value.to(value.device, value.dtype) | |
| value = value.clone() | |
| value[:, position, :] = p | |
| return value | |
| return hook | |
| def make_logit_bias_hook(token_ids: list[int], bias: float) -> HookFn: | |
| """Add *bias* to the final logits for *token_ids* (steers generation).""" | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| # value: (batch, seq, vocab) or (batch, vocab) | |
| value = value.clone() | |
| value[..., token_ids] += bias | |
| return value | |
| return hook | |
| # ββ Blue: defense hooks βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_ablate_hook(direction: torch.Tensor) -> HookFn: | |
| """Project *direction* out of the residual stream (orthogonal ablation).""" | |
| direction = _unit(direction) | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| d = direction.to(value.device, value.dtype) | |
| # v - (vΒ·d) * d for each token | |
| # value: (batch, seq, d_model) | |
| proj = (value @ d).unsqueeze(-1) * d # (batch, seq, d_model) | |
| return value - proj | |
| return hook | |
| def make_suppress_head_hook(head: int) -> HookFn: | |
| """Zero out a specific attention head's contribution.""" | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| value = value.clone() | |
| value[:, head, :, :] = 0.0 | |
| return value | |
| return hook | |
| def make_clamp_hook(clamp_min: float, clamp_max: float) -> HookFn: | |
| """Clamp residual stream values to [clamp_min, clamp_max].""" | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| return value.clamp(clamp_min, clamp_max) | |
| return hook | |
| def make_restore_hook( | |
| position: int, | |
| baseline_activation: torch.Tensor, | |
| ) -> HookFn: | |
| """Replace activation at *position* with the clean baseline.""" | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| b = baseline_activation.to(value.device, value.dtype) | |
| value = value.clone() | |
| value[:, position, :] = b | |
| return value | |
| return hook | |
| def make_logit_filter_hook(token_ids: list[int]) -> HookFn: | |
| """Set logits for prohibited token ids to -inf.""" | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| value = value.clone() | |
| value[..., token_ids] = float("-inf") | |
| return value | |
| return hook | |
| # ββ Detection utility βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_detection_hook( | |
| ref_direction: torch.Tensor, | |
| threshold: float, | |
| detected_layers: list[int], | |
| layer_id: int, | |
| ) -> HookFn: | |
| """Blue monitoring hook: records *layer_id* if steering is detected. | |
| Detects steering by measuring cosine similarity of the per-token residual | |
| shift relative to `ref_direction`. If mean cosine sim > threshold, the | |
| layer is flagged. | |
| """ | |
| ref = _unit(ref_direction) | |
| def hook(value: torch.Tensor, hook) -> torch.Tensor: # noqa: ANN001 | |
| r = ref.to(value.device, value.dtype) | |
| # value: (batch, seq, d_model) | |
| norms = value.norm(dim=-1, keepdim=True).clamp(min=1e-8) | |
| cos_sim = (value / norms) @ r # (batch, seq) | |
| mean_cos = cos_sim.abs().mean().item() | |
| if mean_cos > threshold: | |
| detected_layers.append(layer_id) | |
| return value # no modification β pure observation | |
| return hook | |
| # ββ Internals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _unit(v: torch.Tensor) -> torch.Tensor: | |
| norm = v.norm() | |
| return v / norm if norm > 1e-8 else v | |