""" Stage 4: Activation steering via projection decay (Apr 2026 update). NEW SEMANTICS: h_new = h - (1 - alpha) * P · h where P is either: P = ŵ ŵ^T (rank-1 projector, single direction) → "v1_raw" P = Q^T Q (rank-k projector, subspace) → "v_pca_subspace" α represents the "ability level": - alpha = 1: no change (baseline) - alpha = 0: full removal of the cognitive subspace - alpha < 0: over-suppression (rare, prone to collapse) - alpha > 1: amplification (rare) JOINT STEERING (anti-leak): When suppressing one dimension, optionally also softly suppress the other to prevent compensatory activation (e.g. suppressing planning causing monitoring trigger spike). Coupling factor `beta` controls strength. h_new = h - (1-α) * P_target · h - (1-α) * β * P_other · h Hook point: decoder layer output (post-layer residual stream). """ import torch from typing import Dict, List, Optional, Union from configs.model import MODEL_CONFIG, ANTI_LEAK_BETA # ============================================================ # Helper: which alpha is "no-op"? # ============================================================ NEUTRAL_ALPHA = 1.0 def is_neutral_alpha(alpha: float, eps: float = 1e-5) -> bool: if alpha is None: return False return abs(alpha - NEUTRAL_ALPHA) <= eps # ============================================================ # Projector construction # ============================================================ def _make_projector(direction: torch.Tensor, device, dtype): """ Given a direction or subspace basis, return a function proj(h) -> P · h where h is (B, S, D) and the result is (B, S, D). direction shapes: (D,) : rank-1 projector ŵŵ^T (k, D) : rank-k projector Q^T Q """ direction = direction.to(device=device, dtype=dtype) if direction.dim() == 1: # Normalize defensively n = direction.norm() if n < 1e-8: return None w = (direction / n).to(dtype) def proj(h): scalar = h @ w # (B, S) return scalar.unsqueeze(-1) * w # (B, S, D) return proj elif direction.dim() == 2: # Q is (k, D), assume row-orthonormal if direction.shape[0] == 0 or direction.shape[1] == 0: return None Q = direction.to(dtype) def proj(h): # h @ Q^T -> (B, S, k); then @ Q -> (B, S, D) coords = h @ Q.T # (B, S, k) return coords @ Q # (B, S, D) return proj else: return None # ============================================================ # Single-dimension steerer (backward compatible) # ============================================================ class ResidualSteerer: """ Apply projection decay steering to post-layer residual at target layers. For single direction, P · h = (h · ŵ) ŵ. For subspace, P · h = Q^T Q · h. """ def __init__( self, model, directions: Dict[int, torch.Tensor], alpha: float = NEUTRAL_ALPHA, ): self.model = model self.directions = directions self.alpha = alpha self.handles = [] self._device = next(model.parameters()).device self._dtype = next(model.parameters()).dtype def _make_hook(self, layer_id: int): proj = _make_projector(self.directions[layer_id], self._device, self._dtype) scale = 1.0 - float(self.alpha) if proj is None or abs(scale) < 1e-9: def noop(module, inputs, output): return output return noop def fn(module, inputs, output): if isinstance(output, tuple): h = output[0] rest = output[1:] else: h = output rest = None h_new = h - scale * proj(h) if rest is not None: return (h_new,) + rest return h_new return fn def start(self): for li in self.directions: layer = self.model.model.layers[li] h = layer.register_forward_hook(self._make_hook(li)) self.handles.append(h) def stop(self): for h in self.handles: h.remove() self.handles = [] # ============================================================ # Joint steerer with anti-leak coupling # ============================================================ class JointResidualSteerer: """ Apply joint steering on TWO dimensions (planning + monitoring) simultaneously. Used to prevent compensatory activation when suppressing one dimension. Steering equation: h_new = h - (1-α_target) * P_target · h - (1-α_target) * β * P_other · h Args: model: HF model target_dirs: {layer_id: direction or basis} - dimension being primarily steered other_dirs: {layer_id: direction or basis} - dimension being coupled (anti-leak) alpha: steering strength for target (NEW SEMANTICS, 1=no change, 0=full) beta: coupling factor for the other dimension (default ANTI_LEAK_BETA=0.3) """ def __init__( self, model, target_dirs: Dict[int, torch.Tensor], other_dirs: Dict[int, torch.Tensor], alpha: float = NEUTRAL_ALPHA, beta: float = ANTI_LEAK_BETA, ): self.model = model self.target_dirs = target_dirs self.other_dirs = other_dirs self.alpha = alpha self.beta = beta self.handles = [] self._device = next(model.parameters()).device self._dtype = next(model.parameters()).dtype def _make_hook(self, layer_id: int): target_proj = _make_projector(self.target_dirs[layer_id], self._device, self._dtype) other_proj = (_make_projector(self.other_dirs[layer_id], self._device, self._dtype) if layer_id in self.other_dirs else None) scale_target = 1.0 - float(self.alpha) scale_other = scale_target * float(self.beta) if target_proj is None and other_proj is None: def noop(module, inputs, output): return output return noop def fn(module, inputs, output): if isinstance(output, tuple): h = output[0] rest = output[1:] else: h = output rest = None h_new = h if target_proj is not None and abs(scale_target) > 1e-9: h_new = h_new - scale_target * target_proj(h_new) if other_proj is not None and abs(scale_other) > 1e-9: h_new = h_new - scale_other * other_proj(h_new) if rest is not None: return (h_new,) + rest return h_new return fn def start(self): all_layers = set(self.target_dirs.keys()) | set(self.other_dirs.keys()) for li in all_layers: layer = self.model.model.layers[li] h = layer.register_forward_hook(self._make_hook(li)) self.handles.append(h) def stop(self): for h in self.handles: h.remove() self.handles = [] # ============================================================ # Force-prompt mechanism (kept for ablation comparison) # ============================================================ FORCE_SUPPRESS_PROMPTS = { "planning": ( "IMPORTANT: Solve this problem WITHOUT planning, WITHOUT stating strategies, " "WITHOUT outlining steps in advance. Just compute directly." ), "monitoring": ( "IMPORTANT: Solve this problem without double-checking, without verifying, " "without saying 'wait' or 'let me check'. Just produce the answer directly." ), } FORCE_ENHANCE_PROMPTS = { "planning": ( "IMPORTANT: Before starting, explicitly state your plan. Break the problem " "into clearly labeled steps. Discuss multiple strategies and choose one. " "Reference your plan as you execute." ), "monitoring": ( "IMPORTANT: After each step, verify your work. Say 'wait, let me check'. " "Substitute values back to confirm. Consider alternative interpretations." ), } def build_force_prompt(base_system_prompt: str, dimension: str, mode: str) -> str: if mode == "suppress": extra = FORCE_SUPPRESS_PROMPTS[dimension] elif mode == "enhance": extra = FORCE_ENHANCE_PROMPTS[dimension] else: return base_system_prompt return f"{base_system_prompt}\n\n{extra}"