abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""M6 — SAE: Top mean-activation Sparse Autoencoder feature.
Stub implementation — requires pretrained SAE for LLaVA-1.5.
Raises NotImplementedError unless a SAE checkpoint is provided.
"""
from typing import Optional
import numpy as np
from src.methods.base import SteeringMethod
class SAE(SteeringMethod):
"""SAE — Sparse Autoencoder feature selection (stub)."""
def __init__(self, sae_checkpoint: Optional[str] = None, **kwargs):
self.sae_checkpoint = sae_checkpoint
if sae_checkpoint is None:
import logging
logging.getLogger(__name__).warning(
"SAE (M6) requires a pretrained SAE checkpoint. "
"Skipped unless --sae-checkpoint is provided."
)
@property
def name(self) -> str:
return "SAE"
@property
def method_id(self) -> str:
return "M6"
def extract_vector(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
**kwargs,
) -> np.ndarray:
if self.sae_checkpoint is None:
raise NotImplementedError(
"SAE (M6) requires a pretrained SAE checkpoint for LLaVA-1.5. "
"Provide --sae-checkpoint or skip M6."
)
# TODO: Implement SAE feature selection when checkpoint available
raise NotImplementedError("SAE feature extraction not yet implemented.")