File size: 2,746 Bytes
b30e7a3
 
6268ac2
b30e7a3
 
3fde4e4
b30e7a3
21c29ae
6268ac2
 
 
 
 
 
f89fa0b
 
 
6268ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b30e7a3
3fde4e4
6268ac2
 
b30e7a3
 
 
6268ac2
 
 
 
 
 
 
 
 
 
 
 
45eb65b
b30e7a3
 
 
 
 
 
 
 
45eb65b
b30e7a3
 
 
 
 
 
 
 
 
 
 
 
 
21c29ae
b30e7a3
 
 
 
 
 
45eb65b
 
1c6c619
45eb65b
1c6c619
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from functools import lru_cache
from typing import Callable, Dict, Optional, Tuple

from .base import Segmenter
from .grounded_sam2 import GroundedSAM2Segmenter

DEFAULT_SEGMENTER = "GSAM2-L"
_DEFAULT_SEGMENTER_DETECTOR = "grounding_dino"

_SEGMENTER_SPECS: Dict[str, Tuple[str, Optional[str]]] = {
    "GSAM2-S": ("small", None),
    "GSAM2-B": ("base", None),
    "GSAM2-L": ("large", None),
    "YSAM2-S": ("small", "yolo11"),
    "YSAM2-B": ("base", "yolo11"),
    "YSAM2-L": ("large", "yolo11"),
}


def _build_factory(model_size: str, detector_name: Optional[str]) -> Callable[..., Segmenter]:
    def _factory(**kw) -> Segmenter:
        if detector_name is None:
            return GroundedSAM2Segmenter(model_size=model_size, **kw)
        # YSAM2 keys own detector selection; drop accidental caller overrides.
        kw.pop("detector_name", None)
        return GroundedSAM2Segmenter(
            model_size=model_size,
            detector_name=detector_name,
            **kw,
        )

    return _factory


_REGISTRY: Dict[str, Callable[..., Segmenter]] = {
    name: _build_factory(model_size, detector_name)
    for name, (model_size, detector_name) in _SEGMENTER_SPECS.items()
}


def get_segmenter_detector(segmenter_name: str) -> str:
    """Return the detector key associated with a segmenter (for mission parsing)."""
    spec = _SEGMENTER_SPECS.get(segmenter_name)
    if spec is None:
        available = ", ".join(sorted(_REGISTRY))
        raise ValueError(
            f"Unknown segmenter '{segmenter_name}'. Available: {available}"
        )
    detector_name = spec[1]
    return detector_name or _DEFAULT_SEGMENTER_DETECTOR


def _create_segmenter(name: str, **kwargs) -> Segmenter:
    """Create a new segmenter instance."""
    try:
        factory = _REGISTRY[name]
    except KeyError as exc:
        available = ", ".join(sorted(_REGISTRY))
        raise ValueError(
            f"Unknown segmenter '{name}'. Available: {available}"
        ) from exc
    return factory(**kwargs)


@lru_cache(maxsize=None)
def _get_cached_segmenter(name: str) -> Segmenter:
    """Get or create cached segmenter instance."""
    return _create_segmenter(name)


def load_segmenter(name: Optional[str] = None) -> Segmenter:
    """
    Load a segmenter by name.

    Args:
        name: Segmenter name (default: GSAM2-L)

    Returns:
        Cached segmenter instance
    """
    segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER)
    return _get_cached_segmenter(segmenter_name)


def load_segmenter_on_device(name: str, device: str, **kwargs) -> Segmenter:
    """Create a new segmenter instance on the specified device (no caching)."""
    return _create_segmenter(name, device=device, **kwargs)