abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""Abstract base class for all steering methods."""
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
class SteeringMethod(ABC):
"""Base class for activation steering methods.
Subclasses implement `extract_vector()` which takes probing data
(H_pos, H_neg) and returns a steering direction vector.
"""
@property
@abstractmethod
def name(self) -> str:
"""Human-readable method name."""
...
@property
@abstractmethod
def method_id(self) -> str:
"""Method ID (M0–M11)."""
...
@property
def is_training_free(self) -> bool:
"""Whether this method requires no training."""
return True
@abstractmethod
def extract_vector(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
**kwargs,
) -> Optional[np.ndarray]:
"""Compute the steering vector from probing data.
Args:
h_pos: Positive hidden states, shape (N, d)
h_neg: Negative hidden states, shape (N, d)
Returns:
Steering vector of shape (d,), or None for prompt-only methods.
"""
...
def train(self, train_data: dict) -> None:
"""Train the method (for non-training-free methods).
Default: no-op. Override in M8, M9, M10.
"""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}(id={self.method_id}, training_free={self.is_training_free})"