Smile_Changer / arguments /inference_arguments.py
LogicGoInfotechSpaces's picture
Bundle StyleFeatureEditor code packages in Space to fix ModuleNotFoundError
95b1715
import os
from pathlib import Path
from typing import Optional, List, Tuple, Dict
from dataclasses import dataclass, field
from omegaconf import OmegaConf, MISSING
from utils.class_registry import ClassRegistry
from models.methods import methods_registry
from metrics.metrics import metrics_registry
args = ClassRegistry()
@args.add_to_registry("exp")
@dataclass
class ExperimentArgs:
config_dir: str = str(Path(__file__).resolve().parent / "configs")
config: str = MISSING
output_dir: str = "results_dir"
seed: int = 1
root: str = os.getenv("EXP_ROOT", ".")
domain: str = "human_faces"
wandb: bool = False
@args.add_to_registry("data")
@dataclass
class DataArgs:
inference_dir: str = ""
transform: str = "face_1024"
@args.add_to_registry("inference")
@dataclass
class InferenceArgs:
inference_runner: str = "base_inference_runner"
editings_data: Dict = field(default_factory=lambda: {})
@args.add_to_registry("model")
@dataclass
class ModelArgs:
method: str = "fse_full"
device: str = "0"
batch_size: int = 4
workers: int = 4
checkpoint_path: str = ""
MethodsArgs = methods_registry.make_dataclass_from_args("MethodsArgs")
args.add_to_registry("methods_args")(MethodsArgs)
MetricsArgs = metrics_registry.make_dataclass_from_args("MetricsArgs")
args.add_to_registry("metrics")(MetricsArgs)
Args = args.make_dataclass_from_classes("Args")
def load_config():
config = OmegaConf.structured(Args)
conf_cli = OmegaConf.from_cli()
config.exp.config = conf_cli.exp.config
config.exp.config_dir = conf_cli.exp.config_dir
config_path = os.path.join(config.exp.config_dir, config.exp.config)
conf_file = OmegaConf.load(config_path)
config = OmegaConf.merge(config, conf_file)
for method in list(config.methods_args.keys()):
if method != config.model.method:
config.methods_args.__delattr__(method)
config = OmegaConf.merge(config, conf_cli)
return config