import os from functools import lru_cache from typing import Callable, Dict, Optional from .base import Segmenter from .sam3 import SAM3Segmenter DEFAULT_SEGMENTER = "sam3" _REGISTRY: Dict[str, Callable[[], Segmenter]] = { "sam3": SAM3Segmenter, } def _create_segmenter(name: str, **kwargs) -> Segmenter: """Create a new segmenter instance.""" try: factory = _REGISTRY[name] except KeyError as exc: available = ", ".join(sorted(_REGISTRY)) raise ValueError( f"Unknown segmenter '{name}'. Available: {available}" ) from exc return factory(**kwargs) @lru_cache(maxsize=None) def _get_cached_segmenter(name: str) -> Segmenter: """Get or create cached segmenter instance.""" return _create_segmenter(name) def load_segmenter(name: Optional[str] = None) -> Segmenter: """ Load a segmenter by name. Args: name: Segmenter name (default: sam3) Returns: Cached segmenter instance """ segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER) return _get_cached_segmenter(segmenter_name) def load_segmenter_on_device(name: str, device: str) -> Segmenter: """Create a new segmenter instance on the specified device (no caching).""" # bypass cache by calling private creator directly # Note: _create_segmenter calls factory() which needs to accept device now. # We need to update _create_segmenter to pass kwargs too. return _create_segmenter(name, device=device)