stylsteer-vlm / src /utils /registry.py
abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""Method registry — maps M0–M11 to their implementation classes."""
from typing import Dict, List, Optional, Type
# Lazy imports to avoid loading all methods at import time
_REGISTRY: Dict[str, str] = {
"M0": "src.methods.prompt_base.PromptBase",
"M1": "src.methods.captionsmiths.CaptionSmiths",
"M2": "src.methods.diffmean.DiffMean",
"M3": "src.methods.pca_dir.PCADir",
"M4": "src.methods.lat.LAT",
"M5": "src.methods.linear_probe.LinearProbe",
"M6": "src.methods.sae.SAE",
"M7": "src.methods.sae_auroc.SAEAUROC",
"M8": "src.methods.reft.ReFT",
"M9": "src.methods.reps.RePS",
"M10": "src.methods.hypersteer.HyperSteer",
"M11": "src.methods.snmf_diff.SNMFDiff",
}
TRAINING_FREE_IDS = ["M0", "M2", "M3", "M4", "M5", "M11"]
ALL_IDS = list(_REGISTRY.keys())
SKIP_DEFAULT = ["M6", "M7"] # SAE methods need pretrained checkpoint
def _import_class(dotted_path: str):
"""Dynamically import a class from a dotted path."""
module_path, class_name = dotted_path.rsplit(".", 1)
import importlib
module = importlib.import_module(module_path)
return getattr(module, class_name)
def get_method_class(method_id: str):
"""Get the class for a method ID (e.g., 'M2' → DiffMean)."""
if method_id not in _REGISTRY:
raise ValueError(f"Unknown method: {method_id}. Available: {list(_REGISTRY.keys())}")
return _import_class(_REGISTRY[method_id])
def get_method(method_id: str, **kwargs):
"""Instantiate a method by ID."""
cls = get_method_class(method_id)
return cls(**kwargs)
def get_methods(
mode: str = "training_free",
skip: Optional[List[str]] = None,
include_only: Optional[List[str]] = None,
) -> List:
"""Get a list of instantiated methods.
Args:
mode: "training_free" or "all"
skip: Method IDs to skip (e.g., ["M6", "M7"])
include_only: If set, only include these method IDs
Returns:
List of (method_id, method_instance) tuples
"""
if include_only is not None:
ids = include_only
elif mode == "training_free":
ids = TRAINING_FREE_IDS
else:
ids = ALL_IDS
skip = skip or []
results = []
for mid in ids:
if mid in skip:
continue
try:
method = get_method(mid)
results.append((mid, method))
except (ImportError, NotImplementedError) as e:
import logging
logging.getLogger(__name__).warning(f"Skipping {mid}: {e}")
return results
def list_methods(mode: str = "training_free") -> List[str]:
"""List available method IDs."""
if mode == "training_free":
return TRAINING_FREE_IDS.copy()
return ALL_IDS.copy()