| import yaml |
|
|
|
|
| class Params: |
| def __init__(self): |
| self.graph_mode = "sequential" |
| self.accumulation_steps = 1 |
| self.activation = "relu" |
| self.predict_intensity = False |
| self.batch_size = 32 |
| self.beta_2 = 0.98 |
| self.blank_weight = 1.0 |
| self.char_embedding = True |
| self.char_embedding_size = 128 |
| self.decoder_delay_steps = 0 |
| self.decoder_learning_rate = 6e-4 |
| self.decoder_weight_decay = 1.2e-6 |
| self.dropout_anchor = 0.5 |
| self.dropout_edge_label = 0.5 |
| self.dropout_edge_presence = 0.5 |
| self.dropout_label = 0.5 |
| self.dropout_transformer = 0.5 |
| self.dropout_transformer_attention = 0.1 |
| self.dropout_word = 0.1 |
| self.encoder = "xlm-roberta-base" |
| self.encoder_delay_steps = 2000 |
| self.encoder_freeze_embedding = True |
| self.encoder_learning_rate = 6e-5 |
| self.encoder_weight_decay = 1e-2 |
| self.lr_decay_multiplier = 100 |
| self.epochs = 100 |
| self.focal = True |
| self.freeze_bert = False |
| self.group_ops = False |
| self.hidden_size_ff = 4 * 768 |
| self.hidden_size_anchor = 128 |
| self.hidden_size_edge_label = 256 |
| self.hidden_size_edge_presence = 512 |
| self.layerwise_lr_decay = 1.0 |
| self.n_attention_heads = 8 |
| self.n_layers = 3 |
| self.query_length = 4 |
| self.pre_norm = True |
| self.warmup_steps = 6000 |
|
|
| def init_data_paths(self): |
| directory_1 = { |
| "sequential": "node_centric_mrp", |
| "node-centric": "node_centric_mrp", |
| "labeled-edge": "labeled_edge_mrp" |
| }[self.graph_mode] |
| directory_2 = { |
| ("darmstadt", "en"): "darmstadt_unis", |
| ("mpqa", "en"): "mpqa", |
| ("multibooked", "ca"): "multibooked_ca", |
| ("multibooked", "eu"): "multibooked_eu", |
| ("norec", "no"): "norec", |
| ("opener", "en"): "opener_en", |
| ("opener", "es"): "opener_es", |
| }[(self.framework, self.language)] |
|
|
| self.training_data = f"{self.data_directory}/{directory_1}/{directory_2}/train.mrp" |
| self.validation_data = f"{self.data_directory}/{directory_1}/{directory_2}/dev.mrp" |
| self.test_data = f"{self.data_directory}/{directory_1}/{directory_2}/test.mrp" |
|
|
| self.raw_training_data = f"{self.data_directory}/raw/{directory_2}/train.json" |
| self.raw_validation_data = f"{self.data_directory}/raw/{directory_2}/dev.json" |
|
|
| return self |
|
|
| def load_state_dict(self, d): |
| for k, v in d.items(): |
| setattr(self, k, v) |
| return self |
|
|
| def state_dict(self): |
| members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")] |
| return {k: self.__dict__[k] for k in members} |
|
|
| def load(self, args): |
| with open(args.config, "r", encoding="utf-8") as f: |
| params = yaml.safe_load(f) |
| self.load_state_dict(params) |
| self.init_data_paths() |
|
|
| def save(self, json_path): |
| with open(json_path, "w", encoding="utf-8") as f: |
| d = self.state_dict() |
| yaml.dump(d, f) |
|
|