"""Registry and loader for depth estimators.""" from functools import lru_cache from typing import Callable, Dict from .base import DepthEstimator from .depth_anything_v2 import DepthAnythingV2Estimator # Registry of depth estimators _REGISTRY: Dict[str, Callable[[], DepthEstimator]] = { "depth": DepthAnythingV2Estimator, } @lru_cache(maxsize=None) def _get_cached_depth_estimator(name: str) -> DepthEstimator: """ Create and cache depth estimator instance. Args: name: Depth estimator name (e.g., "depth") Returns: Depth estimator instance """ return _create_depth_estimator(name) def _create_depth_estimator(name: str, **kwargs) -> DepthEstimator: """ Create depth estimator instance. Args: name: Depth estimator name Returns: Depth estimator instance Raises: KeyError: If estimator not found in registry """ if name not in _REGISTRY: raise KeyError( f"Depth estimator '{name}' not found. Available: {list(_REGISTRY.keys())}" ) estimator_class = _REGISTRY[name] return estimator_class(**kwargs) def load_depth_estimator(name: str = "depth") -> DepthEstimator: """ Load depth estimator by name (with caching). Args: name: Depth estimator name (default: "depth") Returns: Cached depth estimator instance """ return _get_cached_depth_estimator(name) def load_depth_estimator_on_device(name: str, device: str) -> DepthEstimator: """Create a new depth estimator instance on the specified device (no caching).""" return _create_depth_estimator(name, device=device) def list_depth_estimators() -> list[str]: """Return list of available depth estimator names.""" return list(_REGISTRY.keys())