File size: 3,987 Bytes
0343ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Model Factory for MapAnything
"""

import importlib.util
import logging
import warnings

import numpy as np
from omegaconf import DictConfig, OmegaConf

# Core models that are always available
from mapanything.models.any4d.models.model import Any4D

# Suppress DINOv2 warnings
logging.getLogger("dinov2").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
warnings.filterwarnings(
    "ignore", message="xFormers is not available", category=UserWarning
)


def resolve_special_float(value):
    if value == "inf":
        return np.inf
    elif value == "-inf":
        return -np.inf
    else:
        raise ValueError(f"Unknown special float value: {value}")


def init_model(
    model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
):
    """
    Initialize a model using OmegaConf configuration.

    Args:
        model_str (str): Name of the model class to create.
        model_config (DictConfig): OmegaConf model configuration.
        torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
    """
    if not OmegaConf.has_resolver("special_float"):
        OmegaConf.register_new_resolver("special_float", resolve_special_float)
    model_dict = OmegaConf.to_container(model_config, resolve=True)
    model = model_factory(
        model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
    )

    return model


# Define model configurations with import paths
MODEL_CONFIGS = {
    "any4d": {
        "class": Any4D,
    },
}


def check_module_exists(module_path):
    """
    Check if a module can be imported without actually importing it.

    Args:
        module_path (str): The path to the module to check.

    Returns:
        bool: True if the module can be imported, False otherwise.
    """
    return importlib.util.find_spec(module_path) is not None


def model_factory(model_str: str, **kwargs):
    """
    Model factory for MapAnything.

    Args:
        model_str (str): Name of the model to create.
        **kwargs: Additional keyword arguments to pass to the model constructor.

    Returns:
       nn.Module: An instance of the specified model.
    """
    if model_str not in MODEL_CONFIGS:
        raise ValueError(
            f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
        )

    model_config = MODEL_CONFIGS[model_str]

    # Handle core models directly
    if "class" in model_config:
        model_class = model_config["class"]
    # Handle external models with dynamic imports
    elif "module" in model_config:
        module_path = model_config["module"]
        class_name = model_config["class_name"]

        # Check if the module can be imported
        if not check_module_exists(module_path):
            raise ImportError(
                f"Model '{model_str}' requires module '{module_path}' which is not installed. "
                f"Please install the corresponding submodule or package."
            )

        # Dynamically import the module and get the class
        try:
            module = importlib.import_module(module_path)
            model_class = getattr(module, class_name)
        except (ImportError, AttributeError) as e:
            raise ImportError(
                f"Failed to import {class_name} from {module_path}: {str(e)}"
            )
    else:
        raise ValueError(f"Invalid model configuration for {model_str}")

    print(f"Initializing {model_class} with kwargs: {kwargs}")
    if model_str != "org_dust3r":
        return model_class(**kwargs)
    else:
        eval_str = kwargs.get("model_eval_str", None)
        return eval(eval_str)


def get_available_models() -> list:
    """
    Get a list of available models in MapAnything.

    Returns:
        list: A list of available model names.
    """
    return list(MODEL_CONFIGS.keys())


__all__ = ["model_factory", "get_available_models"]