Smile_Changer / arguments /training_arguments.py
LogicGoInfotechSpaces's picture
Bundle StyleFeatureEditor code packages in Space to fix ModuleNotFoundError
95b1715
import os
from training.losses import disc_losses
from training.optimizers import optimizers
from pathlib import Path
from typing import Optional, List, Tuple, Dict
from dataclasses import dataclass, field
from omegaconf import OmegaConf, MISSING
from utils.class_registry import ClassRegistry
from models.methods import methods_registry
from metrics.metrics import metrics_registry
args = ClassRegistry()
@args.add_to_registry("exp")
@dataclass
class ExperimentArgs:
config_dir: str = str(Path(__file__).resolve().parent / "configs")
config: str = MISSING
exp_dir: str = "experiments"
name: str = MISSING
seed: int = 1
root: str = os.getenv("EXP_ROOT", ".")
wandb: bool = True
wandb_project: str = "sfe"
domain: str = "human_faces"
@args.add_to_registry("data")
@dataclass
class DataArgs:
special_dir: str = MISSING
transform: str = "face_1024"
input_train_dir: str = MISSING
input_val_dir: str = MISSING
@args.add_to_registry("train")
@dataclass
class TrainingArgs:
train_runner: str = "base_training_runner"
encoder_optimizer: str = "ranger"
disc_optimizer: str = "adam"
resume_path: str = ""
val_metrics: List[str] = field(
default_factory=lambda: ["msssim", "lpips", "l2", "fid"]
)
start_step: int = 0
steps: int = 300000
log_step: int = 500
checkpoint_step: int = 15000
val_step: int = 15000
train_dis: bool = False
dis_train_start_step: int = 150000
bs_used_before_adv_loss: int = 8
disc_edits: List[str] = field(
default_factory=lambda: []
)
@args.add_to_registry("model")
@dataclass
class ModelArgs:
method: str = "fse_full"
device: str = "0"
batch_size: int = 4
workers: int = 4
checkpoint_path: str = ""
@args.add_to_registry("encoder_losses")
@dataclass
class EncoderLossesArgs:
l2: float = 0.0
lpips: float = 0.0
lpips_scale: float = 0.0
id: float = 0.0
moco: float = 0.0
adv: float = 0.0
feat_rec: float = 0.0
feat_rec_l1: float = 0.0
l2_latent: float = 0.0
id_vit: float = 0.0
MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs")
args.add_to_registry("methods_args")(MethodsArgs)
DiscLossesArgs = disc_losses.make_dataclass_from_args("DiscLossesArgs")
args.add_to_registry("disc_losses")(DiscLossesArgs)
OptimizersArgs = optimizers.make_dataclass_from_args("OptimizersArgs")
args.add_to_registry("optimizers")(OptimizersArgs)
MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs")
args.add_to_registry("metrics")(MetricsArgs)
Args = args.make_dataclass_from_classes("Args")
def load_config():
config = OmegaConf.structured(Args)
conf_cli = OmegaConf.from_cli()
config.exp.config = conf_cli.exp.config
config.exp.config_dir = conf_cli.exp.config_dir
config_path = os.path.join(config.exp.config_dir, config.exp.config)
conf_file = OmegaConf.load(config_path)
config = OmegaConf.merge(config, conf_file)
for method in list(config.methods_args.keys()):
if method != config.model.method:
config.methods_args.__delattr__(method)
config = OmegaConf.merge(config, conf_cli)
return config