Spaces:
Sleeping
Sleeping
| import os | |
| from functools import lru_cache | |
| from typing import Callable, Dict, Optional | |
| from .base import Segmenter | |
| from .sam3 import SAM3Segmenter | |
| DEFAULT_SEGMENTER = "sam3" | |
| _REGISTRY: Dict[str, Callable[[], Segmenter]] = { | |
| "sam3": SAM3Segmenter, | |
| } | |
| def _create_segmenter(name: str, **kwargs) -> Segmenter: | |
| """Create a new segmenter instance.""" | |
| try: | |
| factory = _REGISTRY[name] | |
| except KeyError as exc: | |
| available = ", ".join(sorted(_REGISTRY)) | |
| raise ValueError( | |
| f"Unknown segmenter '{name}'. Available: {available}" | |
| ) from exc | |
| return factory(**kwargs) | |
| def _get_cached_segmenter(name: str) -> Segmenter: | |
| """Get or create cached segmenter instance.""" | |
| return _create_segmenter(name) | |
| def load_segmenter(name: Optional[str] = None) -> Segmenter: | |
| """ | |
| Load a segmenter by name. | |
| Args: | |
| name: Segmenter name (default: sam3) | |
| Returns: | |
| Cached segmenter instance | |
| """ | |
| segmenter_name = name or os.getenv("SEGMENTER", DEFAULT_SEGMENTER) | |
| return _get_cached_segmenter(segmenter_name) | |
| def load_segmenter_on_device(name: str, device: str) -> Segmenter: | |
| """Create a new segmenter instance on the specified device (no caching).""" | |
| # bypass cache by calling private creator directly | |
| # Note: _create_segmenter calls factory() which needs to accept device now. | |
| # We need to update _create_segmenter to pass kwargs too. | |
| return _create_segmenter(name, device=device) | |