Spaces:
Running
Running
| """Activation steering hook — injects steering vectors into model residual stream. | |
| This module provides the core mechanism for all steering methods: | |
| a forward hook that adds α * v to the hidden state at a specified layer. | |
| """ | |
| import logging | |
| from typing import Callable, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class SteeringHook: | |
| """Forward hook that adds a steering vector to the residual stream. | |
| Usage: | |
| hook = SteeringHook(model, layer_idx=16, hf_id="llava-hf/llava-1.5-7b-hf") | |
| hook.set_vector(v, alpha=1.5) | |
| # model forward pass — steering is applied automatically | |
| hook.remove() | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| layer_idx: int, | |
| hf_id: str = "llava-hf/llava-1.5-7b-hf", | |
| ): | |
| self.model = model | |
| self.layer_idx = layer_idx | |
| self.hf_id = hf_id | |
| self._vector: Optional[torch.Tensor] = None | |
| self._alpha: float = 1.0 | |
| self._handle = None | |
| self._active = False | |
| # Get the target layer module | |
| self._layer_module = self._resolve_layer(model, layer_idx, hf_id) | |
| def _resolve_layer(model: nn.Module, layer_idx: int, hf_id: str) -> nn.Module: | |
| """Resolve the layer module from model architecture. | |
| Handles different backbone architectures: | |
| - LLaVA-1.5: model.model.layers.{idx} (within LlavaForConditionalGeneration) | |
| - Qwen2.5-VL: model.language_model.layers.{idx} | |
| - Gemma-3: model.language_model.layers.{idx} | |
| """ | |
| # Try different paths (order matters — more specific first) | |
| paths_to_try = [ | |
| f"model.language_model.layers.{layer_idx}", # Gemma-3, Qwen2.5-VL (conditional generation) | |
| f"model.model.layers.{layer_idx}", # LLaVA wrapped | |
| f"model.layers.{layer_idx}", # Standard (Qwen2, etc.) | |
| f"transformer.h.{layer_idx}", # GPT-style | |
| ] | |
| for path in paths_to_try: | |
| try: | |
| module = model | |
| for attr in path.split("."): | |
| if attr.isdigit(): | |
| module = module[int(attr)] | |
| else: | |
| module = getattr(module, attr) | |
| logger.info(f"Resolved layer {layer_idx} at path: {path}") | |
| return module | |
| except (AttributeError, IndexError, TypeError): | |
| continue | |
| raise ValueError( | |
| f"Could not resolve layer {layer_idx} for {hf_id}. " | |
| f"Run print_layer_names() to verify the correct path." | |
| ) | |
| def set_vector(self, vector: np.ndarray, alpha: float = 1.0): | |
| """Set the steering vector and magnitude. | |
| Args: | |
| vector: (d,) steering direction | |
| alpha: steering magnitude (α) | |
| """ | |
| if isinstance(vector, np.ndarray): | |
| vector = torch.from_numpy(vector) | |
| self._vector = vector.float() | |
| self._alpha = alpha | |
| self._install_hook() | |
| def _hook_fn(self, module, input, output): | |
| """Forward hook that adds α * v to the hidden states.""" | |
| if self._vector is None or not self._active: | |
| return output | |
| # Handle different output formats | |
| if isinstance(output, tuple): | |
| hidden_states = output[0] | |
| else: | |
| hidden_states = output | |
| device = hidden_states.device | |
| dtype = hidden_states.dtype | |
| v = self._vector.to(device=device, dtype=dtype) | |
| # Add steering vector to all token positions | |
| # hidden_states shape: (batch, seq_len, hidden_dim) | |
| hidden_states = hidden_states + self._alpha * v.unsqueeze(0).unsqueeze(0) | |
| if isinstance(output, tuple): | |
| return (hidden_states,) + output[1:] | |
| return hidden_states | |
| def _install_hook(self): | |
| """Install the forward hook on the target layer.""" | |
| if self._handle is not None: | |
| self._handle.remove() | |
| self._handle = self._layer_module.register_forward_hook(self._hook_fn) | |
| self._active = True | |
| logger.debug(f"Steering hook installed at layer {self.layer_idx} (α={self._alpha})") | |
| def activate(self): | |
| """Activate the steering hook.""" | |
| self._active = True | |
| def deactivate(self): | |
| """Temporarily deactivate without removing.""" | |
| self._active = False | |
| def remove(self): | |
| """Remove the hook entirely.""" | |
| if self._handle is not None: | |
| self._handle.remove() | |
| self._handle = None | |
| self._active = False | |
| logger.debug(f"Steering hook removed from layer {self.layer_idx}") | |
| def __del__(self): | |
| self.remove() | |
| def apply_steering( | |
| model: nn.Module, | |
| vector: Optional[np.ndarray], | |
| layer_idx: int, | |
| alpha: float, | |
| hf_id: str, | |
| ) -> Optional[SteeringHook]: | |
| """Convenience function to apply a steering vector. | |
| Args: | |
| model: The backbone model | |
| vector: Steering vector (d,) or None (for prompt-only methods) | |
| layer_idx: Target layer index | |
| alpha: Steering magnitude | |
| hf_id: HuggingFace model identifier | |
| Returns: | |
| SteeringHook (or None if vector is None) | |
| """ | |
| if vector is None: | |
| return None | |
| hook = SteeringHook(model, layer_idx, hf_id) | |
| hook.set_vector(vector, alpha) | |
| return hook | |