Spaces:
Sleeping
Sleeping
| import os | |
| from training.losses import disc_losses | |
| from training.optimizers import optimizers | |
| from pathlib import Path | |
| from typing import Optional, List, Tuple, Dict | |
| from dataclasses import dataclass, field | |
| from omegaconf import OmegaConf, MISSING | |
| from utils.class_registry import ClassRegistry | |
| from models.methods import methods_registry | |
| from metrics.metrics import metrics_registry | |
| args = ClassRegistry() | |
| class ExperimentArgs: | |
| config_dir: str = str(Path(__file__).resolve().parent / "configs") | |
| config: str = MISSING | |
| exp_dir: str = "experiments" | |
| name: str = MISSING | |
| seed: int = 1 | |
| root: str = os.getenv("EXP_ROOT", ".") | |
| wandb: bool = True | |
| wandb_project: str = "sfe" | |
| domain: str = "human_faces" | |
| class DataArgs: | |
| special_dir: str = MISSING | |
| transform: str = "face_1024" | |
| input_train_dir: str = MISSING | |
| input_val_dir: str = MISSING | |
| class TrainingArgs: | |
| train_runner: str = "base_training_runner" | |
| encoder_optimizer: str = "ranger" | |
| disc_optimizer: str = "adam" | |
| resume_path: str = "" | |
| val_metrics: List[str] = field( | |
| default_factory=lambda: ["msssim", "lpips", "l2", "fid"] | |
| ) | |
| start_step: int = 0 | |
| steps: int = 300000 | |
| log_step: int = 500 | |
| checkpoint_step: int = 15000 | |
| val_step: int = 15000 | |
| train_dis: bool = False | |
| dis_train_start_step: int = 150000 | |
| bs_used_before_adv_loss: int = 8 | |
| disc_edits: List[str] = field( | |
| default_factory=lambda: [] | |
| ) | |
| class ModelArgs: | |
| method: str = "fse_full" | |
| device: str = "0" | |
| batch_size: int = 4 | |
| workers: int = 4 | |
| checkpoint_path: str = "" | |
| class EncoderLossesArgs: | |
| l2: float = 0.0 | |
| lpips: float = 0.0 | |
| lpips_scale: float = 0.0 | |
| id: float = 0.0 | |
| moco: float = 0.0 | |
| adv: float = 0.0 | |
| feat_rec: float = 0.0 | |
| feat_rec_l1: float = 0.0 | |
| l2_latent: float = 0.0 | |
| id_vit: float = 0.0 | |
| MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs") | |
| args.add_to_registry("methods_args")(MethodsArgs) | |
| DiscLossesArgs = disc_losses.make_dataclass_from_args("DiscLossesArgs") | |
| args.add_to_registry("disc_losses")(DiscLossesArgs) | |
| OptimizersArgs = optimizers.make_dataclass_from_args("OptimizersArgs") | |
| args.add_to_registry("optimizers")(OptimizersArgs) | |
| MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs") | |
| args.add_to_registry("metrics")(MetricsArgs) | |
| Args = args.make_dataclass_from_classes("Args") | |
| def load_config(): | |
| config = OmegaConf.structured(Args) | |
| conf_cli = OmegaConf.from_cli() | |
| config.exp.config = conf_cli.exp.config | |
| config.exp.config_dir = conf_cli.exp.config_dir | |
| config_path = os.path.join(config.exp.config_dir, config.exp.config) | |
| conf_file = OmegaConf.load(config_path) | |
| config = OmegaConf.merge(config, conf_file) | |
| for method in list(config.methods_args.keys()): | |
| if method != config.model.method: | |
| config.methods_args.__delattr__(method) | |
| config = OmegaConf.merge(config, conf_cli) | |
| return config | |