Spaces:
Running
Running
File size: 5,471 Bytes
e6f24ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """Activation steering hook — injects steering vectors into model residual stream.
This module provides the core mechanism for all steering methods:
a forward hook that adds α * v to the hidden state at a specified layer.
"""
import logging
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import numpy as np
logger = logging.getLogger(__name__)
class SteeringHook:
"""Forward hook that adds a steering vector to the residual stream.
Usage:
hook = SteeringHook(model, layer_idx=16, hf_id="llava-hf/llava-1.5-7b-hf")
hook.set_vector(v, alpha=1.5)
# model forward pass — steering is applied automatically
hook.remove()
"""
def __init__(
self,
model: nn.Module,
layer_idx: int,
hf_id: str = "llava-hf/llava-1.5-7b-hf",
):
self.model = model
self.layer_idx = layer_idx
self.hf_id = hf_id
self._vector: Optional[torch.Tensor] = None
self._alpha: float = 1.0
self._handle = None
self._active = False
# Get the target layer module
self._layer_module = self._resolve_layer(model, layer_idx, hf_id)
@staticmethod
def _resolve_layer(model: nn.Module, layer_idx: int, hf_id: str) -> nn.Module:
"""Resolve the layer module from model architecture.
Handles different backbone architectures:
- LLaVA-1.5: model.model.layers.{idx} (within LlavaForConditionalGeneration)
- Qwen2.5-VL: model.language_model.layers.{idx}
- Gemma-3: model.language_model.layers.{idx}
"""
# Try different paths (order matters — more specific first)
paths_to_try = [
f"model.language_model.layers.{layer_idx}", # Gemma-3, Qwen2.5-VL (conditional generation)
f"model.model.layers.{layer_idx}", # LLaVA wrapped
f"model.layers.{layer_idx}", # Standard (Qwen2, etc.)
f"transformer.h.{layer_idx}", # GPT-style
]
for path in paths_to_try:
try:
module = model
for attr in path.split("."):
if attr.isdigit():
module = module[int(attr)]
else:
module = getattr(module, attr)
logger.info(f"Resolved layer {layer_idx} at path: {path}")
return module
except (AttributeError, IndexError, TypeError):
continue
raise ValueError(
f"Could not resolve layer {layer_idx} for {hf_id}. "
f"Run print_layer_names() to verify the correct path."
)
def set_vector(self, vector: np.ndarray, alpha: float = 1.0):
"""Set the steering vector and magnitude.
Args:
vector: (d,) steering direction
alpha: steering magnitude (α)
"""
if isinstance(vector, np.ndarray):
vector = torch.from_numpy(vector)
self._vector = vector.float()
self._alpha = alpha
self._install_hook()
def _hook_fn(self, module, input, output):
"""Forward hook that adds α * v to the hidden states."""
if self._vector is None or not self._active:
return output
# Handle different output formats
if isinstance(output, tuple):
hidden_states = output[0]
else:
hidden_states = output
device = hidden_states.device
dtype = hidden_states.dtype
v = self._vector.to(device=device, dtype=dtype)
# Add steering vector to all token positions
# hidden_states shape: (batch, seq_len, hidden_dim)
hidden_states = hidden_states + self._alpha * v.unsqueeze(0).unsqueeze(0)
if isinstance(output, tuple):
return (hidden_states,) + output[1:]
return hidden_states
def _install_hook(self):
"""Install the forward hook on the target layer."""
if self._handle is not None:
self._handle.remove()
self._handle = self._layer_module.register_forward_hook(self._hook_fn)
self._active = True
logger.debug(f"Steering hook installed at layer {self.layer_idx} (α={self._alpha})")
def activate(self):
"""Activate the steering hook."""
self._active = True
def deactivate(self):
"""Temporarily deactivate without removing."""
self._active = False
def remove(self):
"""Remove the hook entirely."""
if self._handle is not None:
self._handle.remove()
self._handle = None
self._active = False
logger.debug(f"Steering hook removed from layer {self.layer_idx}")
def __del__(self):
self.remove()
def apply_steering(
model: nn.Module,
vector: Optional[np.ndarray],
layer_idx: int,
alpha: float,
hf_id: str,
) -> Optional[SteeringHook]:
"""Convenience function to apply a steering vector.
Args:
model: The backbone model
vector: Steering vector (d,) or None (for prompt-only methods)
layer_idx: Target layer index
alpha: Steering magnitude
hf_id: HuggingFace model identifier
Returns:
SteeringHook (or None if vector is None)
"""
if vector is None:
return None
hook = SteeringHook(model, layer_idx, hf_id)
hook.set_vector(vector, alpha)
return hook
|