import os from pathlib import Path from typing import Optional, List, Tuple, Dict from dataclasses import dataclass, field from omegaconf import OmegaConf, MISSING from utils.class_registry import ClassRegistry from models.methods import methods_registry from metrics.metrics import metrics_registry args = ClassRegistry() @args.add_to_registry("exp") @dataclass class ExperimentArgs: config_dir: str = str(Path(__file__).resolve().parent / "configs") config: str = MISSING output_dir: str = "results_dir" seed: int = 1 root: str = os.getenv("EXP_ROOT", ".") domain: str = "human_faces" wandb: bool = False @args.add_to_registry("data") @dataclass class DataArgs: inference_dir: str = "" transform: str = "face_1024" @args.add_to_registry("inference") @dataclass class InferenceArgs: inference_runner: str = "base_inference_runner" editings_data: Dict = field(default_factory=lambda: {}) @args.add_to_registry("model") @dataclass class ModelArgs: method: str = "fse_full" device: str = "0" batch_size: int = 4 workers: int = 4 checkpoint_path: str = "" MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs") args.add_to_registry("methods_args")(MethodsArgs) MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs") args.add_to_registry("metrics")(MetricsArgs) Args = args.make_dataclass_from_classes("Args") def load_config(): config = OmegaConf.structured(Args) conf_cli = OmegaConf.from_cli() config.exp.config = conf_cli.exp.config config.exp.config_dir = conf_cli.exp.config_dir config_path = os.path.join(config.exp.config_dir, config.exp.config) conf_file = OmegaConf.load(config_path) config = OmegaConf.merge(config, conf_file) for method in list(config.methods_args.keys()): if method != config.model.method: config.methods_args.__delattr__(method) config = OmegaConf.merge(config, conf_cli) return config