Zhen Ye
feat(inference): enable full multi-GPU support for all models
45eb65b
"""Registry and loader for depth estimators."""
from functools import lru_cache
from typing import Callable, Dict
from .base import DepthEstimator
from .depth_anything_v2 import DepthAnythingV2Estimator
# Registry of depth estimators
_REGISTRY: Dict[str, Callable[[], DepthEstimator]] = {
"depth": DepthAnythingV2Estimator,
}
@lru_cache(maxsize=None)
def _get_cached_depth_estimator(name: str) -> DepthEstimator:
"""
Create and cache depth estimator instance.
Args:
name: Depth estimator name (e.g., "depth")
Returns:
Depth estimator instance
"""
return _create_depth_estimator(name)
def _create_depth_estimator(name: str, **kwargs) -> DepthEstimator:
"""
Create depth estimator instance.
Args:
name: Depth estimator name
Returns:
Depth estimator instance
Raises:
KeyError: If estimator not found in registry
"""
if name not in _REGISTRY:
raise KeyError(
f"Depth estimator '{name}' not found. Available: {list(_REGISTRY.keys())}"
)
estimator_class = _REGISTRY[name]
return estimator_class(**kwargs)
def load_depth_estimator(name: str = "depth") -> DepthEstimator:
"""
Load depth estimator by name (with caching).
Args:
name: Depth estimator name (default: "depth")
Returns:
Cached depth estimator instance
"""
return _get_cached_depth_estimator(name)
def load_depth_estimator_on_device(name: str, device: str) -> DepthEstimator:
"""Create a new depth estimator instance on the specified device (no caching)."""
return _create_depth_estimator(name, device=device)
def list_depth_estimators() -> list[str]:
"""Return list of available depth estimator names."""
return list(_REGISTRY.keys())