"""Activation steering hook — injects steering vectors into model residual stream. This module provides the core mechanism for all steering methods: a forward hook that adds α * v to the hidden state at a specified layer. """ import logging from typing import Callable, Optional, Tuple import torch import torch.nn as nn import numpy as np logger = logging.getLogger(__name__) class SteeringHook: """Forward hook that adds a steering vector to the residual stream. Usage: hook = SteeringHook(model, layer_idx=16, hf_id="llava-hf/llava-1.5-7b-hf") hook.set_vector(v, alpha=1.5) # model forward pass — steering is applied automatically hook.remove() """ def __init__( self, model: nn.Module, layer_idx: int, hf_id: str = "llava-hf/llava-1.5-7b-hf", ): self.model = model self.layer_idx = layer_idx self.hf_id = hf_id self._vector: Optional[torch.Tensor] = None self._alpha: float = 1.0 self._handle = None self._active = False # Get the target layer module self._layer_module = self._resolve_layer(model, layer_idx, hf_id) @staticmethod def _resolve_layer(model: nn.Module, layer_idx: int, hf_id: str) -> nn.Module: """Resolve the layer module from model architecture. Handles different backbone architectures: - LLaVA-1.5: model.model.layers.{idx} (within LlavaForConditionalGeneration) - Qwen2.5-VL: model.language_model.layers.{idx} - Gemma-3: model.language_model.layers.{idx} """ # Try different paths (order matters — more specific first) paths_to_try = [ f"model.language_model.layers.{layer_idx}", # Gemma-3, Qwen2.5-VL (conditional generation) f"model.model.layers.{layer_idx}", # LLaVA wrapped f"model.layers.{layer_idx}", # Standard (Qwen2, etc.) f"transformer.h.{layer_idx}", # GPT-style ] for path in paths_to_try: try: module = model for attr in path.split("."): if attr.isdigit(): module = module[int(attr)] else: module = getattr(module, attr) logger.info(f"Resolved layer {layer_idx} at path: {path}") return module except (AttributeError, IndexError, TypeError): continue raise ValueError( f"Could not resolve layer {layer_idx} for {hf_id}. " f"Run print_layer_names() to verify the correct path." ) def set_vector(self, vector: np.ndarray, alpha: float = 1.0): """Set the steering vector and magnitude. Args: vector: (d,) steering direction alpha: steering magnitude (α) """ if isinstance(vector, np.ndarray): vector = torch.from_numpy(vector) self._vector = vector.float() self._alpha = alpha self._install_hook() def _hook_fn(self, module, input, output): """Forward hook that adds α * v to the hidden states.""" if self._vector is None or not self._active: return output # Handle different output formats if isinstance(output, tuple): hidden_states = output[0] else: hidden_states = output device = hidden_states.device dtype = hidden_states.dtype v = self._vector.to(device=device, dtype=dtype) # Add steering vector to all token positions # hidden_states shape: (batch, seq_len, hidden_dim) hidden_states = hidden_states + self._alpha * v.unsqueeze(0).unsqueeze(0) if isinstance(output, tuple): return (hidden_states,) + output[1:] return hidden_states def _install_hook(self): """Install the forward hook on the target layer.""" if self._handle is not None: self._handle.remove() self._handle = self._layer_module.register_forward_hook(self._hook_fn) self._active = True logger.debug(f"Steering hook installed at layer {self.layer_idx} (α={self._alpha})") def activate(self): """Activate the steering hook.""" self._active = True def deactivate(self): """Temporarily deactivate without removing.""" self._active = False def remove(self): """Remove the hook entirely.""" if self._handle is not None: self._handle.remove() self._handle = None self._active = False logger.debug(f"Steering hook removed from layer {self.layer_idx}") def __del__(self): self.remove() def apply_steering( model: nn.Module, vector: Optional[np.ndarray], layer_idx: int, alpha: float, hf_id: str, ) -> Optional[SteeringHook]: """Convenience function to apply a steering vector. Args: model: The backbone model vector: Steering vector (d,) or None (for prompt-only methods) layer_idx: Target layer index alpha: Steering magnitude hf_id: HuggingFace model identifier Returns: SteeringHook (or None if vector is None) """ if vector is None: return None hook = SteeringHook(model, layer_idx, hf_id) hook.set_vector(vector, alpha) return hook