File size: 1,994 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

from pathlib import Path
from typing import Optional, List, Tuple, Dict
from dataclasses import dataclass, field
from omegaconf import OmegaConf, MISSING
from utils.class_registry import ClassRegistry
from models.methods import methods_registry
from metrics.metrics import metrics_registry



args = ClassRegistry()


@args.add_to_registry("exp")
@dataclass
class ExperimentArgs:
    config_dir: str = str(Path(__file__).resolve().parent / "configs")
    config: str = MISSING
    output_dir: str = "results_dir"
    seed: int = 1
    root: str = os.getenv("EXP_ROOT", ".")
    domain: str = "human_faces"
    wandb: bool = False


@args.add_to_registry("data")
@dataclass
class DataArgs:
    inference_dir: str = ""
    transform: str = "face_1024"


@args.add_to_registry("inference")
@dataclass
class InferenceArgs:
    inference_runner: str = "base_inference_runner"
    editings_data: Dict = field(default_factory=lambda: {})


@args.add_to_registry("model")
@dataclass
class ModelArgs:
    method: str = "fse_full"
    device: str = "0"
    batch_size: int = 4
    workers: int = 4
    checkpoint_path: str = ""



MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs")
args.add_to_registry("methods_args")(MethodsArgs)

MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs")
args.add_to_registry("metrics")(MetricsArgs)



Args = args.make_dataclass_from_classes("Args")


def load_config():
    config = OmegaConf.structured(Args)

    conf_cli = OmegaConf.from_cli()
    config.exp.config = conf_cli.exp.config
    config.exp.config_dir = conf_cli.exp.config_dir

    config_path = os.path.join(config.exp.config_dir, config.exp.config)
    conf_file = OmegaConf.load(config_path)
    config = OmegaConf.merge(config, conf_file)
    for method in list(config.methods_args.keys()):
        if method != config.model.method:
            config.methods_args.__delattr__(method)

    config = OmegaConf.merge(config, conf_cli)

    return config