Spaces:
Running
Running
File size: 1,502 Bytes
e6f24ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | """Abstract base class for all steering methods."""
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
class SteeringMethod(ABC):
"""Base class for activation steering methods.
Subclasses implement `extract_vector()` which takes probing data
(H_pos, H_neg) and returns a steering direction vector.
"""
@property
@abstractmethod
def name(self) -> str:
"""Human-readable method name."""
...
@property
@abstractmethod
def method_id(self) -> str:
"""Method ID (M0–M11)."""
...
@property
def is_training_free(self) -> bool:
"""Whether this method requires no training."""
return True
@abstractmethod
def extract_vector(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
**kwargs,
) -> Optional[np.ndarray]:
"""Compute the steering vector from probing data.
Args:
h_pos: Positive hidden states, shape (N, d)
h_neg: Negative hidden states, shape (N, d)
Returns:
Steering vector of shape (d,), or None for prompt-only methods.
"""
...
def train(self, train_data: dict) -> None:
"""Train the method (for non-training-free methods).
Default: no-op. Override in M8, M9, M10.
"""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}(id={self.method_id}, training_free={self.is_training_free})"
|