Spaces:
Running on Zero
Running on Zero
File size: 1,973 Bytes
03022ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import os
import yaml
class Config(object):
def __init__(self, conf_dict):
for key, value in conf_dict.items():
self.__dict__[key] = value
def convert_to_yaml(overrides):
"""Convert args to yaml for overrides"""
yaml_string = ""
# Handle '--arg=val' type args
joined_args = "=".join(overrides)
split_args = joined_args.split("=")
for arg in split_args:
if arg.startswith("--"):
yaml_string += "\n" + arg[len("--") :] + ":"
else:
yaml_string += " " + arg
return yaml_string.strip()
def yaml_config_loader(conf_file, overrides=None):
with open(conf_file, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if overrides is not None:
config.update(yaml.load(overrides, Loader=yaml.FullLoader))
variables = {k: v for k, v in config.items() if isinstance(k, str) and not k.startswith('_') and isinstance(v, (int, float, str, bool))}
def resolve(x):
if isinstance(x, dict):
return {k: resolve(v) for k, v in x.items()}
elif isinstance(x, list):
return [resolve(item) for item in x]
elif isinstance(x, str) and x.startswith('<') and x.endswith('>'):
key = x[1:-1]
return variables.get(key, x)
else:
return x
return resolve(config)
def build_config(config_file, overrides=None, copy=False):
if config_file.endswith(".yaml"):
if overrides is not None:
overrides = convert_to_yaml(overrides)
conf_dict = yaml_config_loader(config_file, overrides)
if copy and 'exp_dir' in conf_dict:
os.makedirs(conf_dict['exp_dir'], exist_ok=True)
saved_path = os.path.join(conf_dict['exp_dir'], 'config.yaml')
with open(saved_path, 'w') as f:
f.write(yaml.dump(conf_dict))
else:
raise ValueError("Unknown config file format")
return Config(conf_dict)
|