abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""Activation steering hook — injects steering vectors into model residual stream.
This module provides the core mechanism for all steering methods:
a forward hook that adds α * v to the hidden state at a specified layer.
"""
import logging
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import numpy as np
logger = logging.getLogger(__name__)
class SteeringHook:
"""Forward hook that adds a steering vector to the residual stream.
Usage:
hook = SteeringHook(model, layer_idx=16, hf_id="llava-hf/llava-1.5-7b-hf")
hook.set_vector(v, alpha=1.5)
# model forward pass — steering is applied automatically
hook.remove()
"""
def __init__(
self,
model: nn.Module,
layer_idx: int,
hf_id: str = "llava-hf/llava-1.5-7b-hf",
):
self.model = model
self.layer_idx = layer_idx
self.hf_id = hf_id
self._vector: Optional[torch.Tensor] = None
self._alpha: float = 1.0
self._handle = None
self._active = False
# Get the target layer module
self._layer_module = self._resolve_layer(model, layer_idx, hf_id)
@staticmethod
def _resolve_layer(model: nn.Module, layer_idx: int, hf_id: str) -> nn.Module:
"""Resolve the layer module from model architecture.
Handles different backbone architectures:
- LLaVA-1.5: model.model.layers.{idx} (within LlavaForConditionalGeneration)
- Qwen2.5-VL: model.language_model.layers.{idx}
- Gemma-3: model.language_model.layers.{idx}
"""
# Try different paths (order matters — more specific first)
paths_to_try = [
f"model.language_model.layers.{layer_idx}", # Gemma-3, Qwen2.5-VL (conditional generation)
f"model.model.layers.{layer_idx}", # LLaVA wrapped
f"model.layers.{layer_idx}", # Standard (Qwen2, etc.)
f"transformer.h.{layer_idx}", # GPT-style
]
for path in paths_to_try:
try:
module = model
for attr in path.split("."):
if attr.isdigit():
module = module[int(attr)]
else:
module = getattr(module, attr)
logger.info(f"Resolved layer {layer_idx} at path: {path}")
return module
except (AttributeError, IndexError, TypeError):
continue
raise ValueError(
f"Could not resolve layer {layer_idx} for {hf_id}. "
f"Run print_layer_names() to verify the correct path."
)
def set_vector(self, vector: np.ndarray, alpha: float = 1.0):
"""Set the steering vector and magnitude.
Args:
vector: (d,) steering direction
alpha: steering magnitude (α)
"""
if isinstance(vector, np.ndarray):
vector = torch.from_numpy(vector)
self._vector = vector.float()
self._alpha = alpha
self._install_hook()
def _hook_fn(self, module, input, output):
"""Forward hook that adds α * v to the hidden states."""
if self._vector is None or not self._active:
return output
# Handle different output formats
if isinstance(output, tuple):
hidden_states = output[0]
else:
hidden_states = output
device = hidden_states.device
dtype = hidden_states.dtype
v = self._vector.to(device=device, dtype=dtype)
# Add steering vector to all token positions
# hidden_states shape: (batch, seq_len, hidden_dim)
hidden_states = hidden_states + self._alpha * v.unsqueeze(0).unsqueeze(0)
if isinstance(output, tuple):
return (hidden_states,) + output[1:]
return hidden_states
def _install_hook(self):
"""Install the forward hook on the target layer."""
if self._handle is not None:
self._handle.remove()
self._handle = self._layer_module.register_forward_hook(self._hook_fn)
self._active = True
logger.debug(f"Steering hook installed at layer {self.layer_idx} (α={self._alpha})")
def activate(self):
"""Activate the steering hook."""
self._active = True
def deactivate(self):
"""Temporarily deactivate without removing."""
self._active = False
def remove(self):
"""Remove the hook entirely."""
if self._handle is not None:
self._handle.remove()
self._handle = None
self._active = False
logger.debug(f"Steering hook removed from layer {self.layer_idx}")
def __del__(self):
self.remove()
def apply_steering(
model: nn.Module,
vector: Optional[np.ndarray],
layer_idx: int,
alpha: float,
hf_id: str,
) -> Optional[SteeringHook]:
"""Convenience function to apply a steering vector.
Args:
model: The backbone model
vector: Steering vector (d,) or None (for prompt-only methods)
layer_idx: Target layer index
alpha: Steering magnitude
hf_id: HuggingFace model identifier
Returns:
SteeringHook (or None if vector is None)
"""
if vector is None:
return None
hook = SteeringHook(model, layer_idx, hf_id)
hook.set_vector(vector, alpha)
return hook