|
|
from omegaconf import OmegaConf |
|
|
from dataclasses import dataclass, field |
|
|
from datetime import datetime |
|
|
|
|
|
import os |
|
|
|
|
|
def time_string(): |
|
|
return datetime.now().strftime("%m_%d_%Y_%H_%M_%S") |
|
|
|
|
|
def _get_shuffle(lst): |
|
|
if isinstance(lst, list): |
|
|
return True if len(lst) == 1 else False |
|
|
elif lst is None: |
|
|
return True |
|
|
else: |
|
|
raise ValueError("Input must be a list or None") |
|
|
|
|
|
OmegaConf.register_new_resolver("path_append", lambda a, b: os.path.join(a, b)) |
|
|
OmegaConf.register_new_resolver("get_trial_dir", lambda save_dir: os.path.join(os.getcwd(), save_dir, time_string())) |
|
|
OmegaConf.register_new_resolver("get_run_id", lambda save_dir: save_dir.split('/')[-1]) |
|
|
OmegaConf.register_new_resolver("get_shuffle", lambda lst: True if len(lst) == 1 else False) |
|
|
|
|
|
@dataclass |
|
|
class ExpCfg: |
|
|
name:str = 'default' |
|
|
save_dir:str = 'runs' |
|
|
trial_dir:str = None |
|
|
save_cfg_path:str = None |
|
|
seed:int = 0 |
|
|
|
|
|
data_type:str = 'BaseDataModule' |
|
|
data:dict = field(default_factory=dict) |
|
|
|
|
|
system_type:str = 'TinyNerf' |
|
|
system:dict = field(default_factory=dict) |
|
|
|
|
|
train:dict = field(default_factory=dict) |
|
|
|
|
|
def __post_init__(self): |
|
|
print('[INFO]: Experiment Configured') |
|
|
os.makedirs(self.trial_dir, exist_ok=True) |
|
|
print(f'[INFO]: Experiment Directory is created at {self.trial_dir}') |
|
|
self.dump(self.save_cfg_path) |
|
|
print(f'[INFO]: Experiment YAML Config is saved at {self.save_cfg_path}') |
|
|
|
|
|
def dump(self, path:str): |
|
|
with open(path, "w") as fp: |
|
|
OmegaConf.save(config=self, f=fp) |
|
|
|
|
|
def load_cfg(cfg_path: str): |
|
|
cfg = OmegaConf.load(cfg_path) |
|
|
OmegaConf.resolve(cfg) |
|
|
scfg = parse_structure(ExpCfg, cfg) |
|
|
|
|
|
return scfg |
|
|
|
|
|
def parse_structure(template, cfg): |
|
|
return OmegaConf.structured(template(**cfg)) |