| |
|
|
| import logging |
| from typing import Type, TypeVar |
|
|
| from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
| logger = logging.getLogger() |
|
|
| T = TypeVar("T") |
|
|
|
|
| def set_struct_recursively(cfg, strict: bool = True): |
| |
| OmegaConf.set_struct(cfg, strict) |
|
|
| |
| if isinstance(cfg, DictConfig): |
| for key, value in cfg.items(): |
| if isinstance(value, (DictConfig, ListConfig)): |
| set_struct_recursively(value, strict) |
| elif isinstance(cfg, ListConfig): |
| for item in cfg: |
| if isinstance(item, (DictConfig, ListConfig)): |
| set_struct_recursively(item, strict) |
|
|
|
|
| def flatten_dict(d, parent_key="", sep="_"): |
| items = [] |
| for k, v in d.items(): |
| new_key = f"{parent_key}{sep}{k}" if parent_key else k |
| if isinstance(v, dict): |
| items.extend(flatten_dict(v, new_key, sep=sep).items()) |
| else: |
| items.append((new_key, v)) |
| return dict(items) |
|
|
|
|
| def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: |
| """ |
| Converts a dictionary to a dataclass instance, recursively for nested structures. |
| """ |
| base = OmegaConf.structured(cls()) |
| OmegaConf.set_struct(base, strict) |
| override = OmegaConf.create(data) |
| return OmegaConf.to_object(OmegaConf.merge(base, override)) |
|
|
|
|
| def dataclass_to_dict(dataclass_instance: T) -> dict: |
| """ |
| Converts a dataclass instance to a dictionary, recursively for nested structures. |
| """ |
| if isinstance(dataclass_instance, dict): |
| return dataclass_instance |
|
|
| return OmegaConf.to_container( |
| OmegaConf.structured(dataclass_instance), resolve=True |
| ) |
|
|
|
|
| def load_config_file(config_file, dataclass_cls: Type[T]) -> T: |
| config = OmegaConf.to_container(OmegaConf.load(config_file), resolve=True) |
| return dataclass_from_dict(dataclass_cls, config) |
|
|
|
|
| def dump_config(config, path, log_config=True): |
| yaml_dump = OmegaConf.to_yaml(OmegaConf.structured(config)) |
| with open(path, "w") as f: |
| if log_config: |
| logger.info("Using the following config for this run:") |
| logger.info(yaml_dump) |
| f.write(yaml_dump) |
|
|