File size: 1,799 Bytes
b8fe2b6
 
 
 
 
 
1c4206e
b8fe2b6
 
 
 
5e832fe
b8fe2b6
 
 
 
 
 
 
 
 
5e832fe
b8fe2b6
 
 
 
 
 
 
45eb65b
b8fe2b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45eb65b
b8fe2b6
 
5e832fe
b8fe2b6
 
 
 
5e832fe
b8fe2b6
 
 
 
 
 
 
45eb65b
 
 
 
 
b8fe2b6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Registry and loader for depth estimators."""

from functools import lru_cache
from typing import Callable, Dict

from .base import DepthEstimator
from .depth_anything_v2 import DepthAnythingV2Estimator


# Registry of depth estimators
_REGISTRY: Dict[str, Callable[[], DepthEstimator]] = {
    "depth": DepthAnythingV2Estimator,
}


@lru_cache(maxsize=None)
def _get_cached_depth_estimator(name: str) -> DepthEstimator:
    """
    Create and cache depth estimator instance.

    Args:
        name: Depth estimator name (e.g., "depth")

    Returns:
        Depth estimator instance
    """
    return _create_depth_estimator(name)


def _create_depth_estimator(name: str, **kwargs) -> DepthEstimator:
    """
    Create depth estimator instance.

    Args:
        name: Depth estimator name

    Returns:
        Depth estimator instance

    Raises:
        KeyError: If estimator not found in registry
    """
    if name not in _REGISTRY:
        raise KeyError(
            f"Depth estimator '{name}' not found. Available: {list(_REGISTRY.keys())}"
        )

    estimator_class = _REGISTRY[name]
    return estimator_class(**kwargs)


def load_depth_estimator(name: str = "depth") -> DepthEstimator:
    """
    Load depth estimator by name (with caching).

    Args:
        name: Depth estimator name (default: "depth")

    Returns:
        Cached depth estimator instance
    """
    return _get_cached_depth_estimator(name)


def load_depth_estimator_on_device(name: str, device: str) -> DepthEstimator:
    """Create a new depth estimator instance on the specified device (no caching)."""
    return _create_depth_estimator(name, device=device)


def list_depth_estimators() -> list[str]:
    """Return list of available depth estimator names."""
    return list(_REGISTRY.keys())