# Copyright (c) Meta Platforms, Inc. and affiliates. import functools from typing import Any, Callable, Union from omegaconf import DictConfig, ListConfig, OmegaConf from hydra.utils import instantiate TargetType = Union[str, type, Callable[..., Any]] ClassOrCallableType = Union[type, Callable[..., Any]] def dump_config(config: DictConfig, path: str = "./config.yaml"): txt = OmegaConf.to_yaml(config, sort_keys=True) with open(path, "w") as f: f.write(txt) def locate(path: str) -> Any: if path == "": raise ImportError("Empty path") import builtins from importlib import import_module parts = [part for part in path.split(".") if part] # load module part module = None for n in reversed(range(len(parts))): try: mod = ".".join(parts[:n]) module = import_module(mod) except Exception as e: if n == 0: raise ImportError(f"Error loading module '{path}'") from e continue if module: break if module: obj = module else: obj = builtins # load object path in module for part in parts[n:]: mod = mod + "." + part if not hasattr(obj, part): try: import_module(mod) except Exception as e: raise ImportError( f"Encountered error: `{e}` when loading module '{path}'" ) from e obj = getattr(obj, part) return obj def full_instance_name(instance: Any) -> str: return full_class_name(instance.__class__) def full_class_name(klass: Any) -> str: module = klass.__module__ if module == "builtins": return klass.__qualname__ # avoid outputs like 'builtins.str' return module + "." + klass.__qualname__ def ensure_is_subclass(child_class: type, parent_class: type) -> None: if not issubclass(child_class, parent_class): raise RuntimeError( f"class {full_class_name(child_class)} should be a subclass of {full_class_name(parent_class)}" ) def find_class_or_callable_from_target( target: TargetType, ) -> ClassOrCallableType: if isinstance(target, str): obj = locate(target) else: obj = target if (not isinstance(obj, type)) and (not callable(obj)): raise ValueError(f"Invalid type ({type(obj)}) found for {target}") return obj def find_and_ensure_is_subclass(target: TargetType, type_: type) -> ClassOrCallableType: klass = find_class_or_callable_from_target(target) ensure_is_subclass(klass, type_) return klass class StrictPartial: # remark : the `/` will handle the `path` argument name conflict (e.g. calling StrictPartial("a.b.c", ..., path="/a/b/c")) def __init__(self, path, /, *args, **kwargs): class_or_callable = find_class_or_callable_from_target(path) self._partial = functools.partial(class_or_callable, *args, **kwargs) def __call__(self, *args: Any, **kwargs: Any) -> Any: return self._partial(*args, **kwargs) class RecursivePartial: @staticmethod def replace_keys(config, key_mapping): def recurse(data): if isinstance(data, DictConfig): new_data = { key_mapping[k] if k in key_mapping else k: recurse(v) for k, v in data.items() } new_data = DictConfig(new_data) elif isinstance(data, ListConfig): new_data = ListConfig([recurse(item) for item in data]) elif type(data) in {bool, str, int, float, type(None)}: new_data = data else: raise RuntimeError(f"unknow type found : {type(data)}") return new_data return recurse(config) def __init__(self, config): self.config = RecursivePartial.replace_keys( config, {"_rpartial_target_": "_target_"} ) def __call__(self, *args: Any, **kwargs: Any) -> Any: return instantiate(self.config) class Partial(StrictPartial): # remark : allow `path` argument to be exposed for easier use def __init__(self, path, *args, **kwargs): super().__init__(path, *args, **kwargs) def subkey(mapping, key): return mapping[key] def make_set(*args): return set(args) def make_tuple(*args): return tuple(args) def make_list_from_kwargs(**kwargs): # Filter out None/null values to avoid issues with callbacks return [v for v in kwargs.values() if v is not None] def make_string(value): return str(value) def make_dict(**kwargs): return dict(kwargs) def get_item(data, key: str): return data[key] def get_attr(data, key: str): return getattr(data, key)