Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import copy | |
| import importlib | |
| import inspect | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, Optional | |
| from attrdict import AttrDict | |
| from torch import nn | |
| logger: logging.Logger = logging.getLogger(__name__) | |
| def load_module( | |
| module_name: str, class_name: Optional[str] = None, silent: bool = False | |
| ): | |
| """ | |
| Load a module or class given the module/class name. | |
| Example: | |
| .. code-block:: python | |
| eye_geo = load_class("path.to.module", "ClassName") | |
| Args: | |
| module_name: str | |
| The full path of the module relative to the root directory. Ex: ``utils.module_loader`` | |
| class_name: str | |
| The name of the class within the module to load. | |
| silent: bool | |
| If set to True, return None instead of raising an exception if module/class is missing | |
| Returns: | |
| object: | |
| The loaded module or class object. | |
| """ | |
| try: | |
| module = importlib.import_module(f"visualize.{module_name}") | |
| if class_name: | |
| return getattr(module, class_name) | |
| else: | |
| return module | |
| except ModuleNotFoundError as e: | |
| if silent: | |
| return None | |
| logger.error(f"Module not found: {module_name}", exc_info=True) | |
| raise | |
| except AttributeError as e: | |
| if silent: | |
| return None | |
| logger.error( | |
| f"Can not locate class: {class_name} in {module_name}.", exc_info=True | |
| ) | |
| raise | |
| # pyre-ignore[3] | |
| def make_module(mod_config: AttrDict, *args: Any, **kwargs: Any) -> Any: | |
| """ | |
| A shortcut for making an object given the config and arguments | |
| Args: | |
| mod_config: AttrDict | |
| Config. Should contain keys: module_name, class_name, and optionally args | |
| *args | |
| Positional arguments. | |
| **kwargs | |
| Default keyword arguments. Overwritten by content from mod_config.args | |
| Returns: | |
| object: | |
| The loaded module or class object. | |
| """ | |
| mod_config_dict = dict(mod_config) | |
| mod_args = mod_config_dict.pop("args", {}) | |
| mod_args.update({k: v for k, v in kwargs.items() if k not in mod_args.keys()}) | |
| mod_class = load_module(**mod_config_dict) | |
| return mod_class(*args, **mod_args) | |
| def get_full_name(mod: object) -> str: | |
| """ | |
| Returns a name of an object in a form <module>.<parent_scope>.<name> | |
| """ | |
| mod_class = mod.__class__ | |
| return f"{mod_class.__module__}.{mod_class.__qualname__}" | |
| # pyre-fixme[3]: Return type must be annotated. | |
| def load_class(class_name: str): | |
| """ | |
| Load a class given the full class name. | |
| Example: | |
| .. code-block:: python | |
| class_instance = load_class("module.path.ClassName") | |
| Args: | |
| class_name: txt | |
| The full class name including the full path of the module relative to the root directory. | |
| Returns: | |
| A class | |
| """ | |
| # This is a false-positive, pyre doesn't understand rsplit(..., 1) can only have 1-2 elements | |
| # pyre-fixme[6]: In call `load_module`, for 1st positional only parameter expected `bool` but got `str`. | |
| return load_module(*class_name.rsplit(".", 1)) | |
| class ObjectSpec: | |
| """ | |
| Args: | |
| class_name: str | |
| The full class name including the full path of the module relative to | |
| the root directory or just the name of the class within the module to | |
| load when module name is also provided. | |
| module_name: str | |
| The full path of the module relative to the root directory. Ex: ``utils.module_loader`` | |
| kwargs: dict | |
| Keyword arguments for initializing the object. | |
| """ | |
| class_name: str | |
| module_name: Optional[str] = None | |
| kwargs: Dict[str, Any] = field(default_factory=dict) | |
| # pyre-fixme[3]: Return type must be annotated. | |
| def load_object(spec: ObjectSpec, **kwargs: Any): | |
| """ | |
| Instantiate an object given the class name and initialization arguments. | |
| Example: | |
| .. code-block:: python | |
| my_model = load_object(ObjectSpec(**my_model_config), in_channels=3) | |
| Args: | |
| spec: ObjectSpec | |
| An ObjectSpec object that specifies the class name and init arguments. | |
| kwargs: dict | |
| Additional keyword arguments for initialization. | |
| Returns: | |
| An object | |
| """ | |
| if spec.module_name is None: | |
| object_class = load_class(spec.class_name) | |
| else: | |
| object_class = load_module(spec.module_name, spec.class_name) | |
| # Debug message for overriding the object spec | |
| for key in kwargs: | |
| if key in spec.kwargs: | |
| logger.debug(f"Overriding {key} as {kwargs[key]} in {spec}.") | |
| return object_class(**{**spec.kwargs, **kwargs}) | |
| # From DaaT merge. Fix here T145981161 | |
| # pyre-fixme[2]: parameter must be annotated. | |
| # pyre-fixme[3]: Return type must be annotated. | |
| def load_from_config(config: AttrDict, **kwargs): | |
| """Instantiate an object given a config and arguments.""" | |
| assert "class_name" in config and "module_name" not in config | |
| config = copy.deepcopy(config) | |
| class_name = config.pop("class_name") | |
| object_class = load_class(class_name) | |
| return object_class(**config, **kwargs) | |
| # From DaaT merge. Fix here T145981161 | |
| # pyre-fixme[2]: parameter must be annotated. | |
| # pyre-fixme[3]: Return type must be annotated. | |
| def forward_parameter_names(module): | |
| """Get the names arguments of the forward pass for the module. | |
| Args: | |
| module: a class with `forward()` method | |
| """ | |
| names = [] | |
| params = list(inspect.signature(module.forward).parameters.values())[1:] | |
| for p in params: | |
| if p.name in {"*args", "**kwargs"}: | |
| raise ValueError("*args and **kwargs are not supported") | |
| names.append(p.name) | |
| return names | |
| # From DaaT merge. Fix here T145981161 | |
| def build_optimizer(config, model): | |
| """Build an optimizer given optimizer config and a model. | |
| Args: | |
| config: DictConfig | |
| model: nn.Module|Dict[str,nn.Module] | |
| """ | |
| config = copy.deepcopy(config) | |
| if isinstance(model, nn.Module): | |
| if "per_module" in config: | |
| params = [] | |
| for name, value in config.per_module.items(): | |
| if not hasattr(model, name): | |
| logger.warning( | |
| f"model {model.__class__} does not have a submodule {name}, skipping" | |
| ) | |
| continue | |
| params.append( | |
| dict( | |
| params=getattr(model, name).parameters(), | |
| **value, | |
| ) | |
| ) | |
| defined_names = set(config.per_module.keys()) | |
| for name, module in model.named_children(): | |
| n_params = len(list(module.named_parameters())) | |
| if name not in defined_names and n_params: | |
| logger.warning( | |
| f"not going to optimize module {name} which has {n_params} parameters" | |
| ) | |
| config.pop("per_module") | |
| else: | |
| params = model.parameters() | |
| else: | |
| # NOTE: can we do | |
| assert "per_module" in config | |
| assert isinstance(model, dict) | |
| for name, value in config.per_module.items(): | |
| params = [] | |
| for name, value in config.per_module.items(): | |
| if name not in model: | |
| logger.warning(f"not aware of {name}, skipping") | |
| continue | |
| params.append( | |
| dict( | |
| params=model[name].parameters(), | |
| **value, | |
| ) | |
| ) | |
| return load_from_config(config, params=params) | |
| # From DaaT merge. Fix here T145981161 | |
| class ForwardFilter: | |
| """A module that filters out arguments for the `forward()`.""" | |
| # pyre-ignore | |
| def __init__(self, module, optional: bool = False) -> None: | |
| # pyre-ignore | |
| self.module = module | |
| # pyre-ignore | |
| self.input_names = set(forward_parameter_names(module)) | |
| # pyre-ignore | |
| def __call__(self, **kwargs): | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.input_names} | |
| return self.module(**filtered_kwargs) | |