Spaces:
Paused
Paused
| import os | |
| from functools import lru_cache | |
| from typing import Callable, Dict, Optional, Tuple | |
| from .base import Segmenter | |
| from .grounded_sam2 import GroundedSAM2Segmenter | |
| DEFAULT_SEGMENTER = "GSAM2-L" | |
| _DEFAULT_SEGMENTER_DETECTOR = "grounding_dino" | |
| _SEGMENTER_SPECS: Dict[str, Tuple[str, Optional[str]]] = { | |
| "GSAM2-S": ("small", None), | |
| "GSAM2-B": ("base", None), | |
| "GSAM2-L": ("large", None), | |
| "YSAM2-S": ("small", "yolo11"), | |
| "YSAM2-B": ("base", "yolo11"), | |
| "YSAM2-L": ("large", "yolo11"), | |
| } | |
| def _build_factory(model_size: str, detector_name: Optional[str]) -> Callable[..., Segmenter]: | |
| def _factory(**kw) -> Segmenter: | |
| if detector_name is None: | |
| return GroundedSAM2Segmenter(model_size=model_size, **kw) | |
| # YSAM2 keys own detector selection; drop accidental caller overrides. | |
| kw.pop("detector_name", None) | |
| return GroundedSAM2Segmenter( | |
| model_size=model_size, | |
| detector_name=detector_name, | |
| **kw, | |
| ) | |
| return _factory | |
| _REGISTRY: Dict[str, Callable[..., Segmenter]] = { | |
| name: _build_factory(model_size, detector_name) | |
| for name, (model_size, detector_name) in _SEGMENTER_SPECS.items() | |
| } | |
| def get_segmenter_detector(segmenter_name: str) -> str: | |
| """Return the detector key associated with a segmenter (for mission parsing).""" | |
| spec = _SEGMENTER_SPECS.get(segmenter_name) | |
| if spec is None: | |
| available = ", ".join(sorted(_REGISTRY)) | |
| raise ValueError( | |
| f"Unknown segmenter '{segmenter_name}'. Available: {available}" | |
| ) | |
| detector_name = spec[1] | |
| return detector_name or _DEFAULT_SEGMENTER_DETECTOR | |
| def _create_segmenter(name: str, **kwargs) -> Segmenter: | |
| """Create a new segmenter instance.""" | |
| try: | |
| factory = _REGISTRY[name] | |
| except KeyError as exc: | |
| available = ", ".join(sorted(_REGISTRY)) | |
| raise ValueError( | |
| f"Unknown segmenter '{name}'. Available: {available}" | |
| ) from exc | |
| return factory(**kwargs) | |
| def _get_cached_segmenter(name: str) -> Segmenter: | |
| """Get or create cached segmenter instance.""" | |
| return _create_segmenter(name) | |
| def load_segmenter(name: Optional[str] = None) -> Segmenter: | |
| """ | |
| Load a segmenter by name. | |
| Args: | |
| name: Segmenter name (default: GSAM2-L) | |
| Returns: | |
| Cached segmenter instance | |
| """ | |
| segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER) | |
| return _get_cached_segmenter(segmenter_name) | |
| def load_segmenter_on_device(name: str, device: str, **kwargs) -> Segmenter: | |
| """Create a new segmenter instance on the specified device (no caching).""" | |
| return _create_segmenter(name, device=device, **kwargs) | |