Abigail99216's picture
Upload folder using huggingface_hub
f43af3c verified
from abc import abstractmethod
from typing import Any
from omegaconf import OmegaConf
from easy_tpp.utils import save_yaml_config, Registrable, logger
class Config(Registrable):
def save_to_yaml_file(self, config_dir):
"""Save the config into the yaml file 'config_dir'.
Args:
config_dir (str): Target filename.
Returns:
"""
yaml_config = self.get_yaml_config()
OmegaConf.save(yaml_config, config_dir)
@staticmethod
def build_from_yaml_file(yaml_dir, **kwargs):
"""Load yaml config file from disk.
Args:
yaml_dir (str): Path of the yaml config file.
Returns:
EasyTPP.Config: Config object corresponding to cls.
"""
config = OmegaConf.load(yaml_dir)
pipeline_config = config.get('pipeline_config_id')
config_cls = Config.by_name(pipeline_config.lower())
logger.critical(f'Load pipeline config class {config_cls.__name__}')
return config_cls.parse_from_yaml_config(config, **kwargs)
@abstractmethod
def get_yaml_config(self):
"""Get the yaml format config from self.
Returns:
"""
pass
@staticmethod
@abstractmethod
def parse_from_yaml_config(yaml_config):
"""Parse from the yaml to generate the config object.
Args:
yaml_config (dict): configs from yaml file.
Returns:
EasyTPP.Config: Config class for data.
"""
pass
@abstractmethod
def copy(self):
"""Get a same and freely modifiable copy of self.
Returns:
"""
pass
def __str__(self):
"""Str representation of the config.
Returns:
str: str representation of the dict format of the config.
"""
return str(self.get_yaml_config())
def update(self, config):
"""Update the config.
Args:
config (dict): config dict.
Returns:
EasyTPP.Config: Config class for data.
"""
logger.critical(f'Update config class {self.__class__.__name__}')
return self.parse_from_yaml_config(config)
def pop(self, key: str, default_var: Any):
"""pop out the key-value item from the config.
Args:
key (str): key name.
default_var (Any): default value to pop.
Returns:
Any: value to pop.
"""
return vars(self).pop(key) or default_var
def get(self, key: str, default_var: Any):
"""Retrieve the key-value item from the config.
Args:
key (str): key name.
default_var (Any): default value to pop.
Returns:
Any: value to get.
"""
return vars(self)[key] or default_var
def set(self, key: str, var_to_set: Any):
"""Set the key-value item from the config.
Args:
key (str): key name.
var_to_set (Any): default value to pop.
Returns:
Any: value to get.
"""
vars(self)[key] = var_to_set