Spaces:
Sleeping
Sleeping
File size: 1,994 Bytes
95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
from pathlib import Path
from typing import Optional, List, Tuple, Dict
from dataclasses import dataclass, field
from omegaconf import OmegaConf, MISSING
from utils.class_registry import ClassRegistry
from models.methods import methods_registry
from metrics.metrics import metrics_registry
args = ClassRegistry()
@args.add_to_registry("exp")
@dataclass
class ExperimentArgs:
config_dir: str = str(Path(__file__).resolve().parent / "configs")
config: str = MISSING
output_dir: str = "results_dir"
seed: int = 1
root: str = os.getenv("EXP_ROOT", ".")
domain: str = "human_faces"
wandb: bool = False
@args.add_to_registry("data")
@dataclass
class DataArgs:
inference_dir: str = ""
transform: str = "face_1024"
@args.add_to_registry("inference")
@dataclass
class InferenceArgs:
inference_runner: str = "base_inference_runner"
editings_data: Dict = field(default_factory=lambda: {})
@args.add_to_registry("model")
@dataclass
class ModelArgs:
method: str = "fse_full"
device: str = "0"
batch_size: int = 4
workers: int = 4
checkpoint_path: str = ""
MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs")
args.add_to_registry("methods_args")(MethodsArgs)
MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs")
args.add_to_registry("metrics")(MetricsArgs)
Args = args.make_dataclass_from_classes("Args")
def load_config():
config = OmegaConf.structured(Args)
conf_cli = OmegaConf.from_cli()
config.exp.config = conf_cli.exp.config
config.exp.config_dir = conf_cli.exp.config_dir
config_path = os.path.join(config.exp.config_dir, config.exp.config)
conf_file = OmegaConf.load(config_path)
config = OmegaConf.merge(config, conf_file)
for method in list(config.methods_args.keys()):
if method != config.model.method:
config.methods_args.__delattr__(method)
config = OmegaConf.merge(config, conf_cli)
return config
|