import os from functools import lru_cache from typing import Callable, Dict, Optional, Tuple from .base import Segmenter from .grounded_sam2 import GroundedSAM2Segmenter DEFAULT_SEGMENTER = "GSAM2-L" _DEFAULT_SEGMENTER_DETECTOR = "grounding_dino" _SEGMENTER_SPECS: Dict[str, Tuple[str, Optional[str]]] = { "GSAM2-S": ("small", None), "GSAM2-B": ("base", None), "GSAM2-L": ("large", None), "YSAM2-S": ("small", "yolo11"), "YSAM2-B": ("base", "yolo11"), "YSAM2-L": ("large", "yolo11"), } def _build_factory(model_size: str, detector_name: Optional[str]) -> Callable[..., Segmenter]: def _factory(**kw) -> Segmenter: if detector_name is None: return GroundedSAM2Segmenter(model_size=model_size, **kw) # YSAM2 keys own detector selection; drop accidental caller overrides. kw.pop("detector_name", None) return GroundedSAM2Segmenter( model_size=model_size, detector_name=detector_name, **kw, ) return _factory _REGISTRY: Dict[str, Callable[..., Segmenter]] = { name: _build_factory(model_size, detector_name) for name, (model_size, detector_name) in _SEGMENTER_SPECS.items() } def get_segmenter_detector(segmenter_name: str) -> str: """Return the detector key associated with a segmenter (for mission parsing).""" spec = _SEGMENTER_SPECS.get(segmenter_name) if spec is None: available = ", ".join(sorted(_REGISTRY)) raise ValueError( f"Unknown segmenter '{segmenter_name}'. Available: {available}" ) detector_name = spec[1] return detector_name or _DEFAULT_SEGMENTER_DETECTOR def _create_segmenter(name: str, **kwargs) -> Segmenter: """Create a new segmenter instance.""" try: factory = _REGISTRY[name] except KeyError as exc: available = ", ".join(sorted(_REGISTRY)) raise ValueError( f"Unknown segmenter '{name}'. Available: {available}" ) from exc return factory(**kwargs) @lru_cache(maxsize=None) def _get_cached_segmenter(name: str) -> Segmenter: """Get or create cached segmenter instance.""" return _create_segmenter(name) def load_segmenter(name: Optional[str] = None) -> Segmenter: """ Load a segmenter by name. Args: name: Segmenter name (default: GSAM2-L) Returns: Cached segmenter instance """ segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER) return _get_cached_segmenter(segmenter_name) def load_segmenter_on_device(name: str, device: str, **kwargs) -> Segmenter: """Create a new segmenter instance on the specified device (no caching).""" return _create_segmenter(name, device=device, **kwargs)