import os from training.losses import disc_losses from training.optimizers import optimizers from pathlib import Path from typing import Optional, List, Tuple, Dict from dataclasses import dataclass, field from omegaconf import OmegaConf, MISSING from utils.class_registry import ClassRegistry from models.methods import methods_registry from metrics.metrics import metrics_registry args = ClassRegistry() @args.add_to_registry("exp") @dataclass class ExperimentArgs: config_dir: str = str(Path(__file__).resolve().parent / "configs") config: str = MISSING exp_dir: str = "experiments" name: str = MISSING seed: int = 1 root: str = os.getenv("EXP_ROOT", ".") wandb: bool = True wandb_project: str = "sfe" domain: str = "human_faces" @args.add_to_registry("data") @dataclass class DataArgs: special_dir: str = MISSING transform: str = "face_1024" input_train_dir: str = MISSING input_val_dir: str = MISSING @args.add_to_registry("train") @dataclass class TrainingArgs: train_runner: str = "base_training_runner" encoder_optimizer: str = "ranger" disc_optimizer: str = "adam" resume_path: str = "" val_metrics: List[str] = field( default_factory=lambda: ["msssim", "lpips", "l2", "fid"] ) start_step: int = 0 steps: int = 300000 log_step: int = 500 checkpoint_step: int = 15000 val_step: int = 15000 train_dis: bool = False dis_train_start_step: int = 150000 bs_used_before_adv_loss: int = 8 disc_edits: List[str] = field( default_factory=lambda: [] ) @args.add_to_registry("model") @dataclass class ModelArgs: method: str = "fse_full" device: str = "0" batch_size: int = 4 workers: int = 4 checkpoint_path: str = "" @args.add_to_registry("encoder_losses") @dataclass class EncoderLossesArgs: l2: float = 0.0 lpips: float = 0.0 lpips_scale: float = 0.0 id: float = 0.0 moco: float = 0.0 adv: float = 0.0 feat_rec: float = 0.0 feat_rec_l1: float = 0.0 l2_latent: float = 0.0 id_vit: float = 0.0 MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs") args.add_to_registry("methods_args")(MethodsArgs) DiscLossesArgs = disc_losses.make_dataclass_from_args("DiscLossesArgs") args.add_to_registry("disc_losses")(DiscLossesArgs) OptimizersArgs = optimizers.make_dataclass_from_args("OptimizersArgs") args.add_to_registry("optimizers")(OptimizersArgs) MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs") args.add_to_registry("metrics")(MetricsArgs) Args = args.make_dataclass_from_classes("Args") def load_config(): config = OmegaConf.structured(Args) conf_cli = OmegaConf.from_cli() config.exp.config = conf_cli.exp.config config.exp.config_dir = conf_cli.exp.config_dir config_path = os.path.join(config.exp.config_dir, config.exp.config) conf_file = OmegaConf.load(config_path) config = OmegaConf.merge(config, conf_file) for method in list(config.methods_args.keys()): if method != config.model.method: config.methods_args.__delattr__(method) config = OmegaConf.merge(config, conf_cli) return config