| | |
| |
|
| | import collections.abc as abc |
| | import dataclasses |
| | import logging |
| | from typing import Any |
| |
|
| | from detectron2.utils.registry import _convert_target_to_string, locate |
| |
|
| | __all__ = ["dump_dataclass", "instantiate"] |
| |
|
| |
|
| | def dump_dataclass(obj: Any): |
| | """ |
| | Dump a dataclass recursively into a dict that can be later instantiated. |
| | |
| | Args: |
| | obj: a dataclass object |
| | |
| | Returns: |
| | dict |
| | """ |
| | assert dataclasses.is_dataclass(obj) and not isinstance( |
| | obj, type |
| | ), "dump_dataclass() requires an instance of a dataclass." |
| | ret = {"_target_": _convert_target_to_string(type(obj))} |
| | for f in dataclasses.fields(obj): |
| | v = getattr(obj, f.name) |
| | if dataclasses.is_dataclass(v): |
| | v = dump_dataclass(v) |
| | if isinstance(v, (list, tuple)): |
| | v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] |
| | ret[f.name] = v |
| | return ret |
| |
|
| |
|
| | def instantiate(cfg): |
| | """ |
| | Recursively instantiate objects defined in dictionaries by |
| | "_target_" and arguments. |
| | |
| | Args: |
| | cfg: a dict-like object with "_target_" that defines the caller, and |
| | other keys that define the arguments |
| | |
| | Returns: |
| | object instantiated by cfg |
| | """ |
| | from omegaconf import ListConfig, DictConfig, OmegaConf |
| |
|
| | if isinstance(cfg, ListConfig): |
| | lst = [instantiate(x) for x in cfg] |
| | return ListConfig(lst, flags={"allow_objects": True}) |
| | if isinstance(cfg, list): |
| | |
| | |
| | return [instantiate(x) for x in cfg] |
| |
|
| | |
| | |
| | if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type): |
| | return OmegaConf.to_object(cfg) |
| |
|
| | if isinstance(cfg, abc.Mapping) and "_target_" in cfg: |
| | |
| | |
| | cfg = {k: instantiate(v) for k, v in cfg.items()} |
| | cls = cfg.pop("_target_") |
| | cls = instantiate(cls) |
| |
|
| | if isinstance(cls, str): |
| | cls_name = cls |
| | cls = locate(cls_name) |
| | assert cls is not None, cls_name |
| | else: |
| | try: |
| | cls_name = cls.__module__ + "." + cls.__qualname__ |
| | except Exception: |
| | |
| | cls_name = str(cls) |
| | assert callable(cls), f"_target_ {cls} does not define a callable object" |
| | try: |
| | return cls(**cfg) |
| | except TypeError: |
| | logger = logging.getLogger(__name__) |
| | logger.error(f"Error when instantiating {cls_name}!") |
| | raise |
| | return cfg |
| |
|