File size: 5,306 Bytes
f43af3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from easy_tpp.config_factory.config import Config
class DataSpecConfig(Config):
def __init__(self, **kwargs):
"""Initialize the Config class.
"""
self.num_event_types = kwargs.get('num_event_types')
self.pad_token_id = kwargs.get('pad_token_id')
self.padding_side = kwargs.get('padding_side')
self.truncation_side = kwargs.get('truncation_side')
self.padding_strategy = kwargs.get('padding_strategy')
self.max_len = kwargs.get('max_len')
self.truncation_strategy = kwargs.get('truncation_strategy')
self.num_event_types_pad = self.num_event_types + 1
self.model_input_names = kwargs.get('model_input_names')
if self.padding_side is not None and self.padding_side not in ["right", "left"]:
raise ValueError(
f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
)
if self.truncation_side is not None and self.truncation_side not in ["right", "left"]:
raise ValueError(
f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}"
)
def get_yaml_config(self):
"""Return the config in dict (yaml compatible) format.
Returns:
dict: config of the data specs in dict format.
"""
return {
'num_event_types': self.num_event_types,
'pad_token_id': self.pad_token_id,
'padding_side': self.padding_side,
'truncation_side': self.truncation_side,
'padding_strategy': self.padding_strategy,
'truncation_strategy': self.truncation_strategy,
'max_len': self.max_len
}
@staticmethod
def parse_from_yaml_config(yaml_config):
"""Parse from the yaml to generate the config object.
Args:
yaml_config (dict): configs from yaml file.
Returns:
DataSpecConfig: Config class for data specs.
"""
return DataSpecConfig(**yaml_config)
def copy(self):
"""Copy the config.
Returns:
DataSpecConfig: a copy of current config.
"""
return DataSpecConfig(num_event_types_pad=self.num_event_types_pad,
num_event_types=self.num_event_types,
event_pad_index=self.pad_token_id,
padding_side=self.padding_side,
truncation_side=self.truncation_side,
padding_strategy=self.padding_strategy,
truncation_strategy=self.truncation_strategy,
max_len=self.max_len)
@Config.register('data_config')
class DataConfig(Config):
def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None):
"""Initialize the DataConfig object.
Args:
train_dir (str): dir of tran set.
valid_dir (str): dir of valid set.
test_dir (str): dir of test set.
specs (dict, optional): specs of dataset. Defaults to None.
"""
self.train_dir = train_dir
self.valid_dir = valid_dir
self.test_dir = test_dir
self.data_specs = specs or DataSpecConfig()
self.data_format = train_dir.split('.')[-1] if data_format is None else data_format
def get_yaml_config(self):
"""Return the config in dict (yaml compatible) format.
Returns:
dict: config of the data in dict format.
"""
return {
'train_dir': self.train_dir,
'valid_dir': self.valid_dir,
'test_dir': self.test_dir,
'data_format': self.data_format,
'data_specs': self.data_specs.get_yaml_config(),
}
@staticmethod
def parse_from_yaml_config(yaml_config):
"""Parse from the yaml to generate the config object.
Args:
yaml_config (dict): configs from yaml file.
Returns:
EasyTPP.DataConfig: Config class for data.
"""
return DataConfig(
train_dir=yaml_config.get('train_dir'),
valid_dir=yaml_config.get('valid_dir'),
test_dir=yaml_config.get('test_dir'),
data_format=yaml_config.get('data_format'),
specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs'))
)
def copy(self):
"""Copy the config.
Returns:
EasyTPP.DataConfig: a copy of current config.
"""
return DataConfig(train_dir=self.train_dir,
valid_dir=self.valid_dir,
test_dir=self.test_dir,
specs=self.data_specs)
def get_data_dir(self, split):
"""Get the dir of the source raw data.
Args:
split (str): dataset split notation, 'train', 'dev' or 'valid', 'test'.
Returns:
str: dir of the source raw data file.
"""
split = split.lower()
if split == 'train':
return self.train_dir
elif split in ['dev', 'valid']:
return self.valid_dir
else:
return self.test_dir
|