v2 / src /steering.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 4: Activation steering via projection decay (Apr 2026 update).
NEW SEMANTICS:
h_new = h - (1 - alpha) * P · h
where P is either:
P = ŵ ŵ^T (rank-1 projector, single direction) → "v1_raw"
P = Q^T Q (rank-k projector, subspace) → "v_pca_subspace"
α represents the "ability level":
- alpha = 1: no change (baseline)
- alpha = 0: full removal of the cognitive subspace
- alpha < 0: over-suppression (rare, prone to collapse)
- alpha > 1: amplification (rare)
JOINT STEERING (anti-leak):
When suppressing one dimension, optionally also softly suppress the other
to prevent compensatory activation (e.g. suppressing planning causing
monitoring trigger spike). Coupling factor `beta` controls strength.
h_new = h - (1-α) * P_target · h - (1-α) * β * P_other · h
Hook point: decoder layer output (post-layer residual stream).
"""
import torch
from typing import Dict, List, Optional, Union
from configs.model import MODEL_CONFIG, ANTI_LEAK_BETA
# ============================================================
# Helper: which alpha is "no-op"?
# ============================================================
NEUTRAL_ALPHA = 1.0
def is_neutral_alpha(alpha: float, eps: float = 1e-5) -> bool:
if alpha is None:
return False
return abs(alpha - NEUTRAL_ALPHA) <= eps
# ============================================================
# Projector construction
# ============================================================
def _make_projector(direction: torch.Tensor, device, dtype):
"""
Given a direction or subspace basis, return a function
proj(h) -> P · h
where h is (B, S, D) and the result is (B, S, D).
direction shapes:
(D,) : rank-1 projector ŵŵ^T
(k, D) : rank-k projector Q^T Q
"""
direction = direction.to(device=device, dtype=dtype)
if direction.dim() == 1:
# Normalize defensively
n = direction.norm()
if n < 1e-8:
return None
w = (direction / n).to(dtype)
def proj(h):
scalar = h @ w # (B, S)
return scalar.unsqueeze(-1) * w # (B, S, D)
return proj
elif direction.dim() == 2:
# Q is (k, D), assume row-orthonormal
if direction.shape[0] == 0 or direction.shape[1] == 0:
return None
Q = direction.to(dtype)
def proj(h):
# h @ Q^T -> (B, S, k); then @ Q -> (B, S, D)
coords = h @ Q.T # (B, S, k)
return coords @ Q # (B, S, D)
return proj
else:
return None
# ============================================================
# Single-dimension steerer (backward compatible)
# ============================================================
class ResidualSteerer:
"""
Apply projection decay steering to post-layer residual at target layers.
For single direction, P · h = (h · ŵ) ŵ.
For subspace, P · h = Q^T Q · h.
"""
def __init__(
self,
model,
directions: Dict[int, torch.Tensor],
alpha: float = NEUTRAL_ALPHA,
):
self.model = model
self.directions = directions
self.alpha = alpha
self.handles = []
self._device = next(model.parameters()).device
self._dtype = next(model.parameters()).dtype
def _make_hook(self, layer_id: int):
proj = _make_projector(self.directions[layer_id], self._device, self._dtype)
scale = 1.0 - float(self.alpha)
if proj is None or abs(scale) < 1e-9:
def noop(module, inputs, output):
return output
return noop
def fn(module, inputs, output):
if isinstance(output, tuple):
h = output[0]
rest = output[1:]
else:
h = output
rest = None
h_new = h - scale * proj(h)
if rest is not None:
return (h_new,) + rest
return h_new
return fn
def start(self):
for li in self.directions:
layer = self.model.model.layers[li]
h = layer.register_forward_hook(self._make_hook(li))
self.handles.append(h)
def stop(self):
for h in self.handles:
h.remove()
self.handles = []
# ============================================================
# Joint steerer with anti-leak coupling
# ============================================================
class JointResidualSteerer:
"""
Apply joint steering on TWO dimensions (planning + monitoring) simultaneously.
Used to prevent compensatory activation when suppressing one dimension.
Steering equation:
h_new = h - (1-α_target) * P_target · h
- (1-α_target) * β * P_other · h
Args:
model: HF model
target_dirs: {layer_id: direction or basis} - dimension being primarily steered
other_dirs: {layer_id: direction or basis} - dimension being coupled (anti-leak)
alpha: steering strength for target (NEW SEMANTICS, 1=no change, 0=full)
beta: coupling factor for the other dimension (default ANTI_LEAK_BETA=0.3)
"""
def __init__(
self,
model,
target_dirs: Dict[int, torch.Tensor],
other_dirs: Dict[int, torch.Tensor],
alpha: float = NEUTRAL_ALPHA,
beta: float = ANTI_LEAK_BETA,
):
self.model = model
self.target_dirs = target_dirs
self.other_dirs = other_dirs
self.alpha = alpha
self.beta = beta
self.handles = []
self._device = next(model.parameters()).device
self._dtype = next(model.parameters()).dtype
def _make_hook(self, layer_id: int):
target_proj = _make_projector(self.target_dirs[layer_id], self._device, self._dtype)
other_proj = (_make_projector(self.other_dirs[layer_id], self._device, self._dtype)
if layer_id in self.other_dirs else None)
scale_target = 1.0 - float(self.alpha)
scale_other = scale_target * float(self.beta)
if target_proj is None and other_proj is None:
def noop(module, inputs, output):
return output
return noop
def fn(module, inputs, output):
if isinstance(output, tuple):
h = output[0]
rest = output[1:]
else:
h = output
rest = None
h_new = h
if target_proj is not None and abs(scale_target) > 1e-9:
h_new = h_new - scale_target * target_proj(h_new)
if other_proj is not None and abs(scale_other) > 1e-9:
h_new = h_new - scale_other * other_proj(h_new)
if rest is not None:
return (h_new,) + rest
return h_new
return fn
def start(self):
all_layers = set(self.target_dirs.keys()) | set(self.other_dirs.keys())
for li in all_layers:
layer = self.model.model.layers[li]
h = layer.register_forward_hook(self._make_hook(li))
self.handles.append(h)
def stop(self):
for h in self.handles:
h.remove()
self.handles = []
# ============================================================
# Force-prompt mechanism (kept for ablation comparison)
# ============================================================
FORCE_SUPPRESS_PROMPTS = {
"planning": (
"IMPORTANT: Solve this problem WITHOUT planning, WITHOUT stating strategies, "
"WITHOUT outlining steps in advance. Just compute directly."
),
"monitoring": (
"IMPORTANT: Solve this problem without double-checking, without verifying, "
"without saying 'wait' or 'let me check'. Just produce the answer directly."
),
}
FORCE_ENHANCE_PROMPTS = {
"planning": (
"IMPORTANT: Before starting, explicitly state your plan. Break the problem "
"into clearly labeled steps. Discuss multiple strategies and choose one. "
"Reference your plan as you execute."
),
"monitoring": (
"IMPORTANT: After each step, verify your work. Say 'wait, let me check'. "
"Substitute values back to confirm. Consider alternative interpretations."
),
}
def build_force_prompt(base_system_prompt: str, dimension: str, mode: str) -> str:
if mode == "suppress":
extra = FORCE_SUPPRESS_PROMPTS[dimension]
elif mode == "enhance":
extra = FORCE_ENHANCE_PROMPTS[dimension]
else:
return base_system_prompt
return f"{base_system_prompt}\n\n{extra}"