File size: 1,880 Bytes
148d42e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from omegaconf import OmegaConf
from dataclasses import dataclass, field
from datetime import datetime

import os

def time_string():
    return datetime.now().strftime("%m_%d_%Y_%H_%M_%S")

def _get_shuffle(lst):
    if isinstance(lst, list):
        return True if len(lst) == 1 else False
    elif lst is None:
        return True
    else:
        raise ValueError("Input must be a list or None")

OmegaConf.register_new_resolver("path_append", lambda a, b: os.path.join(a, b))
OmegaConf.register_new_resolver("get_trial_dir", lambda save_dir: os.path.join(os.getcwd(), save_dir, time_string()))
OmegaConf.register_new_resolver("get_run_id", lambda save_dir: save_dir.split('/')[-1])
OmegaConf.register_new_resolver("get_shuffle", lambda lst: True if len(lst) == 1 else False)

@dataclass
class ExpCfg:
    name:str = 'default'
    save_dir:str = 'runs'
    trial_dir:str = None
    save_cfg_path:str = None
    seed:int = 0

    data_type:str = 'BaseDataModule'
    data:dict = field(default_factory=dict)

    system_type:str = 'TinyNerf'
    system:dict = field(default_factory=dict)

    train:dict = field(default_factory=dict)

    def __post_init__(self):
        print('[INFO]: Experiment Configured')
        os.makedirs(self.trial_dir, exist_ok=True)
        print(f'[INFO]: Experiment Directory is created at {self.trial_dir}')
        self.dump(self.save_cfg_path)
        print(f'[INFO]: Experiment YAML Config is saved at {self.save_cfg_path}')
    
    def dump(self, path:str):
        with open(path, "w") as fp:
            OmegaConf.save(config=self, f=fp)

def load_cfg(cfg_path: str):
    cfg = OmegaConf.load(cfg_path)
    OmegaConf.resolve(cfg)
    scfg = parse_structure(ExpCfg, cfg)
    # print(f'[INFO]: Configuration: \n{OmegaConf.to_yaml(scfg)}')
    return scfg

def parse_structure(template, cfg):
    return OmegaConf.structured(template(**cfg))