Zhen Ye
refactor: rename hf_yolov8 → yolo11 across codebase
f89fa0b
import os
from functools import lru_cache
from typing import Callable, Dict, Optional, Tuple
from .base import Segmenter
from .grounded_sam2 import GroundedSAM2Segmenter
DEFAULT_SEGMENTER = "GSAM2-L"
_DEFAULT_SEGMENTER_DETECTOR = "grounding_dino"
_SEGMENTER_SPECS: Dict[str, Tuple[str, Optional[str]]] = {
"GSAM2-S": ("small", None),
"GSAM2-B": ("base", None),
"GSAM2-L": ("large", None),
"YSAM2-S": ("small", "yolo11"),
"YSAM2-B": ("base", "yolo11"),
"YSAM2-L": ("large", "yolo11"),
}
def _build_factory(model_size: str, detector_name: Optional[str]) -> Callable[..., Segmenter]:
def _factory(**kw) -> Segmenter:
if detector_name is None:
return GroundedSAM2Segmenter(model_size=model_size, **kw)
# YSAM2 keys own detector selection; drop accidental caller overrides.
kw.pop("detector_name", None)
return GroundedSAM2Segmenter(
model_size=model_size,
detector_name=detector_name,
**kw,
)
return _factory
_REGISTRY: Dict[str, Callable[..., Segmenter]] = {
name: _build_factory(model_size, detector_name)
for name, (model_size, detector_name) in _SEGMENTER_SPECS.items()
}
def get_segmenter_detector(segmenter_name: str) -> str:
"""Return the detector key associated with a segmenter (for mission parsing)."""
spec = _SEGMENTER_SPECS.get(segmenter_name)
if spec is None:
available = ", ".join(sorted(_REGISTRY))
raise ValueError(
f"Unknown segmenter '{segmenter_name}'. Available: {available}"
)
detector_name = spec[1]
return detector_name or _DEFAULT_SEGMENTER_DETECTOR
def _create_segmenter(name: str, **kwargs) -> Segmenter:
"""Create a new segmenter instance."""
try:
factory = _REGISTRY[name]
except KeyError as exc:
available = ", ".join(sorted(_REGISTRY))
raise ValueError(
f"Unknown segmenter '{name}'. Available: {available}"
) from exc
return factory(**kwargs)
@lru_cache(maxsize=None)
def _get_cached_segmenter(name: str) -> Segmenter:
"""Get or create cached segmenter instance."""
return _create_segmenter(name)
def load_segmenter(name: Optional[str] = None) -> Segmenter:
"""
Load a segmenter by name.
Args:
name: Segmenter name (default: GSAM2-L)
Returns:
Cached segmenter instance
"""
segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER)
return _get_cached_segmenter(segmenter_name)
def load_segmenter_on_device(name: str, device: str, **kwargs) -> Segmenter:
"""Create a new segmenter instance on the specified device (no caching)."""
return _create_segmenter(name, device=device, **kwargs)