import copy import dataclasses import enum import sys from typing import Any, Dict, Mapping, Type, TypeVar, get_args, get_origin, get_type_hints import yaml ConfigT = TypeVar("ConfigT", bound="Config") class Config: def __init__(self, **kwargs: Any): annotations = _collect_annotations(type(self)) for key, annotation in annotations.items(): if key in kwargs: value = kwargs[key] elif hasattr(type(self), key): value = copy.deepcopy(getattr(type(self), key)) else: continue setattr(self, key, _convert_value(annotation, value)) for key, value in kwargs.items(): if key not in annotations: setattr(self, key, value) if not hasattr(self, "pretrain_config"): self.pretrain_config = EmptyConfig() self.__post_init__() def __post_init__(self): return @property def empty(self) -> bool: return len(self.__dict__) == 0 def as_json(self) -> Dict[str, Any]: return { key: _to_json(value) for key, value in self.__dict__.items() if key != "pretrain_config" } @classmethod def from_dict(cls: Type[ConfigT], data: Mapping[str, Any]) -> ConfigT: return cls(**dict(data)) @classmethod def from_yaml(cls: Type[ConfigT], file_path: str) -> ConfigT: with open(file_path, "r") as file: data = yaml.safe_load(file) if data is None: data = {} return cls.from_dict(data) class EmptyConfig(Config): def __init__(self, **kwargs: Any): self.__dict__.update(kwargs) class Configurable: ConfigT = Any def __init__(self, config: Any): self.config = config @classmethod def __class_getitem__(cls, _): return cls class Template: ConfigT = Any @classmethod def __class_getitem__(cls, _): return cls def _collect_annotations(cls: Type[Any]) -> Dict[str, Any]: annotations: Dict[str, Any] = {} for base in reversed(cls.__mro__): if base is object: continue module = sys.modules.get(base.__module__) module_globals = vars(module) if module is not None else {} try: hints = get_type_hints(base, globalns=module_globals, localns=module_globals) except Exception: hints = getattr(base, "__annotations__", {}) annotations.update(hints) return annotations def _is_config_type(annotation: Any) -> bool: return isinstance(annotation, type) and issubclass(annotation, Config) def _convert_value(annotation: Any, value: Any) -> Any: if value is None: return None origin = get_origin(annotation) args = get_args(annotation) if _is_config_type(annotation) and isinstance(value, annotation): return value if _is_config_type(annotation) and isinstance(value, Mapping): return annotation(**value) if isinstance(annotation, type) and issubclass(annotation, enum.Enum) and not isinstance(value, annotation): return annotation(value) if origin is None: return value if origin in (list, tuple): item_type = args[0] if len(args) > 0 else Any converted = [_convert_value(item_type, item) for item in value] return converted if origin is list else tuple(converted) if origin is dict: key_type = args[0] if len(args) > 0 else Any val_type = args[1] if len(args) > 1 else Any return { _convert_value(key_type, key): _convert_value(val_type, item) for key, item in value.items() } if str(origin) in {"typing.Union", "types.UnionType"}: non_none_args = [arg for arg in args if arg is not type(None)] for arg in non_none_args: try: return _convert_value(arg, value) except Exception: continue return value return value def _to_json(value: Any) -> Any: if isinstance(value, Config): return value.as_json() if isinstance(value, enum.Enum): return value.value if dataclasses.is_dataclass(value): return dataclasses.asdict(value) if isinstance(value, dict): return {key: _to_json(item) for key, item in value.items()} if isinstance(value, list): return [_to_json(item) for item in value] if isinstance(value, tuple): return [_to_json(item) for item in value] return value