abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""M8 — ReFT: Rank-1 Representation Finetuning.
Lightly-trained method — trains a rank-1 intervention on 200 style-captioned pairs.
"""
import numpy as np
from src.methods.base import SteeringMethod
class ReFT(SteeringMethod):
"""ReFT-r1 — Rank-1 Representation Finetuning."""
def __init__(self, rank: int = 1, **kwargs):
self.rank = rank
self._trained = False
self._direction: np.ndarray = None
@property
def name(self) -> str:
return "ReFT-r1"
@property
def method_id(self) -> str:
return "M8"
@property
def is_training_free(self) -> bool:
return False
def extract_vector(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
**kwargs,
) -> np.ndarray:
"""Extract vector from trained ReFT direction.
If not trained, falls back to DiffMean as initialisation.
"""
if self._direction is not None:
return self._direction
# Fallback: use diff-mean direction as initialisation
from src.methods.diffmean import DiffMean
dm = DiffMean()
return dm.extract_vector(h_pos, h_neg)
def train(self, train_data: dict) -> None:
"""Train ReFT intervention.
Args:
train_data: dict with keys:
- "model": the backbone model
- "processor": the tokenizer/processor
- "train_pairs": list of (image, caption, style) tuples
- "layer": target layer index
- "epochs": number of training epochs
"""
# TODO: Implement rank-1 low-rank linear subspace intervention
# Following: Wu et al., "Representation Engineering" and
# AxBench ReFT implementation
raise NotImplementedError(
"ReFT training requires model access. "
"Implement in full experiment run with GPU."
)