File size: 3,207 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from training.losses import disc_losses
from training.optimizers import optimizers
from pathlib import Path
from typing import Optional, List, Tuple, Dict
from dataclasses import dataclass, field
from omegaconf import OmegaConf, MISSING
from utils.class_registry import ClassRegistry
from models.methods import methods_registry
from metrics.metrics import metrics_registry


args = ClassRegistry()


@args.add_to_registry("exp")
@dataclass
class ExperimentArgs:
    config_dir: str = str(Path(__file__).resolve().parent / "configs")
    config: str = MISSING
    exp_dir: str = "experiments"
    name: str = MISSING
    seed: int = 1
    root: str = os.getenv("EXP_ROOT", ".")
    wandb: bool = True
    wandb_project: str = "sfe"
    domain: str = "human_faces"


@args.add_to_registry("data")
@dataclass
class DataArgs:
    special_dir: str = MISSING
    transform: str = "face_1024"
    input_train_dir: str = MISSING
    input_val_dir: str = MISSING


@args.add_to_registry("train")
@dataclass
class TrainingArgs:
    train_runner: str = "base_training_runner"
    encoder_optimizer: str = "ranger"
    disc_optimizer: str = "adam"
    resume_path: str = ""
    val_metrics: List[str] = field(
        default_factory=lambda: ["msssim", "lpips", "l2", "fid"]
    )
    start_step: int = 0
    steps: int = 300000
    log_step: int = 500
    checkpoint_step: int = 15000
    val_step: int = 15000
    train_dis: bool = False
    dis_train_start_step: int = 150000
    bs_used_before_adv_loss: int = 8
    disc_edits: List[str] = field(
        default_factory=lambda: []
    )

@args.add_to_registry("model")
@dataclass
class ModelArgs:
    method: str = "fse_full"
    device: str = "0"
    batch_size: int = 4
    workers: int = 4
    checkpoint_path: str = ""


@args.add_to_registry("encoder_losses")
@dataclass
class EncoderLossesArgs:
    l2: float = 0.0
    lpips: float = 0.0
    lpips_scale: float = 0.0
    id: float = 0.0
    moco: float = 0.0
    adv: float = 0.0
    feat_rec: float = 0.0
    feat_rec_l1: float = 0.0
    l2_latent: float = 0.0
    id_vit: float = 0.0


MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs")
args.add_to_registry("methods_args")(MethodsArgs)

DiscLossesArgs = disc_losses.make_dataclass_from_args("DiscLossesArgs")
args.add_to_registry("disc_losses")(DiscLossesArgs)

OptimizersArgs = optimizers.make_dataclass_from_args("OptimizersArgs")
args.add_to_registry("optimizers")(OptimizersArgs)

MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs")
args.add_to_registry("metrics")(MetricsArgs)


Args = args.make_dataclass_from_classes("Args")


def load_config():
    config = OmegaConf.structured(Args)

    conf_cli = OmegaConf.from_cli()
    config.exp.config = conf_cli.exp.config
    config.exp.config_dir = conf_cli.exp.config_dir

    config_path = os.path.join(config.exp.config_dir, config.exp.config)
    conf_file = OmegaConf.load(config_path)
    config = OmegaConf.merge(config, conf_file)

    for method in list(config.methods_args.keys()):
        if method != config.model.method:
            config.methods_args.__delattr__(method)

    config = OmegaConf.merge(config, conf_cli)

    return config