Spaces:
Sleeping
Sleeping
| """Method registry — maps M0–M11 to their implementation classes.""" | |
| from typing import Dict, List, Optional, Type | |
| # Lazy imports to avoid loading all methods at import time | |
| _REGISTRY: Dict[str, str] = { | |
| "M0": "src.methods.prompt_base.PromptBase", | |
| "M1": "src.methods.captionsmiths.CaptionSmiths", | |
| "M2": "src.methods.diffmean.DiffMean", | |
| "M3": "src.methods.pca_dir.PCADir", | |
| "M4": "src.methods.lat.LAT", | |
| "M5": "src.methods.linear_probe.LinearProbe", | |
| "M6": "src.methods.sae.SAE", | |
| "M7": "src.methods.sae_auroc.SAEAUROC", | |
| "M8": "src.methods.reft.ReFT", | |
| "M9": "src.methods.reps.RePS", | |
| "M10": "src.methods.hypersteer.HyperSteer", | |
| "M11": "src.methods.snmf_diff.SNMFDiff", | |
| } | |
| TRAINING_FREE_IDS = ["M0", "M2", "M3", "M4", "M5", "M11"] | |
| ALL_IDS = list(_REGISTRY.keys()) | |
| SKIP_DEFAULT = ["M6", "M7"] # SAE methods need pretrained checkpoint | |
| def _import_class(dotted_path: str): | |
| """Dynamically import a class from a dotted path.""" | |
| module_path, class_name = dotted_path.rsplit(".", 1) | |
| import importlib | |
| module = importlib.import_module(module_path) | |
| return getattr(module, class_name) | |
| def get_method_class(method_id: str): | |
| """Get the class for a method ID (e.g., 'M2' → DiffMean).""" | |
| if method_id not in _REGISTRY: | |
| raise ValueError(f"Unknown method: {method_id}. Available: {list(_REGISTRY.keys())}") | |
| return _import_class(_REGISTRY[method_id]) | |
| def get_method(method_id: str, **kwargs): | |
| """Instantiate a method by ID.""" | |
| cls = get_method_class(method_id) | |
| return cls(**kwargs) | |
| def get_methods( | |
| mode: str = "training_free", | |
| skip: Optional[List[str]] = None, | |
| include_only: Optional[List[str]] = None, | |
| ) -> List: | |
| """Get a list of instantiated methods. | |
| Args: | |
| mode: "training_free" or "all" | |
| skip: Method IDs to skip (e.g., ["M6", "M7"]) | |
| include_only: If set, only include these method IDs | |
| Returns: | |
| List of (method_id, method_instance) tuples | |
| """ | |
| if include_only is not None: | |
| ids = include_only | |
| elif mode == "training_free": | |
| ids = TRAINING_FREE_IDS | |
| else: | |
| ids = ALL_IDS | |
| skip = skip or [] | |
| results = [] | |
| for mid in ids: | |
| if mid in skip: | |
| continue | |
| try: | |
| method = get_method(mid) | |
| results.append((mid, method)) | |
| except (ImportError, NotImplementedError) as e: | |
| import logging | |
| logging.getLogger(__name__).warning(f"Skipping {mid}: {e}") | |
| return results | |
| def list_methods(mode: str = "training_free") -> List[str]: | |
| """List available method IDs.""" | |
| if mode == "training_free": | |
| return TRAINING_FREE_IDS.copy() | |
| return ALL_IDS.copy() | |