Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- stanza/stanza/models/common/__init__.py +0 -0
- stanza/stanza/models/common/bert_embedding.py +509 -0
- stanza/stanza/models/common/biaffine.py +80 -0
- stanza/stanza/models/common/build_short_name_to_treebank.py +78 -0
- stanza/stanza/models/common/char_model.py +362 -0
- stanza/stanza/models/common/chuliu_edmonds.py +281 -0
- stanza/stanza/models/common/constant.py +550 -0
- stanza/stanza/models/common/count_ner_coverage.py +38 -0
- stanza/stanza/models/common/count_pretrain_coverage.py +41 -0
- stanza/stanza/models/common/crf.py +149 -0
- stanza/stanza/models/common/data.py +155 -0
- stanza/stanza/models/common/doc.py +1741 -0
- stanza/stanza/models/common/dropout.py +75 -0
- stanza/stanza/models/common/exceptions.py +15 -0
- stanza/stanza/models/common/foundation_cache.py +148 -0
- stanza/stanza/models/common/hlstm.py +124 -0
- stanza/stanza/models/common/large_margin_loss.py +68 -0
- stanza/stanza/models/common/loss.py +134 -0
- stanza/stanza/models/common/maxout_linear.py +42 -0
- stanza/stanza/models/common/packed_lstm.py +105 -0
- stanza/stanza/models/common/peft_config.py +119 -0
- stanza/stanza/models/common/seq2seq_constant.py +17 -0
- stanza/stanza/models/common/seq2seq_model.py +364 -0
- stanza/stanza/models/common/seq2seq_utils.py +121 -0
- stanza/stanza/models/common/short_name_to_treebank.py +619 -0
- stanza/stanza/models/common/trainer.py +20 -0
- stanza/stanza/models/common/utils.py +816 -0
- stanza/stanza/models/common/vocab.py +298 -0
- stanza/stanza/models/constituency/base_model.py +532 -0
- stanza/stanza/models/constituency/base_trainer.py +153 -0
- stanza/stanza/models/constituency/ensemble.py +486 -0
- stanza/stanza/models/constituency/in_order_compound_oracle.py +327 -0
- stanza/stanza/models/constituency/in_order_oracle.py +1029 -0
- stanza/stanza/models/constituency/lstm_model.py +1178 -0
- stanza/stanza/models/constituency/parse_tree.py +591 -0
- stanza/stanza/models/constituency/positional_encoding.py +89 -0
- stanza/stanza/models/constituency/retagging.py +130 -0
- stanza/stanza/models/constituency/state.py +144 -0
- stanza/stanza/models/constituency/top_down_oracle.py +757 -0
- stanza/stanza/models/constituency/trainer.py +306 -0
- stanza/stanza/models/constituency/transformer_tree_stack.py +198 -0
- stanza/stanza/models/constituency/transition_sequence.py +186 -0
- stanza/stanza/models/constituency/tree_embedding.py +135 -0
- stanza/stanza/models/coref/config.py +66 -0
- stanza/stanza/models/coref/coref_config.toml +285 -0
- stanza/stanza/models/coref/dataset.py +61 -0
- stanza/stanza/models/coref/pairwise_encoder.py +94 -0
- stanza/stanza/models/coref/rough_scorer.py +61 -0
- stanza/stanza/models/coref/utils.py +35 -0
- stanza/stanza/models/depparse/model.py +265 -0
stanza/stanza/models/common/__init__.py
ADDED
|
File without changes
|
stanza/stanza/models/common/bert_embedding.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger('stanza')
|
| 10 |
+
|
| 11 |
+
BERT_ARGS = {
|
| 12 |
+
"vinai/phobert-base": { "use_fast": True },
|
| 13 |
+
"vinai/phobert-large": { "use_fast": True },
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
class TextTooLongError(ValueError):
|
| 17 |
+
"""
|
| 18 |
+
A text was too long for the underlying model (possibly BERT)
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, length, max_len, line_num, text):
|
| 21 |
+
super().__init__("Found a text of length %d (possibly after tokenizing). Maximum handled length is %d Error occurred at line %d" % (length, max_len, line_num))
|
| 22 |
+
self.line_num = line_num
|
| 23 |
+
self.text = text
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def update_max_length(model_name, tokenizer):
|
| 27 |
+
if model_name in ('hf-internal-testing/tiny-bert',
|
| 28 |
+
'google/muril-base-cased',
|
| 29 |
+
'google/muril-large-cased',
|
| 30 |
+
'airesearch/wangchanberta-base-att-spm-uncased',
|
| 31 |
+
'camembert/camembert-large',
|
| 32 |
+
'hfl/chinese-electra-180g-large-discriminator',
|
| 33 |
+
'NYTK/electra-small-discriminator-hungarian'):
|
| 34 |
+
tokenizer.model_max_length = 512
|
| 35 |
+
|
| 36 |
+
def load_tokenizer(model_name, tokenizer_kwargs=None, local_files_only=False):
|
| 37 |
+
if model_name:
|
| 38 |
+
# note that use_fast is the default
|
| 39 |
+
try:
|
| 40 |
+
from transformers import AutoTokenizer
|
| 41 |
+
except ImportError:
|
| 42 |
+
raise ImportError("Please install transformers library for BERT support! Try `pip install transformers`.")
|
| 43 |
+
bert_args = BERT_ARGS.get(model_name, dict())
|
| 44 |
+
if not model_name.startswith("vinai/phobert"):
|
| 45 |
+
bert_args["add_prefix_space"] = True
|
| 46 |
+
if tokenizer_kwargs:
|
| 47 |
+
bert_args.update(tokenizer_kwargs)
|
| 48 |
+
bert_args['local_files_only'] = local_files_only
|
| 49 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(model_name, **bert_args)
|
| 50 |
+
update_max_length(model_name, bert_tokenizer)
|
| 51 |
+
return bert_tokenizer
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def load_bert(model_name, tokenizer_kwargs=None, local_files_only=False):
|
| 55 |
+
if model_name:
|
| 56 |
+
# such as: "vinai/phobert-base"
|
| 57 |
+
try:
|
| 58 |
+
from transformers import AutoModel
|
| 59 |
+
except ImportError:
|
| 60 |
+
raise ImportError("Please install transformers library for BERT support! Try `pip install transformers`.")
|
| 61 |
+
bert_model = AutoModel.from_pretrained(model_name, local_files_only=local_files_only)
|
| 62 |
+
bert_tokenizer = load_tokenizer(model_name, tokenizer_kwargs=tokenizer_kwargs, local_files_only=local_files_only)
|
| 63 |
+
return bert_model, bert_tokenizer
|
| 64 |
+
return None, None
|
| 65 |
+
|
| 66 |
+
def tokenize_manual(model_name, sent, tokenizer):
|
| 67 |
+
"""
|
| 68 |
+
Tokenize a sentence manually, using for checking long sentences and PHOBert.
|
| 69 |
+
"""
|
| 70 |
+
#replace \xa0 or whatever the space character is by _ since PhoBERT expects _ between syllables
|
| 71 |
+
tokenized = [word.replace("\xa0","_").replace(" ", "_") for word in sent] if model_name.startswith("vinai/phobert") else [word.replace("\xa0"," ") for word in sent]
|
| 72 |
+
|
| 73 |
+
#concatenate to a sentence
|
| 74 |
+
sentence = ' '.join(tokenized)
|
| 75 |
+
|
| 76 |
+
#tokenize using AutoTokenizer PhoBERT
|
| 77 |
+
tokenized = tokenizer.tokenize(sentence)
|
| 78 |
+
|
| 79 |
+
#convert tokens to ids
|
| 80 |
+
sent_ids = tokenizer.convert_tokens_to_ids(tokenized)
|
| 81 |
+
|
| 82 |
+
#add start and end tokens to sent_ids
|
| 83 |
+
tokenized_sent = [tokenizer.bos_token_id] + sent_ids + [tokenizer.eos_token_id]
|
| 84 |
+
|
| 85 |
+
return tokenized, tokenized_sent
|
| 86 |
+
|
| 87 |
+
def filter_data(model_name, data, tokenizer = None, log_level=logging.DEBUG):
|
| 88 |
+
"""
|
| 89 |
+
Filter out the (NER, POS) data that is too long for BERT model.
|
| 90 |
+
"""
|
| 91 |
+
if tokenizer is None:
|
| 92 |
+
tokenizer = load_tokenizer(model_name)
|
| 93 |
+
filtered_data = []
|
| 94 |
+
#eliminate all the sentences that are too long for bert model
|
| 95 |
+
for sent in data:
|
| 96 |
+
sentence = [word if isinstance(word, str) else word[0] for word in sent]
|
| 97 |
+
_, tokenized_sent = tokenize_manual(model_name, sentence, tokenizer)
|
| 98 |
+
|
| 99 |
+
if len(tokenized_sent) > tokenizer.model_max_length - 2:
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
filtered_data.append(sent)
|
| 103 |
+
|
| 104 |
+
logger.log(log_level, "Eliminated %d of %d datapoints because their length is over maximum size of BERT model.", (len(data)-len(filtered_data)), len(data))
|
| 105 |
+
|
| 106 |
+
return filtered_data
|
| 107 |
+
|
| 108 |
+
def needs_length_filter(model_name):
|
| 109 |
+
"""
|
| 110 |
+
TODO: we were lazy and didn't implement any form of length fudging for models other than bert/roberta/electra
|
| 111 |
+
"""
|
| 112 |
+
if 'bart' in model_name or 'xlnet' in model_name:
|
| 113 |
+
return True
|
| 114 |
+
if model_name.startswith("vinai/phobert"):
|
| 115 |
+
return True
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
def cloned_feature(feature, num_layers, detach=True):
|
| 119 |
+
"""
|
| 120 |
+
Clone & detach the feature, keeping the last N layers (or averaging -2,-3,-4 if not specified)
|
| 121 |
+
|
| 122 |
+
averaging 3 of the last 4 layers worked well for non-VI languages
|
| 123 |
+
"""
|
| 124 |
+
# in most cases, need to call with features.hidden_states
|
| 125 |
+
# bartpho is different - it has features.decoder_hidden_states
|
| 126 |
+
# feature[2] is the same for bert, but it didn't work for
|
| 127 |
+
# older versions of transformers for xlnet
|
| 128 |
+
if num_layers is None:
|
| 129 |
+
feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4
|
| 130 |
+
else:
|
| 131 |
+
feature = torch.stack(feature[-num_layers:], axis=3)
|
| 132 |
+
if detach:
|
| 133 |
+
return feature.clone().detach()
|
| 134 |
+
else:
|
| 135 |
+
return feature
|
| 136 |
+
|
| 137 |
+
def extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
|
| 138 |
+
"""
|
| 139 |
+
Handles vi-bart. May need testing before using on other bart
|
| 140 |
+
|
| 141 |
+
https://github.com/VinAIResearch/BARTpho
|
| 142 |
+
"""
|
| 143 |
+
processed = [] # final product, returns the list of list of word representation
|
| 144 |
+
|
| 145 |
+
sentences = [" ".join([word.replace(" ", "_") for word in sentence]) for sentence in data]
|
| 146 |
+
tokenized = tokenizer(sentences, return_tensors='pt', padding=True, return_attention_mask=True)
|
| 147 |
+
input_ids = tokenized['input_ids'].to(device)
|
| 148 |
+
attention_mask = tokenized['attention_mask'].to(device)
|
| 149 |
+
|
| 150 |
+
for i in range(int(math.ceil(len(sentences)/128))):
|
| 151 |
+
start_sentence = i * 128
|
| 152 |
+
end_sentence = min(start_sentence + 128, len(sentences))
|
| 153 |
+
input_ids = input_ids[start_sentence:end_sentence]
|
| 154 |
+
attention_mask = attention_mask[start_sentence:end_sentence]
|
| 155 |
+
|
| 156 |
+
if detach:
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
| 159 |
+
features = cloned_feature(features.decoder_hidden_states, num_layers, detach)
|
| 160 |
+
else:
|
| 161 |
+
features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
| 162 |
+
features = cloned_feature(features.decoder_hidden_states, num_layers, detach)
|
| 163 |
+
|
| 164 |
+
for feature, sentence in zip(features, data):
|
| 165 |
+
# +2 for the endpoints
|
| 166 |
+
feature = feature[:len(sentence)+2]
|
| 167 |
+
if not keep_endpoints:
|
| 168 |
+
feature = feature[1:-1]
|
| 169 |
+
processed.append(feature)
|
| 170 |
+
|
| 171 |
+
return processed
|
| 172 |
+
|
| 173 |
+
def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
|
| 174 |
+
"""
|
| 175 |
+
Extract transformer embeddings using a method specifically for phobert
|
| 176 |
+
|
| 177 |
+
Since phobert doesn't have the is_split_into_words / tokenized.word_ids(batch_index=0)
|
| 178 |
+
capability, we instead look for @@ to denote a continued token.
|
| 179 |
+
data: list of list of string (the text tokens)
|
| 180 |
+
"""
|
| 181 |
+
processed = [] # final product, returns the list of list of word representation
|
| 182 |
+
tokenized_sents = [] # list of sentences, each is a torch tensor with start and end token
|
| 183 |
+
list_tokenized = [] # list of tokenized sentences from phobert
|
| 184 |
+
for idx, sent in enumerate(data):
|
| 185 |
+
|
| 186 |
+
tokenized, tokenized_sent = tokenize_manual(model_name, sent, tokenizer)
|
| 187 |
+
|
| 188 |
+
#add tokenized to list_tokenzied for later checking
|
| 189 |
+
list_tokenized.append(tokenized)
|
| 190 |
+
|
| 191 |
+
if len(tokenized_sent) > tokenizer.model_max_length:
|
| 192 |
+
logger.error("Invalid size, max size: %d, got %d %s", tokenizer.model_max_length, len(tokenized_sent), data[idx])
|
| 193 |
+
raise TextTooLongError(len(tokenized_sent), tokenizer.model_max_length, idx, " ".join(data[idx]))
|
| 194 |
+
|
| 195 |
+
#add to tokenized_sents
|
| 196 |
+
tokenized_sents.append(torch.tensor(tokenized_sent).detach())
|
| 197 |
+
|
| 198 |
+
processed_sent = []
|
| 199 |
+
processed.append(processed_sent)
|
| 200 |
+
|
| 201 |
+
# done loading bert emb
|
| 202 |
+
|
| 203 |
+
size = len(tokenized_sents)
|
| 204 |
+
|
| 205 |
+
#padding the inputs
|
| 206 |
+
tokenized_sents_padded = torch.nn.utils.rnn.pad_sequence(tokenized_sents,batch_first=True,padding_value=tokenizer.pad_token_id)
|
| 207 |
+
|
| 208 |
+
features = []
|
| 209 |
+
|
| 210 |
+
# Feed into PhoBERT 128 at a time in a batch fashion. In testing, the loop was
|
| 211 |
+
# run only 1 time as the batch size for the outer model was less than that
|
| 212 |
+
# (30 for conparser, for example)
|
| 213 |
+
for i in range(int(math.ceil(size/128))):
|
| 214 |
+
padded_input = tokenized_sents_padded[128*i:128*i+128]
|
| 215 |
+
start_sentence = i * 128
|
| 216 |
+
end_sentence = start_sentence + padded_input.shape[0]
|
| 217 |
+
attention_mask = torch.zeros(end_sentence - start_sentence, padded_input.shape[1], device=device)
|
| 218 |
+
for sent_idx, sent in enumerate(tokenized_sents[start_sentence:end_sentence]):
|
| 219 |
+
attention_mask[sent_idx, :len(sent)] = 1
|
| 220 |
+
if detach:
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
# TODO: is the clone().detach() necessary?
|
| 223 |
+
feature = model(padded_input.clone().detach().to(device), attention_mask=attention_mask, output_hidden_states=True)
|
| 224 |
+
features += cloned_feature(feature.hidden_states, num_layers, detach)
|
| 225 |
+
else:
|
| 226 |
+
feature = model(padded_input.to(device), attention_mask=attention_mask, output_hidden_states=True)
|
| 227 |
+
features += cloned_feature(feature.hidden_states, num_layers, detach)
|
| 228 |
+
|
| 229 |
+
assert len(features)==size
|
| 230 |
+
assert len(features)==len(processed)
|
| 231 |
+
|
| 232 |
+
#process the output
|
| 233 |
+
#only take the vector of the last word piece of a word/ you can do other methods such as first word piece or averaging.
|
| 234 |
+
# idx2+1 compensates for the start token at the start of a sentence
|
| 235 |
+
offsets = [[idx2+1 for idx2, _ in enumerate(list_tokenized[idx]) if (idx2 > 0 and not list_tokenized[idx][idx2-1].endswith("@@")) or (idx2==0)]
|
| 236 |
+
for idx, sent in enumerate(processed)]
|
| 237 |
+
if keep_endpoints:
|
| 238 |
+
# [0] and [-1] grab the start and end representations as well
|
| 239 |
+
offsets = [[0] + off + [-1] for off in offsets]
|
| 240 |
+
processed = [feature[offset] for feature, offset in zip(features, offsets)]
|
| 241 |
+
|
| 242 |
+
# This is a list of tensors
|
| 243 |
+
# Each tensor holds the representation of a sentence extracted from phobert
|
| 244 |
+
return processed
|
| 245 |
+
|
| 246 |
+
BAD_TOKENIZERS = ('bert-base-german-cased',
|
| 247 |
+
# the dbmdz tokenizers turn one or more types of characters into empty words
|
| 248 |
+
# for example, from PoSTWITA:
|
| 249 |
+
# ewww — in viaggio Roma
|
| 250 |
+
# the character which may not be rendering properly is 0xFE4FA
|
| 251 |
+
# https://github.com/dbmdz/berts/issues/48
|
| 252 |
+
'dbmdz/bert-base-german-cased',
|
| 253 |
+
'dbmdz/bert-base-italian-xxl-cased',
|
| 254 |
+
'dbmdz/bert-base-italian-cased',
|
| 255 |
+
'dbmdz/electra-base-italian-xxl-cased-discriminator',
|
| 256 |
+
# each of these (perhaps using similar tokenizers?)
|
| 257 |
+
# does not digest the script-flip-mark \u200f
|
| 258 |
+
'avichr/heBERT',
|
| 259 |
+
'onlplab/alephbert-base',
|
| 260 |
+
'imvladikon/alephbertgimmel-base-512',
|
| 261 |
+
# these indonesian models fail on a sentence in the Indonesian GSD dataset:
|
| 262 |
+
# 'Tak', 'dapat', 'disangkal', 'jika', '\u200e', 'kemenangan', ...
|
| 263 |
+
# weirdly some other indonesian models (even by the same group) don't have that problem
|
| 264 |
+
'cahya/bert-base-indonesian-1.5G',
|
| 265 |
+
'indolem/indobert-base-uncased',
|
| 266 |
+
'google/muril-base-cased',
|
| 267 |
+
'l3cube-pune/marathi-roberta')
|
| 268 |
+
|
| 269 |
+
def fix_blank_tokens(tokenizer, data):
|
| 270 |
+
"""Patch bert tokenizers with missing characters
|
| 271 |
+
|
| 272 |
+
There is an issue that some tokenizers (so far the German ones identified above)
|
| 273 |
+
tokenize soft hyphens or other unknown characters into nothing
|
| 274 |
+
If an entire word is tokenized as a soft hyphen, this means the tokenizer
|
| 275 |
+
simply vaporizes that word. The result is we're missing an embedding for
|
| 276 |
+
an entire word we wanted to use.
|
| 277 |
+
|
| 278 |
+
The solution we take here is to look for any words which get vaporized
|
| 279 |
+
in such a manner, eg `len(token) == 2`, and replace it with a regular "-"
|
| 280 |
+
|
| 281 |
+
Actually, recently we have found that even the Bert / Electra tokenizer
|
| 282 |
+
can do this in the case of "words" which are one special character long,
|
| 283 |
+
so the easiest thing to do is just always run this function
|
| 284 |
+
"""
|
| 285 |
+
new_data = []
|
| 286 |
+
for sentence in data:
|
| 287 |
+
tokenized = tokenizer(sentence, is_split_into_words=False).input_ids
|
| 288 |
+
new_sentence = [word if len(token) > 2 else "-" for word, token in zip(sentence, tokenized)]
|
| 289 |
+
new_data.append(new_sentence)
|
| 290 |
+
return new_data
|
| 291 |
+
|
| 292 |
+
def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
|
| 293 |
+
# using attention masks makes contextual embeddings much more useful for downstream tasks
|
| 294 |
+
tokenized = tokenizer(data, is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=False)
|
| 295 |
+
#tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
|
| 296 |
+
|
| 297 |
+
list_offsets = [[None] * (len(sentence)+2) for sentence in data]
|
| 298 |
+
for idx in range(len(data)):
|
| 299 |
+
offsets = tokenized.word_ids(batch_index=idx)
|
| 300 |
+
list_offsets[idx][0] = 0
|
| 301 |
+
for pos, offset in enumerate(offsets):
|
| 302 |
+
if offset is None:
|
| 303 |
+
break
|
| 304 |
+
# this uses the last token piece for any offset by overwriting the previous value
|
| 305 |
+
# this will be one token earlier
|
| 306 |
+
# we will add a <pad> to the start of each sentence for the endpoints
|
| 307 |
+
list_offsets[idx][offset+1] = pos + 1
|
| 308 |
+
list_offsets[idx][-1] = list_offsets[idx][-2] + 1
|
| 309 |
+
if any(x is None for x in list_offsets[idx]):
|
| 310 |
+
raise ValueError("OOPS, hit None when preparing to use Bert\ndata[idx]: {}\noffsets: {}\nlist_offsets[idx]: {}".format(data[idx], offsets, list_offsets[idx], tokenized))
|
| 311 |
+
|
| 312 |
+
if len(offsets) > tokenizer.model_max_length - 2:
|
| 313 |
+
logger.error("Invalid size, max size: %d, got %d %s", tokenizer.model_max_length, len(offsets), data[idx])
|
| 314 |
+
raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))
|
| 315 |
+
|
| 316 |
+
features = []
|
| 317 |
+
for i in range(int(math.ceil(len(data)/128))):
|
| 318 |
+
# TODO: find a suitable representation for attention masks for xlnet
|
| 319 |
+
# xlnet base on WSJ:
|
| 320 |
+
# sep_token_id at beginning, cls_token_id at end: 0.9441
|
| 321 |
+
# bos_token_id at beginning, eos_token_id at end: 0.9463
|
| 322 |
+
# bos_token_id at beginning, sep_token_id at end: 0.9459
|
| 323 |
+
# bos_token_id at beginning, cls_token_id at end: 0.9457
|
| 324 |
+
# bos_token_id at beginning, sep/cls at end: 0.9454
|
| 325 |
+
# use the xlnet tokenization with words at end,
|
| 326 |
+
# begin token is last pad, end token is sep, no mask: 0.9463
|
| 327 |
+
# same, but with masks: 0.9440
|
| 328 |
+
input_ids = [[tokenizer.bos_token_id] + x[:-2] + [tokenizer.eos_token_id] for x in tokenized['input_ids'][128*i:128*i+128]]
|
| 329 |
+
max_len = max(len(x) for x in input_ids)
|
| 330 |
+
attention_mask = torch.zeros(len(input_ids), max_len, dtype=torch.long, device=device)
|
| 331 |
+
for idx, input_row in enumerate(input_ids):
|
| 332 |
+
attention_mask[idx, :len(input_row)] = 1
|
| 333 |
+
if len(input_row) < max_len:
|
| 334 |
+
input_row.extend([tokenizer.pad_token_id] * (max_len - len(input_row)))
|
| 335 |
+
if detach:
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
id_tensor = torch.tensor(input_ids, device=device)
|
| 338 |
+
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
|
| 339 |
+
# feature[2] is the same for bert, but it didn't work for
|
| 340 |
+
# older versions of transformers for xlnet
|
| 341 |
+
# feature = feature[2]
|
| 342 |
+
features += cloned_feature(feature.hidden_states, num_layers, detach)
|
| 343 |
+
else:
|
| 344 |
+
id_tensor = torch.tensor(input_ids, device=device)
|
| 345 |
+
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
|
| 346 |
+
# feature[2] is the same for bert, but it didn't work for
|
| 347 |
+
# older versions of transformers for xlnet
|
| 348 |
+
# feature = feature[2]
|
| 349 |
+
features += cloned_feature(feature.hidden_states, num_layers, detach)
|
| 350 |
+
|
| 351 |
+
processed = []
|
| 352 |
+
#process the output
|
| 353 |
+
if not keep_endpoints:
|
| 354 |
+
#remove the bos and eos tokens
|
| 355 |
+
list_offsets = [sent[1:-1] for sent in list_offsets]
|
| 356 |
+
for feature, offsets in zip(features, list_offsets):
|
| 357 |
+
new_sent = feature[offsets]
|
| 358 |
+
processed.append(new_sent)
|
| 359 |
+
|
| 360 |
+
return processed
|
| 361 |
+
|
| 362 |
+
def build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device):
|
| 363 |
+
"""
|
| 364 |
+
Extract an embedding from the given transformer for a certain attention mask and tokens range
|
| 365 |
+
|
| 366 |
+
In the event that the tokens are longer than the max length
|
| 367 |
+
supported by the model, the range is split up into overlapping
|
| 368 |
+
sections and the overlapping pieces are connected. No idea if
|
| 369 |
+
this is actually any good, but at least it returns something
|
| 370 |
+
instead of horribly failing
|
| 371 |
+
|
| 372 |
+
TODO: at least two upgrades are very relevant
|
| 373 |
+
1) cut off some overlap at the end as well
|
| 374 |
+
2) use this on the phobert, bart, and xln versions as well
|
| 375 |
+
"""
|
| 376 |
+
if attention_tensor.shape[1] <= tokenizer.model_max_length:
|
| 377 |
+
features = model(id_tensor, attention_mask=attention_tensor, output_hidden_states=True)
|
| 378 |
+
features = cloned_feature(features.hidden_states, num_layers, detach)
|
| 379 |
+
return features
|
| 380 |
+
|
| 381 |
+
slices = []
|
| 382 |
+
slice_len = max(tokenizer.model_max_length - 20, tokenizer.model_max_length // 2)
|
| 383 |
+
prefix_len = tokenizer.model_max_length - slice_len
|
| 384 |
+
if slice_len < 5:
|
| 385 |
+
raise RuntimeError("Really tiny tokenizer!")
|
| 386 |
+
remaining_attention = attention_tensor
|
| 387 |
+
remaining_ids = id_tensor
|
| 388 |
+
while True:
|
| 389 |
+
attention_slice = remaining_attention[:, :tokenizer.model_max_length]
|
| 390 |
+
id_slice = remaining_ids[:, :tokenizer.model_max_length]
|
| 391 |
+
features = model(id_slice, attention_mask=attention_slice, output_hidden_states=True)
|
| 392 |
+
features = cloned_feature(features.hidden_states, num_layers, detach)
|
| 393 |
+
if len(slices) > 0:
|
| 394 |
+
features = features[:, prefix_len:, :]
|
| 395 |
+
slices.append(features)
|
| 396 |
+
if remaining_attention.shape[1] <= tokenizer.model_max_length:
|
| 397 |
+
break
|
| 398 |
+
remaining_attention = remaining_attention[:, slice_len:]
|
| 399 |
+
remaining_ids = remaining_ids[:, slice_len:]
|
| 400 |
+
slices = torch.cat(slices, axis=1)
|
| 401 |
+
return slices
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def convert_to_position_list(sentence, offsets):
|
| 405 |
+
"""
|
| 406 |
+
Convert a transformers-tokenized sentence's offsets to a list of word to position
|
| 407 |
+
"""
|
| 408 |
+
# +2 for the beginning and end
|
| 409 |
+
list_offsets = [None] * (len(sentence) + 2)
|
| 410 |
+
for pos, offset in enumerate(offsets):
|
| 411 |
+
if offset is None:
|
| 412 |
+
continue
|
| 413 |
+
# this uses the last token piece for any offset by overwriting the previous value
|
| 414 |
+
list_offsets[offset+1] = pos
|
| 415 |
+
list_offsets[0] = 0
|
| 416 |
+
for offset in list_offsets[-2::-1]:
|
| 417 |
+
# count backwards in case the last position was
|
| 418 |
+
# a word or character that got erased by the tokenizer
|
| 419 |
+
# this loop should eventually find something...
|
| 420 |
+
# after all, we just set the first one to be 0
|
| 421 |
+
if offset is not None:
|
| 422 |
+
list_offsets[-1] = offset + 1
|
| 423 |
+
break
|
| 424 |
+
return list_offsets
|
| 425 |
+
|
| 426 |
+
def extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach):
|
| 427 |
+
#add add_prefix_space = True for RoBerTa-- error if not
|
| 428 |
+
# using attention masks makes contextual embeddings much more useful for downstream tasks
|
| 429 |
+
tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
|
| 430 |
+
list_offsets = []
|
| 431 |
+
for idx in range(len(data)):
|
| 432 |
+
converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))
|
| 433 |
+
list_offsets.append(converted_offsets)
|
| 434 |
+
|
| 435 |
+
#if list_offsets[idx][-1] > tokenizer.model_max_length - 1:
|
| 436 |
+
# logger.error("Invalid size, max size: %d, got %d.\nTokens: %s\nTokenized: %s", tokenizer.model_max_length, len(offsets), data[idx][:1000], offsets[:1000])
|
| 437 |
+
# raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))
|
| 438 |
+
|
| 439 |
+
if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):
|
| 440 |
+
# at least one of the tokens in the data is composed entirely of characters the tokenizer doesn't know about
|
| 441 |
+
# one possible approach would be to retokenize only those sentences
|
| 442 |
+
# however, in that case the attention mask might be of a different length,
|
| 443 |
+
# as would the token ids, and it would be a pain to fix those
|
| 444 |
+
# easiest to just retokenize the whole thing, hopefully a rare event
|
| 445 |
+
data = fix_blank_tokens(tokenizer, data)
|
| 446 |
+
|
| 447 |
+
tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
|
| 448 |
+
list_offsets = []
|
| 449 |
+
for idx in range(len(data)):
|
| 450 |
+
converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))
|
| 451 |
+
list_offsets.append(converted_offsets)
|
| 452 |
+
|
| 453 |
+
if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):
|
| 454 |
+
raise ValueError("OOPS, hit None when preparing to use Bert\ndata[idx]: {}\noffsets: {}\nlist_offsets[idx]: {}".format(data[idx], offsets, list_offsets[idx], tokenized))
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
features = []
|
| 458 |
+
for i in range(int(math.ceil(len(data)/128))):
|
| 459 |
+
attention_tensor = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
|
| 460 |
+
id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
|
| 461 |
+
if detach:
|
| 462 |
+
with torch.no_grad():
|
| 463 |
+
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
|
| 464 |
+
else:
|
| 465 |
+
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
|
| 466 |
+
|
| 467 |
+
processed = []
|
| 468 |
+
#process the output
|
| 469 |
+
if not keep_endpoints:
|
| 470 |
+
#remove the bos and eos tokens
|
| 471 |
+
list_offsets = [sent[1:-1] for sent in list_offsets]
|
| 472 |
+
for feature, offsets in zip(features, list_offsets):
|
| 473 |
+
new_sent = feature[offsets]
|
| 474 |
+
processed.append(new_sent)
|
| 475 |
+
|
| 476 |
+
return processed
|
| 477 |
+
|
| 478 |
+
def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers=None, detach=True, peft_name=None):
|
| 479 |
+
"""
|
| 480 |
+
Extract transformer embeddings using a generic roberta extraction
|
| 481 |
+
|
| 482 |
+
data: list of list of string (the text tokens)
|
| 483 |
+
num_layers: how many to return. If None, the average of -2, -3, -4 is returned
|
| 484 |
+
"""
|
| 485 |
+
# TODO: can maybe cache this value for a model and save some time
|
| 486 |
+
# TODO: too bad it isn't thread safe, but then again, who does?
|
| 487 |
+
if peft_name is None:
|
| 488 |
+
if model._hf_peft_config_loaded:
|
| 489 |
+
model.disable_adapters()
|
| 490 |
+
else:
|
| 491 |
+
model.enable_adapters()
|
| 492 |
+
model.set_adapter(peft_name)
|
| 493 |
+
|
| 494 |
+
if model_name.startswith("vinai/phobert"):
|
| 495 |
+
return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
|
| 496 |
+
|
| 497 |
+
if 'bart' in model_name:
|
| 498 |
+
# this should work with "vinai/bartpho-word"
|
| 499 |
+
# not sure this works with any other Bart
|
| 500 |
+
return extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
|
| 501 |
+
|
| 502 |
+
if isinstance(data, tuple):
|
| 503 |
+
data = list(data)
|
| 504 |
+
|
| 505 |
+
if "xlnet" in model_name:
|
| 506 |
+
return extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
|
| 507 |
+
|
| 508 |
+
return extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
|
| 509 |
+
|
stanza/stanza/models/common/biaffine.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class PairwiseBilinear(nn.Module):
|
| 6 |
+
''' A bilinear module that deals with broadcasting for efficient memory usage.
|
| 7 |
+
Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
|
| 8 |
+
Output: tensor of size (N x L1 x L2 x O)'''
|
| 9 |
+
def __init__(self, input1_size, input2_size, output_size, bias=True):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.input1_size = input1_size
|
| 13 |
+
self.input2_size = input2_size
|
| 14 |
+
self.output_size = output_size
|
| 15 |
+
|
| 16 |
+
self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))
|
| 17 |
+
self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else 0
|
| 18 |
+
|
| 19 |
+
def forward(self, input1, input2):
|
| 20 |
+
input1_size = list(input1.size())
|
| 21 |
+
input2_size = list(input2.size())
|
| 22 |
+
output_size = [input1_size[0], input1_size[1], input2_size[1], self.output_size]
|
| 23 |
+
|
| 24 |
+
# ((N x L1) x D1) * (D1 x (D2 x O)) -> (N x L1) x (D2 x O)
|
| 25 |
+
intermediate = torch.mm(input1.view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size))
|
| 26 |
+
# (N x L2 x D2) -> (N x D2 x L2)
|
| 27 |
+
input2 = input2.transpose(1, 2)
|
| 28 |
+
# (N x (L1 x O) x D2) * (N x D2 x L2) -> (N x (L1 x O) x L2)
|
| 29 |
+
output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
|
| 30 |
+
# (N x (L1 x O) x L2) -> (N x L1 x L2 x O)
|
| 31 |
+
output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)
|
| 32 |
+
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
class BiaffineScorer(nn.Module):
|
| 36 |
+
def __init__(self, input1_size, input2_size, output_size):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.W_bilin = nn.Bilinear(input1_size + 1, input2_size + 1, output_size)
|
| 39 |
+
|
| 40 |
+
self.W_bilin.weight.data.zero_()
|
| 41 |
+
self.W_bilin.bias.data.zero_()
|
| 42 |
+
|
| 43 |
+
def forward(self, input1, input2):
|
| 44 |
+
input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)
|
| 45 |
+
input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)
|
| 46 |
+
return self.W_bilin(input1, input2)
|
| 47 |
+
|
| 48 |
+
class PairwiseBiaffineScorer(nn.Module):
|
| 49 |
+
def __init__(self, input1_size, input2_size, output_size):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.W_bilin = PairwiseBilinear(input1_size + 1, input2_size + 1, output_size)
|
| 52 |
+
|
| 53 |
+
self.W_bilin.weight.data.zero_()
|
| 54 |
+
self.W_bilin.bias.data.zero_()
|
| 55 |
+
|
| 56 |
+
def forward(self, input1, input2):
|
| 57 |
+
input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)
|
| 58 |
+
input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)
|
| 59 |
+
return self.W_bilin(input1, input2)
|
| 60 |
+
|
| 61 |
+
class DeepBiaffineScorer(nn.Module):
|
| 62 |
+
def __init__(self, input1_size, input2_size, hidden_size, output_size, hidden_func=F.relu, dropout=0, pairwise=True):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.W1 = nn.Linear(input1_size, hidden_size)
|
| 65 |
+
self.W2 = nn.Linear(input2_size, hidden_size)
|
| 66 |
+
self.hidden_func = hidden_func
|
| 67 |
+
if pairwise:
|
| 68 |
+
self.scorer = PairwiseBiaffineScorer(hidden_size, hidden_size, output_size)
|
| 69 |
+
else:
|
| 70 |
+
self.scorer = BiaffineScorer(hidden_size, hidden_size, output_size)
|
| 71 |
+
self.dropout = nn.Dropout(dropout)
|
| 72 |
+
|
| 73 |
+
def forward(self, input1, input2):
|
| 74 |
+
return self.scorer(self.dropout(self.hidden_func(self.W1(input1))), self.dropout(self.hidden_func(self.W2(input2))))
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
x1 = torch.randn(3,4)
|
| 78 |
+
x2 = torch.randn(3,5)
|
| 79 |
+
scorer = DeepBiaffineScorer(4, 5, 6, 7)
|
| 80 |
+
print(scorer(x1, x2))
|
stanza/stanza/models/common/build_short_name_to_treebank.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from stanza.models.common.constant import treebank_to_short_name, UnknownLanguageError, treebank_special_cases
|
| 5 |
+
from stanza.utils import default_paths
|
| 6 |
+
|
| 7 |
+
paths = default_paths.get_default_paths()
|
| 8 |
+
udbase = paths["UDBASE"]
|
| 9 |
+
|
| 10 |
+
directories = glob.glob(udbase + "/UD_*")
|
| 11 |
+
directories.sort()
|
| 12 |
+
|
| 13 |
+
output_name = os.path.join(os.path.split(__file__)[0], "short_name_to_treebank.py")
|
| 14 |
+
ud_names = [os.path.split(ud_path)[1] for ud_path in directories]
|
| 15 |
+
short_names = []
|
| 16 |
+
|
| 17 |
+
# check that all languages are known in the language map
|
| 18 |
+
# use that language map to come up with a shortname for these treebanks
|
| 19 |
+
for directory, ud_name in zip(directories, ud_names):
|
| 20 |
+
try:
|
| 21 |
+
short_names.append(treebank_to_short_name(ud_name))
|
| 22 |
+
except UnknownLanguageError as e:
|
| 23 |
+
raise UnknownLanguageError("Could not find language short name for dataset %s, path %s" % (ud_name, directory)) from e
|
| 24 |
+
|
| 25 |
+
for directory, ud_name in zip(directories, ud_names):
|
| 26 |
+
if ud_name.startswith("UD_Norwegian"):
|
| 27 |
+
if ud_name not in treebank_special_cases:
|
| 28 |
+
raise ValueError("Please figure out if dataset %s is NN or NB, then add to treebank_special_cases" % ud_name)
|
| 29 |
+
if ud_name.startswith("UD_Chinese"):
|
| 30 |
+
if ud_name not in treebank_special_cases:
|
| 31 |
+
raise ValueError("Please figure out if dataset %s is NN or NB, then add to treebank_special_cases" % ud_name)
|
| 32 |
+
|
| 33 |
+
max_len = max(len(x) for x in short_names) + 8
|
| 34 |
+
line_format = " %-" + str(max_len) + "s '%s',\n"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
print("Writing to %s" % output_name)
|
| 38 |
+
with open(output_name, "w") as fout:
|
| 39 |
+
fout.write("# This module is autogenerated by build_short_name_to_treebank.py\n")
|
| 40 |
+
fout.write("# Please do not edit\n")
|
| 41 |
+
fout.write("\n")
|
| 42 |
+
fout.write("SHORT_NAMES = {\n")
|
| 43 |
+
for short_name, ud_name in zip(short_names, ud_names):
|
| 44 |
+
fout.write(line_format % ("'" + short_name + "':", ud_name))
|
| 45 |
+
|
| 46 |
+
if short_name.startswith("zh_"):
|
| 47 |
+
short_name = "zh-hans_" + short_name[3:]
|
| 48 |
+
fout.write(line_format % ("'" + short_name + "':", ud_name))
|
| 49 |
+
elif short_name.startswith("zh-hans_") or short_name.startswith("zh-hant_"):
|
| 50 |
+
short_name = "zh_" + short_name[8:]
|
| 51 |
+
fout.write(line_format % ("'" + short_name + "':", ud_name))
|
| 52 |
+
elif short_name == 'nb_bokmaal':
|
| 53 |
+
short_name = 'no_bokmaal'
|
| 54 |
+
fout.write(line_format % ("'" + short_name + "':", ud_name))
|
| 55 |
+
|
| 56 |
+
fout.write("}\n")
|
| 57 |
+
|
| 58 |
+
fout.write("""
|
| 59 |
+
|
| 60 |
+
def short_name_to_treebank(short_name):
|
| 61 |
+
return SHORT_NAMES[short_name]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
""")
|
| 65 |
+
|
| 66 |
+
max_len = max(len(x) for x in ud_names) + 5
|
| 67 |
+
line_format = " %-" + str(max_len) + "s '%s',\n"
|
| 68 |
+
fout.write("CANONICAL_NAMES = {\n")
|
| 69 |
+
for ud_name in ud_names:
|
| 70 |
+
fout.write(line_format % ("'" + ud_name.lower() + "':", ud_name))
|
| 71 |
+
fout.write("}\n")
|
| 72 |
+
fout.write("""
|
| 73 |
+
|
| 74 |
+
def canonical_treebank_name(ud_name):
|
| 75 |
+
if ud_name in SHORT_NAMES:
|
| 76 |
+
return SHORT_NAMES[ud_name]
|
| 77 |
+
return CANONICAL_NAMES.get(ud_name.lower(), ud_name)
|
| 78 |
+
""")
|
stanza/stanza/models/common/char_model.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on
|
| 3 |
+
|
| 4 |
+
@inproceedings{akbik-etal-2018-contextual,
|
| 5 |
+
title = "Contextual String Embeddings for Sequence Labeling",
|
| 6 |
+
author = "Akbik, Alan and
|
| 7 |
+
Blythe, Duncan and
|
| 8 |
+
Vollgraf, Roland",
|
| 9 |
+
booktitle = "Proceedings of the 27th International Conference on Computational Linguistics",
|
| 10 |
+
month = aug,
|
| 11 |
+
year = "2018",
|
| 12 |
+
address = "Santa Fe, New Mexico, USA",
|
| 13 |
+
publisher = "Association for Computational Linguistics",
|
| 14 |
+
url = "https://aclanthology.org/C18-1139",
|
| 15 |
+
pages = "1638--1649",
|
| 16 |
+
}
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from collections import Counter
|
| 20 |
+
from operator import itemgetter
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pack_padded_sequence, PackedSequence
|
| 26 |
+
|
| 27 |
+
from stanza.models.common.data import get_long_tensor
|
| 28 |
+
from stanza.models.common.packed_lstm import PackedLSTM
|
| 29 |
+
from stanza.models.common.utils import open_read_text, tensor_unsort, unsort
|
| 30 |
+
from stanza.models.common.dropout import SequenceUnitDropout
|
| 31 |
+
from stanza.models.common.vocab import UNK_ID, CharVocab
|
| 32 |
+
|
| 33 |
+
class CharacterModel(nn.Module):
|
| 34 |
+
def __init__(self, args, vocab, pad=False, bidirectional=False, attention=True):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.args = args
|
| 37 |
+
self.pad = pad
|
| 38 |
+
self.num_dir = 2 if bidirectional else 1
|
| 39 |
+
self.attn = attention
|
| 40 |
+
|
| 41 |
+
# char embeddings
|
| 42 |
+
self.char_emb = nn.Embedding(len(vocab['char']), self.args['char_emb_dim'], padding_idx=0)
|
| 43 |
+
if self.attn:
|
| 44 |
+
self.char_attn = nn.Linear(self.num_dir * self.args['char_hidden_dim'], 1, bias=False)
|
| 45 |
+
self.char_attn.weight.data.zero_()
|
| 46 |
+
|
| 47 |
+
# modules
|
| 48 |
+
self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \
|
| 49 |
+
dropout=0 if self.args['char_num_layers'] == 1 else args['dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=bidirectional)
|
| 50 |
+
self.charlstm_h_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
|
| 51 |
+
self.charlstm_c_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
|
| 52 |
+
|
| 53 |
+
self.dropout = nn.Dropout(args['dropout'])
|
| 54 |
+
|
| 55 |
+
def forward(self, chars, chars_mask, word_orig_idx, sentlens, wordlens):
|
| 56 |
+
embs = self.dropout(self.char_emb(chars))
|
| 57 |
+
batch_size = embs.size(0)
|
| 58 |
+
embs = pack_padded_sequence(embs, wordlens, batch_first=True)
|
| 59 |
+
output = self.charlstm(embs, wordlens, hx=(\
|
| 60 |
+
self.charlstm_h_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(), \
|
| 61 |
+
self.charlstm_c_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous()))
|
| 62 |
+
|
| 63 |
+
# apply attention, otherwise take final states
|
| 64 |
+
if self.attn:
|
| 65 |
+
char_reps = output[0]
|
| 66 |
+
weights = torch.sigmoid(self.char_attn(self.dropout(char_reps.data)))
|
| 67 |
+
char_reps = PackedSequence(char_reps.data * weights, char_reps.batch_sizes)
|
| 68 |
+
char_reps, _ = pad_packed_sequence(char_reps, batch_first=True)
|
| 69 |
+
res = char_reps.sum(1)
|
| 70 |
+
else:
|
| 71 |
+
h, c = output[1]
|
| 72 |
+
res = h[-2:].transpose(0,1).contiguous().view(batch_size, -1)
|
| 73 |
+
|
| 74 |
+
# recover character order and word separation
|
| 75 |
+
res = tensor_unsort(res, word_orig_idx)
|
| 76 |
+
res = pack_sequence(res.split(sentlens))
|
| 77 |
+
if self.pad:
|
| 78 |
+
res = pad_packed_sequence(res, batch_first=True)[0]
|
| 79 |
+
|
| 80 |
+
return res
|
| 81 |
+
|
| 82 |
+
def build_charlm_vocab(path, cutoff=0):
|
| 83 |
+
"""
|
| 84 |
+
Build a vocab for a CharacterLanguageModel
|
| 85 |
+
|
| 86 |
+
Requires a large amount of memory, but only need to build once
|
| 87 |
+
|
| 88 |
+
here we need some trick to deal with excessively large files
|
| 89 |
+
for each file we accumulate the counter of characters, and
|
| 90 |
+
at the end we simply pass a list of chars to the vocab builder
|
| 91 |
+
"""
|
| 92 |
+
counter = Counter()
|
| 93 |
+
if os.path.isdir(path):
|
| 94 |
+
filenames = sorted(os.listdir(path))
|
| 95 |
+
else:
|
| 96 |
+
filenames = [os.path.split(path)[1]]
|
| 97 |
+
path = os.path.split(path)[0]
|
| 98 |
+
|
| 99 |
+
for filename in filenames:
|
| 100 |
+
filename = os.path.join(path, filename)
|
| 101 |
+
with open_read_text(filename) as fin:
|
| 102 |
+
for line in fin:
|
| 103 |
+
counter.update(list(line))
|
| 104 |
+
|
| 105 |
+
if len(counter) == 0:
|
| 106 |
+
raise ValueError("Training data was empty!")
|
| 107 |
+
# remove infrequent characters from vocab
|
| 108 |
+
for k in list(counter.keys()):
|
| 109 |
+
if counter[k] < cutoff:
|
| 110 |
+
del counter[k]
|
| 111 |
+
# a singleton list of all characters
|
| 112 |
+
data = [sorted([x[0] for x in counter.most_common()])]
|
| 113 |
+
if len(data[0]) == 0:
|
| 114 |
+
raise ValueError("All characters in the training data were less frequent than --cutoff!")
|
| 115 |
+
vocab = CharVocab(data) # skip cutoff argument because this has been dealt with
|
| 116 |
+
return vocab
|
| 117 |
+
|
| 118 |
+
CHARLM_START = "\n"
|
| 119 |
+
CHARLM_END = " "
|
| 120 |
+
|
| 121 |
+
class CharacterLanguageModel(nn.Module):
|
| 122 |
+
|
| 123 |
+
def __init__(self, args, vocab, pad=False, is_forward_lm=True):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.args = args
|
| 126 |
+
self.vocab = vocab
|
| 127 |
+
self.is_forward_lm = is_forward_lm
|
| 128 |
+
self.pad = pad
|
| 129 |
+
self.finetune = True # always finetune unless otherwise specified
|
| 130 |
+
|
| 131 |
+
# char embeddings
|
| 132 |
+
self.char_emb = nn.Embedding(len(self.vocab['char']), self.args['char_emb_dim'], padding_idx=None) # we use space as padding, so padding_idx is not necessary
|
| 133 |
+
|
| 134 |
+
# modules
|
| 135 |
+
self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \
|
| 136 |
+
dropout=0 if self.args['char_num_layers'] == 1 else args['char_dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=False)
|
| 137 |
+
self.charlstm_h_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
|
| 138 |
+
self.charlstm_c_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
|
| 139 |
+
|
| 140 |
+
# decoder
|
| 141 |
+
self.decoder = nn.Linear(self.args['char_hidden_dim'], len(self.vocab['char']))
|
| 142 |
+
self.dropout = nn.Dropout(args['char_dropout'])
|
| 143 |
+
self.char_dropout = SequenceUnitDropout(args.get('char_unit_dropout', 0), UNK_ID)
|
| 144 |
+
|
| 145 |
+
def forward(self, chars, charlens, hidden=None):
|
| 146 |
+
chars = self.char_dropout(chars)
|
| 147 |
+
embs = self.dropout(self.char_emb(chars))
|
| 148 |
+
batch_size = embs.size(0)
|
| 149 |
+
embs = pack_padded_sequence(embs, charlens, batch_first=True)
|
| 150 |
+
if hidden is None:
|
| 151 |
+
hidden = (self.charlstm_h_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(),
|
| 152 |
+
self.charlstm_c_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous())
|
| 153 |
+
output, hidden = self.charlstm(embs, charlens, hx=hidden)
|
| 154 |
+
output = self.dropout(pad_packed_sequence(output, batch_first=True)[0])
|
| 155 |
+
decoded = self.decoder(output)
|
| 156 |
+
return output, hidden, decoded
|
| 157 |
+
|
| 158 |
+
def get_representation(self, chars, charoffsets, charlens, char_orig_idx):
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
output, _, _ = self.forward(chars, charlens)
|
| 161 |
+
res = [output[i, offsets] for i, offsets in enumerate(charoffsets)]
|
| 162 |
+
res = unsort(res, char_orig_idx)
|
| 163 |
+
res = pack_sequence(res)
|
| 164 |
+
if self.pad:
|
| 165 |
+
res = pad_packed_sequence(res, batch_first=True)[0]
|
| 166 |
+
return res
|
| 167 |
+
|
| 168 |
+
def per_char_representation(self, words):
|
| 169 |
+
device = next(self.parameters()).device
|
| 170 |
+
vocab = self.char_vocab()
|
| 171 |
+
|
| 172 |
+
all_data = [(vocab.map(word), len(word), idx) for idx, word in enumerate(words)]
|
| 173 |
+
all_data.sort(key=itemgetter(1), reverse=True)
|
| 174 |
+
chars = [x[0] for x in all_data]
|
| 175 |
+
char_lens = [x[1] for x in all_data]
|
| 176 |
+
char_tensor = get_long_tensor(chars, len(chars), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
output, _, _ = self.forward(char_tensor, char_lens)
|
| 179 |
+
output = [x[:y, :] for x, y in zip(output, char_lens)]
|
| 180 |
+
output = unsort(output, [x[2] for x in all_data])
|
| 181 |
+
return output
|
| 182 |
+
|
| 183 |
+
def build_char_representation(self, sentences):
|
| 184 |
+
"""
|
| 185 |
+
Return values from this charlm for a list of list of words
|
| 186 |
+
|
| 187 |
+
input: [[str]]
|
| 188 |
+
K sentences, each of length Ki (can be different for each sentence)
|
| 189 |
+
output: [tensor(Ki x dim)]
|
| 190 |
+
list of tensors, each one with shape Ki by the dim of the character model
|
| 191 |
+
|
| 192 |
+
Values are taken from the last character in a word for each word.
|
| 193 |
+
The words are effectively treated as if they are whitespace separated
|
| 194 |
+
(which may actually be somewhat inaccurate for languages such as Chinese or for MWT)
|
| 195 |
+
"""
|
| 196 |
+
forward = self.is_forward_lm
|
| 197 |
+
vocab = self.char_vocab()
|
| 198 |
+
device = next(self.parameters()).device
|
| 199 |
+
|
| 200 |
+
all_data = []
|
| 201 |
+
for idx, words in enumerate(sentences):
|
| 202 |
+
if not forward:
|
| 203 |
+
words = [x[::-1] for x in reversed(words)]
|
| 204 |
+
|
| 205 |
+
chars = [CHARLM_START]
|
| 206 |
+
offsets = []
|
| 207 |
+
for w in words:
|
| 208 |
+
chars.extend(w)
|
| 209 |
+
chars.append(CHARLM_END)
|
| 210 |
+
offsets.append(len(chars) - 1)
|
| 211 |
+
if not forward:
|
| 212 |
+
offsets.reverse()
|
| 213 |
+
chars = vocab.map(chars)
|
| 214 |
+
all_data.append((chars, offsets, len(chars), len(all_data)))
|
| 215 |
+
|
| 216 |
+
all_data.sort(key=itemgetter(2), reverse=True)
|
| 217 |
+
chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data))
|
| 218 |
+
# TODO: can this be faster?
|
| 219 |
+
chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)
|
| 220 |
+
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
output, _, _ = self.forward(chars, char_lens)
|
| 223 |
+
res = [output[i, offsets] for i, offsets in enumerate(char_offsets)]
|
| 224 |
+
res = unsort(res, orig_idx)
|
| 225 |
+
|
| 226 |
+
return res
|
| 227 |
+
|
| 228 |
+
def hidden_dim(self):
|
| 229 |
+
return self.args['char_hidden_dim']
|
| 230 |
+
|
| 231 |
+
def char_vocab(self):
|
| 232 |
+
return self.vocab['char']
|
| 233 |
+
|
| 234 |
+
def train(self, mode=True):
|
| 235 |
+
"""
|
| 236 |
+
Override the default train() function, so that when self.finetune == False, the training mode
|
| 237 |
+
won't be impacted by the parent models' status change.
|
| 238 |
+
"""
|
| 239 |
+
if not mode: # eval() is always allowed, regardless of finetune status
|
| 240 |
+
super().train(mode)
|
| 241 |
+
else:
|
| 242 |
+
if self.finetune: # only set to training mode in finetune status
|
| 243 |
+
super().train(mode)
|
| 244 |
+
|
| 245 |
+
def full_state(self):
|
| 246 |
+
state = {
|
| 247 |
+
'vocab': self.vocab['char'].state_dict(),
|
| 248 |
+
'args': self.args,
|
| 249 |
+
'state_dict': self.state_dict(),
|
| 250 |
+
'pad': self.pad,
|
| 251 |
+
'is_forward_lm': self.is_forward_lm
|
| 252 |
+
}
|
| 253 |
+
return state
|
| 254 |
+
|
| 255 |
+
def save(self, filename):
|
| 256 |
+
os.makedirs(os.path.split(filename)[0], exist_ok=True)
|
| 257 |
+
state = self.full_state()
|
| 258 |
+
torch.save(state, filename, _use_new_zipfile_serialization=False)
|
| 259 |
+
|
| 260 |
+
@classmethod
|
| 261 |
+
def from_full_state(cls, state, finetune=False):
|
| 262 |
+
vocab = {'char': CharVocab.load_state_dict(state['vocab'])}
|
| 263 |
+
model = cls(state['args'], vocab, state['pad'], state['is_forward_lm'])
|
| 264 |
+
model.load_state_dict(state['state_dict'])
|
| 265 |
+
model.eval()
|
| 266 |
+
model.finetune = finetune # set finetune status
|
| 267 |
+
return model
|
| 268 |
+
|
| 269 |
+
@classmethod
|
| 270 |
+
def load(cls, filename, finetune=False):
|
| 271 |
+
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 272 |
+
# allow saving just the Model object,
|
| 273 |
+
# and allow for old charlms to still work
|
| 274 |
+
if 'state_dict' in state:
|
| 275 |
+
return cls.from_full_state(state, finetune)
|
| 276 |
+
return cls.from_full_state(state['model'], finetune)
|
| 277 |
+
|
| 278 |
+
class CharacterLanguageModelWordAdapter(nn.Module):
|
| 279 |
+
"""
|
| 280 |
+
Adapts a character model to return embeddings for each character in a word
|
| 281 |
+
"""
|
| 282 |
+
def __init__(self, charlms):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.charlms = charlms
|
| 285 |
+
|
| 286 |
+
def forward(self, words):
|
| 287 |
+
words = [CHARLM_START + x + CHARLM_END for x in words]
|
| 288 |
+
padded_reps = []
|
| 289 |
+
for charlm in self.charlms:
|
| 290 |
+
rep = charlm.per_char_representation(words)
|
| 291 |
+
padded_rep = torch.zeros(len(rep), max(x.shape[0] for x in rep), rep[0].shape[1], dtype=rep[0].dtype, device=rep[0].device)
|
| 292 |
+
for idx, row in enumerate(rep):
|
| 293 |
+
padded_rep[idx, :row.shape[0], :] = row
|
| 294 |
+
padded_reps.append(padded_rep)
|
| 295 |
+
padded_rep = torch.cat(padded_reps, dim=2)
|
| 296 |
+
return padded_rep
|
| 297 |
+
|
| 298 |
+
def hidden_dim(self):
|
| 299 |
+
return sum(charlm.hidden_dim() for charlm in self.charlms)
|
| 300 |
+
|
| 301 |
+
class CharacterLanguageModelTrainer():
|
| 302 |
+
def __init__(self, model, params, optimizer, criterion, scheduler, epoch=1, global_step=0):
|
| 303 |
+
self.model = model
|
| 304 |
+
self.params = params
|
| 305 |
+
self.optimizer = optimizer
|
| 306 |
+
self.criterion = criterion
|
| 307 |
+
self.scheduler = scheduler
|
| 308 |
+
self.epoch = epoch
|
| 309 |
+
self.global_step = global_step
|
| 310 |
+
|
| 311 |
+
def save(self, filename, full=True):
|
| 312 |
+
os.makedirs(os.path.split(filename)[0], exist_ok=True)
|
| 313 |
+
state = {
|
| 314 |
+
'model': self.model.full_state(),
|
| 315 |
+
'epoch': self.epoch,
|
| 316 |
+
'global_step': self.global_step,
|
| 317 |
+
}
|
| 318 |
+
if full and self.optimizer is not None:
|
| 319 |
+
state['optimizer'] = self.optimizer.state_dict()
|
| 320 |
+
if full and self.criterion is not None:
|
| 321 |
+
state['criterion'] = self.criterion.state_dict()
|
| 322 |
+
if full and self.scheduler is not None:
|
| 323 |
+
state['scheduler'] = self.scheduler.state_dict()
|
| 324 |
+
torch.save(state, filename, _use_new_zipfile_serialization=False)
|
| 325 |
+
|
| 326 |
+
@classmethod
|
| 327 |
+
def from_new_model(cls, args, vocab):
|
| 328 |
+
model = CharacterLanguageModel(args, vocab, is_forward_lm=True if args['direction'] == 'forward' else False)
|
| 329 |
+
model = model.to(args['device'])
|
| 330 |
+
params = [param for param in model.parameters() if param.requires_grad]
|
| 331 |
+
optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
|
| 332 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 333 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=args['anneal'], patience=args['patience'])
|
| 334 |
+
return cls(model, params, optimizer, criterion, scheduler)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
@classmethod
|
| 338 |
+
def load(cls, args, filename, finetune=False):
|
| 339 |
+
"""
|
| 340 |
+
Load the model along with any other saved state for training
|
| 341 |
+
|
| 342 |
+
Note that you MUST set finetune=True if planning to continue training
|
| 343 |
+
Otherwise the only benefit you will get will be a warm GPU
|
| 344 |
+
"""
|
| 345 |
+
state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 346 |
+
model = CharacterLanguageModel.from_full_state(state['model'], finetune)
|
| 347 |
+
model = model.to(args['device'])
|
| 348 |
+
|
| 349 |
+
params = [param for param in model.parameters() if param.requires_grad]
|
| 350 |
+
optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
|
| 351 |
+
if 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])
|
| 352 |
+
|
| 353 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 354 |
+
if 'criterion' in state: criterion.load_state_dict(state['criterion'])
|
| 355 |
+
|
| 356 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=args['anneal'], patience=args['patience'])
|
| 357 |
+
if 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
|
| 358 |
+
|
| 359 |
+
epoch = state.get('epoch', 1)
|
| 360 |
+
global_step = state.get('global_step', 0)
|
| 361 |
+
return cls(model, params, optimizer, criterion, scheduler, epoch, global_step)
|
| 362 |
+
|
stanza/stanza/models/common/chuliu_edmonds.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from Tim's code here: https://github.com/tdozat/Parser-v3/blob/master/scripts/chuliu_edmonds.py
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def tarjan(tree):
|
| 6 |
+
"""Finds the cycles in a dependency graph
|
| 7 |
+
|
| 8 |
+
The input should be a numpy array of integers,
|
| 9 |
+
where in the standard use case,
|
| 10 |
+
tree[i] is the head of node i.
|
| 11 |
+
|
| 12 |
+
tree[0] == 0 to represent the root
|
| 13 |
+
|
| 14 |
+
so for example, for the English sentence "This is a test",
|
| 15 |
+
the input is
|
| 16 |
+
|
| 17 |
+
[0 4 4 4 0]
|
| 18 |
+
|
| 19 |
+
"Arthritis makes my hip hurt"
|
| 20 |
+
|
| 21 |
+
[0 2 0 4 2 2]
|
| 22 |
+
|
| 23 |
+
The return is a list of cycles, where in cycle has True if the
|
| 24 |
+
node at that index is participating in the cycle.
|
| 25 |
+
So, for example, the previous examples both return empty lists,
|
| 26 |
+
whereas an input of
|
| 27 |
+
np.array([0, 3, 1, 2])
|
| 28 |
+
has an output of
|
| 29 |
+
[np.array([False, True, True, True])]
|
| 30 |
+
"""
|
| 31 |
+
indices = -np.ones_like(tree)
|
| 32 |
+
lowlinks = -np.ones_like(tree)
|
| 33 |
+
onstack = np.zeros_like(tree, dtype=bool)
|
| 34 |
+
stack = list()
|
| 35 |
+
_index = [0]
|
| 36 |
+
cycles = []
|
| 37 |
+
#-------------------------------------------------------------
|
| 38 |
+
def maybe_pop_cycle(i):
|
| 39 |
+
if lowlinks[i] == indices[i]:
|
| 40 |
+
# There's a cycle!
|
| 41 |
+
cycle = np.zeros_like(indices, dtype=bool)
|
| 42 |
+
while stack[-1] != i:
|
| 43 |
+
j = stack.pop()
|
| 44 |
+
onstack[j] = False
|
| 45 |
+
cycle[j] = True
|
| 46 |
+
stack.pop()
|
| 47 |
+
onstack[i] = False
|
| 48 |
+
cycle[i] = True
|
| 49 |
+
if cycle.sum() > 1:
|
| 50 |
+
cycles.append(cycle)
|
| 51 |
+
|
| 52 |
+
def initialize_strong_connect(i):
|
| 53 |
+
_index[0] += 1
|
| 54 |
+
index = _index[-1]
|
| 55 |
+
indices[i] = lowlinks[i] = index - 1
|
| 56 |
+
stack.append(i)
|
| 57 |
+
onstack[i] = True
|
| 58 |
+
|
| 59 |
+
def strong_connect(i):
|
| 60 |
+
# this ridiculous atrocity is because somehow people keep
|
| 61 |
+
# coming up with graphs which overflow python's call stack
|
| 62 |
+
# so instead we make our own call stack and turn the recursion
|
| 63 |
+
# into a loop
|
| 64 |
+
# see for example
|
| 65 |
+
# https://github.com/stanfordnlp/stanza/issues/962
|
| 66 |
+
# https://github.com/spraakbanken/sparv-pipeline/issues/166
|
| 67 |
+
# in an ideal world this block of code would look like this
|
| 68 |
+
# initialize_strong_connect(i)
|
| 69 |
+
# dependents = iter(np.where(np.equal(tree, i))[0])
|
| 70 |
+
# for j in dependents:
|
| 71 |
+
# if indices[j] == -1:
|
| 72 |
+
# strong_connect(j)
|
| 73 |
+
# lowlinks[i] = min(lowlinks[i], lowlinks[j])
|
| 74 |
+
# elif onstack[j]:
|
| 75 |
+
# lowlinks[i] = min(lowlinks[i], indices[j])
|
| 76 |
+
#
|
| 77 |
+
# maybe_pop_cycle(i)
|
| 78 |
+
call_stack = [(i, None, None)]
|
| 79 |
+
while len(call_stack) > 0:
|
| 80 |
+
i, dependents_iterator, j = call_stack.pop()
|
| 81 |
+
if dependents_iterator is None: # first time getting here for this i
|
| 82 |
+
initialize_strong_connect(i)
|
| 83 |
+
dependents_iterator = iter(np.where(np.equal(tree, i))[0])
|
| 84 |
+
else: # been here before. j was the dependent we were just considering
|
| 85 |
+
lowlinks[i] = min(lowlinks[i], lowlinks[j])
|
| 86 |
+
for j in dependents_iterator:
|
| 87 |
+
if indices[j] == -1:
|
| 88 |
+
# have to remember where we were...
|
| 89 |
+
# put the current iterator & its state on the "call stack"
|
| 90 |
+
# we will come back to it later
|
| 91 |
+
call_stack.append((i, dependents_iterator, j))
|
| 92 |
+
# also, this is what we do next...
|
| 93 |
+
call_stack.append((j, None, None))
|
| 94 |
+
# this will break this iterator for now
|
| 95 |
+
# the next time through, we will continue progressing this iterator
|
| 96 |
+
break
|
| 97 |
+
elif onstack[j]:
|
| 98 |
+
lowlinks[i] = min(lowlinks[i], indices[j])
|
| 99 |
+
else:
|
| 100 |
+
# this is an intended use of for/else
|
| 101 |
+
# please stop filing git issues on obscure language features
|
| 102 |
+
# we finished iterating without a break
|
| 103 |
+
# and can finally resolve any possible cycles
|
| 104 |
+
maybe_pop_cycle(i)
|
| 105 |
+
# at this point, there are two cases:
|
| 106 |
+
#
|
| 107 |
+
# we iterated all the way through an iterator (the else in the for/else)
|
| 108 |
+
# and have resolved any possible cycles. can then proceed to the previous
|
| 109 |
+
# iterator we were considering (or finish, if there are no others)
|
| 110 |
+
# OR
|
| 111 |
+
# we have hit a break in the iteration over the dependents
|
| 112 |
+
# for a node
|
| 113 |
+
# and we need to dig deeper into the graph and resolve the dependent's dependents
|
| 114 |
+
# before we can continue the previous node
|
| 115 |
+
#
|
| 116 |
+
# either way, we check to see if there are unfinished subtrees
|
| 117 |
+
# when that is finally done, we can return
|
| 118 |
+
|
| 119 |
+
#-------------------------------------------------------------
|
| 120 |
+
for i in range(len(tree)):
|
| 121 |
+
if indices[i] == -1:
|
| 122 |
+
strong_connect(i)
|
| 123 |
+
return cycles
|
| 124 |
+
|
| 125 |
+
def process_cycle(tree, cycle, scores):
|
| 126 |
+
"""
|
| 127 |
+
Build a subproblem with one cycle broken
|
| 128 |
+
"""
|
| 129 |
+
# indices of cycle in original tree; (c) in t
|
| 130 |
+
cycle_locs = np.where(cycle)[0]
|
| 131 |
+
# heads of cycle in original tree; (c) in t
|
| 132 |
+
cycle_subtree = tree[cycle]
|
| 133 |
+
# scores of cycle in original tree; (c) in R
|
| 134 |
+
cycle_scores = scores[cycle, cycle_subtree]
|
| 135 |
+
# total score of cycle; () in R
|
| 136 |
+
cycle_score = cycle_scores.sum()
|
| 137 |
+
|
| 138 |
+
# locations of noncycle; (t) in [0,1]
|
| 139 |
+
noncycle = np.logical_not(cycle)
|
| 140 |
+
# indices of noncycle in original tree; (n) in t
|
| 141 |
+
noncycle_locs = np.where(noncycle)[0]
|
| 142 |
+
#print(cycle_locs, noncycle_locs)
|
| 143 |
+
|
| 144 |
+
# scores of cycle's potential heads; (c x n) - (c) + () -> (n x c) in R
|
| 145 |
+
metanode_head_scores = scores[cycle][:,noncycle] - cycle_scores[:,None] + cycle_score
|
| 146 |
+
# scores of cycle's potential dependents; (n x c) in R
|
| 147 |
+
metanode_dep_scores = scores[noncycle][:,cycle]
|
| 148 |
+
# best noncycle head for each cycle dependent; (n) in c
|
| 149 |
+
metanode_heads = np.argmax(metanode_head_scores, axis=0)
|
| 150 |
+
# best cycle head for each noncycle dependent; (n) in c
|
| 151 |
+
metanode_deps = np.argmax(metanode_dep_scores, axis=1)
|
| 152 |
+
|
| 153 |
+
# scores of noncycle graph; (n x n) in R
|
| 154 |
+
subscores = scores[noncycle][:,noncycle]
|
| 155 |
+
# pad to contracted graph; (n+1 x n+1) in R
|
| 156 |
+
subscores = np.pad(subscores, ( (0,1) , (0,1) ), 'constant')
|
| 157 |
+
# set the contracted graph scores of cycle's potential heads; (c x n)[:, (n) in n] in R -> (n) in R
|
| 158 |
+
subscores[-1, :-1] = metanode_head_scores[metanode_heads, np.arange(len(noncycle_locs))]
|
| 159 |
+
# set the contracted graph scores of cycle's potential dependents; (n x c)[(n) in n] in R-> (n) in R
|
| 160 |
+
subscores[:-1,-1] = metanode_dep_scores[np.arange(len(noncycle_locs)), metanode_deps]
|
| 161 |
+
return subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps):
|
| 165 |
+
"""
|
| 166 |
+
Given a partially solved tree with a cycle and a solved subproblem
|
| 167 |
+
for the cycle, build a larger solution without the cycle
|
| 168 |
+
"""
|
| 169 |
+
# head of the cycle; () in n
|
| 170 |
+
#print(contracted_tree)
|
| 171 |
+
cycle_head = contracted_tree[-1]
|
| 172 |
+
# fixed tree: (n) in n+1
|
| 173 |
+
contracted_tree = contracted_tree[:-1]
|
| 174 |
+
# initialize new tree; (t) in 0
|
| 175 |
+
new_tree = -np.ones_like(tree)
|
| 176 |
+
#print(0, new_tree)
|
| 177 |
+
# fixed tree with no heads coming from the cycle: (n) in [0,1]
|
| 178 |
+
contracted_subtree = contracted_tree < len(contracted_tree)
|
| 179 |
+
# add the nodes to the new tree (t)[(n)[(n) in [0,1]] in t] in t = (n)[(n)[(n) in [0,1]] in n] in t
|
| 180 |
+
new_tree[noncycle_locs[contracted_subtree]] = noncycle_locs[contracted_tree[contracted_subtree]]
|
| 181 |
+
#print(1, new_tree)
|
| 182 |
+
# fixed tree with heads coming from the cycle: (n) in [0,1]
|
| 183 |
+
contracted_subtree = np.logical_not(contracted_subtree)
|
| 184 |
+
# add the nodes to the tree (t)[(n)[(n) in [0,1]] in t] in t = (c)[(n)[(n) in [0,1]] in c] in t
|
| 185 |
+
new_tree[noncycle_locs[contracted_subtree]] = cycle_locs[metanode_deps[contracted_subtree]]
|
| 186 |
+
#print(2, new_tree)
|
| 187 |
+
# add the old cycle to the tree; (t)[(c) in t] in t = (t)[(c) in t] in t
|
| 188 |
+
new_tree[cycle_locs] = tree[cycle_locs]
|
| 189 |
+
#print(3, new_tree)
|
| 190 |
+
# root of the cycle; (n)[() in n] in c = () in c
|
| 191 |
+
cycle_root = metanode_heads[cycle_head]
|
| 192 |
+
# add the root of the cycle to the new tree; (t)[(c)[() in c] in t] = (c)[() in c]
|
| 193 |
+
new_tree[cycle_locs[cycle_root]] = noncycle_locs[cycle_head]
|
| 194 |
+
#print(4, new_tree)
|
| 195 |
+
return new_tree
|
| 196 |
+
|
| 197 |
+
def prepare_scores(scores):
|
| 198 |
+
"""
|
| 199 |
+
Alter the scores matrix to avoid self loops and handle the root
|
| 200 |
+
"""
|
| 201 |
+
# prevent self-loops, set up the root location
|
| 202 |
+
np.fill_diagonal(scores, -float('inf')) # prevent self-loops
|
| 203 |
+
scores[0] = -float('inf')
|
| 204 |
+
scores[0,0] = 0
|
| 205 |
+
|
| 206 |
+
def chuliu_edmonds(scores):
|
| 207 |
+
subtree_stack = []
|
| 208 |
+
|
| 209 |
+
prepare_scores(scores)
|
| 210 |
+
tree = np.argmax(scores, axis=1)
|
| 211 |
+
cycles = tarjan(tree)
|
| 212 |
+
|
| 213 |
+
#print(scores)
|
| 214 |
+
#print(cycles)
|
| 215 |
+
|
| 216 |
+
# recursive implementation:
|
| 217 |
+
#if cycles:
|
| 218 |
+
# # t = len(tree); c = len(cycle); n = len(noncycle)
|
| 219 |
+
# # cycles.pop(): locations of cycle; (t) in [0,1]
|
| 220 |
+
# subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)
|
| 221 |
+
# # MST with contraction; (n+1) in n+1
|
| 222 |
+
# contracted_tree = chuliu_edmonds(subscores)
|
| 223 |
+
# tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)
|
| 224 |
+
# unfortunately, while the recursion is simpler to understand, it can get too deep for python's stack limit
|
| 225 |
+
# so instead we make our own recursion, with blackjack and (you know how it goes)
|
| 226 |
+
|
| 227 |
+
while cycles:
|
| 228 |
+
# t = len(tree); c = len(cycle); n = len(noncycle)
|
| 229 |
+
# cycles.pop(): locations of cycle; (t) in [0,1]
|
| 230 |
+
subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)
|
| 231 |
+
subtree_stack.append((tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps))
|
| 232 |
+
|
| 233 |
+
scores = subscores
|
| 234 |
+
prepare_scores(scores)
|
| 235 |
+
tree = np.argmax(scores, axis=1)
|
| 236 |
+
cycles = tarjan(tree)
|
| 237 |
+
|
| 238 |
+
while len(subtree_stack) > 0:
|
| 239 |
+
contracted_tree = tree
|
| 240 |
+
(tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps) = subtree_stack.pop()
|
| 241 |
+
tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)
|
| 242 |
+
|
| 243 |
+
return tree
|
| 244 |
+
|
| 245 |
+
#===============================================================
|
| 246 |
+
def chuliu_edmonds_one_root(scores):
|
| 247 |
+
""""""
|
| 248 |
+
|
| 249 |
+
scores = scores.astype(np.float64)
|
| 250 |
+
tree = chuliu_edmonds(scores)
|
| 251 |
+
roots_to_try = np.where(np.equal(tree[1:], 0))[0]+1
|
| 252 |
+
if len(roots_to_try) == 1:
|
| 253 |
+
return tree
|
| 254 |
+
|
| 255 |
+
#-------------------------------------------------------------
|
| 256 |
+
def set_root(scores, root):
|
| 257 |
+
root_score = scores[root,0]
|
| 258 |
+
scores = np.array(scores)
|
| 259 |
+
scores[1:,0] = -float('inf')
|
| 260 |
+
scores[root] = -float('inf')
|
| 261 |
+
scores[root,0] = 0
|
| 262 |
+
return scores, root_score
|
| 263 |
+
#-------------------------------------------------------------
|
| 264 |
+
|
| 265 |
+
best_score, best_tree = -np.inf, None # This is what's causing it to crash
|
| 266 |
+
for root in roots_to_try:
|
| 267 |
+
_scores, root_score = set_root(scores, root)
|
| 268 |
+
_tree = chuliu_edmonds(_scores)
|
| 269 |
+
tree_probs = _scores[np.arange(len(_scores)), _tree]
|
| 270 |
+
tree_score = (tree_probs).sum()+(root_score) if (tree_probs > -np.inf).all() else -np.inf
|
| 271 |
+
if tree_score > best_score:
|
| 272 |
+
best_score = tree_score
|
| 273 |
+
best_tree = _tree
|
| 274 |
+
try:
|
| 275 |
+
assert best_tree is not None
|
| 276 |
+
except:
|
| 277 |
+
with open('debug.log', 'w') as f:
|
| 278 |
+
f.write('{}: {}, {}\n'.format(tree, scores, roots_to_try))
|
| 279 |
+
f.write('{}: {}, {}, {}\n'.format(_tree, _scores, tree_probs, tree_score))
|
| 280 |
+
raise
|
| 281 |
+
return best_tree
|
stanza/stanza/models/common/constant.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Global constants.
|
| 3 |
+
|
| 4 |
+
These language codes mirror UD language codes when possible
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
class UnknownLanguageError(ValueError):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
# tuples in a list so we can assert that the langcodes are all unique
|
| 13 |
+
# When applicable, we favor the UD decision over any other possible
|
| 14 |
+
# language code or language name
|
| 15 |
+
# ISO 639-1 is out of date, but many of the UD datasets are labeled
|
| 16 |
+
# using the two letter abbreviations, so we add those for non-UD
|
| 17 |
+
# languages in the hopes that we've guessed right if those languages
|
| 18 |
+
# are eventually processed
|
| 19 |
+
lcode2lang_raw = [
|
| 20 |
+
("abq", "Abaza"),
|
| 21 |
+
("ab", "Abkhazian"),
|
| 22 |
+
("aa", "Afar"),
|
| 23 |
+
("af", "Afrikaans"),
|
| 24 |
+
("ak", "Akan"),
|
| 25 |
+
("akk", "Akkadian"),
|
| 26 |
+
("aqz", "Akuntsu"),
|
| 27 |
+
("sq", "Albanian"),
|
| 28 |
+
("am", "Amharic"),
|
| 29 |
+
("grc", "Ancient_Greek"),
|
| 30 |
+
("hbo", "Ancient_Hebrew"),
|
| 31 |
+
("apu", "Apurina"),
|
| 32 |
+
("ar", "Arabic"),
|
| 33 |
+
("arz", "Egyptian_Arabic"),
|
| 34 |
+
("an", "Aragonese"),
|
| 35 |
+
("hy", "Armenian"),
|
| 36 |
+
("as", "Assamese"),
|
| 37 |
+
("aii", "Assyrian"),
|
| 38 |
+
("ast", "Asturian"),
|
| 39 |
+
("av", "Avaric"),
|
| 40 |
+
("ae", "Avestan"),
|
| 41 |
+
("ay", "Aymara"),
|
| 42 |
+
("az", "Azerbaijani"),
|
| 43 |
+
("bm", "Bambara"),
|
| 44 |
+
("ba", "Bashkir"),
|
| 45 |
+
("eu", "Basque"),
|
| 46 |
+
("bar", "Bavarian"),
|
| 47 |
+
("bej", "Beja"),
|
| 48 |
+
("be", "Belarusian"),
|
| 49 |
+
("bn", "Bengali"),
|
| 50 |
+
("bho", "Bhojpuri"),
|
| 51 |
+
("bpy", "Bishnupriya_Manipuri"),
|
| 52 |
+
("bi", "Bislama"),
|
| 53 |
+
("bor", "Bororo"),
|
| 54 |
+
("bs", "Bosnian"),
|
| 55 |
+
("br", "Breton"),
|
| 56 |
+
("bg", "Bulgarian"),
|
| 57 |
+
("bxr", "Buryat"),
|
| 58 |
+
("yue", "Cantonese"),
|
| 59 |
+
("cpg", "Cappadocian"),
|
| 60 |
+
("ca", "Catalan"),
|
| 61 |
+
("ceb", "Cebuano"),
|
| 62 |
+
("km", "Central_Khmer"),
|
| 63 |
+
("ch", "Chamorro"),
|
| 64 |
+
("ce", "Chechen"),
|
| 65 |
+
("ny", "Chichewa"),
|
| 66 |
+
("ckt", "Chukchi"),
|
| 67 |
+
("cv", "Chuvash"),
|
| 68 |
+
("xcl", "Classical_Armenian"),
|
| 69 |
+
("lzh", "Classical_Chinese"),
|
| 70 |
+
("cop", "Coptic"),
|
| 71 |
+
("kw", "Cornish"),
|
| 72 |
+
("co", "Corsican"),
|
| 73 |
+
("cr", "Cree"),
|
| 74 |
+
("hr", "Croatian"),
|
| 75 |
+
("cs", "Czech"),
|
| 76 |
+
("da", "Danish"),
|
| 77 |
+
("dar", "Dargwa"),
|
| 78 |
+
("dv", "Dhivehi"),
|
| 79 |
+
("nl", "Dutch"),
|
| 80 |
+
("dz", "Dzongkha"),
|
| 81 |
+
("egy", "Egyptian"),
|
| 82 |
+
("en", "English"),
|
| 83 |
+
("myv", "Erzya"),
|
| 84 |
+
("eo", "Esperanto"),
|
| 85 |
+
("et", "Estonian"),
|
| 86 |
+
("ee", "Ewe"),
|
| 87 |
+
("ext", "Extremaduran"),
|
| 88 |
+
("fo", "Faroese"),
|
| 89 |
+
("fj", "Fijian"),
|
| 90 |
+
("fi", "Finnish"),
|
| 91 |
+
("fon", "Fon"),
|
| 92 |
+
("fr", "French"),
|
| 93 |
+
("qfn", "Frisian_Dutch"),
|
| 94 |
+
("ff", "Fulah"),
|
| 95 |
+
("gl", "Galician"),
|
| 96 |
+
("lg", "Ganda"),
|
| 97 |
+
("ka", "Georgian"),
|
| 98 |
+
("de", "German"),
|
| 99 |
+
("aln", "Gheg"),
|
| 100 |
+
("bbj", "Ghomálá'"),
|
| 101 |
+
("got", "Gothic"),
|
| 102 |
+
("el", "Greek"),
|
| 103 |
+
("kl", "Greenlandic"),
|
| 104 |
+
("gub", "Guajajara"),
|
| 105 |
+
("gn", "Guarani"),
|
| 106 |
+
("gu", "Gujarati"),
|
| 107 |
+
("gwi", "Gwichin"),
|
| 108 |
+
("ht", "Haitian"),
|
| 109 |
+
("ha", "Hausa"),
|
| 110 |
+
("he", "Hebrew"),
|
| 111 |
+
("hz", "Herero"),
|
| 112 |
+
("azz", "Highland_Puebla_Nahuatl"),
|
| 113 |
+
("hil", "Hiligaynon"),
|
| 114 |
+
("hi", "Hindi"),
|
| 115 |
+
("qhe", "Hindi_English"),
|
| 116 |
+
("ho", "Hiri_Motu"),
|
| 117 |
+
("hit", "Hittite"),
|
| 118 |
+
("hu", "Hungarian"),
|
| 119 |
+
("is", "Icelandic"),
|
| 120 |
+
("io", "Ido"),
|
| 121 |
+
("ig", "Igbo"),
|
| 122 |
+
("ilo", "Ilocano"),
|
| 123 |
+
("arc", "Imperial_Aramaic"),
|
| 124 |
+
("id", "Indonesian"),
|
| 125 |
+
("iu", "Inuktitut"),
|
| 126 |
+
("ik", "Inupiaq"),
|
| 127 |
+
("ga", "Irish"),
|
| 128 |
+
("it", "Italian"),
|
| 129 |
+
("ja", "Japanese"),
|
| 130 |
+
("jv", "Javanese"),
|
| 131 |
+
("urb", "Kaapor"),
|
| 132 |
+
("kab", "Kabyle"),
|
| 133 |
+
("xnr", "Kangri"),
|
| 134 |
+
("kn", "Kannada"),
|
| 135 |
+
("kr", "Kanuri"),
|
| 136 |
+
("pam", "Kapampangan"),
|
| 137 |
+
("krl", "Karelian"),
|
| 138 |
+
("arr", "Karo"),
|
| 139 |
+
("ks", "Kashmiri"),
|
| 140 |
+
("kk", "Kazakh"),
|
| 141 |
+
("kfm", "Khunsari"),
|
| 142 |
+
("quc", "Kiche"),
|
| 143 |
+
("cgg", "Kiga"),
|
| 144 |
+
("ki", "Kikuyu"),
|
| 145 |
+
("rw", "Kinyarwanda"),
|
| 146 |
+
("ky", "Kyrgyz"),
|
| 147 |
+
("kv", "Komi"),
|
| 148 |
+
("koi", "Komi_Permyak"),
|
| 149 |
+
("kpv", "Komi_Zyrian"),
|
| 150 |
+
("kg", "Kongo"),
|
| 151 |
+
("ko", "Korean"),
|
| 152 |
+
("ku", "Kurdish"),
|
| 153 |
+
("kmr", "Kurmanji"),
|
| 154 |
+
("kj", "Kwanyama"),
|
| 155 |
+
("lad", "Ladino"),
|
| 156 |
+
("lo", "Lao"),
|
| 157 |
+
("ltg", "Latgalian"),
|
| 158 |
+
("la", "Latin"),
|
| 159 |
+
("lv", "Latvian"),
|
| 160 |
+
("lij", "Ligurian"),
|
| 161 |
+
("li", "Limburgish"),
|
| 162 |
+
("ln", "Lingala"),
|
| 163 |
+
("lt", "Lithuanian"),
|
| 164 |
+
("liv", "Livonian"),
|
| 165 |
+
("olo", "Livvi"),
|
| 166 |
+
("nds", "Low_Saxon"),
|
| 167 |
+
("lu", "Luba_Katanga"),
|
| 168 |
+
("lb", "Luxembourgish"),
|
| 169 |
+
("mk", "Macedonian"),
|
| 170 |
+
("jaa", "Madi"),
|
| 171 |
+
("mag", "Magahi"),
|
| 172 |
+
("qaf", "Maghrebi_Arabic_French"),
|
| 173 |
+
("mai", "Maithili"),
|
| 174 |
+
("mpu", "Makurap"),
|
| 175 |
+
("mg", "Malagasy"),
|
| 176 |
+
("ms", "Malay"),
|
| 177 |
+
("ml", "Malayalam"),
|
| 178 |
+
("mt", "Maltese"),
|
| 179 |
+
("mjl", "Mandyali"),
|
| 180 |
+
("gv", "Manx"),
|
| 181 |
+
("mi", "Maori"),
|
| 182 |
+
("mr", "Marathi"),
|
| 183 |
+
("mh", "Marshallese"),
|
| 184 |
+
("mzn", "Mazandarani"),
|
| 185 |
+
("gun", "Mbya_Guarani"),
|
| 186 |
+
("enm", "Middle_English"),
|
| 187 |
+
("frm", "Middle_French"),
|
| 188 |
+
("min", "Minangkabau"),
|
| 189 |
+
("xmf", "Mingrelian"),
|
| 190 |
+
("mwl", "Mirandese"),
|
| 191 |
+
("mdf", "Moksha"),
|
| 192 |
+
("mn", "Mongolian"),
|
| 193 |
+
("mos", "Mossi"),
|
| 194 |
+
("myu", "Munduruku"),
|
| 195 |
+
("my", "Myanmar"),
|
| 196 |
+
("nqo", "N'Ko"),
|
| 197 |
+
("nah", "Nahuatl"),
|
| 198 |
+
("pcm", "Naija"),
|
| 199 |
+
("na", "Nauru"),
|
| 200 |
+
("nv", "Navajo"),
|
| 201 |
+
("nyq", "Nayini"),
|
| 202 |
+
("ng", "Ndonga"),
|
| 203 |
+
("nap", "Neapolitan"),
|
| 204 |
+
("ne", "Nepali"),
|
| 205 |
+
("new", "Newar"),
|
| 206 |
+
("yrl", "Nheengatu"),
|
| 207 |
+
("nyn", "Nkore"),
|
| 208 |
+
("frr", "North_Frisian"),
|
| 209 |
+
("nd", "North_Ndebele"),
|
| 210 |
+
("sme", "North_Sami"),
|
| 211 |
+
("nso", "Northern_Sotho"),
|
| 212 |
+
("gya", "Northwest_Gbaya"),
|
| 213 |
+
("nb", "Norwegian_Bokmaal"),
|
| 214 |
+
("nn", "Norwegian_Nynorsk"),
|
| 215 |
+
("ii", "Nuosu"),
|
| 216 |
+
("oc", "Occitan"),
|
| 217 |
+
("or", "Odia"),
|
| 218 |
+
("oj", "Ojibwa"),
|
| 219 |
+
("cu", "Old_Church_Slavonic"),
|
| 220 |
+
("orv", "Old_East_Slavic"),
|
| 221 |
+
("ang", "Old_English"),
|
| 222 |
+
("fro", "Old_French"),
|
| 223 |
+
("sga", "Old_Irish"),
|
| 224 |
+
("ojp", "Old_Japanese"),
|
| 225 |
+
("otk", "Old_Turkish"),
|
| 226 |
+
("om", "Oromo"),
|
| 227 |
+
("os", "Ossetian"),
|
| 228 |
+
("ota", "Ottoman_Turkish"),
|
| 229 |
+
("pi", "Pali"),
|
| 230 |
+
("ps", "Pashto"),
|
| 231 |
+
("pad", "Paumari"),
|
| 232 |
+
("fa", "Persian"),
|
| 233 |
+
("pay", "Pesh"),
|
| 234 |
+
("xpg", "Phrygian"),
|
| 235 |
+
("pbv", "Pnar"),
|
| 236 |
+
("pl", "Polish"),
|
| 237 |
+
("qpm", "Pomak"),
|
| 238 |
+
("pnt", "Pontic"),
|
| 239 |
+
("pt", "Portuguese"),
|
| 240 |
+
("pra", "Prakrit"),
|
| 241 |
+
("pa", "Punjabi"),
|
| 242 |
+
("qu", "Quechua"),
|
| 243 |
+
("rhg", "Rohingya"),
|
| 244 |
+
("ro", "Romanian"),
|
| 245 |
+
("rm", "Romansh"),
|
| 246 |
+
("rn", "Rundi"),
|
| 247 |
+
("ru", "Russian"),
|
| 248 |
+
("sm", "Samoan"),
|
| 249 |
+
("sg", "Sango"),
|
| 250 |
+
("sa", "Sanskrit"),
|
| 251 |
+
("skr", "Saraiki"),
|
| 252 |
+
("sc", "Sardinian"),
|
| 253 |
+
("sco", "Scots"),
|
| 254 |
+
("gd", "Scottish_Gaelic"),
|
| 255 |
+
("sr", "Serbian"),
|
| 256 |
+
("sn", "Shona"),
|
| 257 |
+
("zh-hans", "Simplified_Chinese"),
|
| 258 |
+
("sd", "Sindhi"),
|
| 259 |
+
("si", "Sinhala"),
|
| 260 |
+
("sms", "Skolt_Sami"),
|
| 261 |
+
("sk", "Slovak"),
|
| 262 |
+
("sl", "Slovenian"),
|
| 263 |
+
("soj", "Soi"),
|
| 264 |
+
("so", "Somali"),
|
| 265 |
+
("ckb", "Sorani"),
|
| 266 |
+
("ajp", "South_Levantine_Arabic"),
|
| 267 |
+
("nr", "South_Ndebele"),
|
| 268 |
+
("st", "Southern_Sotho"),
|
| 269 |
+
("es", "Spanish"),
|
| 270 |
+
("ssp", "Spanish_Sign_Language"),
|
| 271 |
+
("su", "Sundanese"),
|
| 272 |
+
("sw", "Swahili"),
|
| 273 |
+
("ss", "Swati"),
|
| 274 |
+
("sv", "Swedish"),
|
| 275 |
+
("swl", "Swedish_Sign_Language"),
|
| 276 |
+
("gsw", "Swiss_German"),
|
| 277 |
+
("syr", "Syriac"),
|
| 278 |
+
("tl", "Tagalog"),
|
| 279 |
+
("ty", "Tahitian"),
|
| 280 |
+
("tg", "Tajik"),
|
| 281 |
+
("ta", "Tamil"),
|
| 282 |
+
("tt", "Tatar"),
|
| 283 |
+
("eme", "Teko"),
|
| 284 |
+
("te", "Telugu"),
|
| 285 |
+
("qte", "Telugu_English"),
|
| 286 |
+
("th", "Thai"),
|
| 287 |
+
("bo", "Tibetan"),
|
| 288 |
+
("ti", "Tigrinya"),
|
| 289 |
+
("to", "Tonga"),
|
| 290 |
+
("zh-hant", "Traditional_Chinese"),
|
| 291 |
+
("ts", "Tsonga"),
|
| 292 |
+
("tn", "Tswana"),
|
| 293 |
+
("tpn", "Tupinamba"),
|
| 294 |
+
("tr", "Turkish"),
|
| 295 |
+
("qtd", "Turkish_German"),
|
| 296 |
+
("tk", "Turkmen"),
|
| 297 |
+
("tw", "Twi"),
|
| 298 |
+
("uk", "Ukrainian"),
|
| 299 |
+
("xum", "Umbrian"),
|
| 300 |
+
("hsb", "Upper_Sorbian"),
|
| 301 |
+
("ur", "Urdu"),
|
| 302 |
+
("ug", "Uyghur"),
|
| 303 |
+
("uz", "Uzbek"),
|
| 304 |
+
("ve", "Venda"),
|
| 305 |
+
("vep", "Veps"),
|
| 306 |
+
("vi", "Vietnamese"),
|
| 307 |
+
("vo", "Volapük"),
|
| 308 |
+
("wa", "Walloon"),
|
| 309 |
+
("war", "Waray"),
|
| 310 |
+
("wbp", "Warlpiri"),
|
| 311 |
+
("cy", "Welsh"),
|
| 312 |
+
("hyw", "Western_Armenian"),
|
| 313 |
+
("fy", "Western_Frisian"),
|
| 314 |
+
("nhi", "Western_Sierra_Puebla_Nahuatl"),
|
| 315 |
+
("wo", "Wolof"),
|
| 316 |
+
("xav", "Xavante"),
|
| 317 |
+
("xh", "Xhosa"),
|
| 318 |
+
("sjo", "Xibe"),
|
| 319 |
+
("sah", "Yakut"),
|
| 320 |
+
("yi", "Yiddish"),
|
| 321 |
+
("yo", "Yoruba"),
|
| 322 |
+
("ess", "Yupik"),
|
| 323 |
+
("say", "Zaar"),
|
| 324 |
+
("zza", "Zazaki"),
|
| 325 |
+
("zea", "Zeelandic"),
|
| 326 |
+
("za", "Zhuang"),
|
| 327 |
+
("zu", "Zulu"),
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
# build the dictionary, checking for duplicate language codes
|
| 331 |
+
lcode2lang = {}
|
| 332 |
+
for code, language in lcode2lang_raw:
|
| 333 |
+
assert code not in lcode2lang
|
| 334 |
+
lcode2lang[code] = language
|
| 335 |
+
|
| 336 |
+
# invert the dictionary, checking for possible duplicate language names
|
| 337 |
+
lang2lcode = {}
|
| 338 |
+
for code, language in lcode2lang_raw:
|
| 339 |
+
assert language not in lang2lcode
|
| 340 |
+
lang2lcode[language] = code
|
| 341 |
+
|
| 342 |
+
# check that nothing got clobbered
|
| 343 |
+
assert len(lcode2lang_raw) == len(lcode2lang)
|
| 344 |
+
assert len(lcode2lang_raw) == len(lang2lcode)
|
| 345 |
+
|
| 346 |
+
# some of the two letter langcodes get used elsewhere as three letters
|
| 347 |
+
# for example, Wolof is abbreviated "wo" in UD, but "wol" in Masakhane NER
|
| 348 |
+
two_to_three_letters_raw = (
|
| 349 |
+
("bm", "bam"),
|
| 350 |
+
("ee", "ewe"),
|
| 351 |
+
("ha", "hau"),
|
| 352 |
+
("ig", "ibo"),
|
| 353 |
+
("rw", "kin"),
|
| 354 |
+
("lg", "lug"),
|
| 355 |
+
("ny", "nya"),
|
| 356 |
+
("sn", "sna"),
|
| 357 |
+
("sw", "swa"),
|
| 358 |
+
("tn", "tsn"),
|
| 359 |
+
("tw", "twi"),
|
| 360 |
+
("wo", "wol"),
|
| 361 |
+
("xh", "xho"),
|
| 362 |
+
("yo", "yor"),
|
| 363 |
+
("zu", "zul"),
|
| 364 |
+
|
| 365 |
+
# this is a weird case where a 2 letter code was available,
|
| 366 |
+
# but UD used the 3 letter code instead
|
| 367 |
+
("se", "sme"),
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
for two, three in two_to_three_letters_raw:
|
| 371 |
+
if two in lcode2lang:
|
| 372 |
+
assert two in lcode2lang
|
| 373 |
+
assert three not in lcode2lang
|
| 374 |
+
assert three not in lang2lcode
|
| 375 |
+
lang2lcode[three] = two
|
| 376 |
+
lcode2lang[three] = lcode2lang[two]
|
| 377 |
+
elif three in lcode2lang:
|
| 378 |
+
assert three in lcode2lang
|
| 379 |
+
assert two not in lcode2lang
|
| 380 |
+
assert two not in lang2lcode
|
| 381 |
+
lang2lcode[two] = three
|
| 382 |
+
lcode2lang[two] = lcode2lang[three]
|
| 383 |
+
else:
|
| 384 |
+
raise AssertionError("Found a proposed alias %s -> %s when neither code was already known" % (two, three))
|
| 385 |
+
|
| 386 |
+
two_to_three_letters = {
|
| 387 |
+
two: three for two, three in two_to_three_letters_raw
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
three_to_two_letters = {
|
| 391 |
+
three: two for two, three in two_to_three_letters_raw
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
assert len(two_to_three_letters) == len(two_to_three_letters_raw)
|
| 395 |
+
assert len(three_to_two_letters) == len(two_to_three_letters_raw)
|
| 396 |
+
|
| 397 |
+
# additional useful code to language mapping
|
| 398 |
+
# added after dict invert to avoid conflict
|
| 399 |
+
lcode2lang['nb'] = 'Norwegian' # Norwegian Bokmall mapped to default norwegian
|
| 400 |
+
lcode2lang['no'] = 'Norwegian'
|
| 401 |
+
lcode2lang['zh'] = 'Simplified_Chinese'
|
| 402 |
+
|
| 403 |
+
extra_lang_to_lcodes = [
|
| 404 |
+
("ab", "Abkhaz"),
|
| 405 |
+
("gsw", "Alemannic"),
|
| 406 |
+
("my", "Burmese"),
|
| 407 |
+
("ckb", "Central_Kurdish"),
|
| 408 |
+
("ny", "Chewa"),
|
| 409 |
+
("zh", "Chinese"),
|
| 410 |
+
("za", "Chuang"),
|
| 411 |
+
("dv", "Divehi"),
|
| 412 |
+
("eme", "Emerillon"),
|
| 413 |
+
("lij", "Genoese"),
|
| 414 |
+
("ga", "Gaelic"),
|
| 415 |
+
("ne", "Gorkhali"),
|
| 416 |
+
("ht", "Haitian_Creole"),
|
| 417 |
+
("ilo", "Ilokano"),
|
| 418 |
+
("nr", "isiNdebele"),
|
| 419 |
+
("xh", "isiXhosa"),
|
| 420 |
+
("zu", "isiZulu"),
|
| 421 |
+
("jaa", "Jamamadí"),
|
| 422 |
+
("kab", "Kabylian"),
|
| 423 |
+
("kl", "Kalaallisut"),
|
| 424 |
+
("km", "Khmer"),
|
| 425 |
+
("ky", "Kirghiz"),
|
| 426 |
+
("lb", "Letzeburgesch"),
|
| 427 |
+
("lg", "Luganda"),
|
| 428 |
+
("jaa", "Madí"),
|
| 429 |
+
("dv", "Maldivian"),
|
| 430 |
+
("mjl", "Mandeali"),
|
| 431 |
+
("skr", "Multani"),
|
| 432 |
+
("nb", "Norwegian"),
|
| 433 |
+
("ny", "Nyanja"),
|
| 434 |
+
("sga", "Old_Gaelic"),
|
| 435 |
+
("or", "Oriya"),
|
| 436 |
+
("arr", "Ramarama"),
|
| 437 |
+
("sah", "Sakha"),
|
| 438 |
+
("nso", "Sepedi"),
|
| 439 |
+
("tn", "Setswana"),
|
| 440 |
+
("ii", "Sichuan_Yi"),
|
| 441 |
+
("si", "Sinhalese"),
|
| 442 |
+
("ss", "Siswati"),
|
| 443 |
+
("soj", "Sohi"),
|
| 444 |
+
("st", "Sesotho"),
|
| 445 |
+
("ve", "Tshivenda"),
|
| 446 |
+
("ts", "Xitsonga"),
|
| 447 |
+
("fy", "West_Frisian"),
|
| 448 |
+
("zza", "Zaza"),
|
| 449 |
+
]
|
| 450 |
+
|
| 451 |
+
for code, language in extra_lang_to_lcodes:
|
| 452 |
+
assert language not in lang2lcode
|
| 453 |
+
assert code in lcode2lang
|
| 454 |
+
lang2lcode[language] = code
|
| 455 |
+
|
| 456 |
+
# treebank names changed from Old Russian to Old East Slavic in 2.8
|
| 457 |
+
lang2lcode['Old_Russian'] = 'orv'
|
| 458 |
+
|
| 459 |
+
# build a lowercase map from language to langcode
|
| 460 |
+
langlower2lcode = {}
|
| 461 |
+
for k in lang2lcode:
|
| 462 |
+
langlower2lcode[k.lower()] = lang2lcode[k]
|
| 463 |
+
|
| 464 |
+
treebank_special_cases = {
|
| 465 |
+
"UD_Chinese-Beginner": "zh-hans_beginner",
|
| 466 |
+
"UD_Chinese-GSDSimp": "zh-hans_gsdsimp",
|
| 467 |
+
"UD_Chinese-GSD": "zh-hant_gsd",
|
| 468 |
+
"UD_Chinese-HK": "zh-hant_hk",
|
| 469 |
+
"UD_Chinese-CFL": "zh-hans_cfl",
|
| 470 |
+
"UD_Chinese-PatentChar": "zh-hans_patentchar",
|
| 471 |
+
"UD_Chinese-PUD": "zh-hant_pud",
|
| 472 |
+
"UD_Norwegian-Bokmaal": "nb_bokmaal",
|
| 473 |
+
"UD_Norwegian-Nynorsk": "nn_nynorsk",
|
| 474 |
+
"UD_Norwegian-NynorskLIA": "nn_nynorsklia",
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
SHORTNAME_RE = re.compile("^[a-z-]+_[a-z0-9-_]+$")
|
| 478 |
+
|
| 479 |
+
def langcode_to_lang(lcode):
|
| 480 |
+
if lcode in lcode2lang:
|
| 481 |
+
return lcode2lang[lcode]
|
| 482 |
+
elif lcode.lower() in lcode2lang:
|
| 483 |
+
return lcode2lang[lcode.lower()]
|
| 484 |
+
else:
|
| 485 |
+
return lcode
|
| 486 |
+
|
| 487 |
+
def pretty_langcode_to_lang(lcode):
|
| 488 |
+
lang = langcode_to_lang(lcode)
|
| 489 |
+
lang = lang.replace("_", " ")
|
| 490 |
+
if lang == 'Simplified Chinese':
|
| 491 |
+
lang = 'Chinese (Simplified)'
|
| 492 |
+
elif lang == 'Traditional Chinese':
|
| 493 |
+
lang = 'Chinese (Traditional)'
|
| 494 |
+
return lang
|
| 495 |
+
|
| 496 |
+
def lang_to_langcode(lang):
|
| 497 |
+
if lang in lang2lcode:
|
| 498 |
+
lcode = lang2lcode[lang]
|
| 499 |
+
elif lang.lower() in langlower2lcode:
|
| 500 |
+
lcode = langlower2lcode[lang.lower()]
|
| 501 |
+
elif lang in lcode2lang:
|
| 502 |
+
lcode = lang
|
| 503 |
+
elif lang.lower() in lcode2lang:
|
| 504 |
+
lcode = lang.lower()
|
| 505 |
+
else:
|
| 506 |
+
raise UnknownLanguageError("Unable to find language code for %s" % lang)
|
| 507 |
+
return lcode
|
| 508 |
+
|
| 509 |
+
RIGHT_TO_LEFT = set(["ar", "arc", "az", "ckb", "dv", "ff", "he", "ku", "mzn", "nqo", "ps", "fa", "rhg", "sd", "syr", "ur"])
|
| 510 |
+
|
| 511 |
+
def is_right_to_left(lang):
|
| 512 |
+
"""
|
| 513 |
+
Covers all the RtL languages we support, as well as many we don't.
|
| 514 |
+
|
| 515 |
+
If a language is left out, please let us know!
|
| 516 |
+
"""
|
| 517 |
+
lcode = lang_to_langcode(lang)
|
| 518 |
+
return lcode in RIGHT_TO_LEFT
|
| 519 |
+
|
| 520 |
+
def treebank_to_short_name(treebank):
|
| 521 |
+
""" Convert treebank name to short code. """
|
| 522 |
+
if treebank in treebank_special_cases:
|
| 523 |
+
return treebank_special_cases.get(treebank)
|
| 524 |
+
if SHORTNAME_RE.match(treebank):
|
| 525 |
+
lang, corpus = treebank.split("_", 1)
|
| 526 |
+
lang = lang_to_langcode(lang)
|
| 527 |
+
return lang + "_" + corpus
|
| 528 |
+
|
| 529 |
+
if treebank.startswith('UD_'):
|
| 530 |
+
treebank = treebank[3:]
|
| 531 |
+
# special case starting with zh in case the input is an already-converted ZH treebank
|
| 532 |
+
if treebank.startswith("zh-hans") or treebank.startswith("zh-hant"):
|
| 533 |
+
splits = (treebank[:len("zh-hans")], treebank[len("zh-hans")+1:])
|
| 534 |
+
else:
|
| 535 |
+
splits = treebank.split('-')
|
| 536 |
+
if len(splits) == 1:
|
| 537 |
+
splits = treebank.split("_", 1)
|
| 538 |
+
assert len(splits) == 2, "Unable to process %s" % treebank
|
| 539 |
+
lang, corpus = splits
|
| 540 |
+
|
| 541 |
+
lcode = lang_to_langcode(lang)
|
| 542 |
+
|
| 543 |
+
short = "{}_{}".format(lcode, corpus.lower())
|
| 544 |
+
return short
|
| 545 |
+
|
| 546 |
+
def treebank_to_langid(treebank):
|
| 547 |
+
""" Convert treebank name to langid """
|
| 548 |
+
short_name = treebank_to_short_name(treebank)
|
| 549 |
+
return short_name.split("_")[0]
|
| 550 |
+
|
stanza/stanza/models/common/count_ner_coverage.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stanza.models.common import pretrain
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
def parse_args():
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument('ners', type=str, nargs='*', help='Which treebanks to run on')
|
| 7 |
+
parser.add_argument('--pretrain', type=str, default="/home/john/stanza_resources/hi/pretrain/hdtb.pt", help='Which pretrain to use')
|
| 8 |
+
parser.set_defaults(ners=["/home/john/stanza/data/ner/hi_fire2013.train.csv",
|
| 9 |
+
"/home/john/stanza/data/ner/hi_fire2013.dev.csv"])
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
return args
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def read_ner(filename):
|
| 15 |
+
words = []
|
| 16 |
+
for line in open(filename).readlines():
|
| 17 |
+
line = line.strip()
|
| 18 |
+
if not line:
|
| 19 |
+
continue
|
| 20 |
+
if line.split("\t")[1] == 'O':
|
| 21 |
+
continue
|
| 22 |
+
words.append(line.split("\t")[0])
|
| 23 |
+
return words
|
| 24 |
+
|
| 25 |
+
def count_coverage(pretrain, words):
|
| 26 |
+
count = 0
|
| 27 |
+
for w in words:
|
| 28 |
+
if w in pretrain.vocab:
|
| 29 |
+
count = count + 1
|
| 30 |
+
return count / len(words)
|
| 31 |
+
|
| 32 |
+
args = parse_args()
|
| 33 |
+
pt = pretrain.Pretrain(args.pretrain)
|
| 34 |
+
for dataset in args.ners:
|
| 35 |
+
words = read_ner(dataset)
|
| 36 |
+
print(dataset)
|
| 37 |
+
print(count_coverage(pt, words))
|
| 38 |
+
print()
|
stanza/stanza/models/common/count_pretrain_coverage.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A simple script to count the fraction of words in a UD dataset which are in a particular pretrain.
|
| 2 |
+
|
| 3 |
+
For example, this script shows that the word2vec Armenian vectors,
|
| 4 |
+
truncated at 250K words, have 75% coverage of the Western Armenian
|
| 5 |
+
dataset, whereas the vectors available here have 88% coverage:
|
| 6 |
+
|
| 7 |
+
https://github.com/ispras-texterra/word-embeddings-eval-hy
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from stanza.models.common import pretrain
|
| 11 |
+
from stanza.utils.conll import CoNLL
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument('treebanks', type=str, nargs='*', help='Which treebanks to run on')
|
| 18 |
+
parser.add_argument('--pretrain', type=str, default="/home/john/extern_data/wordvec/glove/armenian.pt", help='Which pretrain to use')
|
| 19 |
+
parser.set_defaults(treebanks=["/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Western_Armenian-ArmTDP/hyw_armtdp-ud-train.conllu",
|
| 20 |
+
"/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu"])
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
return args
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
args = parse_args()
|
| 26 |
+
pt = pretrain.Pretrain(args.pretrain)
|
| 27 |
+
pt.load()
|
| 28 |
+
print("Pretrain stats: {} vectors, {} dim".format(len(pt.vocab), pt.emb[0].shape[0]))
|
| 29 |
+
|
| 30 |
+
for treebank in args.treebanks:
|
| 31 |
+
print(treebank)
|
| 32 |
+
found = 0
|
| 33 |
+
total = 0
|
| 34 |
+
doc = CoNLL.conll2doc(treebank)
|
| 35 |
+
for sentence in doc.sentences:
|
| 36 |
+
for word in sentence.words:
|
| 37 |
+
total = total + 1
|
| 38 |
+
if word.text in pt.vocab:
|
| 39 |
+
found = found + 1
|
| 40 |
+
|
| 41 |
+
print (found / total)
|
stanza/stanza/models/common/crf.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CRF loss and viterbi decoding.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from numbers import Number
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
import torch.nn.init as init
|
| 11 |
+
|
| 12 |
+
class CRFLoss(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Calculate log-space crf loss, given unary potentials, a transition matrix
|
| 15 |
+
and gold tag sequences.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, num_tag, batch_average=True):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self._transitions = nn.Parameter(torch.zeros(num_tag, num_tag))
|
| 20 |
+
self._batch_average = batch_average # if not batch average, average on all tokens
|
| 21 |
+
|
| 22 |
+
def forward(self, inputs, masks, tag_indices):
|
| 23 |
+
"""
|
| 24 |
+
inputs: batch_size x seq_len x num_tags
|
| 25 |
+
masks: batch_size x seq_len
|
| 26 |
+
tag_indices: batch_size x seq_len
|
| 27 |
+
|
| 28 |
+
@return:
|
| 29 |
+
loss: CRF negative log likelihood on all instances.
|
| 30 |
+
transitions: the transition matrix
|
| 31 |
+
"""
|
| 32 |
+
# TODO: handle <start> and <end> tags
|
| 33 |
+
input_bs, input_sl, input_nc = inputs.size()
|
| 34 |
+
unary_scores = self.crf_unary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)
|
| 35 |
+
binary_scores = self.crf_binary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)
|
| 36 |
+
log_norm = self.crf_log_norm(inputs, masks, tag_indices)
|
| 37 |
+
log_likelihood = unary_scores + binary_scores - log_norm # batch_size
|
| 38 |
+
loss = torch.sum(-log_likelihood)
|
| 39 |
+
if self._batch_average:
|
| 40 |
+
loss = loss / input_bs
|
| 41 |
+
else:
|
| 42 |
+
total = masks.eq(0).sum()
|
| 43 |
+
loss = loss / (total + 1e-8)
|
| 44 |
+
return loss, self._transitions
|
| 45 |
+
|
| 46 |
+
def crf_unary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):
|
| 47 |
+
"""
|
| 48 |
+
@return:
|
| 49 |
+
unary_scores: batch_size
|
| 50 |
+
"""
|
| 51 |
+
flat_inputs = inputs.view(input_bs, -1)
|
| 52 |
+
flat_tag_indices = tag_indices + torch.arange(input_sl, device=tag_indices.device).long().unsqueeze(0) * input_nc
|
| 53 |
+
unary_scores = torch.gather(flat_inputs, 1, flat_tag_indices).view(input_bs, -1)
|
| 54 |
+
unary_scores.masked_fill_(masks, 0)
|
| 55 |
+
return unary_scores.sum(dim=1)
|
| 56 |
+
|
| 57 |
+
def crf_binary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):
|
| 58 |
+
"""
|
| 59 |
+
@return:
|
| 60 |
+
binary_scores: batch_size
|
| 61 |
+
"""
|
| 62 |
+
# get number of transitions
|
| 63 |
+
nt = tag_indices.size(-1) - 1
|
| 64 |
+
start_indices = tag_indices[:, :nt]
|
| 65 |
+
end_indices = tag_indices[:, 1:]
|
| 66 |
+
# flat matrices
|
| 67 |
+
flat_transition_indices = start_indices * input_nc + end_indices
|
| 68 |
+
flat_transition_indices = flat_transition_indices.view(-1)
|
| 69 |
+
flat_transition_matrix = self._transitions.view(-1)
|
| 70 |
+
binary_scores = torch.gather(flat_transition_matrix, 0, flat_transition_indices)\
|
| 71 |
+
.view(input_bs, -1)
|
| 72 |
+
score_masks = masks[:, 1:]
|
| 73 |
+
binary_scores.masked_fill_(score_masks, 0)
|
| 74 |
+
return binary_scores.sum(dim=1)
|
| 75 |
+
|
| 76 |
+
def crf_log_norm(self, inputs, masks, tag_indices):
|
| 77 |
+
"""
|
| 78 |
+
Calculate the CRF partition in log space for each instance, following:
|
| 79 |
+
http://www.cs.columbia.edu/~mcollins/fb.pdf
|
| 80 |
+
@return:
|
| 81 |
+
log_norm: batch_size
|
| 82 |
+
"""
|
| 83 |
+
start_inputs = inputs[:,0,:] # bs x nc
|
| 84 |
+
rest_inputs = inputs[:,1:,:]
|
| 85 |
+
# TODO: technically we need to pay attention to the initial
|
| 86 |
+
# value being masked. Currently we do compensate for the
|
| 87 |
+
# entire row being masked at the end of the operation
|
| 88 |
+
rest_masks = masks[:,1:]
|
| 89 |
+
alphas = start_inputs # bs x nc
|
| 90 |
+
trans = self._transitions.unsqueeze(0) # 1 x nc x nc
|
| 91 |
+
# accumulate alphas in log space
|
| 92 |
+
for i in range(rest_inputs.size(1)):
|
| 93 |
+
transition_scores = alphas.unsqueeze(2) + trans # bs x nc x nc
|
| 94 |
+
new_alphas = rest_inputs[:,i,:] + log_sum_exp(transition_scores, dim=1)
|
| 95 |
+
m = rest_masks[:,i].unsqueeze(1).expand_as(new_alphas) # bs x nc, 1 for padding idx
|
| 96 |
+
# apply masks
|
| 97 |
+
new_alphas.masked_scatter_(m, alphas.masked_select(m))
|
| 98 |
+
alphas = new_alphas
|
| 99 |
+
log_norm = log_sum_exp(alphas, dim=1)
|
| 100 |
+
|
| 101 |
+
# if any row was entirely masked, we just turn its log denominator to 0
|
| 102 |
+
# eg, the empty summation for the denominator will be 1, and its log will be 0
|
| 103 |
+
all_masked = torch.all(masks, dim=1)
|
| 104 |
+
log_norm = log_norm * torch.logical_not(all_masked)
|
| 105 |
+
return log_norm
|
| 106 |
+
|
| 107 |
+
def viterbi_decode(scores, transition_params):
|
| 108 |
+
"""
|
| 109 |
+
Decode a tag sequence with viterbi algorithm.
|
| 110 |
+
scores: seq_len x num_tags (numpy array)
|
| 111 |
+
transition_params: num_tags x num_tags (numpy array)
|
| 112 |
+
@return:
|
| 113 |
+
viterbi: a list of tag ids with highest score
|
| 114 |
+
viterbi_score: the highest score
|
| 115 |
+
"""
|
| 116 |
+
trellis = np.zeros_like(scores)
|
| 117 |
+
backpointers = np.zeros_like(scores, dtype=np.int32)
|
| 118 |
+
trellis[0] = scores[0]
|
| 119 |
+
|
| 120 |
+
for t in range(1, scores.shape[0]):
|
| 121 |
+
v = np.expand_dims(trellis[t-1], 1) + transition_params
|
| 122 |
+
trellis[t] = scores[t] + np.max(v, 0)
|
| 123 |
+
backpointers[t] = np.argmax(v, 0)
|
| 124 |
+
|
| 125 |
+
viterbi = [np.argmax(trellis[-1])]
|
| 126 |
+
for bp in reversed(backpointers[1:]):
|
| 127 |
+
viterbi.append(bp[viterbi[-1]])
|
| 128 |
+
viterbi.reverse()
|
| 129 |
+
viterbi_score = np.max(trellis[-1])
|
| 130 |
+
return viterbi, viterbi_score
|
| 131 |
+
|
| 132 |
+
def log_sum_exp(value, dim=None, keepdim=False):
|
| 133 |
+
"""Numerically stable implementation of the operation
|
| 134 |
+
value.exp().sum(dim, keepdim).log()
|
| 135 |
+
"""
|
| 136 |
+
if dim is not None:
|
| 137 |
+
m, _ = torch.max(value, dim=dim, keepdim=True)
|
| 138 |
+
value0 = value - m
|
| 139 |
+
if keepdim is False:
|
| 140 |
+
m = m.squeeze(dim)
|
| 141 |
+
return m + torch.log(torch.sum(torch.exp(value0),
|
| 142 |
+
dim=dim, keepdim=keepdim))
|
| 143 |
+
else:
|
| 144 |
+
m = torch.max(value)
|
| 145 |
+
sum_exp = torch.sum(torch.exp(value - m))
|
| 146 |
+
if isinstance(sum_exp, Number):
|
| 147 |
+
return m + math.log(sum_exp)
|
| 148 |
+
else:
|
| 149 |
+
return m + torch.log(sum_exp)
|
stanza/stanza/models/common/data.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for data transformations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 11 |
+
from stanza.models.common.doc import HEAD, ID, UPOS
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('stanza')
|
| 14 |
+
|
| 15 |
+
def map_to_ids(tokens, vocab):
|
| 16 |
+
ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]
|
| 17 |
+
return ids
|
| 18 |
+
|
| 19 |
+
def get_long_tensor(tokens_list, batch_size, pad_id=constant.PAD_ID):
|
| 20 |
+
""" Convert (list of )+ tokens to a padded LongTensor. """
|
| 21 |
+
sizes = []
|
| 22 |
+
x = tokens_list
|
| 23 |
+
while isinstance(x[0], list):
|
| 24 |
+
sizes.append(max(len(y) for y in x))
|
| 25 |
+
x = [z for y in x for z in y]
|
| 26 |
+
# TODO: pass in a device parameter and put it directly on the relevant device?
|
| 27 |
+
# that might be faster than creating it and then moving it
|
| 28 |
+
tokens = torch.LongTensor(batch_size, *sizes).fill_(pad_id)
|
| 29 |
+
for i, s in enumerate(tokens_list):
|
| 30 |
+
tokens[i, :len(s)] = torch.LongTensor(s)
|
| 31 |
+
return tokens
|
| 32 |
+
|
| 33 |
+
def get_float_tensor(features_list, batch_size):
|
| 34 |
+
if features_list is None or features_list[0] is None:
|
| 35 |
+
return None
|
| 36 |
+
seq_len = max(len(x) for x in features_list)
|
| 37 |
+
feature_len = len(features_list[0][0])
|
| 38 |
+
features = torch.FloatTensor(batch_size, seq_len, feature_len).zero_()
|
| 39 |
+
for i,f in enumerate(features_list):
|
| 40 |
+
features[i,:len(f),:] = torch.FloatTensor(f)
|
| 41 |
+
return features
|
| 42 |
+
|
| 43 |
+
def sort_all(batch, lens):
|
| 44 |
+
""" Sort all fields by descending order of lens, and return the original indices. """
|
| 45 |
+
if batch == [[]]:
|
| 46 |
+
return [[]], []
|
| 47 |
+
unsorted_all = [lens] + [range(len(lens))] + list(batch)
|
| 48 |
+
sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
|
| 49 |
+
return sorted_all[2:], sorted_all[1]
|
| 50 |
+
|
| 51 |
+
def get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate, desired_ratio=0.1, max_ratio=0.5):
|
| 52 |
+
"""
|
| 53 |
+
Returns X so that if you randomly select X * N sentences, you get 10%
|
| 54 |
+
|
| 55 |
+
The ratio will be chosen in the assumption that the final dataset
|
| 56 |
+
is of size N rather than N + X * N.
|
| 57 |
+
|
| 58 |
+
should_augment_predicate: returns True if the sentence has some
|
| 59 |
+
feature which we may want to change occasionally. for example,
|
| 60 |
+
depparse sentences which end in punct
|
| 61 |
+
can_augment_predicate: in the depparse sentences example, it is
|
| 62 |
+
technically possible for the punct at the end to be the parent
|
| 63 |
+
of some other word in the sentence. in that case, the sentence
|
| 64 |
+
should not be chosen. should be at least as restrictive as
|
| 65 |
+
should_augment_predicate
|
| 66 |
+
"""
|
| 67 |
+
n_data = len(train_data)
|
| 68 |
+
n_should_augment = sum(should_augment_predicate(sentence) for sentence in train_data)
|
| 69 |
+
n_can_augment = sum(can_augment_predicate(sentence) for sentence in train_data)
|
| 70 |
+
n_error = sum(can_augment_predicate(sentence) and not should_augment_predicate(sentence)
|
| 71 |
+
for sentence in train_data)
|
| 72 |
+
if n_error > 0:
|
| 73 |
+
raise AssertionError("can_augment_predicate allowed sentences not allowed by should_augment_predicate")
|
| 74 |
+
|
| 75 |
+
if n_can_augment == 0:
|
| 76 |
+
logger.warning("Found no sentences which matched can_augment_predicate {}".format(can_augment_predicate))
|
| 77 |
+
return 0.0
|
| 78 |
+
n_needed = n_data * desired_ratio - (n_data - n_should_augment)
|
| 79 |
+
# if we want 10%, for example, and more than 10% already matches, we can skip
|
| 80 |
+
if n_needed < 0:
|
| 81 |
+
return 0.0
|
| 82 |
+
ratio = n_needed / n_can_augment
|
| 83 |
+
if ratio > max_ratio:
|
| 84 |
+
return max_ratio
|
| 85 |
+
return ratio
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def should_augment_nopunct_predicate(sentence):
|
| 89 |
+
last_word = sentence[-1]
|
| 90 |
+
return last_word.get(UPOS, None) == 'PUNCT'
|
| 91 |
+
|
| 92 |
+
def can_augment_nopunct_predicate(sentence):
|
| 93 |
+
"""
|
| 94 |
+
Check that the sentence ends with PUNCT and also doesn't have any words which depend on the last word
|
| 95 |
+
"""
|
| 96 |
+
last_word = sentence[-1]
|
| 97 |
+
if last_word.get(UPOS, None) != 'PUNCT':
|
| 98 |
+
return False
|
| 99 |
+
# don't cut off MWT
|
| 100 |
+
if len(last_word[ID]) > 1:
|
| 101 |
+
return False
|
| 102 |
+
if any(len(word[ID]) == 1 and word[HEAD] == last_word[ID][0] for word in sentence):
|
| 103 |
+
return False
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
def augment_punct(train_data, augment_ratio,
|
| 107 |
+
should_augment_predicate=should_augment_nopunct_predicate,
|
| 108 |
+
can_augment_predicate=can_augment_nopunct_predicate,
|
| 109 |
+
keep_original_sentences=True):
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
Adds extra training data to compensate for some models having all sentences end with PUNCT
|
| 113 |
+
|
| 114 |
+
Some of the models (for example, UD_Hebrew-HTB) have the flaw that
|
| 115 |
+
all of the training sentences end with PUNCT. The model therefore
|
| 116 |
+
learns to finish every sentence with punctuation, even if it is
|
| 117 |
+
given a sentence with non-punct at the end.
|
| 118 |
+
|
| 119 |
+
One simple way to fix this is to train on some fraction of training data with punct.
|
| 120 |
+
|
| 121 |
+
Params:
|
| 122 |
+
train_data: list of list of dicts, eg a conll doc
|
| 123 |
+
augment_ratio: the fraction to augment. if None, a best guess is made to get to 10%
|
| 124 |
+
|
| 125 |
+
should_augment_predicate: a function which returns T/F if a sentence already ends with not PUNCT
|
| 126 |
+
can_augment_predicate: a function which returns T/F if it makes sense to remove the last PUNCT
|
| 127 |
+
|
| 128 |
+
TODO: do this dynamically, as part of the DataLoader or elsewhere?
|
| 129 |
+
One complication is the data comes back from the DataLoader as
|
| 130 |
+
tensors & indices, so it is much more complicated to manipulate
|
| 131 |
+
"""
|
| 132 |
+
if len(train_data) == 0:
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
if augment_ratio is None:
|
| 136 |
+
augment_ratio = get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate)
|
| 137 |
+
|
| 138 |
+
if augment_ratio <= 0:
|
| 139 |
+
if keep_original_sentences:
|
| 140 |
+
return list(train_data)
|
| 141 |
+
else:
|
| 142 |
+
return []
|
| 143 |
+
|
| 144 |
+
new_data = []
|
| 145 |
+
for sentence in train_data:
|
| 146 |
+
if can_augment_predicate(sentence):
|
| 147 |
+
if random.random() < augment_ratio and len(sentence) > 1:
|
| 148 |
+
# todo: could deep copy the words
|
| 149 |
+
# or not deep copy any of this
|
| 150 |
+
new_sentence = list(sentence[:-1])
|
| 151 |
+
new_data.append(new_sentence)
|
| 152 |
+
elif keep_original_sentences:
|
| 153 |
+
new_data.append(new_sentence)
|
| 154 |
+
|
| 155 |
+
return new_data
|
stanza/stanza/models/common/doc.py
ADDED
|
@@ -0,0 +1,1741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic data structures
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
from itertools import repeat
|
| 7 |
+
import re
|
| 8 |
+
import json
|
| 9 |
+
import pickle
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
from enum import Enum
|
| 13 |
+
|
| 14 |
+
import networkx as nx
|
| 15 |
+
|
| 16 |
+
from stanza.models.common.stanza_object import StanzaObject
|
| 17 |
+
from stanza.models.common.utils import misc_to_space_after, space_after_to_misc, misc_to_space_before, space_before_to_misc
|
| 18 |
+
from stanza.models.ner.utils import decode_from_bioes
|
| 19 |
+
from stanza.models.constituency import tree_reader
|
| 20 |
+
from stanza.models.coref.coref_chain import CorefMention, CorefChain, CorefAttachment
|
| 21 |
+
|
| 22 |
+
class MWTProcessingType(Enum):
|
| 23 |
+
FLATTEN = 0 # flatten the current token into one ID instead of MWT
|
| 24 |
+
PROCESS = 1 # process the current token as an MWT and expand it as such
|
| 25 |
+
SKIP = 2 # do nothing on this token, simply increment IDs
|
| 26 |
+
|
| 27 |
+
multi_word_token_id = re.compile(r"([0-9]+)-([0-9]+)")
|
| 28 |
+
multi_word_token_misc = re.compile(r".*MWT=Yes.*")
|
| 29 |
+
|
| 30 |
+
MEXP = 'manual_expansion'
|
| 31 |
+
ID = 'id'
|
| 32 |
+
TEXT = 'text'
|
| 33 |
+
LEMMA = 'lemma'
|
| 34 |
+
UPOS = 'upos'
|
| 35 |
+
XPOS = 'xpos'
|
| 36 |
+
FEATS = 'feats'
|
| 37 |
+
HEAD = 'head'
|
| 38 |
+
DEPREL = 'deprel'
|
| 39 |
+
DEPS = 'deps'
|
| 40 |
+
MISC = 'misc'
|
| 41 |
+
NER = 'ner'
|
| 42 |
+
MULTI_NER = 'multi_ner' # will represent tags from multiple NER models
|
| 43 |
+
START_CHAR = 'start_char'
|
| 44 |
+
END_CHAR = 'end_char'
|
| 45 |
+
TYPE = 'type'
|
| 46 |
+
SENTIMENT = 'sentiment'
|
| 47 |
+
CONSTITUENCY = 'constituency'
|
| 48 |
+
COREF_CHAINS = 'coref_chains'
|
| 49 |
+
|
| 50 |
+
# field indices when converting the document to conll
|
| 51 |
+
FIELD_TO_IDX = {ID: 0, TEXT: 1, LEMMA: 2, UPOS: 3, XPOS: 4, FEATS: 5, HEAD: 6, DEPREL: 7, DEPS: 8, MISC: 9}
|
| 52 |
+
FIELD_NUM = len(FIELD_TO_IDX)
|
| 53 |
+
|
| 54 |
+
class DocJSONEncoder(json.JSONEncoder):
|
| 55 |
+
def default(self, obj):
|
| 56 |
+
if isinstance(obj, CorefMention):
|
| 57 |
+
return obj.__dict__
|
| 58 |
+
if isinstance(obj, CorefAttachment):
|
| 59 |
+
return obj.to_json()
|
| 60 |
+
return json.JSONEncoder.default(self, obj)
|
| 61 |
+
|
| 62 |
+
class Document(StanzaObject):
|
| 63 |
+
""" A document class that stores attributes of a document and carries a list of sentences.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, sentences, text=None, comments=None, empty_sentences=None):
|
| 67 |
+
""" Construct a document given a list of sentences in the form of lists of CoNLL-U dicts.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
sentences: a list of sentences, which being a list of token entry, in the form of a CoNLL-U dict.
|
| 71 |
+
text: the raw text of the document.
|
| 72 |
+
comments: A list of list of strings to use as comments on the sentences, either None or the same length as sentences
|
| 73 |
+
"""
|
| 74 |
+
self._sentences = []
|
| 75 |
+
self._lang = None
|
| 76 |
+
self._text = text
|
| 77 |
+
self._num_tokens = 0
|
| 78 |
+
self._num_words = 0
|
| 79 |
+
|
| 80 |
+
self._process_sentences(sentences, comments, empty_sentences)
|
| 81 |
+
self._ents = []
|
| 82 |
+
self._coref = []
|
| 83 |
+
if self._text is not None:
|
| 84 |
+
self.build_ents()
|
| 85 |
+
self.mark_whitespace()
|
| 86 |
+
|
| 87 |
+
def mark_whitespace(self):
|
| 88 |
+
for sentence in self._sentences:
|
| 89 |
+
# TODO: pairwise, once we move to minimum 3.10
|
| 90 |
+
for prev_token, next_token in zip(sentence.tokens[:-1], sentence.tokens[1:]):
|
| 91 |
+
whitespace = self._text[prev_token.end_char:next_token.start_char]
|
| 92 |
+
prev_token.spaces_after = whitespace
|
| 93 |
+
for prev_sentence, next_sentence in zip(self._sentences[:-1], self._sentences[1:]):
|
| 94 |
+
prev_token = prev_sentence.tokens[-1]
|
| 95 |
+
next_token = next_sentence.tokens[0]
|
| 96 |
+
whitespace = self._text[prev_token.end_char:next_token.start_char]
|
| 97 |
+
prev_token.spaces_after = whitespace
|
| 98 |
+
if len(self._sentences) > 0 and len(self._sentences[-1].tokens) > 0:
|
| 99 |
+
final_token = self._sentences[-1].tokens[-1]
|
| 100 |
+
whitespace = self._text[final_token.end_char:]
|
| 101 |
+
final_token.spaces_after = whitespace
|
| 102 |
+
if len(self._sentences) > 0 and len(self._sentences[0].tokens) > 0:
|
| 103 |
+
first_token = self._sentences[0].tokens[0]
|
| 104 |
+
whitespace = self._text[:first_token.start_char]
|
| 105 |
+
first_token.spaces_before = whitespace
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def lang(self):
|
| 110 |
+
""" Access the language of this document """
|
| 111 |
+
return self._lang
|
| 112 |
+
|
| 113 |
+
@lang.setter
|
| 114 |
+
def lang(self, value):
|
| 115 |
+
""" Set the language of this document """
|
| 116 |
+
self._lang = value
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def text(self):
|
| 120 |
+
""" Access the raw text for this document. """
|
| 121 |
+
return self._text
|
| 122 |
+
|
| 123 |
+
@text.setter
|
| 124 |
+
def text(self, value):
|
| 125 |
+
""" Set the raw text for this document. """
|
| 126 |
+
self._text = value
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def sentences(self):
|
| 130 |
+
""" Access the list of sentences for this document. """
|
| 131 |
+
return self._sentences
|
| 132 |
+
|
| 133 |
+
@sentences.setter
|
| 134 |
+
def sentences(self, value):
|
| 135 |
+
""" Set the list of tokens for this document. """
|
| 136 |
+
self._sentences = value
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def num_tokens(self):
|
| 140 |
+
""" Access the number of tokens for this document. """
|
| 141 |
+
return self._num_tokens
|
| 142 |
+
|
| 143 |
+
@num_tokens.setter
|
| 144 |
+
def num_tokens(self, value):
|
| 145 |
+
""" Set the number of tokens for this document. """
|
| 146 |
+
self._num_tokens = value
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def num_words(self):
|
| 150 |
+
""" Access the number of words for this document. """
|
| 151 |
+
return self._num_words
|
| 152 |
+
|
| 153 |
+
@num_words.setter
|
| 154 |
+
def num_words(self, value):
|
| 155 |
+
""" Set the number of words for this document. """
|
| 156 |
+
self._num_words = value
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def ents(self):
|
| 160 |
+
""" Access the list of entities in this document. """
|
| 161 |
+
return self._ents
|
| 162 |
+
|
| 163 |
+
@ents.setter
|
| 164 |
+
def ents(self, value):
|
| 165 |
+
""" Set the list of entities in this document. """
|
| 166 |
+
self._ents = value
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def entities(self):
|
| 170 |
+
""" Access the list of entities. This is just an alias of `ents`. """
|
| 171 |
+
return self._ents
|
| 172 |
+
|
| 173 |
+
@entities.setter
|
| 174 |
+
def entities(self, value):
|
| 175 |
+
""" Set the list of entities in this document. """
|
| 176 |
+
self._ents = value
|
| 177 |
+
|
| 178 |
+
def _process_sentences(self, sentences, comments=None, empty_sentences=None):
|
| 179 |
+
self.sentences = []
|
| 180 |
+
if empty_sentences is None:
|
| 181 |
+
empty_sentences = repeat([])
|
| 182 |
+
for sent_idx, (tokens, empty_words) in enumerate(zip(sentences, empty_sentences)):
|
| 183 |
+
try:
|
| 184 |
+
sentence = Sentence(tokens, doc=self, empty_words=empty_words)
|
| 185 |
+
except IndexError as e:
|
| 186 |
+
raise IndexError("Could not process document at sentence %d" % sent_idx) from e
|
| 187 |
+
except ValueError as e:
|
| 188 |
+
tokens = ["|%s|" % t for t in tokens]
|
| 189 |
+
tokens = ", ".join(tokens)
|
| 190 |
+
raise ValueError("Could not process document at sentence %d\n Raw tokens: %s" % (sent_idx, tokens)) from e
|
| 191 |
+
self.sentences.append(sentence)
|
| 192 |
+
begin_idx, end_idx = sentence.tokens[0].start_char, sentence.tokens[-1].end_char
|
| 193 |
+
if all((self.text is not None, begin_idx is not None, end_idx is not None)): sentence.text = self.text[begin_idx: end_idx]
|
| 194 |
+
sentence.index = sent_idx
|
| 195 |
+
|
| 196 |
+
self._count_words()
|
| 197 |
+
|
| 198 |
+
# Add a #text comment to each sentence in a doc if it doesn't already exist
|
| 199 |
+
if not comments:
|
| 200 |
+
comments = [[] for x in self.sentences]
|
| 201 |
+
else:
|
| 202 |
+
comments = [list(x) for x in comments]
|
| 203 |
+
for sentence, sentence_comments in zip(self.sentences, comments):
|
| 204 |
+
# the space after text can occur in treebanks such as the Naija-NSC treebank,
|
| 205 |
+
# which extensively uses `# text_en =` and `# text_ortho`
|
| 206 |
+
if sentence.text and not any(comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text=") for comment in sentence_comments):
|
| 207 |
+
# split/join to handle weird whitespace, especially newlines
|
| 208 |
+
sentence_comments.append("# text = " + ' '.join(sentence.text.split()))
|
| 209 |
+
elif not sentence.text:
|
| 210 |
+
for comment in sentence_comments:
|
| 211 |
+
if comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text="):
|
| 212 |
+
sentence.text = comment.split("=", 1)[-1].strip()
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
for comment in sentence_comments:
|
| 216 |
+
sentence.add_comment(comment)
|
| 217 |
+
|
| 218 |
+
# look for sent_id in the comments
|
| 219 |
+
# if it's there, overwrite the sent_idx id from above
|
| 220 |
+
for comment in sentence_comments:
|
| 221 |
+
if comment.startswith("# sent_id"):
|
| 222 |
+
sentence.sent_id = comment.split("=", 1)[-1].strip()
|
| 223 |
+
break
|
| 224 |
+
else:
|
| 225 |
+
# no sent_id found. add a comment with our enumerated id
|
| 226 |
+
# setting the sent_id on the sentence will automatically add the comment
|
| 227 |
+
sentence.sent_id = str(sentence.index)
|
| 228 |
+
|
| 229 |
+
def _count_words(self):
|
| 230 |
+
"""
|
| 231 |
+
Count the number of tokens and words
|
| 232 |
+
"""
|
| 233 |
+
self.num_tokens = sum([len(sentence.tokens) for sentence in self.sentences])
|
| 234 |
+
self.num_words = sum([len(sentence.words) for sentence in self.sentences])
|
| 235 |
+
|
| 236 |
+
def get(self, fields, as_sentences=False, from_token=False):
|
| 237 |
+
""" Get fields from a list of field names.
|
| 238 |
+
If only one field name (string or singleton list) is provided,
|
| 239 |
+
return a list of that field; if more than one, return a list of list.
|
| 240 |
+
Note that all returned fields are after multi-word expansion.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
fields: name of the fields as a list or a single string
|
| 244 |
+
as_sentences: if True, return the fields as a list of sentences; otherwise as a whole list
|
| 245 |
+
from_token: if True, get the fields from Token; otherwise from Word
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
All requested fields.
|
| 249 |
+
"""
|
| 250 |
+
if isinstance(fields, str):
|
| 251 |
+
fields = [fields]
|
| 252 |
+
assert isinstance(fields, list), "Must provide field names as a list."
|
| 253 |
+
assert len(fields) >= 1, "Must have at least one field."
|
| 254 |
+
|
| 255 |
+
results = []
|
| 256 |
+
for sentence in self.sentences:
|
| 257 |
+
cursent = []
|
| 258 |
+
# decide word or token
|
| 259 |
+
if from_token:
|
| 260 |
+
units = sentence.tokens
|
| 261 |
+
else:
|
| 262 |
+
units = sentence.words
|
| 263 |
+
for unit in units:
|
| 264 |
+
if len(fields) == 1:
|
| 265 |
+
cursent += [getattr(unit, fields[0])]
|
| 266 |
+
else:
|
| 267 |
+
cursent += [[getattr(unit, field) for field in fields]]
|
| 268 |
+
|
| 269 |
+
# decide whether append the results as a sentence or a whole list
|
| 270 |
+
if as_sentences:
|
| 271 |
+
results.append(cursent)
|
| 272 |
+
else:
|
| 273 |
+
results += cursent
|
| 274 |
+
return results
|
| 275 |
+
|
| 276 |
+
def set(self, fields, contents, to_token=False, to_sentence=False):
|
| 277 |
+
"""Set fields based on contents. If only one field (string or
|
| 278 |
+
singleton list) is provided, then a list of content will be
|
| 279 |
+
expected; otherwise a list of list of contents will be expected.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
fields: name of the fields as a list or a single string
|
| 283 |
+
contents: field values to set; total length should be equal to number of words/tokens
|
| 284 |
+
to_token: if True, set field values to tokens; otherwise to words
|
| 285 |
+
|
| 286 |
+
"""
|
| 287 |
+
if isinstance(fields, str):
|
| 288 |
+
fields = [fields]
|
| 289 |
+
assert isinstance(fields, (tuple, list)), "Must provide field names as a list."
|
| 290 |
+
assert isinstance(contents, (tuple, list)), "Must provide contents as a list (one item per line)."
|
| 291 |
+
assert len(fields) >= 1, "Must have at least one field."
|
| 292 |
+
|
| 293 |
+
assert not to_sentence or not to_token, "Both to_token and to_sentence set to True, which is very confusing"
|
| 294 |
+
|
| 295 |
+
if to_sentence:
|
| 296 |
+
assert len(self.sentences) == len(contents), \
|
| 297 |
+
"Contents must have the same length as the sentences"
|
| 298 |
+
for sentence, content in zip(self.sentences, contents):
|
| 299 |
+
if len(fields) == 1:
|
| 300 |
+
setattr(sentence, fields[0], content)
|
| 301 |
+
else:
|
| 302 |
+
for field, piece in zip(fields, content):
|
| 303 |
+
setattr(sentence, field, piece)
|
| 304 |
+
else:
|
| 305 |
+
assert (to_token and self.num_tokens == len(contents)) or self.num_words == len(contents), \
|
| 306 |
+
"Contents must have the same length as the original file."
|
| 307 |
+
|
| 308 |
+
cidx = 0
|
| 309 |
+
for sentence in self.sentences:
|
| 310 |
+
# decide word or token
|
| 311 |
+
if to_token:
|
| 312 |
+
units = sentence.tokens
|
| 313 |
+
else:
|
| 314 |
+
units = sentence.words
|
| 315 |
+
for unit in units:
|
| 316 |
+
if len(fields) == 1:
|
| 317 |
+
setattr(unit, fields[0], contents[cidx])
|
| 318 |
+
else:
|
| 319 |
+
for field, content in zip(fields, contents[cidx]):
|
| 320 |
+
setattr(unit, field, content)
|
| 321 |
+
cidx += 1
|
| 322 |
+
|
| 323 |
+
def set_mwt_expansions(self, expansions,
|
| 324 |
+
fake_dependencies=False,
|
| 325 |
+
process_manual_expanded=None):
|
| 326 |
+
""" Extend the multi-word tokens annotated by tokenizer. A list of list of expansions
|
| 327 |
+
will be expected for each multi-word token. Use `process_manual_expanded` to limit
|
| 328 |
+
processing for tokens marked manually expanded:
|
| 329 |
+
|
| 330 |
+
There are two types of MWT expansions: those with `misc`: `MWT=True`, and those with
|
| 331 |
+
`manual_expansion`: True. The latter of which means that it is an expansion which the
|
| 332 |
+
user manually specified through a postprocessor; the former means that it is a MWT
|
| 333 |
+
which the detector picked out, but needs to be automatically expanded.
|
| 334 |
+
|
| 335 |
+
process_manual_expanded = None - default; doesn't process manually expanded tokens
|
| 336 |
+
= True - process only manually expanded tokens (with `manual_expansion`: True)
|
| 337 |
+
= False - process only tokens explicitly tagged as MWT (`misc`: `MWT=True`)
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
idx_e = 0
|
| 341 |
+
for sentence in self.sentences:
|
| 342 |
+
idx_w = 0
|
| 343 |
+
for token in sentence.tokens:
|
| 344 |
+
idx_w += 1
|
| 345 |
+
is_multi = (len(token.id) > 1)
|
| 346 |
+
is_mwt = (multi_word_token_misc.match(token.misc) if token.misc is not None else None)
|
| 347 |
+
is_manual_expansion = token.manual_expansion
|
| 348 |
+
|
| 349 |
+
perform_mwt_processing = MWTProcessingType.FLATTEN
|
| 350 |
+
|
| 351 |
+
if (process_manual_expanded and is_manual_expansion):
|
| 352 |
+
perform_mwt_processing = MWTProcessingType.PROCESS
|
| 353 |
+
elif (process_manual_expanded==False and is_mwt):
|
| 354 |
+
perform_mwt_processing = MWTProcessingType.PROCESS
|
| 355 |
+
elif (process_manual_expanded==False and is_manual_expansion):
|
| 356 |
+
perform_mwt_processing = MWTProcessingType.SKIP
|
| 357 |
+
elif (process_manual_expanded==None and (is_mwt or is_multi)):
|
| 358 |
+
perform_mwt_processing = MWTProcessingType.PROCESS
|
| 359 |
+
|
| 360 |
+
if perform_mwt_processing == MWTProcessingType.FLATTEN:
|
| 361 |
+
for word in token.words:
|
| 362 |
+
token.id = (idx_w, )
|
| 363 |
+
# delete dependency information
|
| 364 |
+
word.deps = None
|
| 365 |
+
word.head, word.deprel = None, None
|
| 366 |
+
word.id = idx_w
|
| 367 |
+
elif perform_mwt_processing == MWTProcessingType.PROCESS:
|
| 368 |
+
expanded = [x for x in expansions[idx_e].split(' ') if len(x) > 0]
|
| 369 |
+
# in the event the MWT annotator only split the
|
| 370 |
+
# Token into a single Word, we preserve its text
|
| 371 |
+
# otherwise the Token's text is different from its
|
| 372 |
+
# only Word's text
|
| 373 |
+
if len(expanded) == 1:
|
| 374 |
+
expanded = [token.text]
|
| 375 |
+
idx_e += 1
|
| 376 |
+
idx_w_end = idx_w + len(expanded) - 1
|
| 377 |
+
if token.misc: # None can happen when using a prebuilt doc
|
| 378 |
+
token.misc = None if token.misc == 'MWT=Yes' else '|'.join([x for x in token.misc.split('|') if x != 'MWT=Yes'])
|
| 379 |
+
token.id = (idx_w, idx_w_end) if len(expanded) > 1 else (idx_w,)
|
| 380 |
+
token.words = []
|
| 381 |
+
for i, e_word in enumerate(expanded):
|
| 382 |
+
token.words.append(Word(sentence, {ID: idx_w + i, TEXT: e_word}))
|
| 383 |
+
idx_w = idx_w_end
|
| 384 |
+
elif perform_mwt_processing == MWTProcessingType.SKIP:
|
| 385 |
+
token.id = tuple(orig_id + idx_e for orig_id in token.id)
|
| 386 |
+
for i in token.words:
|
| 387 |
+
i.id += idx_e
|
| 388 |
+
idx_w = token.id[-1]
|
| 389 |
+
token.manual_expansion = None
|
| 390 |
+
|
| 391 |
+
# reprocess the words using the new tokens
|
| 392 |
+
sentence.words = []
|
| 393 |
+
for token in sentence.tokens:
|
| 394 |
+
token.sent = sentence
|
| 395 |
+
for word in token.words:
|
| 396 |
+
word.sent = sentence
|
| 397 |
+
word.parent = token
|
| 398 |
+
sentence.words.append(word)
|
| 399 |
+
if token.start_char is not None and token.end_char is not None and "".join(word.text for word in token.words) == token.text:
|
| 400 |
+
start_char = token.start_char
|
| 401 |
+
for word in token.words:
|
| 402 |
+
end_char = start_char + len(word.text)
|
| 403 |
+
word.start_char = start_char
|
| 404 |
+
word.end_char = end_char
|
| 405 |
+
start_char = end_char
|
| 406 |
+
|
| 407 |
+
if fake_dependencies:
|
| 408 |
+
sentence.build_fake_dependencies()
|
| 409 |
+
else:
|
| 410 |
+
sentence.rebuild_dependencies()
|
| 411 |
+
|
| 412 |
+
self._count_words() # update number of words & tokens
|
| 413 |
+
assert idx_e == len(expansions), "{} {}".format(idx_e, len(expansions))
|
| 414 |
+
return
|
| 415 |
+
|
| 416 |
+
def get_mwt_expansions(self, evaluation=False):
|
| 417 |
+
""" Get the multi-word tokens. For training, return a list of
|
| 418 |
+
(multi-word token, extended multi-word token); otherwise, return a list of
|
| 419 |
+
multi-word token only. By default doesn't skip already expanded tokens, but
|
| 420 |
+
`skip_already_expanded` will return only tokens marked as MWT.
|
| 421 |
+
"""
|
| 422 |
+
expansions = []
|
| 423 |
+
for sentence in self.sentences:
|
| 424 |
+
for token in sentence.tokens:
|
| 425 |
+
is_multi = (len(token.id) > 1)
|
| 426 |
+
is_mwt = multi_word_token_misc.match(token.misc) if token.misc is not None else None
|
| 427 |
+
is_manual_expansion = token.manual_expansion
|
| 428 |
+
if (is_multi and not is_manual_expansion) or is_mwt:
|
| 429 |
+
src = token.text
|
| 430 |
+
dst = ' '.join([word.text for word in token.words])
|
| 431 |
+
expansions.append([src, dst])
|
| 432 |
+
if evaluation: expansions = [e[0] for e in expansions]
|
| 433 |
+
return expansions
|
| 434 |
+
|
| 435 |
+
def build_ents(self):
|
| 436 |
+
""" Build the list of entities by iterating over all words. Return all entities as a list. """
|
| 437 |
+
self.ents = []
|
| 438 |
+
for s in self.sentences:
|
| 439 |
+
s_ents = s.build_ents()
|
| 440 |
+
self.ents += s_ents
|
| 441 |
+
return self.ents
|
| 442 |
+
|
| 443 |
+
def sort_features(self):
|
| 444 |
+
""" Sort the features on all the words... useful for prototype treebanks, for example """
|
| 445 |
+
for sentence in self.sentences:
|
| 446 |
+
for word in sentence.words:
|
| 447 |
+
if not word.feats:
|
| 448 |
+
continue
|
| 449 |
+
pieces = word.feats.split("|")
|
| 450 |
+
pieces = sorted(pieces)
|
| 451 |
+
word.feats = "|".join(pieces)
|
| 452 |
+
|
| 453 |
+
def iter_words(self):
|
| 454 |
+
""" An iterator that returns all of the words in this Document. """
|
| 455 |
+
for s in self.sentences:
|
| 456 |
+
yield from s.words
|
| 457 |
+
|
| 458 |
+
def iter_tokens(self):
|
| 459 |
+
""" An iterator that returns all of the tokens in this Document. """
|
| 460 |
+
for s in self.sentences:
|
| 461 |
+
yield from s.tokens
|
| 462 |
+
|
| 463 |
+
def sentence_comments(self):
|
| 464 |
+
""" Returns a list of list of comments for the sentences """
|
| 465 |
+
return [[comment for comment in sentence.comments] for sentence in self.sentences]
|
| 466 |
+
|
| 467 |
+
@property
|
| 468 |
+
def coref(self):
|
| 469 |
+
"""
|
| 470 |
+
Access the coref lists of the document
|
| 471 |
+
"""
|
| 472 |
+
return self._coref
|
| 473 |
+
|
| 474 |
+
@coref.setter
|
| 475 |
+
def coref(self, chains):
|
| 476 |
+
""" Set the document's coref lists """
|
| 477 |
+
self._coref = chains
|
| 478 |
+
self._attach_coref_mentions(chains)
|
| 479 |
+
|
| 480 |
+
def _attach_coref_mentions(self, chains):
|
| 481 |
+
for sentence in self.sentences:
|
| 482 |
+
for word in sentence.words:
|
| 483 |
+
word.coref_chains = []
|
| 484 |
+
|
| 485 |
+
for chain in chains:
|
| 486 |
+
for mention_idx, mention in enumerate(chain.mentions):
|
| 487 |
+
sentence = self.sentences[mention.sentence]
|
| 488 |
+
for word_idx in range(mention.start_word, mention.end_word):
|
| 489 |
+
is_start = word_idx == mention.start_word
|
| 490 |
+
is_end = word_idx == mention.end_word - 1
|
| 491 |
+
is_representative = mention_idx == chain.representative_index
|
| 492 |
+
attachment = CorefAttachment(chain, is_start, is_end, is_representative)
|
| 493 |
+
sentence.words[word_idx].coref_chains.append(attachment)
|
| 494 |
+
|
| 495 |
+
def reindex_sentences(self, start_index):
|
| 496 |
+
for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences):
|
| 497 |
+
sentence.sent_id = str(sent_id)
|
| 498 |
+
|
| 499 |
+
def to_dict(self):
|
| 500 |
+
""" Dumps the whole document into a list of list of dictionary for each token in each sentence in the doc.
|
| 501 |
+
"""
|
| 502 |
+
return [sentence.to_dict() for sentence in self.sentences]
|
| 503 |
+
|
| 504 |
+
def __repr__(self):
|
| 505 |
+
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
|
| 506 |
+
|
| 507 |
+
def __format__(self, spec):
|
| 508 |
+
if spec == 'c':
|
| 509 |
+
return "\n\n".join("{:c}".format(s) for s in self.sentences)
|
| 510 |
+
elif spec == 'C':
|
| 511 |
+
return "\n\n".join("{:C}".format(s) for s in self.sentences)
|
| 512 |
+
else:
|
| 513 |
+
return str(self)
|
| 514 |
+
|
| 515 |
+
def to_serialized(self):
|
| 516 |
+
""" Dumps the whole document including text to a byte array containing a list of list of dictionaries for each token in each sentence in the doc.
|
| 517 |
+
"""
|
| 518 |
+
return pickle.dumps((self.text, self.to_dict(), self.sentence_comments()))
|
| 519 |
+
|
| 520 |
+
@classmethod
|
| 521 |
+
def from_serialized(cls, serialized_string):
|
| 522 |
+
""" Create and initialize a new document from a serialized string generated by Document.to_serialized_string():
|
| 523 |
+
"""
|
| 524 |
+
stuff = pickle.loads(serialized_string)
|
| 525 |
+
if not isinstance(stuff, tuple):
|
| 526 |
+
raise TypeError("Serialized data was not a tuple when building a Document")
|
| 527 |
+
if len(stuff) == 2:
|
| 528 |
+
text, sentences = pickle.loads(serialized_string)
|
| 529 |
+
doc = cls(sentences, text)
|
| 530 |
+
else:
|
| 531 |
+
text, sentences, comments = pickle.loads(serialized_string)
|
| 532 |
+
doc = cls(sentences, text, comments)
|
| 533 |
+
return doc
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class Sentence(StanzaObject):
|
| 537 |
+
""" A sentence class that stores attributes of a sentence and carries a list of tokens.
|
| 538 |
+
"""
|
| 539 |
+
|
| 540 |
+
def __init__(self, tokens, doc=None, empty_words=None):
|
| 541 |
+
""" Construct a sentence given a list of tokens in the form of CoNLL-U dicts.
|
| 542 |
+
"""
|
| 543 |
+
self._tokens = []
|
| 544 |
+
self._words = []
|
| 545 |
+
self._dependencies = []
|
| 546 |
+
self._text = None
|
| 547 |
+
self._ents = []
|
| 548 |
+
self._doc = doc
|
| 549 |
+
self._constituency = None
|
| 550 |
+
self._sentiment = None
|
| 551 |
+
# comments are a list of comment lines occurring before the
|
| 552 |
+
# sentence in a CoNLL-U file. Can be empty
|
| 553 |
+
self._comments = []
|
| 554 |
+
self._doc_id = None
|
| 555 |
+
|
| 556 |
+
# enhanced_dependencies represents the DEPS column
|
| 557 |
+
# this is a networkx MultiDiGraph
|
| 558 |
+
# with edges from the parent to the dependent
|
| 559 |
+
# however, we set it to None until needed, as it is somewhat slow
|
| 560 |
+
self._enhanced_dependencies = None
|
| 561 |
+
self._process_tokens(tokens)
|
| 562 |
+
|
| 563 |
+
if empty_words is not None:
|
| 564 |
+
self._empty_words = [Word(self, entry) for entry in empty_words]
|
| 565 |
+
else:
|
| 566 |
+
self._empty_words = []
|
| 567 |
+
|
| 568 |
+
def _process_tokens(self, tokens):
|
| 569 |
+
st, en = -1, -1
|
| 570 |
+
self.tokens, self.words = [], []
|
| 571 |
+
for i, entry in enumerate(tokens):
|
| 572 |
+
if ID not in entry: # manually set a 1-based id for word if not exist
|
| 573 |
+
entry[ID] = (i+1, )
|
| 574 |
+
if isinstance(entry[ID], int):
|
| 575 |
+
entry[ID] = (entry[ID], )
|
| 576 |
+
if len(entry.get(ID)) > 1: # if this token is a multi-word token
|
| 577 |
+
st, en = entry[ID]
|
| 578 |
+
self.tokens.append(Token(self, entry))
|
| 579 |
+
else: # else this token is a word
|
| 580 |
+
new_word = Word(self, entry)
|
| 581 |
+
if len(self.words) > 0 and self.words[-1].id == new_word.id:
|
| 582 |
+
# this can happen in the following context:
|
| 583 |
+
# a document was created with MWT=Yes to mark that a token should be split
|
| 584 |
+
# and then there was an MWT "expansion" with a single word after that token
|
| 585 |
+
# we replace the Word in the Token assuming that the expansion token might
|
| 586 |
+
# have more information than the Token dict did
|
| 587 |
+
# note that a single word MWT like that can be detected with something like
|
| 588 |
+
# multi_word_token_misc.match(entry.get(MISC)) if entry.get(MISC, None)
|
| 589 |
+
self.words[-1] = new_word
|
| 590 |
+
self.tokens[-1].words[-1] = new_word
|
| 591 |
+
continue
|
| 592 |
+
self.words.append(new_word)
|
| 593 |
+
idx = entry.get(ID)[0]
|
| 594 |
+
if idx <= en:
|
| 595 |
+
self.tokens[-1].words.append(new_word)
|
| 596 |
+
else:
|
| 597 |
+
self.tokens.append(Token(self, entry, words=[new_word]))
|
| 598 |
+
new_word.parent = self.tokens[-1]
|
| 599 |
+
|
| 600 |
+
# put all of the whitespace annotations (if any) on the Tokens instead of the Words
|
| 601 |
+
for token in self.tokens:
|
| 602 |
+
token.consolidate_whitespace()
|
| 603 |
+
self.rebuild_dependencies()
|
| 604 |
+
|
| 605 |
+
def has_enhanced_dependencies(self):
|
| 606 |
+
"""
|
| 607 |
+
Whether or not the enhanced dependencies are part of this sentence
|
| 608 |
+
"""
|
| 609 |
+
return self._enhanced_dependencies is not None and len(self._enhanced_dependencies) > 0
|
| 610 |
+
|
| 611 |
+
@property
|
| 612 |
+
def index(self):
|
| 613 |
+
"""
|
| 614 |
+
Access the index of this sentence within the doc.
|
| 615 |
+
|
| 616 |
+
If multiple docs were processed together,
|
| 617 |
+
the sentence index will continue counting across docs.
|
| 618 |
+
"""
|
| 619 |
+
return self._index
|
| 620 |
+
|
| 621 |
+
@index.setter
|
| 622 |
+
def index(self, value):
|
| 623 |
+
""" Set the sentence's index value. """
|
| 624 |
+
self._index = value
|
| 625 |
+
|
| 626 |
+
@property
|
| 627 |
+
def id(self):
|
| 628 |
+
"""
|
| 629 |
+
Access the index of this sentence within the doc.
|
| 630 |
+
|
| 631 |
+
If multiple docs were processed together,
|
| 632 |
+
the sentence index will continue counting across docs.
|
| 633 |
+
"""
|
| 634 |
+
warnings.warn("Use of sentence.id is deprecated. Please use sentence.index instead", stacklevel=2)
|
| 635 |
+
return self._index
|
| 636 |
+
|
| 637 |
+
@id.setter
|
| 638 |
+
def id(self, value):
|
| 639 |
+
""" Set the sentence's index value. """
|
| 640 |
+
warnings.warn("Use of sentence.id is deprecated. Please use sentence.index instead", stacklevel=2)
|
| 641 |
+
self._index = value
|
| 642 |
+
|
| 643 |
+
@property
|
| 644 |
+
def sent_id(self):
|
| 645 |
+
""" conll-style sent_id Will be set from index if unknown """
|
| 646 |
+
return self._sent_id
|
| 647 |
+
|
| 648 |
+
@sent_id.setter
|
| 649 |
+
def sent_id(self, value):
|
| 650 |
+
""" Set the sentence's sent_id value. """
|
| 651 |
+
self._sent_id = value
|
| 652 |
+
sent_id_comment = "# sent_id = " + str(value)
|
| 653 |
+
for comment_idx, comment in enumerate(self._comments):
|
| 654 |
+
if comment.startswith("# sent_id = "):
|
| 655 |
+
self._comments[comment_idx] = sent_id_comment
|
| 656 |
+
break
|
| 657 |
+
else: # this is intended to be a for/else loop
|
| 658 |
+
self._comments.append(sent_id_comment)
|
| 659 |
+
|
| 660 |
+
@property
|
| 661 |
+
def doc_id(self):
|
| 662 |
+
""" conll-style doc_id Can be left blank if unknown """
|
| 663 |
+
return self._doc_id
|
| 664 |
+
|
| 665 |
+
@doc_id.setter
|
| 666 |
+
def doc_id(self, value):
|
| 667 |
+
""" Set the sentence's doc_id value. """
|
| 668 |
+
self._doc_id = value
|
| 669 |
+
doc_id_comment = "# doc_id = " + str(value)
|
| 670 |
+
for comment_idx, comment in enumerate(self._comments):
|
| 671 |
+
if comment.startswith("# doc_id = "):
|
| 672 |
+
self._comments[comment_idx] = doc_id_comment
|
| 673 |
+
break
|
| 674 |
+
else: # this is intended to be a for/else loop
|
| 675 |
+
self._comments.append(doc_id_comment)
|
| 676 |
+
|
| 677 |
+
@property
|
| 678 |
+
def doc(self):
|
| 679 |
+
""" Access the parent doc of this span. """
|
| 680 |
+
return self._doc
|
| 681 |
+
|
| 682 |
+
@doc.setter
|
| 683 |
+
def doc(self, value):
|
| 684 |
+
""" Set the parent doc of this span. """
|
| 685 |
+
self._doc = value
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def text(self):
|
| 689 |
+
""" Access the raw text for this sentence. """
|
| 690 |
+
return self._text
|
| 691 |
+
|
| 692 |
+
@text.setter
|
| 693 |
+
def text(self, value):
|
| 694 |
+
""" Set the raw text for this sentence. """
|
| 695 |
+
self._text = value
|
| 696 |
+
|
| 697 |
+
@property
|
| 698 |
+
def dependencies(self):
|
| 699 |
+
""" Access list of dependencies for this sentence. """
|
| 700 |
+
return self._dependencies
|
| 701 |
+
|
| 702 |
+
@dependencies.setter
|
| 703 |
+
def dependencies(self, value):
|
| 704 |
+
""" Set the list of dependencies for this sentence. """
|
| 705 |
+
self._dependencies = value
|
| 706 |
+
|
| 707 |
+
@property
|
| 708 |
+
def tokens(self):
|
| 709 |
+
""" Access the list of tokens for this sentence. """
|
| 710 |
+
return self._tokens
|
| 711 |
+
|
| 712 |
+
@tokens.setter
|
| 713 |
+
def tokens(self, value):
|
| 714 |
+
""" Set the list of tokens for this sentence. """
|
| 715 |
+
self._tokens = value
|
| 716 |
+
|
| 717 |
+
@property
|
| 718 |
+
def words(self):
|
| 719 |
+
""" Access the list of words for this sentence. """
|
| 720 |
+
return self._words
|
| 721 |
+
|
| 722 |
+
@words.setter
|
| 723 |
+
def words(self, value):
|
| 724 |
+
""" Set the list of words for this sentence. """
|
| 725 |
+
self._words = value
|
| 726 |
+
|
| 727 |
+
@property
|
| 728 |
+
def empty_words(self):
|
| 729 |
+
""" Access the list of words for this sentence. """
|
| 730 |
+
return self._empty_words
|
| 731 |
+
|
| 732 |
+
@empty_words.setter
|
| 733 |
+
def empty_words(self, value):
|
| 734 |
+
""" Set the list of words for this sentence. """
|
| 735 |
+
self._empty_words = value
|
| 736 |
+
|
| 737 |
+
@property
|
| 738 |
+
def ents(self):
|
| 739 |
+
""" Access the list of entities in this sentence. """
|
| 740 |
+
return self._ents
|
| 741 |
+
|
| 742 |
+
@ents.setter
|
| 743 |
+
def ents(self, value):
|
| 744 |
+
""" Set the list of entities in this sentence. """
|
| 745 |
+
self._ents = value
|
| 746 |
+
|
| 747 |
+
@property
|
| 748 |
+
def entities(self):
|
| 749 |
+
""" Access the list of entities. This is just an alias of `ents`. """
|
| 750 |
+
return self._ents
|
| 751 |
+
|
| 752 |
+
@entities.setter
|
| 753 |
+
def entities(self, value):
|
| 754 |
+
""" Set the list of entities in this sentence. """
|
| 755 |
+
self._ents = value
|
| 756 |
+
|
| 757 |
+
def build_ents(self):
|
| 758 |
+
""" Build the list of entities by iterating over all tokens. Return all entities as a list.
|
| 759 |
+
|
| 760 |
+
Note that unlike other attributes, since NER requires raw text, the actual tagging are always
|
| 761 |
+
performed at and attached to the `Token`s, instead of `Word`s.
|
| 762 |
+
"""
|
| 763 |
+
self.ents = []
|
| 764 |
+
tags = [w.ner for w in self.tokens]
|
| 765 |
+
decoded = decode_from_bioes(tags)
|
| 766 |
+
for e in decoded:
|
| 767 |
+
ent_tokens = self.tokens[e['start']:e['end']+1]
|
| 768 |
+
self.ents.append(Span(tokens=ent_tokens, type=e['type'], doc=self.doc, sent=self))
|
| 769 |
+
return self.ents
|
| 770 |
+
|
| 771 |
+
@property
|
| 772 |
+
def sentiment(self):
|
| 773 |
+
""" Returns the sentiment value for this sentence """
|
| 774 |
+
return self._sentiment
|
| 775 |
+
|
| 776 |
+
@sentiment.setter
|
| 777 |
+
def sentiment(self, value):
|
| 778 |
+
""" Set the sentiment value """
|
| 779 |
+
self._sentiment = value
|
| 780 |
+
sentiment_comment = "# sentiment = " + str(value)
|
| 781 |
+
for comment_idx, comment in enumerate(self._comments):
|
| 782 |
+
if comment.startswith("# sentiment = "):
|
| 783 |
+
self._comments[comment_idx] = sentiment_comment
|
| 784 |
+
break
|
| 785 |
+
else: # this is intended to be a for/else loop
|
| 786 |
+
self._comments.append(sentiment_comment)
|
| 787 |
+
|
| 788 |
+
@property
|
| 789 |
+
def constituency(self):
|
| 790 |
+
""" Returns the constituency tree for this sentence """
|
| 791 |
+
return self._constituency
|
| 792 |
+
|
| 793 |
+
@constituency.setter
|
| 794 |
+
def constituency(self, value):
|
| 795 |
+
"""
|
| 796 |
+
Set the constituency tree
|
| 797 |
+
|
| 798 |
+
This incidentally updates the #constituency comment if it already exists,
|
| 799 |
+
or otherwise creates a new comment # constituency = ...
|
| 800 |
+
"""
|
| 801 |
+
self._constituency = value
|
| 802 |
+
constituency_comment = "# constituency = " + str(value)
|
| 803 |
+
constituency_comment = constituency_comment.replace("\n", "*NL*").replace("\r", "")
|
| 804 |
+
for comment_idx, comment in enumerate(self._comments):
|
| 805 |
+
if comment.startswith("# constituency = "):
|
| 806 |
+
self._comments[comment_idx] = constituency_comment
|
| 807 |
+
break
|
| 808 |
+
else: # this is intended to be a for/else loop
|
| 809 |
+
self._comments.append(constituency_comment)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
@property
|
| 813 |
+
def comments(self):
|
| 814 |
+
""" Returns CoNLL-style comments for this sentence """
|
| 815 |
+
return self._comments
|
| 816 |
+
|
| 817 |
+
def add_comment(self, comment):
|
| 818 |
+
""" Adds a single comment to this sentence.
|
| 819 |
+
|
| 820 |
+
If the comment does not already have # at the start, it will be added.
|
| 821 |
+
"""
|
| 822 |
+
if not comment.startswith("#"):
|
| 823 |
+
comment = "# " + comment
|
| 824 |
+
if comment.startswith("# constituency ="):
|
| 825 |
+
_, tree_text = comment.split("=", 1)
|
| 826 |
+
tree = tree_reader.read_trees(tree_text)
|
| 827 |
+
if len(tree) > 1:
|
| 828 |
+
raise ValueError("Multiple constituency trees for one sentence: %s" % tree_text)
|
| 829 |
+
self._constituency = tree[0]
|
| 830 |
+
self._comments = [x for x in self._comments if not x.startswith("# constituency =")]
|
| 831 |
+
elif comment.startswith("# sentiment ="):
|
| 832 |
+
_, sentiment = comment.split("=", 1)
|
| 833 |
+
sentiment = int(sentiment.strip())
|
| 834 |
+
self._sentiment = sentiment
|
| 835 |
+
self._comments = [x for x in self._comments if not x.startswith("# sentiment =")]
|
| 836 |
+
elif comment.startswith("# sent_id ="):
|
| 837 |
+
_, sent_id = comment.split("=", 1)
|
| 838 |
+
sent_id = sent_id.strip()
|
| 839 |
+
self._sent_id = sent_id
|
| 840 |
+
self._comments = [x for x in self._comments if not x.startswith("# sent_id =")]
|
| 841 |
+
elif comment.startswith("# doc_id ="):
|
| 842 |
+
_, doc_id = comment.split("=", 1)
|
| 843 |
+
doc_id = doc_id.strip()
|
| 844 |
+
self._doc_id = doc_id
|
| 845 |
+
self._comments = [x for x in self._comments if not x.startswith("# doc_id =")]
|
| 846 |
+
self._comments.append(comment)
|
| 847 |
+
|
| 848 |
+
def rebuild_dependencies(self):
|
| 849 |
+
# rebuild dependencies if there is dependency info
|
| 850 |
+
is_complete_dependencies = all(word.head is not None and word.deprel is not None for word in self.words)
|
| 851 |
+
is_complete_words = (len(self.words) >= len(self.tokens)) and (len(self.words) == self.words[-1].id)
|
| 852 |
+
if is_complete_dependencies and is_complete_words: self.build_dependencies()
|
| 853 |
+
|
| 854 |
+
def build_dependencies(self):
|
| 855 |
+
""" Build the dependency graph for this sentence. Each dependency graph entry is
|
| 856 |
+
a list of (head, deprel, word).
|
| 857 |
+
"""
|
| 858 |
+
self.dependencies = []
|
| 859 |
+
for word in self.words:
|
| 860 |
+
if word.head == 0:
|
| 861 |
+
# make a word for the ROOT
|
| 862 |
+
word_entry = {ID: 0, TEXT: "ROOT"}
|
| 863 |
+
head = Word(self, word_entry)
|
| 864 |
+
else:
|
| 865 |
+
# id is index in words list + 1
|
| 866 |
+
try:
|
| 867 |
+
head = self.words[word.head - 1]
|
| 868 |
+
except IndexError as e:
|
| 869 |
+
raise IndexError("Word head {} is not a valid word index for word {}".format(word.head, word.id)) from e
|
| 870 |
+
if word.head != head.id:
|
| 871 |
+
raise ValueError("Dependency tree is incorrectly constructed")
|
| 872 |
+
self.dependencies.append((head, word.deprel, word))
|
| 873 |
+
|
| 874 |
+
def build_fake_dependencies(self):
|
| 875 |
+
self.dependencies = []
|
| 876 |
+
for word_idx, word in enumerate(self.words):
|
| 877 |
+
word.head = word_idx # note that this goes one previous to the index
|
| 878 |
+
word.deprel = "root" if word_idx == 0 else "dep"
|
| 879 |
+
word.deps = "%d:%s" % (word.head, word.deprel)
|
| 880 |
+
self.dependencies.append((word_idx, word.deprel, word))
|
| 881 |
+
|
| 882 |
+
def print_dependencies(self, file=None):
|
| 883 |
+
""" Print the dependencies for this sentence. """
|
| 884 |
+
for dep_edge in self.dependencies:
|
| 885 |
+
print((dep_edge[2].text, dep_edge[0].id, dep_edge[1]), file=file)
|
| 886 |
+
|
| 887 |
+
def dependencies_string(self):
|
| 888 |
+
""" Dump the dependencies for this sentence into string. """
|
| 889 |
+
dep_string = io.StringIO()
|
| 890 |
+
self.print_dependencies(file=dep_string)
|
| 891 |
+
return dep_string.getvalue().strip()
|
| 892 |
+
|
| 893 |
+
def print_tokens(self, file=None):
|
| 894 |
+
""" Print the tokens for this sentence. """
|
| 895 |
+
for tok in self.tokens:
|
| 896 |
+
print(tok.pretty_print(), file=file)
|
| 897 |
+
|
| 898 |
+
def tokens_string(self):
|
| 899 |
+
""" Dump the tokens for this sentence into string. """
|
| 900 |
+
toks_string = io.StringIO()
|
| 901 |
+
self.print_tokens(file=toks_string)
|
| 902 |
+
return toks_string.getvalue().strip()
|
| 903 |
+
|
| 904 |
+
def print_words(self, file=None):
|
| 905 |
+
""" Print the words for this sentence. """
|
| 906 |
+
for word in self.words:
|
| 907 |
+
print(word.pretty_print(), file=file)
|
| 908 |
+
|
| 909 |
+
def words_string(self):
|
| 910 |
+
""" Dump the words for this sentence into string. """
|
| 911 |
+
wrds_string = io.StringIO()
|
| 912 |
+
self.print_words(file=wrds_string)
|
| 913 |
+
return wrds_string.getvalue().strip()
|
| 914 |
+
|
| 915 |
+
def to_dict(self):
|
| 916 |
+
""" Dumps the sentence into a list of dictionary for each token in the sentence.
|
| 917 |
+
"""
|
| 918 |
+
ret = []
|
| 919 |
+
empty_idx = 0
|
| 920 |
+
for token_idx, token in enumerate(self.tokens):
|
| 921 |
+
while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:
|
| 922 |
+
ret.append(self._empty_words[empty_idx].to_dict())
|
| 923 |
+
empty_idx += 1
|
| 924 |
+
ret += token.to_dict()
|
| 925 |
+
for empty_word in self._empty_words[empty_idx:]:
|
| 926 |
+
ret.append(empty_word.to_dict())
|
| 927 |
+
return ret
|
| 928 |
+
|
| 929 |
+
def __repr__(self):
|
| 930 |
+
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
|
| 931 |
+
|
| 932 |
+
def __format__(self, spec):
|
| 933 |
+
if spec != 'c' and spec != 'C':
|
| 934 |
+
return str(self)
|
| 935 |
+
|
| 936 |
+
pieces = []
|
| 937 |
+
empty_idx = 0
|
| 938 |
+
for token_idx, token in enumerate(self.tokens):
|
| 939 |
+
while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:
|
| 940 |
+
pieces.append(self._empty_words[empty_idx].to_conll_text())
|
| 941 |
+
empty_idx += 1
|
| 942 |
+
pieces.append(token.to_conll_text())
|
| 943 |
+
for empty_word in self._empty_words[empty_idx:]:
|
| 944 |
+
pieces.append(empty_word.to_conll_text())
|
| 945 |
+
|
| 946 |
+
if spec == 'c':
|
| 947 |
+
return "\n".join(pieces)
|
| 948 |
+
elif spec == 'C':
|
| 949 |
+
tokens = "\n".join(pieces)
|
| 950 |
+
if len(self.comments) > 0:
|
| 951 |
+
text = "\n".join(self.comments)
|
| 952 |
+
return text + "\n" + tokens
|
| 953 |
+
return tokens
|
| 954 |
+
|
| 955 |
+
def init_from_misc(unit):
|
| 956 |
+
"""Create attributes by parsing from the `misc` field.
|
| 957 |
+
|
| 958 |
+
Also, remove start_char, end_char, and any other values we can set
|
| 959 |
+
from the misc field if applicable, so that we don't repeat ourselves
|
| 960 |
+
"""
|
| 961 |
+
remaining_values = []
|
| 962 |
+
for item in unit._misc.split('|'):
|
| 963 |
+
key_value = item.split('=', 1)
|
| 964 |
+
if len(key_value) == 2:
|
| 965 |
+
# some key_value can not be split
|
| 966 |
+
key, value = key_value
|
| 967 |
+
# start & end char are kept as ints
|
| 968 |
+
if key in (START_CHAR, END_CHAR):
|
| 969 |
+
value = int(value)
|
| 970 |
+
# set attribute
|
| 971 |
+
attr = f'_{key}'
|
| 972 |
+
if hasattr(unit, attr):
|
| 973 |
+
setattr(unit, attr, value)
|
| 974 |
+
continue
|
| 975 |
+
elif key == NER:
|
| 976 |
+
# special case skipping NER for Words, since there is no Word NER field
|
| 977 |
+
continue
|
| 978 |
+
remaining_values.append(item)
|
| 979 |
+
unit._misc = "|".join(remaining_values)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
def dict_to_conll_text(token_dict, id_connector="-"):
|
| 983 |
+
token_conll = ['_' for i in range(FIELD_NUM)]
|
| 984 |
+
misc = []
|
| 985 |
+
for key in token_dict:
|
| 986 |
+
if key == START_CHAR or key == END_CHAR:
|
| 987 |
+
misc.append("{}={}".format(key, token_dict[key]))
|
| 988 |
+
elif key == NER:
|
| 989 |
+
# TODO: potentially need to escape =|\ in the NER
|
| 990 |
+
misc.append("{}={}".format(key, token_dict[key]))
|
| 991 |
+
elif key == COREF_CHAINS:
|
| 992 |
+
chains = token_dict[key]
|
| 993 |
+
if len(chains) > 0:
|
| 994 |
+
misc_chains = []
|
| 995 |
+
for chain in chains:
|
| 996 |
+
if chain.is_start and chain.is_end:
|
| 997 |
+
coref_position = "unit-"
|
| 998 |
+
elif chain.is_start:
|
| 999 |
+
coref_position = "start-"
|
| 1000 |
+
elif chain.is_end:
|
| 1001 |
+
coref_position = "end-"
|
| 1002 |
+
else:
|
| 1003 |
+
coref_position = "middle-"
|
| 1004 |
+
is_representative = "repr-" if chain.is_representative else ""
|
| 1005 |
+
misc_chains.append("%s%sid%d" % (coref_position, is_representative, chain.chain.index))
|
| 1006 |
+
misc.append("{}={}".format(key, ",".join(misc_chains)))
|
| 1007 |
+
elif key == MISC:
|
| 1008 |
+
# avoid appending a blank misc entry.
|
| 1009 |
+
# otherwise the resulting misc field in the conll doc will wind up being blank text
|
| 1010 |
+
# TODO: potentially need to escape =|\ in the MISC as well
|
| 1011 |
+
if token_dict[key]:
|
| 1012 |
+
misc.append(token_dict[key])
|
| 1013 |
+
elif key == ID:
|
| 1014 |
+
token_conll[FIELD_TO_IDX[key]] = id_connector.join([str(x) for x in token_dict[key]]) if isinstance(token_dict[key], tuple) else str(token_dict[key])
|
| 1015 |
+
elif key in FIELD_TO_IDX:
|
| 1016 |
+
token_conll[FIELD_TO_IDX[key]] = str(token_dict[key])
|
| 1017 |
+
if misc:
|
| 1018 |
+
token_conll[FIELD_TO_IDX[MISC]] = "|".join(misc)
|
| 1019 |
+
else:
|
| 1020 |
+
token_conll[FIELD_TO_IDX[MISC]] = '_'
|
| 1021 |
+
# when a word (not mwt token) without head is found, we insert dummy head as required by the UD eval script
|
| 1022 |
+
if '-' not in token_conll[FIELD_TO_IDX[ID]] and '.' not in token_conll[FIELD_TO_IDX[ID]] and HEAD not in token_dict:
|
| 1023 |
+
token_conll[FIELD_TO_IDX[HEAD]] = str(int(token_dict[ID] if isinstance(token_dict[ID], int) else token_dict[ID][0]) - 1) # evaluation script requires head: int
|
| 1024 |
+
return "\t".join(token_conll)
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
class Token(StanzaObject):
|
| 1028 |
+
""" A token class that stores attributes of a token and carries a list of words. A token corresponds to a unit in the raw
|
| 1029 |
+
text. In some languages such as English, a token has a one-to-one mapping to a word, while in other languages such as French,
|
| 1030 |
+
a (multi-word) token might be expanded into multiple words that carry syntactic annotations.
|
| 1031 |
+
"""
|
| 1032 |
+
|
| 1033 |
+
def __init__(self, sentence, token_entry, words=None):
|
| 1034 |
+
"""
|
| 1035 |
+
Construct a token given a dictionary format token entry. Optionally link itself to the corresponding words.
|
| 1036 |
+
The owning sentence must be passed in.
|
| 1037 |
+
"""
|
| 1038 |
+
self._id = token_entry.get(ID)
|
| 1039 |
+
self._text = token_entry.get(TEXT)
|
| 1040 |
+
if not self._id:
|
| 1041 |
+
raise ValueError('id not included for the token')
|
| 1042 |
+
if not self._text:
|
| 1043 |
+
raise ValueError('text not included for the token')
|
| 1044 |
+
self._misc = token_entry.get(MISC, None)
|
| 1045 |
+
self._ner = token_entry.get(NER, None)
|
| 1046 |
+
self._multi_ner = token_entry.get(MULTI_NER, None)
|
| 1047 |
+
self._words = words if words is not None else []
|
| 1048 |
+
self._start_char = token_entry.get(START_CHAR, None)
|
| 1049 |
+
self._end_char = token_entry.get(END_CHAR, None)
|
| 1050 |
+
self._sent = sentence
|
| 1051 |
+
self._mexp = token_entry.get(MEXP, None)
|
| 1052 |
+
self._spaces_before = ""
|
| 1053 |
+
self._spaces_after = " "
|
| 1054 |
+
|
| 1055 |
+
if self._misc is not None:
|
| 1056 |
+
init_from_misc(self)
|
| 1057 |
+
|
| 1058 |
+
@property
|
| 1059 |
+
def id(self):
|
| 1060 |
+
""" Access the index of this token. """
|
| 1061 |
+
return self._id
|
| 1062 |
+
|
| 1063 |
+
@id.setter
|
| 1064 |
+
def id(self, value):
|
| 1065 |
+
""" Set the token's id value. """
|
| 1066 |
+
self._id = value
|
| 1067 |
+
|
| 1068 |
+
@property
|
| 1069 |
+
def manual_expansion(self):
|
| 1070 |
+
""" Access the whether this token was manually expanded. """
|
| 1071 |
+
return self._mexp
|
| 1072 |
+
|
| 1073 |
+
@manual_expansion.setter
|
| 1074 |
+
def manual_expansion(self, value):
|
| 1075 |
+
""" Set the whether this token was manually expanded. """
|
| 1076 |
+
self._mexp = value
|
| 1077 |
+
|
| 1078 |
+
@property
|
| 1079 |
+
def text(self):
|
| 1080 |
+
""" Access the text of this token. Example: 'The' """
|
| 1081 |
+
return self._text
|
| 1082 |
+
|
| 1083 |
+
@text.setter
|
| 1084 |
+
def text(self, value):
|
| 1085 |
+
""" Set the token's text value. Example: 'The' """
|
| 1086 |
+
self._text = value
|
| 1087 |
+
|
| 1088 |
+
@property
|
| 1089 |
+
def misc(self):
|
| 1090 |
+
""" Access the miscellaneousness of this token. """
|
| 1091 |
+
return self._misc
|
| 1092 |
+
|
| 1093 |
+
@misc.setter
|
| 1094 |
+
def misc(self, value):
|
| 1095 |
+
""" Set the token's miscellaneousness value. """
|
| 1096 |
+
self._misc = value if self._is_null(value) == False else None
|
| 1097 |
+
|
| 1098 |
+
def consolidate_whitespace(self):
|
| 1099 |
+
"""
|
| 1100 |
+
Remove whitespace misc annotations from the Words and mark the whitespace on the Tokens
|
| 1101 |
+
"""
|
| 1102 |
+
found_after = False
|
| 1103 |
+
found_before = False
|
| 1104 |
+
num_words = len(self.words)
|
| 1105 |
+
for word_idx, word in enumerate(self.words):
|
| 1106 |
+
misc = word.misc
|
| 1107 |
+
if not misc:
|
| 1108 |
+
continue
|
| 1109 |
+
pieces = misc.split("|")
|
| 1110 |
+
if word_idx == 0:
|
| 1111 |
+
if any(piece.startswith("SpacesBefore=") for piece in pieces):
|
| 1112 |
+
self.spaces_before = misc_to_space_before(misc)
|
| 1113 |
+
found_before = True
|
| 1114 |
+
else:
|
| 1115 |
+
if any(piece.startswith("SpacesBefore=") for piece in pieces):
|
| 1116 |
+
warnings.warn("Found a SpacesBefore MISC annotation on a Word that was not the first Word in a Token")
|
| 1117 |
+
if word_idx == num_words - 1:
|
| 1118 |
+
if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
|
| 1119 |
+
self.spaces_after = misc_to_space_after(misc)
|
| 1120 |
+
found_after = True
|
| 1121 |
+
else:
|
| 1122 |
+
if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
|
| 1123 |
+
unexpected_space_after = misc_to_space_after(misc)
|
| 1124 |
+
if unexpected_space_after == "":
|
| 1125 |
+
warnings.warn("Unexpected SpaceAfter=No annotation on a word in the middle of an MWT")
|
| 1126 |
+
else:
|
| 1127 |
+
warnings.warn("Unexpected SpacesAfter on a word in the middle on an MWT")
|
| 1128 |
+
pieces = [x for x in pieces if not x.startswith("SpacesAfter=") and not x.startswith("SpaceAfter=") and not x.startswith("SpacesBefore=")]
|
| 1129 |
+
word.misc = "|".join(pieces)
|
| 1130 |
+
|
| 1131 |
+
misc = self.misc
|
| 1132 |
+
if misc:
|
| 1133 |
+
pieces = misc.split("|")
|
| 1134 |
+
if any(piece.startswith("SpacesBefore=") for piece in pieces):
|
| 1135 |
+
spaces_before = misc_to_space_before(misc)
|
| 1136 |
+
if found_before:
|
| 1137 |
+
if spaces_before != self.spaces_before:
|
| 1138 |
+
warnings.warn("Found conflicting SpacesBefore on a token and its word!")
|
| 1139 |
+
else:
|
| 1140 |
+
self.spaces_before = spaces_before
|
| 1141 |
+
if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
|
| 1142 |
+
spaces_after = misc_to_space_after(misc)
|
| 1143 |
+
if found_after:
|
| 1144 |
+
if spaces_after != self.spaces_after:
|
| 1145 |
+
warnings.warn("Found conflicting SpaceAfter / SpacesAfter on a token and its word!")
|
| 1146 |
+
else:
|
| 1147 |
+
self.spaces_after = spaces_after
|
| 1148 |
+
pieces = [x for x in pieces if not x.startswith("SpacesAfter=") and not x.startswith("SpaceAfter=") and not x.startswith("SpacesBefore=")]
|
| 1149 |
+
self.misc = "|".join(pieces)
|
| 1150 |
+
|
| 1151 |
+
@property
|
| 1152 |
+
def spaces_before(self):
|
| 1153 |
+
""" SpacesBefore for the token. Translated from the MISC fields """
|
| 1154 |
+
return self._spaces_before
|
| 1155 |
+
|
| 1156 |
+
@spaces_before.setter
|
| 1157 |
+
def spaces_before(self, value):
|
| 1158 |
+
self._spaces_before = value
|
| 1159 |
+
|
| 1160 |
+
@property
|
| 1161 |
+
def spaces_after(self):
|
| 1162 |
+
""" SpaceAfter or SpacesAfter for the token. Translated from the MISC field """
|
| 1163 |
+
return self._spaces_after
|
| 1164 |
+
|
| 1165 |
+
@spaces_after.setter
|
| 1166 |
+
def spaces_after(self, value):
|
| 1167 |
+
self._spaces_after = value
|
| 1168 |
+
|
| 1169 |
+
@property
|
| 1170 |
+
def words(self):
|
| 1171 |
+
""" Access the list of syntactic words underlying this token. """
|
| 1172 |
+
return self._words
|
| 1173 |
+
|
| 1174 |
+
@words.setter
|
| 1175 |
+
def words(self, value):
|
| 1176 |
+
""" Set this token's list of underlying syntactic words. """
|
| 1177 |
+
self._words = value
|
| 1178 |
+
for w in self._words:
|
| 1179 |
+
w.parent = self
|
| 1180 |
+
|
| 1181 |
+
@property
|
| 1182 |
+
def start_char(self):
|
| 1183 |
+
""" Access the start character index for this token in the raw text. """
|
| 1184 |
+
return self._start_char
|
| 1185 |
+
|
| 1186 |
+
@property
|
| 1187 |
+
def end_char(self):
|
| 1188 |
+
""" Access the end character index for this token in the raw text. """
|
| 1189 |
+
return self._end_char
|
| 1190 |
+
|
| 1191 |
+
@property
|
| 1192 |
+
def ner(self):
|
| 1193 |
+
""" Access the NER tag of this token. Example: 'B-ORG'"""
|
| 1194 |
+
return self._ner
|
| 1195 |
+
|
| 1196 |
+
@ner.setter
|
| 1197 |
+
def ner(self, value):
|
| 1198 |
+
""" Set the token's NER tag. Example: 'B-ORG'"""
|
| 1199 |
+
self._ner = value if self._is_null(value) == False else None
|
| 1200 |
+
|
| 1201 |
+
@property
|
| 1202 |
+
def multi_ner(self):
|
| 1203 |
+
""" Access the MULTI_NER tag of this token. Example: '(B-ORG, B-DISEASE)'"""
|
| 1204 |
+
return self._multi_ner
|
| 1205 |
+
|
| 1206 |
+
@multi_ner.setter
|
| 1207 |
+
def multi_ner(self, value):
|
| 1208 |
+
""" Set the token's MULTI_NER tag. Example: '(B-ORG, B-DISEASE)'"""
|
| 1209 |
+
self._multi_ner = value if self._is_null(value) == False else None
|
| 1210 |
+
|
| 1211 |
+
@property
|
| 1212 |
+
def sent(self):
|
| 1213 |
+
""" Access the pointer to the sentence that this token belongs to. """
|
| 1214 |
+
return self._sent
|
| 1215 |
+
|
| 1216 |
+
@sent.setter
|
| 1217 |
+
def sent(self, value):
|
| 1218 |
+
""" Set the pointer to the sentence that this token belongs to. """
|
| 1219 |
+
self._sent = value
|
| 1220 |
+
|
| 1221 |
+
def __repr__(self):
|
| 1222 |
+
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
|
| 1223 |
+
|
| 1224 |
+
def __format__(self, spec):
|
| 1225 |
+
if spec == 'C':
|
| 1226 |
+
return "\n".join(self.to_conll_text())
|
| 1227 |
+
elif spec == 'P':
|
| 1228 |
+
return self.pretty_print()
|
| 1229 |
+
else:
|
| 1230 |
+
return str(self)
|
| 1231 |
+
|
| 1232 |
+
def to_conll_text(self):
|
| 1233 |
+
return "\n".join(dict_to_conll_text(x) for x in self.to_dict())
|
| 1234 |
+
|
| 1235 |
+
def to_dict(self, fields=[ID, TEXT, MISC, START_CHAR, END_CHAR, NER, MULTI_NER, MEXP]):
|
| 1236 |
+
""" Dumps the token into a list of dictionary for this token with its extended words
|
| 1237 |
+
if the token is a multi-word token.
|
| 1238 |
+
"""
|
| 1239 |
+
ret = []
|
| 1240 |
+
if len(self.id) > 1:
|
| 1241 |
+
token_dict = {}
|
| 1242 |
+
for field in fields:
|
| 1243 |
+
if getattr(self, field) is not None:
|
| 1244 |
+
token_dict[field] = getattr(self, field)
|
| 1245 |
+
if MISC in fields:
|
| 1246 |
+
spaces_after = self.spaces_after
|
| 1247 |
+
if spaces_after is not None and spaces_after != ' ':
|
| 1248 |
+
space_misc = space_after_to_misc(spaces_after)
|
| 1249 |
+
if token_dict.get(MISC):
|
| 1250 |
+
token_dict[MISC] = token_dict[MISC] + "|" + space_misc
|
| 1251 |
+
else:
|
| 1252 |
+
token_dict[MISC] = space_misc
|
| 1253 |
+
|
| 1254 |
+
spaces_before = self.spaces_before
|
| 1255 |
+
if spaces_before is not None and spaces_before != '':
|
| 1256 |
+
space_misc = space_before_to_misc(spaces_before)
|
| 1257 |
+
if token_dict.get(MISC):
|
| 1258 |
+
token_dict[MISC] = token_dict[MISC] + "|" + space_misc
|
| 1259 |
+
else:
|
| 1260 |
+
token_dict[MISC] = space_misc
|
| 1261 |
+
|
| 1262 |
+
ret.append(token_dict)
|
| 1263 |
+
for word in self.words:
|
| 1264 |
+
word_dict = word.to_dict()
|
| 1265 |
+
if len(self.id) == 1 and NER in fields and getattr(self, NER) is not None: # propagate NER label to Word if it is a single-word token
|
| 1266 |
+
word_dict[NER] = getattr(self, NER)
|
| 1267 |
+
if len(self.id) == 1 and MULTI_NER in fields and getattr(self, MULTI_NER) is not None: # propagate MULTI_NER label to Word if it is a single-word token
|
| 1268 |
+
word_dict[MULTI_NER] = getattr(self, MULTI_NER)
|
| 1269 |
+
if len(self.id) == 1 and MISC in fields:
|
| 1270 |
+
spaces_after = self.spaces_after
|
| 1271 |
+
if spaces_after is not None and spaces_after != ' ':
|
| 1272 |
+
space_misc = space_after_to_misc(spaces_after)
|
| 1273 |
+
if word_dict.get(MISC):
|
| 1274 |
+
word_dict[MISC] = word_dict[MISC] + "|" + space_misc
|
| 1275 |
+
else:
|
| 1276 |
+
word_dict[MISC] = space_misc
|
| 1277 |
+
|
| 1278 |
+
spaces_before = self.spaces_before
|
| 1279 |
+
if spaces_before is not None and spaces_before != '':
|
| 1280 |
+
space_misc = space_before_to_misc(spaces_before)
|
| 1281 |
+
if word_dict.get(MISC):
|
| 1282 |
+
word_dict[MISC] = word_dict[MISC] + "|" + space_misc
|
| 1283 |
+
else:
|
| 1284 |
+
word_dict[MISC] = space_misc
|
| 1285 |
+
ret.append(word_dict)
|
| 1286 |
+
return ret
|
| 1287 |
+
|
| 1288 |
+
def pretty_print(self):
|
| 1289 |
+
""" Print this token with its extended words in one line. """
|
| 1290 |
+
return f"<{self.__class__.__name__} id={'-'.join([str(x) for x in self.id])};words=[{', '.join([word.pretty_print() for word in self.words])}]>"
|
| 1291 |
+
|
| 1292 |
+
def _is_null(self, value):
|
| 1293 |
+
return (value is None) or (value == '_')
|
| 1294 |
+
|
| 1295 |
+
def is_mwt(self):
|
| 1296 |
+
return len(self.words) > 1
|
| 1297 |
+
|
| 1298 |
+
class Word(StanzaObject):
|
| 1299 |
+
""" A word class that stores attributes of a word.
|
| 1300 |
+
"""
|
| 1301 |
+
|
| 1302 |
+
def __init__(self, sentence, word_entry):
|
| 1303 |
+
""" Construct a word given a dictionary format word entry.
|
| 1304 |
+
"""
|
| 1305 |
+
self._id = word_entry.get(ID, None)
|
| 1306 |
+
if isinstance(self._id, tuple):
|
| 1307 |
+
if len(self._id) == 1:
|
| 1308 |
+
self._id = self._id[0]
|
| 1309 |
+
self._text = word_entry.get(TEXT, None)
|
| 1310 |
+
|
| 1311 |
+
assert self._id is not None and self._text is not None, 'id and text should be included for the word. {}'.format(word_entry)
|
| 1312 |
+
|
| 1313 |
+
self._lemma = word_entry.get(LEMMA, None)
|
| 1314 |
+
self._upos = word_entry.get(UPOS, None)
|
| 1315 |
+
self._xpos = word_entry.get(XPOS, None)
|
| 1316 |
+
self._feats = word_entry.get(FEATS, None)
|
| 1317 |
+
self._head = word_entry.get(HEAD, None)
|
| 1318 |
+
self._deprel = word_entry.get(DEPREL, None)
|
| 1319 |
+
self._misc = word_entry.get(MISC, None)
|
| 1320 |
+
self._start_char = word_entry.get(START_CHAR, None)
|
| 1321 |
+
self._end_char = word_entry.get(END_CHAR, None)
|
| 1322 |
+
self._parent = None
|
| 1323 |
+
self._sent = sentence
|
| 1324 |
+
self._mexp = word_entry.get(MEXP, None)
|
| 1325 |
+
self._coref_chains = None
|
| 1326 |
+
|
| 1327 |
+
if self._misc is not None:
|
| 1328 |
+
init_from_misc(self)
|
| 1329 |
+
|
| 1330 |
+
# use the setter, which will go up to the sentence and set the
|
| 1331 |
+
# dependencies on that graph
|
| 1332 |
+
self.deps = word_entry.get(DEPS, None)
|
| 1333 |
+
|
| 1334 |
+
@property
|
| 1335 |
+
def manual_expansion(self):
|
| 1336 |
+
""" Access the whether this token was manually expanded. """
|
| 1337 |
+
return self._mexp
|
| 1338 |
+
|
| 1339 |
+
@manual_expansion.setter
|
| 1340 |
+
def manual_expansion(self, value):
|
| 1341 |
+
""" Set the whether this token was manually expanded. """
|
| 1342 |
+
self._mexp = value
|
| 1343 |
+
|
| 1344 |
+
@property
|
| 1345 |
+
def id(self):
|
| 1346 |
+
""" Access the index of this word. """
|
| 1347 |
+
return self._id
|
| 1348 |
+
|
| 1349 |
+
@id.setter
|
| 1350 |
+
def id(self, value):
|
| 1351 |
+
""" Set the word's index value. """
|
| 1352 |
+
self._id = value
|
| 1353 |
+
|
| 1354 |
+
@property
|
| 1355 |
+
def text(self):
|
| 1356 |
+
""" Access the text of this word. Example: 'The'"""
|
| 1357 |
+
return self._text
|
| 1358 |
+
|
| 1359 |
+
@text.setter
|
| 1360 |
+
def text(self, value):
|
| 1361 |
+
""" Set the word's text value. Example: 'The'"""
|
| 1362 |
+
self._text = value
|
| 1363 |
+
|
| 1364 |
+
@property
|
| 1365 |
+
def lemma(self):
|
| 1366 |
+
""" Access the lemma of this word. """
|
| 1367 |
+
return self._lemma
|
| 1368 |
+
|
| 1369 |
+
@lemma.setter
|
| 1370 |
+
def lemma(self, value):
|
| 1371 |
+
""" Set the word's lemma value. """
|
| 1372 |
+
self._lemma = value if self._is_null(value) == False or self._text == '_' else None
|
| 1373 |
+
|
| 1374 |
+
@property
|
| 1375 |
+
def upos(self):
|
| 1376 |
+
""" Access the universal part-of-speech of this word. Example: 'NOUN'"""
|
| 1377 |
+
return self._upos
|
| 1378 |
+
|
| 1379 |
+
@upos.setter
|
| 1380 |
+
def upos(self, value):
|
| 1381 |
+
""" Set the word's universal part-of-speech value. Example: 'NOUN'"""
|
| 1382 |
+
self._upos = value if self._is_null(value) == False else None
|
| 1383 |
+
|
| 1384 |
+
@property
|
| 1385 |
+
def xpos(self):
|
| 1386 |
+
""" Access the treebank-specific part-of-speech of this word. Example: 'NNP'"""
|
| 1387 |
+
return self._xpos
|
| 1388 |
+
|
| 1389 |
+
@xpos.setter
|
| 1390 |
+
def xpos(self, value):
|
| 1391 |
+
""" Set the word's treebank-specific part-of-speech value. Example: 'NNP'"""
|
| 1392 |
+
self._xpos = value if self._is_null(value) == False else None
|
| 1393 |
+
|
| 1394 |
+
@property
|
| 1395 |
+
def feats(self):
|
| 1396 |
+
""" Access the morphological features of this word. Example: 'Gender=Fem'"""
|
| 1397 |
+
return self._feats
|
| 1398 |
+
|
| 1399 |
+
@feats.setter
|
| 1400 |
+
def feats(self, value):
|
| 1401 |
+
""" Set this word's morphological features. Example: 'Gender=Fem'"""
|
| 1402 |
+
self._feats = value if self._is_null(value) == False else None
|
| 1403 |
+
|
| 1404 |
+
@property
|
| 1405 |
+
def head(self):
|
| 1406 |
+
""" Access the id of the governor of this word. """
|
| 1407 |
+
return self._head
|
| 1408 |
+
|
| 1409 |
+
@head.setter
|
| 1410 |
+
def head(self, value):
|
| 1411 |
+
""" Set the word's governor id value. """
|
| 1412 |
+
self._head = int(value) if self._is_null(value) == False else None
|
| 1413 |
+
|
| 1414 |
+
@property
|
| 1415 |
+
def deprel(self):
|
| 1416 |
+
""" Access the dependency relation of this word. Example: 'nmod'"""
|
| 1417 |
+
return self._deprel
|
| 1418 |
+
|
| 1419 |
+
@deprel.setter
|
| 1420 |
+
def deprel(self, value):
|
| 1421 |
+
""" Set the word's dependency relation value. Example: 'nmod'"""
|
| 1422 |
+
self._deprel = value if self._is_null(value) == False else None
|
| 1423 |
+
|
| 1424 |
+
@property
|
| 1425 |
+
def deps(self):
|
| 1426 |
+
""" Access the dependencies of this word. """
|
| 1427 |
+
graph = self._sent._enhanced_dependencies
|
| 1428 |
+
if graph is None or not graph.has_node(self.id):
|
| 1429 |
+
return None
|
| 1430 |
+
|
| 1431 |
+
data = []
|
| 1432 |
+
predecessors = sorted(list(graph.predecessors(self.id)), key=lambda x: x if isinstance(x, tuple) else (x,))
|
| 1433 |
+
for parent in predecessors:
|
| 1434 |
+
deps = sorted(list(graph.get_edge_data(parent, self.id)))
|
| 1435 |
+
for dep in deps:
|
| 1436 |
+
if isinstance(parent, int):
|
| 1437 |
+
data.append("%d:%s" % (parent, dep))
|
| 1438 |
+
else:
|
| 1439 |
+
data.append("%d.%d:%s" % (parent[0], parent[1], dep))
|
| 1440 |
+
if not data:
|
| 1441 |
+
return None
|
| 1442 |
+
|
| 1443 |
+
return "|".join(data)
|
| 1444 |
+
|
| 1445 |
+
@deps.setter
|
| 1446 |
+
def deps(self, value):
|
| 1447 |
+
""" Set the word's dependencies value. """
|
| 1448 |
+
graph = self._sent._enhanced_dependencies
|
| 1449 |
+
# if we don't have a graph, and we aren't trying to set any actual
|
| 1450 |
+
# dependencies, we can save the time of doing anything else
|
| 1451 |
+
if graph is None and value is None:
|
| 1452 |
+
return
|
| 1453 |
+
|
| 1454 |
+
if graph is None:
|
| 1455 |
+
graph = nx.MultiDiGraph()
|
| 1456 |
+
self._sent._enhanced_dependencies = graph
|
| 1457 |
+
# need to make a new list: cannot iterate and delete at the same time
|
| 1458 |
+
if graph.has_node(self.id):
|
| 1459 |
+
in_edges = list(graph.in_edges(self.id))
|
| 1460 |
+
graph.remove_edges_from(in_edges)
|
| 1461 |
+
|
| 1462 |
+
if value is None:
|
| 1463 |
+
return
|
| 1464 |
+
|
| 1465 |
+
if isinstance(value, str):
|
| 1466 |
+
value = value.split("|")
|
| 1467 |
+
if all(isinstance(x, str) for x in value):
|
| 1468 |
+
value = [x.split(":", maxsplit=1) for x in value]
|
| 1469 |
+
for parent, dep in value:
|
| 1470 |
+
# we have to match the format of the IDs. since the IDs
|
| 1471 |
+
# of the words are int if they aren't empty words, we need
|
| 1472 |
+
# to convert single int IDs into int instead of tuple
|
| 1473 |
+
parent = tuple(map(int, parent.split(".", maxsplit=1)))
|
| 1474 |
+
if len(parent) == 1:
|
| 1475 |
+
parent = parent[0]
|
| 1476 |
+
graph.add_edge(parent, self.id, dep)
|
| 1477 |
+
|
| 1478 |
+
@property
|
| 1479 |
+
def misc(self):
|
| 1480 |
+
""" Access the miscellaneousness of this word. """
|
| 1481 |
+
return self._misc
|
| 1482 |
+
|
| 1483 |
+
@misc.setter
|
| 1484 |
+
def misc(self, value):
|
| 1485 |
+
""" Set the word's miscellaneousness value. """
|
| 1486 |
+
self._misc = value if self._is_null(value) == False else None
|
| 1487 |
+
|
| 1488 |
+
@property
|
| 1489 |
+
def start_char(self):
|
| 1490 |
+
""" Access the start character index for this token in the raw text. """
|
| 1491 |
+
return self._start_char
|
| 1492 |
+
|
| 1493 |
+
@start_char.setter
|
| 1494 |
+
def start_char(self, value):
|
| 1495 |
+
self._start_char = value
|
| 1496 |
+
|
| 1497 |
+
@property
|
| 1498 |
+
def end_char(self):
|
| 1499 |
+
""" Access the end character index for this token in the raw text. """
|
| 1500 |
+
return self._end_char
|
| 1501 |
+
|
| 1502 |
+
@end_char.setter
|
| 1503 |
+
def end_char(self, value):
|
| 1504 |
+
self._end_char = value
|
| 1505 |
+
|
| 1506 |
+
@property
|
| 1507 |
+
def parent(self):
|
| 1508 |
+
""" Access the parent token of this word. In the case of a multi-word token, a token can be the parent of
|
| 1509 |
+
multiple words. Note that this should return a reference to the parent token object.
|
| 1510 |
+
"""
|
| 1511 |
+
return self._parent
|
| 1512 |
+
|
| 1513 |
+
@parent.setter
|
| 1514 |
+
def parent(self, value):
|
| 1515 |
+
""" Set this word's parent token. In the case of a multi-word token, a token can be the parent of
|
| 1516 |
+
multiple words. Note that value here should be a reference to the parent token object.
|
| 1517 |
+
"""
|
| 1518 |
+
self._parent = value
|
| 1519 |
+
|
| 1520 |
+
@property
|
| 1521 |
+
def pos(self):
|
| 1522 |
+
""" Access the universal part-of-speech of this word. Example: 'NOUN'"""
|
| 1523 |
+
return self._upos
|
| 1524 |
+
|
| 1525 |
+
@pos.setter
|
| 1526 |
+
def pos(self, value):
|
| 1527 |
+
""" Set the word's universal part-of-speech value. Example: 'NOUN'"""
|
| 1528 |
+
self._upos = value if self._is_null(value) == False else None
|
| 1529 |
+
|
| 1530 |
+
@property
|
| 1531 |
+
def coref_chains(self):
|
| 1532 |
+
"""
|
| 1533 |
+
coref_chains points to a list of CorefChain namedtuple, which has a list of mentions and a representative mention.
|
| 1534 |
+
|
| 1535 |
+
Useful for disambiguating words such as "him" (in languages where coref is available)
|
| 1536 |
+
|
| 1537 |
+
Theoretically it is possible for multiple corefs to occur at the same word. For example,
|
| 1538 |
+
"Chris Manning's NLP Group"
|
| 1539 |
+
could have "Chris Manning" and "Chris Manning's NLP Group" as overlapping entities
|
| 1540 |
+
"""
|
| 1541 |
+
return self._coref_chains
|
| 1542 |
+
|
| 1543 |
+
@coref_chains.setter
|
| 1544 |
+
def coref_chains(self, chain):
|
| 1545 |
+
""" Set the backref for the coref chains """
|
| 1546 |
+
self._coref_chains = chain
|
| 1547 |
+
|
| 1548 |
+
@property
|
| 1549 |
+
def sent(self):
|
| 1550 |
+
""" Access the pointer to the sentence that this word belongs to. """
|
| 1551 |
+
return self._sent
|
| 1552 |
+
|
| 1553 |
+
@sent.setter
|
| 1554 |
+
def sent(self, value):
|
| 1555 |
+
""" Set the pointer to the sentence that this word belongs to. """
|
| 1556 |
+
self._sent = value
|
| 1557 |
+
|
| 1558 |
+
def __repr__(self):
|
| 1559 |
+
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
|
| 1560 |
+
|
| 1561 |
+
def __format__(self, spec):
|
| 1562 |
+
if spec == 'C':
|
| 1563 |
+
return self.to_conll_text()
|
| 1564 |
+
elif spec == 'P':
|
| 1565 |
+
return self.pretty_print()
|
| 1566 |
+
else:
|
| 1567 |
+
return str(self)
|
| 1568 |
+
|
| 1569 |
+
def to_conll_text(self):
|
| 1570 |
+
"""
|
| 1571 |
+
Turn a word into a conll representation (10 column tab separated)
|
| 1572 |
+
"""
|
| 1573 |
+
token_dict = self.to_dict()
|
| 1574 |
+
return dict_to_conll_text(token_dict, '.')
|
| 1575 |
+
|
| 1576 |
+
def to_dict(self, fields=[ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, MEXP, COREF_CHAINS]):
|
| 1577 |
+
""" Dumps the word into a dictionary.
|
| 1578 |
+
"""
|
| 1579 |
+
word_dict = {}
|
| 1580 |
+
for field in fields:
|
| 1581 |
+
if getattr(self, field) is not None:
|
| 1582 |
+
word_dict[field] = getattr(self, field)
|
| 1583 |
+
return word_dict
|
| 1584 |
+
|
| 1585 |
+
def pretty_print(self):
|
| 1586 |
+
""" Print the word in one line. """
|
| 1587 |
+
features = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL]
|
| 1588 |
+
feature_str = ";".join(["{}={}".format(k, getattr(self, k)) for k in features if getattr(self, k) is not None])
|
| 1589 |
+
return f"<{self.__class__.__name__} {feature_str}>"
|
| 1590 |
+
|
| 1591 |
+
def _is_null(self, value):
|
| 1592 |
+
return (value is None) or (value == '_')
|
| 1593 |
+
|
| 1594 |
+
|
| 1595 |
+
class Span(StanzaObject):
|
| 1596 |
+
""" A span class that stores attributes of a textual span. A span can be typed.
|
| 1597 |
+
A range of objects (e.g., entity mentions) can be represented as spans.
|
| 1598 |
+
"""
|
| 1599 |
+
|
| 1600 |
+
def __init__(self, span_entry=None, tokens=None, type=None, doc=None, sent=None):
|
| 1601 |
+
""" Construct a span given a span entry or a list of tokens. A valid reference to a doc
|
| 1602 |
+
must be provided to construct a span (otherwise the text of the span cannot be initialized).
|
| 1603 |
+
"""
|
| 1604 |
+
assert span_entry is not None or (tokens is not None and type is not None), \
|
| 1605 |
+
'Either a span_entry or a token list needs to be provided to construct a span.'
|
| 1606 |
+
assert doc is not None, 'A parent doc must be provided to construct a span.'
|
| 1607 |
+
self._text, self._type, self._start_char, self._end_char = [None] * 4
|
| 1608 |
+
self._tokens = []
|
| 1609 |
+
self._words = []
|
| 1610 |
+
self._doc = doc
|
| 1611 |
+
self._sent = sent
|
| 1612 |
+
|
| 1613 |
+
if span_entry is not None:
|
| 1614 |
+
self.init_from_entry(span_entry)
|
| 1615 |
+
|
| 1616 |
+
if tokens is not None:
|
| 1617 |
+
self.init_from_tokens(tokens, type)
|
| 1618 |
+
|
| 1619 |
+
def init_from_entry(self, span_entry):
|
| 1620 |
+
self.text = span_entry.get(TEXT, None)
|
| 1621 |
+
self.type = span_entry.get(TYPE, None)
|
| 1622 |
+
self.start_char = span_entry.get(START_CHAR, None)
|
| 1623 |
+
self.end_char = span_entry.get(END_CHAR, None)
|
| 1624 |
+
|
| 1625 |
+
def init_from_tokens(self, tokens, type):
|
| 1626 |
+
assert isinstance(tokens, list), 'Tokens must be provided as a list to construct a span.'
|
| 1627 |
+
assert len(tokens) > 0, "Tokens of a span cannot be an empty list."
|
| 1628 |
+
self.tokens = tokens
|
| 1629 |
+
self.type = type
|
| 1630 |
+
# load start and end char offsets from tokens
|
| 1631 |
+
self.start_char = self.tokens[0].start_char
|
| 1632 |
+
self.end_char = self.tokens[-1].end_char
|
| 1633 |
+
if self.doc is not None and self.doc.text is not None:
|
| 1634 |
+
self.text = self.doc.text[self.start_char:self.end_char]
|
| 1635 |
+
elif tokens[0].sent is tokens[-1].sent:
|
| 1636 |
+
sentence = tokens[0].sent
|
| 1637 |
+
text_start = tokens[0].start_char - sentence.tokens[0].start_char
|
| 1638 |
+
text_end = tokens[-1].end_char - sentence.tokens[0].start_char
|
| 1639 |
+
self.text = sentence.text[text_start:text_end]
|
| 1640 |
+
else:
|
| 1641 |
+
# TODO: do any spans ever cross sentences?
|
| 1642 |
+
raise RuntimeError("Document text does not exist, and the span tested crosses two sentences, so it is impossible to extract the entity text!")
|
| 1643 |
+
# collect the words of the span following tokens
|
| 1644 |
+
self.words = [w for t in tokens for w in t.words]
|
| 1645 |
+
# set the sentence back-pointer to point to the sentence of the first token
|
| 1646 |
+
self.sent = tokens[0].sent
|
| 1647 |
+
|
| 1648 |
+
@property
|
| 1649 |
+
def doc(self):
|
| 1650 |
+
""" Access the parent doc of this span. """
|
| 1651 |
+
return self._doc
|
| 1652 |
+
|
| 1653 |
+
@doc.setter
|
| 1654 |
+
def doc(self, value):
|
| 1655 |
+
""" Set the parent doc of this span. """
|
| 1656 |
+
self._doc = value
|
| 1657 |
+
|
| 1658 |
+
@property
|
| 1659 |
+
def text(self):
|
| 1660 |
+
""" Access the text of this span. Example: 'Stanford University'"""
|
| 1661 |
+
return self._text
|
| 1662 |
+
|
| 1663 |
+
@text.setter
|
| 1664 |
+
def text(self, value):
|
| 1665 |
+
""" Set the span's text value. Example: 'Stanford University'"""
|
| 1666 |
+
self._text = value
|
| 1667 |
+
|
| 1668 |
+
@property
|
| 1669 |
+
def tokens(self):
|
| 1670 |
+
""" Access reference to a list of tokens that correspond to this span. """
|
| 1671 |
+
return self._tokens
|
| 1672 |
+
|
| 1673 |
+
@tokens.setter
|
| 1674 |
+
def tokens(self, value):
|
| 1675 |
+
""" Set the span's list of tokens. """
|
| 1676 |
+
self._tokens = value
|
| 1677 |
+
|
| 1678 |
+
@property
|
| 1679 |
+
def words(self):
|
| 1680 |
+
""" Access reference to a list of words that correspond to this span. """
|
| 1681 |
+
return self._words
|
| 1682 |
+
|
| 1683 |
+
@words.setter
|
| 1684 |
+
def words(self, value):
|
| 1685 |
+
""" Set the span's list of words. """
|
| 1686 |
+
self._words = value
|
| 1687 |
+
|
| 1688 |
+
@property
|
| 1689 |
+
def type(self):
|
| 1690 |
+
""" Access the type of this span. Example: 'PERSON'"""
|
| 1691 |
+
return self._type
|
| 1692 |
+
|
| 1693 |
+
@type.setter
|
| 1694 |
+
def type(self, value):
|
| 1695 |
+
""" Set the type of this span. """
|
| 1696 |
+
self._type = value
|
| 1697 |
+
|
| 1698 |
+
@property
|
| 1699 |
+
def start_char(self):
|
| 1700 |
+
""" Access the start character offset of this span. """
|
| 1701 |
+
return self._start_char
|
| 1702 |
+
|
| 1703 |
+
@start_char.setter
|
| 1704 |
+
def start_char(self, value):
|
| 1705 |
+
""" Set the start character offset of this span. """
|
| 1706 |
+
self._start_char = value
|
| 1707 |
+
|
| 1708 |
+
@property
|
| 1709 |
+
def end_char(self):
|
| 1710 |
+
""" Access the end character offset of this span. """
|
| 1711 |
+
return self._end_char
|
| 1712 |
+
|
| 1713 |
+
@end_char.setter
|
| 1714 |
+
def end_char(self, value):
|
| 1715 |
+
""" Set the end character offset of this span. """
|
| 1716 |
+
self._end_char = value
|
| 1717 |
+
|
| 1718 |
+
@property
|
| 1719 |
+
def sent(self):
|
| 1720 |
+
""" Access the pointer to the sentence that this span belongs to. """
|
| 1721 |
+
return self._sent
|
| 1722 |
+
|
| 1723 |
+
@sent.setter
|
| 1724 |
+
def sent(self, value):
|
| 1725 |
+
""" Set the pointer to the sentence that this span belongs to. """
|
| 1726 |
+
self._sent = value
|
| 1727 |
+
|
| 1728 |
+
def to_dict(self):
|
| 1729 |
+
""" Dumps the span into a dictionary. """
|
| 1730 |
+
attrs = ['text', 'type', 'start_char', 'end_char']
|
| 1731 |
+
span_dict = dict([(attr_name, getattr(self, attr_name)) for attr_name in attrs])
|
| 1732 |
+
return span_dict
|
| 1733 |
+
|
| 1734 |
+
def __repr__(self):
|
| 1735 |
+
return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
|
| 1736 |
+
|
| 1737 |
+
def pretty_print(self):
|
| 1738 |
+
""" Print the span in one line. """
|
| 1739 |
+
span_dict = self.to_dict()
|
| 1740 |
+
feature_str = ";".join(["{}={}".format(k,v) for k,v in span_dict.items()])
|
| 1741 |
+
return f"<{self.__class__.__name__} {feature_str}>"
|
stanza/stanza/models/common/dropout.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class WordDropout(nn.Module):
|
| 5 |
+
""" A word dropout layer that's designed for embedded inputs (e.g., any inputs to an LSTM layer).
|
| 6 |
+
Given a batch of embedded inputs, this layer randomly set some of them to be a replacement state.
|
| 7 |
+
Note that this layer assumes the last dimension of the input to be the hidden dimension of a unit.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, dropprob):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.dropprob = dropprob
|
| 12 |
+
|
| 13 |
+
def forward(self, x, replacement=None):
|
| 14 |
+
if not self.training or self.dropprob == 0:
|
| 15 |
+
return x
|
| 16 |
+
|
| 17 |
+
masksize = [y for y in x.size()]
|
| 18 |
+
masksize[-1] = 1
|
| 19 |
+
dropmask = torch.rand(*masksize, device=x.device) < self.dropprob
|
| 20 |
+
|
| 21 |
+
res = x.masked_fill(dropmask, 0)
|
| 22 |
+
if replacement is not None:
|
| 23 |
+
res = res + dropmask.float() * replacement
|
| 24 |
+
|
| 25 |
+
return res
|
| 26 |
+
|
| 27 |
+
def extra_repr(self):
|
| 28 |
+
return 'p={}'.format(self.dropprob)
|
| 29 |
+
|
| 30 |
+
class LockedDropout(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
A variant of dropout layer that consistently drops out the same parameters over time. Also known as the variational dropout.
|
| 33 |
+
This implementation was modified from the LockedDropout implementation in the flair library (https://github.com/zalandoresearch/flair).
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, dropprob, batch_first=True):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.dropprob = dropprob
|
| 38 |
+
self.batch_first = batch_first
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
if not self.training or self.dropprob == 0:
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
if not self.batch_first:
|
| 45 |
+
m = x.new_empty(1, x.size(1), x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)
|
| 46 |
+
else:
|
| 47 |
+
m = x.new_empty(x.size(0), 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)
|
| 48 |
+
|
| 49 |
+
mask = m.div(1 - self.dropprob).expand_as(x)
|
| 50 |
+
return mask * x
|
| 51 |
+
|
| 52 |
+
def extra_repr(self):
|
| 53 |
+
return 'p={}'.format(self.dropprob)
|
| 54 |
+
|
| 55 |
+
class SequenceUnitDropout(nn.Module):
|
| 56 |
+
""" A unit dropout layer that's designed for input of sequence units (e.g., word sequence, char sequence, etc.).
|
| 57 |
+
Given a sequence of unit indices, this layer randomly set some of them to be a replacement id (usually set to be <UNK>).
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, dropprob, replacement_id):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.dropprob = dropprob
|
| 62 |
+
self.replacement_id = replacement_id
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
""" :param: x must be a LongTensor of unit indices. """
|
| 66 |
+
if not self.training or self.dropprob == 0:
|
| 67 |
+
return x
|
| 68 |
+
masksize = [y for y in x.size()]
|
| 69 |
+
dropmask = torch.rand(*masksize, device=x.device) < self.dropprob
|
| 70 |
+
res = x.masked_fill(dropmask, self.replacement_id)
|
| 71 |
+
return res
|
| 72 |
+
|
| 73 |
+
def extra_repr(self):
|
| 74 |
+
return 'p={}, replacement_id={}'.format(self.dropprob, self.replacement_id)
|
| 75 |
+
|
stanza/stanza/models/common/exceptions.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A couple more specific FileNotFoundError exceptions
|
| 3 |
+
|
| 4 |
+
The idea being, the caller can catch it and report a more useful error resolution
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import errno
|
| 8 |
+
|
| 9 |
+
class ForwardCharlmNotFoundError(FileNotFoundError):
|
| 10 |
+
def __init__(self, msg, filename):
|
| 11 |
+
super().__init__(errno.ENOENT, msg, filename)
|
| 12 |
+
|
| 13 |
+
class BackwardCharlmNotFoundError(FileNotFoundError):
|
| 14 |
+
def __init__(self, msg, filename):
|
| 15 |
+
super().__init__(errno.ENOENT, msg, filename)
|
stanza/stanza/models/common/foundation_cache.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Keeps BERT, charlm, word embedings in a cache to save memory
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import namedtuple
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
import logging
|
| 8 |
+
import threading
|
| 9 |
+
|
| 10 |
+
from stanza.models.common import bert_embedding
|
| 11 |
+
from stanza.models.common.char_model import CharacterLanguageModel
|
| 12 |
+
from stanza.models.common.pretrain import Pretrain
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger('stanza')
|
| 15 |
+
|
| 16 |
+
BertRecord = namedtuple('BertRecord', ['model', 'tokenizer', 'peft_ids'])
|
| 17 |
+
|
| 18 |
+
class FoundationCache:
|
| 19 |
+
def __init__(self, other=None, local_files_only=False):
|
| 20 |
+
if other is None:
|
| 21 |
+
self.bert = {}
|
| 22 |
+
self.charlms = {}
|
| 23 |
+
self.pretrains = {}
|
| 24 |
+
# future proof the module by using a lock for the glorious day
|
| 25 |
+
# when the GIL is finally gone
|
| 26 |
+
self.lock = threading.Lock()
|
| 27 |
+
else:
|
| 28 |
+
self.bert = other.bert
|
| 29 |
+
self.charlms = other.charlms
|
| 30 |
+
self.pretrains = other.pretrains
|
| 31 |
+
self.lock = other.lock
|
| 32 |
+
self.local_files_only=local_files_only
|
| 33 |
+
|
| 34 |
+
def load_bert(self, transformer_name, local_files_only=None):
|
| 35 |
+
m, t, _ = self.load_bert_with_peft(transformer_name, None, local_files_only=local_files_only)
|
| 36 |
+
return m, t
|
| 37 |
+
|
| 38 |
+
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
|
| 39 |
+
"""
|
| 40 |
+
Load a transformer only once
|
| 41 |
+
|
| 42 |
+
Uses a lock for thread safety
|
| 43 |
+
"""
|
| 44 |
+
if transformer_name is None:
|
| 45 |
+
return None, None, None
|
| 46 |
+
with self.lock:
|
| 47 |
+
if transformer_name not in self.bert:
|
| 48 |
+
if local_files_only is None:
|
| 49 |
+
local_files_only = self.local_files_only
|
| 50 |
+
model, tokenizer = bert_embedding.load_bert(transformer_name, local_files_only=local_files_only)
|
| 51 |
+
self.bert[transformer_name] = BertRecord(model, tokenizer, {})
|
| 52 |
+
else:
|
| 53 |
+
logger.debug("Reusing bert %s", transformer_name)
|
| 54 |
+
|
| 55 |
+
bert_record = self.bert[transformer_name]
|
| 56 |
+
if not peft_name:
|
| 57 |
+
return bert_record.model, bert_record.tokenizer, None
|
| 58 |
+
if peft_name not in bert_record.peft_ids:
|
| 59 |
+
bert_record.peft_ids[peft_name] = 0
|
| 60 |
+
else:
|
| 61 |
+
bert_record.peft_ids[peft_name] = bert_record.peft_ids[peft_name] + 1
|
| 62 |
+
peft_name = "%s_%d" % (peft_name, bert_record.peft_ids[peft_name])
|
| 63 |
+
return bert_record.model, bert_record.tokenizer, peft_name
|
| 64 |
+
|
| 65 |
+
def load_charlm(self, filename):
|
| 66 |
+
if not filename:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
with self.lock:
|
| 70 |
+
if filename not in self.charlms:
|
| 71 |
+
logger.debug("Loading charlm from %s", filename)
|
| 72 |
+
self.charlms[filename] = CharacterLanguageModel.load(filename, finetune=False)
|
| 73 |
+
else:
|
| 74 |
+
logger.debug("Reusing charlm from %s", filename)
|
| 75 |
+
|
| 76 |
+
return self.charlms[filename]
|
| 77 |
+
|
| 78 |
+
def load_pretrain(self, filename):
|
| 79 |
+
"""
|
| 80 |
+
Load a pretrained word embedding only once
|
| 81 |
+
|
| 82 |
+
Uses a lock for thread safety
|
| 83 |
+
"""
|
| 84 |
+
if filename is None:
|
| 85 |
+
return None
|
| 86 |
+
with self.lock:
|
| 87 |
+
if filename not in self.pretrains:
|
| 88 |
+
logger.debug("Loading pretrain %s", filename)
|
| 89 |
+
self.pretrains[filename] = Pretrain(filename)
|
| 90 |
+
else:
|
| 91 |
+
logger.debug("Reusing pretrain %s", filename)
|
| 92 |
+
|
| 93 |
+
return self.pretrains[filename]
|
| 94 |
+
|
| 95 |
+
class NoTransformerFoundationCache(FoundationCache):
|
| 96 |
+
"""
|
| 97 |
+
Uses the underlying FoundationCache, but hiding the transformer.
|
| 98 |
+
|
| 99 |
+
Useful for when loading a downstream model such as POS which has a
|
| 100 |
+
finetuned transformer, and we don't want the transformer reused
|
| 101 |
+
since it will then have the finetuned weights for other models
|
| 102 |
+
which don't want them
|
| 103 |
+
"""
|
| 104 |
+
def load_bert(self, transformer_name, local_files_only=None):
|
| 105 |
+
return load_bert(transformer_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
|
| 106 |
+
|
| 107 |
+
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
|
| 108 |
+
return load_bert_with_peft(transformer_name, peft_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
|
| 109 |
+
|
| 110 |
+
def load_bert(model_name, foundation_cache=None, local_files_only=None):
|
| 111 |
+
"""
|
| 112 |
+
Load a bert, possibly using a foundation cache, ignoring the cache if None
|
| 113 |
+
"""
|
| 114 |
+
if foundation_cache is None:
|
| 115 |
+
return bert_embedding.load_bert(model_name, local_files_only=local_files_only)
|
| 116 |
+
else:
|
| 117 |
+
return foundation_cache.load_bert(model_name, local_files_only=local_files_only)
|
| 118 |
+
|
| 119 |
+
def load_bert_with_peft(model_name, peft_name, foundation_cache=None, local_files_only=None):
|
| 120 |
+
if foundation_cache is None:
|
| 121 |
+
m, t = bert_embedding.load_bert(model_name, local_files_only=local_files_only)
|
| 122 |
+
return m, t, peft_name
|
| 123 |
+
return foundation_cache.load_bert_with_peft(model_name, peft_name, local_files_only=local_files_only)
|
| 124 |
+
|
| 125 |
+
def load_charlm(charlm_file, foundation_cache=None, finetune=False):
|
| 126 |
+
if not charlm_file:
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
if finetune:
|
| 130 |
+
# can't use the cache in the case of a model which will be finetuned
|
| 131 |
+
# and the numbers will be different for other users of the model
|
| 132 |
+
return CharacterLanguageModel.load(charlm_file, finetune=True)
|
| 133 |
+
|
| 134 |
+
if foundation_cache is not None:
|
| 135 |
+
return foundation_cache.load_charlm(charlm_file)
|
| 136 |
+
|
| 137 |
+
logger.debug("Loading charlm from %s", charlm_file)
|
| 138 |
+
return CharacterLanguageModel.load(charlm_file, finetune=False)
|
| 139 |
+
|
| 140 |
+
def load_pretrain(filename, foundation_cache=None):
|
| 141 |
+
if not filename:
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
if foundation_cache is not None:
|
| 145 |
+
return foundation_cache.load_pretrain(filename)
|
| 146 |
+
|
| 147 |
+
logger.debug("Loading pretrain from %s", filename)
|
| 148 |
+
return Pretrain(filename)
|
stanza/stanza/models/common/hlstm.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
|
| 5 |
+
|
| 6 |
+
from stanza.models.common.packed_lstm import PackedLSTM
|
| 7 |
+
|
| 8 |
+
class HLSTMCell(nn.modules.rnn.RNNCellBase):
|
| 9 |
+
"""
|
| 10 |
+
A Highway LSTM Cell as proposed in Zhang et al. (2018) Highway Long Short-Term Memory RNNs for
|
| 11 |
+
Distant Speech Recognition.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, input_size, hidden_size, bias=True):
|
| 14 |
+
super(HLSTMCell, self).__init__()
|
| 15 |
+
self.input_size = input_size
|
| 16 |
+
self.hidden_size = hidden_size
|
| 17 |
+
|
| 18 |
+
# LSTM parameters
|
| 19 |
+
self.Wi = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
|
| 20 |
+
self.Wf = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
|
| 21 |
+
self.Wo = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
|
| 22 |
+
self.Wg = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
|
| 23 |
+
|
| 24 |
+
# highway gate parameters
|
| 25 |
+
self.gate = nn.Linear(input_size + 2 * hidden_size, hidden_size, bias=bias)
|
| 26 |
+
|
| 27 |
+
def forward(self, input, c_l_minus_one=None, hx=None):
|
| 28 |
+
self.check_forward_input(input)
|
| 29 |
+
if hx is None:
|
| 30 |
+
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
|
| 31 |
+
hx = (hx, hx)
|
| 32 |
+
if c_l_minus_one is None:
|
| 33 |
+
c_l_minus_one = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
|
| 34 |
+
|
| 35 |
+
self.check_forward_hidden(input, hx[0], '[0]')
|
| 36 |
+
self.check_forward_hidden(input, hx[1], '[1]')
|
| 37 |
+
self.check_forward_hidden(input, c_l_minus_one, 'c_l_minus_one')
|
| 38 |
+
|
| 39 |
+
# vanilla LSTM computation
|
| 40 |
+
rec_input = torch.cat([input, hx[0]], 1)
|
| 41 |
+
i = F.sigmoid(self.Wi(rec_input))
|
| 42 |
+
f = F.sigmoid(self.Wf(rec_input))
|
| 43 |
+
o = F.sigmoid(self.Wo(rec_input))
|
| 44 |
+
g = F.tanh(self.Wg(rec_input))
|
| 45 |
+
|
| 46 |
+
# highway gates
|
| 47 |
+
gate = F.sigmoid(self.gate(torch.cat([c_l_minus_one, hx[1], input], 1)))
|
| 48 |
+
|
| 49 |
+
c = gate * c_l_minus_one + f * hx[1] + i * g
|
| 50 |
+
h = o * F.tanh(c)
|
| 51 |
+
|
| 52 |
+
return h, c
|
| 53 |
+
|
| 54 |
+
# Highway LSTM network, does NOT use the HLSTMCell above
|
| 55 |
+
class HighwayLSTM(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
A Highway LSTM network, as used in the original Tensorflow version of the Dozat parser. Note that this
|
| 58 |
+
is independent from the HLSTMCell above.
|
| 59 |
+
"""
|
| 60 |
+
def __init__(self, input_size, hidden_size,
|
| 61 |
+
num_layers=1, bias=True, batch_first=False,
|
| 62 |
+
dropout=0, bidirectional=False, rec_dropout=0, highway_func=None, pad=False):
|
| 63 |
+
super(HighwayLSTM, self).__init__()
|
| 64 |
+
self.input_size = input_size
|
| 65 |
+
self.hidden_size = hidden_size
|
| 66 |
+
self.num_layers = num_layers
|
| 67 |
+
self.bias = bias
|
| 68 |
+
self.batch_first = batch_first
|
| 69 |
+
self.dropout = dropout
|
| 70 |
+
self.dropout_state = {}
|
| 71 |
+
self.bidirectional = bidirectional
|
| 72 |
+
self.num_directions = 2 if bidirectional else 1
|
| 73 |
+
self.highway_func = highway_func
|
| 74 |
+
self.pad = pad
|
| 75 |
+
|
| 76 |
+
self.lstm = nn.ModuleList()
|
| 77 |
+
self.highway = nn.ModuleList()
|
| 78 |
+
self.gate = nn.ModuleList()
|
| 79 |
+
self.drop = nn.Dropout(dropout, inplace=True)
|
| 80 |
+
|
| 81 |
+
in_size = input_size
|
| 82 |
+
for l in range(num_layers):
|
| 83 |
+
self.lstm.append(PackedLSTM(in_size, hidden_size, num_layers=1, bias=bias,
|
| 84 |
+
batch_first=batch_first, dropout=0, bidirectional=bidirectional, rec_dropout=rec_dropout))
|
| 85 |
+
self.highway.append(nn.Linear(in_size, hidden_size * self.num_directions))
|
| 86 |
+
self.gate.append(nn.Linear(in_size, hidden_size * self.num_directions))
|
| 87 |
+
self.highway[-1].bias.data.zero_()
|
| 88 |
+
self.gate[-1].bias.data.zero_()
|
| 89 |
+
in_size = hidden_size * self.num_directions
|
| 90 |
+
|
| 91 |
+
def forward(self, input, seqlens, hx=None):
|
| 92 |
+
highway_func = (lambda x: x) if self.highway_func is None else self.highway_func
|
| 93 |
+
|
| 94 |
+
hs = []
|
| 95 |
+
cs = []
|
| 96 |
+
|
| 97 |
+
if not isinstance(input, PackedSequence):
|
| 98 |
+
input = pack_padded_sequence(input, seqlens, batch_first=self.batch_first)
|
| 99 |
+
|
| 100 |
+
for l in range(self.num_layers):
|
| 101 |
+
if l > 0:
|
| 102 |
+
input = PackedSequence(self.drop(input.data), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
|
| 103 |
+
layer_hx = (hx[0][l * self.num_directions:(l+1)*self.num_directions], hx[1][l * self.num_directions:(l+1)*self.num_directions]) if hx is not None else None
|
| 104 |
+
h, (ht, ct) = self.lstm[l](input, seqlens, layer_hx)
|
| 105 |
+
|
| 106 |
+
hs.append(ht)
|
| 107 |
+
cs.append(ct)
|
| 108 |
+
|
| 109 |
+
input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
|
| 110 |
+
|
| 111 |
+
if self.pad:
|
| 112 |
+
input = pad_packed_sequence(input, batch_first=self.batch_first)[0]
|
| 113 |
+
return input, (torch.cat(hs, 0), torch.cat(cs, 0))
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
T = 10
|
| 117 |
+
bidir = True
|
| 118 |
+
num_dir = 2 if bidir else 1
|
| 119 |
+
rnn = HighwayLSTM(10, 20, num_layers=2, bidirectional=True)
|
| 120 |
+
input = torch.randn(T, 3, 10)
|
| 121 |
+
hx = torch.randn(2 * num_dir, 3, 20)
|
| 122 |
+
cx = torch.randn(2 * num_dir, 3, 20)
|
| 123 |
+
output = rnn(input, (hx, cx))
|
| 124 |
+
print(output)
|
stanza/stanza/models/common/large_margin_loss.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LargeMarginInSoftmax, from the article
|
| 3 |
+
|
| 4 |
+
@inproceedings{kobayashi2019bmvc,
|
| 5 |
+
title={Large Margin In Softmax Cross-Entropy Loss},
|
| 6 |
+
author={Takumi Kobayashi},
|
| 7 |
+
booktitle={Proceedings of the British Machine Vision Conference (BMVC)},
|
| 8 |
+
year={2019}
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
implementation from
|
| 12 |
+
|
| 13 |
+
https://github.com/tk1980/LargeMarginInSoftmax
|
| 14 |
+
|
| 15 |
+
There is no license specifically chosen; they just ask people to cite the paper if the work is useful.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.init as init
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LargeMarginInSoftmaxLoss(nn.CrossEntropyLoss):
|
| 28 |
+
r"""
|
| 29 |
+
This combines the Softmax Cross-Entropy Loss (nn.CrossEntropyLoss) and the large-margin inducing
|
| 30 |
+
regularization proposed in
|
| 31 |
+
T. Kobayashi, "Large-Margin In Softmax Cross-Entropy Loss." In BMVC2019.
|
| 32 |
+
|
| 33 |
+
This loss function inherits the parameters from nn.CrossEntropyLoss except for `reg_lambda` and `deg_logit`.
|
| 34 |
+
Args:
|
| 35 |
+
reg_lambda (float, optional): a regularization parameter. (default: 0.3)
|
| 36 |
+
deg_logit (bool, optional): underestimate (degrade) the target logit by -1 or not. (default: False)
|
| 37 |
+
If True, it realizes the method that incorporates the modified loss into ours
|
| 38 |
+
as described in the above paper (Table 4).
|
| 39 |
+
"""
|
| 40 |
+
def __init__(self, reg_lambda=0.3, deg_logit=None,
|
| 41 |
+
weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'):
|
| 42 |
+
super(LargeMarginInSoftmaxLoss, self).__init__(weight=weight, size_average=size_average,
|
| 43 |
+
ignore_index=ignore_index, reduce=reduce, reduction=reduction)
|
| 44 |
+
self.reg_lambda = reg_lambda
|
| 45 |
+
self.deg_logit = deg_logit
|
| 46 |
+
|
| 47 |
+
def forward(self, input, target):
|
| 48 |
+
N = input.size(0) # number of samples
|
| 49 |
+
C = input.size(1) # number of classes
|
| 50 |
+
Mask = torch.zeros_like(input, requires_grad=False)
|
| 51 |
+
Mask[range(N),target] = 1
|
| 52 |
+
|
| 53 |
+
if self.deg_logit is not None:
|
| 54 |
+
input = input - self.deg_logit * Mask
|
| 55 |
+
|
| 56 |
+
loss = F.cross_entropy(input, target, weight=self.weight,
|
| 57 |
+
ignore_index=self.ignore_index, reduction=self.reduction)
|
| 58 |
+
|
| 59 |
+
X = input - 1.e6 * Mask # [N x C], excluding the target class
|
| 60 |
+
reg = 0.5 * ((F.softmax(X, dim=1) - 1.0/(C-1)) * F.log_softmax(X, dim=1) * (1.0-Mask)).sum(dim=1)
|
| 61 |
+
if self.reduction == 'sum':
|
| 62 |
+
reg = reg.sum()
|
| 63 |
+
elif self.reduction == 'mean':
|
| 64 |
+
reg = reg.mean()
|
| 65 |
+
elif self.reduction == 'none':
|
| 66 |
+
reg = reg
|
| 67 |
+
|
| 68 |
+
return loss + self.reg_lambda * reg
|
stanza/stanza/models/common/loss.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Different loss functions.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger('stanza')
|
| 13 |
+
|
| 14 |
+
def SequenceLoss(vocab_size):
|
| 15 |
+
weight = torch.ones(vocab_size)
|
| 16 |
+
weight[constant.PAD_ID] = 0
|
| 17 |
+
crit = nn.NLLLoss(weight)
|
| 18 |
+
return crit
|
| 19 |
+
|
| 20 |
+
def weighted_cross_entropy_loss(labels, log_dampened=False):
|
| 21 |
+
"""
|
| 22 |
+
Either return a loss function which reweights all examples so the
|
| 23 |
+
classes have the same effective weight, or dampened reweighting
|
| 24 |
+
using log() so that the biggest class has some priority
|
| 25 |
+
"""
|
| 26 |
+
if isinstance(labels, list):
|
| 27 |
+
all_labels = np.array(labels)
|
| 28 |
+
_, weights = np.unique(labels, return_counts=True)
|
| 29 |
+
weights = weights / float(np.sum(weights))
|
| 30 |
+
weights = np.sum(weights) / weights
|
| 31 |
+
if log_dampened:
|
| 32 |
+
weights = 1 + np.log(weights)
|
| 33 |
+
logger.debug("Reweighting cross entropy by {}".format(weights))
|
| 34 |
+
loss = nn.CrossEntropyLoss(
|
| 35 |
+
weight=torch.from_numpy(weights).type('torch.FloatTensor')
|
| 36 |
+
)
|
| 37 |
+
return loss
|
| 38 |
+
|
| 39 |
+
class FocalLoss(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Uses the model's assessment of how likely the correct answer is
|
| 42 |
+
to weight the loss for a each error
|
| 43 |
+
|
| 44 |
+
multi-category focal loss, in other words
|
| 45 |
+
|
| 46 |
+
from "Focal Loss for Dense Object Detection"
|
| 47 |
+
|
| 48 |
+
https://arxiv.org/abs/1708.02002
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, reduction='mean', gamma=2.0):
|
| 51 |
+
super().__init__()
|
| 52 |
+
if reduction not in ('sum', 'none', 'mean'):
|
| 53 |
+
raise ValueError("Unknown reduction: %s" % reduction)
|
| 54 |
+
|
| 55 |
+
self.reduction = reduction
|
| 56 |
+
self.ce_loss = nn.CrossEntropyLoss(reduction='none')
|
| 57 |
+
self.gamma = gamma
|
| 58 |
+
|
| 59 |
+
def forward(self, inputs, targets):
|
| 60 |
+
"""
|
| 61 |
+
Weight the loss using the models assessment of the correct answer
|
| 62 |
+
|
| 63 |
+
inputs: [N, C]
|
| 64 |
+
targets: [N]
|
| 65 |
+
"""
|
| 66 |
+
if len(inputs.shape) == 2 and len(targets.shape) == 1:
|
| 67 |
+
if inputs.shape[0] != targets.shape[0]:
|
| 68 |
+
raise ValueError("Expected inputs N,C and targets N, but got {} and {}".format(inputs.shape, targets.shape))
|
| 69 |
+
elif len(inputs.shape) == 1 and len(targets.shape) == 0:
|
| 70 |
+
raise NotImplementedError("This would be a reasonable thing to implement, but we haven't done it yet")
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError("Expected inputs N,C and targets N, but got {} and {}".format(inputs.shape, targets.shape))
|
| 73 |
+
|
| 74 |
+
raw_loss = self.ce_loss(inputs, targets)
|
| 75 |
+
assert len(raw_loss.shape) == 1 and raw_loss.shape[0] == inputs.shape[0]
|
| 76 |
+
|
| 77 |
+
# https://www.tutorialexample.com/implement-focal-loss-for-multi-label-classification-in-pytorch-pytorch-tutorial/
|
| 78 |
+
final_loss = raw_loss * ((1 - torch.exp(-raw_loss)) ** self.gamma)
|
| 79 |
+
assert len(final_loss.shape) == 1 and final_loss.shape[0] == inputs.shape[0]
|
| 80 |
+
if self.reduction == 'sum':
|
| 81 |
+
return final_loss.sum()
|
| 82 |
+
elif self.reduction == 'mean':
|
| 83 |
+
return final_loss.mean()
|
| 84 |
+
elif self.reduction == 'none':
|
| 85 |
+
return final_loss
|
| 86 |
+
raise AssertionError("unknown reduction! how did this happen??")
|
| 87 |
+
|
| 88 |
+
class MixLoss(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
A mixture of SequenceLoss and CrossEntropyLoss.
|
| 91 |
+
Loss = SequenceLoss + alpha * CELoss
|
| 92 |
+
"""
|
| 93 |
+
def __init__(self, vocab_size, alpha):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.seq_loss = SequenceLoss(vocab_size)
|
| 96 |
+
self.ce_loss = nn.CrossEntropyLoss()
|
| 97 |
+
assert alpha >= 0
|
| 98 |
+
self.alpha = alpha
|
| 99 |
+
|
| 100 |
+
def forward(self, seq_inputs, seq_targets, class_inputs, class_targets):
|
| 101 |
+
sl = self.seq_loss(seq_inputs, seq_targets)
|
| 102 |
+
cel = self.ce_loss(class_inputs, class_targets)
|
| 103 |
+
loss = sl + self.alpha * cel
|
| 104 |
+
return loss
|
| 105 |
+
|
| 106 |
+
class MaxEntropySequenceLoss(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
A max entropy loss that encourage the model to have large entropy,
|
| 109 |
+
therefore giving more diverse outputs.
|
| 110 |
+
|
| 111 |
+
Loss = NLLLoss + alpha * EntropyLoss
|
| 112 |
+
"""
|
| 113 |
+
def __init__(self, vocab_size, alpha):
|
| 114 |
+
super().__init__()
|
| 115 |
+
weight = torch.ones(vocab_size)
|
| 116 |
+
weight[constant.PAD_ID] = 0
|
| 117 |
+
self.nll = nn.NLLLoss(weight)
|
| 118 |
+
self.alpha = alpha
|
| 119 |
+
|
| 120 |
+
def forward(self, inputs, targets):
|
| 121 |
+
"""
|
| 122 |
+
inputs: [N, C]
|
| 123 |
+
targets: [N]
|
| 124 |
+
"""
|
| 125 |
+
assert inputs.size(0) == targets.size(0)
|
| 126 |
+
nll_loss = self.nll(inputs, targets)
|
| 127 |
+
# entropy loss
|
| 128 |
+
mask = targets.eq(constant.PAD_ID).unsqueeze(1).expand_as(inputs)
|
| 129 |
+
masked_inputs = inputs.clone().masked_fill_(mask, 0.0)
|
| 130 |
+
p = torch.exp(masked_inputs)
|
| 131 |
+
ent_loss = p.mul(masked_inputs).sum() / inputs.size(0) # average over minibatch
|
| 132 |
+
loss = nll_loss + self.alpha * ent_loss
|
| 133 |
+
return loss
|
| 134 |
+
|
stanza/stanza/models/common/maxout_linear.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A layer which implements maxout from the "Maxout Networks" paper
|
| 3 |
+
|
| 4 |
+
https://arxiv.org/pdf/1302.4389v4.pdf
|
| 5 |
+
Goodfellow, Warde-Farley, Mirza, Courville, Bengio
|
| 6 |
+
|
| 7 |
+
or a simpler explanation here:
|
| 8 |
+
|
| 9 |
+
https://stats.stackexchange.com/questions/129698/what-is-maxout-in-neural-network/298705#298705
|
| 10 |
+
|
| 11 |
+
The implementation here:
|
| 12 |
+
for k layers of maxout, in -> out channels, we make a single linear
|
| 13 |
+
map of size in -> out*k
|
| 14 |
+
then we reshape the end to be (..., k, out)
|
| 15 |
+
and return the max over the k layers
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
class MaxoutLinear(nn.Module):
|
| 23 |
+
def __init__(self, in_channels, out_channels, maxout_k):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.in_channels = in_channels
|
| 27 |
+
self.out_channels = out_channels
|
| 28 |
+
self.maxout_k = maxout_k
|
| 29 |
+
|
| 30 |
+
self.linear = nn.Linear(in_channels, out_channels * maxout_k)
|
| 31 |
+
|
| 32 |
+
def forward(self, inputs):
|
| 33 |
+
"""
|
| 34 |
+
Use the oversized linear as the repeated linear, then take the max
|
| 35 |
+
|
| 36 |
+
One large linear map makes the implementation simpler and easier for pytorch to make parallel
|
| 37 |
+
"""
|
| 38 |
+
outputs = self.linear(inputs)
|
| 39 |
+
outputs = outputs.view(*outputs.shape[:-1], self.maxout_k, self.out_channels)
|
| 40 |
+
outputs = torch.max(outputs, dim=-2)[0]
|
| 41 |
+
return outputs
|
| 42 |
+
|
stanza/stanza/models/common/packed_lstm.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
|
| 5 |
+
|
| 6 |
+
class PackedLSTM(nn.Module):
|
| 7 |
+
def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.batch_first = batch_first
|
| 11 |
+
self.pad = pad
|
| 12 |
+
if rec_dropout == 0:
|
| 13 |
+
# use the fast, native LSTM implementation
|
| 14 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
|
| 15 |
+
else:
|
| 16 |
+
self.lstm = LSTMwRecDropout(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, rec_dropout=rec_dropout)
|
| 17 |
+
|
| 18 |
+
def forward(self, input, lengths, hx=None):
|
| 19 |
+
if not isinstance(input, PackedSequence):
|
| 20 |
+
input = pack_padded_sequence(input, lengths, batch_first=self.batch_first)
|
| 21 |
+
|
| 22 |
+
res = self.lstm(input, hx)
|
| 23 |
+
if self.pad:
|
| 24 |
+
res = (pad_packed_sequence(res[0], batch_first=self.batch_first)[0], res[1])
|
| 25 |
+
return res
|
| 26 |
+
|
| 27 |
+
class LSTMwRecDropout(nn.Module):
|
| 28 |
+
""" An LSTM implementation that supports recurrent dropout """
|
| 29 |
+
def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.batch_first = batch_first
|
| 32 |
+
self.pad = pad
|
| 33 |
+
self.num_layers = num_layers
|
| 34 |
+
self.hidden_size = hidden_size
|
| 35 |
+
|
| 36 |
+
self.dropout = dropout
|
| 37 |
+
self.drop = nn.Dropout(dropout, inplace=True)
|
| 38 |
+
self.rec_drop = nn.Dropout(rec_dropout, inplace=True)
|
| 39 |
+
|
| 40 |
+
self.num_directions = 2 if bidirectional else 1
|
| 41 |
+
|
| 42 |
+
self.cells = nn.ModuleList()
|
| 43 |
+
for l in range(num_layers):
|
| 44 |
+
in_size = input_size if l == 0 else self.num_directions * hidden_size
|
| 45 |
+
for d in range(self.num_directions):
|
| 46 |
+
self.cells.append(nn.LSTMCell(in_size, hidden_size, bias=bias))
|
| 47 |
+
|
| 48 |
+
def forward(self, input, hx=None):
|
| 49 |
+
def rnn_loop(x, batch_sizes, cell, inits, reverse=False):
|
| 50 |
+
# RNN loop for one layer in one direction with recurrent dropout
|
| 51 |
+
# Assumes input is PackedSequence, returns PackedSequence as well
|
| 52 |
+
batch_size = batch_sizes[0].item()
|
| 53 |
+
states = [list(init.split([1] * batch_size)) for init in inits]
|
| 54 |
+
h_drop_mask = x.new_ones(batch_size, self.hidden_size)
|
| 55 |
+
h_drop_mask = self.rec_drop(h_drop_mask)
|
| 56 |
+
resh = []
|
| 57 |
+
|
| 58 |
+
if not reverse:
|
| 59 |
+
st = 0
|
| 60 |
+
for bs in batch_sizes:
|
| 61 |
+
s1 = cell(x[st:st+bs], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
|
| 62 |
+
resh.append(s1[0])
|
| 63 |
+
for j in range(bs):
|
| 64 |
+
states[0][j] = s1[0][j].unsqueeze(0)
|
| 65 |
+
states[1][j] = s1[1][j].unsqueeze(0)
|
| 66 |
+
st += bs
|
| 67 |
+
else:
|
| 68 |
+
en = x.size(0)
|
| 69 |
+
for i in range(batch_sizes.size(0)-1, -1, -1):
|
| 70 |
+
bs = batch_sizes[i]
|
| 71 |
+
s1 = cell(x[en-bs:en], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
|
| 72 |
+
resh.append(s1[0])
|
| 73 |
+
for j in range(bs):
|
| 74 |
+
states[0][j] = s1[0][j].unsqueeze(0)
|
| 75 |
+
states[1][j] = s1[1][j].unsqueeze(0)
|
| 76 |
+
en -= bs
|
| 77 |
+
resh = list(reversed(resh))
|
| 78 |
+
|
| 79 |
+
return torch.cat(resh, 0), tuple(torch.cat(s, 0) for s in states)
|
| 80 |
+
|
| 81 |
+
all_states = [[], []]
|
| 82 |
+
inputdata, batch_sizes = input.data, input.batch_sizes
|
| 83 |
+
for l in range(self.num_layers):
|
| 84 |
+
new_input = []
|
| 85 |
+
|
| 86 |
+
if self.dropout > 0 and l > 0:
|
| 87 |
+
inputdata = self.drop(inputdata)
|
| 88 |
+
for d in range(self.num_directions):
|
| 89 |
+
idx = l * self.num_directions + d
|
| 90 |
+
cell = self.cells[idx]
|
| 91 |
+
out, states = rnn_loop(inputdata, batch_sizes, cell, (hx[i][idx] for i in range(2)) if hx is not None else (input.data.new_zeros(input.batch_sizes[0].item(), self.hidden_size, requires_grad=False) for _ in range(2)), reverse=(d == 1))
|
| 92 |
+
|
| 93 |
+
new_input.append(out)
|
| 94 |
+
all_states[0].append(states[0].unsqueeze(0))
|
| 95 |
+
all_states[1].append(states[1].unsqueeze(0))
|
| 96 |
+
|
| 97 |
+
if self.num_directions > 1:
|
| 98 |
+
# concatenate both directions
|
| 99 |
+
inputdata = torch.cat(new_input, 1)
|
| 100 |
+
else:
|
| 101 |
+
inputdata = new_input[0]
|
| 102 |
+
|
| 103 |
+
input = PackedSequence(inputdata, batch_sizes)
|
| 104 |
+
|
| 105 |
+
return input, tuple(torch.cat(x, 0) for x in all_states)
|
stanza/stanza/models/common/peft_config.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Set a few common flags for peft uage
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
TRANSFORMER_LORA_RANK = {}
|
| 7 |
+
DEFAULT_LORA_RANK = 64
|
| 8 |
+
|
| 9 |
+
TRANSFORMER_LORA_ALPHA = {}
|
| 10 |
+
DEFAULT_LORA_ALPHA = 128
|
| 11 |
+
|
| 12 |
+
TRANSFORMER_LORA_DROPOUT = {}
|
| 13 |
+
DEFAULT_LORA_DROPOUT = 0.1
|
| 14 |
+
|
| 15 |
+
TRANSFORMER_LORA_TARGETS = {}
|
| 16 |
+
DEFAULT_LORA_TARGETS = "query,value,output.dense,intermediate.dense"
|
| 17 |
+
|
| 18 |
+
TRANSFORMER_LORA_SAVE = {}
|
| 19 |
+
DEFAULT_LORA_SAVE = ""
|
| 20 |
+
|
| 21 |
+
def add_peft_args(parser):
|
| 22 |
+
"""
|
| 23 |
+
Add common default flags to an argparse
|
| 24 |
+
"""
|
| 25 |
+
parser.add_argument('--lora_rank', type=int, default=None, help="Rank of a LoRA approximation. Default will be %d or a model-specific parameter" % DEFAULT_LORA_RANK)
|
| 26 |
+
parser.add_argument('--lora_alpha', type=int, default=None, help="Alpha of a LoRA approximation. Default will be %d or a model-specific parameter" % DEFAULT_LORA_ALPHA)
|
| 27 |
+
parser.add_argument('--lora_dropout', type=float, default=None, help="Dropout for the LoRA approximation. Default will be %s or a model-specific parameter" % DEFAULT_LORA_DROPOUT)
|
| 28 |
+
parser.add_argument('--lora_target_modules', type=str, default=None, help="Comma separated list of LoRA targets. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_TARGETS)
|
| 29 |
+
parser.add_argument('--lora_modules_to_save', type=str, default=None, help="Comma separated list of modules to save (eg, fully tune) when using LoRA. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_SAVE)
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--use_peft', default=False, action='store_true', help="Finetune Bert using peft")
|
| 32 |
+
|
| 33 |
+
def pop_peft_args(args):
|
| 34 |
+
"""
|
| 35 |
+
Pop all of the peft-related arguments from a given dict
|
| 36 |
+
|
| 37 |
+
Useful for making sure a model loaded from disk is recreated with
|
| 38 |
+
the right shapes, for example
|
| 39 |
+
"""
|
| 40 |
+
args.pop("lora_rank", None)
|
| 41 |
+
args.pop("lora_alpha", None)
|
| 42 |
+
args.pop("lora_dropout", None)
|
| 43 |
+
args.pop("lora_target_modules", None)
|
| 44 |
+
args.pop("lora_modules_to_save", None)
|
| 45 |
+
|
| 46 |
+
args.pop("use_peft", None)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def resolve_peft_args(args, logger, check_bert_finetune=True):
|
| 50 |
+
if not hasattr(args, 'bert_model'):
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
if args.lora_rank is None:
|
| 54 |
+
args.lora_rank = TRANSFORMER_LORA_RANK.get(args.bert_model, DEFAULT_LORA_RANK)
|
| 55 |
+
|
| 56 |
+
if args.lora_alpha is None:
|
| 57 |
+
args.lora_alpha = TRANSFORMER_LORA_ALPHA.get(args.bert_model, DEFAULT_LORA_ALPHA)
|
| 58 |
+
|
| 59 |
+
if args.lora_dropout is None:
|
| 60 |
+
args.lora_dropout = TRANSFORMER_LORA_DROPOUT.get(args.bert_model, DEFAULT_LORA_DROPOUT)
|
| 61 |
+
|
| 62 |
+
if args.lora_target_modules is None:
|
| 63 |
+
args.lora_target_modules = TRANSFORMER_LORA_TARGETS.get(args.bert_model, DEFAULT_LORA_TARGETS)
|
| 64 |
+
if not args.lora_target_modules.strip():
|
| 65 |
+
args.lora_target_modules = []
|
| 66 |
+
else:
|
| 67 |
+
args.lora_target_modules = args.lora_target_modules.split(",")
|
| 68 |
+
|
| 69 |
+
if args.lora_modules_to_save is None:
|
| 70 |
+
args.lora_modules_to_save = TRANSFORMER_LORA_SAVE.get(args.bert_model, DEFAULT_LORA_SAVE)
|
| 71 |
+
if not args.lora_modules_to_save.strip():
|
| 72 |
+
args.lora_modules_to_save = []
|
| 73 |
+
else:
|
| 74 |
+
args.lora_modules_to_save = args.lora_modules_to_save.split(",")
|
| 75 |
+
|
| 76 |
+
if check_bert_finetune and hasattr(args, 'bert_finetune'):
|
| 77 |
+
if args.use_peft and not args.bert_finetune:
|
| 78 |
+
logger.info("--use_peft set. setting --bert_finetune as well")
|
| 79 |
+
args.bert_finetune = True
|
| 80 |
+
|
| 81 |
+
def build_peft_config(args, logger):
|
| 82 |
+
# Hide import so that the peft dependency is optional
|
| 83 |
+
from peft import LoraConfig
|
| 84 |
+
logger.debug("Creating lora adapter with rank %d and alpha %d", args['lora_rank'], args['lora_alpha'])
|
| 85 |
+
peft_config = LoraConfig(inference_mode=False,
|
| 86 |
+
r=args['lora_rank'],
|
| 87 |
+
target_modules=args['lora_target_modules'],
|
| 88 |
+
lora_alpha=args['lora_alpha'],
|
| 89 |
+
lora_dropout=args['lora_dropout'],
|
| 90 |
+
modules_to_save=args['lora_modules_to_save'],
|
| 91 |
+
bias="none")
|
| 92 |
+
return peft_config
|
| 93 |
+
|
| 94 |
+
def build_peft_wrapper(bert_model, args, logger, adapter_name="default"):
|
| 95 |
+
# Hide import so that the peft dependency is optional
|
| 96 |
+
from peft import get_peft_model
|
| 97 |
+
peft_config = build_peft_config(args, logger)
|
| 98 |
+
|
| 99 |
+
pefted = get_peft_model(bert_model, peft_config, adapter_name=adapter_name)
|
| 100 |
+
# apparently get_peft_model doesn't actually mark that
|
| 101 |
+
# peft configs are loaded, making it impossible to turn off (or on)
|
| 102 |
+
# the peft adapter later
|
| 103 |
+
bert_model._hf_peft_config_loaded = True
|
| 104 |
+
pefted._hf_peft_config_loaded = True
|
| 105 |
+
pefted.set_adapter(adapter_name)
|
| 106 |
+
return pefted
|
| 107 |
+
|
| 108 |
+
def load_peft_wrapper(bert_model, lora_params, args, logger, adapter_name):
|
| 109 |
+
peft_config = build_peft_config(args, logger)
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
bert_model.load_adapter(adapter_name=adapter_name, peft_config=peft_config, adapter_state_dict=lora_params)
|
| 113 |
+
except (ValueError, TypeError) as _:
|
| 114 |
+
from peft import set_peft_model_state_dict
|
| 115 |
+
# this can happen if the adapter already exists...
|
| 116 |
+
# in that case, try setting the adapter weights?
|
| 117 |
+
set_peft_model_state_dict(bert_model, lora_params, adapter_name=adapter_name)
|
| 118 |
+
bert_model.set_adapter(adapter_name)
|
| 119 |
+
return bert_model
|
stanza/stanza/models/common/seq2seq_constant.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Constants for seq2seq models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
PAD = '<PAD>'
|
| 6 |
+
PAD_ID = 0
|
| 7 |
+
UNK = '<UNK>'
|
| 8 |
+
UNK_ID = 1
|
| 9 |
+
SOS = '<SOS>'
|
| 10 |
+
SOS_ID = 2
|
| 11 |
+
EOS = '<EOS>'
|
| 12 |
+
EOS_ID = 3
|
| 13 |
+
|
| 14 |
+
VOCAB_PREFIX = [PAD, UNK, SOS, EOS]
|
| 15 |
+
|
| 16 |
+
EMB_INIT_RANGE = 1.0
|
| 17 |
+
INFINITY_NUMBER = 1e12
|
stanza/stanza/models/common/seq2seq_model.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The full encoder-decoder model, built on top of the base seq2seq modules.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 12 |
+
from stanza.models.common import utils
|
| 13 |
+
from stanza.models.common.seq2seq_modules import LSTMAttention
|
| 14 |
+
from stanza.models.common.beam import Beam
|
| 15 |
+
from stanza.models.common.seq2seq_constant import UNK_ID
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger('stanza')
|
| 18 |
+
|
| 19 |
+
class Seq2SeqModel(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
A complete encoder-decoder model, with optional attention.
|
| 22 |
+
|
| 23 |
+
A parent class which makes use of the contextual_embedding (such as a charlm)
|
| 24 |
+
can make use of unsaved_modules when saving.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, args, emb_matrix=None, contextual_embedding=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.unsaved_modules = []
|
| 30 |
+
|
| 31 |
+
self.vocab_size = args['vocab_size']
|
| 32 |
+
self.emb_dim = args['emb_dim']
|
| 33 |
+
self.hidden_dim = args['hidden_dim']
|
| 34 |
+
self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1
|
| 35 |
+
self.emb_dropout = args.get('emb_dropout', 0.0)
|
| 36 |
+
self.dropout = args['dropout']
|
| 37 |
+
self.pad_token = constant.PAD_ID
|
| 38 |
+
self.max_dec_len = args['max_dec_len']
|
| 39 |
+
self.top = args.get('top', 1e10)
|
| 40 |
+
self.args = args
|
| 41 |
+
self.emb_matrix = emb_matrix
|
| 42 |
+
self.add_unsaved_module("contextual_embedding", contextual_embedding)
|
| 43 |
+
|
| 44 |
+
logger.debug("Building an attentional Seq2Seq model...")
|
| 45 |
+
logger.debug("Using a Bi-LSTM encoder")
|
| 46 |
+
self.num_directions = 2
|
| 47 |
+
self.enc_hidden_dim = self.hidden_dim // 2
|
| 48 |
+
self.dec_hidden_dim = self.hidden_dim
|
| 49 |
+
|
| 50 |
+
self.use_pos = args.get('pos', False)
|
| 51 |
+
self.pos_dim = args.get('pos_dim', 0)
|
| 52 |
+
self.pos_vocab_size = args.get('pos_vocab_size', 0)
|
| 53 |
+
self.pos_dropout = args.get('pos_dropout', 0)
|
| 54 |
+
self.edit = args.get('edit', False)
|
| 55 |
+
self.num_edit = args.get('num_edit', 0)
|
| 56 |
+
self.copy = args.get('copy', False)
|
| 57 |
+
|
| 58 |
+
self.emb_drop = nn.Dropout(self.emb_dropout)
|
| 59 |
+
self.drop = nn.Dropout(self.dropout)
|
| 60 |
+
self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
|
| 61 |
+
self.input_dim = self.emb_dim
|
| 62 |
+
if self.contextual_embedding is not None:
|
| 63 |
+
self.input_dim += self.contextual_embedding.hidden_dim()
|
| 64 |
+
self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
|
| 65 |
+
bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
|
| 66 |
+
self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \
|
| 67 |
+
batch_first=True, attn_type=self.args['attn_type'])
|
| 68 |
+
self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size)
|
| 69 |
+
if self.use_pos and self.pos_dim > 0:
|
| 70 |
+
logger.debug("Using POS in encoder")
|
| 71 |
+
self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token)
|
| 72 |
+
self.pos_drop = nn.Dropout(self.pos_dropout)
|
| 73 |
+
if self.edit:
|
| 74 |
+
edit_hidden = self.hidden_dim//2
|
| 75 |
+
self.edit_clf = nn.Sequential(
|
| 76 |
+
nn.Linear(self.hidden_dim, edit_hidden),
|
| 77 |
+
nn.ReLU(),
|
| 78 |
+
nn.Linear(edit_hidden, self.num_edit))
|
| 79 |
+
|
| 80 |
+
if self.copy:
|
| 81 |
+
self.copy_gate = nn.Linear(self.dec_hidden_dim, 1)
|
| 82 |
+
|
| 83 |
+
SOS_tensor = torch.LongTensor([constant.SOS_ID])
|
| 84 |
+
self.register_buffer('SOS_tensor', SOS_tensor)
|
| 85 |
+
|
| 86 |
+
self.init_weights()
|
| 87 |
+
|
| 88 |
+
def add_unsaved_module(self, name, module):
|
| 89 |
+
self.unsaved_modules += [name]
|
| 90 |
+
setattr(self, name, module)
|
| 91 |
+
|
| 92 |
+
def init_weights(self):
|
| 93 |
+
# initialize embeddings
|
| 94 |
+
init_range = constant.EMB_INIT_RANGE
|
| 95 |
+
if self.emb_matrix is not None:
|
| 96 |
+
if isinstance(self.emb_matrix, np.ndarray):
|
| 97 |
+
self.emb_matrix = torch.from_numpy(self.emb_matrix)
|
| 98 |
+
assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \
|
| 99 |
+
"Input embedding matrix must match size: {} x {}".format(self.vocab_size, self.emb_dim)
|
| 100 |
+
self.embedding.weight.data.copy_(self.emb_matrix)
|
| 101 |
+
else:
|
| 102 |
+
self.embedding.weight.data.uniform_(-init_range, init_range)
|
| 103 |
+
# decide finetuning
|
| 104 |
+
if self.top <= 0:
|
| 105 |
+
logger.debug("Do not finetune embedding layer.")
|
| 106 |
+
self.embedding.weight.requires_grad = False
|
| 107 |
+
elif self.top < self.vocab_size:
|
| 108 |
+
logger.debug("Finetune top {} embeddings.".format(self.top))
|
| 109 |
+
self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top))
|
| 110 |
+
else:
|
| 111 |
+
logger.debug("Finetune all embeddings.")
|
| 112 |
+
# initialize pos embeddings
|
| 113 |
+
if self.use_pos:
|
| 114 |
+
self.pos_embedding.weight.data.uniform_(-init_range, init_range)
|
| 115 |
+
|
| 116 |
+
def zero_state(self, inputs):
|
| 117 |
+
batch_size = inputs.size(0)
|
| 118 |
+
device = self.SOS_tensor.device
|
| 119 |
+
h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
|
| 120 |
+
c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
|
| 121 |
+
return h0, c0
|
| 122 |
+
|
| 123 |
+
def encode(self, enc_inputs, lens):
|
| 124 |
+
""" Encode source sequence. """
|
| 125 |
+
h0, c0 = self.zero_state(enc_inputs)
|
| 126 |
+
|
| 127 |
+
packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
|
| 128 |
+
packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0))
|
| 129 |
+
h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True)
|
| 130 |
+
hn = torch.cat((hn[-1], hn[-2]), 1)
|
| 131 |
+
cn = torch.cat((cn[-1], cn[-2]), 1)
|
| 132 |
+
return h_in, (hn, cn)
|
| 133 |
+
|
| 134 |
+
def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False):
|
| 135 |
+
""" Decode a step, based on context encoding and source context states."""
|
| 136 |
+
dec_hidden = (hn, cn)
|
| 137 |
+
decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy)
|
| 138 |
+
if self.copy:
|
| 139 |
+
h_out, dec_hidden, log_attn = decoder_output
|
| 140 |
+
else:
|
| 141 |
+
h_out, dec_hidden = decoder_output
|
| 142 |
+
|
| 143 |
+
h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1)
|
| 144 |
+
decoder_logits = self.dec2vocab(h_out_reshape)
|
| 145 |
+
decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1)
|
| 146 |
+
log_probs = self.get_log_prob(decoder_logits)
|
| 147 |
+
|
| 148 |
+
if self.copy:
|
| 149 |
+
copy_logit = self.copy_gate(h_out)
|
| 150 |
+
if self.use_pos:
|
| 151 |
+
# can't copy the UPOS
|
| 152 |
+
log_attn = log_attn[:, :, 1:]
|
| 153 |
+
|
| 154 |
+
# renormalize
|
| 155 |
+
log_attn = torch.log_softmax(log_attn, -1)
|
| 156 |
+
# calculate copy probability for each word in the vocab
|
| 157 |
+
log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn
|
| 158 |
+
# scatter logsumexp
|
| 159 |
+
mx = log_copy_prob.max(-1, keepdim=True)[0]
|
| 160 |
+
log_copy_prob = log_copy_prob - mx
|
| 161 |
+
# here we make space in the log probs for vocab items
|
| 162 |
+
# which might be copied from the encoder side, but which
|
| 163 |
+
# were not known at training time
|
| 164 |
+
# note that such an item cannot possibly be predicted by
|
| 165 |
+
# the model as a raw output token
|
| 166 |
+
# however, the copy gate might score high on copying a
|
| 167 |
+
# previously unknown vocab item
|
| 168 |
+
copy_prob = torch.exp(log_copy_prob)
|
| 169 |
+
copied_vocab_shape = list(log_probs.size())
|
| 170 |
+
if torch.max(src) >= copied_vocab_shape[-1]:
|
| 171 |
+
copied_vocab_shape[-1] = torch.max(src) + 1
|
| 172 |
+
copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape)
|
| 173 |
+
scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1))
|
| 174 |
+
# fill in the copy tensor with the copy probs of each character
|
| 175 |
+
# the rest of the copy tensor will be filled with -largenumber
|
| 176 |
+
copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob)
|
| 177 |
+
zero_mask = (copied_vocab_prob == 0)
|
| 178 |
+
log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx
|
| 179 |
+
log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12)
|
| 180 |
+
|
| 181 |
+
# combine with normal vocab probability
|
| 182 |
+
log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit))
|
| 183 |
+
if log_probs.shape[-1] < copied_vocab_shape[-1]:
|
| 184 |
+
# for previously unknown vocab items which are in the encoder,
|
| 185 |
+
# we reuse the UNK_ID prediction
|
| 186 |
+
# this gives a baseline number which we can combine with
|
| 187 |
+
# the copy gate prediction
|
| 188 |
+
# technically this makes log_probs no longer represent
|
| 189 |
+
# a probability distribution when looking at unknown vocab
|
| 190 |
+
# this is probably not a serious problem
|
| 191 |
+
# an example of this usage is in the Lemmatizer, such as a
|
| 192 |
+
# plural word in English with the character "ã" in it instead of "a"
|
| 193 |
+
# if "ã" is not known in the training data, the lemmatizer would
|
| 194 |
+
# ordinarily be unable to output it, and thus the seq2seq model
|
| 195 |
+
# would have no chance to depluralize "ãntennae" -> "ãntenna"
|
| 196 |
+
# however, if we temporarily add "ã" to the encoder vocab,
|
| 197 |
+
# then let the copy gate accept that letter, we find the Lemmatizer
|
| 198 |
+
# seq2seq model will want to copy that particular vocab item
|
| 199 |
+
# this allows the Lemmatizer to produce "ã" instead of requiring
|
| 200 |
+
# that it produces UNK, then going back to the input text to
|
| 201 |
+
# figure out which UNK it intended to produce
|
| 202 |
+
new_log_probs = log_probs.new_zeros(copied_vocab_shape)
|
| 203 |
+
new_log_probs[:, :, :log_probs.shape[-1]] = log_probs
|
| 204 |
+
new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2)
|
| 205 |
+
log_probs = new_log_probs
|
| 206 |
+
log_probs = log_probs + log_nocopy_prob
|
| 207 |
+
log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0)
|
| 208 |
+
|
| 209 |
+
if never_decode_unk:
|
| 210 |
+
log_probs[:, :, UNK_ID] = float("-inf")
|
| 211 |
+
return log_probs, dec_hidden
|
| 212 |
+
|
| 213 |
+
def embed(self, src, src_mask, pos, raw):
|
| 214 |
+
embed_src = src.clone()
|
| 215 |
+
embed_src[embed_src >= self.vocab_size] = UNK_ID
|
| 216 |
+
enc_inputs = self.emb_drop(self.embedding(embed_src))
|
| 217 |
+
batch_size = enc_inputs.size(0)
|
| 218 |
+
if self.use_pos:
|
| 219 |
+
assert pos is not None, "Missing POS input for seq2seq lemmatizer."
|
| 220 |
+
pos_inputs = self.pos_drop(self.pos_embedding(pos))
|
| 221 |
+
enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
|
| 222 |
+
pos_src_mask = src_mask.new_zeros([batch_size, 1])
|
| 223 |
+
src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
|
| 224 |
+
if raw is not None and self.contextual_embedding is not None:
|
| 225 |
+
raw_inputs = self.contextual_embedding(raw)
|
| 226 |
+
if self.use_pos:
|
| 227 |
+
raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))
|
| 228 |
+
raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)
|
| 229 |
+
enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)
|
| 230 |
+
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
|
| 231 |
+
return enc_inputs, batch_size, src_lens, src_mask
|
| 232 |
+
|
| 233 |
+
def forward(self, src, src_mask, tgt_in, pos=None, raw=None):
|
| 234 |
+
# prepare for encoder/decoder
|
| 235 |
+
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
|
| 236 |
+
|
| 237 |
+
# encode source
|
| 238 |
+
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
|
| 239 |
+
|
| 240 |
+
if self.edit:
|
| 241 |
+
edit_logits = self.edit_clf(hn)
|
| 242 |
+
else:
|
| 243 |
+
edit_logits = None
|
| 244 |
+
|
| 245 |
+
dec_inputs = self.emb_drop(self.embedding(tgt_in))
|
| 246 |
+
|
| 247 |
+
log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src)
|
| 248 |
+
return log_probs, edit_logits
|
| 249 |
+
|
| 250 |
+
def get_log_prob(self, logits):
|
| 251 |
+
logits_reshape = logits.view(-1, self.vocab_size)
|
| 252 |
+
log_probs = F.log_softmax(logits_reshape, dim=1)
|
| 253 |
+
if logits.dim() == 2:
|
| 254 |
+
return log_probs
|
| 255 |
+
return log_probs.view(logits.size(0), logits.size(1), logits.size(2))
|
| 256 |
+
|
| 257 |
+
def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False):
|
| 258 |
+
""" Predict with greedy decoding. """
|
| 259 |
+
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
|
| 260 |
+
|
| 261 |
+
# encode source
|
| 262 |
+
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
|
| 263 |
+
|
| 264 |
+
if self.edit:
|
| 265 |
+
edit_logits = self.edit_clf(hn)
|
| 266 |
+
else:
|
| 267 |
+
edit_logits = None
|
| 268 |
+
|
| 269 |
+
# greedy decode by step
|
| 270 |
+
dec_inputs = self.embedding(self.SOS_tensor)
|
| 271 |
+
dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))
|
| 272 |
+
|
| 273 |
+
done = [False for _ in range(batch_size)]
|
| 274 |
+
total_done = 0
|
| 275 |
+
max_len = 0
|
| 276 |
+
output_seqs = [[] for _ in range(batch_size)]
|
| 277 |
+
|
| 278 |
+
while total_done < batch_size and max_len < self.max_dec_len:
|
| 279 |
+
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
|
| 280 |
+
assert log_probs.size(1) == 1, "Output must have 1-step of output."
|
| 281 |
+
_, preds = log_probs.squeeze(1).max(1, keepdim=True)
|
| 282 |
+
# if a unlearned character is predicted via the copy mechanism,
|
| 283 |
+
# use the UNK embedding for it
|
| 284 |
+
dec_inputs = preds.clone()
|
| 285 |
+
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
|
| 286 |
+
dec_inputs = self.embedding(dec_inputs) # update decoder inputs
|
| 287 |
+
max_len += 1
|
| 288 |
+
for i in range(batch_size):
|
| 289 |
+
if not done[i]:
|
| 290 |
+
token = preds.data[i][0].item()
|
| 291 |
+
if token == constant.EOS_ID:
|
| 292 |
+
done[i] = True
|
| 293 |
+
total_done += 1
|
| 294 |
+
else:
|
| 295 |
+
output_seqs[i].append(token)
|
| 296 |
+
return output_seqs, edit_logits
|
| 297 |
+
|
| 298 |
+
def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False):
|
| 299 |
+
""" Predict with beam search. """
|
| 300 |
+
if beam_size == 1:
|
| 301 |
+
return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk)
|
| 302 |
+
|
| 303 |
+
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
|
| 304 |
+
|
| 305 |
+
# (1) encode source
|
| 306 |
+
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
|
| 307 |
+
|
| 308 |
+
if self.edit:
|
| 309 |
+
edit_logits = self.edit_clf(hn)
|
| 310 |
+
else:
|
| 311 |
+
edit_logits = None
|
| 312 |
+
|
| 313 |
+
# (2) set up beam
|
| 314 |
+
with torch.no_grad():
|
| 315 |
+
h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search
|
| 316 |
+
src_mask = src_mask.repeat(beam_size, 1)
|
| 317 |
+
# repeat decoder hidden states
|
| 318 |
+
hn = hn.data.repeat(beam_size, 1)
|
| 319 |
+
cn = cn.data.repeat(beam_size, 1)
|
| 320 |
+
device = self.SOS_tensor.device
|
| 321 |
+
beam = [Beam(beam_size, device) for _ in range(batch_size)]
|
| 322 |
+
|
| 323 |
+
def update_state(states, idx, positions, beam_size):
|
| 324 |
+
""" Select the states according to back pointers. """
|
| 325 |
+
for e in states:
|
| 326 |
+
br, d = e.size()
|
| 327 |
+
s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx]
|
| 328 |
+
s.data.copy_(s.data.index_select(0, positions))
|
| 329 |
+
|
| 330 |
+
# (3) main loop
|
| 331 |
+
for i in range(self.max_dec_len):
|
| 332 |
+
dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1)
|
| 333 |
+
# if a unlearned character is predicted via the copy mechanism,
|
| 334 |
+
# use the UNK embedding for it
|
| 335 |
+
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
|
| 336 |
+
dec_inputs = self.embedding(dec_inputs)
|
| 337 |
+
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
|
| 338 |
+
log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V]
|
| 339 |
+
|
| 340 |
+
# advance each beam
|
| 341 |
+
done = []
|
| 342 |
+
for b in range(batch_size):
|
| 343 |
+
is_done = beam[b].advance(log_probs.data[b])
|
| 344 |
+
if is_done:
|
| 345 |
+
done += [b]
|
| 346 |
+
# update beam state
|
| 347 |
+
update_state((hn, cn), b, beam[b].get_current_origin(), beam_size)
|
| 348 |
+
|
| 349 |
+
if len(done) == batch_size:
|
| 350 |
+
break
|
| 351 |
+
|
| 352 |
+
# back trace and find hypothesis
|
| 353 |
+
all_hyp, all_scores = [], []
|
| 354 |
+
for b in range(batch_size):
|
| 355 |
+
scores, ks = beam[b].sort_best()
|
| 356 |
+
all_scores += [scores[0]]
|
| 357 |
+
k = ks[0]
|
| 358 |
+
hyp = beam[b].get_hyp(k)
|
| 359 |
+
hyp = utils.prune_hyp(hyp)
|
| 360 |
+
hyp = [i.item() for i in hyp]
|
| 361 |
+
all_hyp += [hyp]
|
| 362 |
+
|
| 363 |
+
return all_hyp, edit_logits
|
| 364 |
+
|
stanza/stanza/models/common/seq2seq_utils.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utils for seq2seq models.
|
| 3 |
+
"""
|
| 4 |
+
from collections import Counter
|
| 5 |
+
import random
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 10 |
+
|
| 11 |
+
# torch utils
|
| 12 |
+
def get_optimizer(name, parameters, lr):
|
| 13 |
+
if name == 'sgd':
|
| 14 |
+
return torch.optim.SGD(parameters, lr=lr)
|
| 15 |
+
elif name == 'adagrad':
|
| 16 |
+
return torch.optim.Adagrad(parameters, lr=lr)
|
| 17 |
+
elif name == 'adam':
|
| 18 |
+
return torch.optim.Adam(parameters) # use default lr
|
| 19 |
+
elif name == 'adamax':
|
| 20 |
+
return torch.optim.Adamax(parameters) # use default lr
|
| 21 |
+
else:
|
| 22 |
+
raise Exception("Unsupported optimizer: {}".format(name))
|
| 23 |
+
|
| 24 |
+
def change_lr(optimizer, new_lr):
|
| 25 |
+
for param_group in optimizer.param_groups:
|
| 26 |
+
param_group['lr'] = new_lr
|
| 27 |
+
|
| 28 |
+
def flatten_indices(seq_lens, width):
|
| 29 |
+
flat = []
|
| 30 |
+
for i, l in enumerate(seq_lens):
|
| 31 |
+
for j in range(l):
|
| 32 |
+
flat.append(i * width + j)
|
| 33 |
+
return flat
|
| 34 |
+
|
| 35 |
+
def keep_partial_grad(grad, topk):
|
| 36 |
+
"""
|
| 37 |
+
Keep only the topk rows of grads.
|
| 38 |
+
"""
|
| 39 |
+
assert topk < grad.size(0)
|
| 40 |
+
grad.data[topk:].zero_()
|
| 41 |
+
return grad
|
| 42 |
+
|
| 43 |
+
# other utils
|
| 44 |
+
def save_config(config, path, verbose=True):
|
| 45 |
+
with open(path, 'w') as outfile:
|
| 46 |
+
json.dump(config, outfile, indent=2)
|
| 47 |
+
if verbose:
|
| 48 |
+
print("Config saved to file {}".format(path))
|
| 49 |
+
return config
|
| 50 |
+
|
| 51 |
+
def load_config(path, verbose=True):
|
| 52 |
+
with open(path) as f:
|
| 53 |
+
config = json.load(f)
|
| 54 |
+
if verbose:
|
| 55 |
+
print("Config loaded from file {}".format(path))
|
| 56 |
+
return config
|
| 57 |
+
|
| 58 |
+
def unmap_with_copy(indices, src_tokens, vocab):
|
| 59 |
+
"""
|
| 60 |
+
Unmap a list of list of indices, by optionally copying from src_tokens.
|
| 61 |
+
"""
|
| 62 |
+
result = []
|
| 63 |
+
for ind, tokens in zip(indices, src_tokens):
|
| 64 |
+
words = []
|
| 65 |
+
for idx in ind:
|
| 66 |
+
if idx >= 0:
|
| 67 |
+
words.append(vocab.id2word[idx])
|
| 68 |
+
else:
|
| 69 |
+
idx = -idx - 1 # flip and minus 1
|
| 70 |
+
words.append(tokens[idx])
|
| 71 |
+
result += [words]
|
| 72 |
+
return result
|
| 73 |
+
|
| 74 |
+
def prune_decoded_seqs(seqs):
|
| 75 |
+
"""
|
| 76 |
+
Prune decoded sequences after EOS token.
|
| 77 |
+
"""
|
| 78 |
+
out = []
|
| 79 |
+
for s in seqs:
|
| 80 |
+
if constant.EOS in s:
|
| 81 |
+
idx = s.index(constant.EOS_TOKEN)
|
| 82 |
+
out += [s[:idx]]
|
| 83 |
+
else:
|
| 84 |
+
out += [s]
|
| 85 |
+
return out
|
| 86 |
+
|
| 87 |
+
def prune_hyp(hyp):
|
| 88 |
+
"""
|
| 89 |
+
Prune a decoded hypothesis
|
| 90 |
+
"""
|
| 91 |
+
if constant.EOS_ID in hyp:
|
| 92 |
+
idx = hyp.index(constant.EOS_ID)
|
| 93 |
+
return hyp[:idx]
|
| 94 |
+
else:
|
| 95 |
+
return hyp
|
| 96 |
+
|
| 97 |
+
def prune(data_list, lens):
|
| 98 |
+
assert len(data_list) == len(lens)
|
| 99 |
+
nl = []
|
| 100 |
+
for d, l in zip(data_list, lens):
|
| 101 |
+
nl.append(d[:l])
|
| 102 |
+
return nl
|
| 103 |
+
|
| 104 |
+
def sort(packed, ref, reverse=True):
|
| 105 |
+
"""
|
| 106 |
+
Sort a series of packed list, according to a ref list.
|
| 107 |
+
Also return the original index before the sort.
|
| 108 |
+
"""
|
| 109 |
+
assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)
|
| 110 |
+
packed = [ref] + [range(len(ref))] + list(packed)
|
| 111 |
+
sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]
|
| 112 |
+
return tuple(sorted_packed[1:])
|
| 113 |
+
|
| 114 |
+
def unsort(sorted_list, oidx):
|
| 115 |
+
"""
|
| 116 |
+
Unsort a sorted list, based on the original idx.
|
| 117 |
+
"""
|
| 118 |
+
assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
|
| 119 |
+
_, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
|
| 120 |
+
return unsorted
|
| 121 |
+
|
stanza/stanza/models/common/short_name_to_treebank.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is autogenerated by build_short_name_to_treebank.py
|
| 2 |
+
# Please do not edit
|
| 3 |
+
|
| 4 |
+
SHORT_NAMES = {
|
| 5 |
+
'abq_atb': 'UD_Abaza-ATB',
|
| 6 |
+
'ab_abnc': 'UD_Abkhaz-AbNC',
|
| 7 |
+
'af_afribooms': 'UD_Afrikaans-AfriBooms',
|
| 8 |
+
'akk_pisandub': 'UD_Akkadian-PISANDUB',
|
| 9 |
+
'akk_riao': 'UD_Akkadian-RIAO',
|
| 10 |
+
'aqz_tudet': 'UD_Akuntsu-TuDeT',
|
| 11 |
+
'sq_staf': 'UD_Albanian-STAF',
|
| 12 |
+
'sq_tsa': 'UD_Albanian-TSA',
|
| 13 |
+
'am_att': 'UD_Amharic-ATT',
|
| 14 |
+
'grc_proiel': 'UD_Ancient_Greek-PROIEL',
|
| 15 |
+
'grc_ptnk': 'UD_Ancient_Greek-PTNK',
|
| 16 |
+
'grc_perseus': 'UD_Ancient_Greek-Perseus',
|
| 17 |
+
'hbo_ptnk': 'UD_Ancient_Hebrew-PTNK',
|
| 18 |
+
'apu_ufpa': 'UD_Apurina-UFPA',
|
| 19 |
+
'ar_nyuad': 'UD_Arabic-NYUAD',
|
| 20 |
+
'ar_padt': 'UD_Arabic-PADT',
|
| 21 |
+
'ar_pud': 'UD_Arabic-PUD',
|
| 22 |
+
'hy_armtdp': 'UD_Armenian-ArmTDP',
|
| 23 |
+
'hy_bsut': 'UD_Armenian-BSUT',
|
| 24 |
+
'aii_as': 'UD_Assyrian-AS',
|
| 25 |
+
'az_tuecl': 'UD_Azerbaijani-TueCL',
|
| 26 |
+
'bm_crb': 'UD_Bambara-CRB',
|
| 27 |
+
'eu_bdt': 'UD_Basque-BDT',
|
| 28 |
+
'bar_maibaam': 'UD_Bavarian-MaiBaam',
|
| 29 |
+
'bej_autogramm': 'UD_Beja-Autogramm',
|
| 30 |
+
'be_hse': 'UD_Belarusian-HSE',
|
| 31 |
+
'bn_bru': 'UD_Bengali-BRU',
|
| 32 |
+
'bho_bhtb': 'UD_Bhojpuri-BHTB',
|
| 33 |
+
'bor_bdt': 'UD_Bororo-BDT',
|
| 34 |
+
'br_keb': 'UD_Breton-KEB',
|
| 35 |
+
'bg_btb': 'UD_Bulgarian-BTB',
|
| 36 |
+
'bxr_bdt': 'UD_Buryat-BDT',
|
| 37 |
+
'yue_hk': 'UD_Cantonese-HK',
|
| 38 |
+
'cpg_amgic': 'UD_Cappadocian-AMGiC',
|
| 39 |
+
'cpg_tuecl': 'UD_Cappadocian-TueCL',
|
| 40 |
+
'ca_ancora': 'UD_Catalan-AnCora',
|
| 41 |
+
'ceb_gja': 'UD_Cebuano-GJA',
|
| 42 |
+
'zh-hans_beginner': 'UD_Chinese-Beginner',
|
| 43 |
+
'zh_beginner': 'UD_Chinese-Beginner',
|
| 44 |
+
'zh-hans_cfl': 'UD_Chinese-CFL',
|
| 45 |
+
'zh_cfl': 'UD_Chinese-CFL',
|
| 46 |
+
'zh-hant_gsd': 'UD_Chinese-GSD',
|
| 47 |
+
'zh_gsd': 'UD_Chinese-GSD',
|
| 48 |
+
'zh-hans_gsdsimp': 'UD_Chinese-GSDSimp',
|
| 49 |
+
'zh_gsdsimp': 'UD_Chinese-GSDSimp',
|
| 50 |
+
'zh-hant_hk': 'UD_Chinese-HK',
|
| 51 |
+
'zh_hk': 'UD_Chinese-HK',
|
| 52 |
+
'zh-hant_pud': 'UD_Chinese-PUD',
|
| 53 |
+
'zh_pud': 'UD_Chinese-PUD',
|
| 54 |
+
'zh-hans_patentchar': 'UD_Chinese-PatentChar',
|
| 55 |
+
'zh_patentchar': 'UD_Chinese-PatentChar',
|
| 56 |
+
'ckt_hse': 'UD_Chukchi-HSE',
|
| 57 |
+
'xcl_caval': 'UD_Classical_Armenian-CAVaL',
|
| 58 |
+
'lzh_kyoto': 'UD_Classical_Chinese-Kyoto',
|
| 59 |
+
'lzh_tuecl': 'UD_Classical_Chinese-TueCL',
|
| 60 |
+
'cop_scriptorium': 'UD_Coptic-Scriptorium',
|
| 61 |
+
'hr_set': 'UD_Croatian-SET',
|
| 62 |
+
'cs_cac': 'UD_Czech-CAC',
|
| 63 |
+
'cs_cltt': 'UD_Czech-CLTT',
|
| 64 |
+
'cs_fictree': 'UD_Czech-FicTree',
|
| 65 |
+
'cs_pdt': 'UD_Czech-PDT',
|
| 66 |
+
'cs_pud': 'UD_Czech-PUD',
|
| 67 |
+
'cs_poetry': 'UD_Czech-Poetry',
|
| 68 |
+
'da_ddt': 'UD_Danish-DDT',
|
| 69 |
+
'nl_alpino': 'UD_Dutch-Alpino',
|
| 70 |
+
'nl_lassysmall': 'UD_Dutch-LassySmall',
|
| 71 |
+
'egy_ujaen': 'UD_Egyptian-UJaen',
|
| 72 |
+
'en_atis': 'UD_English-Atis',
|
| 73 |
+
'en_ctetex': 'UD_English-CTeTex',
|
| 74 |
+
'en_eslspok': 'UD_English-ESLSpok',
|
| 75 |
+
'en_ewt': 'UD_English-EWT',
|
| 76 |
+
'en_gentle': 'UD_English-GENTLE',
|
| 77 |
+
'en_gum': 'UD_English-GUM',
|
| 78 |
+
'en_gumreddit': 'UD_English-GUMReddit',
|
| 79 |
+
'en_lines': 'UD_English-LinES',
|
| 80 |
+
'en_pud': 'UD_English-PUD',
|
| 81 |
+
'en_partut': 'UD_English-ParTUT',
|
| 82 |
+
'en_pronouns': 'UD_English-Pronouns',
|
| 83 |
+
'myv_jr': 'UD_Erzya-JR',
|
| 84 |
+
'et_edt': 'UD_Estonian-EDT',
|
| 85 |
+
'et_ewt': 'UD_Estonian-EWT',
|
| 86 |
+
'fo_farpahc': 'UD_Faroese-FarPaHC',
|
| 87 |
+
'fo_oft': 'UD_Faroese-OFT',
|
| 88 |
+
'fi_ftb': 'UD_Finnish-FTB',
|
| 89 |
+
'fi_ood': 'UD_Finnish-OOD',
|
| 90 |
+
'fi_pud': 'UD_Finnish-PUD',
|
| 91 |
+
'fi_tdt': 'UD_Finnish-TDT',
|
| 92 |
+
'fr_fqb': 'UD_French-FQB',
|
| 93 |
+
'fr_gsd': 'UD_French-GSD',
|
| 94 |
+
'fr_pud': 'UD_French-PUD',
|
| 95 |
+
'fr_partut': 'UD_French-ParTUT',
|
| 96 |
+
'fr_parisstories': 'UD_French-ParisStories',
|
| 97 |
+
'fr_rhapsodie': 'UD_French-Rhapsodie',
|
| 98 |
+
'fr_sequoia': 'UD_French-Sequoia',
|
| 99 |
+
'qfn_fame': 'UD_Frisian_Dutch-Fame',
|
| 100 |
+
'gl_ctg': 'UD_Galician-CTG',
|
| 101 |
+
'gl_pud': 'UD_Galician-PUD',
|
| 102 |
+
'gl_treegal': 'UD_Galician-TreeGal',
|
| 103 |
+
'ka_glc': 'UD_Georgian-GLC',
|
| 104 |
+
'de_gsd': 'UD_German-GSD',
|
| 105 |
+
'de_hdt': 'UD_German-HDT',
|
| 106 |
+
'de_lit': 'UD_German-LIT',
|
| 107 |
+
'de_pud': 'UD_German-PUD',
|
| 108 |
+
'aln_gps': 'UD_Gheg-GPS',
|
| 109 |
+
'got_proiel': 'UD_Gothic-PROIEL',
|
| 110 |
+
'el_gdt': 'UD_Greek-GDT',
|
| 111 |
+
'el_gud': 'UD_Greek-GUD',
|
| 112 |
+
'gub_tudet': 'UD_Guajajara-TuDeT',
|
| 113 |
+
'gn_oldtudet': 'UD_Guarani-OldTuDeT',
|
| 114 |
+
'gu_gujtb': 'UD_Gujarati-GujTB',
|
| 115 |
+
'gwi_tuecl': 'UD_Gwichin-TueCL',
|
| 116 |
+
'ht_autogramm': 'UD_Haitian_Creole-Autogramm',
|
| 117 |
+
'ha_northernautogramm': 'UD_Hausa-NorthernAutogramm',
|
| 118 |
+
'ha_southernautogramm': 'UD_Hausa-SouthernAutogramm',
|
| 119 |
+
'he_htb': 'UD_Hebrew-HTB',
|
| 120 |
+
'he_iahltknesset': 'UD_Hebrew-IAHLTknesset',
|
| 121 |
+
'he_iahltwiki': 'UD_Hebrew-IAHLTwiki',
|
| 122 |
+
'azz_itml': 'UD_Highland_Puebla_Nahuatl-ITML',
|
| 123 |
+
'hi_hdtb': 'UD_Hindi-HDTB',
|
| 124 |
+
'hi_pud': 'UD_Hindi-PUD',
|
| 125 |
+
'hit_hittb': 'UD_Hittite-HitTB',
|
| 126 |
+
'hu_szeged': 'UD_Hungarian-Szeged',
|
| 127 |
+
'is_gc': 'UD_Icelandic-GC',
|
| 128 |
+
'is_icepahc': 'UD_Icelandic-IcePaHC',
|
| 129 |
+
'is_modern': 'UD_Icelandic-Modern',
|
| 130 |
+
'is_pud': 'UD_Icelandic-PUD',
|
| 131 |
+
'id_csui': 'UD_Indonesian-CSUI',
|
| 132 |
+
'id_gsd': 'UD_Indonesian-GSD',
|
| 133 |
+
'id_pud': 'UD_Indonesian-PUD',
|
| 134 |
+
'ga_cadhan': 'UD_Irish-Cadhan',
|
| 135 |
+
'ga_idt': 'UD_Irish-IDT',
|
| 136 |
+
'ga_twittirish': 'UD_Irish-TwittIrish',
|
| 137 |
+
'it_isdt': 'UD_Italian-ISDT',
|
| 138 |
+
'it_markit': 'UD_Italian-MarkIT',
|
| 139 |
+
'it_old': 'UD_Italian-Old',
|
| 140 |
+
'it_pud': 'UD_Italian-PUD',
|
| 141 |
+
'it_partut': 'UD_Italian-ParTUT',
|
| 142 |
+
'it_parlamint': 'UD_Italian-ParlaMint',
|
| 143 |
+
'it_postwita': 'UD_Italian-PoSTWITA',
|
| 144 |
+
'it_twittiro': 'UD_Italian-TWITTIRO',
|
| 145 |
+
'it_vit': 'UD_Italian-VIT',
|
| 146 |
+
'it_valico': 'UD_Italian-Valico',
|
| 147 |
+
'ja_bccwj': 'UD_Japanese-BCCWJ',
|
| 148 |
+
'ja_bccwjluw': 'UD_Japanese-BCCWJLUW',
|
| 149 |
+
'ja_gsd': 'UD_Japanese-GSD',
|
| 150 |
+
'ja_gsdluw': 'UD_Japanese-GSDLUW',
|
| 151 |
+
'ja_pud': 'UD_Japanese-PUD',
|
| 152 |
+
'ja_pudluw': 'UD_Japanese-PUDLUW',
|
| 153 |
+
'jv_csui': 'UD_Javanese-CSUI',
|
| 154 |
+
'urb_tudet': 'UD_Kaapor-TuDeT',
|
| 155 |
+
'xnr_kdtb': 'UD_Kangri-KDTB',
|
| 156 |
+
'krl_kkpp': 'UD_Karelian-KKPP',
|
| 157 |
+
'arr_tudet': 'UD_Karo-TuDeT',
|
| 158 |
+
'kk_ktb': 'UD_Kazakh-KTB',
|
| 159 |
+
'kfm_aha': 'UD_Khunsari-AHA',
|
| 160 |
+
'quc_iu': 'UD_Kiche-IU',
|
| 161 |
+
'koi_uh': 'UD_Komi_Permyak-UH',
|
| 162 |
+
'kpv_ikdp': 'UD_Komi_Zyrian-IKDP',
|
| 163 |
+
'kpv_lattice': 'UD_Komi_Zyrian-Lattice',
|
| 164 |
+
'ko_gsd': 'UD_Korean-GSD',
|
| 165 |
+
'ko_ksl': 'UD_Korean-KSL',
|
| 166 |
+
'ko_kaist': 'UD_Korean-Kaist',
|
| 167 |
+
'ko_pud': 'UD_Korean-PUD',
|
| 168 |
+
'kmr_mg': 'UD_Kurmanji-MG',
|
| 169 |
+
'ky_ktmu': 'UD_Kyrgyz-KTMU',
|
| 170 |
+
'ky_tuecl': 'UD_Kyrgyz-TueCL',
|
| 171 |
+
'ltg_cairo': 'UD_Latgalian-Cairo',
|
| 172 |
+
'la_circse': 'UD_Latin-CIRCSE',
|
| 173 |
+
'la_ittb': 'UD_Latin-ITTB',
|
| 174 |
+
'la_llct': 'UD_Latin-LLCT',
|
| 175 |
+
'la_proiel': 'UD_Latin-PROIEL',
|
| 176 |
+
'la_perseus': 'UD_Latin-Perseus',
|
| 177 |
+
'la_udante': 'UD_Latin-UDante',
|
| 178 |
+
'lv_cairo': 'UD_Latvian-Cairo',
|
| 179 |
+
'lv_lvtb': 'UD_Latvian-LVTB',
|
| 180 |
+
'lij_glt': 'UD_Ligurian-GLT',
|
| 181 |
+
'lt_alksnis': 'UD_Lithuanian-ALKSNIS',
|
| 182 |
+
'lt_hse': 'UD_Lithuanian-HSE',
|
| 183 |
+
'olo_kkpp': 'UD_Livvi-KKPP',
|
| 184 |
+
'nds_lsdc': 'UD_Low_Saxon-LSDC',
|
| 185 |
+
'lb_luxbank': 'UD_Luxembourgish-LuxBank',
|
| 186 |
+
'mk_mtb': 'UD_Macedonian-MTB',
|
| 187 |
+
'jaa_jarawara': 'UD_Madi-Jarawara',
|
| 188 |
+
'qaf_arabizi': 'UD_Maghrebi_Arabic_French-Arabizi',
|
| 189 |
+
'mpu_tudet': 'UD_Makurap-TuDeT',
|
| 190 |
+
'ml_ufal': 'UD_Malayalam-UFAL',
|
| 191 |
+
'mt_mudt': 'UD_Maltese-MUDT',
|
| 192 |
+
'gv_cadhan': 'UD_Manx-Cadhan',
|
| 193 |
+
'mr_ufal': 'UD_Marathi-UFAL',
|
| 194 |
+
'gun_dooley': 'UD_Mbya_Guarani-Dooley',
|
| 195 |
+
'gun_thomas': 'UD_Mbya_Guarani-Thomas',
|
| 196 |
+
'frm_profiterole': 'UD_Middle_French-PROFITEROLE',
|
| 197 |
+
'mdf_jr': 'UD_Moksha-JR',
|
| 198 |
+
'myu_tudet': 'UD_Munduruku-TuDeT',
|
| 199 |
+
'pcm_nsc': 'UD_Naija-NSC',
|
| 200 |
+
'nyq_aha': 'UD_Nayini-AHA',
|
| 201 |
+
'nap_rb': 'UD_Neapolitan-RB',
|
| 202 |
+
'yrl_complin': 'UD_Nheengatu-CompLin',
|
| 203 |
+
'sme_giella': 'UD_North_Sami-Giella',
|
| 204 |
+
'gya_autogramm': 'UD_Northwest_Gbaya-Autogramm',
|
| 205 |
+
'nb_bokmaal': 'UD_Norwegian-Bokmaal',
|
| 206 |
+
'no_bokmaal': 'UD_Norwegian-Bokmaal',
|
| 207 |
+
'nn_nynorsk': 'UD_Norwegian-Nynorsk',
|
| 208 |
+
'cu_proiel': 'UD_Old_Church_Slavonic-PROIEL',
|
| 209 |
+
'orv_birchbark': 'UD_Old_East_Slavic-Birchbark',
|
| 210 |
+
'orv_rnc': 'UD_Old_East_Slavic-RNC',
|
| 211 |
+
'orv_ruthenian': 'UD_Old_East_Slavic-Ruthenian',
|
| 212 |
+
'orv_torot': 'UD_Old_East_Slavic-TOROT',
|
| 213 |
+
'fro_profiterole': 'UD_Old_French-PROFITEROLE',
|
| 214 |
+
'sga_dipsgg': 'UD_Old_Irish-DipSGG',
|
| 215 |
+
'sga_dipwbg': 'UD_Old_Irish-DipWBG',
|
| 216 |
+
'otk_clausal': 'UD_Old_Turkish-Clausal',
|
| 217 |
+
'ota_boun': 'UD_Ottoman_Turkish-BOUN',
|
| 218 |
+
'ota_dudu': 'UD_Ottoman_Turkish-DUDU',
|
| 219 |
+
'ps_sikaram': 'UD_Pashto-Sikaram',
|
| 220 |
+
'pad_tuecl': 'UD_Paumari-TueCL',
|
| 221 |
+
'fa_perdt': 'UD_Persian-PerDT',
|
| 222 |
+
'fa_seraji': 'UD_Persian-Seraji',
|
| 223 |
+
'pay_chibergis': 'UD_Pesh-ChibErgIS',
|
| 224 |
+
'xpg_kul': 'UD_Phrygian-KUL',
|
| 225 |
+
'pl_lfg': 'UD_Polish-LFG',
|
| 226 |
+
'pl_pdb': 'UD_Polish-PDB',
|
| 227 |
+
'pl_pud': 'UD_Polish-PUD',
|
| 228 |
+
'qpm_philotis': 'UD_Pomak-Philotis',
|
| 229 |
+
'pt_bosque': 'UD_Portuguese-Bosque',
|
| 230 |
+
'pt_cintil': 'UD_Portuguese-CINTIL',
|
| 231 |
+
'pt_dantestocks': 'UD_Portuguese-DANTEStocks',
|
| 232 |
+
'pt_gsd': 'UD_Portuguese-GSD',
|
| 233 |
+
'pt_pud': 'UD_Portuguese-PUD',
|
| 234 |
+
'pt_petrogold': 'UD_Portuguese-PetroGold',
|
| 235 |
+
'pt_porttinari': 'UD_Portuguese-Porttinari',
|
| 236 |
+
'ro_art': 'UD_Romanian-ArT',
|
| 237 |
+
'ro_nonstandard': 'UD_Romanian-Nonstandard',
|
| 238 |
+
'ro_rrt': 'UD_Romanian-RRT',
|
| 239 |
+
'ro_simonero': 'UD_Romanian-SiMoNERo',
|
| 240 |
+
'ro_tuecl': 'UD_Romanian-TueCL',
|
| 241 |
+
'ru_gsd': 'UD_Russian-GSD',
|
| 242 |
+
'ru_pud': 'UD_Russian-PUD',
|
| 243 |
+
'ru_poetry': 'UD_Russian-Poetry',
|
| 244 |
+
'ru_syntagrus': 'UD_Russian-SynTagRus',
|
| 245 |
+
'ru_taiga': 'UD_Russian-Taiga',
|
| 246 |
+
'sa_ufal': 'UD_Sanskrit-UFAL',
|
| 247 |
+
'sa_vedic': 'UD_Sanskrit-Vedic',
|
| 248 |
+
'gd_arcosg': 'UD_Scottish_Gaelic-ARCOSG',
|
| 249 |
+
'sr_set': 'UD_Serbian-SET',
|
| 250 |
+
'si_stb': 'UD_Sinhala-STB',
|
| 251 |
+
'sms_giellagas': 'UD_Skolt_Sami-Giellagas',
|
| 252 |
+
'sk_snk': 'UD_Slovak-SNK',
|
| 253 |
+
'sl_ssj': 'UD_Slovenian-SSJ',
|
| 254 |
+
'sl_sst': 'UD_Slovenian-SST',
|
| 255 |
+
'soj_aha': 'UD_Soi-AHA',
|
| 256 |
+
'ajp_madar': 'UD_South_Levantine_Arabic-MADAR',
|
| 257 |
+
'es_ancora': 'UD_Spanish-AnCora',
|
| 258 |
+
'es_coser': 'UD_Spanish-COSER',
|
| 259 |
+
'es_gsd': 'UD_Spanish-GSD',
|
| 260 |
+
'es_pud': 'UD_Spanish-PUD',
|
| 261 |
+
'ssp_lse': 'UD_Spanish_Sign_Language-LSE',
|
| 262 |
+
'sv_lines': 'UD_Swedish-LinES',
|
| 263 |
+
'sv_pud': 'UD_Swedish-PUD',
|
| 264 |
+
'sv_talbanken': 'UD_Swedish-Talbanken',
|
| 265 |
+
'swl_sslc': 'UD_Swedish_Sign_Language-SSLC',
|
| 266 |
+
'gsw_uzh': 'UD_Swiss_German-UZH',
|
| 267 |
+
'tl_trg': 'UD_Tagalog-TRG',
|
| 268 |
+
'tl_ugnayan': 'UD_Tagalog-Ugnayan',
|
| 269 |
+
'ta_mwtt': 'UD_Tamil-MWTT',
|
| 270 |
+
'ta_ttb': 'UD_Tamil-TTB',
|
| 271 |
+
'tt_nmctt': 'UD_Tatar-NMCTT',
|
| 272 |
+
'eme_tudet': 'UD_Teko-TuDeT',
|
| 273 |
+
'te_mtg': 'UD_Telugu-MTG',
|
| 274 |
+
'qte_tect': 'UD_Telugu_English-TECT',
|
| 275 |
+
'th_pud': 'UD_Thai-PUD',
|
| 276 |
+
'tn_popapolelo': 'UD_Tswana-Popapolelo',
|
| 277 |
+
'tpn_tudet': 'UD_Tupinamba-TuDeT',
|
| 278 |
+
'tr_atis': 'UD_Turkish-Atis',
|
| 279 |
+
'tr_boun': 'UD_Turkish-BOUN',
|
| 280 |
+
'tr_framenet': 'UD_Turkish-FrameNet',
|
| 281 |
+
'tr_gb': 'UD_Turkish-GB',
|
| 282 |
+
'tr_imst': 'UD_Turkish-IMST',
|
| 283 |
+
'tr_kenet': 'UD_Turkish-Kenet',
|
| 284 |
+
'tr_pud': 'UD_Turkish-PUD',
|
| 285 |
+
'tr_penn': 'UD_Turkish-Penn',
|
| 286 |
+
'tr_tourism': 'UD_Turkish-Tourism',
|
| 287 |
+
'qtd_sagt': 'UD_Turkish_German-SAGT',
|
| 288 |
+
'uk_iu': 'UD_Ukrainian-IU',
|
| 289 |
+
'uk_parlamint': 'UD_Ukrainian-ParlaMint',
|
| 290 |
+
'xum_ikuvina': 'UD_Umbrian-IKUVINA',
|
| 291 |
+
'hsb_ufal': 'UD_Upper_Sorbian-UFAL',
|
| 292 |
+
'ur_udtb': 'UD_Urdu-UDTB',
|
| 293 |
+
'ug_udt': 'UD_Uyghur-UDT',
|
| 294 |
+
'uz_ut': 'UD_Uzbek-UT',
|
| 295 |
+
'vep_vwt': 'UD_Veps-VWT',
|
| 296 |
+
'vi_tuecl': 'UD_Vietnamese-TueCL',
|
| 297 |
+
'vi_vtb': 'UD_Vietnamese-VTB',
|
| 298 |
+
'wbp_ufal': 'UD_Warlpiri-UFAL',
|
| 299 |
+
'cy_ccg': 'UD_Welsh-CCG',
|
| 300 |
+
'hyw_armtdp': 'UD_Western_Armenian-ArmTDP',
|
| 301 |
+
'nhi_itml': 'UD_Western_Sierra_Puebla_Nahuatl-ITML',
|
| 302 |
+
'wo_wtb': 'UD_Wolof-WTB',
|
| 303 |
+
'xav_xdt': 'UD_Xavante-XDT',
|
| 304 |
+
'sjo_xdt': 'UD_Xibe-XDT',
|
| 305 |
+
'sah_yktdt': 'UD_Yakut-YKTDT',
|
| 306 |
+
'yo_ytb': 'UD_Yoruba-YTB',
|
| 307 |
+
'ess_sli': 'UD_Yupik-SLI',
|
| 308 |
+
'say_autogramm': 'UD_Zaar-Autogramm',
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def short_name_to_treebank(short_name):
|
| 313 |
+
return SHORT_NAMES[short_name]
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
CANONICAL_NAMES = {
|
| 317 |
+
'ud_abaza-atb': 'UD_Abaza-ATB',
|
| 318 |
+
'ud_abkhaz-abnc': 'UD_Abkhaz-AbNC',
|
| 319 |
+
'ud_afrikaans-afribooms': 'UD_Afrikaans-AfriBooms',
|
| 320 |
+
'ud_akkadian-pisandub': 'UD_Akkadian-PISANDUB',
|
| 321 |
+
'ud_akkadian-riao': 'UD_Akkadian-RIAO',
|
| 322 |
+
'ud_akuntsu-tudet': 'UD_Akuntsu-TuDeT',
|
| 323 |
+
'ud_albanian-staf': 'UD_Albanian-STAF',
|
| 324 |
+
'ud_albanian-tsa': 'UD_Albanian-TSA',
|
| 325 |
+
'ud_amharic-att': 'UD_Amharic-ATT',
|
| 326 |
+
'ud_ancient_greek-proiel': 'UD_Ancient_Greek-PROIEL',
|
| 327 |
+
'ud_ancient_greek-ptnk': 'UD_Ancient_Greek-PTNK',
|
| 328 |
+
'ud_ancient_greek-perseus': 'UD_Ancient_Greek-Perseus',
|
| 329 |
+
'ud_ancient_hebrew-ptnk': 'UD_Ancient_Hebrew-PTNK',
|
| 330 |
+
'ud_apurina-ufpa': 'UD_Apurina-UFPA',
|
| 331 |
+
'ud_arabic-nyuad': 'UD_Arabic-NYUAD',
|
| 332 |
+
'ud_arabic-padt': 'UD_Arabic-PADT',
|
| 333 |
+
'ud_arabic-pud': 'UD_Arabic-PUD',
|
| 334 |
+
'ud_armenian-armtdp': 'UD_Armenian-ArmTDP',
|
| 335 |
+
'ud_armenian-bsut': 'UD_Armenian-BSUT',
|
| 336 |
+
'ud_assyrian-as': 'UD_Assyrian-AS',
|
| 337 |
+
'ud_azerbaijani-tuecl': 'UD_Azerbaijani-TueCL',
|
| 338 |
+
'ud_bambara-crb': 'UD_Bambara-CRB',
|
| 339 |
+
'ud_basque-bdt': 'UD_Basque-BDT',
|
| 340 |
+
'ud_bavarian-maibaam': 'UD_Bavarian-MaiBaam',
|
| 341 |
+
'ud_beja-autogramm': 'UD_Beja-Autogramm',
|
| 342 |
+
'ud_belarusian-hse': 'UD_Belarusian-HSE',
|
| 343 |
+
'ud_bengali-bru': 'UD_Bengali-BRU',
|
| 344 |
+
'ud_bhojpuri-bhtb': 'UD_Bhojpuri-BHTB',
|
| 345 |
+
'ud_bororo-bdt': 'UD_Bororo-BDT',
|
| 346 |
+
'ud_breton-keb': 'UD_Breton-KEB',
|
| 347 |
+
'ud_bulgarian-btb': 'UD_Bulgarian-BTB',
|
| 348 |
+
'ud_buryat-bdt': 'UD_Buryat-BDT',
|
| 349 |
+
'ud_cantonese-hk': 'UD_Cantonese-HK',
|
| 350 |
+
'ud_cappadocian-amgic': 'UD_Cappadocian-AMGiC',
|
| 351 |
+
'ud_cappadocian-tuecl': 'UD_Cappadocian-TueCL',
|
| 352 |
+
'ud_catalan-ancora': 'UD_Catalan-AnCora',
|
| 353 |
+
'ud_cebuano-gja': 'UD_Cebuano-GJA',
|
| 354 |
+
'ud_chinese-beginner': 'UD_Chinese-Beginner',
|
| 355 |
+
'ud_chinese-cfl': 'UD_Chinese-CFL',
|
| 356 |
+
'ud_chinese-gsd': 'UD_Chinese-GSD',
|
| 357 |
+
'ud_chinese-gsdsimp': 'UD_Chinese-GSDSimp',
|
| 358 |
+
'ud_chinese-hk': 'UD_Chinese-HK',
|
| 359 |
+
'ud_chinese-pud': 'UD_Chinese-PUD',
|
| 360 |
+
'ud_chinese-patentchar': 'UD_Chinese-PatentChar',
|
| 361 |
+
'ud_chukchi-hse': 'UD_Chukchi-HSE',
|
| 362 |
+
'ud_classical_armenian-caval': 'UD_Classical_Armenian-CAVaL',
|
| 363 |
+
'ud_classical_chinese-kyoto': 'UD_Classical_Chinese-Kyoto',
|
| 364 |
+
'ud_classical_chinese-tuecl': 'UD_Classical_Chinese-TueCL',
|
| 365 |
+
'ud_coptic-scriptorium': 'UD_Coptic-Scriptorium',
|
| 366 |
+
'ud_croatian-set': 'UD_Croatian-SET',
|
| 367 |
+
'ud_czech-cac': 'UD_Czech-CAC',
|
| 368 |
+
'ud_czech-cltt': 'UD_Czech-CLTT',
|
| 369 |
+
'ud_czech-fictree': 'UD_Czech-FicTree',
|
| 370 |
+
'ud_czech-pdt': 'UD_Czech-PDT',
|
| 371 |
+
'ud_czech-pud': 'UD_Czech-PUD',
|
| 372 |
+
'ud_czech-poetry': 'UD_Czech-Poetry',
|
| 373 |
+
'ud_danish-ddt': 'UD_Danish-DDT',
|
| 374 |
+
'ud_dutch-alpino': 'UD_Dutch-Alpino',
|
| 375 |
+
'ud_dutch-lassysmall': 'UD_Dutch-LassySmall',
|
| 376 |
+
'ud_egyptian-ujaen': 'UD_Egyptian-UJaen',
|
| 377 |
+
'ud_english-atis': 'UD_English-Atis',
|
| 378 |
+
'ud_english-ctetex': 'UD_English-CTeTex',
|
| 379 |
+
'ud_english-eslspok': 'UD_English-ESLSpok',
|
| 380 |
+
'ud_english-ewt': 'UD_English-EWT',
|
| 381 |
+
'ud_english-gentle': 'UD_English-GENTLE',
|
| 382 |
+
'ud_english-gum': 'UD_English-GUM',
|
| 383 |
+
'ud_english-gumreddit': 'UD_English-GUMReddit',
|
| 384 |
+
'ud_english-lines': 'UD_English-LinES',
|
| 385 |
+
'ud_english-pud': 'UD_English-PUD',
|
| 386 |
+
'ud_english-partut': 'UD_English-ParTUT',
|
| 387 |
+
'ud_english-pronouns': 'UD_English-Pronouns',
|
| 388 |
+
'ud_erzya-jr': 'UD_Erzya-JR',
|
| 389 |
+
'ud_estonian-edt': 'UD_Estonian-EDT',
|
| 390 |
+
'ud_estonian-ewt': 'UD_Estonian-EWT',
|
| 391 |
+
'ud_faroese-farpahc': 'UD_Faroese-FarPaHC',
|
| 392 |
+
'ud_faroese-oft': 'UD_Faroese-OFT',
|
| 393 |
+
'ud_finnish-ftb': 'UD_Finnish-FTB',
|
| 394 |
+
'ud_finnish-ood': 'UD_Finnish-OOD',
|
| 395 |
+
'ud_finnish-pud': 'UD_Finnish-PUD',
|
| 396 |
+
'ud_finnish-tdt': 'UD_Finnish-TDT',
|
| 397 |
+
'ud_french-fqb': 'UD_French-FQB',
|
| 398 |
+
'ud_french-gsd': 'UD_French-GSD',
|
| 399 |
+
'ud_french-pud': 'UD_French-PUD',
|
| 400 |
+
'ud_french-partut': 'UD_French-ParTUT',
|
| 401 |
+
'ud_french-parisstories': 'UD_French-ParisStories',
|
| 402 |
+
'ud_french-rhapsodie': 'UD_French-Rhapsodie',
|
| 403 |
+
'ud_french-sequoia': 'UD_French-Sequoia',
|
| 404 |
+
'ud_frisian_dutch-fame': 'UD_Frisian_Dutch-Fame',
|
| 405 |
+
'ud_galician-ctg': 'UD_Galician-CTG',
|
| 406 |
+
'ud_galician-pud': 'UD_Galician-PUD',
|
| 407 |
+
'ud_galician-treegal': 'UD_Galician-TreeGal',
|
| 408 |
+
'ud_georgian-glc': 'UD_Georgian-GLC',
|
| 409 |
+
'ud_german-gsd': 'UD_German-GSD',
|
| 410 |
+
'ud_german-hdt': 'UD_German-HDT',
|
| 411 |
+
'ud_german-lit': 'UD_German-LIT',
|
| 412 |
+
'ud_german-pud': 'UD_German-PUD',
|
| 413 |
+
'ud_gheg-gps': 'UD_Gheg-GPS',
|
| 414 |
+
'ud_gothic-proiel': 'UD_Gothic-PROIEL',
|
| 415 |
+
'ud_greek-gdt': 'UD_Greek-GDT',
|
| 416 |
+
'ud_greek-gud': 'UD_Greek-GUD',
|
| 417 |
+
'ud_guajajara-tudet': 'UD_Guajajara-TuDeT',
|
| 418 |
+
'ud_guarani-oldtudet': 'UD_Guarani-OldTuDeT',
|
| 419 |
+
'ud_gujarati-gujtb': 'UD_Gujarati-GujTB',
|
| 420 |
+
'ud_gwichin-tuecl': 'UD_Gwichin-TueCL',
|
| 421 |
+
'ud_haitian_creole-autogramm': 'UD_Haitian_Creole-Autogramm',
|
| 422 |
+
'ud_hausa-northernautogramm': 'UD_Hausa-NorthernAutogramm',
|
| 423 |
+
'ud_hausa-southernautogramm': 'UD_Hausa-SouthernAutogramm',
|
| 424 |
+
'ud_hebrew-htb': 'UD_Hebrew-HTB',
|
| 425 |
+
'ud_hebrew-iahltknesset': 'UD_Hebrew-IAHLTknesset',
|
| 426 |
+
'ud_hebrew-iahltwiki': 'UD_Hebrew-IAHLTwiki',
|
| 427 |
+
'ud_highland_puebla_nahuatl-itml': 'UD_Highland_Puebla_Nahuatl-ITML',
|
| 428 |
+
'ud_hindi-hdtb': 'UD_Hindi-HDTB',
|
| 429 |
+
'ud_hindi-pud': 'UD_Hindi-PUD',
|
| 430 |
+
'ud_hittite-hittb': 'UD_Hittite-HitTB',
|
| 431 |
+
'ud_hungarian-szeged': 'UD_Hungarian-Szeged',
|
| 432 |
+
'ud_icelandic-gc': 'UD_Icelandic-GC',
|
| 433 |
+
'ud_icelandic-icepahc': 'UD_Icelandic-IcePaHC',
|
| 434 |
+
'ud_icelandic-modern': 'UD_Icelandic-Modern',
|
| 435 |
+
'ud_icelandic-pud': 'UD_Icelandic-PUD',
|
| 436 |
+
'ud_indonesian-csui': 'UD_Indonesian-CSUI',
|
| 437 |
+
'ud_indonesian-gsd': 'UD_Indonesian-GSD',
|
| 438 |
+
'ud_indonesian-pud': 'UD_Indonesian-PUD',
|
| 439 |
+
'ud_irish-cadhan': 'UD_Irish-Cadhan',
|
| 440 |
+
'ud_irish-idt': 'UD_Irish-IDT',
|
| 441 |
+
'ud_irish-twittirish': 'UD_Irish-TwittIrish',
|
| 442 |
+
'ud_italian-isdt': 'UD_Italian-ISDT',
|
| 443 |
+
'ud_italian-markit': 'UD_Italian-MarkIT',
|
| 444 |
+
'ud_italian-old': 'UD_Italian-Old',
|
| 445 |
+
'ud_italian-pud': 'UD_Italian-PUD',
|
| 446 |
+
'ud_italian-partut': 'UD_Italian-ParTUT',
|
| 447 |
+
'ud_italian-parlamint': 'UD_Italian-ParlaMint',
|
| 448 |
+
'ud_italian-postwita': 'UD_Italian-PoSTWITA',
|
| 449 |
+
'ud_italian-twittiro': 'UD_Italian-TWITTIRO',
|
| 450 |
+
'ud_italian-vit': 'UD_Italian-VIT',
|
| 451 |
+
'ud_italian-valico': 'UD_Italian-Valico',
|
| 452 |
+
'ud_japanese-bccwj': 'UD_Japanese-BCCWJ',
|
| 453 |
+
'ud_japanese-bccwjluw': 'UD_Japanese-BCCWJLUW',
|
| 454 |
+
'ud_japanese-gsd': 'UD_Japanese-GSD',
|
| 455 |
+
'ud_japanese-gsdluw': 'UD_Japanese-GSDLUW',
|
| 456 |
+
'ud_japanese-pud': 'UD_Japanese-PUD',
|
| 457 |
+
'ud_japanese-pudluw': 'UD_Japanese-PUDLUW',
|
| 458 |
+
'ud_javanese-csui': 'UD_Javanese-CSUI',
|
| 459 |
+
'ud_kaapor-tudet': 'UD_Kaapor-TuDeT',
|
| 460 |
+
'ud_kangri-kdtb': 'UD_Kangri-KDTB',
|
| 461 |
+
'ud_karelian-kkpp': 'UD_Karelian-KKPP',
|
| 462 |
+
'ud_karo-tudet': 'UD_Karo-TuDeT',
|
| 463 |
+
'ud_kazakh-ktb': 'UD_Kazakh-KTB',
|
| 464 |
+
'ud_khunsari-aha': 'UD_Khunsari-AHA',
|
| 465 |
+
'ud_kiche-iu': 'UD_Kiche-IU',
|
| 466 |
+
'ud_komi_permyak-uh': 'UD_Komi_Permyak-UH',
|
| 467 |
+
'ud_komi_zyrian-ikdp': 'UD_Komi_Zyrian-IKDP',
|
| 468 |
+
'ud_komi_zyrian-lattice': 'UD_Komi_Zyrian-Lattice',
|
| 469 |
+
'ud_korean-gsd': 'UD_Korean-GSD',
|
| 470 |
+
'ud_korean-ksl': 'UD_Korean-KSL',
|
| 471 |
+
'ud_korean-kaist': 'UD_Korean-Kaist',
|
| 472 |
+
'ud_korean-pud': 'UD_Korean-PUD',
|
| 473 |
+
'ud_kurmanji-mg': 'UD_Kurmanji-MG',
|
| 474 |
+
'ud_kyrgyz-ktmu': 'UD_Kyrgyz-KTMU',
|
| 475 |
+
'ud_kyrgyz-tuecl': 'UD_Kyrgyz-TueCL',
|
| 476 |
+
'ud_latgalian-cairo': 'UD_Latgalian-Cairo',
|
| 477 |
+
'ud_latin-circse': 'UD_Latin-CIRCSE',
|
| 478 |
+
'ud_latin-ittb': 'UD_Latin-ITTB',
|
| 479 |
+
'ud_latin-llct': 'UD_Latin-LLCT',
|
| 480 |
+
'ud_latin-proiel': 'UD_Latin-PROIEL',
|
| 481 |
+
'ud_latin-perseus': 'UD_Latin-Perseus',
|
| 482 |
+
'ud_latin-udante': 'UD_Latin-UDante',
|
| 483 |
+
'ud_latvian-cairo': 'UD_Latvian-Cairo',
|
| 484 |
+
'ud_latvian-lvtb': 'UD_Latvian-LVTB',
|
| 485 |
+
'ud_ligurian-glt': 'UD_Ligurian-GLT',
|
| 486 |
+
'ud_lithuanian-alksnis': 'UD_Lithuanian-ALKSNIS',
|
| 487 |
+
'ud_lithuanian-hse': 'UD_Lithuanian-HSE',
|
| 488 |
+
'ud_livvi-kkpp': 'UD_Livvi-KKPP',
|
| 489 |
+
'ud_low_saxon-lsdc': 'UD_Low_Saxon-LSDC',
|
| 490 |
+
'ud_luxembourgish-luxbank': 'UD_Luxembourgish-LuxBank',
|
| 491 |
+
'ud_macedonian-mtb': 'UD_Macedonian-MTB',
|
| 492 |
+
'ud_madi-jarawara': 'UD_Madi-Jarawara',
|
| 493 |
+
'ud_maghrebi_arabic_french-arabizi': 'UD_Maghrebi_Arabic_French-Arabizi',
|
| 494 |
+
'ud_makurap-tudet': 'UD_Makurap-TuDeT',
|
| 495 |
+
'ud_malayalam-ufal': 'UD_Malayalam-UFAL',
|
| 496 |
+
'ud_maltese-mudt': 'UD_Maltese-MUDT',
|
| 497 |
+
'ud_manx-cadhan': 'UD_Manx-Cadhan',
|
| 498 |
+
'ud_marathi-ufal': 'UD_Marathi-UFAL',
|
| 499 |
+
'ud_mbya_guarani-dooley': 'UD_Mbya_Guarani-Dooley',
|
| 500 |
+
'ud_mbya_guarani-thomas': 'UD_Mbya_Guarani-Thomas',
|
| 501 |
+
'ud_middle_french-profiterole': 'UD_Middle_French-PROFITEROLE',
|
| 502 |
+
'ud_moksha-jr': 'UD_Moksha-JR',
|
| 503 |
+
'ud_munduruku-tudet': 'UD_Munduruku-TuDeT',
|
| 504 |
+
'ud_naija-nsc': 'UD_Naija-NSC',
|
| 505 |
+
'ud_nayini-aha': 'UD_Nayini-AHA',
|
| 506 |
+
'ud_neapolitan-rb': 'UD_Neapolitan-RB',
|
| 507 |
+
'ud_nheengatu-complin': 'UD_Nheengatu-CompLin',
|
| 508 |
+
'ud_north_sami-giella': 'UD_North_Sami-Giella',
|
| 509 |
+
'ud_northwest_gbaya-autogramm': 'UD_Northwest_Gbaya-Autogramm',
|
| 510 |
+
'ud_norwegian-bokmaal': 'UD_Norwegian-Bokmaal',
|
| 511 |
+
'ud_norwegian-nynorsk': 'UD_Norwegian-Nynorsk',
|
| 512 |
+
'ud_old_church_slavonic-proiel': 'UD_Old_Church_Slavonic-PROIEL',
|
| 513 |
+
'ud_old_east_slavic-birchbark': 'UD_Old_East_Slavic-Birchbark',
|
| 514 |
+
'ud_old_east_slavic-rnc': 'UD_Old_East_Slavic-RNC',
|
| 515 |
+
'ud_old_east_slavic-ruthenian': 'UD_Old_East_Slavic-Ruthenian',
|
| 516 |
+
'ud_old_east_slavic-torot': 'UD_Old_East_Slavic-TOROT',
|
| 517 |
+
'ud_old_french-profiterole': 'UD_Old_French-PROFITEROLE',
|
| 518 |
+
'ud_old_irish-dipsgg': 'UD_Old_Irish-DipSGG',
|
| 519 |
+
'ud_old_irish-dipwbg': 'UD_Old_Irish-DipWBG',
|
| 520 |
+
'ud_old_turkish-clausal': 'UD_Old_Turkish-Clausal',
|
| 521 |
+
'ud_ottoman_turkish-boun': 'UD_Ottoman_Turkish-BOUN',
|
| 522 |
+
'ud_ottoman_turkish-dudu': 'UD_Ottoman_Turkish-DUDU',
|
| 523 |
+
'ud_pashto-sikaram': 'UD_Pashto-Sikaram',
|
| 524 |
+
'ud_paumari-tuecl': 'UD_Paumari-TueCL',
|
| 525 |
+
'ud_persian-perdt': 'UD_Persian-PerDT',
|
| 526 |
+
'ud_persian-seraji': 'UD_Persian-Seraji',
|
| 527 |
+
'ud_pesh-chibergis': 'UD_Pesh-ChibErgIS',
|
| 528 |
+
'ud_phrygian-kul': 'UD_Phrygian-KUL',
|
| 529 |
+
'ud_polish-lfg': 'UD_Polish-LFG',
|
| 530 |
+
'ud_polish-pdb': 'UD_Polish-PDB',
|
| 531 |
+
'ud_polish-pud': 'UD_Polish-PUD',
|
| 532 |
+
'ud_pomak-philotis': 'UD_Pomak-Philotis',
|
| 533 |
+
'ud_portuguese-bosque': 'UD_Portuguese-Bosque',
|
| 534 |
+
'ud_portuguese-cintil': 'UD_Portuguese-CINTIL',
|
| 535 |
+
'ud_portuguese-dantestocks': 'UD_Portuguese-DANTEStocks',
|
| 536 |
+
'ud_portuguese-gsd': 'UD_Portuguese-GSD',
|
| 537 |
+
'ud_portuguese-pud': 'UD_Portuguese-PUD',
|
| 538 |
+
'ud_portuguese-petrogold': 'UD_Portuguese-PetroGold',
|
| 539 |
+
'ud_portuguese-porttinari': 'UD_Portuguese-Porttinari',
|
| 540 |
+
'ud_romanian-art': 'UD_Romanian-ArT',
|
| 541 |
+
'ud_romanian-nonstandard': 'UD_Romanian-Nonstandard',
|
| 542 |
+
'ud_romanian-rrt': 'UD_Romanian-RRT',
|
| 543 |
+
'ud_romanian-simonero': 'UD_Romanian-SiMoNERo',
|
| 544 |
+
'ud_romanian-tuecl': 'UD_Romanian-TueCL',
|
| 545 |
+
'ud_russian-gsd': 'UD_Russian-GSD',
|
| 546 |
+
'ud_russian-pud': 'UD_Russian-PUD',
|
| 547 |
+
'ud_russian-poetry': 'UD_Russian-Poetry',
|
| 548 |
+
'ud_russian-syntagrus': 'UD_Russian-SynTagRus',
|
| 549 |
+
'ud_russian-taiga': 'UD_Russian-Taiga',
|
| 550 |
+
'ud_sanskrit-ufal': 'UD_Sanskrit-UFAL',
|
| 551 |
+
'ud_sanskrit-vedic': 'UD_Sanskrit-Vedic',
|
| 552 |
+
'ud_scottish_gaelic-arcosg': 'UD_Scottish_Gaelic-ARCOSG',
|
| 553 |
+
'ud_serbian-set': 'UD_Serbian-SET',
|
| 554 |
+
'ud_sinhala-stb': 'UD_Sinhala-STB',
|
| 555 |
+
'ud_skolt_sami-giellagas': 'UD_Skolt_Sami-Giellagas',
|
| 556 |
+
'ud_slovak-snk': 'UD_Slovak-SNK',
|
| 557 |
+
'ud_slovenian-ssj': 'UD_Slovenian-SSJ',
|
| 558 |
+
'ud_slovenian-sst': 'UD_Slovenian-SST',
|
| 559 |
+
'ud_soi-aha': 'UD_Soi-AHA',
|
| 560 |
+
'ud_south_levantine_arabic-madar': 'UD_South_Levantine_Arabic-MADAR',
|
| 561 |
+
'ud_spanish-ancora': 'UD_Spanish-AnCora',
|
| 562 |
+
'ud_spanish-coser': 'UD_Spanish-COSER',
|
| 563 |
+
'ud_spanish-gsd': 'UD_Spanish-GSD',
|
| 564 |
+
'ud_spanish-pud': 'UD_Spanish-PUD',
|
| 565 |
+
'ud_spanish_sign_language-lse': 'UD_Spanish_Sign_Language-LSE',
|
| 566 |
+
'ud_swedish-lines': 'UD_Swedish-LinES',
|
| 567 |
+
'ud_swedish-pud': 'UD_Swedish-PUD',
|
| 568 |
+
'ud_swedish-talbanken': 'UD_Swedish-Talbanken',
|
| 569 |
+
'ud_swedish_sign_language-sslc': 'UD_Swedish_Sign_Language-SSLC',
|
| 570 |
+
'ud_swiss_german-uzh': 'UD_Swiss_German-UZH',
|
| 571 |
+
'ud_tagalog-trg': 'UD_Tagalog-TRG',
|
| 572 |
+
'ud_tagalog-ugnayan': 'UD_Tagalog-Ugnayan',
|
| 573 |
+
'ud_tamil-mwtt': 'UD_Tamil-MWTT',
|
| 574 |
+
'ud_tamil-ttb': 'UD_Tamil-TTB',
|
| 575 |
+
'ud_tatar-nmctt': 'UD_Tatar-NMCTT',
|
| 576 |
+
'ud_teko-tudet': 'UD_Teko-TuDeT',
|
| 577 |
+
'ud_telugu-mtg': 'UD_Telugu-MTG',
|
| 578 |
+
'ud_telugu_english-tect': 'UD_Telugu_English-TECT',
|
| 579 |
+
'ud_thai-pud': 'UD_Thai-PUD',
|
| 580 |
+
'ud_tswana-popapolelo': 'UD_Tswana-Popapolelo',
|
| 581 |
+
'ud_tupinamba-tudet': 'UD_Tupinamba-TuDeT',
|
| 582 |
+
'ud_turkish-atis': 'UD_Turkish-Atis',
|
| 583 |
+
'ud_turkish-boun': 'UD_Turkish-BOUN',
|
| 584 |
+
'ud_turkish-framenet': 'UD_Turkish-FrameNet',
|
| 585 |
+
'ud_turkish-gb': 'UD_Turkish-GB',
|
| 586 |
+
'ud_turkish-imst': 'UD_Turkish-IMST',
|
| 587 |
+
'ud_turkish-kenet': 'UD_Turkish-Kenet',
|
| 588 |
+
'ud_turkish-pud': 'UD_Turkish-PUD',
|
| 589 |
+
'ud_turkish-penn': 'UD_Turkish-Penn',
|
| 590 |
+
'ud_turkish-tourism': 'UD_Turkish-Tourism',
|
| 591 |
+
'ud_turkish_german-sagt': 'UD_Turkish_German-SAGT',
|
| 592 |
+
'ud_ukrainian-iu': 'UD_Ukrainian-IU',
|
| 593 |
+
'ud_ukrainian-parlamint': 'UD_Ukrainian-ParlaMint',
|
| 594 |
+
'ud_umbrian-ikuvina': 'UD_Umbrian-IKUVINA',
|
| 595 |
+
'ud_upper_sorbian-ufal': 'UD_Upper_Sorbian-UFAL',
|
| 596 |
+
'ud_urdu-udtb': 'UD_Urdu-UDTB',
|
| 597 |
+
'ud_uyghur-udt': 'UD_Uyghur-UDT',
|
| 598 |
+
'ud_uzbek-ut': 'UD_Uzbek-UT',
|
| 599 |
+
'ud_veps-vwt': 'UD_Veps-VWT',
|
| 600 |
+
'ud_vietnamese-tuecl': 'UD_Vietnamese-TueCL',
|
| 601 |
+
'ud_vietnamese-vtb': 'UD_Vietnamese-VTB',
|
| 602 |
+
'ud_warlpiri-ufal': 'UD_Warlpiri-UFAL',
|
| 603 |
+
'ud_welsh-ccg': 'UD_Welsh-CCG',
|
| 604 |
+
'ud_western_armenian-armtdp': 'UD_Western_Armenian-ArmTDP',
|
| 605 |
+
'ud_western_sierra_puebla_nahuatl-itml': 'UD_Western_Sierra_Puebla_Nahuatl-ITML',
|
| 606 |
+
'ud_wolof-wtb': 'UD_Wolof-WTB',
|
| 607 |
+
'ud_xavante-xdt': 'UD_Xavante-XDT',
|
| 608 |
+
'ud_xibe-xdt': 'UD_Xibe-XDT',
|
| 609 |
+
'ud_yakut-yktdt': 'UD_Yakut-YKTDT',
|
| 610 |
+
'ud_yoruba-ytb': 'UD_Yoruba-YTB',
|
| 611 |
+
'ud_yupik-sli': 'UD_Yupik-SLI',
|
| 612 |
+
'ud_zaar-autogramm': 'UD_Zaar-Autogramm',
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def canonical_treebank_name(ud_name):
|
| 617 |
+
if ud_name in SHORT_NAMES:
|
| 618 |
+
return SHORT_NAMES[ud_name]
|
| 619 |
+
return CANONICAL_NAMES.get(ud_name.lower(), ud_name)
|
stanza/stanza/models/common/trainer.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class Trainer:
|
| 4 |
+
def change_lr(self, new_lr):
|
| 5 |
+
for param_group in self.optimizer.param_groups:
|
| 6 |
+
param_group['lr'] = new_lr
|
| 7 |
+
|
| 8 |
+
def save(self, filename):
|
| 9 |
+
savedict = {
|
| 10 |
+
'model': self.model.state_dict(),
|
| 11 |
+
'optimizer': self.optimizer.state_dict()
|
| 12 |
+
}
|
| 13 |
+
torch.save(savedict, filename)
|
| 14 |
+
|
| 15 |
+
def load(self, filename):
|
| 16 |
+
savedict = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 17 |
+
|
| 18 |
+
self.model.load_state_dict(savedict['model'])
|
| 19 |
+
if self.args['mode'] == 'train':
|
| 20 |
+
self.optimizer.load_state_dict(savedict['optimizer'])
|
stanza/stanza/models/common/utils.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
import gzip
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import lzma
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
import re
|
| 15 |
+
import sys
|
| 16 |
+
import unicodedata
|
| 17 |
+
import zipfile
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from stanza.models.common.constant import lcode2lang
|
| 23 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 24 |
+
from stanza.resources.default_packages import TRANSFORMER_NICKNAMES
|
| 25 |
+
import stanza.utils.conll18_ud_eval as ud_eval
|
| 26 |
+
from stanza.utils.conll18_ud_eval import UDError
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger('stanza')
|
| 29 |
+
|
| 30 |
+
# filenames
|
| 31 |
+
def get_wordvec_file(wordvec_dir, shorthand, wordvec_type=None):
|
| 32 |
+
""" Lookup the name of the word vectors file, given a directory and the language shorthand.
|
| 33 |
+
"""
|
| 34 |
+
lcode, tcode = shorthand.split('_', 1)
|
| 35 |
+
lang = lcode2lang[lcode]
|
| 36 |
+
# locate language folder
|
| 37 |
+
word2vec_dir = os.path.join(wordvec_dir, 'word2vec', lang)
|
| 38 |
+
fasttext_dir = os.path.join(wordvec_dir, 'fasttext', lang)
|
| 39 |
+
lang_dir = None
|
| 40 |
+
if wordvec_type is not None:
|
| 41 |
+
lang_dir = os.path.join(wordvec_dir, wordvec_type, lang)
|
| 42 |
+
if not os.path.exists(lang_dir):
|
| 43 |
+
raise FileNotFoundError("Word vector type {} was specified, but directory {} does not exist".format(wordvec_type, lang_dir))
|
| 44 |
+
elif os.path.exists(word2vec_dir): # first try word2vec
|
| 45 |
+
lang_dir = word2vec_dir
|
| 46 |
+
elif os.path.exists(fasttext_dir): # otherwise try fasttext
|
| 47 |
+
lang_dir = fasttext_dir
|
| 48 |
+
else:
|
| 49 |
+
raise FileNotFoundError("Cannot locate word vector directory for language: {} Looked in {} and {}".format(lang, word2vec_dir, fasttext_dir))
|
| 50 |
+
# look for wordvec filename in {lang_dir}
|
| 51 |
+
filename = os.path.join(lang_dir, '{}.vectors'.format(lcode))
|
| 52 |
+
if os.path.exists(filename + ".xz"):
|
| 53 |
+
filename = filename + ".xz"
|
| 54 |
+
elif os.path.exists(filename + ".txt"):
|
| 55 |
+
filename = filename + ".txt"
|
| 56 |
+
return filename
|
| 57 |
+
|
| 58 |
+
@contextmanager
|
| 59 |
+
def output_stream(filename=None):
|
| 60 |
+
"""
|
| 61 |
+
Yields the given file if a file is given, or returns sys.stdout if filename is None
|
| 62 |
+
|
| 63 |
+
Opens the file in a context manager so it closes nicely
|
| 64 |
+
"""
|
| 65 |
+
if filename is None:
|
| 66 |
+
yield sys.stdout
|
| 67 |
+
else:
|
| 68 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 69 |
+
yield fout
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@contextmanager
|
| 73 |
+
def open_read_text(filename, encoding="utf-8"):
|
| 74 |
+
"""
|
| 75 |
+
Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular text otherwise.
|
| 76 |
+
|
| 77 |
+
Use as a context
|
| 78 |
+
|
| 79 |
+
eg:
|
| 80 |
+
with open_read_text(filename) as fin:
|
| 81 |
+
do stuff
|
| 82 |
+
|
| 83 |
+
File will be closed once the context exits
|
| 84 |
+
"""
|
| 85 |
+
if filename.endswith(".xz"):
|
| 86 |
+
with lzma.open(filename, mode='rt', encoding=encoding) as fin:
|
| 87 |
+
yield fin
|
| 88 |
+
elif filename.endswith(".gz"):
|
| 89 |
+
with gzip.open(filename, mode='rt', encoding=encoding) as fin:
|
| 90 |
+
yield fin
|
| 91 |
+
else:
|
| 92 |
+
with open(filename, encoding=encoding) as fin:
|
| 93 |
+
yield fin
|
| 94 |
+
|
| 95 |
+
@contextmanager
|
| 96 |
+
def open_read_binary(filename):
|
| 97 |
+
"""
|
| 98 |
+
Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular binary file otherwise.
|
| 99 |
+
|
| 100 |
+
If a .zip file is given, it can be read if there is a single file in there
|
| 101 |
+
|
| 102 |
+
Use as a context
|
| 103 |
+
|
| 104 |
+
eg:
|
| 105 |
+
with open_read_binary(filename) as fin:
|
| 106 |
+
do stuff
|
| 107 |
+
|
| 108 |
+
File will be closed once the context exits
|
| 109 |
+
"""
|
| 110 |
+
if filename.endswith(".xz"):
|
| 111 |
+
with lzma.open(filename, mode='rb') as fin:
|
| 112 |
+
yield fin
|
| 113 |
+
elif filename.endswith(".gz"):
|
| 114 |
+
with gzip.open(filename, mode='rb') as fin:
|
| 115 |
+
yield fin
|
| 116 |
+
elif filename.endswith(".zip"):
|
| 117 |
+
with zipfile.ZipFile(filename) as zin:
|
| 118 |
+
input_names = zin.namelist()
|
| 119 |
+
if len(input_names) == 0:
|
| 120 |
+
raise ValueError("Empty zip archive")
|
| 121 |
+
if len(input_names) > 1:
|
| 122 |
+
raise ValueError("zip file %s has more than one file in it")
|
| 123 |
+
with zin.open(input_names[0]) as fin:
|
| 124 |
+
yield fin
|
| 125 |
+
else:
|
| 126 |
+
with open(filename, mode='rb') as fin:
|
| 127 |
+
yield fin
|
| 128 |
+
|
| 129 |
+
# training schedule
|
| 130 |
+
def get_adaptive_eval_interval(cur_dev_size, thres_dev_size, base_interval):
|
| 131 |
+
""" Adjust the evaluation interval adaptively.
|
| 132 |
+
If cur_dev_size <= thres_dev_size, return base_interval;
|
| 133 |
+
else, linearly increase the interval (round to integer times of base interval).
|
| 134 |
+
"""
|
| 135 |
+
if cur_dev_size <= thres_dev_size:
|
| 136 |
+
return base_interval
|
| 137 |
+
else:
|
| 138 |
+
alpha = round(cur_dev_size / thres_dev_size)
|
| 139 |
+
return base_interval * alpha
|
| 140 |
+
|
| 141 |
+
# ud utils
|
| 142 |
+
def ud_scores(gold_conllu_file, system_conllu_file):
|
| 143 |
+
try:
|
| 144 |
+
gold_ud = ud_eval.load_conllu_file(gold_conllu_file)
|
| 145 |
+
except UDError as e:
|
| 146 |
+
raise UDError("Could not read %s" % gold_conllu_file) from e
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
system_ud = ud_eval.load_conllu_file(system_conllu_file)
|
| 150 |
+
except UDError as e:
|
| 151 |
+
raise UDError("Could not read %s" % system_conllu_file) from e
|
| 152 |
+
evaluation = ud_eval.evaluate(gold_ud, system_ud)
|
| 153 |
+
|
| 154 |
+
return evaluation
|
| 155 |
+
|
| 156 |
+
def harmonic_mean(a, weights=None):
|
| 157 |
+
if any([x == 0 for x in a]):
|
| 158 |
+
return 0
|
| 159 |
+
else:
|
| 160 |
+
assert weights is None or len(weights) == len(a), 'Weights has length {} which is different from that of the array ({}).'.format(len(weights), len(a))
|
| 161 |
+
if weights is None:
|
| 162 |
+
return len(a) / sum([1/x for x in a])
|
| 163 |
+
else:
|
| 164 |
+
return sum(weights) / sum(w/x for x, w in zip(a, weights))
|
| 165 |
+
|
| 166 |
+
# torch utils
|
| 167 |
+
def dispatch_optimizer(name, parameters, opt_logger, lr=None, betas=None, eps=None, momentum=None, **extra_args):
|
| 168 |
+
extra_logging = ""
|
| 169 |
+
if len(extra_args) > 0:
|
| 170 |
+
extra_logging = ", " + ", ".join("%s=%s" % (x, y) for x, y in extra_args.items())
|
| 171 |
+
|
| 172 |
+
if name == 'amsgrad':
|
| 173 |
+
opt_logger.debug("Building Adam w/ amsgrad with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
|
| 174 |
+
return torch.optim.Adam(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
|
| 175 |
+
elif name == 'amsgradw':
|
| 176 |
+
opt_logger.debug("Building AdamW w/ amsgrad with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
|
| 177 |
+
return torch.optim.AdamW(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
|
| 178 |
+
elif name == 'sgd':
|
| 179 |
+
opt_logger.debug("Building SGD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
|
| 180 |
+
return torch.optim.SGD(parameters, lr=lr, momentum=momentum, **extra_args)
|
| 181 |
+
elif name == 'adagrad':
|
| 182 |
+
opt_logger.debug("Building Adagrad with lr=%f%s", lr, extra_logging)
|
| 183 |
+
return torch.optim.Adagrad(parameters, lr=lr, **extra_args)
|
| 184 |
+
elif name == 'adam':
|
| 185 |
+
opt_logger.debug("Building Adam with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
|
| 186 |
+
return torch.optim.Adam(parameters, lr=lr, betas=betas, eps=eps, **extra_args)
|
| 187 |
+
elif name == 'adamw':
|
| 188 |
+
opt_logger.debug("Building AdamW with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
|
| 189 |
+
return torch.optim.AdamW(parameters, lr=lr, betas=betas, eps=eps, **extra_args)
|
| 190 |
+
elif name == 'adamax':
|
| 191 |
+
opt_logger.debug("Building Adamax%s", extra_logging)
|
| 192 |
+
return torch.optim.Adamax(parameters, **extra_args) # use default lr
|
| 193 |
+
elif name == 'adadelta':
|
| 194 |
+
opt_logger.debug("Building Adadelta with lr=%f%s", lr, extra_logging)
|
| 195 |
+
return torch.optim.Adadelta(parameters, lr=lr, **extra_args)
|
| 196 |
+
elif name == 'adabelief':
|
| 197 |
+
try:
|
| 198 |
+
from adabelief_pytorch import AdaBelief
|
| 199 |
+
except ModuleNotFoundError as e:
|
| 200 |
+
raise ModuleNotFoundError("Could not create adabelief optimizer. Perhaps the adabelief-pytorch package is not installed") from e
|
| 201 |
+
opt_logger.debug("Building AdaBelief with lr=%f, eps=%f%s", lr, eps, extra_logging)
|
| 202 |
+
# TODO: add weight_decouple and rectify as extra args?
|
| 203 |
+
return AdaBelief(parameters, lr=lr, eps=eps, weight_decouple=True, rectify=True, **extra_args)
|
| 204 |
+
elif name == 'madgrad':
|
| 205 |
+
try:
|
| 206 |
+
import madgrad
|
| 207 |
+
except ModuleNotFoundError as e:
|
| 208 |
+
raise ModuleNotFoundError("Could not create madgrad optimizer. Perhaps the madgrad package is not installed") from e
|
| 209 |
+
opt_logger.debug("Building MADGRAD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
|
| 210 |
+
return madgrad.MADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)
|
| 211 |
+
elif name == 'mirror_madgrad':
|
| 212 |
+
try:
|
| 213 |
+
import madgrad
|
| 214 |
+
except ModuleNotFoundError as e:
|
| 215 |
+
raise ModuleNotFoundError("Could not create mirror_madgrad optimizer. Perhaps the madgrad package is not installed") from e
|
| 216 |
+
opt_logger.debug("Building MirrorMADGRAD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
|
| 217 |
+
return madgrad.MirrorMADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError("Unsupported optimizer: {}".format(name))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None, opt_logger=None):
|
| 223 |
+
opt_logger = opt_logger if opt_logger is not None else logger
|
| 224 |
+
base_parameters = [p for n, p in model.named_parameters()
|
| 225 |
+
if p.requires_grad and not n.startswith("bert_model.")
|
| 226 |
+
and not n.startswith("charmodel_forward.") and not n.startswith("charmodel_backward.")]
|
| 227 |
+
parameters = [{'param_group_name': 'base', 'params': base_parameters}]
|
| 228 |
+
|
| 229 |
+
charlm_parameters = [p for n, p in model.named_parameters()
|
| 230 |
+
if p.requires_grad and (n.startswith("charmodel_forward.") or n.startswith("charmodel_backward."))]
|
| 231 |
+
if len(charlm_parameters) > 0 and charlm_learning_rate > 0:
|
| 232 |
+
parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})
|
| 233 |
+
|
| 234 |
+
if not is_peft:
|
| 235 |
+
bert_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith("bert_model.")]
|
| 236 |
+
|
| 237 |
+
# bert_finetune_layers limits the bert finetuning to the *last* N layers of the model
|
| 238 |
+
if len(bert_parameters) > 0 and bert_finetune_layers is not None:
|
| 239 |
+
num_layers = model.bert_model.config.num_hidden_layers
|
| 240 |
+
start_layer = num_layers - bert_finetune_layers
|
| 241 |
+
bert_parameters = []
|
| 242 |
+
for layer_num in range(start_layer, num_layers):
|
| 243 |
+
bert_parameters.extend([param for name, param in model.named_parameters()
|
| 244 |
+
if param.requires_grad and name.startswith("bert_model.") and "layer.%d." % layer_num in name])
|
| 245 |
+
|
| 246 |
+
if len(bert_parameters) > 0 and bert_learning_rate > 0:
|
| 247 |
+
opt_logger.debug("Finetuning %d bert parameters with LR %s and WD %s", len(bert_parameters), lr * bert_learning_rate, bert_weight_decay)
|
| 248 |
+
parameters.append({'param_group_name': 'bert', 'params': bert_parameters, 'lr': lr * bert_learning_rate})
|
| 249 |
+
if bert_weight_decay is not None:
|
| 250 |
+
parameters[-1]['weight_decay'] = bert_weight_decay
|
| 251 |
+
else:
|
| 252 |
+
# some optimizers seem to train some even with a learning rate of 0...
|
| 253 |
+
if bert_learning_rate > 0:
|
| 254 |
+
# because PEFT handles what to hand to an optimizer, we don't want to touch that
|
| 255 |
+
parameters.append({'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate})
|
| 256 |
+
if bert_weight_decay is not None:
|
| 257 |
+
parameters[-1]['weight_decay'] = bert_weight_decay
|
| 258 |
+
|
| 259 |
+
extra_args = {}
|
| 260 |
+
if weight_decay is not None:
|
| 261 |
+
extra_args["weight_decay"] = weight_decay
|
| 262 |
+
|
| 263 |
+
return dispatch_optimizer(name, parameters, opt_logger=opt_logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
|
| 264 |
+
|
| 265 |
+
def get_split_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None):
|
| 266 |
+
"""Same as `get_optimizer`, but splits the optimizer for Bert into a seperate optimizer"""
|
| 267 |
+
base_parameters = [p for n, p in model.named_parameters()
|
| 268 |
+
if p.requires_grad and not n.startswith("bert_model.")
|
| 269 |
+
and not n.startswith("charmodel_forward.") and not n.startswith("charmodel_backward.")]
|
| 270 |
+
parameters = [{'param_group_name': 'base', 'params': base_parameters}]
|
| 271 |
+
|
| 272 |
+
charlm_parameters = [p for n, p in model.named_parameters()
|
| 273 |
+
if p.requires_grad and (n.startswith("charmodel_forward.") or n.startswith("charmodel_backward."))]
|
| 274 |
+
if len(charlm_parameters) > 0 and charlm_learning_rate > 0:
|
| 275 |
+
parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})
|
| 276 |
+
|
| 277 |
+
bert_parameters = None
|
| 278 |
+
if not is_peft:
|
| 279 |
+
trainable_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith("bert_model.")]
|
| 280 |
+
|
| 281 |
+
# bert_finetune_layers limits the bert finetuning to the *last* N layers of the model
|
| 282 |
+
if len(trainable_parameters) > 0 and bert_finetune_layers is not None:
|
| 283 |
+
num_layers = model.bert_model.config.num_hidden_layers
|
| 284 |
+
start_layer = num_layers - bert_finetune_layers
|
| 285 |
+
trainable_parameters = []
|
| 286 |
+
for layer_num in range(start_layer, num_layers):
|
| 287 |
+
trainable_parameters.extend([param for name, param in model.named_parameters()
|
| 288 |
+
if param.requires_grad and name.startswith("bert_model.") and "layer.%d." % layer_num in name])
|
| 289 |
+
|
| 290 |
+
if len(trainable_parameters) > 0:
|
| 291 |
+
bert_parameters = [{'param_group_name': 'bert', 'params': trainable_parameters, 'lr': lr * bert_learning_rate}]
|
| 292 |
+
else:
|
| 293 |
+
# because PEFT handles what to hand to an optimizer, we don't want to touch that
|
| 294 |
+
bert_parameters = [{'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate}]
|
| 295 |
+
|
| 296 |
+
extra_args = {}
|
| 297 |
+
if weight_decay is not None:
|
| 298 |
+
extra_args["weight_decay"] = weight_decay
|
| 299 |
+
|
| 300 |
+
optimizers = {
|
| 301 |
+
"general_optimizer": dispatch_optimizer(name, parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
|
| 302 |
+
}
|
| 303 |
+
if bert_parameters is not None and bert_learning_rate > 0.0:
|
| 304 |
+
if bert_weight_decay is not None:
|
| 305 |
+
extra_args['weight_decay'] = bert_weight_decay
|
| 306 |
+
optimizers["bert_optimizer"] = dispatch_optimizer(name, bert_parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
|
| 307 |
+
return optimizers
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def change_lr(optimizer, new_lr):
|
| 311 |
+
for param_group in optimizer.param_groups:
|
| 312 |
+
param_group['lr'] = new_lr
|
| 313 |
+
|
| 314 |
+
def flatten_indices(seq_lens, width):
|
| 315 |
+
flat = []
|
| 316 |
+
for i, l in enumerate(seq_lens):
|
| 317 |
+
for j in range(l):
|
| 318 |
+
flat.append(i * width + j)
|
| 319 |
+
return flat
|
| 320 |
+
|
| 321 |
+
def keep_partial_grad(grad, topk):
|
| 322 |
+
"""
|
| 323 |
+
Keep only the topk rows of grads.
|
| 324 |
+
"""
|
| 325 |
+
assert topk < grad.size(0)
|
| 326 |
+
grad.data[topk:].zero_()
|
| 327 |
+
return grad
|
| 328 |
+
|
| 329 |
+
# other utils
|
| 330 |
+
def ensure_dir(d, verbose=True):
|
| 331 |
+
if not os.path.exists(d):
|
| 332 |
+
if verbose:
|
| 333 |
+
logger.info("Directory {} does not exist; creating...".format(d))
|
| 334 |
+
# exist_ok: guard against race conditions
|
| 335 |
+
os.makedirs(d, exist_ok=True)
|
| 336 |
+
|
| 337 |
+
def save_config(config, path, verbose=True):
|
| 338 |
+
with open(path, 'w') as outfile:
|
| 339 |
+
json.dump(config, outfile, indent=2)
|
| 340 |
+
if verbose:
|
| 341 |
+
print("Config saved to file {}".format(path))
|
| 342 |
+
return config
|
| 343 |
+
|
| 344 |
+
def load_config(path, verbose=True):
|
| 345 |
+
with open(path) as f:
|
| 346 |
+
config = json.load(f)
|
| 347 |
+
if verbose:
|
| 348 |
+
print("Config loaded from file {}".format(path))
|
| 349 |
+
return config
|
| 350 |
+
|
| 351 |
+
def print_config(config):
|
| 352 |
+
info = "Running with the following configs:\n"
|
| 353 |
+
for k,v in config.items():
|
| 354 |
+
info += "\t{} : {}\n".format(k, str(v))
|
| 355 |
+
logger.info("\n" + info + "\n")
|
| 356 |
+
|
| 357 |
+
def normalize_text(text):
|
| 358 |
+
return unicodedata.normalize('NFD', text)
|
| 359 |
+
|
| 360 |
+
def unmap_with_copy(indices, src_tokens, vocab):
|
| 361 |
+
"""
|
| 362 |
+
Unmap a list of list of indices, by optionally copying from src_tokens.
|
| 363 |
+
"""
|
| 364 |
+
result = []
|
| 365 |
+
for ind, tokens in zip(indices, src_tokens):
|
| 366 |
+
words = []
|
| 367 |
+
for idx in ind:
|
| 368 |
+
if idx >= 0:
|
| 369 |
+
words.append(vocab.id2word[idx])
|
| 370 |
+
else:
|
| 371 |
+
idx = -idx - 1 # flip and minus 1
|
| 372 |
+
words.append(tokens[idx])
|
| 373 |
+
result += [words]
|
| 374 |
+
return result
|
| 375 |
+
|
| 376 |
+
def prune_decoded_seqs(seqs):
|
| 377 |
+
"""
|
| 378 |
+
Prune decoded sequences after EOS token.
|
| 379 |
+
"""
|
| 380 |
+
out = []
|
| 381 |
+
for s in seqs:
|
| 382 |
+
if constant.EOS in s:
|
| 383 |
+
idx = s.index(constant.EOS_TOKEN)
|
| 384 |
+
out += [s[:idx]]
|
| 385 |
+
else:
|
| 386 |
+
out += [s]
|
| 387 |
+
return out
|
| 388 |
+
|
| 389 |
+
def prune_hyp(hyp):
|
| 390 |
+
"""
|
| 391 |
+
Prune a decoded hypothesis
|
| 392 |
+
"""
|
| 393 |
+
if constant.EOS_ID in hyp:
|
| 394 |
+
idx = hyp.index(constant.EOS_ID)
|
| 395 |
+
return hyp[:idx]
|
| 396 |
+
else:
|
| 397 |
+
return hyp
|
| 398 |
+
|
| 399 |
+
def prune(data_list, lens):
|
| 400 |
+
assert len(data_list) == len(lens)
|
| 401 |
+
nl = []
|
| 402 |
+
for d, l in zip(data_list, lens):
|
| 403 |
+
nl.append(d[:l])
|
| 404 |
+
return nl
|
| 405 |
+
|
| 406 |
+
def sort(packed, ref, reverse=True):
|
| 407 |
+
"""
|
| 408 |
+
Sort a series of packed list, according to a ref list.
|
| 409 |
+
Also return the original index before the sort.
|
| 410 |
+
"""
|
| 411 |
+
assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)
|
| 412 |
+
packed = [ref] + [range(len(ref))] + list(packed)
|
| 413 |
+
sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]
|
| 414 |
+
return tuple(sorted_packed[1:])
|
| 415 |
+
|
| 416 |
+
def unsort(sorted_list, oidx):
|
| 417 |
+
"""
|
| 418 |
+
Unsort a sorted list, based on the original idx.
|
| 419 |
+
"""
|
| 420 |
+
assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
|
| 421 |
+
if len(sorted_list) == 0:
|
| 422 |
+
return []
|
| 423 |
+
_, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
|
| 424 |
+
return unsorted
|
| 425 |
+
|
| 426 |
+
def sort_with_indices(data, key=None, reverse=False):
|
| 427 |
+
"""
|
| 428 |
+
Sort data and return both the data and the original indices.
|
| 429 |
+
|
| 430 |
+
One useful application is to sort by length, which can be done with key=len
|
| 431 |
+
Returns the data as a sorted list, then the indices of the original list.
|
| 432 |
+
"""
|
| 433 |
+
if not data:
|
| 434 |
+
return [], []
|
| 435 |
+
if key:
|
| 436 |
+
ordered = sorted(enumerate(data), key=lambda x: key(x[1]), reverse=reverse)
|
| 437 |
+
else:
|
| 438 |
+
ordered = sorted(enumerate(data), key=lambda x: x[1], reverse=reverse)
|
| 439 |
+
|
| 440 |
+
result = tuple(zip(*ordered))
|
| 441 |
+
return result[1], result[0]
|
| 442 |
+
|
| 443 |
+
def split_into_batches(data, batch_size):
|
| 444 |
+
"""
|
| 445 |
+
Returns a list of intervals so that each interval is either <= batch_size or one element long.
|
| 446 |
+
|
| 447 |
+
Long elements are not dropped from the intervals.
|
| 448 |
+
data is a list of lists
|
| 449 |
+
batch_size is how long to make each batch
|
| 450 |
+
return value is a list of pairs, start_idx end_idx
|
| 451 |
+
"""
|
| 452 |
+
intervals = []
|
| 453 |
+
interval_start = 0
|
| 454 |
+
interval_size = 0
|
| 455 |
+
for idx, line in enumerate(data):
|
| 456 |
+
if len(line) > batch_size:
|
| 457 |
+
# guess we'll just hope the model can handle a batch of this size after all
|
| 458 |
+
if interval_size > 0:
|
| 459 |
+
intervals.append((interval_start, idx))
|
| 460 |
+
intervals.append((idx, idx+1))
|
| 461 |
+
interval_start = idx+1
|
| 462 |
+
interval_size = 0
|
| 463 |
+
elif len(line) + interval_size > batch_size:
|
| 464 |
+
# this line puts us over batch_size
|
| 465 |
+
intervals.append((interval_start, idx))
|
| 466 |
+
interval_start = idx
|
| 467 |
+
interval_size = len(line)
|
| 468 |
+
else:
|
| 469 |
+
interval_size = interval_size + len(line)
|
| 470 |
+
if interval_size > 0:
|
| 471 |
+
# there's some leftover
|
| 472 |
+
intervals.append((interval_start, len(data)))
|
| 473 |
+
return intervals
|
| 474 |
+
|
| 475 |
+
def tensor_unsort(sorted_tensor, oidx):
|
| 476 |
+
"""
|
| 477 |
+
Unsort a sorted tensor on its 0-th dimension, based on the original idx.
|
| 478 |
+
"""
|
| 479 |
+
assert sorted_tensor.size(0) == len(oidx), "Number of list elements must match with original indices."
|
| 480 |
+
backidx = [x[0] for x in sorted(enumerate(oidx), key=lambda x: x[1])]
|
| 481 |
+
return sorted_tensor[backidx]
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def set_random_seed(seed):
|
| 485 |
+
"""
|
| 486 |
+
Set a random seed on all of the things which might need it.
|
| 487 |
+
torch, np, python random, and torch.cuda
|
| 488 |
+
"""
|
| 489 |
+
if seed is None:
|
| 490 |
+
seed = random.randint(0, 1000000000)
|
| 491 |
+
|
| 492 |
+
torch.manual_seed(seed)
|
| 493 |
+
np.random.seed(seed)
|
| 494 |
+
random.seed(seed)
|
| 495 |
+
# some of these calls are probably redundant
|
| 496 |
+
torch.manual_seed(seed)
|
| 497 |
+
if torch.cuda.is_available():
|
| 498 |
+
torch.cuda.manual_seed(seed)
|
| 499 |
+
torch.cuda.manual_seed_all(seed)
|
| 500 |
+
return seed
|
| 501 |
+
|
| 502 |
+
def find_missing_tags(known_tags, test_tags):
|
| 503 |
+
if isinstance(known_tags, list) and isinstance(known_tags[0], list):
|
| 504 |
+
known_tags = set(x for y in known_tags for x in y)
|
| 505 |
+
if isinstance(test_tags, list) and isinstance(test_tags[0], list):
|
| 506 |
+
test_tags = sorted(set(x for y in test_tags for x in y))
|
| 507 |
+
missing_tags = sorted(x for x in test_tags if x not in known_tags)
|
| 508 |
+
return missing_tags
|
| 509 |
+
|
| 510 |
+
def warn_missing_tags(known_tags, test_tags, test_set_name):
|
| 511 |
+
"""
|
| 512 |
+
Print a warning if any tags present in the second list are not in the first list.
|
| 513 |
+
|
| 514 |
+
Can also handle a list of lists.
|
| 515 |
+
"""
|
| 516 |
+
missing_tags = find_missing_tags(known_tags, test_tags)
|
| 517 |
+
if len(missing_tags) > 0:
|
| 518 |
+
logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags))
|
| 519 |
+
return True
|
| 520 |
+
return False
|
| 521 |
+
|
| 522 |
+
def checkpoint_name(save_dir, save_name, checkpoint_name):
|
| 523 |
+
"""
|
| 524 |
+
Will return a recommended checkpoint name for the given dir, save_name, optional checkpoint_name
|
| 525 |
+
|
| 526 |
+
For example, can pass in args['save_dir'], args['save_name'], args['checkpoint_save_name']
|
| 527 |
+
"""
|
| 528 |
+
if checkpoint_name:
|
| 529 |
+
model_dir = os.path.split(checkpoint_name)[0]
|
| 530 |
+
if model_dir == save_dir:
|
| 531 |
+
return checkpoint_name
|
| 532 |
+
return os.path.join(save_dir, checkpoint_name)
|
| 533 |
+
|
| 534 |
+
model_dir = os.path.split(save_name)[0]
|
| 535 |
+
if model_dir != save_dir:
|
| 536 |
+
save_name = os.path.join(save_dir, save_name)
|
| 537 |
+
if save_name.endswith(".pt"):
|
| 538 |
+
return save_name[:-3] + "_checkpoint.pt"
|
| 539 |
+
|
| 540 |
+
return save_name + "_checkpoint"
|
| 541 |
+
|
| 542 |
+
def default_device():
|
| 543 |
+
"""
|
| 544 |
+
Pick a default device based on what's available on this system
|
| 545 |
+
"""
|
| 546 |
+
if torch.cuda.is_available():
|
| 547 |
+
return 'cuda'
|
| 548 |
+
return 'cpu'
|
| 549 |
+
|
| 550 |
+
def add_device_args(parser):
|
| 551 |
+
"""
|
| 552 |
+
Add args which specify cpu, cuda, or arbitrary device
|
| 553 |
+
"""
|
| 554 |
+
parser.add_argument('--device', type=str, default=default_device(), help='Which device to run on - use a torch device string name')
|
| 555 |
+
parser.add_argument('--cuda', dest='device', action='store_const', const='cuda', help='Run on CUDA')
|
| 556 |
+
parser.add_argument('--cpu', dest='device', action='store_const', const='cpu', help='Ignore CUDA and run on CPU')
|
| 557 |
+
|
| 558 |
+
def load_elmo(elmo_model):
|
| 559 |
+
# This import is here so that Elmo integration can be treated
|
| 560 |
+
# as an optional feature
|
| 561 |
+
import elmoformanylangs
|
| 562 |
+
|
| 563 |
+
logger.info("Loading elmo: %s" % elmo_model)
|
| 564 |
+
elmo_model = elmoformanylangs.Embedder(elmo_model)
|
| 565 |
+
return elmo_model
|
| 566 |
+
|
| 567 |
+
def log_training_args(args, args_logger, name="training"):
|
| 568 |
+
"""
|
| 569 |
+
For record keeping purposes, log the arguments when training
|
| 570 |
+
"""
|
| 571 |
+
if isinstance(args, argparse.Namespace):
|
| 572 |
+
args = vars(args)
|
| 573 |
+
keys = sorted(args.keys())
|
| 574 |
+
log_lines = ['%s: %s' % (k, args[k]) for k in keys]
|
| 575 |
+
args_logger.info('ARGS USED AT %s TIME:\n%s\n', name.upper(), '\n'.join(log_lines))
|
| 576 |
+
|
| 577 |
+
def embedding_name(args):
|
| 578 |
+
"""
|
| 579 |
+
Return the generic name of the biggest embedding used by a model.
|
| 580 |
+
|
| 581 |
+
Used by POS and depparse, for example.
|
| 582 |
+
|
| 583 |
+
TODO: Probably will make the transformer names a bit more informative,
|
| 584 |
+
such as electra, roberta, etc. Maybe even phobert for VI, for example
|
| 585 |
+
"""
|
| 586 |
+
embedding = "nocharlm"
|
| 587 |
+
if args['wordvec_pretrain_file'] is None and args['wordvec_file'] is None:
|
| 588 |
+
embedding = "nopretrain"
|
| 589 |
+
if args.get('charlm', True) and (args['charlm_forward_file'] or args['charlm_backward_file']):
|
| 590 |
+
embedding = "charlm"
|
| 591 |
+
if args['bert_model']:
|
| 592 |
+
if args['bert_model'] in TRANSFORMER_NICKNAMES:
|
| 593 |
+
embedding = TRANSFORMER_NICKNAMES[args['bert_model']]
|
| 594 |
+
else:
|
| 595 |
+
embedding = "transformer"
|
| 596 |
+
|
| 597 |
+
return embedding
|
| 598 |
+
|
| 599 |
+
def standard_model_file_name(args, model_type, **kwargs):
|
| 600 |
+
"""
|
| 601 |
+
Returns a model file name based on some common args found in the various models.
|
| 602 |
+
|
| 603 |
+
The expectation is that the args will have something like
|
| 604 |
+
|
| 605 |
+
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_parser.pt", help="File name to save the model")
|
| 606 |
+
|
| 607 |
+
Then the model shorthand, embedding type, and other args will be
|
| 608 |
+
turned into arguments in a format string
|
| 609 |
+
"""
|
| 610 |
+
embedding = embedding_name(args)
|
| 611 |
+
|
| 612 |
+
finetune = ""
|
| 613 |
+
transformer_lr = ""
|
| 614 |
+
if args.get("bert_finetune", False):
|
| 615 |
+
finetune = "finetuned"
|
| 616 |
+
if "bert_learning_rate" in args:
|
| 617 |
+
transformer_lr = "{}".format(args["bert_learning_rate"])
|
| 618 |
+
|
| 619 |
+
use_peft = "nopeft"
|
| 620 |
+
if args.get("bert_finetune", False) and args.get("use_peft", False):
|
| 621 |
+
use_peft = "peft"
|
| 622 |
+
|
| 623 |
+
bert_finetuning = ""
|
| 624 |
+
if args.get("bert_finetune", False):
|
| 625 |
+
if args.get("use_peft", False):
|
| 626 |
+
bert_finetuning = "peft"
|
| 627 |
+
else:
|
| 628 |
+
bert_finetuning = "ft"
|
| 629 |
+
|
| 630 |
+
seed = args.get('seed', None)
|
| 631 |
+
if seed is None:
|
| 632 |
+
seed = ""
|
| 633 |
+
else:
|
| 634 |
+
seed = str(seed)
|
| 635 |
+
|
| 636 |
+
format_args = {
|
| 637 |
+
"batch_size": args['batch_size'],
|
| 638 |
+
"bert_finetuning": bert_finetuning,
|
| 639 |
+
"embedding": embedding,
|
| 640 |
+
"finetune": finetune,
|
| 641 |
+
"peft": use_peft,
|
| 642 |
+
"seed": seed,
|
| 643 |
+
"shorthand": args['shorthand'],
|
| 644 |
+
"transformer_lr": transformer_lr,
|
| 645 |
+
}
|
| 646 |
+
format_args.update(**kwargs)
|
| 647 |
+
model_file = args['save_name'].format(**format_args)
|
| 648 |
+
model_file = re.sub("_+", "_", model_file)
|
| 649 |
+
|
| 650 |
+
model_dir = os.path.split(model_file)[0]
|
| 651 |
+
|
| 652 |
+
if not os.path.exists(os.path.join(args['save_dir'], model_file)) and os.path.exists(model_file):
|
| 653 |
+
return model_file
|
| 654 |
+
return os.path.join(args['save_dir'], model_file)
|
| 655 |
+
|
| 656 |
+
def escape_misc_space(space):
|
| 657 |
+
spaces = []
|
| 658 |
+
for char in space:
|
| 659 |
+
if char == ' ':
|
| 660 |
+
spaces.append('\\s')
|
| 661 |
+
elif char == '\t':
|
| 662 |
+
spaces.append('\\t')
|
| 663 |
+
elif char == '\r':
|
| 664 |
+
spaces.append('\\r')
|
| 665 |
+
elif char == '\n':
|
| 666 |
+
spaces.append('\\n')
|
| 667 |
+
elif char == '|':
|
| 668 |
+
spaces.append('\\p')
|
| 669 |
+
elif char == '\\':
|
| 670 |
+
spaces.append('\\\\')
|
| 671 |
+
elif char == ' ':
|
| 672 |
+
spaces.append('\\u00A0')
|
| 673 |
+
else:
|
| 674 |
+
spaces.append(char)
|
| 675 |
+
escaped_space = "".join(spaces)
|
| 676 |
+
return escaped_space
|
| 677 |
+
|
| 678 |
+
def unescape_misc_space(misc_space):
|
| 679 |
+
spaces = []
|
| 680 |
+
pos = 0
|
| 681 |
+
while pos < len(misc_space):
|
| 682 |
+
if misc_space[pos:pos+2] == '\\s':
|
| 683 |
+
spaces.append(' ')
|
| 684 |
+
pos += 2
|
| 685 |
+
elif misc_space[pos:pos+2] == '\\t':
|
| 686 |
+
spaces.append('\t')
|
| 687 |
+
pos += 2
|
| 688 |
+
elif misc_space[pos:pos+2] == '\\r':
|
| 689 |
+
spaces.append('\r')
|
| 690 |
+
pos += 2
|
| 691 |
+
elif misc_space[pos:pos+2] == '\\n':
|
| 692 |
+
spaces.append('\n')
|
| 693 |
+
pos += 2
|
| 694 |
+
elif misc_space[pos:pos+2] == '\\p':
|
| 695 |
+
spaces.append('|')
|
| 696 |
+
pos += 2
|
| 697 |
+
elif misc_space[pos:pos+2] == '\\\\':
|
| 698 |
+
spaces.append('\\')
|
| 699 |
+
pos += 2
|
| 700 |
+
elif misc_space[pos:pos+6] == '\\u00A0':
|
| 701 |
+
spaces.append(' ')
|
| 702 |
+
pos += 6
|
| 703 |
+
else:
|
| 704 |
+
spaces.append(misc_space[pos])
|
| 705 |
+
pos += 1
|
| 706 |
+
unescaped_space = "".join(spaces)
|
| 707 |
+
return unescaped_space
|
| 708 |
+
|
| 709 |
+
def space_before_to_misc(space):
|
| 710 |
+
"""
|
| 711 |
+
Convert whitespace to SpacesBefore specifically for the start of a document.
|
| 712 |
+
|
| 713 |
+
In general, UD datasets do not have both SpacesAfter on a token and SpacesBefore on the next token.
|
| 714 |
+
|
| 715 |
+
The space(s) are only marked on one of the tokens.
|
| 716 |
+
|
| 717 |
+
Only at the very beginning of a document is it necessary to mark what spaces occurred before the actual text,
|
| 718 |
+
and the default assumption is that there is no space if there is no SpacesBefore annotation.
|
| 719 |
+
"""
|
| 720 |
+
if not space:
|
| 721 |
+
return ""
|
| 722 |
+
escaped_space = escape_misc_space(space)
|
| 723 |
+
return "SpacesBefore=%s" % escaped_space
|
| 724 |
+
|
| 725 |
+
def space_after_to_misc(space):
|
| 726 |
+
"""
|
| 727 |
+
Convert whitespace back to the escaped format - either SpaceAfter=No or SpacesAfter=...
|
| 728 |
+
"""
|
| 729 |
+
if not space:
|
| 730 |
+
return "SpaceAfter=No"
|
| 731 |
+
if space == " ":
|
| 732 |
+
return ""
|
| 733 |
+
escaped_space = escape_misc_space(space)
|
| 734 |
+
return "SpacesAfter=%s" % escaped_space
|
| 735 |
+
|
| 736 |
+
def misc_to_space_before(misc):
|
| 737 |
+
"""
|
| 738 |
+
Find any SpacesBefore annotation in the MISC column and turn it into a space value
|
| 739 |
+
"""
|
| 740 |
+
if not misc:
|
| 741 |
+
return ""
|
| 742 |
+
pieces = misc.split("|")
|
| 743 |
+
for piece in pieces:
|
| 744 |
+
if not piece.lower().startswith("spacesbefore="):
|
| 745 |
+
continue
|
| 746 |
+
misc_space = piece.split("=", maxsplit=1)[1]
|
| 747 |
+
return unescape_misc_space(misc_space)
|
| 748 |
+
return ""
|
| 749 |
+
|
| 750 |
+
def misc_to_space_after(misc):
|
| 751 |
+
"""
|
| 752 |
+
Convert either SpaceAfter=No or the SpacesAfter annotation
|
| 753 |
+
|
| 754 |
+
see https://universaldependencies.org/misc.html#spacesafter
|
| 755 |
+
|
| 756 |
+
We compensate for some treebanks using SpaceAfter=\n instead of SpacesAfter=\n
|
| 757 |
+
On the way back, though, those annotations will be turned into SpacesAfter
|
| 758 |
+
"""
|
| 759 |
+
if not misc:
|
| 760 |
+
return " "
|
| 761 |
+
pieces = misc.split("|")
|
| 762 |
+
if any(piece.lower() == "spaceafter=no" for piece in pieces):
|
| 763 |
+
return ""
|
| 764 |
+
if "SpaceAfter=Yes" in pieces:
|
| 765 |
+
# as of UD 2.11, the Cantonese treebank had this as a misc feature
|
| 766 |
+
return " "
|
| 767 |
+
if "SpaceAfter=No~" in pieces:
|
| 768 |
+
# as of UD 2.11, a weird typo in the Russian Taiga dataset
|
| 769 |
+
return ""
|
| 770 |
+
for piece in pieces:
|
| 771 |
+
if piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter="):
|
| 772 |
+
misc_space = piece.split("=", maxsplit=1)[1]
|
| 773 |
+
return unescape_misc_space(misc_space)
|
| 774 |
+
return " "
|
| 775 |
+
|
| 776 |
+
def log_norms(model):
|
| 777 |
+
lines = ["NORMS FOR MODEL PARAMTERS"]
|
| 778 |
+
pieces = []
|
| 779 |
+
for name, param in model.named_parameters():
|
| 780 |
+
if param.requires_grad:
|
| 781 |
+
pieces.append((name, "%.6g" % torch.norm(param).item(), "%d" % param.numel()))
|
| 782 |
+
name_len = max(len(x[0]) for x in pieces)
|
| 783 |
+
norm_len = max(len(x[1]) for x in pieces)
|
| 784 |
+
line_format = " %-" + str(name_len) + "s %" + str(norm_len) + "s %s"
|
| 785 |
+
for line in pieces:
|
| 786 |
+
lines.append(line_format % line)
|
| 787 |
+
logger.info("\n".join(lines))
|
| 788 |
+
|
| 789 |
+
def attach_bert_model(model, bert_model, bert_tokenizer, use_peft, force_bert_saved):
|
| 790 |
+
if use_peft:
|
| 791 |
+
# we use a peft-specific pathway for saving peft weights
|
| 792 |
+
model.add_unsaved_module('bert_model', bert_model)
|
| 793 |
+
model.bert_model.train()
|
| 794 |
+
elif force_bert_saved:
|
| 795 |
+
model.bert_model = bert_model
|
| 796 |
+
elif bert_model is not None:
|
| 797 |
+
model.add_unsaved_module('bert_model', bert_model)
|
| 798 |
+
for _, parameter in bert_model.named_parameters():
|
| 799 |
+
parameter.requires_grad = False
|
| 800 |
+
else:
|
| 801 |
+
model.bert_model = None
|
| 802 |
+
model.add_unsaved_module('bert_tokenizer', bert_tokenizer)
|
| 803 |
+
|
| 804 |
+
def build_save_each_filename(base_filename):
|
| 805 |
+
"""
|
| 806 |
+
If the given name doesn't have %d in it, add %4d at the end of the filename
|
| 807 |
+
|
| 808 |
+
This way, there's something to count how many models have been saved
|
| 809 |
+
"""
|
| 810 |
+
try:
|
| 811 |
+
base_filename % 1
|
| 812 |
+
except TypeError:
|
| 813 |
+
# so models.pt -> models_0001.pt, etc
|
| 814 |
+
pieces = os.path.splitext(model_save_each_file)
|
| 815 |
+
base_filename = pieces[0] + "_%04d" + pieces[1]
|
| 816 |
+
return base_filename
|
stanza/stanza/models/common/vocab.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import copy
|
| 2 |
+
from collections import Counter, OrderedDict
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
|
| 7 |
+
PAD = '<PAD>'
|
| 8 |
+
PAD_ID = 0
|
| 9 |
+
UNK = '<UNK>'
|
| 10 |
+
UNK_ID = 1
|
| 11 |
+
EMPTY = '<EMPTY>'
|
| 12 |
+
EMPTY_ID = 2
|
| 13 |
+
ROOT = '<ROOT>'
|
| 14 |
+
ROOT_ID = 3
|
| 15 |
+
VOCAB_PREFIX = [PAD, UNK, EMPTY, ROOT]
|
| 16 |
+
VOCAB_PREFIX_SIZE = len(VOCAB_PREFIX)
|
| 17 |
+
|
| 18 |
+
class BaseVocab:
|
| 19 |
+
""" A base class for common vocabulary operations. Each subclass should at least
|
| 20 |
+
implement its own build_vocab() function."""
|
| 21 |
+
def __init__(self, data=None, lang="", idx=0, cutoff=0, lower=False):
|
| 22 |
+
self.data = data
|
| 23 |
+
self.lang = lang
|
| 24 |
+
self.idx = idx
|
| 25 |
+
self.cutoff = cutoff
|
| 26 |
+
self.lower = lower
|
| 27 |
+
if data is not None:
|
| 28 |
+
self.build_vocab()
|
| 29 |
+
self.state_attrs = ['lang', 'idx', 'cutoff', 'lower', '_unit2id', '_id2unit']
|
| 30 |
+
|
| 31 |
+
def build_vocab(self):
|
| 32 |
+
raise NotImplementedError("This BaseVocab does not have build_vocab implemented. This method should create _id2unit and _unit2id")
|
| 33 |
+
|
| 34 |
+
def state_dict(self):
|
| 35 |
+
""" Returns a dictionary containing all states that are necessary to recover
|
| 36 |
+
this vocab. Useful for serialization."""
|
| 37 |
+
state = OrderedDict()
|
| 38 |
+
for attr in self.state_attrs:
|
| 39 |
+
if hasattr(self, attr):
|
| 40 |
+
state[attr] = getattr(self, attr)
|
| 41 |
+
return state
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def load_state_dict(cls, state_dict):
|
| 45 |
+
""" Returns a new Vocab instance constructed from a state dict. """
|
| 46 |
+
new = cls()
|
| 47 |
+
for attr, value in state_dict.items():
|
| 48 |
+
setattr(new, attr, value)
|
| 49 |
+
return new
|
| 50 |
+
|
| 51 |
+
def normalize_unit(self, unit):
|
| 52 |
+
# be sure to look in subclasses for other normalization being done
|
| 53 |
+
# especially PretrainWordVocab
|
| 54 |
+
if unit is None:
|
| 55 |
+
return unit
|
| 56 |
+
if self.lower:
|
| 57 |
+
return unit.lower()
|
| 58 |
+
return unit
|
| 59 |
+
|
| 60 |
+
def unit2id(self, unit):
|
| 61 |
+
unit = self.normalize_unit(unit)
|
| 62 |
+
if unit in self._unit2id:
|
| 63 |
+
return self._unit2id[unit]
|
| 64 |
+
else:
|
| 65 |
+
return self._unit2id[UNK]
|
| 66 |
+
|
| 67 |
+
def id2unit(self, id):
|
| 68 |
+
return self._id2unit[id]
|
| 69 |
+
|
| 70 |
+
def map(self, units):
|
| 71 |
+
return [self.unit2id(x) for x in units]
|
| 72 |
+
|
| 73 |
+
def unmap(self, ids):
|
| 74 |
+
return [self.id2unit(x) for x in ids]
|
| 75 |
+
|
| 76 |
+
def __str__(self):
|
| 77 |
+
lang_str = "(%s)" % self.lang if self.lang else ""
|
| 78 |
+
name = str(type(self)) + lang_str
|
| 79 |
+
return "<%s: %s>" % (name, self._id2unit)
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self._id2unit)
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, key):
|
| 85 |
+
if isinstance(key, str):
|
| 86 |
+
return self.unit2id(key)
|
| 87 |
+
elif isinstance(key, int) or isinstance(key, list):
|
| 88 |
+
return self.id2unit(key)
|
| 89 |
+
else:
|
| 90 |
+
raise TypeError("Vocab key must be one of str, list, or int")
|
| 91 |
+
|
| 92 |
+
def __contains__(self, key):
|
| 93 |
+
return self.normalize_unit(key) in self._unit2id
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def size(self):
|
| 97 |
+
return len(self)
|
| 98 |
+
|
| 99 |
+
class DeltaVocab(BaseVocab):
|
| 100 |
+
"""
|
| 101 |
+
A vocab that starts off with a BaseVocab, then possibly adds more tokens based on the text in the given data
|
| 102 |
+
|
| 103 |
+
Currently meant only for characters, such as built by MWT or Lemma
|
| 104 |
+
|
| 105 |
+
Expected data format is either a list of strings, or a list of list of strings
|
| 106 |
+
"""
|
| 107 |
+
def __init__(self, data, orig_vocab):
|
| 108 |
+
self.orig_vocab = orig_vocab
|
| 109 |
+
super().__init__(data=data, lang=orig_vocab.lang, idx=orig_vocab.idx, cutoff=orig_vocab.cutoff, lower=orig_vocab.lower)
|
| 110 |
+
|
| 111 |
+
def build_vocab(self):
|
| 112 |
+
if all(isinstance(word, str) for word in self.data):
|
| 113 |
+
allchars = "".join(self.data)
|
| 114 |
+
else:
|
| 115 |
+
allchars = "".join([word for sentence in self.data for word in sentence])
|
| 116 |
+
|
| 117 |
+
unk = [c for c in allchars if c not in self.orig_vocab._unit2id]
|
| 118 |
+
if len(unk) > 0:
|
| 119 |
+
unk = sorted(set(unk))
|
| 120 |
+
self._id2unit = self.orig_vocab._id2unit + unk
|
| 121 |
+
self._unit2id = dict(self.orig_vocab._unit2id)
|
| 122 |
+
for c in unk:
|
| 123 |
+
self._unit2id[c] = len(self._unit2id)
|
| 124 |
+
else:
|
| 125 |
+
self._id2unit = self.orig_vocab._id2unit
|
| 126 |
+
self._unit2id = self.orig_vocab._unit2id
|
| 127 |
+
|
| 128 |
+
class CompositeVocab(BaseVocab):
|
| 129 |
+
''' Vocabulary class that handles parsing and printing composite values such as
|
| 130 |
+
compositional XPOS and universal morphological features (UFeats).
|
| 131 |
+
|
| 132 |
+
Two key options are `keyed` and `sep`. `sep` specifies the separator used between
|
| 133 |
+
different parts of the composite values, which is `|` for UFeats, for example.
|
| 134 |
+
If `keyed` is `True`, then the incoming value is treated similarly to UFeats, where
|
| 135 |
+
each part is a key/value pair separated by an equal sign (`=`). There are no inherit
|
| 136 |
+
order to the keys, and we sort them alphabetically for serialization and deserialization.
|
| 137 |
+
Whenever a part is absent, its internal value is a special `<EMPTY>` symbol that will
|
| 138 |
+
be treated accordingly when generating the output. If `keyed` is `False`, then the parts
|
| 139 |
+
are treated as positioned values, and `<EMPTY>` is used to pad parts at the end when the
|
| 140 |
+
incoming value is not long enough.'''
|
| 141 |
+
|
| 142 |
+
def __init__(self, data=None, lang="", idx=0, sep="", keyed=False):
|
| 143 |
+
self.sep = sep
|
| 144 |
+
self.keyed = keyed
|
| 145 |
+
super().__init__(data, lang, idx=idx)
|
| 146 |
+
self.state_attrs += ['sep', 'keyed']
|
| 147 |
+
|
| 148 |
+
def unit2parts(self, unit):
|
| 149 |
+
# unpack parts of a unit
|
| 150 |
+
if not self.sep:
|
| 151 |
+
parts = [x for x in unit]
|
| 152 |
+
else:
|
| 153 |
+
parts = unit.split(self.sep)
|
| 154 |
+
if self.keyed:
|
| 155 |
+
if len(parts) == 1 and parts[0] == '_':
|
| 156 |
+
return dict()
|
| 157 |
+
parts = [x.split('=') for x in parts]
|
| 158 |
+
if any(len(x) != 2 for x in parts):
|
| 159 |
+
raise ValueError('Received "%s" for a dictionary which is supposed to be keyed, eg the entries should all be of the form key=value and separated by %s' % (unit, self.sep))
|
| 160 |
+
|
| 161 |
+
# Just treat multi-valued properties values as one possible value
|
| 162 |
+
parts = dict(parts)
|
| 163 |
+
elif unit == '_':
|
| 164 |
+
parts = []
|
| 165 |
+
return parts
|
| 166 |
+
|
| 167 |
+
def unit2id(self, unit):
|
| 168 |
+
parts = self.unit2parts(unit)
|
| 169 |
+
if self.keyed:
|
| 170 |
+
# treat multi-valued properties as singletons
|
| 171 |
+
return [self._unit2id[k].get(parts[k], UNK_ID) if k in parts else EMPTY_ID for k in self._unit2id]
|
| 172 |
+
else:
|
| 173 |
+
return [self._unit2id[i].get(parts[i], UNK_ID) if i < len(parts) else EMPTY_ID for i in range(len(self._unit2id))]
|
| 174 |
+
|
| 175 |
+
def id2unit(self, id):
|
| 176 |
+
# special case: allow single ids for vocabs with length 1
|
| 177 |
+
if len(self._id2unit) == 1 and not isinstance(id, Iterable):
|
| 178 |
+
id = (id,)
|
| 179 |
+
items = []
|
| 180 |
+
for v, k in zip(id, self._id2unit.keys()):
|
| 181 |
+
if v == EMPTY_ID: continue
|
| 182 |
+
if self.keyed:
|
| 183 |
+
items.append("{}={}".format(k, self._id2unit[k][v]))
|
| 184 |
+
else:
|
| 185 |
+
items.append(self._id2unit[k][v])
|
| 186 |
+
if self.sep is not None:
|
| 187 |
+
res = self.sep.join(items)
|
| 188 |
+
if res == "":
|
| 189 |
+
res = "_"
|
| 190 |
+
return res
|
| 191 |
+
else:
|
| 192 |
+
return items
|
| 193 |
+
|
| 194 |
+
def build_vocab(self):
|
| 195 |
+
allunits = [w[self.idx] for sent in self.data for w in sent]
|
| 196 |
+
if self.keyed:
|
| 197 |
+
self._id2unit = dict()
|
| 198 |
+
|
| 199 |
+
for u in allunits:
|
| 200 |
+
parts = self.unit2parts(u)
|
| 201 |
+
for key in parts:
|
| 202 |
+
if key not in self._id2unit:
|
| 203 |
+
self._id2unit[key] = copy(VOCAB_PREFIX)
|
| 204 |
+
|
| 205 |
+
# treat multi-valued properties as singletons
|
| 206 |
+
if parts[key] not in self._id2unit[key]:
|
| 207 |
+
self._id2unit[key].append(parts[key])
|
| 208 |
+
|
| 209 |
+
# special handle for the case where upos/xpos/ufeats are always empty
|
| 210 |
+
if len(self._id2unit) == 0:
|
| 211 |
+
self._id2unit['_'] = copy(VOCAB_PREFIX) # use an arbitrary key
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
self._id2unit = dict()
|
| 215 |
+
|
| 216 |
+
allparts = [self.unit2parts(u) for u in allunits]
|
| 217 |
+
maxlen = max([len(p) for p in allparts])
|
| 218 |
+
|
| 219 |
+
for parts in allparts:
|
| 220 |
+
for i, p in enumerate(parts):
|
| 221 |
+
if i not in self._id2unit:
|
| 222 |
+
self._id2unit[i] = copy(VOCAB_PREFIX)
|
| 223 |
+
if i < len(parts) and p not in self._id2unit[i]:
|
| 224 |
+
self._id2unit[i].append(p)
|
| 225 |
+
|
| 226 |
+
# special handle for the case where upos/xpos/ufeats are always empty
|
| 227 |
+
if len(self._id2unit) == 0:
|
| 228 |
+
self._id2unit[0] = copy(VOCAB_PREFIX) # use an arbitrary key
|
| 229 |
+
|
| 230 |
+
self._id2unit = OrderedDict([(k, self._id2unit[k]) for k in sorted(self._id2unit.keys())])
|
| 231 |
+
self._unit2id = {k: {w:i for i, w in enumerate(self._id2unit[k])} for k in self._id2unit}
|
| 232 |
+
|
| 233 |
+
def lens(self):
|
| 234 |
+
return [len(self._unit2id[k]) for k in self._unit2id]
|
| 235 |
+
|
| 236 |
+
def items(self, idx):
|
| 237 |
+
return self._id2unit[idx]
|
| 238 |
+
|
| 239 |
+
def __str__(self):
|
| 240 |
+
pieces = ["[" + ",".join(x) + "]" for _, x in self._id2unit.items()]
|
| 241 |
+
rep = "<{}:\n {}>".format(type(self), "\n ".join(pieces))
|
| 242 |
+
return rep
|
| 243 |
+
|
| 244 |
+
class BaseMultiVocab:
|
| 245 |
+
""" A convenient vocab container that can store multiple BaseVocab instances, and support
|
| 246 |
+
safe serialization of all instances via state dicts. Each subclass of this base class
|
| 247 |
+
should implement the load_state_dict() function to specify how a saved state dict
|
| 248 |
+
should be loaded back."""
|
| 249 |
+
def __init__(self, vocab_dict=None):
|
| 250 |
+
self._vocabs = OrderedDict()
|
| 251 |
+
if vocab_dict is None:
|
| 252 |
+
return
|
| 253 |
+
# check all values provided must be a subclass of the Vocab base class
|
| 254 |
+
assert all([isinstance(v, BaseVocab) for v in vocab_dict.values()])
|
| 255 |
+
for k, v in vocab_dict.items():
|
| 256 |
+
self._vocabs[k] = v
|
| 257 |
+
|
| 258 |
+
def __setitem__(self, key, item):
|
| 259 |
+
self._vocabs[key] = item
|
| 260 |
+
|
| 261 |
+
def __getitem__(self, key):
|
| 262 |
+
return self._vocabs[key]
|
| 263 |
+
|
| 264 |
+
def __str__(self):
|
| 265 |
+
return "<{}: [{}]>".format(type(self), ", ".join(self._vocabs.keys()))
|
| 266 |
+
|
| 267 |
+
def __contains__(self, key):
|
| 268 |
+
return key in self._vocabs
|
| 269 |
+
|
| 270 |
+
def keys(self):
|
| 271 |
+
return self._vocabs.keys()
|
| 272 |
+
|
| 273 |
+
def state_dict(self):
|
| 274 |
+
""" Build a state dict by iteratively calling state_dict() of all vocabs. """
|
| 275 |
+
state = OrderedDict()
|
| 276 |
+
for k, v in self._vocabs.items():
|
| 277 |
+
state[k] = v.state_dict()
|
| 278 |
+
return state
|
| 279 |
+
|
| 280 |
+
@classmethod
|
| 281 |
+
def load_state_dict(cls, state_dict):
|
| 282 |
+
""" Construct a MultiVocab by reading from a state dict."""
|
| 283 |
+
raise NotImplementedError
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class CharVocab(BaseVocab):
|
| 288 |
+
def build_vocab(self):
|
| 289 |
+
if isinstance(self.data[0][0], (list, tuple)): # general data from DataLoader
|
| 290 |
+
counter = Counter([c for sent in self.data for w in sent for c in w[self.idx]])
|
| 291 |
+
for k in list(counter.keys()):
|
| 292 |
+
if counter[k] < self.cutoff:
|
| 293 |
+
del counter[k]
|
| 294 |
+
else: # special data from Char LM
|
| 295 |
+
counter = Counter([c for sent in self.data for c in sent])
|
| 296 |
+
self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: (counter[k], k), reverse=True))
|
| 297 |
+
self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
|
| 298 |
+
|
stanza/stanza/models/constituency/base_model.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The BaseModel is passed to the transitions so that the transitions
|
| 3 |
+
can operate on a parsing state without knowing the exact
|
| 4 |
+
representation used in the model.
|
| 5 |
+
|
| 6 |
+
For example, a SimpleModel simply looks at the top of the various stacks in the state.
|
| 7 |
+
|
| 8 |
+
A model with LSTM representations for the different transitions may
|
| 9 |
+
attach the hidden and output states of the LSTM to the word /
|
| 10 |
+
constituent / transition stacks.
|
| 11 |
+
|
| 12 |
+
Reminder: the parsing state is a list of words to parse, the
|
| 13 |
+
transitions used to build a (possibly incomplete) parse, and the
|
| 14 |
+
constituent(s) built so far by those transitions. Each of these
|
| 15 |
+
components are represented using stacks to improve the efficiency
|
| 16 |
+
of operations such as "combine the most recent 4 constituents"
|
| 17 |
+
or "turn the next input word into a constituent"
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from abc import ABC, abstractmethod
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from stanza.models.common import utils
|
| 27 |
+
from stanza.models.constituency import transition_sequence
|
| 28 |
+
from stanza.models.constituency.parse_transitions import TransitionScheme, CloseConstituent
|
| 29 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 30 |
+
from stanza.models.constituency.state import State
|
| 31 |
+
from stanza.models.constituency.tree_stack import TreeStack
|
| 32 |
+
from stanza.server.parser_eval import ParseResult, ScoredTree
|
| 33 |
+
|
| 34 |
+
# default unary limit. some treebanks may have longer chains (CTB, for example)
|
| 35 |
+
UNARY_LIMIT = 4
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger('stanza.constituency.trainer')
|
| 38 |
+
|
| 39 |
+
class BaseModel(ABC):
|
| 40 |
+
"""
|
| 41 |
+
This base class defines abstract methods for manipulating a State.
|
| 42 |
+
|
| 43 |
+
Applying transitions may change important metadata about a State
|
| 44 |
+
such as the vectors associated with LSTM hidden states, for example.
|
| 45 |
+
|
| 46 |
+
The constructor forwards all unused arguments to other classes in the
|
| 47 |
+
constructor sequence, so put this before other classes such as nn.Module
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, transition_scheme, unary_limit, reverse_sentence, root_labels, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs) # forwards all unused arguments
|
| 51 |
+
|
| 52 |
+
self._transition_scheme = transition_scheme
|
| 53 |
+
self._unary_limit = unary_limit
|
| 54 |
+
self._reverse_sentence = reverse_sentence
|
| 55 |
+
self._root_labels = sorted(list(root_labels))
|
| 56 |
+
|
| 57 |
+
self._is_top_down = (self._transition_scheme is TransitionScheme.TOP_DOWN or
|
| 58 |
+
self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY or
|
| 59 |
+
self._transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND)
|
| 60 |
+
|
| 61 |
+
@abstractmethod
|
| 62 |
+
def initial_word_queues(self, tagged_word_lists):
|
| 63 |
+
"""
|
| 64 |
+
For each list of tagged words, builds a TreeStack of word nodes
|
| 65 |
+
|
| 66 |
+
The word lists should be backwards so that the first word is the last word put on the stack (LIFO)
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def initial_transitions(self):
|
| 71 |
+
"""
|
| 72 |
+
Builds an initial transition stack with whatever values need to go into first position
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def initial_constituents(self):
|
| 77 |
+
"""
|
| 78 |
+
Builds an initial constituent stack with whatever values need to go into first position
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def get_word(self, word_node):
|
| 83 |
+
"""
|
| 84 |
+
Get the word corresponding to this position in the word queue
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
@abstractmethod
|
| 88 |
+
def transform_word_to_constituent(self, state):
|
| 89 |
+
"""
|
| 90 |
+
Transform the top node of word_queue to something that can push on the constituent stack
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
@abstractmethod
|
| 94 |
+
def dummy_constituent(self, dummy):
|
| 95 |
+
"""
|
| 96 |
+
When using a dummy node as a sentinel, transform it to something usable by this model
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
@abstractmethod
|
| 100 |
+
def build_constituents(self, labels, children_lists):
|
| 101 |
+
"""
|
| 102 |
+
Build multiple constituents at once. This gives the opportunity for batching operations
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
@abstractmethod
|
| 106 |
+
def push_constituents(self, constituent_stacks, constituents):
|
| 107 |
+
"""
|
| 108 |
+
Add a multiple constituents to multiple constituent_stacks
|
| 109 |
+
|
| 110 |
+
Useful to factor this out in case batching will help
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
@abstractmethod
|
| 114 |
+
def get_top_constituent(self, constituents):
|
| 115 |
+
"""
|
| 116 |
+
Get the first constituent from the constituent stack
|
| 117 |
+
|
| 118 |
+
For example, a model might want to remove embeddings and LSTM state vectors
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
@abstractmethod
|
| 122 |
+
def push_transitions(self, transition_stacks, transitions):
|
| 123 |
+
"""
|
| 124 |
+
Add a multiple transitions to multiple transition_stacks
|
| 125 |
+
|
| 126 |
+
Useful to factor this out in case batching will help
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
@abstractmethod
|
| 130 |
+
def get_top_transition(self, transitions):
|
| 131 |
+
"""
|
| 132 |
+
Get the first transition from the transition stack
|
| 133 |
+
|
| 134 |
+
For example, a model might want to remove transition embeddings before returning the transition
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def root_labels(self):
|
| 139 |
+
"""
|
| 140 |
+
Return ROOT labels for this model. Probably ROOT, TOP, or both
|
| 141 |
+
|
| 142 |
+
(Danish uses 's', though)
|
| 143 |
+
"""
|
| 144 |
+
return self._root_labels
|
| 145 |
+
|
| 146 |
+
def unary_limit(self):
|
| 147 |
+
"""
|
| 148 |
+
Limit on the number of consecutive unary transitions
|
| 149 |
+
"""
|
| 150 |
+
return self._unary_limit
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def transition_scheme(self):
|
| 154 |
+
"""
|
| 155 |
+
Transition scheme used - see parse_transitions
|
| 156 |
+
"""
|
| 157 |
+
return self._transition_scheme
|
| 158 |
+
|
| 159 |
+
def has_unary_transitions(self):
|
| 160 |
+
"""
|
| 161 |
+
Whether or not this model uses unary transitions, based on transition_scheme
|
| 162 |
+
"""
|
| 163 |
+
return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def is_top_down(self):
|
| 167 |
+
"""
|
| 168 |
+
Whether or not this model is TOP_DOWN
|
| 169 |
+
"""
|
| 170 |
+
return self._is_top_down
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def reverse_sentence(self):
|
| 174 |
+
"""
|
| 175 |
+
Whether or not this model is built to parse backwards
|
| 176 |
+
"""
|
| 177 |
+
return self._reverse_sentence
|
| 178 |
+
|
| 179 |
+
def predict(self, states, is_legal=True):
|
| 180 |
+
raise NotImplementedError("LSTMModel can predict, but SimpleModel cannot")
|
| 181 |
+
|
| 182 |
+
def weighted_choice(self, states):
|
| 183 |
+
raise NotImplementedError("LSTMModel can weighted_choice, but SimpleModel cannot")
|
| 184 |
+
|
| 185 |
+
def predict_gold(self, states, is_legal=True):
|
| 186 |
+
"""
|
| 187 |
+
For each State, return the next item in the gold_sequence
|
| 188 |
+
"""
|
| 189 |
+
transitions = [y.gold_sequence[y.num_transitions] for y in states]
|
| 190 |
+
if is_legal:
|
| 191 |
+
for trans, state in zip(transitions, states):
|
| 192 |
+
if not trans.is_legal(state, self):
|
| 193 |
+
raise RuntimeError("Transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(state.num_transitions, trans, state.gold_tree, state.gold_sequence))
|
| 194 |
+
return None, transitions, None
|
| 195 |
+
|
| 196 |
+
def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):
|
| 197 |
+
"""
|
| 198 |
+
what is passed in should be a list of list of preterminals
|
| 199 |
+
"""
|
| 200 |
+
word_queues = self.initial_word_queues(preterminal_lists)
|
| 201 |
+
# this is the bottom of the TreeStack and will be the same for each State
|
| 202 |
+
transitions = self.initial_transitions()
|
| 203 |
+
constituents = self.initial_constituents()
|
| 204 |
+
states = [State(sentence_length=len(wq)-2, # -2 because it starts and ends with a sentinel
|
| 205 |
+
num_opens=0,
|
| 206 |
+
word_queue=wq,
|
| 207 |
+
gold_tree=None,
|
| 208 |
+
gold_sequence=None,
|
| 209 |
+
transitions=transitions,
|
| 210 |
+
constituents=constituents,
|
| 211 |
+
word_position=0,
|
| 212 |
+
score=0.0)
|
| 213 |
+
for idx, wq in enumerate(word_queues)]
|
| 214 |
+
if gold_trees:
|
| 215 |
+
states = [state._replace(gold_tree=gold_tree) for gold_tree, state in zip(gold_trees, states)]
|
| 216 |
+
if gold_sequences:
|
| 217 |
+
states = [state._replace(gold_sequence=gold_sequence) for gold_sequence, state in zip(gold_sequences, states)]
|
| 218 |
+
return states
|
| 219 |
+
|
| 220 |
+
def initial_state_from_words(self, word_lists):
|
| 221 |
+
preterminal_lists = [[Tree(tag, Tree(word)) for word, tag in words]
|
| 222 |
+
for words in word_lists]
|
| 223 |
+
return self.initial_state_from_preterminals(preterminal_lists, gold_trees=None, gold_sequences=None)
|
| 224 |
+
|
| 225 |
+
def initial_state_from_gold_trees(self, trees, gold_sequences=None):
|
| 226 |
+
preterminal_lists = [[Tree(pt.label, Tree(pt.children[0].label))
|
| 227 |
+
for pt in tree.yield_preterminals()]
|
| 228 |
+
for tree in trees]
|
| 229 |
+
return self.initial_state_from_preterminals(preterminal_lists, gold_trees=trees, gold_sequences=gold_sequences)
|
| 230 |
+
|
| 231 |
+
def build_batch_from_trees(self, batch_size, data_iterator):
|
| 232 |
+
"""
|
| 233 |
+
Read from the data_iterator batch_size trees and turn them into new parsing states
|
| 234 |
+
"""
|
| 235 |
+
state_batch = []
|
| 236 |
+
for _ in range(batch_size):
|
| 237 |
+
gold_tree = next(data_iterator, None)
|
| 238 |
+
if gold_tree is None:
|
| 239 |
+
break
|
| 240 |
+
state_batch.append(gold_tree)
|
| 241 |
+
|
| 242 |
+
if len(state_batch) > 0:
|
| 243 |
+
state_batch = self.initial_state_from_gold_trees(state_batch)
|
| 244 |
+
return state_batch
|
| 245 |
+
|
| 246 |
+
def build_batch_from_trees_with_gold_sequence(self, batch_size, data_iterator):
|
| 247 |
+
"""
|
| 248 |
+
Same as build_batch_from_trees, but use the model parameters to turn the trees into gold sequences and include the sequence
|
| 249 |
+
"""
|
| 250 |
+
state_batch = self.build_batch_from_trees(batch_size, data_iterator)
|
| 251 |
+
if len(state_batch) == 0:
|
| 252 |
+
return state_batch
|
| 253 |
+
|
| 254 |
+
gold_sequences = transition_sequence.build_treebank([state.gold_tree for state in state_batch], self.transition_scheme(), self.reverse_sentence)
|
| 255 |
+
state_batch = [state._replace(gold_sequence=sequence) for state, sequence in zip(state_batch, gold_sequences)]
|
| 256 |
+
return state_batch
|
| 257 |
+
|
| 258 |
+
def build_batch_from_tagged_words(self, batch_size, data_iterator):
|
| 259 |
+
"""
|
| 260 |
+
Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
|
| 261 |
+
|
| 262 |
+
Expects a list of list of (word, tag)
|
| 263 |
+
"""
|
| 264 |
+
state_batch = []
|
| 265 |
+
for _ in range(batch_size):
|
| 266 |
+
sentence = next(data_iterator, None)
|
| 267 |
+
if sentence is None:
|
| 268 |
+
break
|
| 269 |
+
state_batch.append(sentence)
|
| 270 |
+
|
| 271 |
+
if len(state_batch) > 0:
|
| 272 |
+
state_batch = self.initial_state_from_words(state_batch)
|
| 273 |
+
return state_batch
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
|
| 277 |
+
"""
|
| 278 |
+
Repeat transitions to build a list of trees from the input batches.
|
| 279 |
+
|
| 280 |
+
The data_iterator should be anything which returns the data for a parse task via next()
|
| 281 |
+
build_batch_fn is a function that turns that data into State objects
|
| 282 |
+
This will be called to generate batches of size batch_size until the data is exhausted
|
| 283 |
+
|
| 284 |
+
The return is a list of tuples: (gold_tree, [(predicted, score) ...])
|
| 285 |
+
gold_tree will be left blank if the data did not include gold trees
|
| 286 |
+
if keep_scores is true, the score will be the sum of the values
|
| 287 |
+
returned by the model for each transition
|
| 288 |
+
|
| 289 |
+
transition_choice: which method of the model to use for choosing the next transition
|
| 290 |
+
predict for predicting the transition based on the model
|
| 291 |
+
predict_gold to just extract the gold transition from the sequence
|
| 292 |
+
"""
|
| 293 |
+
treebank = []
|
| 294 |
+
treebank_indices = []
|
| 295 |
+
state_batch = build_batch_fn(batch_size, data_iterator)
|
| 296 |
+
# used to track which indices we are currently parsing
|
| 297 |
+
# since the parses get finished at different times, this will let us unsort after
|
| 298 |
+
batch_indices = list(range(len(state_batch)))
|
| 299 |
+
horizon_iterator = iter([])
|
| 300 |
+
|
| 301 |
+
if keep_constituents:
|
| 302 |
+
constituents = defaultdict(list)
|
| 303 |
+
|
| 304 |
+
while len(state_batch) > 0:
|
| 305 |
+
pred_scores, transitions, scores = transition_choice(state_batch)
|
| 306 |
+
if keep_scores and scores is not None:
|
| 307 |
+
state_batch = [state._replace(score=state.score + score) for state, score in zip(state_batch, scores)]
|
| 308 |
+
state_batch = self.bulk_apply(state_batch, transitions)
|
| 309 |
+
|
| 310 |
+
if keep_constituents:
|
| 311 |
+
for t_idx, transition in enumerate(transitions):
|
| 312 |
+
if isinstance(transition, CloseConstituent):
|
| 313 |
+
# constituents is a TreeStack with information on how to build the next state of the LSTM or attn
|
| 314 |
+
# constituents.value is the TreeStack node
|
| 315 |
+
# constituents.value.value is the Constituent itself (with the tree and the embedding)
|
| 316 |
+
constituents[batch_indices[t_idx]].append(state_batch[t_idx].constituents.value.value)
|
| 317 |
+
|
| 318 |
+
remove = set()
|
| 319 |
+
for idx, state in enumerate(state_batch):
|
| 320 |
+
if state.finished(self):
|
| 321 |
+
predicted_tree = state.get_tree(self)
|
| 322 |
+
if self.reverse_sentence:
|
| 323 |
+
predicted_tree = predicted_tree.reverse()
|
| 324 |
+
gold_tree = state.gold_tree
|
| 325 |
+
treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, state.score)], state if keep_state else None, constituents[batch_indices[idx]] if keep_constituents else None))
|
| 326 |
+
treebank_indices.append(batch_indices[idx])
|
| 327 |
+
remove.add(idx)
|
| 328 |
+
|
| 329 |
+
if len(remove) > 0:
|
| 330 |
+
state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]
|
| 331 |
+
batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]
|
| 332 |
+
|
| 333 |
+
for _ in range(batch_size - len(state_batch)):
|
| 334 |
+
horizon_state = next(horizon_iterator, None)
|
| 335 |
+
if not horizon_state:
|
| 336 |
+
horizon_batch = build_batch_fn(batch_size, data_iterator)
|
| 337 |
+
if len(horizon_batch) == 0:
|
| 338 |
+
break
|
| 339 |
+
horizon_iterator = iter(horizon_batch)
|
| 340 |
+
horizon_state = next(horizon_iterator, None)
|
| 341 |
+
|
| 342 |
+
state_batch.append(horizon_state)
|
| 343 |
+
batch_indices.append(len(treebank) + len(state_batch))
|
| 344 |
+
|
| 345 |
+
treebank = utils.unsort(treebank, treebank_indices)
|
| 346 |
+
return treebank
|
| 347 |
+
|
| 348 |
+
def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
|
| 349 |
+
"""
|
| 350 |
+
Given an iterator over the data and a method for building batches, returns a list of parse trees.
|
| 351 |
+
|
| 352 |
+
no_grad() is so that gradients aren't kept, which makes the model
|
| 353 |
+
run faster and use less memory at inference time
|
| 354 |
+
"""
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)
|
| 357 |
+
|
| 358 |
+
def analyze_trees(self, trees, batch_size=None, keep_state=True, keep_constituents=True, keep_scores=True):
|
| 359 |
+
"""
|
| 360 |
+
Return a ParseResult for each tree in the trees list
|
| 361 |
+
|
| 362 |
+
The transitions run will be the transitions represented by the tree
|
| 363 |
+
The output layers will be available in result.state for each result
|
| 364 |
+
|
| 365 |
+
keep_state=True as a default here as a method which keeps the grad
|
| 366 |
+
is likely to want to keep the resulting state as well
|
| 367 |
+
"""
|
| 368 |
+
if batch_size is None:
|
| 369 |
+
# TODO: refactor?
|
| 370 |
+
batch_size = self.args['eval_batch_size']
|
| 371 |
+
tree_iterator = iter(trees)
|
| 372 |
+
treebank = self.parse_sentences(tree_iterator, self.build_batch_from_trees_with_gold_sequence, batch_size, self.predict_gold, keep_state, keep_constituents, keep_scores=keep_scores)
|
| 373 |
+
return treebank
|
| 374 |
+
|
| 375 |
+
def parse_tagged_words(self, words, batch_size):
|
| 376 |
+
"""
|
| 377 |
+
This parses tagged words and returns a list of trees.
|
| 378 |
+
|
| 379 |
+
`parse_tagged_words` is useful at Pipeline time -
|
| 380 |
+
it takes words & tags and processes that into trees.
|
| 381 |
+
|
| 382 |
+
The tagged words should be represented:
|
| 383 |
+
one list per sentence
|
| 384 |
+
each sentence is a list of (word, tag)
|
| 385 |
+
The return value is a list of ParseTree objects
|
| 386 |
+
"""
|
| 387 |
+
logger.debug("Processing %d sentences", len(words))
|
| 388 |
+
self.eval()
|
| 389 |
+
|
| 390 |
+
sentence_iterator = iter(words)
|
| 391 |
+
treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)
|
| 392 |
+
|
| 393 |
+
results = [t.predictions[0].tree for t in treebank]
|
| 394 |
+
return results
|
| 395 |
+
|
| 396 |
+
def bulk_apply(self, state_batch, transitions, fail=False):
|
| 397 |
+
"""
|
| 398 |
+
Apply the given list of Transitions to the given list of States, using the model as a reference
|
| 399 |
+
|
| 400 |
+
model: SimpleModel, LSTMModel, or any other form of model
|
| 401 |
+
state_batch: list of States
|
| 402 |
+
transitions: list of transitions, one per state
|
| 403 |
+
fail: throw an exception on a failed transition, as opposed to skipping the tree
|
| 404 |
+
"""
|
| 405 |
+
remove = set()
|
| 406 |
+
|
| 407 |
+
word_positions = []
|
| 408 |
+
constituents = []
|
| 409 |
+
new_constituents = []
|
| 410 |
+
callbacks = defaultdict(list)
|
| 411 |
+
|
| 412 |
+
for idx, (tree, transition) in enumerate(zip(state_batch, transitions)):
|
| 413 |
+
if not transition:
|
| 414 |
+
error = "Got stuck and couldn't find a legal transition on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(self))
|
| 415 |
+
if fail:
|
| 416 |
+
raise ValueError(error)
|
| 417 |
+
else:
|
| 418 |
+
logger.error(error)
|
| 419 |
+
remove.add(idx)
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
if tree.num_transitions >= len(tree.word_queue) * 20:
|
| 423 |
+
# too many transitions
|
| 424 |
+
# x20 is somewhat empirically chosen based on certain
|
| 425 |
+
# treebanks having deep unary structures, especially early
|
| 426 |
+
# on when the model is fumbling around
|
| 427 |
+
if tree.gold_tree:
|
| 428 |
+
error = "Went infinite on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(self))
|
| 429 |
+
else:
|
| 430 |
+
error = "Went infinite!:\nFinal state:\n{}".format(tree.to_string(self))
|
| 431 |
+
if fail:
|
| 432 |
+
raise ValueError(error)
|
| 433 |
+
else:
|
| 434 |
+
logger.error(error)
|
| 435 |
+
remove.add(idx)
|
| 436 |
+
continue
|
| 437 |
+
|
| 438 |
+
wq, c, nc, callback = transition.update_state(tree, self)
|
| 439 |
+
|
| 440 |
+
word_positions.append(wq)
|
| 441 |
+
constituents.append(c)
|
| 442 |
+
new_constituents.append(nc)
|
| 443 |
+
if callback:
|
| 444 |
+
# not `idx` in case something was removed
|
| 445 |
+
callbacks[callback].append(len(new_constituents)-1)
|
| 446 |
+
|
| 447 |
+
for key, idxs in callbacks.items():
|
| 448 |
+
data = [new_constituents[x] for x in idxs]
|
| 449 |
+
callback_constituents = key.build_constituents(self, data)
|
| 450 |
+
for idx, constituent in zip(idxs, callback_constituents):
|
| 451 |
+
new_constituents[idx] = constituent
|
| 452 |
+
|
| 453 |
+
if len(remove) > 0:
|
| 454 |
+
state_batch = [tree for idx, tree in enumerate(state_batch) if idx not in remove]
|
| 455 |
+
transitions = [trans for idx, trans in enumerate(transitions) if idx not in remove]
|
| 456 |
+
|
| 457 |
+
if len(state_batch) == 0:
|
| 458 |
+
return state_batch
|
| 459 |
+
|
| 460 |
+
new_transitions = self.push_transitions([tree.transitions for tree in state_batch], transitions)
|
| 461 |
+
new_constituents = self.push_constituents(constituents, new_constituents)
|
| 462 |
+
|
| 463 |
+
state_batch = [state._replace(num_opens=state.num_opens + transition.delta_opens(),
|
| 464 |
+
word_position=word_position,
|
| 465 |
+
transitions=transition_stack,
|
| 466 |
+
constituents=constituents)
|
| 467 |
+
for (state, transition, word_position, transition_stack, constituents)
|
| 468 |
+
in zip(state_batch, transitions, word_positions, new_transitions, new_constituents)]
|
| 469 |
+
|
| 470 |
+
return state_batch
|
| 471 |
+
|
| 472 |
+
class SimpleModel(BaseModel):
|
| 473 |
+
"""
|
| 474 |
+
This model allows pushing and popping with no extra data
|
| 475 |
+
|
| 476 |
+
This class is primarily used for testing various operations which
|
| 477 |
+
don't need the NN's weights
|
| 478 |
+
|
| 479 |
+
Also, for rebuilding trees from transitions when verifying the
|
| 480 |
+
transitions in situations where the NN state is not relevant,
|
| 481 |
+
as this class will be faster than using the NN
|
| 482 |
+
"""
|
| 483 |
+
def __init__(self, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, unary_limit=UNARY_LIMIT, reverse_sentence=False, root_labels=("ROOT",)):
|
| 484 |
+
super().__init__(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse_sentence, root_labels=root_labels)
|
| 485 |
+
|
| 486 |
+
def initial_word_queues(self, tagged_word_lists):
|
| 487 |
+
word_queues = []
|
| 488 |
+
for tagged_words in tagged_word_lists:
|
| 489 |
+
word_queue = [None]
|
| 490 |
+
word_queue += [tag_node for tag_node in tagged_words]
|
| 491 |
+
word_queue.append(None)
|
| 492 |
+
if self.reverse_sentence:
|
| 493 |
+
word_queue.reverse()
|
| 494 |
+
word_queues.append(word_queue)
|
| 495 |
+
return word_queues
|
| 496 |
+
|
| 497 |
+
def initial_transitions(self):
|
| 498 |
+
return TreeStack(value=None, parent=None, length=1)
|
| 499 |
+
|
| 500 |
+
def initial_constituents(self):
|
| 501 |
+
return TreeStack(value=None, parent=None, length=1)
|
| 502 |
+
|
| 503 |
+
def get_word(self, word_node):
|
| 504 |
+
return word_node
|
| 505 |
+
|
| 506 |
+
def transform_word_to_constituent(self, state):
|
| 507 |
+
return state.get_word(state.word_position)
|
| 508 |
+
|
| 509 |
+
def dummy_constituent(self, dummy):
|
| 510 |
+
return dummy
|
| 511 |
+
|
| 512 |
+
def build_constituents(self, labels, children_lists):
|
| 513 |
+
constituents = []
|
| 514 |
+
for label, children in zip(labels, children_lists):
|
| 515 |
+
if isinstance(label, str):
|
| 516 |
+
label = (label,)
|
| 517 |
+
for value in reversed(label):
|
| 518 |
+
children = Tree(label=value, children=children)
|
| 519 |
+
constituents.append(children)
|
| 520 |
+
return constituents
|
| 521 |
+
|
| 522 |
+
def push_constituents(self, constituent_stacks, constituents):
|
| 523 |
+
return [stack.push(constituent) for stack, constituent in zip(constituent_stacks, constituents)]
|
| 524 |
+
|
| 525 |
+
def get_top_constituent(self, constituents):
|
| 526 |
+
return constituents.value
|
| 527 |
+
|
| 528 |
+
def push_transitions(self, transition_stacks, transitions):
|
| 529 |
+
return [stack.push(transition) for stack, transition in zip(transition_stacks, transitions)]
|
| 530 |
+
|
| 531 |
+
def get_top_transition(self, transitions):
|
| 532 |
+
return transitions.value
|
stanza/stanza/models/constituency/base_trainer.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from pickle import UnpicklingError
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger('stanza')
|
| 11 |
+
|
| 12 |
+
class ModelType(Enum):
|
| 13 |
+
LSTM = 1
|
| 14 |
+
ENSEMBLE = 2
|
| 15 |
+
|
| 16 |
+
class BaseTrainer:
|
| 17 |
+
def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
|
| 18 |
+
self.model = model
|
| 19 |
+
self.optimizer = optimizer
|
| 20 |
+
self.scheduler = scheduler
|
| 21 |
+
# keeping track of the epochs trained will be useful
|
| 22 |
+
# for adjusting the learning scheme
|
| 23 |
+
self.epochs_trained = epochs_trained
|
| 24 |
+
self.batches_trained = batches_trained
|
| 25 |
+
self.best_f1 = best_f1
|
| 26 |
+
self.best_epoch = best_epoch
|
| 27 |
+
self.first_optimizer = first_optimizer
|
| 28 |
+
|
| 29 |
+
def save(self, filename, save_optimizer=True):
|
| 30 |
+
params = self.model.get_params()
|
| 31 |
+
checkpoint = {
|
| 32 |
+
'params': params,
|
| 33 |
+
'epochs_trained': self.epochs_trained,
|
| 34 |
+
'batches_trained': self.batches_trained,
|
| 35 |
+
'best_f1': self.best_f1,
|
| 36 |
+
'best_epoch': self.best_epoch,
|
| 37 |
+
'model_type': self.model_type.name,
|
| 38 |
+
'first_optimizer': self.first_optimizer,
|
| 39 |
+
}
|
| 40 |
+
checkpoint["bert_lora"] = self.get_peft_params()
|
| 41 |
+
if save_optimizer and self.optimizer is not None:
|
| 42 |
+
checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
|
| 43 |
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
| 44 |
+
torch.save(checkpoint, filename, _use_new_zipfile_serialization=False)
|
| 45 |
+
logger.info("Model saved to %s", filename)
|
| 46 |
+
|
| 47 |
+
def log_norms(self):
|
| 48 |
+
self.model.log_norms()
|
| 49 |
+
|
| 50 |
+
def log_shapes(self):
|
| 51 |
+
self.model.log_shapes()
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def transitions(self):
|
| 55 |
+
return self.model.transitions
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def root_labels(self):
|
| 59 |
+
return self.model.root_labels
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def device(self):
|
| 63 |
+
return next(self.model.parameters()).device
|
| 64 |
+
|
| 65 |
+
def train(self):
|
| 66 |
+
return self.model.train()
|
| 67 |
+
|
| 68 |
+
def eval(self):
|
| 69 |
+
return self.model.eval()
|
| 70 |
+
|
| 71 |
+
# TODO: make ABC with methods such as model_from_params?
|
| 72 |
+
# TODO: if we save the type in the checkpoint, use that here to figure out which to load
|
| 73 |
+
@staticmethod
|
| 74 |
+
def load(filename, args=None, load_optimizer=False, foundation_cache=None, peft_name=None):
|
| 75 |
+
"""
|
| 76 |
+
Load back a model and possibly its optimizer.
|
| 77 |
+
"""
|
| 78 |
+
# hide the import here to avoid circular imports
|
| 79 |
+
from stanza.models.constituency.ensemble import EnsembleTrainer
|
| 80 |
+
from stanza.models.constituency.trainer import Trainer
|
| 81 |
+
|
| 82 |
+
if not os.path.exists(filename):
|
| 83 |
+
if args.get('save_dir', None) is None:
|
| 84 |
+
raise FileNotFoundError("Cannot find model in {} and args['save_dir'] is None".format(filename))
|
| 85 |
+
elif os.path.exists(os.path.join(args['save_dir'], filename)):
|
| 86 |
+
filename = os.path.join(args['save_dir'], filename)
|
| 87 |
+
else:
|
| 88 |
+
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args['save_dir'], filename)))
|
| 89 |
+
try:
|
| 90 |
+
# TODO: currently cannot switch this to weights_only=True
|
| 91 |
+
# without in some way changing the model to save enums in
|
| 92 |
+
# a safe manner, probably by converting to int
|
| 93 |
+
try:
|
| 94 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 95 |
+
except UnpicklingError as e:
|
| 96 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
|
| 97 |
+
warnings.warn("The saved constituency parser has an old format using Enum, set, unsanitized Transitions, etc. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the constituency parser using this version ASAP.")
|
| 98 |
+
except BaseException:
|
| 99 |
+
logger.exception("Cannot load model from %s", filename)
|
| 100 |
+
raise
|
| 101 |
+
logger.debug("Loaded model from %s", filename)
|
| 102 |
+
|
| 103 |
+
params = checkpoint['params']
|
| 104 |
+
|
| 105 |
+
if 'model_type' not in checkpoint:
|
| 106 |
+
# old models will have this trait
|
| 107 |
+
# TODO: can remove this after 1.10
|
| 108 |
+
checkpoint['model_type'] = ModelType.LSTM
|
| 109 |
+
if isinstance(checkpoint['model_type'], str):
|
| 110 |
+
checkpoint['model_type'] = ModelType[checkpoint['model_type']]
|
| 111 |
+
if checkpoint['model_type'] == ModelType.LSTM:
|
| 112 |
+
clazz = Trainer
|
| 113 |
+
elif checkpoint['model_type'] == ModelType.ENSEMBLE:
|
| 114 |
+
clazz = EnsembleTrainer
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError("Unexpected model type: %s" % checkpoint['model_type'])
|
| 117 |
+
model = clazz.model_from_params(params, checkpoint.get('bert_lora', None), args, foundation_cache, peft_name)
|
| 118 |
+
|
| 119 |
+
epochs_trained = checkpoint['epochs_trained']
|
| 120 |
+
batches_trained = checkpoint.get('batches_trained', 0)
|
| 121 |
+
best_f1 = checkpoint['best_f1']
|
| 122 |
+
best_epoch = checkpoint['best_epoch']
|
| 123 |
+
|
| 124 |
+
if 'first_optimizer' not in checkpoint:
|
| 125 |
+
# this will only apply to old (LSTM) Trainers
|
| 126 |
+
# EnsembleTrainers will always have this value saved
|
| 127 |
+
# so here we can compensate by looking at the old training statistics...
|
| 128 |
+
# we use params['config'] here instead of model.args
|
| 129 |
+
# because the args might have a different training
|
| 130 |
+
# mechanism, but in order to reload the optimizer, we need
|
| 131 |
+
# to match the optimizer we build with the one that was
|
| 132 |
+
# used at training time
|
| 133 |
+
build_simple_adadelta = params['config']['multistage'] and epochs_trained < params['config']['epochs'] // 2
|
| 134 |
+
checkpoint['first_optimizer'] = build_simple_adadelta
|
| 135 |
+
first_optimizer = checkpoint['first_optimizer']
|
| 136 |
+
|
| 137 |
+
if load_optimizer:
|
| 138 |
+
optimizer = clazz.load_optimizer(model, checkpoint, first_optimizer, filename)
|
| 139 |
+
scheduler = clazz.load_scheduler(model, optimizer, checkpoint, first_optimizer)
|
| 140 |
+
else:
|
| 141 |
+
optimizer = None
|
| 142 |
+
scheduler = None
|
| 143 |
+
|
| 144 |
+
if checkpoint['model_type'] == ModelType.LSTM:
|
| 145 |
+
logger.debug("-- MODEL CONFIG --")
|
| 146 |
+
for k in model.args.keys():
|
| 147 |
+
logger.debug(" --%s: %s", k, model.args[k])
|
| 148 |
+
return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)
|
| 149 |
+
elif checkpoint['model_type'] == ModelType.ENSEMBLE:
|
| 150 |
+
return EnsembleTrainer(ensemble=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError("Unexpected model type: %s" % checkpoint['model_type'])
|
| 153 |
+
|
stanza/stanza/models/constituency/ensemble.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prototype of ensembling N models together on the same dataset
|
| 3 |
+
|
| 4 |
+
The main inference method is to run the normal transition sequence,
|
| 5 |
+
but sum the scores for the N models and use that to choose the highest
|
| 6 |
+
scoring transition
|
| 7 |
+
|
| 8 |
+
Example of how to run it to build a silver dataset
|
| 9 |
+
(or just parse a text file in general):
|
| 10 |
+
|
| 11 |
+
# first, use this tool to build a saved ensemble
|
| 12 |
+
python3 stanza/models/constituency/ensemble.py
|
| 13 |
+
saved_models/constituency/wsj_inorder_?.pt
|
| 14 |
+
--save_name saved_models/constituency/en_ensemble.pt
|
| 15 |
+
|
| 16 |
+
# then use the ensemble directly as a model in constituency_parser.py
|
| 17 |
+
python3 stanza/models/constituency_parser.py
|
| 18 |
+
--save_name saved_models/constituency/en_ensemble.pt
|
| 19 |
+
--mode parse_text
|
| 20 |
+
--tokenized_file /nlp/scr/horatio/en_silver/en_split_100
|
| 21 |
+
--predict_file /nlp/scr/horatio/en_silver/en_split_100.inorder.mrg
|
| 22 |
+
--retag_package en_combined_bert
|
| 23 |
+
--lang en
|
| 24 |
+
|
| 25 |
+
then, ideally, run a second time with a set of topdown models,
|
| 26 |
+
then take the trees which match from the files
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import copy
|
| 32 |
+
import logging
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
|
| 38 |
+
from stanza.models.common import utils
|
| 39 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 40 |
+
from stanza.models.constituency.base_trainer import BaseTrainer, ModelType
|
| 41 |
+
from stanza.models.constituency.state import MultiState
|
| 42 |
+
from stanza.models.constituency.trainer import Trainer
|
| 43 |
+
from stanza.models.constituency.utils import build_optimizer, build_scheduler
|
| 44 |
+
from stanza.server.parser_eval import ParseResult, ScoredTree
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger('stanza.constituency.trainer')
|
| 47 |
+
|
| 48 |
+
class Ensemble(nn.Module):
|
| 49 |
+
def __init__(self, args, filenames=None, models=None, foundation_cache=None):
|
| 50 |
+
"""
|
| 51 |
+
Loads each model in filenames
|
| 52 |
+
|
| 53 |
+
If foundation_cache is None, we build one on our own,
|
| 54 |
+
as the expectation is the models will reuse modules
|
| 55 |
+
such as pretrain, charlm, bert
|
| 56 |
+
"""
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
self.args = args
|
| 60 |
+
if filenames:
|
| 61 |
+
if models:
|
| 62 |
+
raise ValueError("both filenames and models set when making the Ensemble")
|
| 63 |
+
|
| 64 |
+
if foundation_cache is None:
|
| 65 |
+
foundation_cache = FoundationCache()
|
| 66 |
+
|
| 67 |
+
if isinstance(filenames, str):
|
| 68 |
+
filenames = [filenames]
|
| 69 |
+
logger.info("Models used for ensemble:\n %s", "\n ".join(filenames))
|
| 70 |
+
models = [Trainer.load(filename, args, load_optimizer=False, foundation_cache=foundation_cache).model for filename in filenames]
|
| 71 |
+
elif not models:
|
| 72 |
+
raise ValueError("filenames and models both not set!")
|
| 73 |
+
|
| 74 |
+
self.models = nn.ModuleList(models)
|
| 75 |
+
|
| 76 |
+
for model_idx, model in enumerate(self.models):
|
| 77 |
+
if self.models[0].transition_scheme() != model.transition_scheme():
|
| 78 |
+
raise ValueError("Models {} and {} are incompatible. {} vs {}".format(filenames[0], filenames[model_idx], self.models[0].transition_scheme(), model.transition_scheme()))
|
| 79 |
+
if self.models[0].transitions != model.transitions:
|
| 80 |
+
raise ValueError(f"Models {filenames[0]} and {filenames[model_idx]} are incompatible: different transitions\n{filenames[0]}:\n{self.models[0].transitions}\n{filenames[model_idx]}:\n{model.transitions}")
|
| 81 |
+
if self.models[0].constituents != model.constituents:
|
| 82 |
+
raise ValueError("Models %s and %s are incompatible: different constituents" % (filenames[0], filenames[model_idx]))
|
| 83 |
+
if self.models[0].root_labels != model.root_labels:
|
| 84 |
+
raise ValueError("Models %s and %s are incompatible: different root_labels" % (filenames[0], filenames[model_idx]))
|
| 85 |
+
if self.models[0].uses_xpos() != model.uses_xpos():
|
| 86 |
+
raise ValueError("Models %s and %s are incompatible: different uses_xpos" % (filenames[0], filenames[model_idx]))
|
| 87 |
+
if self.models[0].reverse_sentence != model.reverse_sentence:
|
| 88 |
+
raise ValueError("Models %s and %s are incompatible: different reverse_sentence" % (filenames[0], filenames[model_idx]))
|
| 89 |
+
|
| 90 |
+
self._reverse_sentence = self.models[0].reverse_sentence
|
| 91 |
+
|
| 92 |
+
# submodels are not trained (so far)
|
| 93 |
+
self.detach_submodels()
|
| 94 |
+
|
| 95 |
+
logger.debug("Number of models in the Ensemble: %d", len(self.models))
|
| 96 |
+
self.register_parameter('weighted_sum', torch.nn.Parameter(torch.zeros(len(self.models), len(self.transitions), requires_grad=True)))
|
| 97 |
+
|
| 98 |
+
def detach_submodels(self):
|
| 99 |
+
# submodels are not trained (so far)
|
| 100 |
+
for model in self.models:
|
| 101 |
+
for _, parameter in model.named_parameters():
|
| 102 |
+
parameter.requires_grad = False
|
| 103 |
+
|
| 104 |
+
def train(self, mode=True):
|
| 105 |
+
super().train(mode)
|
| 106 |
+
if mode:
|
| 107 |
+
# peft has a weird interaction where it turns requires_grad back on
|
| 108 |
+
# even if it was previously off
|
| 109 |
+
self.detach_submodels()
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def transitions(self):
|
| 113 |
+
return self.models[0].transitions
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def root_labels(self):
|
| 117 |
+
return self.models[0].root_labels
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def device(self):
|
| 121 |
+
return next(self.parameters()).device
|
| 122 |
+
|
| 123 |
+
def unary_limit(self):
|
| 124 |
+
"""
|
| 125 |
+
Limit on the number of consecutive unary transitions
|
| 126 |
+
"""
|
| 127 |
+
return min(m.unary_limit() for m in self.models)
|
| 128 |
+
|
| 129 |
+
def transition_scheme(self):
|
| 130 |
+
return self.models[0].transition_scheme()
|
| 131 |
+
|
| 132 |
+
def has_unary_transitions(self):
|
| 133 |
+
return self.models[0].has_unary_transitions()
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def is_top_down(self):
|
| 137 |
+
return self.models[0].is_top_down
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def reverse_sentence(self):
|
| 141 |
+
return self._reverse_sentence
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def retag_method(self):
|
| 145 |
+
# TODO: make the method an enum
|
| 146 |
+
return self.models[0].args['retag_method']
|
| 147 |
+
|
| 148 |
+
def uses_xpos(self):
|
| 149 |
+
return self.models[0].uses_xpos()
|
| 150 |
+
|
| 151 |
+
def get_top_constituent(self, constituents):
|
| 152 |
+
return self.models[0].get_top_constituent(constituents)
|
| 153 |
+
|
| 154 |
+
def get_top_transition(self, transitions):
|
| 155 |
+
return self.models[0].get_top_transition(transitions)
|
| 156 |
+
|
| 157 |
+
def log_norms(self):
|
| 158 |
+
lines = ["NORMS FOR MODEL PARAMETERS"]
|
| 159 |
+
for name, param in self.named_parameters():
|
| 160 |
+
if param.requires_grad and not name.startswith("models."):
|
| 161 |
+
zeros = torch.sum(param.abs() < 0.000001).item()
|
| 162 |
+
norm = "%.6g" % torch.norm(param).item()
|
| 163 |
+
lines.append("%s %s %d %d" % (name, norm, zeros, param.nelement()))
|
| 164 |
+
for model_idx, model in enumerate(self.models):
|
| 165 |
+
sublines = model.get_norms()
|
| 166 |
+
if len(sublines) > 0:
|
| 167 |
+
lines.append(" ---- MODEL %d ----" % model_idx)
|
| 168 |
+
lines.extend(sublines)
|
| 169 |
+
logger.info("\n".join(lines))
|
| 170 |
+
|
| 171 |
+
def log_shapes(self):
|
| 172 |
+
lines = ["NORMS FOR MODEL PARAMETERS"]
|
| 173 |
+
for name, param in self.named_parameters():
|
| 174 |
+
if param.requires_grad:
|
| 175 |
+
lines.append("{} {}".format(name, param.shape))
|
| 176 |
+
logger.info("\n".join(lines))
|
| 177 |
+
|
| 178 |
+
def get_params(self):
|
| 179 |
+
model_state = self.state_dict()
|
| 180 |
+
# don't save the children in the base params
|
| 181 |
+
model_state = {k: v for k, v in model_state.items() if not k.startswith("models.")}
|
| 182 |
+
return {
|
| 183 |
+
"base_params": model_state,
|
| 184 |
+
"children_params": [x.get_params() for x in self.models]
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):
|
| 188 |
+
state_batch = [model.initial_state_from_preterminals(preterminal_lists, gold_trees, gold_sequences) for model in self.models]
|
| 189 |
+
state_batch = list(zip(*state_batch))
|
| 190 |
+
state_batch = [MultiState(states, gold_tree, gold_sequence, 0.0)
|
| 191 |
+
for states, gold_tree, gold_sequence in zip(state_batch, gold_trees, gold_sequences)]
|
| 192 |
+
return state_batch
|
| 193 |
+
|
| 194 |
+
def build_batch_from_tagged_words(self, batch_size, data_iterator):
|
| 195 |
+
"""
|
| 196 |
+
Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
|
| 197 |
+
|
| 198 |
+
Expects a list of list of (word, tag)
|
| 199 |
+
"""
|
| 200 |
+
state_batch = []
|
| 201 |
+
for _ in range(batch_size):
|
| 202 |
+
sentence = next(data_iterator, None)
|
| 203 |
+
if sentence is None:
|
| 204 |
+
break
|
| 205 |
+
state_batch.append(sentence)
|
| 206 |
+
|
| 207 |
+
if len(state_batch) > 0:
|
| 208 |
+
state_batch = [model.initial_state_from_words(state_batch) for model in self.models]
|
| 209 |
+
state_batch = list(zip(*state_batch))
|
| 210 |
+
state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]
|
| 211 |
+
return state_batch
|
| 212 |
+
|
| 213 |
+
def build_batch_from_trees(self, batch_size, data_iterator):
|
| 214 |
+
"""
|
| 215 |
+
Read from the data_iterator batch_size trees and turn them into N lists of parsing states
|
| 216 |
+
"""
|
| 217 |
+
state_batch = []
|
| 218 |
+
for _ in range(batch_size):
|
| 219 |
+
gold_tree = next(data_iterator, None)
|
| 220 |
+
if gold_tree is None:
|
| 221 |
+
break
|
| 222 |
+
state_batch.append(gold_tree)
|
| 223 |
+
|
| 224 |
+
if len(state_batch) > 0:
|
| 225 |
+
state_batch = [model.initial_state_from_gold_trees(state_batch) for model in self.models]
|
| 226 |
+
state_batch = list(zip(*state_batch))
|
| 227 |
+
state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]
|
| 228 |
+
return state_batch
|
| 229 |
+
|
| 230 |
+
def predict(self, states, is_legal=True):
|
| 231 |
+
states = list(zip(*[x.states for x in states]))
|
| 232 |
+
predictions = [model.forward(state_batch) for model, state_batch in zip(self.models, states)]
|
| 233 |
+
|
| 234 |
+
# batch X num transitions X num models
|
| 235 |
+
predictions = torch.stack(predictions, dim=2)
|
| 236 |
+
|
| 237 |
+
flat_predictions = torch.einsum("BTM,MT->BT", predictions, self.weighted_sum)
|
| 238 |
+
predictions = torch.sum(predictions, dim=2) + flat_predictions
|
| 239 |
+
|
| 240 |
+
model = self.models[0]
|
| 241 |
+
|
| 242 |
+
# TODO: possibly refactor with lstm_model.predict
|
| 243 |
+
pred_max = torch.argmax(predictions, dim=1)
|
| 244 |
+
scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)
|
| 245 |
+
pred_max = pred_max.detach().cpu()
|
| 246 |
+
|
| 247 |
+
pred_trans = [model.transitions[pred_max[idx]] for idx in range(len(states[0]))]
|
| 248 |
+
if is_legal:
|
| 249 |
+
for idx, (state, trans) in enumerate(zip(states[0], pred_trans)):
|
| 250 |
+
if not trans.is_legal(state, model):
|
| 251 |
+
_, indices = predictions[idx, :].sort(descending=True)
|
| 252 |
+
for index in indices:
|
| 253 |
+
if model.transitions[index].is_legal(state, model):
|
| 254 |
+
pred_trans[idx] = model.transitions[index]
|
| 255 |
+
scores[idx] = predictions[idx, index]
|
| 256 |
+
break
|
| 257 |
+
else: # yeah, else on a for loop, deal with it
|
| 258 |
+
pred_trans[idx] = None
|
| 259 |
+
scores[idx] = None
|
| 260 |
+
|
| 261 |
+
return predictions, pred_trans, scores.squeeze(1)
|
| 262 |
+
|
| 263 |
+
def bulk_apply(self, state_batch, transitions, fail=False):
|
| 264 |
+
new_states = []
|
| 265 |
+
|
| 266 |
+
states = list(zip(*[x.states for x in state_batch]))
|
| 267 |
+
states = [x.bulk_apply(y, transitions, fail=fail) for x, y in zip(self.models, states)]
|
| 268 |
+
states = list(zip(*states))
|
| 269 |
+
state_batch = [x._replace(states=y) for x, y in zip(state_batch, states)]
|
| 270 |
+
return state_batch
|
| 271 |
+
|
| 272 |
+
def parse_tagged_words(self, words, batch_size):
|
| 273 |
+
"""
|
| 274 |
+
This parses tagged words and returns a list of trees.
|
| 275 |
+
|
| 276 |
+
`parse_tagged_words` is useful at Pipeline time -
|
| 277 |
+
it takes words & tags and processes that into trees.
|
| 278 |
+
|
| 279 |
+
The tagged words should be represented:
|
| 280 |
+
one list per sentence
|
| 281 |
+
each sentence is a list of (word, tag)
|
| 282 |
+
The return value is a list of ParseTree objects
|
| 283 |
+
|
| 284 |
+
TODO: this really ought to be refactored with base_model
|
| 285 |
+
"""
|
| 286 |
+
logger.debug("Processing %d sentences", len(words))
|
| 287 |
+
self.eval()
|
| 288 |
+
|
| 289 |
+
sentence_iterator = iter(words)
|
| 290 |
+
treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)
|
| 291 |
+
|
| 292 |
+
results = [t.predictions[0].tree for t in treebank]
|
| 293 |
+
return results
|
| 294 |
+
|
| 295 |
+
def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
|
| 296 |
+
"""
|
| 297 |
+
Repeat transitions to build a list of trees from the input batches.
|
| 298 |
+
|
| 299 |
+
The data_iterator should be anything which returns the data for a parse task via next()
|
| 300 |
+
build_batch_fn is a function that turns that data into State objects
|
| 301 |
+
This will be called to generate batches of size batch_size until the data is exhausted
|
| 302 |
+
|
| 303 |
+
The return is a list of tuples: (gold_tree, [(predicted, score) ...])
|
| 304 |
+
gold_tree will be left blank if the data did not include gold trees
|
| 305 |
+
currently score is always 1.0, but the interface may be expanded
|
| 306 |
+
to get a score from the result of the parsing
|
| 307 |
+
|
| 308 |
+
transition_choice: which method of the model to use for
|
| 309 |
+
choosing the next transition
|
| 310 |
+
|
| 311 |
+
TODO: refactor with base_model
|
| 312 |
+
"""
|
| 313 |
+
treebank = []
|
| 314 |
+
treebank_indices = []
|
| 315 |
+
# this will produce tuples of states
|
| 316 |
+
# batch size lists of num models tuples
|
| 317 |
+
state_batch = build_batch_fn(batch_size, data_iterator)
|
| 318 |
+
batch_indices = list(range(len(state_batch)))
|
| 319 |
+
horizon_iterator = iter([])
|
| 320 |
+
|
| 321 |
+
if keep_constituents:
|
| 322 |
+
constituents = defaultdict(list)
|
| 323 |
+
|
| 324 |
+
while len(state_batch) > 0:
|
| 325 |
+
pred_scores, transitions, scores = transition_choice(state_batch)
|
| 326 |
+
# num models lists of batch size states
|
| 327 |
+
state_batch = self.bulk_apply(state_batch, transitions)
|
| 328 |
+
|
| 329 |
+
remove = set()
|
| 330 |
+
for idx, states in enumerate(state_batch):
|
| 331 |
+
if states.finished(self):
|
| 332 |
+
predicted_tree = states.get_tree(self)
|
| 333 |
+
if self.reverse_sentence:
|
| 334 |
+
predicted_tree = predicted_tree.reverse()
|
| 335 |
+
gold_tree = states.gold_tree
|
| 336 |
+
# TODO: could easily store the score here
|
| 337 |
+
# not sure what it means to store the state,
|
| 338 |
+
# since each model is tracking its own state
|
| 339 |
+
treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, None)], None, None))
|
| 340 |
+
treebank_indices.append(batch_indices[idx])
|
| 341 |
+
remove.add(idx)
|
| 342 |
+
|
| 343 |
+
if len(remove) > 0:
|
| 344 |
+
state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]
|
| 345 |
+
batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]
|
| 346 |
+
|
| 347 |
+
for _ in range(batch_size - len(state_batch)):
|
| 348 |
+
horizon_state = next(horizon_iterator, None)
|
| 349 |
+
if not horizon_state:
|
| 350 |
+
horizon_batch = build_batch_fn(batch_size, data_iterator)
|
| 351 |
+
if len(horizon_batch) == 0:
|
| 352 |
+
break
|
| 353 |
+
horizon_iterator = iter(horizon_batch)
|
| 354 |
+
horizon_state = next(horizon_iterator, None)
|
| 355 |
+
|
| 356 |
+
state_batch.append(horizon_state)
|
| 357 |
+
batch_indices.append(len(treebank) + len(state_batch))
|
| 358 |
+
|
| 359 |
+
treebank = utils.unsort(treebank, treebank_indices)
|
| 360 |
+
return treebank
|
| 361 |
+
|
| 362 |
+
def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
|
| 363 |
+
with torch.no_grad():
|
| 364 |
+
return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)
|
| 365 |
+
|
| 366 |
+
class EnsembleTrainer(BaseTrainer):
|
| 367 |
+
"""
|
| 368 |
+
Stores a list of constituency models, useful for combining their results into one stronger model
|
| 369 |
+
"""
|
| 370 |
+
def __init__(self, ensemble, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
|
| 371 |
+
super().__init__(ensemble, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)
|
| 372 |
+
|
| 373 |
+
@staticmethod
|
| 374 |
+
def from_files(args, filenames, foundation_cache=None):
|
| 375 |
+
ensemble = Ensemble(args, filenames, foundation_cache=foundation_cache)
|
| 376 |
+
ensemble = ensemble.to(args.get('device', None))
|
| 377 |
+
return EnsembleTrainer(ensemble)
|
| 378 |
+
|
| 379 |
+
def get_peft_params(self):
|
| 380 |
+
params = []
|
| 381 |
+
for model in self.model.models:
|
| 382 |
+
if model.args.get('use_peft', False):
|
| 383 |
+
from peft import get_peft_model_state_dict
|
| 384 |
+
params.append(get_peft_model_state_dict(model.bert_model, adapter_name=model.peft_name))
|
| 385 |
+
else:
|
| 386 |
+
params.append(None)
|
| 387 |
+
|
| 388 |
+
return params
|
| 389 |
+
|
| 390 |
+
@property
|
| 391 |
+
def model_type(self):
|
| 392 |
+
return ModelType.ENSEMBLE
|
| 393 |
+
|
| 394 |
+
def log_num_words_known(self, words):
|
| 395 |
+
nwk = [m.num_words_known(words) for m in self.model.models]
|
| 396 |
+
if all(x == nwk[0] for x in nwk):
|
| 397 |
+
logger.info("Number of words in the training set known to each sub-model: %d out of %d", nwk[0], len(words))
|
| 398 |
+
else:
|
| 399 |
+
logger.info("Number of words in the training set known to the sub-models:\n %s" % "\n ".join(["%d/%d" % (x, len(words)) for x in nwk]))
|
| 400 |
+
|
| 401 |
+
@staticmethod
|
| 402 |
+
def build_optimizer(args, model, first_optimizer):
|
| 403 |
+
def fake_named_parameters():
|
| 404 |
+
for n, p in model.named_parameters():
|
| 405 |
+
if not n.startswith("models."):
|
| 406 |
+
yield n, p
|
| 407 |
+
|
| 408 |
+
# TODO: there has to be a cleaner way to do this, like maybe a "keep" callback
|
| 409 |
+
# TODO: if we finetune the underlying models, we will want a series of optimizers
|
| 410 |
+
# so that they can have a different learning rate from the ensemble's fields
|
| 411 |
+
fake_model = copy.copy(model)
|
| 412 |
+
fake_model.named_parameters = fake_named_parameters
|
| 413 |
+
optimizer = build_optimizer(args, fake_model, first_optimizer)
|
| 414 |
+
return optimizer
|
| 415 |
+
|
| 416 |
+
@staticmethod
|
| 417 |
+
def load_optimizer(model, checkpoint, first_optimizer, filename):
|
| 418 |
+
optimizer = EnsembleTrainer.build_optimizer(model.models[0].args, model, first_optimizer)
|
| 419 |
+
if checkpoint.get('optimizer_state_dict', None) is not None:
|
| 420 |
+
try:
|
| 421 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 422 |
+
except ValueError as e:
|
| 423 |
+
raise ValueError("Failed to load optimizer from %s" % filename) from e
|
| 424 |
+
else:
|
| 425 |
+
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
|
| 426 |
+
return optimizer
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def load_scheduler(model, optimizer, checkpoint, first_optimizer):
|
| 430 |
+
scheduler = build_scheduler(model.models[0].args, optimizer, first_optimizer=first_optimizer)
|
| 431 |
+
if 'scheduler_state_dict' in checkpoint:
|
| 432 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 433 |
+
return scheduler
|
| 434 |
+
|
| 435 |
+
@staticmethod
|
| 436 |
+
def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):
|
| 437 |
+
# TODO: no need for the if/else once the models are rebuilt
|
| 438 |
+
children_params = params["children_params"] if isinstance(params, dict) else params
|
| 439 |
+
base_params = params["base_params"] if isinstance(params, dict) else {}
|
| 440 |
+
|
| 441 |
+
# TODO: fill in peft_name
|
| 442 |
+
if peft_params is None:
|
| 443 |
+
peft_params = [None] * len(children_params)
|
| 444 |
+
if peft_name is None:
|
| 445 |
+
peft_name = [None] * len(children_params)
|
| 446 |
+
|
| 447 |
+
if len(children_params) != len(peft_params):
|
| 448 |
+
raise ValueError("Model file had params length %d and peft params length %d" % (len(params), len(peft_params)))
|
| 449 |
+
if len(children_params) != len(peft_name):
|
| 450 |
+
raise ValueError("Model file had params length %d and peft name length %d" % (len(params), len(peft_name)))
|
| 451 |
+
|
| 452 |
+
models = [Trainer.model_from_params(model_param, peft_param, args, foundation_cache, peft_name=pname)
|
| 453 |
+
for model_param, peft_param, pname in zip(children_params, peft_params, peft_name)]
|
| 454 |
+
ensemble = Ensemble(args, models=models)
|
| 455 |
+
ensemble.load_state_dict(base_params, strict=False)
|
| 456 |
+
ensemble = ensemble.to(args.get('device', None))
|
| 457 |
+
return ensemble
|
| 458 |
+
|
| 459 |
+
def parse_args(args=None):
|
| 460 |
+
parser = argparse.ArgumentParser()
|
| 461 |
+
|
| 462 |
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
|
| 463 |
+
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
|
| 464 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 465 |
+
|
| 466 |
+
utils.add_device_args(parser)
|
| 467 |
+
|
| 468 |
+
parser.add_argument('--lang', default='en', help='Language to use')
|
| 469 |
+
|
| 470 |
+
parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load")
|
| 471 |
+
|
| 472 |
+
parser.add_argument('--save_name', type=str, default=None, required=True, help='Where to save the combined ensemble')
|
| 473 |
+
|
| 474 |
+
args = vars(parser.parse_args())
|
| 475 |
+
|
| 476 |
+
return args
|
| 477 |
+
|
| 478 |
+
def main(args=None):
|
| 479 |
+
args = parse_args(args)
|
| 480 |
+
foundation_cache = FoundationCache()
|
| 481 |
+
|
| 482 |
+
ensemble = EnsembleTrainer.from_files(args, args['models'], foundation_cache)
|
| 483 |
+
ensemble.save(args['save_name'], save_optimizer=False)
|
| 484 |
+
|
| 485 |
+
if __name__ == "__main__":
|
| 486 |
+
main()
|
stanza/stanza/models/constituency/in_order_compound_oracle.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, DynamicOracle
|
| 4 |
+
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, CompoundUnary, Finalize
|
| 5 |
+
|
| 6 |
+
def fix_missing_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 7 |
+
"""
|
| 8 |
+
A CompoundUnary transition was missed after a Shift, but the sequence was continued correctly otherwise
|
| 9 |
+
"""
|
| 10 |
+
if not isinstance(gold_transition, CompoundUnary):
|
| 11 |
+
return None
|
| 12 |
+
|
| 13 |
+
if pred_transition != gold_sequence[gold_index + 1]:
|
| 14 |
+
return None
|
| 15 |
+
if isinstance(pred_transition, Finalize):
|
| 16 |
+
# this can happen if the entire tree is a single word
|
| 17 |
+
# but it can't be fixed if it means the parser missed the ROOT transition
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
return gold_sequence[:gold_index] + gold_sequence[gold_index+1:]
|
| 21 |
+
|
| 22 |
+
def fix_wrong_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 23 |
+
if not isinstance(gold_transition, CompoundUnary):
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
if not isinstance(pred_transition, CompoundUnary):
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
assert gold_transition != pred_transition
|
| 30 |
+
|
| 31 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]
|
| 32 |
+
|
| 33 |
+
def fix_spurious_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 34 |
+
if isinstance(gold_transition, CompoundUnary):
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
if not isinstance(pred_transition, CompoundUnary):
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:]
|
| 41 |
+
|
| 42 |
+
def fix_open_shift_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 43 |
+
"""
|
| 44 |
+
Fix a missed Open constituent where we predicted a Shift and the next transition was a Shift
|
| 45 |
+
|
| 46 |
+
In fact, the subsequent transition MUST be a Shift with this transition scheme
|
| 47 |
+
"""
|
| 48 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
if not isinstance(pred_transition, Shift):
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
#if not isinstance(gold_sequence[gold_index+1], Shift):
|
| 55 |
+
# return None
|
| 56 |
+
assert isinstance(gold_sequence[gold_index+1], Shift)
|
| 57 |
+
|
| 58 |
+
# close_index represents the Close for the missing Open
|
| 59 |
+
close_index = advance_past_constituents(gold_sequence, gold_index+1)
|
| 60 |
+
assert close_index is not None
|
| 61 |
+
return gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]
|
| 62 |
+
|
| 63 |
+
def fix_open_open_two_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 64 |
+
if gold_transition == pred_transition:
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 68 |
+
return None
|
| 69 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
|
| 73 |
+
if isinstance(gold_sequence[block_end], Shift):
|
| 74 |
+
# this is a multiple subtrees version of this error
|
| 75 |
+
# we are only skipping the two subtrees errors for now
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
# no fix is possible, so we just return here
|
| 79 |
+
return RepairType.OPEN_OPEN_TWO_SUBTREES_ERROR, None
|
| 80 |
+
|
| 81 |
+
def fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three):
|
| 82 |
+
if gold_transition == pred_transition:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 86 |
+
return None
|
| 87 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
|
| 91 |
+
if not isinstance(gold_sequence[block_end], Shift):
|
| 92 |
+
# this is a multiple subtrees version of this error
|
| 93 |
+
# we are only skipping the two subtrees errors for now
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
next_block_end = find_in_order_constituent_end(gold_sequence, block_end+1)
|
| 97 |
+
if exactly_three and isinstance(gold_sequence[next_block_end], Shift):
|
| 98 |
+
# for exactly three subtrees,
|
| 99 |
+
# we can put back the missing open transition
|
| 100 |
+
# and now we have no recall error, only precision error
|
| 101 |
+
# for more than three, we separate that out as an ambiguous choice
|
| 102 |
+
return None
|
| 103 |
+
elif not exactly_three and isinstance(gold_sequence[next_block_end], CloseConstituent):
|
| 104 |
+
# this is ambiguous, but we can still try this fix
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
# at this point, we build a new sequence with the origin constituent inserted
|
| 108 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def fix_open_open_three_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 112 |
+
return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=True)
|
| 113 |
+
|
| 114 |
+
def fix_open_open_many_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 115 |
+
return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=False)
|
| 116 |
+
|
| 117 |
+
def fix_open_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 118 |
+
"""
|
| 119 |
+
Find the closed bracket, reopen it
|
| 120 |
+
|
| 121 |
+
The Open we just missed must be forgotten - it cannot be reopened
|
| 122 |
+
"""
|
| 123 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
if not isinstance(pred_transition, CloseConstituent):
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
# find the appropriate Open so we can reopen it
|
| 130 |
+
open_idx = find_previous_open(gold_sequence, gold_index)
|
| 131 |
+
# actually, if the Close is legal, this can't happen
|
| 132 |
+
# but it might happen in a unit test which doesn't check legality
|
| 133 |
+
if open_idx is None:
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
# also, since we are punting on the missed Open, we need to skip
|
| 137 |
+
# the Close which would have closed it
|
| 138 |
+
close_idx = advance_past_constituents(gold_sequence, gold_index+1)
|
| 139 |
+
|
| 140 |
+
return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index+1:close_idx] + gold_sequence[close_idx+1:]
|
| 141 |
+
|
| 142 |
+
def fix_shift_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 143 |
+
"""
|
| 144 |
+
Find the closed bracket, reopen it
|
| 145 |
+
"""
|
| 146 |
+
if not isinstance(gold_transition, Shift):
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
if not isinstance(pred_transition, CloseConstituent):
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
# don't do this at the start or immediately after opening
|
| 153 |
+
if gold_index == 0 or isinstance(gold_sequence[gold_index - 1], OpenConstituent):
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
open_idx = find_previous_open(gold_sequence, gold_index)
|
| 157 |
+
assert open_idx is not None
|
| 158 |
+
|
| 159 |
+
return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index:]
|
| 160 |
+
|
| 161 |
+
def fix_shift_open_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 162 |
+
if not isinstance(gold_transition, Shift):
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
bracket_end = find_in_order_constituent_end(gold_sequence, gold_index)
|
| 169 |
+
assert bracket_end is not None
|
| 170 |
+
if isinstance(gold_sequence[bracket_end], Shift):
|
| 171 |
+
# this is an ambiguous error
|
| 172 |
+
# multiple possible places to end the wrong constituent
|
| 173 |
+
return None
|
| 174 |
+
assert isinstance(gold_sequence[bracket_end], CloseConstituent)
|
| 175 |
+
|
| 176 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]
|
| 177 |
+
|
| 178 |
+
def fix_close_shift_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 179 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
if not isinstance(pred_transition, Shift):
|
| 183 |
+
return None
|
| 184 |
+
if not isinstance(gold_sequence[gold_index+1], Shift):
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
bracket_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
|
| 188 |
+
assert bracket_end is not None
|
| 189 |
+
if isinstance(gold_sequence[bracket_end], Shift):
|
| 190 |
+
# this is an ambiguous error
|
| 191 |
+
# multiple possible places to end the wrong constituent
|
| 192 |
+
return None
|
| 193 |
+
assert isinstance(gold_sequence[bracket_end], CloseConstituent)
|
| 194 |
+
|
| 195 |
+
return gold_sequence[:gold_index] + gold_sequence[gold_index+1:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]
|
| 196 |
+
|
| 197 |
+
class RepairType(Enum):
|
| 198 |
+
"""
|
| 199 |
+
Keep track of which repair is used, if any, on an incorrect transition
|
| 200 |
+
|
| 201 |
+
Effects of different repair types:
|
| 202 |
+
no oracle: 0.9251 0.9226
|
| 203 |
+
+missing_unary: 0.9246 0.9214
|
| 204 |
+
+wrong_unary: 0.9236 0.9213
|
| 205 |
+
+spurious_unary: 0.9247 0.9229
|
| 206 |
+
+open_shift_error: 0.9258 0.9226
|
| 207 |
+
+open_open_two_subtrees: 0.9256 0.9215 # nothing changes with this one...
|
| 208 |
+
+open_open_three_subtrees: 0.9256 0.9226
|
| 209 |
+
+open_open_many_subtrees: 0.9257 0.9234
|
| 210 |
+
+shift_close: 0.9267 0.9250
|
| 211 |
+
+shift_open: 0.9273 0.9247
|
| 212 |
+
+close_shift: 0.9266 0.9229
|
| 213 |
+
+open_close: 0.9267 0.9256
|
| 214 |
+
"""
|
| 215 |
+
def __new__(cls, fn, correct=False, debug=False):
|
| 216 |
+
"""
|
| 217 |
+
Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
|
| 218 |
+
"""
|
| 219 |
+
value = len(cls.__members__)
|
| 220 |
+
obj = object.__new__(cls)
|
| 221 |
+
obj._value_ = value + 1
|
| 222 |
+
obj.fn = fn
|
| 223 |
+
obj.correct = correct
|
| 224 |
+
obj.debug = debug
|
| 225 |
+
return obj
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def is_correct(self):
|
| 229 |
+
return self.correct
|
| 230 |
+
|
| 231 |
+
# The correct sequence went Shift - Unary - Stuff
|
| 232 |
+
# but the CompoundUnary was missed and Stuff predicted
|
| 233 |
+
# so now we just proceed as if nothing happened
|
| 234 |
+
# note that CompoundUnary happens immediately after a Shift
|
| 235 |
+
# complicated nodes are created with single Open transitions
|
| 236 |
+
MISSING_UNARY_ERROR = (fix_missing_unary_error,)
|
| 237 |
+
|
| 238 |
+
# Predicted a wrong CompoundUnary. No way to fix this, so just keep going
|
| 239 |
+
WRONG_UNARY_ERROR = (fix_wrong_unary_error,)
|
| 240 |
+
|
| 241 |
+
# The correct sequence went Shift - Stuff
|
| 242 |
+
# but instead we predicted a CompoundUnary
|
| 243 |
+
# again, we just keep going
|
| 244 |
+
SPURIOUS_UNARY_ERROR = (fix_spurious_unary_error,)
|
| 245 |
+
|
| 246 |
+
# Were supposed to open a new constituent,
|
| 247 |
+
# but instead shifted an item onto the stack
|
| 248 |
+
#
|
| 249 |
+
# The missed Open cannot be recovered
|
| 250 |
+
#
|
| 251 |
+
# One could ask, is it possible to open a bigger constituent later,
|
| 252 |
+
# but if the constituent patterns go
|
| 253 |
+
# X (good open) Y (missed open) Z
|
| 254 |
+
# when we eventually close Y and Z, because of the missed Open,
|
| 255 |
+
# it is guaranteed to capture X as well
|
| 256 |
+
# since it will grab constituents until one left of the previous Open before Y
|
| 257 |
+
#
|
| 258 |
+
# Therefore, in this case, we must simply forget about this Open (recall error)
|
| 259 |
+
OPEN_SHIFT_ERROR = (fix_open_shift_error,)
|
| 260 |
+
|
| 261 |
+
# With this transition scheme, it is not possible to fix the following pattern:
|
| 262 |
+
# T1 O_x T2 C -> T1 O_y T2 C
|
| 263 |
+
# seeing as how there are no unary transitions
|
| 264 |
+
# so whatever precision & recall errors are caused by substituting O_x -> O_y
|
| 265 |
+
# (which could include multiple transitions)
|
| 266 |
+
# those errors are unfixable in any way
|
| 267 |
+
OPEN_OPEN_TWO_SUBTREES_ERROR = (fix_open_open_two_subtrees_error,)
|
| 268 |
+
|
| 269 |
+
# With this transition scheme, a three subtree branch with a wrong Open
|
| 270 |
+
# has a non-ambiguous fix
|
| 271 |
+
# T1 O_x T2 T3 C -> T1 O_y T2 T3 C
|
| 272 |
+
# this can become
|
| 273 |
+
# T1 O_y T2 C O_x T3 C
|
| 274 |
+
# now there are precision errors from the incorrectly added transition(s),
|
| 275 |
+
# but the correctly replaced transitions are unambiguous
|
| 276 |
+
OPEN_OPEN_THREE_SUBTREES_ERROR = (fix_open_open_three_subtrees_error,)
|
| 277 |
+
|
| 278 |
+
# We were supposed to shift a new item onto the stack,
|
| 279 |
+
# but instead we closed the previous constituent
|
| 280 |
+
# This causes a precision error, but we can avoid the recall error
|
| 281 |
+
# by immediately reopening the closed constituent.
|
| 282 |
+
SHIFT_CLOSE_ERROR = (fix_shift_close_error,)
|
| 283 |
+
|
| 284 |
+
# We opened a new constituent instead of shifting
|
| 285 |
+
# In the event that the next constituent ends with a close,
|
| 286 |
+
# rather than building another new constituent,
|
| 287 |
+
# then there is no ambiguity
|
| 288 |
+
SHIFT_OPEN_UNAMBIGUOUS_ERROR = (fix_shift_open_unambiguous_error,)
|
| 289 |
+
|
| 290 |
+
# Suppose we were supposed to Close, then Shift
|
| 291 |
+
# but instead we just did a Shift
|
| 292 |
+
# Similar to shift_open_unambiguous, we now have an opened
|
| 293 |
+
# constituent which shouldn't be there
|
| 294 |
+
# We can scroll past the next constituent created to see
|
| 295 |
+
# if the outer constituents close at that point
|
| 296 |
+
# If so, we can close this constituent as well in an unambiguous manner
|
| 297 |
+
# TODO: analyze the case where we were supposed to Close, Open
|
| 298 |
+
# but instead did a Shift
|
| 299 |
+
CLOSE_SHIFT_UNAMBIGUOUS_ERROR = (fix_close_shift_unambiguous_error,)
|
| 300 |
+
|
| 301 |
+
# Supposed to open a new constituent,
|
| 302 |
+
# instead closed an existing constituent
|
| 303 |
+
#
|
| 304 |
+
# X (good open) Y (open -> close) Z
|
| 305 |
+
#
|
| 306 |
+
# the constituent that should contain Y, Z is unfortunately lost
|
| 307 |
+
# since now the stack has
|
| 308 |
+
#
|
| 309 |
+
# XY ...
|
| 310 |
+
#
|
| 311 |
+
# furthermore, there is now a precision error for the extra XY
|
| 312 |
+
# constituent that should not exist
|
| 313 |
+
# however, what we can do to minimize further errors is
|
| 314 |
+
# to at least reopen the label between X and Y
|
| 315 |
+
OPEN_CLOSE_ERROR = (fix_open_close_error,)
|
| 316 |
+
|
| 317 |
+
# this is ambiguous, but we can still try the same fix as three_subtrees (see above)
|
| 318 |
+
OPEN_OPEN_MANY_SUBTREES_ERROR = (fix_open_open_many_subtrees_error,)
|
| 319 |
+
|
| 320 |
+
CORRECT = (None, True)
|
| 321 |
+
|
| 322 |
+
UNKNOWN = None
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class InOrderCompoundOracle(DynamicOracle):
|
| 326 |
+
def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
|
| 327 |
+
super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
|
stanza/stanza/models/constituency/in_order_oracle.py
ADDED
|
@@ -0,0 +1,1029 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, score_candidates, DynamicOracle, RepairEnum
|
| 4 |
+
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent
|
| 5 |
+
|
| 6 |
+
def fix_wrong_open_root_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 7 |
+
"""
|
| 8 |
+
If there is an open/open error specifically at the ROOT, close the wrong open and try again
|
| 9 |
+
"""
|
| 10 |
+
if gold_transition == pred_transition:
|
| 11 |
+
return None
|
| 12 |
+
|
| 13 |
+
if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent) and gold_transition.top_label in root_labels:
|
| 14 |
+
return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]
|
| 15 |
+
|
| 16 |
+
return None
|
| 17 |
+
|
| 18 |
+
def fix_wrong_open_unary_chain(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 19 |
+
"""
|
| 20 |
+
Fix a wrong open/open in a unary chain by removing the skipped unary transitions
|
| 21 |
+
|
| 22 |
+
Only applies is the wrong pred transition is a transition found higher up in the unary chain
|
| 23 |
+
"""
|
| 24 |
+
# useful to have this check here in case the call is made independently in a unit test
|
| 25 |
+
if gold_transition == pred_transition:
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent):
|
| 29 |
+
cur_index = gold_index + 1 # This is now a Close if we are in this particular context
|
| 30 |
+
while cur_index + 1 < len(gold_sequence) and isinstance(gold_sequence[cur_index], CloseConstituent) and isinstance(gold_sequence[cur_index+1], OpenConstituent):
|
| 31 |
+
cur_index = cur_index + 1 # advance to the next Open
|
| 32 |
+
if gold_sequence[cur_index] == pred_transition:
|
| 33 |
+
return gold_sequence[:gold_index] + gold_sequence[cur_index:]
|
| 34 |
+
cur_index = cur_index + 1 # advance to the next Close
|
| 35 |
+
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two):
|
| 39 |
+
if gold_transition == pred_transition:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 43 |
+
return None
|
| 44 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
|
| 48 |
+
# if Close, the gold was a unary
|
| 49 |
+
return None
|
| 50 |
+
assert not isinstance(gold_sequence[gold_index+1], OpenConstituent)
|
| 51 |
+
assert isinstance(gold_sequence[gold_index+1], Shift)
|
| 52 |
+
|
| 53 |
+
block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
|
| 54 |
+
assert block_end is not None
|
| 55 |
+
|
| 56 |
+
if more_than_two and isinstance(gold_sequence[block_end], CloseConstituent):
|
| 57 |
+
return None
|
| 58 |
+
if not more_than_two and isinstance(gold_sequence[block_end], Shift):
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]
|
| 62 |
+
|
| 63 |
+
def fix_wrong_open_two_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 64 |
+
return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=False)
|
| 65 |
+
|
| 66 |
+
def fix_wrong_open_multiple_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 67 |
+
return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=True)
|
| 68 |
+
|
| 69 |
+
def advance_past_unaries(gold_sequence, cur_index):
|
| 70 |
+
while cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index], OpenConstituent) and isinstance(gold_sequence[cur_index+1], CloseConstituent):
|
| 71 |
+
cur_index += 2
|
| 72 |
+
return cur_index
|
| 73 |
+
|
| 74 |
+
def fix_wrong_open_stuff_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 75 |
+
"""
|
| 76 |
+
Fix a wrong open/open when there is an intervening constituent and then the guessed NT
|
| 77 |
+
|
| 78 |
+
This happens when the correct pattern is
|
| 79 |
+
stuff_1 NT_X stuff_2 close NT_Y ...
|
| 80 |
+
and instead of guessing the gold transition NT_X,
|
| 81 |
+
the prediction was NT_Y
|
| 82 |
+
"""
|
| 83 |
+
if gold_transition == pred_transition:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 87 |
+
return None
|
| 88 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 89 |
+
return None
|
| 90 |
+
# TODO: Here we could advance past unary transitions while
|
| 91 |
+
# watching for hitting pred_transition. However, that is an open
|
| 92 |
+
# question... is it better to try to keep such an Open as part of
|
| 93 |
+
# the sequence, or is it better to skip them and attach the inner
|
| 94 |
+
# nodes to the upper level
|
| 95 |
+
stuff_start = gold_index + 1
|
| 96 |
+
if not isinstance(gold_sequence[stuff_start], Shift):
|
| 97 |
+
return None
|
| 98 |
+
stuff_end = advance_past_constituents(gold_sequence, stuff_start)
|
| 99 |
+
if stuff_end is None:
|
| 100 |
+
return None
|
| 101 |
+
# at this point, stuff_end points to the Close which occurred after stuff_2
|
| 102 |
+
# also, stuff_start points to the first transition which makes stuff_2, the Shift
|
| 103 |
+
cur_index = stuff_end + 1
|
| 104 |
+
while isinstance(gold_sequence[cur_index], OpenConstituent):
|
| 105 |
+
if gold_sequence[cur_index] == pred_transition:
|
| 106 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index+1:]
|
| 107 |
+
# this was an OpenConstituent, but not the OpenConstituent we guessed
|
| 108 |
+
# maybe there's a unary transition which lets us try again
|
| 109 |
+
if cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index + 1], CloseConstituent):
|
| 110 |
+
cur_index = cur_index + 2
|
| 111 |
+
else:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
# oh well, none of this worked
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
def fix_wrong_open_general(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 118 |
+
"""
|
| 119 |
+
Fix a general wrong open/open transition by accepting the open and continuing
|
| 120 |
+
|
| 121 |
+
A couple other open/open patterns have already been carved out
|
| 122 |
+
|
| 123 |
+
TODO: negative checks for the previous patterns, in case we turn those off
|
| 124 |
+
"""
|
| 125 |
+
if gold_transition == pred_transition:
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 129 |
+
return None
|
| 130 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 131 |
+
return None
|
| 132 |
+
# If the top is a ROOT, then replacing it with a non-ROOT creates an illegal
|
| 133 |
+
# transition sequence. The ROOT case was already handled elsewhere anyway
|
| 134 |
+
if gold_transition.top_label in root_labels:
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]
|
| 138 |
+
|
| 139 |
+
def fix_missed_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 140 |
+
"""
|
| 141 |
+
Fix a missed unary which is followed by an otherwise correct transition
|
| 142 |
+
|
| 143 |
+
(also handles multiple missed unary transitions)
|
| 144 |
+
"""
|
| 145 |
+
if gold_transition == pred_transition:
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
cur_index = gold_index
|
| 149 |
+
cur_index = advance_past_unaries(gold_sequence, cur_index)
|
| 150 |
+
if gold_sequence[cur_index] == pred_transition:
|
| 151 |
+
return gold_sequence[:gold_index] + gold_sequence[cur_index:]
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
def fix_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 155 |
+
"""
|
| 156 |
+
Fix an Open replaced with a Shift
|
| 157 |
+
|
| 158 |
+
Suppose we were supposed to guess NT_X and instead did S
|
| 159 |
+
|
| 160 |
+
We derive the repair as follows.
|
| 161 |
+
|
| 162 |
+
For simplicity, assume the open is not a unary for now
|
| 163 |
+
|
| 164 |
+
Since we know an Open was legal, there must be stuff
|
| 165 |
+
stuff NT_X
|
| 166 |
+
Shift is also legal, so there must be other stuff and a previous Open
|
| 167 |
+
stuff_1 NT_Y stuff_2 NT_X
|
| 168 |
+
After the NT_X which we missed, there was a bunch of stuff and a close for NT_X
|
| 169 |
+
stuff_1 NT_Y stuff_2 NT_X stuff_3 C
|
| 170 |
+
There could be more stuff here which can be saved...
|
| 171 |
+
stuff_1 NT_Y stuff_2 NT_X stuff_3 C stuff_4 C
|
| 172 |
+
stuff_1 NT_Y stuff_2 NT_X stuff_3 C C
|
| 173 |
+
"""
|
| 174 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 175 |
+
return None
|
| 176 |
+
if not isinstance(pred_transition, Shift):
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
cur_index = gold_index
|
| 180 |
+
cur_index = advance_past_unaries(gold_sequence, cur_index)
|
| 181 |
+
if not isinstance(gold_sequence[cur_index], OpenConstituent):
|
| 182 |
+
return None
|
| 183 |
+
if gold_sequence[cur_index].top_label in root_labels:
|
| 184 |
+
return None
|
| 185 |
+
# cur_index now points to the NT_X we missed (not counting unaries)
|
| 186 |
+
|
| 187 |
+
stuff_start = cur_index + 1
|
| 188 |
+
# can't be a Close, since we just went past an Open and checked for unaries
|
| 189 |
+
# can't be an Open, since two Open in a row is illegal
|
| 190 |
+
assert isinstance(gold_sequence[stuff_start], Shift)
|
| 191 |
+
stuff_end = advance_past_constituents(gold_sequence, stuff_start)
|
| 192 |
+
# stuff_end is now the Close which ends NT_X
|
| 193 |
+
cur_index = stuff_end + 1
|
| 194 |
+
if cur_index >= len(gold_sequence):
|
| 195 |
+
return None
|
| 196 |
+
if isinstance(gold_sequence[cur_index], OpenConstituent):
|
| 197 |
+
cur_index = advance_past_unaries(gold_sequence, cur_index)
|
| 198 |
+
if cur_index >= len(gold_sequence):
|
| 199 |
+
return None
|
| 200 |
+
if isinstance(gold_sequence[cur_index], OpenConstituent):
|
| 201 |
+
# an Open here signifies that there was a bracket containing X underneath Y
|
| 202 |
+
# TODO: perhaps try to salvage something out of that situation?
|
| 203 |
+
return None
|
| 204 |
+
# the repair starts with the sequence up through the error,
|
| 205 |
+
# then stuff_3, which includes the error
|
| 206 |
+
# skip the Close for the missed NT_X
|
| 207 |
+
# then finish the sequence with any potential stuff_4, the next Close, and everything else
|
| 208 |
+
repair = gold_sequence[:gold_index] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]
|
| 209 |
+
return repair
|
| 210 |
+
|
| 211 |
+
def fix_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 212 |
+
"""
|
| 213 |
+
Fix an Open replaced with a Close
|
| 214 |
+
|
| 215 |
+
Call the Open NT_X
|
| 216 |
+
Open legal, so there must be stuff:
|
| 217 |
+
stuff NT_X
|
| 218 |
+
Close legal, so there must be something to close:
|
| 219 |
+
stuff_1 NT_Y stuff_2 NT_X
|
| 220 |
+
|
| 221 |
+
The incorrect close makes the following brackets:
|
| 222 |
+
(Y stuff_1 stuff_2)
|
| 223 |
+
We were supposed to build
|
| 224 |
+
(Y stuff_1 (X stuff_2 ...) (possibly more stuff))
|
| 225 |
+
The simplest fix here is to reopen Y at this point.
|
| 226 |
+
|
| 227 |
+
One issue might be if there is another bracket which encloses X underneath Y
|
| 228 |
+
So, for example, the tree was supposed to be
|
| 229 |
+
(Y stuff_1 (Z (X stuff_2 stuff_3) stuff_4))
|
| 230 |
+
The pattern for this case is
|
| 231 |
+
stuff_1 NT_Y stuff_2 NY_X stuff_3 close NT_Z stuff_4 close close
|
| 232 |
+
"""
|
| 233 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 234 |
+
return None
|
| 235 |
+
if not isinstance(pred_transition, CloseConstituent):
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
cur_index = advance_past_unaries(gold_sequence, gold_index)
|
| 239 |
+
if cur_index >= len(gold_sequence):
|
| 240 |
+
return None
|
| 241 |
+
if not isinstance(gold_sequence[cur_index], OpenConstituent):
|
| 242 |
+
return None
|
| 243 |
+
if gold_sequence[cur_index].top_label in root_labels:
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
prev_open_index = find_previous_open(gold_sequence, gold_index)
|
| 247 |
+
if prev_open_index is None:
|
| 248 |
+
return None
|
| 249 |
+
prev_open = gold_sequence[prev_open_index]
|
| 250 |
+
# prev_open is now NT_Y from above
|
| 251 |
+
|
| 252 |
+
stuff_start = cur_index + 1
|
| 253 |
+
assert isinstance(gold_sequence[stuff_start], Shift)
|
| 254 |
+
stuff_end = advance_past_constituents(gold_sequence, stuff_start)
|
| 255 |
+
# stuff_end is now the Close which ends NT_X
|
| 256 |
+
# stuff_start:stuff_end is the stuff_3 block above
|
| 257 |
+
cur_index = stuff_end + 1
|
| 258 |
+
if cur_index >= len(gold_sequence):
|
| 259 |
+
return None
|
| 260 |
+
# if there are unary transitions here, we want to skip those.
|
| 261 |
+
# those are unary transitions on X and cannot be recovered, since X is gone
|
| 262 |
+
cur_index = advance_past_unaries(gold_sequence, cur_index)
|
| 263 |
+
# now there is a certain failure case which has to be accounted for.
|
| 264 |
+
|
| 265 |
+
# specifically, if there is a new non-terminal which opens
|
| 266 |
+
# immediately after X closes, it is encompassing X in a way that
|
| 267 |
+
# cannot be recovered now that part of X is stuck under Y.
|
| 268 |
+
# The two choices at this point would be to eliminate the new
|
| 269 |
+
# transition or just reject the tree from the repair
|
| 270 |
+
# For now, we reject the tree
|
| 271 |
+
if isinstance(gold_sequence[cur_index], OpenConstituent):
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
repair = gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]
|
| 275 |
+
return repair
|
| 276 |
+
|
| 277 |
+
def fix_shift_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 278 |
+
"""
|
| 279 |
+
This fixes Shift replaced with a Close transition.
|
| 280 |
+
|
| 281 |
+
This error occurs in the following pattern:
|
| 282 |
+
stuff_1 NT_X stuff... shift
|
| 283 |
+
Instead of shift, you close the NT_X
|
| 284 |
+
The easiest fix here is to just restore the NT_X.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
if not isinstance(pred_transition, CloseConstituent):
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
# this fix can also be applied if there were unaries on the
|
| 291 |
+
# previous constituent. we just skip those until the Shift
|
| 292 |
+
cur_index = gold_index
|
| 293 |
+
if isinstance(gold_transition, OpenConstituent):
|
| 294 |
+
cur_index = advance_past_unaries(gold_sequence, cur_index)
|
| 295 |
+
if not isinstance(gold_sequence[cur_index], Shift):
|
| 296 |
+
return None
|
| 297 |
+
|
| 298 |
+
prev_open_index = find_previous_open(gold_sequence, gold_index)
|
| 299 |
+
if prev_open_index is None:
|
| 300 |
+
return None
|
| 301 |
+
prev_open = gold_sequence[prev_open_index]
|
| 302 |
+
# prev_open is now NT_X from above
|
| 303 |
+
|
| 304 |
+
return gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[cur_index:]
|
| 305 |
+
|
| 306 |
+
def fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):
|
| 307 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 308 |
+
return None
|
| 309 |
+
if not isinstance(pred_transition, Shift):
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
if len(gold_sequence) < gold_index + 3:
|
| 313 |
+
return None
|
| 314 |
+
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
open_index = advance_past_unaries(gold_sequence, gold_index+1)
|
| 318 |
+
if not isinstance(gold_sequence[open_index], OpenConstituent):
|
| 319 |
+
return None
|
| 320 |
+
if not isinstance(gold_sequence[open_index+1], Shift):
|
| 321 |
+
return None
|
| 322 |
+
|
| 323 |
+
# check that the next operation was to open a *different* constituent
|
| 324 |
+
# from the one we just closed
|
| 325 |
+
prev_open_index = find_previous_open(gold_sequence, gold_index)
|
| 326 |
+
if prev_open_index is None:
|
| 327 |
+
return None
|
| 328 |
+
prev_open = gold_sequence[prev_open_index]
|
| 329 |
+
if gold_sequence[open_index] == prev_open:
|
| 330 |
+
return None
|
| 331 |
+
|
| 332 |
+
# check that the following stuff is a single bracket, not multiple brackets
|
| 333 |
+
end_index = find_in_order_constituent_end(gold_sequence, open_index+1)
|
| 334 |
+
if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):
|
| 335 |
+
return None
|
| 336 |
+
elif not ambiguous and isinstance(gold_sequence[end_index], Shift):
|
| 337 |
+
return None
|
| 338 |
+
|
| 339 |
+
# if closing at the end of the next blocks,
|
| 340 |
+
# instead of closing after the first block ends,
|
| 341 |
+
# we go to the end of the last block
|
| 342 |
+
if late:
|
| 343 |
+
end_index = advance_past_constituents(gold_sequence, open_index+1)
|
| 344 |
+
|
| 345 |
+
return gold_sequence[:gold_index] + gold_sequence[open_index+1:end_index] + gold_sequence[gold_index:open_index+1] + gold_sequence[end_index:]
|
| 346 |
+
|
| 347 |
+
def fix_close_open_shift_unambiguous_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 348 |
+
return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)
|
| 349 |
+
|
| 350 |
+
def fix_close_open_shift_ambiguous_bracket_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 351 |
+
return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)
|
| 352 |
+
|
| 353 |
+
def fix_close_open_shift_ambiguous_bracket_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 354 |
+
return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)
|
| 355 |
+
|
| 356 |
+
def fix_close_open_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 357 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 358 |
+
return None
|
| 359 |
+
if not isinstance(pred_transition, Shift):
|
| 360 |
+
return None
|
| 361 |
+
|
| 362 |
+
if len(gold_sequence) < gold_index + 3:
|
| 363 |
+
return None
|
| 364 |
+
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
|
| 365 |
+
return None
|
| 366 |
+
|
| 367 |
+
open_index = advance_past_unaries(gold_sequence, gold_index+1)
|
| 368 |
+
if not isinstance(gold_sequence[open_index], OpenConstituent):
|
| 369 |
+
return None
|
| 370 |
+
if not isinstance(gold_sequence[open_index+1], Shift):
|
| 371 |
+
return None
|
| 372 |
+
|
| 373 |
+
# check that the next operation was to open a *different* constituent
|
| 374 |
+
# from the one we just closed
|
| 375 |
+
prev_open_index = find_previous_open(gold_sequence, gold_index)
|
| 376 |
+
if prev_open_index is None:
|
| 377 |
+
return None
|
| 378 |
+
prev_open = gold_sequence[prev_open_index]
|
| 379 |
+
if gold_sequence[open_index] == prev_open:
|
| 380 |
+
return None
|
| 381 |
+
|
| 382 |
+
# alright, at long last we have:
|
| 383 |
+
# a close that was missed
|
| 384 |
+
# a non-nested open that was missed
|
| 385 |
+
end_index = find_in_order_constituent_end(gold_sequence, open_index+1)
|
| 386 |
+
|
| 387 |
+
candidates = []
|
| 388 |
+
candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))
|
| 389 |
+
while isinstance(gold_sequence[end_index], Shift):
|
| 390 |
+
end_index = find_in_order_constituent_end(gold_sequence, end_index+1)
|
| 391 |
+
candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))
|
| 392 |
+
|
| 393 |
+
scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
|
| 394 |
+
if len(candidates) == 1:
|
| 395 |
+
return RepairType.CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET, best_candidate
|
| 396 |
+
|
| 397 |
+
if best_idx == len(candidates) - 1:
|
| 398 |
+
best_idx = -1
|
| 399 |
+
repair_type = RepairEnum(name=RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.name,
|
| 400 |
+
value="%d.%d" % (RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),
|
| 401 |
+
is_correct=False)
|
| 402 |
+
return repair_type, best_candidate
|
| 403 |
+
|
| 404 |
+
def fix_close_open_shift_nested(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 405 |
+
"""
|
| 406 |
+
Fix a Close X..Open X..Shift pattern where both the Close and Open were skipped.
|
| 407 |
+
|
| 408 |
+
Here the pattern we are trying to fix is
|
| 409 |
+
stuff_A open_X stuff_B *close* open_X shift...
|
| 410 |
+
replaced with
|
| 411 |
+
stuff_A open_X stuff_B shift...
|
| 412 |
+
the missed close & open means a missed recall error for (X A B)
|
| 413 |
+
whereas the previous open_X can still get the outer bracket
|
| 414 |
+
"""
|
| 415 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 416 |
+
return None
|
| 417 |
+
if not isinstance(pred_transition, Shift):
|
| 418 |
+
return None
|
| 419 |
+
|
| 420 |
+
if len(gold_sequence) < gold_index + 3:
|
| 421 |
+
return None
|
| 422 |
+
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
|
| 423 |
+
return None
|
| 424 |
+
|
| 425 |
+
# handle the sequence:
|
| 426 |
+
# stuff_A open_X stuff_B close open_Y close open_X shift
|
| 427 |
+
open_index = advance_past_unaries(gold_sequence, gold_index+1)
|
| 428 |
+
if not isinstance(gold_sequence[open_index], OpenConstituent):
|
| 429 |
+
return None
|
| 430 |
+
if not isinstance(gold_sequence[open_index+1], Shift):
|
| 431 |
+
return None
|
| 432 |
+
|
| 433 |
+
# check that the next operation was to open the same constituent
|
| 434 |
+
# we just closed
|
| 435 |
+
prev_open_index = find_previous_open(gold_sequence, gold_index)
|
| 436 |
+
if prev_open_index is None:
|
| 437 |
+
return None
|
| 438 |
+
prev_open = gold_sequence[prev_open_index]
|
| 439 |
+
if gold_sequence[open_index] != prev_open:
|
| 440 |
+
return None
|
| 441 |
+
|
| 442 |
+
return gold_sequence[:gold_index] + gold_sequence[open_index+1:]
|
| 443 |
+
|
| 444 |
+
def fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):
|
| 445 |
+
"""
|
| 446 |
+
Repair Close/Shift -> Shift by moving the Close to after the next block is created
|
| 447 |
+
"""
|
| 448 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 449 |
+
return None
|
| 450 |
+
if not isinstance(pred_transition, Shift):
|
| 451 |
+
return None
|
| 452 |
+
if len(gold_sequence) < gold_index + 2:
|
| 453 |
+
return None
|
| 454 |
+
start_index = gold_index + 1
|
| 455 |
+
start_index = advance_past_unaries(gold_sequence, start_index)
|
| 456 |
+
if len(gold_sequence) < start_index + 2:
|
| 457 |
+
return None
|
| 458 |
+
if not isinstance(gold_sequence[start_index], Shift):
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
end_index = find_in_order_constituent_end(gold_sequence, start_index)
|
| 462 |
+
if end_index is None:
|
| 463 |
+
return None
|
| 464 |
+
# if this *isn't* a close, we don't allow it in the unambiguous case
|
| 465 |
+
# that case seems to be ambiguous...
|
| 466 |
+
# stuff_1 close stuff_2 stuff_3
|
| 467 |
+
# if you would normally start building stuff_3,
|
| 468 |
+
# it is not clear if you want to close at the end of
|
| 469 |
+
# stuff_2 or build stuff_3 instead.
|
| 470 |
+
if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):
|
| 471 |
+
return None
|
| 472 |
+
elif not ambiguous and isinstance(gold_sequence[end_index], Shift):
|
| 473 |
+
return None
|
| 474 |
+
|
| 475 |
+
# close at the end of the brackets, rather than once the first bracket is finished
|
| 476 |
+
if late:
|
| 477 |
+
end_index = advance_past_constituents(gold_sequence, start_index)
|
| 478 |
+
|
| 479 |
+
return gold_sequence[:gold_index] + gold_sequence[start_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
|
| 480 |
+
|
| 481 |
+
def fix_close_shift_shift_unambiguous(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 482 |
+
return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)
|
| 483 |
+
|
| 484 |
+
def fix_close_shift_shift_ambiguous_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 485 |
+
return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)
|
| 486 |
+
|
| 487 |
+
def fix_close_shift_shift_ambiguous_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 488 |
+
return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)
|
| 489 |
+
|
| 490 |
+
def fix_close_shift_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 491 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 492 |
+
return None
|
| 493 |
+
if not isinstance(pred_transition, Shift):
|
| 494 |
+
return None
|
| 495 |
+
if len(gold_sequence) < gold_index + 2:
|
| 496 |
+
return None
|
| 497 |
+
start_index = gold_index + 1
|
| 498 |
+
start_index = advance_past_unaries(gold_sequence, start_index)
|
| 499 |
+
if len(gold_sequence) < start_index + 2:
|
| 500 |
+
return None
|
| 501 |
+
if not isinstance(gold_sequence[start_index], Shift):
|
| 502 |
+
return None
|
| 503 |
+
|
| 504 |
+
# now we know that the gold pattern was
|
| 505 |
+
# Close (unaries) Shift
|
| 506 |
+
# and instead the model predicted Shift
|
| 507 |
+
candidates = []
|
| 508 |
+
current_index = start_index
|
| 509 |
+
while isinstance(gold_sequence[current_index], Shift):
|
| 510 |
+
current_index = find_in_order_constituent_end(gold_sequence, current_index)
|
| 511 |
+
assert current_index is not None
|
| 512 |
+
candidates.append((gold_sequence[:gold_index], gold_sequence[start_index:current_index], [CloseConstituent()], gold_sequence[current_index:]))
|
| 513 |
+
scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
|
| 514 |
+
if len(candidates) == 1:
|
| 515 |
+
return RepairType.CLOSE_SHIFT_SHIFT, best_candidate
|
| 516 |
+
if best_idx == len(candidates) - 1:
|
| 517 |
+
best_idx = -1
|
| 518 |
+
repair_type = RepairEnum(name=RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.name,
|
| 519 |
+
value="%d.%d" % (RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),
|
| 520 |
+
is_correct=False)
|
| 521 |
+
#print(best_idx, len(candidates), repair_type)
|
| 522 |
+
return repair_type, best_candidate
|
| 523 |
+
|
| 524 |
+
def ambiguous_shift_open_unary_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 525 |
+
if not isinstance(gold_transition, Shift):
|
| 526 |
+
return None
|
| 527 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 528 |
+
return None
|
| 529 |
+
|
| 530 |
+
return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]
|
| 531 |
+
|
| 532 |
+
def ambiguous_shift_open_early_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 533 |
+
if not isinstance(gold_transition, Shift):
|
| 534 |
+
return None
|
| 535 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 536 |
+
return None
|
| 537 |
+
|
| 538 |
+
# Find when the current block ends,
|
| 539 |
+
# either via a Shift or a Close
|
| 540 |
+
end_index = find_in_order_constituent_end(gold_sequence, gold_index)
|
| 541 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
|
| 542 |
+
|
| 543 |
+
def ambiguous_shift_open_late_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 544 |
+
if not isinstance(gold_transition, Shift):
|
| 545 |
+
return None
|
| 546 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 547 |
+
return None
|
| 548 |
+
|
| 549 |
+
end_index = advance_past_constituents(gold_sequence, gold_index)
|
| 550 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
|
| 551 |
+
|
| 552 |
+
def ambiguous_shift_open_predicted_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 553 |
+
if not isinstance(gold_transition, Shift):
|
| 554 |
+
return None
|
| 555 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 556 |
+
return None
|
| 557 |
+
|
| 558 |
+
unary_candidate = (gold_sequence[:gold_index], [pred_transition], [CloseConstituent()], gold_sequence[gold_index:])
|
| 559 |
+
|
| 560 |
+
early_index = find_in_order_constituent_end(gold_sequence, gold_index)
|
| 561 |
+
early_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:early_index], [CloseConstituent()], gold_sequence[early_index:])
|
| 562 |
+
|
| 563 |
+
late_index = advance_past_constituents(gold_sequence, gold_index)
|
| 564 |
+
if early_index == late_index:
|
| 565 |
+
candidates = [unary_candidate, early_candidate]
|
| 566 |
+
scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
|
| 567 |
+
if best_idx == 0:
|
| 568 |
+
return_label = "U"
|
| 569 |
+
else:
|
| 570 |
+
return_label = "S"
|
| 571 |
+
else:
|
| 572 |
+
late_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:late_index], [CloseConstituent()], gold_sequence[late_index:])
|
| 573 |
+
candidates = [unary_candidate, early_candidate, late_candidate]
|
| 574 |
+
scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
|
| 575 |
+
if best_idx == 0:
|
| 576 |
+
return_label = "U"
|
| 577 |
+
elif best_idx == 1:
|
| 578 |
+
return_label = "E"
|
| 579 |
+
else:
|
| 580 |
+
return_label = "L"
|
| 581 |
+
repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_PREDICTED_CLOSE.name,
|
| 582 |
+
value="%d.%s" % (RepairType.SHIFT_OPEN_PREDICTED_CLOSE.value, return_label),
|
| 583 |
+
is_correct=False)
|
| 584 |
+
return repair_type, best_candidate
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 588 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 589 |
+
return None
|
| 590 |
+
if not isinstance(pred_transition, Shift):
|
| 591 |
+
return None
|
| 592 |
+
|
| 593 |
+
return RepairType.OTHER_CLOSE_SHIFT, None
|
| 594 |
+
|
| 595 |
+
def report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 596 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 597 |
+
return None
|
| 598 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 599 |
+
return None
|
| 600 |
+
|
| 601 |
+
return RepairType.OTHER_CLOSE_OPEN, None
|
| 602 |
+
|
| 603 |
+
def report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 604 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 605 |
+
return None
|
| 606 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 607 |
+
return None
|
| 608 |
+
|
| 609 |
+
return RepairType.OTHER_OPEN_OPEN, None
|
| 610 |
+
|
| 611 |
+
def report_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 612 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 613 |
+
return None
|
| 614 |
+
if not isinstance(pred_transition, Shift):
|
| 615 |
+
return None
|
| 616 |
+
|
| 617 |
+
return RepairType.OTHER_OPEN_SHIFT, None
|
| 618 |
+
|
| 619 |
+
def report_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 620 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 621 |
+
return None
|
| 622 |
+
if not isinstance(pred_transition, CloseConstituent):
|
| 623 |
+
return None
|
| 624 |
+
|
| 625 |
+
return RepairType.OTHER_OPEN_CLOSE, None
|
| 626 |
+
|
| 627 |
+
def report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 628 |
+
if not isinstance(gold_transition, Shift):
|
| 629 |
+
return None
|
| 630 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 631 |
+
return None
|
| 632 |
+
|
| 633 |
+
return RepairType.OTHER_SHIFT_OPEN, None
|
| 634 |
+
|
| 635 |
+
class RepairType(Enum):
|
| 636 |
+
"""
|
| 637 |
+
Keep track of which repair is used, if any, on an incorrect transition
|
| 638 |
+
|
| 639 |
+
Statistics on English w/ no charlm, no transformer,
|
| 640 |
+
eg word vectors only, best model as of January 2024
|
| 641 |
+
|
| 642 |
+
unambiguous transitions only:
|
| 643 |
+
oracle scheme dev test
|
| 644 |
+
no oracle 0.9245 0.9226
|
| 645 |
+
+wrong_open_root 0.9244 0.9224
|
| 646 |
+
+wrong_unary_chain 0.9243 0.9237
|
| 647 |
+
+wrong_open_unary 0.9249 0.9223
|
| 648 |
+
+wrong_open_general 0.9251 0.9215
|
| 649 |
+
+missed_unary 0.9248 0.9215
|
| 650 |
+
+open_shift 0.9243 0.9216
|
| 651 |
+
+open_close 0.9254 0.9217
|
| 652 |
+
+shift_close 0.9261 0.9238
|
| 653 |
+
+close_shift_nested 0.9253 0.9250
|
| 654 |
+
|
| 655 |
+
Redoing the wrong_open_general, which seemed to hurt test scores:
|
| 656 |
+
wrong_open_two_subtrees - L4 0.9244 0.9220
|
| 657 |
+
every else w/o ambiguous open/open fix 0.9259 0.9241
|
| 658 |
+
everything w/ open_two_subtrees 0.9261 0.9246
|
| 659 |
+
w/ ambiguous open_three_subtrees 0.9264 0.9243
|
| 660 |
+
|
| 661 |
+
Testing three different possible repairs for shift-open:
|
| 662 |
+
w/ ambiguous open_three_subtrees 0.9264 0.9243
|
| 663 |
+
immediate close (unary) 0.9267 0.9246
|
| 664 |
+
close after first bracket 0.9265 0.9256
|
| 665 |
+
close after last bracket 0.9264 0.9240
|
| 666 |
+
|
| 667 |
+
Testing three possible repairs for close-open-shift/shift
|
| 668 |
+
w/ ambiguous open_three_subtrees 0.9264 0.9243
|
| 669 |
+
unambiguous c-o-s/shift 0.9265 0.9246
|
| 670 |
+
ambiguous c-o-s/shift closed early 0.9262 0.9246
|
| 671 |
+
ambiguous c-o-s/shift closed late 0.9259 0.9245
|
| 672 |
+
|
| 673 |
+
Testing three possible repairs for close-shift/shift
|
| 674 |
+
w/ ambiguous open_three_subtrees 0.9264 0.9243
|
| 675 |
+
unambiguous c-s/shift 0.9253 0.9239
|
| 676 |
+
ambiguous c-s/shift closed early 0.9259 0.9235
|
| 677 |
+
ambiguous c-s/shift closed late 0.9252 0.9241
|
| 678 |
+
ambiguous c-s/shift predicted 0.9264 0.9243
|
| 679 |
+
|
| 680 |
+
--------------------------------------------------------
|
| 681 |
+
|
| 682 |
+
Running ID experiments to verify some of the above findings
|
| 683 |
+
no charlm or bert, only 200 epochs
|
| 684 |
+
|
| 685 |
+
Comparing wrong_open fixes
|
| 686 |
+
w/ ambiguous open_two_subtrees 0.8448 0.8335
|
| 687 |
+
w/ ambiguous open_three_subtrees 0.8424 0.8336
|
| 688 |
+
|
| 689 |
+
Testing three possible repairs for close-shift/shift
|
| 690 |
+
unambiguous c-s/shift 0.8448 0.8360
|
| 691 |
+
ambiguous c-s/shift closed early 0.8425 0.8352
|
| 692 |
+
ambiguous c-s/shift closed late 0.8452 0.8334
|
| 693 |
+
|
| 694 |
+
--------------------------------------------------------
|
| 695 |
+
|
| 696 |
+
Running ID experiments to verify some of the above findings
|
| 697 |
+
bert + peft, only 200 epochs
|
| 698 |
+
|
| 699 |
+
Comparing wrong_open fixes
|
| 700 |
+
w/o ambiguous open/open fix 0.8923 0.8834
|
| 701 |
+
w/ ambiguous open_two_subtrees 0.8908 0.8828
|
| 702 |
+
w/ ambiguous open_three_subtrees 0.8901 0.8801
|
| 703 |
+
|
| 704 |
+
Testing three possible repairs for close-shift/shift
|
| 705 |
+
unambiguous c-s/shift 0.8921 0.8825
|
| 706 |
+
ambiguous c-s/shift closed early 0.8924 0.8841
|
| 707 |
+
ambiguous c-s/shift closed late 0.8921 0.8806
|
| 708 |
+
ambiguous c-s/shift predicted 0.8923 0.8835
|
| 709 |
+
|
| 710 |
+
--------------------------------------------------------
|
| 711 |
+
|
| 712 |
+
Running DE experiments to verify some of the above findings
|
| 713 |
+
bert + peft, only 200 epochs
|
| 714 |
+
|
| 715 |
+
Comparing wrong_open fixes
|
| 716 |
+
w/o ambiguous open/open fix 0.9576 0.9402
|
| 717 |
+
w/ ambiguous open_two_subtrees 0.9570 0.9410
|
| 718 |
+
w/ ambiguous open_three_subtrees 0.9569 0.9412
|
| 719 |
+
|
| 720 |
+
Testing three possible repairs for close-shift/shift
|
| 721 |
+
unambiguous c-s/shift 0.9566 0.9408
|
| 722 |
+
ambiguous c-s/shift closed early 0.9564 0.9394
|
| 723 |
+
ambiguous c-s/shift closed late 0.9572 0.9408
|
| 724 |
+
ambiguous c-s/shift predicted 0.9571 0.9404
|
| 725 |
+
|
| 726 |
+
--------------------------------------------------------
|
| 727 |
+
|
| 728 |
+
Running IT experiments to verify some of the above findings
|
| 729 |
+
bert + peft, only 200 epochs
|
| 730 |
+
|
| 731 |
+
Comparing wrong_open fixes
|
| 732 |
+
w/o ambiguous open/open fix 0.8380 0.8361
|
| 733 |
+
w/ ambiguous open_two_subtrees 0.8377 0.8351
|
| 734 |
+
w/ ambiguous open_three_subtrees 0.8381 0.8368
|
| 735 |
+
|
| 736 |
+
Testing three possible repairs for close-shift/shift
|
| 737 |
+
unambiguous c-s/shift 0.8376 0.8392
|
| 738 |
+
ambiguous c-s/shift closed early 0.8363 0.8359
|
| 739 |
+
ambiguous c-s/shift closed late 0.8365 0.8383
|
| 740 |
+
ambiguous c-s/shift predicted 0.8379 0.8371
|
| 741 |
+
|
| 742 |
+
--------------------------------------------------------
|
| 743 |
+
|
| 744 |
+
Running ZH experiments to verify some of the above findings
|
| 745 |
+
bert + peft, only 200 epochs
|
| 746 |
+
|
| 747 |
+
Comparing wrong_open fixes
|
| 748 |
+
w/o ambiguous open/open fix 0.9160 0.9143
|
| 749 |
+
w/ ambiguous open_two_subtrees 0.9145 0.9144
|
| 750 |
+
w/ ambiguous open_three_subtrees 0.9146 0.9142
|
| 751 |
+
|
| 752 |
+
Testing three possible repairs for close-shift/shift
|
| 753 |
+
unambiguous c-s/shift 0.9155 0.9146
|
| 754 |
+
ambiguous c-s/shift closed early 0.9145 0.9153
|
| 755 |
+
ambiguous c-s/shift closed late 0.9138 0.9140
|
| 756 |
+
ambiguous c-s/shift predicted 0.9154 0.9144
|
| 757 |
+
|
| 758 |
+
--------------------------------------------------------
|
| 759 |
+
|
| 760 |
+
Running VI experiments to verify some of the above findings
|
| 761 |
+
bert + peft, only 200 epochs
|
| 762 |
+
|
| 763 |
+
Comparing wrong_open fixes
|
| 764 |
+
w/o ambiguous open/open fix 0.8282 0.7668
|
| 765 |
+
w/ ambiguous open_two_subtrees 0.8272 0.7670
|
| 766 |
+
w/ ambiguous open_three_subtrees 0.8282 0.7668
|
| 767 |
+
|
| 768 |
+
Testing three possible repairs for close-shift/shift
|
| 769 |
+
unambiguous c-s/shift 0.8285 0.7683
|
| 770 |
+
ambiguous c-s/shift closed early 0.8276 0.7678
|
| 771 |
+
ambiguous c-s/shift closed late 0.8278 0.7668
|
| 772 |
+
ambiguous c-s/shift predicted 0.8270 0.7668
|
| 773 |
+
|
| 774 |
+
--------------------------------------------------------
|
| 775 |
+
|
| 776 |
+
Testing a combination of ambiguous vs predicted transitions
|
| 777 |
+
|
| 778 |
+
ambiguous
|
| 779 |
+
EN: (no CSS_U) 0.9258 0.9252
|
| 780 |
+
ZH: (no CSS_U) 0.9153 0.9145
|
| 781 |
+
|
| 782 |
+
predicted
|
| 783 |
+
EN: (no CSS_U) 0.9264 0.9241
|
| 784 |
+
ZH: (no CSS_U) 0.9145 0.9141
|
| 785 |
+
"""
|
| 786 |
+
def __new__(cls, fn, correct=False, debug=False):
|
| 787 |
+
"""
|
| 788 |
+
Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
|
| 789 |
+
|
| 790 |
+
correct: this represents a correct transition
|
| 791 |
+
|
| 792 |
+
debug: always run this, as it just counts statistics
|
| 793 |
+
"""
|
| 794 |
+
value = len(cls.__members__)
|
| 795 |
+
obj = object.__new__(cls)
|
| 796 |
+
obj._value_ = value + 1
|
| 797 |
+
obj.fn = fn
|
| 798 |
+
obj.correct = correct
|
| 799 |
+
obj.debug = debug
|
| 800 |
+
return obj
|
| 801 |
+
|
| 802 |
+
@property
|
| 803 |
+
def is_correct(self):
|
| 804 |
+
return self.correct
|
| 805 |
+
|
| 806 |
+
# The first section is a sequence of repairs when the parser
|
| 807 |
+
# should have chosen NTx but instead chose NTy
|
| 808 |
+
|
| 809 |
+
# Blocks of transitions which can be abstracted away to be
|
| 810 |
+
# anything will be represented as S1, S2, etc... S for stuff
|
| 811 |
+
|
| 812 |
+
# We carve out an exception for a wrong open at the root
|
| 813 |
+
# The only possble transtions at this point are to close
|
| 814 |
+
# the error and try again with the root
|
| 815 |
+
WRONG_OPEN_ROOT_ERROR = (fix_wrong_open_root_error,)
|
| 816 |
+
|
| 817 |
+
# The simplest form of such an error is when there is a sequence
|
| 818 |
+
# of unary transitions and the parser chose a wrong parent.
|
| 819 |
+
# Remember that a unary transition is represented by a pair
|
| 820 |
+
# of transitions, NTx, Close.
|
| 821 |
+
# In this case, the correct sequence was
|
| 822 |
+
# S1 NTx Close NTy Close NTz ...
|
| 823 |
+
# but the parser chose NTy, NTz, etc
|
| 824 |
+
# The repair in this case is to simply discard the unchosen
|
| 825 |
+
# unary transitions and continue
|
| 826 |
+
WRONG_OPEN_UNARY_CHAIN = (fix_wrong_open_unary_chain,)
|
| 827 |
+
|
| 828 |
+
# Similar to the UNARY_CHAIN error, but in this case there is a
|
| 829 |
+
# bunch of stuff (one or more constituents built) between the
|
| 830 |
+
# missed open transition and the close transition
|
| 831 |
+
WRONG_OPEN_STUFF_UNARY = (fix_wrong_open_stuff_unary,)
|
| 832 |
+
|
| 833 |
+
# If the correct sequence is
|
| 834 |
+
# T1 O_x T2 C
|
| 835 |
+
# and instead we predicted
|
| 836 |
+
# T1 O_y ...
|
| 837 |
+
# this can be fixed with a unary transition after
|
| 838 |
+
# T1 O_y T2 C O_x C
|
| 839 |
+
# note that this is technically ambiguous
|
| 840 |
+
# could have done
|
| 841 |
+
# T1 O_x C O_y T2 C
|
| 842 |
+
# but doing this should be easier for the parser to detect (untested)
|
| 843 |
+
# also this way the same code paths can be used for two subtrees
|
| 844 |
+
# and for multiple subtrees
|
| 845 |
+
WRONG_OPEN_TWO_SUBTREES = (fix_wrong_open_two_subtrees,)
|
| 846 |
+
|
| 847 |
+
# If the gold transition is an Open because it is part of
|
| 848 |
+
# a unary transition, and the following transition is a
|
| 849 |
+
# correct Shift or Close, we can just skip past the unary.
|
| 850 |
+
MISSED_UNARY = (fix_missed_unary,)
|
| 851 |
+
|
| 852 |
+
# Open -> Shift errors which don't just represent a unary
|
| 853 |
+
# generally represent a missing bracket which cannot be
|
| 854 |
+
# recovered using the in-order mechanism. Dropping the
|
| 855 |
+
# missing transition is generally the only fix.
|
| 856 |
+
# (This means removing the corresponding Close)
|
| 857 |
+
# One could theoretically create a new transition which
|
| 858 |
+
# grabs two constituents, though
|
| 859 |
+
OPEN_SHIFT = (fix_open_shift,)
|
| 860 |
+
|
| 861 |
+
# Open -> Close is a rather drastic break in the
|
| 862 |
+
# potential structure of the tree. We can no longer
|
| 863 |
+
# recover the missed Open, and we might not be able
|
| 864 |
+
# to recover other following missed Opens as well.
|
| 865 |
+
# In most cases, the only thing to do is reopen the
|
| 866 |
+
# incorrectly closed outer bracket and keep going.
|
| 867 |
+
OPEN_CLOSE = (fix_open_close,)
|
| 868 |
+
|
| 869 |
+
# Similar to the Open -> Close error, but at least
|
| 870 |
+
# in this case we are just introducing one wrong bracket
|
| 871 |
+
# rather than also breaking some existing brackets.
|
| 872 |
+
# The fix here is to reopen the closed bracket.
|
| 873 |
+
SHIFT_CLOSE = (fix_shift_close,)
|
| 874 |
+
|
| 875 |
+
# Specifically fixes an error where bracket X is
|
| 876 |
+
# closed and then immediately opened to build a
|
| 877 |
+
# new X bracket. In this case, the simplest fix
|
| 878 |
+
# will be to skip both the close and the new open
|
| 879 |
+
# and continue from there.
|
| 880 |
+
CLOSE_OPEN_SHIFT_NESTED = (fix_close_open_shift_nested,)
|
| 881 |
+
|
| 882 |
+
# Fix an error where the correct sequence was to Close X, Open Y,
|
| 883 |
+
# then continue building,
|
| 884 |
+
# but instead the model did a Shift in place of C_X O_Y
|
| 885 |
+
# The damage here is a recall error for the missed X and
|
| 886 |
+
# a precision error for the incorrectly opened X
|
| 887 |
+
# However, the Y can actually be recovered - whenever we finally
|
| 888 |
+
# close X, we can then open Y
|
| 889 |
+
# One form of that is unambiguous, that of
|
| 890 |
+
# T_A O_X T_B C O_Y T_C C
|
| 891 |
+
# with only one subtree after the O_Y
|
| 892 |
+
# In that case, the Close that would have closed Y
|
| 893 |
+
# is the only place for the missing close of X
|
| 894 |
+
# So we can produce the following:
|
| 895 |
+
# T_A O_X T_B T_C C O_Y C
|
| 896 |
+
CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET = (fix_close_open_shift_unambiguous_bracket,)
|
| 897 |
+
|
| 898 |
+
# Similarly to WRONG_OPEN_TWO_SUBTREES, if the correct sequence is
|
| 899 |
+
# T1 O_x T2 T3 C
|
| 900 |
+
# and instead we predicted
|
| 901 |
+
# T1 O_y ...
|
| 902 |
+
# this can be fixed by closing O_y in any number of places
|
| 903 |
+
# T1 O_y T2 C O_x T3 C
|
| 904 |
+
# T1 O_y T2 C T3 O_x C
|
| 905 |
+
# Either solution is a single precision error,
|
| 906 |
+
# but keeps the O_x subtree correct
|
| 907 |
+
# This is an ambiguous transition - we can experiment with different fixes
|
| 908 |
+
WRONG_OPEN_MULTIPLE_SUBTREES = (fix_wrong_open_multiple_subtrees,)
|
| 909 |
+
|
| 910 |
+
CORRECT = (None, True)
|
| 911 |
+
|
| 912 |
+
UNKNOWN = None
|
| 913 |
+
|
| 914 |
+
# If the model is supposed to build a block after a Close
|
| 915 |
+
# operation, attach that block to the piece to the left
|
| 916 |
+
# a couple different variations on this were tried
|
| 917 |
+
# we tried attaching all constituents to the
|
| 918 |
+
# bracket which should have been closed
|
| 919 |
+
# we tried attaching exactly one constituent
|
| 920 |
+
# and we tried attaching only if there was
|
| 921 |
+
# exactly one following constituent
|
| 922 |
+
# none of these improved f1. for example, on the VI dataset, we
|
| 923 |
+
# lost 0.15 F1 with the exactly one following constituent version
|
| 924 |
+
# it might be worthwhile double checking some of the other
|
| 925 |
+
# versions to make sure those also fail, though
|
| 926 |
+
CLOSE_SHIFT_SHIFT = (fix_close_shift_shift_unambiguous,)
|
| 927 |
+
|
| 928 |
+
# In the ambiguous close-shift/shift case, this closes the surrounding bracket
|
| 929 |
+
# (which should have already been closed)
|
| 930 |
+
# as soon as the next constituent is built
|
| 931 |
+
# this turns
|
| 932 |
+
# (A (B s1 s2) s3 s4)
|
| 933 |
+
# into
|
| 934 |
+
# (A (B s1 s2 s3) s4)
|
| 935 |
+
CLOSE_SHIFT_SHIFT_AMBIGUOUS_EARLY = (fix_close_shift_shift_ambiguous_early,)
|
| 936 |
+
|
| 937 |
+
# In the ambiguous close-shift/shift case, this closes the surrounding bracket
|
| 938 |
+
# (which should have already been closed)
|
| 939 |
+
# when the rest of the constituents in this bracket are built
|
| 940 |
+
# this turns
|
| 941 |
+
# (A (B s1 s2) s3 s4)
|
| 942 |
+
# into
|
| 943 |
+
# (A (B s1 s2 s3 s4))
|
| 944 |
+
CLOSE_SHIFT_SHIFT_AMBIGUOUS_LATE = (fix_close_shift_shift_ambiguous_late,)
|
| 945 |
+
|
| 946 |
+
# For the close-shift/shift errors which are ambiguous,
|
| 947 |
+
# this uses the model's predictions to guess which block
|
| 948 |
+
# to put the close after
|
| 949 |
+
CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_shift_shift_ambiguous_predicted,)
|
| 950 |
+
|
| 951 |
+
# If a sequence should have gone Close - Open - Shift,
|
| 952 |
+
# and instead we went Shift,
|
| 953 |
+
# we need to close the previous bracket
|
| 954 |
+
# If it is ambiguous
|
| 955 |
+
# such as Close - Open - Shift - Shift
|
| 956 |
+
# close the bracket ASAP
|
| 957 |
+
# eg, Shift - Close - Open - Shift
|
| 958 |
+
CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_EARLY = (fix_close_open_shift_ambiguous_bracket_early,)
|
| 959 |
+
|
| 960 |
+
# for Close - Open - Shift - Shift
|
| 961 |
+
# close the bracket as late as possible
|
| 962 |
+
# eg, Shift - Shift - Close - Open
|
| 963 |
+
CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_LATE = (fix_close_open_shift_ambiguous_bracket_late,)
|
| 964 |
+
|
| 965 |
+
# If the sequence should have gone
|
| 966 |
+
# Close - Open - Shift
|
| 967 |
+
# and instead we predicted a Shift
|
| 968 |
+
# in a context where closing the bracket would be ambiguous
|
| 969 |
+
# we use the model to predict where the close should actually happen
|
| 970 |
+
CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_open_shift_ambiguous_predicted,)
|
| 971 |
+
|
| 972 |
+
# This particular repair effectively turns the shift -> ambiguous open
|
| 973 |
+
# into a unary transition
|
| 974 |
+
SHIFT_OPEN_UNARY_CLOSE = (ambiguous_shift_open_unary_close,)
|
| 975 |
+
|
| 976 |
+
# Fix the shift -> ambiguous open by closing after the first constituent
|
| 977 |
+
# This is an ambiguous solution because it could also be closed either
|
| 978 |
+
# as a unary transition or with a close at the end of the outer bracket
|
| 979 |
+
SHIFT_OPEN_EARLY_CLOSE = (ambiguous_shift_open_early_close,)
|
| 980 |
+
|
| 981 |
+
# Fix the shift -> ambiguous open by closing after all constituents
|
| 982 |
+
# This is an ambiguous solution because it could also be closed either
|
| 983 |
+
# as a unary transition or with a close at the end of the first constituent
|
| 984 |
+
SHIFT_OPEN_LATE_CLOSE = (ambiguous_shift_open_late_close,)
|
| 985 |
+
|
| 986 |
+
# Use the model to predict when to close!
|
| 987 |
+
# The different options for where to put the Close are put into the model,
|
| 988 |
+
# and the highest scoring close is used
|
| 989 |
+
SHIFT_OPEN_PREDICTED_CLOSE = (ambiguous_shift_open_predicted_close,)
|
| 990 |
+
|
| 991 |
+
OTHER_CLOSE_SHIFT = (report_close_shift, False, True)
|
| 992 |
+
|
| 993 |
+
OTHER_CLOSE_OPEN = (report_close_open, False, True)
|
| 994 |
+
|
| 995 |
+
OTHER_OPEN_OPEN = (report_open_open, False, True)
|
| 996 |
+
|
| 997 |
+
OTHER_OPEN_CLOSE = (report_open_close, False, True)
|
| 998 |
+
|
| 999 |
+
OTHER_OPEN_SHIFT = (report_open_shift, False, True)
|
| 1000 |
+
|
| 1001 |
+
OTHER_SHIFT_OPEN = (report_shift_open, False, True)
|
| 1002 |
+
|
| 1003 |
+
# any other open transition we get wrong, which hasn't already
|
| 1004 |
+
# been carved out as an exception above, we just accept the
|
| 1005 |
+
# incorrect Open and keep going
|
| 1006 |
+
#
|
| 1007 |
+
# TODO: check if there is a way to improve this
|
| 1008 |
+
# it appears to hurt scores simply by existing
|
| 1009 |
+
# explanation: this is wrong logic
|
| 1010 |
+
# Suppose the correct sequence had been
|
| 1011 |
+
# T1 open(NP) T2 T3 close
|
| 1012 |
+
# Instead we had done
|
| 1013 |
+
# T1 open(VP) T2 T3 close
|
| 1014 |
+
# We can recover the missing NP!
|
| 1015 |
+
# T1 open(VP) T2 close open(NP) T3 close
|
| 1016 |
+
# Can also recover it as
|
| 1017 |
+
# T1 open(VP) T2 T3 close open(NP) close
|
| 1018 |
+
# So this is actually an ambiguous transition
|
| 1019 |
+
# except in the case of
|
| 1020 |
+
# T1 open(...) close
|
| 1021 |
+
# In this case, a unary transition can fix make it so we only have
|
| 1022 |
+
# a precision error, not also a recall error
|
| 1023 |
+
# Currently, the approach is to put this after the default fixes
|
| 1024 |
+
# and use the two & more-than-two versions of the fix above
|
| 1025 |
+
WRONG_OPEN_GENERAL = (fix_wrong_open_general,)
|
| 1026 |
+
|
| 1027 |
+
class InOrderOracle(DynamicOracle):
|
| 1028 |
+
def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
|
| 1029 |
+
super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
|
stanza/stanza/models/constituency/lstm_model.py
ADDED
|
@@ -0,0 +1,1178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A version of the BaseModel which uses LSTMs to predict the correct next transition
|
| 3 |
+
based on the current known state.
|
| 4 |
+
|
| 5 |
+
The primary purpose of this class is to implement the prediction of the next
|
| 6 |
+
transition, which is done by concatenating the output of an LSTM operated over
|
| 7 |
+
previous transitions, the words, and the partially built constituents.
|
| 8 |
+
|
| 9 |
+
A complete processing of a sentence is as follows:
|
| 10 |
+
1) Run the input words through an encoder.
|
| 11 |
+
The encoder includes some or all of the following:
|
| 12 |
+
pretrained word embedding
|
| 13 |
+
finetuned word embedding for training set words - "delta_embedding"
|
| 14 |
+
POS tag embedding
|
| 15 |
+
pretrained charlm representation
|
| 16 |
+
BERT or similar large language model representation
|
| 17 |
+
attention transformer over the previous inputs
|
| 18 |
+
labeled attention transformer over the first attention layer
|
| 19 |
+
The encoded input is then put through a bi-lstm, giving a word representation
|
| 20 |
+
2) Transitions are put in an embedding, and transitions already used are tracked
|
| 21 |
+
in an LSTM
|
| 22 |
+
3) Constituents already built are also processed in an LSTM
|
| 23 |
+
4) Every transition is chosen by taking the output of the current word position,
|
| 24 |
+
the transition LSTM, and the constituent LSTM, and classifying the next
|
| 25 |
+
transition
|
| 26 |
+
5) Transitions are repeated (with constraints) until the sentence is completed
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from collections import namedtuple
|
| 30 |
+
import copy
|
| 31 |
+
from enum import Enum
|
| 32 |
+
import logging
|
| 33 |
+
import math
|
| 34 |
+
import random
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torch.nn as nn
|
| 38 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
| 39 |
+
|
| 40 |
+
from stanza.models.common.bert_embedding import extract_bert_embeddings
|
| 41 |
+
from stanza.models.common.maxout_linear import MaxoutLinear
|
| 42 |
+
from stanza.models.common.utils import attach_bert_model, unsort
|
| 43 |
+
from stanza.models.common.vocab import PAD_ID, UNK_ID
|
| 44 |
+
from stanza.models.constituency.base_model import BaseModel
|
| 45 |
+
from stanza.models.constituency.label_attention import LabelAttentionModule
|
| 46 |
+
from stanza.models.constituency.lstm_tree_stack import LSTMTreeStack
|
| 47 |
+
from stanza.models.constituency.parse_transitions import TransitionScheme
|
| 48 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 49 |
+
from stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule
|
| 50 |
+
from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding
|
| 51 |
+
from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack
|
| 52 |
+
from stanza.models.constituency.tree_stack import TreeStack
|
| 53 |
+
from stanza.models.constituency.utils import build_nonlinearity, initialize_linear
|
| 54 |
+
|
| 55 |
+
logger = logging.getLogger('stanza')
|
| 56 |
+
tlogger = logging.getLogger('stanza.constituency.trainer')
|
| 57 |
+
|
| 58 |
+
WordNode = namedtuple("WordNode", ['value', 'hx'])
|
| 59 |
+
|
| 60 |
+
# lstm_hx & lstm_cx are the hidden & cell states of the LSTM going across constituents
|
| 61 |
+
# tree_hx and tree_cx are the states of the lstm going up the constituents in the case of the tree_lstm combination method
|
| 62 |
+
Constituent = namedtuple("Constituent", ['value', 'tree_hx', 'tree_cx'])
|
| 63 |
+
|
| 64 |
+
# The sentence boundary vectors are marginally useful at best.
|
| 65 |
+
# However, they make it much easier to use non-bert layers as input to
|
| 66 |
+
# attention layers, as the attention layers work better when they have
|
| 67 |
+
# an index 0 to attend to.
|
| 68 |
+
class SentenceBoundary(Enum):
|
| 69 |
+
NONE = 1
|
| 70 |
+
WORDS = 2
|
| 71 |
+
EVERYTHING = 3
|
| 72 |
+
|
| 73 |
+
class StackHistory(Enum):
|
| 74 |
+
LSTM = 1
|
| 75 |
+
ATTN = 2
|
| 76 |
+
|
| 77 |
+
# How to compose constituent children into new constituents
|
| 78 |
+
# MAX is simply take the max value of the children
|
| 79 |
+
# this is surprisingly effective
|
| 80 |
+
# for example, a Turkish dataset went from 81-81.5 dev, 75->75.5 test
|
| 81 |
+
# BILSTM is the method described in the papers of making an lstm
|
| 82 |
+
# out of the constituents
|
| 83 |
+
# BILSTM_MAX is the same as BILSTM, but instead of using a Linear
|
| 84 |
+
# to reduce the outputs of the lstm, we first take the max
|
| 85 |
+
# and then use a linear to reduce the max
|
| 86 |
+
# BIGRAM combines pairs of children and then takes the max over those
|
| 87 |
+
# ATTN means to put an attention layer over the children nodes
|
| 88 |
+
# we then take the max of the children with their attention
|
| 89 |
+
#
|
| 90 |
+
# Experiments show that MAX is noticeably better than the other options
|
| 91 |
+
# On ja_alt, here are a few results after 200 iterations,
|
| 92 |
+
# averaged over 5 iterations:
|
| 93 |
+
# MAX: 0.8985
|
| 94 |
+
# BILSTM: 0.8964
|
| 95 |
+
# BILSTM_MAX: 0.8973
|
| 96 |
+
# BIGRAM: 0.8982
|
| 97 |
+
#
|
| 98 |
+
# The MAX method has a linear transform after the max.
|
| 99 |
+
# Removing that transform makes the score go down to 0.8982
|
| 100 |
+
#
|
| 101 |
+
# We tried a few varieties of BILSTM_MAX
|
| 102 |
+
# In particular:
|
| 103 |
+
# max over LSTM, combining forward & backward using the max: 0.8970
|
| 104 |
+
# max over forward & backward separately, then reduce: 0.8970
|
| 105 |
+
# max over forward & backward only over 1:-1
|
| 106 |
+
# (eg, leave out the node embedding): 0.8969
|
| 107 |
+
# same as previous, but split the reduce into 2 pieces: 0.8973
|
| 108 |
+
# max over forward & backward separately, then reduce as
|
| 109 |
+
# 1/2(F + B) + W(F,B)
|
| 110 |
+
# the idea being that this way F and B are guaranteed
|
| 111 |
+
# to be represented: 0.8971
|
| 112 |
+
#
|
| 113 |
+
# BIGRAM is an attempt to mix information from nodes
|
| 114 |
+
# when building constituents, but it didn't help
|
| 115 |
+
# The first example, just taking pairs and learning
|
| 116 |
+
# a transform, went to NaN. Likely the transform
|
| 117 |
+
# expanded the embedding too much. Switching it to
|
| 118 |
+
# scale the matrix by 0.5 didn't go to Nan, but only
|
| 119 |
+
# resulted in 0.8982
|
| 120 |
+
#
|
| 121 |
+
# A couple varieties of ATTN:
|
| 122 |
+
# first an input linear, then attn, then an output linear
|
| 123 |
+
# the upside of this would be making the dimension of the attn
|
| 124 |
+
# independent from the rest of the model
|
| 125 |
+
# however, this caused an expansion in the magnitude of the vectors,
|
| 126 |
+
# resulting in NaN for deep enough trees
|
| 127 |
+
# adding layernorm or tanh to balance this out resulted in
|
| 128 |
+
# disappointing performance
|
| 129 |
+
# tanh: 0.8972
|
| 130 |
+
# another alternative not tested yet: lower initialization weights
|
| 131 |
+
# and enforce that the norms of the matrices are low enough that
|
| 132 |
+
# exponential explosion up the layers of the tree doesn't happen
|
| 133 |
+
# just an attention layer means hidden_size % reduce_heads == 0
|
| 134 |
+
# that is simple enough to enforce by slightly changing hidden_size
|
| 135 |
+
# if needed
|
| 136 |
+
# appending the embedding for the open state to the start of the
|
| 137 |
+
# sequence of children and taking only the content nodes
|
| 138 |
+
# was very disappointing: 0.8967
|
| 139 |
+
# taking the entire sequence of children including the open state
|
| 140 |
+
# embedding resulted in 0.8973
|
| 141 |
+
# long story short, this looks like an idea that should work, but it
|
| 142 |
+
# doesn't help. suggestions welcome for improving these results
|
| 143 |
+
#
|
| 144 |
+
# The current TREE_LSTM_CX mechanism uses a word's embedding
|
| 145 |
+
# as the hx and a trained embedding over tags as the cx 0.8996
|
| 146 |
+
# This worked slightly better than 0s for cx (TREE_LSTM) 0.8992
|
| 147 |
+
# A variant of TREE_LSTM which didn't work out:
|
| 148 |
+
# nodes are combined with an LSTM
|
| 149 |
+
# hx & cx are embeddings of the node type (eg S, NP, etc)
|
| 150 |
+
# input is the max over children: 0.8977
|
| 151 |
+
# Another variant which didn't work: use the word embedding
|
| 152 |
+
# as input to the same LSTM to get hx & cx 0.8985
|
| 153 |
+
# Note that although the scores for TREE_LSTM_CX are slightly higher
|
| 154 |
+
# than MAX for the JA dataset, the benefit was not as clear for EN,
|
| 155 |
+
# so we left the default at MAX.
|
| 156 |
+
# For example, on English WSJ, before switching to Bert POS and
|
| 157 |
+
# a learned Bert mixing layer, a comparison of 5x models trained
|
| 158 |
+
# for 400 iterations got dev scores of:
|
| 159 |
+
# TREE_LSTM_CX 0.9589
|
| 160 |
+
# MAX 0.9593
|
| 161 |
+
#
|
| 162 |
+
# UNTIED_MAX has a different reduce_linear for each type of
|
| 163 |
+
# constituent in the model. Similar to the different linear
|
| 164 |
+
# maps used in the CVG paper from Socher, Bauer, Manning, Ng
|
| 165 |
+
# This is implemented as a large CxHxH parameter,
|
| 166 |
+
# with num_constituent layers of hidden-hidden transform,
|
| 167 |
+
# along with a CxH bias parameter.
|
| 168 |
+
# Essentially C Linears stacked on top of each other,
|
| 169 |
+
# but in a parameter so that indexing can be done quickly.
|
| 170 |
+
# Unfortunately this does not beat out MAX with one combined linear.
|
| 171 |
+
# On an experiment on WSJ with all the best settings as of early
|
| 172 |
+
# October 2022, such as a Bert model POS tagger:
|
| 173 |
+
# MAX 0.9597
|
| 174 |
+
# UNTIED_MAX 0.9592
|
| 175 |
+
# Furthermore, starting from a finished MAX model and restarting
|
| 176 |
+
# by splitting the MAX layer into multiple pieces did not improve.
|
| 177 |
+
#
|
| 178 |
+
# KEY has a single Key which is used for a facsimile of ATTN
|
| 179 |
+
# each incoming subtree has its values weighted by a Query
|
| 180 |
+
# then the Key is used to calculate a softmax
|
| 181 |
+
# finally, a Value is used to scale the subtrees
|
| 182 |
+
# reduce_heads is used to determine the number of heads
|
| 183 |
+
# There is an option to use or not use position information
|
| 184 |
+
# using a sinusoidal position embedding
|
| 185 |
+
# UNTIED_KEY is the same, but has a different key
|
| 186 |
+
# for each possible constituent
|
| 187 |
+
# On a VI dataset:
|
| 188 |
+
# MAX 0.82064
|
| 189 |
+
# KEY (pos, 8) 0.81739
|
| 190 |
+
# UNTIED_KEY (pos, 8) 0.82046
|
| 191 |
+
# UNTIED_KEY (pos, 4) 0.81742
|
| 192 |
+
# Attempted to add a linear to mix the attn heads together,
|
| 193 |
+
# but that was awful: 0.81567
|
| 194 |
+
# Adding two position vectors, one in each direction, did not help:
|
| 195 |
+
# UNTIED_KEY (2x pos, 8) 0.8188
|
| 196 |
+
# To redo that experiment, double the width of reduce_query and
|
| 197 |
+
# reduce_value, then call reduce_position on nhx, flip it,
|
| 198 |
+
# and call reduce_position again
|
| 199 |
+
# Evidently the experiments to try should be:
|
| 200 |
+
# no pos at all
|
| 201 |
+
# more heads
|
| 202 |
+
class ConstituencyComposition(Enum):
|
| 203 |
+
BILSTM = 1
|
| 204 |
+
MAX = 2
|
| 205 |
+
TREE_LSTM = 3
|
| 206 |
+
BILSTM_MAX = 4
|
| 207 |
+
BIGRAM = 5
|
| 208 |
+
ATTN = 6
|
| 209 |
+
TREE_LSTM_CX = 7
|
| 210 |
+
UNTIED_MAX = 8
|
| 211 |
+
KEY = 9
|
| 212 |
+
UNTIED_KEY = 10
|
| 213 |
+
|
| 214 |
+
class LSTMModel(BaseModel, nn.Module):
|
| 215 |
+
def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, force_bert_saved, peft_name, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args):
|
| 216 |
+
"""
|
| 217 |
+
pretrain: a Pretrain object
|
| 218 |
+
transitions: a list of all possible transitions which will be
|
| 219 |
+
used to build trees
|
| 220 |
+
constituents: a list of all possible constituents in the treebank
|
| 221 |
+
tags: a list of all possible tags in the treebank
|
| 222 |
+
words: a list of all known words, used for a delta word embedding.
|
| 223 |
+
note that there will be an attempt made to learn UNK words as well,
|
| 224 |
+
and tags by themselves may help UNK words
|
| 225 |
+
rare_words: a list of rare words, used to occasionally replace with UNK
|
| 226 |
+
root_labels: probably ROOT, although apparently some treebanks like TOP or even s
|
| 227 |
+
constituent_opens: a list of all possible open nodes which will go on the stack
|
| 228 |
+
- this might be different from constituents if there are nodes
|
| 229 |
+
which represent multiple constituents at once
|
| 230 |
+
args: hidden_size, transition_hidden_size, etc as gotten from
|
| 231 |
+
constituency_parser.py
|
| 232 |
+
|
| 233 |
+
Note that it might look like a hassle to pass all of this in
|
| 234 |
+
when it can be collected directly from the trees themselves.
|
| 235 |
+
However, that would only work at train time. At eval or
|
| 236 |
+
pipeline time we will load the lists from the saved model.
|
| 237 |
+
"""
|
| 238 |
+
super().__init__(transition_scheme=args['transition_scheme'], unary_limit=unary_limit, reverse_sentence=args.get('reversed', False), root_labels=root_labels)
|
| 239 |
+
|
| 240 |
+
self.args = args
|
| 241 |
+
self.unsaved_modules = []
|
| 242 |
+
|
| 243 |
+
emb_matrix = pretrain.emb
|
| 244 |
+
self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
|
| 245 |
+
|
| 246 |
+
# replacing NBSP picks up a whole bunch of words for VI
|
| 247 |
+
self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }
|
| 248 |
+
# precompute tensors for the word indices
|
| 249 |
+
# the tensors should be put on the GPU if needed by calling to(device)
|
| 250 |
+
self.register_buffer('vocab_tensors', torch.tensor(range(len(pretrain.vocab)), requires_grad=False))
|
| 251 |
+
self.vocab_size = emb_matrix.shape[0]
|
| 252 |
+
self.embedding_dim = emb_matrix.shape[1]
|
| 253 |
+
|
| 254 |
+
self.constituents = sorted(list(constituents))
|
| 255 |
+
|
| 256 |
+
self.hidden_size = self.args['hidden_size']
|
| 257 |
+
self.constituency_composition = self.args.get("constituency_composition", ConstituencyComposition.BILSTM)
|
| 258 |
+
if self.constituency_composition in (ConstituencyComposition.ATTN, ConstituencyComposition.KEY, ConstituencyComposition.UNTIED_KEY):
|
| 259 |
+
self.reduce_heads = self.args['reduce_heads']
|
| 260 |
+
if self.hidden_size % self.reduce_heads != 0:
|
| 261 |
+
self.hidden_size = self.hidden_size + self.reduce_heads - (self.hidden_size % self.reduce_heads)
|
| 262 |
+
|
| 263 |
+
if args['constituent_stack'] == StackHistory.ATTN:
|
| 264 |
+
self.reduce_heads = self.args['reduce_heads']
|
| 265 |
+
if self.hidden_size % args['constituent_heads'] != 0:
|
| 266 |
+
# TODO: technically we should either use the LCM of this and reduce_heads, or just have two separate fields
|
| 267 |
+
self.hidden_size = self.hidden_size + args['constituent_heads'] - (hidden_size % args['constituent_heads'])
|
| 268 |
+
if self.constituency_composition == ConstituencyComposition.ATTN and self.hidden_size % self.reduce_heads != 0:
|
| 269 |
+
raise ValueError("--reduce_heads and --constituent_heads not compatible!")
|
| 270 |
+
|
| 271 |
+
self.transition_hidden_size = self.args['transition_hidden_size']
|
| 272 |
+
if args['transition_stack'] == StackHistory.ATTN:
|
| 273 |
+
if self.transition_hidden_size % args['transition_heads'] > 0:
|
| 274 |
+
logger.warning("transition_hidden_size %d %% transition_heads %d != 0. reconfiguring", transition_hidden_size, args['transition_heads'])
|
| 275 |
+
self.transition_hidden_size = self.transition_hidden_size + args['transition_heads'] - (self.transition_hidden_size % args['transition_heads'])
|
| 276 |
+
|
| 277 |
+
self.tag_embedding_dim = self.args['tag_embedding_dim']
|
| 278 |
+
self.transition_embedding_dim = self.args['transition_embedding_dim']
|
| 279 |
+
self.delta_embedding_dim = self.args['delta_embedding_dim']
|
| 280 |
+
|
| 281 |
+
self.word_input_size = self.embedding_dim + self.tag_embedding_dim + self.delta_embedding_dim
|
| 282 |
+
|
| 283 |
+
if forward_charlm is not None:
|
| 284 |
+
self.add_unsaved_module('forward_charlm', forward_charlm)
|
| 285 |
+
self.word_input_size += self.forward_charlm.hidden_dim()
|
| 286 |
+
if not forward_charlm.is_forward_lm:
|
| 287 |
+
raise ValueError("Got a backward charlm as a forward charlm!")
|
| 288 |
+
else:
|
| 289 |
+
self.forward_charlm = None
|
| 290 |
+
if backward_charlm is not None:
|
| 291 |
+
self.add_unsaved_module('backward_charlm', backward_charlm)
|
| 292 |
+
self.word_input_size += self.backward_charlm.hidden_dim()
|
| 293 |
+
if backward_charlm.is_forward_lm:
|
| 294 |
+
raise ValueError("Got a forward charlm as a backward charlm!")
|
| 295 |
+
else:
|
| 296 |
+
self.backward_charlm = None
|
| 297 |
+
|
| 298 |
+
self.delta_words = sorted(set(words))
|
| 299 |
+
self.delta_word_map = { word: i+2 for i, word in enumerate(self.delta_words) }
|
| 300 |
+
assert PAD_ID == 0
|
| 301 |
+
assert UNK_ID == 1
|
| 302 |
+
# initialization is chosen based on the observed values of the norms
|
| 303 |
+
# after several long training cycles
|
| 304 |
+
# (this is true for other embeddings and embedding-like vectors as well)
|
| 305 |
+
# the experiments show this slightly helps were done with
|
| 306 |
+
# Adadelta and the correct initialization may be slightly
|
| 307 |
+
# different for a different optimizer.
|
| 308 |
+
# in fact, it is likely a scheme other than normal_ would
|
| 309 |
+
# be better - the optimizer tends to learn the weights
|
| 310 |
+
# rather close to 0 before learning in the direction it
|
| 311 |
+
# actually wants to go
|
| 312 |
+
self.delta_embedding = nn.Embedding(num_embeddings = len(self.delta_words)+2,
|
| 313 |
+
embedding_dim = self.delta_embedding_dim,
|
| 314 |
+
padding_idx = 0)
|
| 315 |
+
nn.init.normal_(self.delta_embedding.weight, std=0.05)
|
| 316 |
+
self.register_buffer('delta_tensors', torch.tensor(range(len(self.delta_words) + 2), requires_grad=False))
|
| 317 |
+
|
| 318 |
+
self.rare_words = set(rare_words)
|
| 319 |
+
|
| 320 |
+
self.tags = sorted(list(tags))
|
| 321 |
+
if self.tag_embedding_dim > 0:
|
| 322 |
+
self.tag_map = { t: i+2 for i, t in enumerate(self.tags) }
|
| 323 |
+
self.tag_embedding = nn.Embedding(num_embeddings = len(tags)+2,
|
| 324 |
+
embedding_dim = self.tag_embedding_dim,
|
| 325 |
+
padding_idx = 0)
|
| 326 |
+
nn.init.normal_(self.tag_embedding.weight, std=0.25)
|
| 327 |
+
self.register_buffer('tag_tensors', torch.tensor(range(len(self.tags) + 2), requires_grad=False))
|
| 328 |
+
|
| 329 |
+
self.num_lstm_layers = self.args['num_lstm_layers']
|
| 330 |
+
self.num_tree_lstm_layers = self.args['num_tree_lstm_layers']
|
| 331 |
+
self.lstm_layer_dropout = self.args['lstm_layer_dropout']
|
| 332 |
+
|
| 333 |
+
self.word_dropout = nn.Dropout(self.args['word_dropout'])
|
| 334 |
+
self.predict_dropout = nn.Dropout(self.args['predict_dropout'])
|
| 335 |
+
self.lstm_input_dropout = nn.Dropout(self.args['lstm_input_dropout'])
|
| 336 |
+
|
| 337 |
+
# also register a buffer of zeros so that we can always get zeros on the appropriate device
|
| 338 |
+
self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers))
|
| 339 |
+
self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size))
|
| 340 |
+
|
| 341 |
+
# possibly add a couple vectors for bookends of the sentence
|
| 342 |
+
# We put the word_start and word_end here, AFTER counting the
|
| 343 |
+
# charlm dimension, but BEFORE counting the bert dimension,
|
| 344 |
+
# as we want word_start and word_end to not have dimensions
|
| 345 |
+
# for the bert embedding. The bert model will add its own
|
| 346 |
+
# start and end representation.
|
| 347 |
+
self.sentence_boundary_vectors = self.args['sentence_boundary_vectors']
|
| 348 |
+
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
|
| 349 |
+
self.register_parameter('word_start_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))
|
| 350 |
+
self.register_parameter('word_end_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))
|
| 351 |
+
|
| 352 |
+
# we set up the bert AFTER building word_start and word_end
|
| 353 |
+
# so that we can use the charlm endpoint values rather than
|
| 354 |
+
# try to train our own
|
| 355 |
+
self.force_bert_saved = force_bert_saved or self.args['bert_finetune'] or self.args['stage1_bert_finetune']
|
| 356 |
+
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), self.force_bert_saved)
|
| 357 |
+
self.peft_name = peft_name
|
| 358 |
+
|
| 359 |
+
if bert_model is not None:
|
| 360 |
+
if bert_tokenizer is None:
|
| 361 |
+
raise ValueError("Cannot have a bert model without a tokenizer")
|
| 362 |
+
self.bert_dim = self.bert_model.config.hidden_size
|
| 363 |
+
if args['bert_hidden_layers']:
|
| 364 |
+
# The average will be offset by 1/N so that the default zeros
|
| 365 |
+
# represents an average of the N layers
|
| 366 |
+
if args['bert_hidden_layers'] > bert_model.config.num_hidden_layers:
|
| 367 |
+
# limit ourselves to the number of layers actually available
|
| 368 |
+
# note that we can +1 because of the initial embedding layer
|
| 369 |
+
args['bert_hidden_layers'] = bert_model.config.num_hidden_layers + 1
|
| 370 |
+
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
|
| 371 |
+
nn.init.zeros_(self.bert_layer_mix.weight)
|
| 372 |
+
else:
|
| 373 |
+
# an average of layers 2, 3, 4 will be used
|
| 374 |
+
# (for historic reasons)
|
| 375 |
+
self.bert_layer_mix = None
|
| 376 |
+
self.word_input_size = self.word_input_size + self.bert_dim
|
| 377 |
+
|
| 378 |
+
self.partitioned_transformer_module = None
|
| 379 |
+
self.pattn_d_model = 0
|
| 380 |
+
if LSTMModel.uses_pattn(self.args):
|
| 381 |
+
# Initializations of parameters for the Partitioned Attention
|
| 382 |
+
# round off the size of the model so that it divides in half evenly
|
| 383 |
+
self.pattn_d_model = self.args['pattn_d_model'] // 2 * 2
|
| 384 |
+
|
| 385 |
+
# Initializations for the Partitioned Attention
|
| 386 |
+
# experiments suggest having a bias does not help here
|
| 387 |
+
self.partitioned_transformer_module = PartitionedTransformerModule(
|
| 388 |
+
self.args['pattn_num_layers'],
|
| 389 |
+
d_model=self.pattn_d_model,
|
| 390 |
+
n_head=self.args['pattn_num_heads'],
|
| 391 |
+
d_qkv=self.args['pattn_d_kv'],
|
| 392 |
+
d_ff=self.args['pattn_d_ff'],
|
| 393 |
+
ff_dropout=self.args['pattn_relu_dropout'],
|
| 394 |
+
residual_dropout=self.args['pattn_residual_dropout'],
|
| 395 |
+
attention_dropout=self.args['pattn_attention_dropout'],
|
| 396 |
+
word_input_size=self.word_input_size,
|
| 397 |
+
bias=self.args['pattn_bias'],
|
| 398 |
+
morpho_emb_dropout=self.args['pattn_morpho_emb_dropout'],
|
| 399 |
+
timing=self.args['pattn_timing'],
|
| 400 |
+
encoder_max_len=self.args['pattn_encoder_max_len']
|
| 401 |
+
)
|
| 402 |
+
self.word_input_size += self.pattn_d_model
|
| 403 |
+
|
| 404 |
+
self.label_attention_module = None
|
| 405 |
+
if LSTMModel.uses_lattn(self.args):
|
| 406 |
+
if self.partitioned_transformer_module is None:
|
| 407 |
+
logger.error("Not using Labeled Attention, as the Partitioned Attention module is not used")
|
| 408 |
+
else:
|
| 409 |
+
# TODO: think of a couple ways to use alternate inputs
|
| 410 |
+
# for example, could pass in the word inputs with a positional embedding
|
| 411 |
+
# that would also allow it to work in the case of no partitioned module
|
| 412 |
+
if self.args['lattn_combined_input']:
|
| 413 |
+
self.lattn_d_input = self.word_input_size
|
| 414 |
+
else:
|
| 415 |
+
self.lattn_d_input = self.pattn_d_model
|
| 416 |
+
self.label_attention_module = LabelAttentionModule(self.lattn_d_input,
|
| 417 |
+
self.args['lattn_d_input_proj'],
|
| 418 |
+
self.args['lattn_d_kv'],
|
| 419 |
+
self.args['lattn_d_kv'],
|
| 420 |
+
self.args['lattn_d_l'],
|
| 421 |
+
self.args['lattn_d_proj'],
|
| 422 |
+
self.args['lattn_combine_as_self'],
|
| 423 |
+
self.args['lattn_resdrop'],
|
| 424 |
+
self.args['lattn_q_as_matrix'],
|
| 425 |
+
self.args['lattn_residual_dropout'],
|
| 426 |
+
self.args['lattn_attention_dropout'],
|
| 427 |
+
self.pattn_d_model // 2,
|
| 428 |
+
self.args['lattn_d_ff'],
|
| 429 |
+
self.args['lattn_relu_dropout'],
|
| 430 |
+
self.args['lattn_partitioned'])
|
| 431 |
+
self.word_input_size = self.word_input_size + self.args['lattn_d_proj']*self.args['lattn_d_l']
|
| 432 |
+
|
| 433 |
+
self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
|
| 434 |
+
|
| 435 |
+
# after putting the word_delta_tag input through the word_lstm, we get back
|
| 436 |
+
# hidden_size * 2 output with the front and back lstms concatenated.
|
| 437 |
+
# this transforms it into hidden_size with the values mixed together
|
| 438 |
+
self.word_to_constituent = nn.Linear(self.hidden_size * 2, self.hidden_size * self.num_tree_lstm_layers)
|
| 439 |
+
initialize_linear(self.word_to_constituent, self.args['nonlinearity'], self.hidden_size * 2)
|
| 440 |
+
|
| 441 |
+
self.transitions = sorted(list(transitions))
|
| 442 |
+
self.transition_map = { t: i for i, t in enumerate(self.transitions) }
|
| 443 |
+
# precompute tensors for the transitions
|
| 444 |
+
self.register_buffer('transition_tensors', torch.tensor(range(len(transitions)), requires_grad=False))
|
| 445 |
+
self.transition_embedding = nn.Embedding(num_embeddings = len(transitions),
|
| 446 |
+
embedding_dim = self.transition_embedding_dim)
|
| 447 |
+
nn.init.normal_(self.transition_embedding.weight, std=0.25)
|
| 448 |
+
if args['transition_stack'] == StackHistory.LSTM:
|
| 449 |
+
self.transition_stack = LSTMTreeStack(input_size=self.transition_embedding_dim,
|
| 450 |
+
hidden_size=self.transition_hidden_size,
|
| 451 |
+
num_lstm_layers=self.num_lstm_layers,
|
| 452 |
+
dropout=self.lstm_layer_dropout,
|
| 453 |
+
uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,
|
| 454 |
+
input_dropout=self.lstm_input_dropout)
|
| 455 |
+
elif args['transition_stack'] == StackHistory.ATTN:
|
| 456 |
+
self.transition_stack = TransformerTreeStack(input_size=self.transition_embedding_dim,
|
| 457 |
+
output_size=self.transition_hidden_size,
|
| 458 |
+
input_dropout=self.lstm_input_dropout,
|
| 459 |
+
use_position=True,
|
| 460 |
+
num_heads=args['transition_heads'])
|
| 461 |
+
else:
|
| 462 |
+
raise ValueError("Unhandled transition_stack StackHistory: {}".format(args['transition_stack']))
|
| 463 |
+
|
| 464 |
+
self.constituent_opens = sorted(list(constituent_opens))
|
| 465 |
+
# an embedding for the spot on the constituent LSTM taken up by the Open transitions
|
| 466 |
+
# the pattern when condensing constituents is embedding - con1 - con2 - con3 - embedding
|
| 467 |
+
# TODO: try the two ends have different embeddings?
|
| 468 |
+
self.constituent_open_map = { x: i for (i, x) in enumerate(self.constituent_opens) }
|
| 469 |
+
self.constituent_open_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),
|
| 470 |
+
embedding_dim = self.hidden_size)
|
| 471 |
+
nn.init.normal_(self.constituent_open_embedding.weight, std=0.2)
|
| 472 |
+
|
| 473 |
+
# input_size is hidden_size - could introduce a new constituent_size instead if we liked
|
| 474 |
+
if args['constituent_stack'] == StackHistory.LSTM:
|
| 475 |
+
self.constituent_stack = LSTMTreeStack(input_size=self.hidden_size,
|
| 476 |
+
hidden_size=self.hidden_size,
|
| 477 |
+
num_lstm_layers=self.num_lstm_layers,
|
| 478 |
+
dropout=self.lstm_layer_dropout,
|
| 479 |
+
uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,
|
| 480 |
+
input_dropout=self.lstm_input_dropout)
|
| 481 |
+
elif args['constituent_stack'] == StackHistory.ATTN:
|
| 482 |
+
self.constituent_stack = TransformerTreeStack(input_size=self.hidden_size,
|
| 483 |
+
output_size=self.hidden_size,
|
| 484 |
+
input_dropout=self.lstm_input_dropout,
|
| 485 |
+
use_position=True,
|
| 486 |
+
num_heads=args['constituent_heads'])
|
| 487 |
+
else:
|
| 488 |
+
raise ValueError("Unhandled constituent_stack StackHistory: {}".format(args['transition_stack']))
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if args['combined_dummy_embedding']:
|
| 492 |
+
self.dummy_embedding = self.constituent_open_embedding
|
| 493 |
+
else:
|
| 494 |
+
self.dummy_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),
|
| 495 |
+
embedding_dim = self.hidden_size)
|
| 496 |
+
nn.init.normal_(self.dummy_embedding.weight, std=0.2)
|
| 497 |
+
self.register_buffer('constituent_open_tensors', torch.tensor(range(len(constituent_opens)), requires_grad=False))
|
| 498 |
+
|
| 499 |
+
# TODO: refactor
|
| 500 |
+
if (self.constituency_composition == ConstituencyComposition.BILSTM or
|
| 501 |
+
self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
|
| 502 |
+
# forward and backward pieces for crunching several
|
| 503 |
+
# constituents into one, combined into a bi-lstm
|
| 504 |
+
# TODO: make the hidden size here an option?
|
| 505 |
+
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
|
| 506 |
+
# affine transformation from bi-lstm reduce to a new hidden layer
|
| 507 |
+
if self.constituency_composition == ConstituencyComposition.BILSTM:
|
| 508 |
+
self.reduce_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)
|
| 509 |
+
initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size * 2)
|
| 510 |
+
else:
|
| 511 |
+
self.reduce_forward = nn.Linear(self.hidden_size, self.hidden_size)
|
| 512 |
+
self.reduce_backward = nn.Linear(self.hidden_size, self.hidden_size)
|
| 513 |
+
initialize_linear(self.reduce_forward, self.args['nonlinearity'], self.hidden_size)
|
| 514 |
+
initialize_linear(self.reduce_backward, self.args['nonlinearity'], self.hidden_size)
|
| 515 |
+
elif self.constituency_composition == ConstituencyComposition.MAX:
|
| 516 |
+
# transformation to turn several constituents into one new constituent
|
| 517 |
+
self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)
|
| 518 |
+
initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)
|
| 519 |
+
elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
|
| 520 |
+
# transformation to turn several constituents into one new constituent
|
| 521 |
+
self.register_parameter('reduce_linear_weight', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, self.hidden_size, requires_grad=True)))
|
| 522 |
+
self.register_parameter('reduce_linear_bias', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, requires_grad=True)))
|
| 523 |
+
for layer_idx in range(len(constituent_opens)):
|
| 524 |
+
nn.init.kaiming_normal_(self.reduce_linear_weight[layer_idx], nonlinearity=self.args['nonlinearity'])
|
| 525 |
+
nn.init.uniform_(self.reduce_linear_bias, 0, 1 / (self.hidden_size * 2) ** 0.5)
|
| 526 |
+
elif self.constituency_composition == ConstituencyComposition.BIGRAM:
|
| 527 |
+
self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)
|
| 528 |
+
self.reduce_bigram = nn.Linear(self.hidden_size * 2, self.hidden_size)
|
| 529 |
+
initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)
|
| 530 |
+
initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size)
|
| 531 |
+
elif self.constituency_composition == ConstituencyComposition.ATTN:
|
| 532 |
+
self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads)
|
| 533 |
+
elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
|
| 534 |
+
if self.args['reduce_position']:
|
| 535 |
+
# unsaved module so that if it grows, we don't save
|
| 536 |
+
# the larger version unnecessarily
|
| 537 |
+
# under any normal circumstances, the growth will
|
| 538 |
+
# happen early in training when the model is not
|
| 539 |
+
# behaving well, then will not be needed once the
|
| 540 |
+
# model learns not to make super degenerate
|
| 541 |
+
# constituents
|
| 542 |
+
self.add_unsaved_module("reduce_position", ConcatSinusoidalEncoding(self.args['reduce_position'], 50))
|
| 543 |
+
else:
|
| 544 |
+
self.add_unsaved_module("reduce_position", nn.Identity())
|
| 545 |
+
self.reduce_query = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size, bias=False)
|
| 546 |
+
self.reduce_value = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size)
|
| 547 |
+
if self.constituency_composition == ConstituencyComposition.KEY:
|
| 548 |
+
self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
|
| 549 |
+
else:
|
| 550 |
+
self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(len(constituent_opens), self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
|
| 551 |
+
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM:
|
| 552 |
+
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
|
| 553 |
+
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:
|
| 554 |
+
self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2,
|
| 555 |
+
embedding_dim = self.num_tree_lstm_layers * self.hidden_size)
|
| 556 |
+
self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
|
| 557 |
+
else:
|
| 558 |
+
raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
|
| 559 |
+
|
| 560 |
+
self.nonlinearity = build_nonlinearity(self.args['nonlinearity'])
|
| 561 |
+
|
| 562 |
+
# matrix for predicting the next transition using word/constituent/transition queues
|
| 563 |
+
# word size + constituency size + transition size
|
| 564 |
+
# TODO: .get() is only necessary until all models rebuilt with this param
|
| 565 |
+
self.maxout_k = self.args.get('maxout_k', 0)
|
| 566 |
+
self.output_layers = self.build_output_layers(self.args['num_output_layers'], len(transitions), self.maxout_k)
|
| 567 |
+
|
| 568 |
+
@staticmethod
|
| 569 |
+
def uses_lattn(args):
|
| 570 |
+
return args.get('use_lattn', True) and args.get('lattn_d_proj', 0) > 0 and args.get('lattn_d_l', 0) > 0
|
| 571 |
+
|
| 572 |
+
@staticmethod
|
| 573 |
+
def uses_pattn(args):
|
| 574 |
+
return args['pattn_num_heads'] > 0 and args['pattn_num_layers'] > 0
|
| 575 |
+
|
| 576 |
+
def copy_with_new_structure(self, other):
|
| 577 |
+
"""
|
| 578 |
+
Copy parameters from the other model to this model
|
| 579 |
+
|
| 580 |
+
word_lstm can change size if the other model didn't use pattn / lattn and this one does.
|
| 581 |
+
In that case, the new values are initialized to 0.
|
| 582 |
+
This will rebuild the model in such a way that the outputs will be
|
| 583 |
+
exactly the same as the previous model.
|
| 584 |
+
"""
|
| 585 |
+
if self.constituency_composition != other.constituency_composition and self.constituency_composition != ConstituencyComposition.UNTIED_MAX:
|
| 586 |
+
raise ValueError("Models are incompatible: self.constituency_composition == {}, other.constituency_composition == {}".format(self.constituency_composition, other.constituency_composition))
|
| 587 |
+
for name, other_parameter in other.named_parameters():
|
| 588 |
+
# this allows other.constituency_composition == UNTIED_MAX to fall through
|
| 589 |
+
if name.startswith('reduce_linear.') and self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
|
| 590 |
+
if name == 'reduce_linear.weight':
|
| 591 |
+
my_parameter = self.reduce_linear_weight
|
| 592 |
+
elif name == 'reduce_linear.bias':
|
| 593 |
+
my_parameter = self.reduce_linear_bias
|
| 594 |
+
else:
|
| 595 |
+
raise ValueError("Unexpected other parameter name {}".format(name))
|
| 596 |
+
for idx in range(len(self.constituent_opens)):
|
| 597 |
+
my_parameter[idx].data.copy_(other_parameter.data)
|
| 598 |
+
elif name.startswith('word_lstm.weight_ih_l0'):
|
| 599 |
+
# bottom layer shape may have changed from adding a new pattn / lattn block
|
| 600 |
+
my_parameter = self.get_parameter(name)
|
| 601 |
+
# -1 so that it can be converted easier to a different parameter
|
| 602 |
+
copy_size = min(other_parameter.data.shape[-1], my_parameter.data.shape[-1])
|
| 603 |
+
#new_values = my_parameter.data.clone().detach()
|
| 604 |
+
new_values = torch.zeros_like(my_parameter.data)
|
| 605 |
+
new_values[..., :copy_size] = other_parameter.data[..., :copy_size]
|
| 606 |
+
my_parameter.data.copy_(new_values)
|
| 607 |
+
else:
|
| 608 |
+
try:
|
| 609 |
+
self.get_parameter(name).data.copy_(other_parameter.data)
|
| 610 |
+
except AttributeError as e:
|
| 611 |
+
raise AttributeError("Could not process %s" % name) from e
|
| 612 |
+
|
| 613 |
+
def build_output_layers(self, num_output_layers, final_layer_size, maxout_k):
|
| 614 |
+
"""
|
| 615 |
+
Build a ModuleList of Linear transformations for the given num_output_layers
|
| 616 |
+
|
| 617 |
+
The final layer size can be specified.
|
| 618 |
+
Initial layer size is the combination of word, constituent, and transition vectors
|
| 619 |
+
Middle layer sizes are self.hidden_size
|
| 620 |
+
"""
|
| 621 |
+
middle_layers = num_output_layers - 1
|
| 622 |
+
# word_lstm: hidden_size * num_tree_lstm_layers
|
| 623 |
+
# transition_stack: transition_hidden_size
|
| 624 |
+
# constituent_stack: hidden_size
|
| 625 |
+
predict_input_size = [self.hidden_size + self.hidden_size * self.num_tree_lstm_layers + self.transition_hidden_size] + [self.hidden_size] * middle_layers
|
| 626 |
+
predict_output_size = [self.hidden_size] * middle_layers + [final_layer_size]
|
| 627 |
+
if not maxout_k:
|
| 628 |
+
output_layers = nn.ModuleList([nn.Linear(input_size, output_size)
|
| 629 |
+
for input_size, output_size in zip(predict_input_size, predict_output_size)])
|
| 630 |
+
for output_layer, input_size in zip(output_layers, predict_input_size):
|
| 631 |
+
initialize_linear(output_layer, self.args['nonlinearity'], input_size)
|
| 632 |
+
else:
|
| 633 |
+
output_layers = nn.ModuleList([MaxoutLinear(input_size, output_size, maxout_k)
|
| 634 |
+
for input_size, output_size in zip(predict_input_size, predict_output_size)])
|
| 635 |
+
return output_layers
|
| 636 |
+
|
| 637 |
+
def num_words_known(self, words):
|
| 638 |
+
return sum(word in self.vocab_map or word.lower() in self.vocab_map for word in words)
|
| 639 |
+
|
| 640 |
+
@property
|
| 641 |
+
def retag_method(self):
|
| 642 |
+
# TODO: make the method an enum
|
| 643 |
+
return self.args['retag_method']
|
| 644 |
+
|
| 645 |
+
def uses_xpos(self):
|
| 646 |
+
return self.args['retag_package'] is not None and self.args['retag_method'] == 'xpos'
|
| 647 |
+
|
| 648 |
+
def add_unsaved_module(self, name, module):
|
| 649 |
+
"""
|
| 650 |
+
Adds a module which will not be saved to disk
|
| 651 |
+
|
| 652 |
+
Best used for large models such as pretrained word embeddings
|
| 653 |
+
"""
|
| 654 |
+
self.unsaved_modules += [name]
|
| 655 |
+
setattr(self, name, module)
|
| 656 |
+
if module is not None and name in ('forward_charlm', 'backward_charlm'):
|
| 657 |
+
for _, parameter in module.named_parameters():
|
| 658 |
+
parameter.requires_grad = False
|
| 659 |
+
|
| 660 |
+
def is_unsaved_module(self, name):
|
| 661 |
+
return name.split('.')[0] in self.unsaved_modules
|
| 662 |
+
|
| 663 |
+
def get_norms(self):
|
| 664 |
+
lines = []
|
| 665 |
+
skip = set()
|
| 666 |
+
if self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
|
| 667 |
+
skip = {'reduce_linear_weight', 'reduce_linear_bias'}
|
| 668 |
+
lines.append("reduce_linear:")
|
| 669 |
+
for c_idx, c_open in enumerate(self.constituent_opens):
|
| 670 |
+
lines.append(" %s weight %.6g bias %.6g" % (c_open, torch.norm(self.reduce_linear_weight[c_idx]).item(), torch.norm(self.reduce_linear_bias[c_idx]).item()))
|
| 671 |
+
active_params = [(name, param) for name, param in self.named_parameters() if param.requires_grad and name not in skip]
|
| 672 |
+
if len(active_params) == 0:
|
| 673 |
+
return lines
|
| 674 |
+
print(len(active_params))
|
| 675 |
+
|
| 676 |
+
max_name_len = max(len(name) for name, param in active_params)
|
| 677 |
+
max_norm_len = max(len("%.6g" % torch.norm(param).item()) for name, param in active_params)
|
| 678 |
+
format_string = "%-" + str(max_name_len) + "s norm %" + str(max_norm_len) + "s zeros %d / %d"
|
| 679 |
+
for name, param in active_params:
|
| 680 |
+
zeros = torch.sum(param.abs() < 0.000001).item()
|
| 681 |
+
norm = "%.6g" % torch.norm(param).item()
|
| 682 |
+
lines.append(format_string % (name, norm, zeros, param.nelement()))
|
| 683 |
+
return lines
|
| 684 |
+
|
| 685 |
+
def log_norms(self):
|
| 686 |
+
lines = ["NORMS FOR MODEL PARAMETERS"]
|
| 687 |
+
lines.extend(self.get_norms())
|
| 688 |
+
logger.info("\n".join(lines))
|
| 689 |
+
|
| 690 |
+
def log_shapes(self):
|
| 691 |
+
lines = ["NORMS FOR MODEL PARAMETERS"]
|
| 692 |
+
for name, param in self.named_parameters():
|
| 693 |
+
if param.requires_grad:
|
| 694 |
+
lines.append("{} {}".format(name, param.shape))
|
| 695 |
+
logger.info("\n".join(lines))
|
| 696 |
+
|
| 697 |
+
def initial_word_queues(self, tagged_word_lists):
|
| 698 |
+
"""
|
| 699 |
+
Produce initial word queues out of the model's LSTMs for use in the tagged word lists.
|
| 700 |
+
|
| 701 |
+
Operates in a batched fashion to reduce the runtime for the LSTM operations
|
| 702 |
+
"""
|
| 703 |
+
device = next(self.parameters()).device
|
| 704 |
+
|
| 705 |
+
vocab_map = self.vocab_map
|
| 706 |
+
def map_word(word):
|
| 707 |
+
idx = vocab_map.get(word, None)
|
| 708 |
+
if idx is not None:
|
| 709 |
+
return idx
|
| 710 |
+
return vocab_map.get(word.lower(), UNK_ID)
|
| 711 |
+
|
| 712 |
+
all_word_inputs = []
|
| 713 |
+
all_word_labels = [[word.children[0].label for word in tagged_words]
|
| 714 |
+
for tagged_words in tagged_word_lists]
|
| 715 |
+
|
| 716 |
+
for sentence_idx, tagged_words in enumerate(tagged_word_lists):
|
| 717 |
+
word_labels = all_word_labels[sentence_idx]
|
| 718 |
+
word_idx = torch.stack([self.vocab_tensors[map_word(word.children[0].label)] for word in tagged_words])
|
| 719 |
+
word_input = self.embedding(word_idx)
|
| 720 |
+
|
| 721 |
+
# this occasionally learns UNK at train time
|
| 722 |
+
if self.training:
|
| 723 |
+
delta_labels = [None if word in self.rare_words and random.random() < self.args['rare_word_unknown_frequency'] else word
|
| 724 |
+
for word in word_labels]
|
| 725 |
+
else:
|
| 726 |
+
delta_labels = word_labels
|
| 727 |
+
delta_idx = torch.stack([self.delta_tensors[self.delta_word_map.get(word, UNK_ID)] for word in delta_labels])
|
| 728 |
+
|
| 729 |
+
delta_input = self.delta_embedding(delta_idx)
|
| 730 |
+
word_inputs = [word_input, delta_input]
|
| 731 |
+
|
| 732 |
+
if self.tag_embedding_dim > 0:
|
| 733 |
+
if self.training:
|
| 734 |
+
tag_labels = [None if random.random() < self.args['tag_unknown_frequency'] else word.label for word in tagged_words]
|
| 735 |
+
else:
|
| 736 |
+
tag_labels = [word.label for word in tagged_words]
|
| 737 |
+
tag_idx = torch.stack([self.tag_tensors[self.tag_map.get(tag, UNK_ID)] for tag in tag_labels])
|
| 738 |
+
tag_input = self.tag_embedding(tag_idx)
|
| 739 |
+
word_inputs.append(tag_input)
|
| 740 |
+
|
| 741 |
+
all_word_inputs.append(word_inputs)
|
| 742 |
+
|
| 743 |
+
if self.forward_charlm is not None:
|
| 744 |
+
all_forward_chars = self.forward_charlm.build_char_representation(all_word_labels)
|
| 745 |
+
for word_inputs, forward_chars in zip(all_word_inputs, all_forward_chars):
|
| 746 |
+
word_inputs.append(forward_chars)
|
| 747 |
+
if self.backward_charlm is not None:
|
| 748 |
+
all_backward_chars = self.backward_charlm.build_char_representation(all_word_labels)
|
| 749 |
+
for word_inputs, backward_chars in zip(all_word_inputs, all_backward_chars):
|
| 750 |
+
word_inputs.append(backward_chars)
|
| 751 |
+
|
| 752 |
+
all_word_inputs = [torch.cat(word_inputs, dim=1) for word_inputs in all_word_inputs]
|
| 753 |
+
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
|
| 754 |
+
word_start = self.word_start_embedding.unsqueeze(0)
|
| 755 |
+
word_end = self.word_end_embedding.unsqueeze(0)
|
| 756 |
+
all_word_inputs = [torch.cat([word_start, word_inputs, word_end], dim=0) for word_inputs in all_word_inputs]
|
| 757 |
+
|
| 758 |
+
if self.bert_model is not None:
|
| 759 |
+
# BERT embedding extraction
|
| 760 |
+
# result will be len+2 for each sentence
|
| 761 |
+
# we will take 1:-1 if we don't care about the endpoints
|
| 762 |
+
bert_embeddings = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, all_word_labels, device,
|
| 763 |
+
keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE,
|
| 764 |
+
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
|
| 765 |
+
detach=not self.args['bert_finetune'] and not self.args['stage1_bert_finetune'],
|
| 766 |
+
peft_name=self.peft_name)
|
| 767 |
+
if self.bert_layer_mix is not None:
|
| 768 |
+
# add the average so that the default behavior is to
|
| 769 |
+
# take an average of the N layers, and anything else
|
| 770 |
+
# other than that needs to be learned
|
| 771 |
+
bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]
|
| 772 |
+
|
| 773 |
+
all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)]
|
| 774 |
+
|
| 775 |
+
# Extract partitioned representation
|
| 776 |
+
if self.partitioned_transformer_module is not None:
|
| 777 |
+
partitioned_embeddings = self.partitioned_transformer_module(None, all_word_inputs)
|
| 778 |
+
all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, partitioned_embeddings)]
|
| 779 |
+
|
| 780 |
+
# Extract Labeled Representation
|
| 781 |
+
if self.label_attention_module is not None:
|
| 782 |
+
if self.args['lattn_combined_input']:
|
| 783 |
+
labeled_representations = self.label_attention_module(all_word_inputs, tagged_word_lists)
|
| 784 |
+
else:
|
| 785 |
+
labeled_representations = self.label_attention_module(partitioned_embeddings, tagged_word_lists)
|
| 786 |
+
all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, labeled_representations)]
|
| 787 |
+
|
| 788 |
+
all_word_inputs = [self.word_dropout(word_inputs) for word_inputs in all_word_inputs]
|
| 789 |
+
packed_word_input = torch.nn.utils.rnn.pack_sequence(all_word_inputs, enforce_sorted=False)
|
| 790 |
+
word_output, _ = self.word_lstm(packed_word_input)
|
| 791 |
+
# would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear
|
| 792 |
+
# word_output will now be sentence x batch x 2*hidden_size
|
| 793 |
+
word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output)
|
| 794 |
+
# now sentence x batch x hidden_size
|
| 795 |
+
|
| 796 |
+
word_queues = []
|
| 797 |
+
for sentence_idx, tagged_words in enumerate(tagged_word_lists):
|
| 798 |
+
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
|
| 799 |
+
sentence_output = word_output[:len(tagged_words)+2, sentence_idx, :]
|
| 800 |
+
else:
|
| 801 |
+
sentence_output = word_output[:len(tagged_words), sentence_idx, :]
|
| 802 |
+
sentence_output = self.word_to_constituent(sentence_output)
|
| 803 |
+
sentence_output = self.nonlinearity(sentence_output)
|
| 804 |
+
# TODO: this makes it so constituents downstream are
|
| 805 |
+
# build with the outputs of the LSTM, not the word
|
| 806 |
+
# embeddings themselves. It is possible we want to
|
| 807 |
+
# transform the word_input to hidden_size in some way
|
| 808 |
+
# and use that instead
|
| 809 |
+
if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
|
| 810 |
+
word_queue = [WordNode(None, sentence_output[0, :])]
|
| 811 |
+
word_queue += [WordNode(tag_node, sentence_output[idx+1, :])
|
| 812 |
+
for idx, tag_node in enumerate(tagged_words)]
|
| 813 |
+
word_queue.append(WordNode(None, sentence_output[len(tagged_words)+1, :]))
|
| 814 |
+
else:
|
| 815 |
+
word_queue = [WordNode(None, self.word_zeros)]
|
| 816 |
+
word_queue += [WordNode(tag_node, sentence_output[idx, :])
|
| 817 |
+
for idx, tag_node in enumerate(tagged_words)]
|
| 818 |
+
word_queue.append(WordNode(None, self.word_zeros))
|
| 819 |
+
|
| 820 |
+
if self.reverse_sentence:
|
| 821 |
+
word_queue = list(reversed(word_queue))
|
| 822 |
+
word_queues.append(word_queue)
|
| 823 |
+
|
| 824 |
+
return word_queues
|
| 825 |
+
|
| 826 |
+
def initial_transitions(self):
|
| 827 |
+
"""
|
| 828 |
+
Return an initial TreeStack with no transitions
|
| 829 |
+
"""
|
| 830 |
+
return self.transition_stack.initial_state()
|
| 831 |
+
|
| 832 |
+
def initial_constituents(self):
|
| 833 |
+
"""
|
| 834 |
+
Return an initial TreeStack with no constituents
|
| 835 |
+
"""
|
| 836 |
+
return self.constituent_stack.initial_state(Constituent(None, self.constituent_zeros, self.constituent_zeros))
|
| 837 |
+
|
| 838 |
+
def get_word(self, word_node):
|
| 839 |
+
return word_node.value
|
| 840 |
+
|
| 841 |
+
def transform_word_to_constituent(self, state):
|
| 842 |
+
word_node = state.get_word(state.word_position)
|
| 843 |
+
word = word_node.value
|
| 844 |
+
if self.constituency_composition == ConstituencyComposition.TREE_LSTM:
|
| 845 |
+
return Constituent(word, word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size), self.word_zeros.view(self.num_tree_lstm_layers, self.hidden_size))
|
| 846 |
+
elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:
|
| 847 |
+
# the UNK tag will be trained thanks to occasionally dropping out tags
|
| 848 |
+
tag = word.label
|
| 849 |
+
tree_hx = word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size)
|
| 850 |
+
tag_tensor = self.tag_tensors[self.tag_map.get(tag, UNK_ID)]
|
| 851 |
+
tree_cx = self.constituent_reduce_embedding(tag_tensor)
|
| 852 |
+
tree_cx = tree_cx.view(self.num_tree_lstm_layers, self.hidden_size)
|
| 853 |
+
return Constituent(word, tree_hx, tree_cx * tree_hx)
|
| 854 |
+
else:
|
| 855 |
+
return Constituent(word, word_node.hx[:self.hidden_size].unsqueeze(0), None)
|
| 856 |
+
|
| 857 |
+
def dummy_constituent(self, dummy):
|
| 858 |
+
label = dummy.label
|
| 859 |
+
open_index = self.constituent_open_tensors[self.constituent_open_map[label]]
|
| 860 |
+
hx = self.dummy_embedding(open_index)
|
| 861 |
+
# the cx doesn't matter: the dummy will be discarded when building a new constituent
|
| 862 |
+
return Constituent(dummy, hx.unsqueeze(0), None)
|
| 863 |
+
|
| 864 |
+
def build_constituents(self, labels, children_lists):
|
| 865 |
+
"""
|
| 866 |
+
Build new constituents with the given label from the list of children
|
| 867 |
+
|
| 868 |
+
labels is a list of labels for each of the new nodes to construct
|
| 869 |
+
children_lists is a list of children that go under each of the new nodes
|
| 870 |
+
lists of each are used so that we can stack operations
|
| 871 |
+
"""
|
| 872 |
+
# at the end of each of these operations, we expect lstm_hx.shape
|
| 873 |
+
# is (L, N, hidden_size) for N lists of children
|
| 874 |
+
if (self.constituency_composition == ConstituencyComposition.BILSTM or
|
| 875 |
+
self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
|
| 876 |
+
node_hx = [[child.value.tree_hx.squeeze(0) for child in children] for children in children_lists]
|
| 877 |
+
label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]
|
| 878 |
+
|
| 879 |
+
max_length = max(len(children) for children in children_lists)
|
| 880 |
+
zeros = torch.zeros(self.hidden_size, device=label_hx[0].device)
|
| 881 |
+
# weirdly, this is faster than using pack_sequence
|
| 882 |
+
unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx)) for lhx, nhx in zip(label_hx, node_hx)]
|
| 883 |
+
unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx]
|
| 884 |
+
packed_hx = torch.stack(unpacked_hx, axis=1)
|
| 885 |
+
packed_hx = torch.nn.utils.rnn.pack_padded_sequence(packed_hx, [len(x)+2 for x in children_lists], enforce_sorted=False)
|
| 886 |
+
lstm_output = self.constituent_reduce_lstm(packed_hx)
|
| 887 |
+
# take just the output of the final layer
|
| 888 |
+
# result of lstm is ouput, (hx, cx)
|
| 889 |
+
# so [1][0] gets hx
|
| 890 |
+
# [1][0][-1] is the final output
|
| 891 |
+
# will be shape len(children_lists) * 2, hidden_size for bidirectional
|
| 892 |
+
# where forward outputs are -2 and backwards are -1
|
| 893 |
+
if self.constituency_composition == ConstituencyComposition.BILSTM:
|
| 894 |
+
lstm_output = lstm_output[1][0]
|
| 895 |
+
forward_hx = lstm_output[-2, :, :]
|
| 896 |
+
backward_hx = lstm_output[-1, :, :]
|
| 897 |
+
hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1))
|
| 898 |
+
else:
|
| 899 |
+
lstm_output, lstm_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_output[0])
|
| 900 |
+
lstm_output = [lstm_output[1:length-1, x, :] for x, length in zip(range(len(lstm_lengths)), lstm_lengths)]
|
| 901 |
+
lstm_output = torch.stack([torch.max(x, 0).values for x in lstm_output], axis=0)
|
| 902 |
+
hx = self.reduce_forward(lstm_output[:, :self.hidden_size]) + self.reduce_backward(lstm_output[:, self.hidden_size:])
|
| 903 |
+
lstm_hx = self.nonlinearity(hx).unsqueeze(0)
|
| 904 |
+
lstm_cx = None
|
| 905 |
+
elif self.constituency_composition == ConstituencyComposition.MAX:
|
| 906 |
+
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
|
| 907 |
+
unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]
|
| 908 |
+
packed_hx = torch.stack(unpacked_hx, axis=1)
|
| 909 |
+
hx = self.reduce_linear(packed_hx)
|
| 910 |
+
lstm_hx = self.nonlinearity(hx)
|
| 911 |
+
lstm_cx = None
|
| 912 |
+
elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
|
| 913 |
+
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
|
| 914 |
+
unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]
|
| 915 |
+
# shape == len(labels),1,hidden_size after the stack
|
| 916 |
+
#packed_hx = torch.stack(unpacked_hx, axis=0)
|
| 917 |
+
label_indices = [self.constituent_open_map[label] for label in labels]
|
| 918 |
+
# we would like to stack the reduce_linear_weight calculations as follows:
|
| 919 |
+
#reduce_weight = self.reduce_linear_weight[label_indices]
|
| 920 |
+
#reduce_bias = self.reduce_linear_bias[label_indices]
|
| 921 |
+
# this would allow for faster vectorized operations.
|
| 922 |
+
# however, this runs out of memory on larger training examples,
|
| 923 |
+
# presumably because there are too many stacks in a row and each one
|
| 924 |
+
# has its own gradient kept for the entire calculation
|
| 925 |
+
# fortunately, this operation is not a huge part of the expense
|
| 926 |
+
hx = [torch.matmul(self.reduce_linear_weight[label_idx], hx_layer.squeeze(0)) + self.reduce_linear_bias[label_idx]
|
| 927 |
+
for label_idx, hx_layer in zip(label_indices, unpacked_hx)]
|
| 928 |
+
hx = torch.stack(hx, axis=0)
|
| 929 |
+
hx = hx.unsqueeze(0)
|
| 930 |
+
lstm_hx = self.nonlinearity(hx)
|
| 931 |
+
lstm_cx = None
|
| 932 |
+
elif self.constituency_composition == ConstituencyComposition.BIGRAM:
|
| 933 |
+
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
|
| 934 |
+
unpacked_hx = []
|
| 935 |
+
for nhx in node_hx:
|
| 936 |
+
# tanh or otherwise limit the size of the output?
|
| 937 |
+
stacked_nhx = self.lstm_input_dropout(torch.cat(nhx, axis=0))
|
| 938 |
+
if stacked_nhx.shape[0] > 1:
|
| 939 |
+
bigram_hx = torch.cat((stacked_nhx[:-1, :], stacked_nhx[1:, :]), axis=1)
|
| 940 |
+
bigram_hx = self.reduce_bigram(bigram_hx) / 2
|
| 941 |
+
stacked_nhx = torch.cat((stacked_nhx, bigram_hx), axis=0)
|
| 942 |
+
unpacked_hx.append(torch.max(stacked_nhx, 0).values)
|
| 943 |
+
packed_hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)
|
| 944 |
+
hx = self.reduce_linear(packed_hx)
|
| 945 |
+
lstm_hx = self.nonlinearity(hx)
|
| 946 |
+
lstm_cx = None
|
| 947 |
+
elif self.constituency_composition == ConstituencyComposition.ATTN:
|
| 948 |
+
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
|
| 949 |
+
label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]
|
| 950 |
+
unpacked_hx = [torch.stack(nhx) for nhx in node_hx]
|
| 951 |
+
unpacked_hx = [torch.cat((lhx.unsqueeze(0).unsqueeze(0), nhx), axis=0) for lhx, nhx in zip(label_hx, unpacked_hx)]
|
| 952 |
+
unpacked_hx = [self.reduce_attn(nhx, nhx, nhx)[0].squeeze(1) for nhx in unpacked_hx]
|
| 953 |
+
unpacked_hx = [self.lstm_input_dropout(torch.max(nhx, 0).values) for nhx in unpacked_hx]
|
| 954 |
+
hx = torch.stack(unpacked_hx, axis=0)
|
| 955 |
+
lstm_hx = self.nonlinearity(hx).unsqueeze(0)
|
| 956 |
+
lstm_cx = None
|
| 957 |
+
elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
|
| 958 |
+
node_hx = [torch.stack([child.value.tree_hx for child in children]) for children in children_lists]
|
| 959 |
+
# add a position vector to each node_hx
|
| 960 |
+
node_hx = [self.reduce_position(x.reshape(x.shape[0], -1)) for x in node_hx]
|
| 961 |
+
query_hx = [self.reduce_query(nhx) for nhx in node_hx]
|
| 962 |
+
# reshape query for MHA
|
| 963 |
+
query_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in query_hx]
|
| 964 |
+
if self.constituency_composition == ConstituencyComposition.KEY:
|
| 965 |
+
queries = [torch.matmul(nhx, self.reduce_key) for nhx in query_hx]
|
| 966 |
+
else:
|
| 967 |
+
label_indices = [self.constituent_open_map[label] for label in labels]
|
| 968 |
+
queries = [torch.matmul(nhx, self.reduce_key[label_idx]) for nhx, label_idx in zip(query_hx, label_indices)]
|
| 969 |
+
# softmax each head
|
| 970 |
+
weights = [torch.nn.functional.softmax(nhx, dim=1).transpose(1, 2) for nhx in queries]
|
| 971 |
+
value_hx = [self.reduce_value(nhx) for nhx in node_hx]
|
| 972 |
+
value_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in value_hx]
|
| 973 |
+
# use the softmaxes to add up the heads
|
| 974 |
+
unpacked_hx = [torch.matmul(weight, nhx).squeeze(1) for weight, nhx in zip(weights, value_hx)]
|
| 975 |
+
unpacked_hx = [nhx.reshape(-1) for nhx in unpacked_hx]
|
| 976 |
+
hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)
|
| 977 |
+
lstm_hx = self.nonlinearity(hx)
|
| 978 |
+
lstm_cx = None
|
| 979 |
+
elif self.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
|
| 980 |
+
label_hx = [self.lstm_input_dropout(self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]])) for label in labels]
|
| 981 |
+
label_hx = torch.stack(label_hx).unsqueeze(0)
|
| 982 |
+
|
| 983 |
+
max_length = max(len(children) for children in children_lists)
|
| 984 |
+
|
| 985 |
+
# stacking will let us do elementwise multiplication faster, hopefully
|
| 986 |
+
node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
|
| 987 |
+
unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in node_hx]
|
| 988 |
+
unpacked_hx = [nhx.max(dim=0) for nhx in unpacked_hx]
|
| 989 |
+
packed_hx = torch.stack([nhx.values for nhx in unpacked_hx], axis=1)
|
| 990 |
+
#packed_hx = packed_hx.max(dim=0).values
|
| 991 |
+
|
| 992 |
+
node_cx = [torch.stack([child.value.tree_cx for child in children]) for children in children_lists]
|
| 993 |
+
node_cx_indices = [uhx.indices.unsqueeze(0) for uhx in unpacked_hx]
|
| 994 |
+
unpacked_cx = [ncx.gather(0, nci).squeeze(0) for ncx, nci in zip(node_cx, node_cx_indices)]
|
| 995 |
+
packed_cx = torch.stack(unpacked_cx, axis=1)
|
| 996 |
+
|
| 997 |
+
_, (lstm_hx, lstm_cx) = self.constituent_reduce_lstm(label_hx, (packed_hx, packed_cx))
|
| 998 |
+
else:
|
| 999 |
+
raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
|
| 1000 |
+
|
| 1001 |
+
constituents = []
|
| 1002 |
+
for idx, (label, children) in enumerate(zip(labels, children_lists)):
|
| 1003 |
+
children = [child.value.value for child in children]
|
| 1004 |
+
if isinstance(label, str):
|
| 1005 |
+
node = Tree(label=label, children=children)
|
| 1006 |
+
else:
|
| 1007 |
+
for value in reversed(label):
|
| 1008 |
+
node = Tree(label=value, children=children)
|
| 1009 |
+
children = node
|
| 1010 |
+
constituents.append(Constituent(node, lstm_hx[:, idx, :], lstm_cx[:, idx, :] if lstm_cx is not None else None))
|
| 1011 |
+
return constituents
|
| 1012 |
+
|
| 1013 |
+
def push_constituents(self, constituent_stacks, constituents):
|
| 1014 |
+
# Another possibility here would be to use output[0, i, :]
|
| 1015 |
+
# from the constituency lstm for the value of the new node.
|
| 1016 |
+
# This might theoretically make the new constituent include
|
| 1017 |
+
# information from neighboring constituents. However, this
|
| 1018 |
+
# lowers the scores of various models.
|
| 1019 |
+
# For example, an experiment on ja_alt built this way,
|
| 1020 |
+
# averaged over 5 trials, had the following loss in accuracy:
|
| 1021 |
+
# 150 epochs: 0.8971 to 0.8953
|
| 1022 |
+
# 200 epochs: 0.8985 to 0.8964
|
| 1023 |
+
current_nodes = [stack.value for stack in constituent_stacks]
|
| 1024 |
+
|
| 1025 |
+
constituent_input = torch.stack([x.tree_hx[-1:] for x in constituents], axis=1)
|
| 1026 |
+
#constituent_input = constituent_input.unsqueeze(0)
|
| 1027 |
+
# the constituents are already Constituent(tree, tree_hx, tree_cx)
|
| 1028 |
+
return self.constituent_stack.push_states(constituent_stacks, constituents, constituent_input)
|
| 1029 |
+
|
| 1030 |
+
def get_top_constituent(self, constituents):
|
| 1031 |
+
"""
|
| 1032 |
+
Extract only the top constituent from a state's constituent
|
| 1033 |
+
sequence, even though it has multiple addition pieces of
|
| 1034 |
+
information
|
| 1035 |
+
"""
|
| 1036 |
+
# TreeStack value -> LSTMTreeStack value -> Constituent value -> constituent
|
| 1037 |
+
return constituents.value.value.value
|
| 1038 |
+
|
| 1039 |
+
def push_transitions(self, transition_stacks, transitions):
|
| 1040 |
+
"""
|
| 1041 |
+
Push all of the given transitions on to the stack as a batch operations.
|
| 1042 |
+
|
| 1043 |
+
Significantly faster than doing one transition at a time.
|
| 1044 |
+
"""
|
| 1045 |
+
transition_idx = torch.stack([self.transition_tensors[self.transition_map[transition]] for transition in transitions])
|
| 1046 |
+
transition_input = self.transition_embedding(transition_idx).unsqueeze(0)
|
| 1047 |
+
return self.transition_stack.push_states(transition_stacks, transitions, transition_input)
|
| 1048 |
+
|
| 1049 |
+
def get_top_transition(self, transitions):
|
| 1050 |
+
"""
|
| 1051 |
+
Extract only the top transition from a state's transition
|
| 1052 |
+
sequence, even though it has multiple addition pieces of
|
| 1053 |
+
information
|
| 1054 |
+
"""
|
| 1055 |
+
# TreeStack value -> LSTMTreeStack value -> transition
|
| 1056 |
+
return transitions.value.value
|
| 1057 |
+
|
| 1058 |
+
def forward(self, states):
|
| 1059 |
+
"""
|
| 1060 |
+
Return logits for a prediction of what transition to make next
|
| 1061 |
+
|
| 1062 |
+
We've basically done all the work analyzing the state as
|
| 1063 |
+
part of applying the transitions, so this method is very simple
|
| 1064 |
+
|
| 1065 |
+
return shape: (num_states, num_transitions)
|
| 1066 |
+
"""
|
| 1067 |
+
word_hx = torch.stack([state.get_word(state.word_position).hx for state in states])
|
| 1068 |
+
transition_hx = torch.stack([self.transition_stack.output(state.transitions) for state in states])
|
| 1069 |
+
# this .output() is the output of the constituent stack, not the
|
| 1070 |
+
# constituent itself
|
| 1071 |
+
# this way, we can, as an option, NOT include the constituents to the left
|
| 1072 |
+
# when building the current vector for a constituent
|
| 1073 |
+
# and the vector used for inference will still incorporate the entire LSTM
|
| 1074 |
+
constituent_hx = torch.stack([self.constituent_stack.output(state.constituents) for state in states])
|
| 1075 |
+
|
| 1076 |
+
hx = torch.cat((word_hx, transition_hx, constituent_hx), axis=1)
|
| 1077 |
+
for idx, output_layer in enumerate(self.output_layers):
|
| 1078 |
+
hx = self.predict_dropout(hx)
|
| 1079 |
+
if not self.maxout_k and idx < len(self.output_layers) - 1:
|
| 1080 |
+
hx = self.nonlinearity(hx)
|
| 1081 |
+
hx = output_layer(hx)
|
| 1082 |
+
return hx
|
| 1083 |
+
|
| 1084 |
+
def predict(self, states, is_legal=True):
|
| 1085 |
+
"""
|
| 1086 |
+
Generate and return predictions, along with the transitions those predictions represent
|
| 1087 |
+
|
| 1088 |
+
If is_legal is set to True, will only return legal transitions.
|
| 1089 |
+
This means returning None if there are no legal transitions.
|
| 1090 |
+
Hopefully the constraints prevent that from happening
|
| 1091 |
+
"""
|
| 1092 |
+
predictions = self.forward(states)
|
| 1093 |
+
pred_max = torch.argmax(predictions, dim=1)
|
| 1094 |
+
scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)
|
| 1095 |
+
pred_max = pred_max.detach().cpu()
|
| 1096 |
+
|
| 1097 |
+
pred_trans = [self.transitions[pred_max[idx]] for idx in range(len(states))]
|
| 1098 |
+
if is_legal:
|
| 1099 |
+
for idx, (state, trans) in enumerate(zip(states, pred_trans)):
|
| 1100 |
+
if not trans.is_legal(state, self):
|
| 1101 |
+
_, indices = predictions[idx, :].sort(descending=True)
|
| 1102 |
+
for index in indices:
|
| 1103 |
+
if self.transitions[index].is_legal(state, self):
|
| 1104 |
+
pred_trans[idx] = self.transitions[index]
|
| 1105 |
+
scores[idx] = predictions[idx, index]
|
| 1106 |
+
break
|
| 1107 |
+
else: # yeah, else on a for loop, deal with it
|
| 1108 |
+
pred_trans[idx] = None
|
| 1109 |
+
scores[idx] = None
|
| 1110 |
+
|
| 1111 |
+
return predictions, pred_trans, scores.squeeze(1)
|
| 1112 |
+
|
| 1113 |
+
def weighted_choice(self, states):
|
| 1114 |
+
"""
|
| 1115 |
+
Generate and return predictions, and randomly choose a prediction weighted by the scores
|
| 1116 |
+
|
| 1117 |
+
TODO: pass in a temperature
|
| 1118 |
+
"""
|
| 1119 |
+
predictions = self.forward(states)
|
| 1120 |
+
pred_trans = []
|
| 1121 |
+
all_scores = []
|
| 1122 |
+
for state, prediction in zip(states, predictions):
|
| 1123 |
+
legal_idx = [idx for idx in range(prediction.shape[0]) if self.transitions[idx].is_legal(state, self)]
|
| 1124 |
+
if len(legal_idx) == 0:
|
| 1125 |
+
pred_trans.append(None)
|
| 1126 |
+
continue
|
| 1127 |
+
scores = prediction[legal_idx]
|
| 1128 |
+
scores = torch.softmax(scores, dim=0)
|
| 1129 |
+
idx = torch.multinomial(scores, 1)
|
| 1130 |
+
idx = legal_idx[idx]
|
| 1131 |
+
pred_trans.append(self.transitions[idx])
|
| 1132 |
+
all_scores.append(prediction[idx])
|
| 1133 |
+
all_scores = torch.stack(all_scores)
|
| 1134 |
+
return predictions, pred_trans, all_scores
|
| 1135 |
+
|
| 1136 |
+
def predict_gold(self, states):
|
| 1137 |
+
"""
|
| 1138 |
+
For each State, return the next item in the gold_sequence
|
| 1139 |
+
"""
|
| 1140 |
+
predictions = self.forward(states)
|
| 1141 |
+
transitions = [y.gold_sequence[y.num_transitions] for y in states]
|
| 1142 |
+
indices = torch.tensor([self.transition_map[t] for t in transitions], device=predictions.device)
|
| 1143 |
+
scores = torch.take_along_dim(predictions, indices.unsqueeze(1), dim=1)
|
| 1144 |
+
return predictions, transitions, scores.squeeze(1)
|
| 1145 |
+
|
| 1146 |
+
def get_params(self, skip_modules=True):
|
| 1147 |
+
"""
|
| 1148 |
+
Get a dictionary for saving the model
|
| 1149 |
+
"""
|
| 1150 |
+
model_state = self.state_dict()
|
| 1151 |
+
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
|
| 1152 |
+
if skip_modules:
|
| 1153 |
+
skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]
|
| 1154 |
+
for k in skipped:
|
| 1155 |
+
del model_state[k]
|
| 1156 |
+
config = copy.deepcopy(self.args)
|
| 1157 |
+
config['sentence_boundary_vectors'] = config['sentence_boundary_vectors'].name
|
| 1158 |
+
config['constituency_composition'] = config['constituency_composition'].name
|
| 1159 |
+
config['transition_stack'] = config['transition_stack'].name
|
| 1160 |
+
config['constituent_stack'] = config['constituent_stack'].name
|
| 1161 |
+
config['transition_scheme'] = config['transition_scheme'].name
|
| 1162 |
+
assert isinstance(self.rare_words, set)
|
| 1163 |
+
params = {
|
| 1164 |
+
'model': model_state,
|
| 1165 |
+
'model_type': "LSTM",
|
| 1166 |
+
'config': config,
|
| 1167 |
+
'transitions': [repr(x) for x in self.transitions],
|
| 1168 |
+
'constituents': self.constituents,
|
| 1169 |
+
'tags': self.tags,
|
| 1170 |
+
'words': self.delta_words,
|
| 1171 |
+
'rare_words': list(self.rare_words),
|
| 1172 |
+
'root_labels': self.root_labels,
|
| 1173 |
+
'constituent_opens': self.constituent_opens,
|
| 1174 |
+
'unary_limit': self.unary_limit(),
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
return params
|
| 1178 |
+
|
stanza/stanza/models/constituency/parse_tree.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tree datastructure
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import deque, Counter
|
| 6 |
+
import copy
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from io import StringIO
|
| 9 |
+
import itertools
|
| 10 |
+
import re
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
from stanza.models.common.stanza_object import StanzaObject
|
| 14 |
+
|
| 15 |
+
# useful more for the "is" functionality than the time savings
|
| 16 |
+
CLOSE_PAREN = ')'
|
| 17 |
+
SPACE_SEPARATOR = ' '
|
| 18 |
+
OPEN_PAREN = '('
|
| 19 |
+
|
| 20 |
+
EMPTY_CHILDREN = ()
|
| 21 |
+
|
| 22 |
+
# used to split off the functional tags from various treebanks
|
| 23 |
+
# for example, the Icelandic treebank (which we don't currently
|
| 24 |
+
# incorporate) uses * to distinguish 'ADJP', 'ADJP*OC' but we treat
|
| 25 |
+
# those as the same
|
| 26 |
+
CONSTITUENT_SPLIT = re.compile("[-=#*]")
|
| 27 |
+
|
| 28 |
+
# These words occur in the VLSP dataset.
|
| 29 |
+
# The documentation claims there might be *O*, although those don't
|
| 30 |
+
# seem to exist in practice
|
| 31 |
+
WORDS_TO_PRUNE = ('*E*', '*T*', '*O*')
|
| 32 |
+
|
| 33 |
+
class TreePrintMethod(Enum):
|
| 34 |
+
"""
|
| 35 |
+
Describes a few options for printing trees.
|
| 36 |
+
|
| 37 |
+
This probably doesn't need to be used directly. See __format__
|
| 38 |
+
"""
|
| 39 |
+
ONE_LINE = 1 # (ROOT (S ... ))
|
| 40 |
+
LABELED_PARENS = 2 # (_ROOT (_S ... )_S )_ROOT
|
| 41 |
+
PRETTY = 3 # multiple lines
|
| 42 |
+
VLSP = 4 # <s> (S ... ) </s>
|
| 43 |
+
LATEX_TREE = 5 # \Tree [.S [.NP ... ] ]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Tree(StanzaObject):
|
| 47 |
+
"""
|
| 48 |
+
A data structure to represent a parse tree
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, label=None, children=None):
|
| 51 |
+
if children is None:
|
| 52 |
+
self.children = EMPTY_CHILDREN
|
| 53 |
+
elif isinstance(children, Tree):
|
| 54 |
+
self.children = (children,)
|
| 55 |
+
else:
|
| 56 |
+
self.children = tuple(children)
|
| 57 |
+
|
| 58 |
+
self.label = label
|
| 59 |
+
|
| 60 |
+
def is_leaf(self):
|
| 61 |
+
return len(self.children) == 0
|
| 62 |
+
|
| 63 |
+
def is_preterminal(self):
|
| 64 |
+
return len(self.children) == 1 and len(self.children[0].children) == 0
|
| 65 |
+
|
| 66 |
+
def yield_preterminals(self):
|
| 67 |
+
"""
|
| 68 |
+
Yield the preterminals one at a time in order
|
| 69 |
+
"""
|
| 70 |
+
if self.is_preterminal():
|
| 71 |
+
yield self
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
if self.is_leaf():
|
| 75 |
+
raise ValueError("Attempted to iterate preterminals on non-internal node")
|
| 76 |
+
|
| 77 |
+
iterator = iter(self.children)
|
| 78 |
+
node = next(iterator, None)
|
| 79 |
+
while node is not None:
|
| 80 |
+
if node.is_preterminal():
|
| 81 |
+
yield node
|
| 82 |
+
else:
|
| 83 |
+
iterator = itertools.chain(node.children, iterator)
|
| 84 |
+
node = next(iterator, None)
|
| 85 |
+
|
| 86 |
+
def leaf_labels(self):
|
| 87 |
+
"""
|
| 88 |
+
Get the labels of the leaves
|
| 89 |
+
"""
|
| 90 |
+
if self.is_leaf():
|
| 91 |
+
return [self.label]
|
| 92 |
+
|
| 93 |
+
words = [x.children[0].label for x in self.yield_preterminals()]
|
| 94 |
+
return words
|
| 95 |
+
|
| 96 |
+
def __len__(self):
|
| 97 |
+
return len(self.leaf_labels())
|
| 98 |
+
|
| 99 |
+
def all_leaves_are_preterminals(self):
|
| 100 |
+
"""
|
| 101 |
+
Returns True if all leaves are under preterminals, False otherwise
|
| 102 |
+
"""
|
| 103 |
+
if self.is_leaf():
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
if self.is_preterminal():
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
return all(t.all_leaves_are_preterminals() for t in self.children)
|
| 110 |
+
|
| 111 |
+
def pretty_print(self, normalize=None):
|
| 112 |
+
"""
|
| 113 |
+
Print with newlines & indentation on each line
|
| 114 |
+
|
| 115 |
+
Preterminals and nodes with all preterminal children go on their own line
|
| 116 |
+
|
| 117 |
+
You can pass in your own normalize() function. If you do,
|
| 118 |
+
make sure the function updates the parens to be something
|
| 119 |
+
other than () or the brackets will be broken
|
| 120 |
+
"""
|
| 121 |
+
if normalize is None:
|
| 122 |
+
normalize = lambda x: x.replace("(", "-LRB-").replace(")", "-RRB-")
|
| 123 |
+
|
| 124 |
+
indent = 0
|
| 125 |
+
with StringIO() as buf:
|
| 126 |
+
stack = deque()
|
| 127 |
+
stack.append(self)
|
| 128 |
+
while len(stack) > 0:
|
| 129 |
+
node = stack.pop()
|
| 130 |
+
|
| 131 |
+
if node is CLOSE_PAREN:
|
| 132 |
+
# if we're trying to pretty print trees, pop all off close parens
|
| 133 |
+
# then write a newline
|
| 134 |
+
while node is CLOSE_PAREN:
|
| 135 |
+
indent -= 1
|
| 136 |
+
buf.write(CLOSE_PAREN)
|
| 137 |
+
if len(stack) == 0:
|
| 138 |
+
node = None
|
| 139 |
+
break
|
| 140 |
+
node = stack.pop()
|
| 141 |
+
buf.write("\n")
|
| 142 |
+
if node is None:
|
| 143 |
+
break
|
| 144 |
+
stack.append(node)
|
| 145 |
+
elif node.is_preterminal():
|
| 146 |
+
buf.write(" " * indent)
|
| 147 |
+
buf.write("%s%s %s%s" % (OPEN_PAREN, normalize(node.label), normalize(node.children[0].label), CLOSE_PAREN))
|
| 148 |
+
if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:
|
| 149 |
+
buf.write("\n")
|
| 150 |
+
elif all(x.is_preterminal() for x in node.children):
|
| 151 |
+
buf.write(" " * indent)
|
| 152 |
+
buf.write("%s%s" % (OPEN_PAREN, normalize(node.label)))
|
| 153 |
+
for child in node.children:
|
| 154 |
+
buf.write(" %s%s %s%s" % (OPEN_PAREN, normalize(child.label), normalize(child.children[0].label), CLOSE_PAREN))
|
| 155 |
+
buf.write(CLOSE_PAREN)
|
| 156 |
+
if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:
|
| 157 |
+
buf.write("\n")
|
| 158 |
+
else:
|
| 159 |
+
buf.write(" " * indent)
|
| 160 |
+
buf.write("%s%s\n" % (OPEN_PAREN, normalize(node.label)))
|
| 161 |
+
stack.append(CLOSE_PAREN)
|
| 162 |
+
for child in reversed(node.children):
|
| 163 |
+
stack.append(child)
|
| 164 |
+
indent += 1
|
| 165 |
+
|
| 166 |
+
buf.seek(0)
|
| 167 |
+
return buf.read()
|
| 168 |
+
|
| 169 |
+
def __format__(self, spec):
|
| 170 |
+
"""
|
| 171 |
+
Turn the tree into a string representing the tree
|
| 172 |
+
|
| 173 |
+
Note that this is not a recursive traversal
|
| 174 |
+
Otherwise, a tree too deep might blow up the call stack
|
| 175 |
+
|
| 176 |
+
There is a type specific format:
|
| 177 |
+
O -> one line PTB format, which is the default anyway
|
| 178 |
+
L -> open and close brackets are labeled, spaces in the tokens are replaced with _
|
| 179 |
+
P -> pretty print over multiple lines
|
| 180 |
+
V -> surround lines with <s>...</s>, don't print ROOT, and turn () into L/RBKT
|
| 181 |
+
? -> spaces in the tokens are replaced with ? for any value of ? other than OLP
|
| 182 |
+
warning: this may be removed in the future
|
| 183 |
+
?{OLPV} -> specific format AND a custom space replacement
|
| 184 |
+
Vi -> add an ID to the <s> in the V format. Also works with ?Vi
|
| 185 |
+
"""
|
| 186 |
+
space_replacement = " "
|
| 187 |
+
print_format = TreePrintMethod.ONE_LINE
|
| 188 |
+
if spec == 'L':
|
| 189 |
+
print_format = TreePrintMethod.LABELED_PARENS
|
| 190 |
+
space_replacement = "_"
|
| 191 |
+
elif spec and spec[-1] == 'L':
|
| 192 |
+
print_format = TreePrintMethod.LABELED_PARENS
|
| 193 |
+
space_replacement = spec[0]
|
| 194 |
+
elif spec == 'O':
|
| 195 |
+
print_format = TreePrintMethod.ONE_LINE
|
| 196 |
+
elif spec and spec[-1] == 'O':
|
| 197 |
+
print_format = TreePrintMethod.ONE_LINE
|
| 198 |
+
space_replacement = spec[0]
|
| 199 |
+
elif spec == 'P':
|
| 200 |
+
print_format = TreePrintMethod.PRETTY
|
| 201 |
+
elif spec and spec[-1] == 'P':
|
| 202 |
+
print_format = TreePrintMethod.PRETTY
|
| 203 |
+
space_replacement = spec[0]
|
| 204 |
+
elif spec and spec[0] == 'V':
|
| 205 |
+
print_format = TreePrintMethod.VLSP
|
| 206 |
+
use_tree_id = spec[-1] == 'i'
|
| 207 |
+
elif spec and len(spec) > 1 and spec[1] == 'V':
|
| 208 |
+
print_format = TreePrintMethod.VLSP
|
| 209 |
+
space_replacement = spec[0]
|
| 210 |
+
use_tree_id = spec[-1] == 'i'
|
| 211 |
+
elif spec == 'T':
|
| 212 |
+
print_format = TreePrintMethod.LATEX_TREE
|
| 213 |
+
elif spec and len(spec) > 1 and spec[1] == 'T':
|
| 214 |
+
print_format = TreePrintMethod.LATEX_TREE
|
| 215 |
+
space_replacement = spec[0]
|
| 216 |
+
elif spec:
|
| 217 |
+
space_replacement = spec[0]
|
| 218 |
+
warnings.warn("Use of a custom replacement without a format specifier is deprecated. Please use {}O instead".format(space_replacement), stacklevel=2)
|
| 219 |
+
|
| 220 |
+
LRB = "LBKT" if print_format == TreePrintMethod.VLSP else "-LRB-"
|
| 221 |
+
RRB = "RBKT" if print_format == TreePrintMethod.VLSP else "-RRB-"
|
| 222 |
+
def normalize(text):
|
| 223 |
+
return text.replace(" ", space_replacement).replace("(", LRB).replace(")", RRB)
|
| 224 |
+
|
| 225 |
+
if print_format is TreePrintMethod.PRETTY:
|
| 226 |
+
return self.pretty_print(normalize)
|
| 227 |
+
|
| 228 |
+
with StringIO() as buf:
|
| 229 |
+
stack = deque()
|
| 230 |
+
if print_format == TreePrintMethod.VLSP:
|
| 231 |
+
if use_tree_id:
|
| 232 |
+
buf.write("<s id={}>\n".format(self.tree_id))
|
| 233 |
+
else:
|
| 234 |
+
buf.write("<s>\n")
|
| 235 |
+
if len(self.children) == 0:
|
| 236 |
+
raise ValueError("Cannot print an empty tree with V format")
|
| 237 |
+
elif len(self.children) > 1:
|
| 238 |
+
raise ValueError("Cannot print a tree with %d branches with V format" % len(self.children))
|
| 239 |
+
stack.append(self.children[0])
|
| 240 |
+
elif print_format == TreePrintMethod.LATEX_TREE:
|
| 241 |
+
buf.write("\\Tree ")
|
| 242 |
+
if len(self.children) == 0:
|
| 243 |
+
raise ValueError("Cannot print an empty tree with T format")
|
| 244 |
+
elif len(self.children) == 1 and len(self.children[0].children) == 0:
|
| 245 |
+
buf.write("[.? ")
|
| 246 |
+
buf.write(normalize(self.children[0].label))
|
| 247 |
+
buf.write(" ]")
|
| 248 |
+
elif self.label == 'ROOT':
|
| 249 |
+
stack.append(self.children[0])
|
| 250 |
+
else:
|
| 251 |
+
stack.append(self)
|
| 252 |
+
else:
|
| 253 |
+
stack.append(self)
|
| 254 |
+
while len(stack) > 0:
|
| 255 |
+
node = stack.pop()
|
| 256 |
+
|
| 257 |
+
if isinstance(node, str):
|
| 258 |
+
buf.write(node)
|
| 259 |
+
continue
|
| 260 |
+
if len(node.children) == 0:
|
| 261 |
+
if node.label is not None:
|
| 262 |
+
buf.write(normalize(node.label))
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
if print_format is TreePrintMethod.LATEX_TREE:
|
| 266 |
+
if node.is_preterminal():
|
| 267 |
+
buf.write(normalize(node.children[0].label))
|
| 268 |
+
continue
|
| 269 |
+
buf.write("[.%s" % normalize(node.label))
|
| 270 |
+
stack.append(" ]")
|
| 271 |
+
elif print_format is TreePrintMethod.ONE_LINE or print_format is TreePrintMethod.VLSP:
|
| 272 |
+
buf.write(OPEN_PAREN)
|
| 273 |
+
if node.label is not None:
|
| 274 |
+
buf.write(normalize(node.label))
|
| 275 |
+
stack.append(CLOSE_PAREN)
|
| 276 |
+
elif print_format is TreePrintMethod.LABELED_PARENS:
|
| 277 |
+
buf.write("%s_%s" % (OPEN_PAREN, normalize(node.label)))
|
| 278 |
+
stack.append(CLOSE_PAREN + "_" + normalize(node.label))
|
| 279 |
+
stack.append(SPACE_SEPARATOR)
|
| 280 |
+
|
| 281 |
+
for child in reversed(node.children):
|
| 282 |
+
stack.append(child)
|
| 283 |
+
stack.append(SPACE_SEPARATOR)
|
| 284 |
+
if print_format == TreePrintMethod.VLSP:
|
| 285 |
+
buf.write("\n</s>")
|
| 286 |
+
buf.seek(0)
|
| 287 |
+
return buf.read()
|
| 288 |
+
|
| 289 |
+
def __repr__(self):
|
| 290 |
+
return "{}".format(self)
|
| 291 |
+
|
| 292 |
+
def __eq__(self, other):
|
| 293 |
+
if self is other:
|
| 294 |
+
return True
|
| 295 |
+
if not isinstance(other, Tree):
|
| 296 |
+
return False
|
| 297 |
+
if self.label != other.label:
|
| 298 |
+
return False
|
| 299 |
+
if len(self.children) != len(other.children):
|
| 300 |
+
return False
|
| 301 |
+
if any(c1 != c2 for c1, c2 in zip(self.children, other.children)):
|
| 302 |
+
return False
|
| 303 |
+
return True
|
| 304 |
+
|
| 305 |
+
def depth(self):
|
| 306 |
+
if not self.children:
|
| 307 |
+
return 0
|
| 308 |
+
return 1 + max(x.depth() for x in self.children)
|
| 309 |
+
|
| 310 |
+
def visit_preorder(self, internal=None, preterminal=None, leaf=None):
|
| 311 |
+
"""
|
| 312 |
+
Visit the tree in a preorder order
|
| 313 |
+
|
| 314 |
+
Applies the given functions to each node.
|
| 315 |
+
internal: if not None, applies this function to each non-leaf, non-preterminal node
|
| 316 |
+
preterminal: if not None, applies this functiion to each preterminal
|
| 317 |
+
leaf: if not None, applies this function to each leaf
|
| 318 |
+
|
| 319 |
+
The functions should *not* destructively alter the trees.
|
| 320 |
+
There is no attempt to interpret the results of calling these functions.
|
| 321 |
+
Rather, you can use visit_preorder to collect stats on trees, etc.
|
| 322 |
+
"""
|
| 323 |
+
if self.is_leaf():
|
| 324 |
+
if leaf:
|
| 325 |
+
leaf(self)
|
| 326 |
+
elif self.is_preterminal():
|
| 327 |
+
if preterminal:
|
| 328 |
+
preterminal(self)
|
| 329 |
+
else:
|
| 330 |
+
if internal:
|
| 331 |
+
internal(self)
|
| 332 |
+
for child in self.children:
|
| 333 |
+
child.visit_preorder(internal, preterminal, leaf)
|
| 334 |
+
|
| 335 |
+
@staticmethod
|
| 336 |
+
def get_unique_constituent_labels(trees):
|
| 337 |
+
"""
|
| 338 |
+
Walks over all of the trees and gets all of the unique constituent names from the trees
|
| 339 |
+
"""
|
| 340 |
+
if isinstance(trees, Tree):
|
| 341 |
+
trees = [trees]
|
| 342 |
+
constituents = Tree.get_constituent_counts(trees)
|
| 343 |
+
return sorted(set(constituents.keys()))
|
| 344 |
+
|
| 345 |
+
@staticmethod
|
| 346 |
+
def get_constituent_counts(trees):
|
| 347 |
+
"""
|
| 348 |
+
Walks over all of the trees and gets the count of the unique constituent names from the trees
|
| 349 |
+
"""
|
| 350 |
+
if isinstance(trees, Tree):
|
| 351 |
+
trees = [trees]
|
| 352 |
+
|
| 353 |
+
constituents = Counter()
|
| 354 |
+
for tree in trees:
|
| 355 |
+
tree.visit_preorder(internal = lambda x: constituents.update([x.label]))
|
| 356 |
+
return constituents
|
| 357 |
+
|
| 358 |
+
@staticmethod
|
| 359 |
+
def get_unique_tags(trees):
|
| 360 |
+
"""
|
| 361 |
+
Walks over all of the trees and gets all of the unique tags from the trees
|
| 362 |
+
"""
|
| 363 |
+
if isinstance(trees, Tree):
|
| 364 |
+
trees = [trees]
|
| 365 |
+
|
| 366 |
+
tags = set()
|
| 367 |
+
for tree in trees:
|
| 368 |
+
tree.visit_preorder(preterminal = lambda x: tags.add(x.label))
|
| 369 |
+
return sorted(tags)
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def get_unique_words(trees):
|
| 373 |
+
"""
|
| 374 |
+
Walks over all of the trees and gets all of the unique words from the trees
|
| 375 |
+
"""
|
| 376 |
+
if isinstance(trees, Tree):
|
| 377 |
+
trees = [trees]
|
| 378 |
+
|
| 379 |
+
words = set()
|
| 380 |
+
for tree in trees:
|
| 381 |
+
tree.visit_preorder(leaf = lambda x: words.add(x.label))
|
| 382 |
+
return sorted(words)
|
| 383 |
+
|
| 384 |
+
@staticmethod
|
| 385 |
+
def get_common_words(trees, num_words):
|
| 386 |
+
"""
|
| 387 |
+
Walks over all of the trees and gets the most frequently occurring words.
|
| 388 |
+
"""
|
| 389 |
+
if num_words == 0:
|
| 390 |
+
return set()
|
| 391 |
+
|
| 392 |
+
if isinstance(trees, Tree):
|
| 393 |
+
trees = [trees]
|
| 394 |
+
|
| 395 |
+
words = Counter()
|
| 396 |
+
for tree in trees:
|
| 397 |
+
tree.visit_preorder(leaf = lambda x: words.update([x.label]))
|
| 398 |
+
return sorted(x[0] for x in words.most_common()[:num_words])
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def get_rare_words(trees, threshold=0.05):
|
| 402 |
+
"""
|
| 403 |
+
Walks over all of the trees and gets the least frequently occurring words.
|
| 404 |
+
|
| 405 |
+
threshold: choose the bottom X percent
|
| 406 |
+
"""
|
| 407 |
+
if isinstance(trees, Tree):
|
| 408 |
+
trees = [trees]
|
| 409 |
+
|
| 410 |
+
words = Counter()
|
| 411 |
+
for tree in trees:
|
| 412 |
+
tree.visit_preorder(leaf = lambda x: words.update([x.label]))
|
| 413 |
+
threshold = max(int(len(words) * threshold), 1)
|
| 414 |
+
return sorted(x[0] for x in words.most_common()[:-threshold-1:-1])
|
| 415 |
+
|
| 416 |
+
@staticmethod
|
| 417 |
+
def get_root_labels(trees):
|
| 418 |
+
return sorted(set(x.label for x in trees))
|
| 419 |
+
|
| 420 |
+
@staticmethod
|
| 421 |
+
def get_compound_constituents(trees, separate_root=False):
|
| 422 |
+
constituents = set()
|
| 423 |
+
stack = deque()
|
| 424 |
+
for tree in trees:
|
| 425 |
+
if separate_root:
|
| 426 |
+
constituents.add((tree.label,))
|
| 427 |
+
for child in tree.children:
|
| 428 |
+
stack.append(child)
|
| 429 |
+
else:
|
| 430 |
+
stack.append(tree)
|
| 431 |
+
while len(stack) > 0:
|
| 432 |
+
node = stack.pop()
|
| 433 |
+
if node.is_leaf() or node.is_preterminal():
|
| 434 |
+
continue
|
| 435 |
+
labels = [node.label]
|
| 436 |
+
while len(node.children) == 1 and not node.children[0].is_preterminal():
|
| 437 |
+
node = node.children[0]
|
| 438 |
+
labels.append(node.label)
|
| 439 |
+
constituents.add(tuple(labels))
|
| 440 |
+
for child in node.children:
|
| 441 |
+
stack.append(child)
|
| 442 |
+
return sorted(constituents)
|
| 443 |
+
|
| 444 |
+
# TODO: test different pattern
|
| 445 |
+
def simplify_labels(self, pattern=CONSTITUENT_SPLIT):
|
| 446 |
+
"""
|
| 447 |
+
Return a copy of the tree with the -=# removed
|
| 448 |
+
|
| 449 |
+
Leaves the text of the leaves alone.
|
| 450 |
+
"""
|
| 451 |
+
new_label = self.label
|
| 452 |
+
# check len(new_label) just in case it's a tag of - or =
|
| 453 |
+
if new_label and not self.is_leaf() and len(new_label) > 1 and new_label not in ('-LRB-', '-RRB-'):
|
| 454 |
+
new_label = pattern.split(new_label)[0]
|
| 455 |
+
new_children = [child.simplify_labels(pattern) for child in self.children]
|
| 456 |
+
return Tree(new_label, new_children)
|
| 457 |
+
|
| 458 |
+
def reverse(self):
|
| 459 |
+
"""
|
| 460 |
+
Flip a tree backwards
|
| 461 |
+
|
| 462 |
+
The intent is to train a parser backwards to see if the
|
| 463 |
+
forward and backwards parsers can augment each other
|
| 464 |
+
"""
|
| 465 |
+
if self.is_leaf():
|
| 466 |
+
return Tree(self.label)
|
| 467 |
+
|
| 468 |
+
new_children = [child.reverse() for child in reversed(self.children)]
|
| 469 |
+
return Tree(self.label, new_children)
|
| 470 |
+
|
| 471 |
+
def remap_constituent_labels(self, label_map):
|
| 472 |
+
"""
|
| 473 |
+
Copies the tree with some labels replaced.
|
| 474 |
+
|
| 475 |
+
Labels in the map are replaced with the mapped value.
|
| 476 |
+
Labels not in the map are unchanged.
|
| 477 |
+
"""
|
| 478 |
+
if self.is_leaf():
|
| 479 |
+
return Tree(self.label)
|
| 480 |
+
if self.is_preterminal():
|
| 481 |
+
return Tree(self.label, Tree(self.children[0].label))
|
| 482 |
+
new_label = label_map.get(self.label, self.label)
|
| 483 |
+
return Tree(new_label, [child.remap_constituent_labels(label_map) for child in self.children])
|
| 484 |
+
|
| 485 |
+
def remap_words(self, word_map):
|
| 486 |
+
"""
|
| 487 |
+
Copies the tree with some labels replaced.
|
| 488 |
+
|
| 489 |
+
Labels in the map are replaced with the mapped value.
|
| 490 |
+
Labels not in the map are unchanged.
|
| 491 |
+
"""
|
| 492 |
+
if self.is_leaf():
|
| 493 |
+
new_label = word_map.get(self.label, self.label)
|
| 494 |
+
return Tree(new_label)
|
| 495 |
+
if self.is_preterminal():
|
| 496 |
+
return Tree(self.label, self.children[0].remap_words(word_map))
|
| 497 |
+
return Tree(self.label, [child.remap_words(word_map) for child in self.children])
|
| 498 |
+
|
| 499 |
+
def replace_words(self, words):
|
| 500 |
+
"""
|
| 501 |
+
Replace all leaf words with the words in the given list (or iterable)
|
| 502 |
+
|
| 503 |
+
Returns a new tree
|
| 504 |
+
"""
|
| 505 |
+
word_iterator = iter(words)
|
| 506 |
+
def recursive_replace_words(subtree):
|
| 507 |
+
if subtree.is_leaf():
|
| 508 |
+
word = next(word_iterator, None)
|
| 509 |
+
if word is None:
|
| 510 |
+
raise ValueError("Not enough words to replace all leaves")
|
| 511 |
+
return Tree(word)
|
| 512 |
+
return Tree(subtree.label, [recursive_replace_words(x) for x in subtree.children])
|
| 513 |
+
|
| 514 |
+
new_tree = recursive_replace_words(self)
|
| 515 |
+
if any(True for _ in word_iterator):
|
| 516 |
+
raise ValueError("Too many words for the given tree")
|
| 517 |
+
return new_tree
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def replace_tags(self, tags):
|
| 521 |
+
if self.is_leaf():
|
| 522 |
+
raise ValueError("Must call replace_tags with non-leaf")
|
| 523 |
+
|
| 524 |
+
if isinstance(tags, Tree):
|
| 525 |
+
tag_iterator = (x.label for x in tags.yield_preterminals())
|
| 526 |
+
else:
|
| 527 |
+
tag_iterator = iter(tags)
|
| 528 |
+
|
| 529 |
+
new_tree = copy.deepcopy(self)
|
| 530 |
+
queue = deque()
|
| 531 |
+
queue.append(new_tree)
|
| 532 |
+
while len(queue) > 0:
|
| 533 |
+
next_node = queue.pop()
|
| 534 |
+
if next_node.is_preterminal():
|
| 535 |
+
try:
|
| 536 |
+
label = next(tag_iterator)
|
| 537 |
+
except StopIteration:
|
| 538 |
+
raise ValueError("Not enough tags in sentence for given tree")
|
| 539 |
+
next_node.label = label
|
| 540 |
+
elif next_node.is_leaf():
|
| 541 |
+
raise ValueError("Got a badly structured tree: {}".format(self))
|
| 542 |
+
else:
|
| 543 |
+
queue.extend(reversed(next_node.children))
|
| 544 |
+
|
| 545 |
+
if any(True for _ in tag_iterator):
|
| 546 |
+
raise ValueError("Too many tags for the given tree")
|
| 547 |
+
|
| 548 |
+
return new_tree
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def prune_none(self):
|
| 552 |
+
"""
|
| 553 |
+
Return a copy of the tree, eliminating all nodes which are in one of two categories:
|
| 554 |
+
they are a preterminal -NONE-, such as appears in PTB
|
| 555 |
+
*E* shows up in a VLSP dataset
|
| 556 |
+
they have been pruned to 0 children by the recursive call
|
| 557 |
+
"""
|
| 558 |
+
if self.is_leaf():
|
| 559 |
+
return Tree(self.label)
|
| 560 |
+
if self.is_preterminal():
|
| 561 |
+
if self.label == '-NONE-' or self.children[0].label in WORDS_TO_PRUNE:
|
| 562 |
+
return None
|
| 563 |
+
return Tree(self.label, Tree(self.children[0].label))
|
| 564 |
+
# must be internal node
|
| 565 |
+
new_children = [child.prune_none() for child in self.children]
|
| 566 |
+
new_children = [child for child in new_children if child is not None]
|
| 567 |
+
if len(new_children) == 0:
|
| 568 |
+
return None
|
| 569 |
+
return Tree(self.label, new_children)
|
| 570 |
+
|
| 571 |
+
def count_unary_depth(self):
|
| 572 |
+
if self.is_preterminal() or self.is_leaf():
|
| 573 |
+
return 0
|
| 574 |
+
if len(self.children) == 1:
|
| 575 |
+
t = self
|
| 576 |
+
score = 0
|
| 577 |
+
while not t.is_preterminal() and not t.is_leaf() and len(t.children) == 1:
|
| 578 |
+
score = score + 1
|
| 579 |
+
t = t.children[0]
|
| 580 |
+
child_score = max(tc.count_unary_depth() for tc in t.children)
|
| 581 |
+
score = max(score, child_score)
|
| 582 |
+
return score
|
| 583 |
+
score = max(t.count_unary_depth() for t in self.children)
|
| 584 |
+
return score
|
| 585 |
+
|
| 586 |
+
@staticmethod
|
| 587 |
+
def write_treebank(trees, out_file, fmt="{}"):
|
| 588 |
+
with open(out_file, "w", encoding="utf-8") as fout:
|
| 589 |
+
for tree in trees:
|
| 590 |
+
fout.write(fmt.format(tree))
|
| 591 |
+
fout.write("\n")
|
stanza/stanza/models/constituency/positional_encoding.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on
|
| 3 |
+
https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
class SinusoidalEncoding(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Uses sine & cosine to represent position
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, model_dim, max_len):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.register_buffer('pe', self.build_position(model_dim, max_len))
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def build_position(model_dim, max_len, device=None):
|
| 21 |
+
position = torch.arange(max_len).unsqueeze(1)
|
| 22 |
+
div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
|
| 23 |
+
pe = torch.zeros(max_len, model_dim)
|
| 24 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 25 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 26 |
+
if device is not None:
|
| 27 |
+
pe = pe.to(device=device)
|
| 28 |
+
return pe
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
if max(x) >= self.pe.shape[0]:
|
| 32 |
+
# try to drop the reference first before creating a new encoding
|
| 33 |
+
# the goal being to save memory if we are close to the memory limit
|
| 34 |
+
device = self.pe.device
|
| 35 |
+
shape = self.pe.shape[1]
|
| 36 |
+
self.register_buffer('pe', None)
|
| 37 |
+
# TODO: this may result in very poor performance
|
| 38 |
+
# in the event of a model that increases size one at a time
|
| 39 |
+
self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device))
|
| 40 |
+
return self.pe[x]
|
| 41 |
+
|
| 42 |
+
def max_len(self):
|
| 43 |
+
return self.pe.shape[0]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AddSinusoidalEncoding(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Uses sine & cosine to represent position. Adds the position to the given matrix
|
| 49 |
+
|
| 50 |
+
Default behavior is batch_first
|
| 51 |
+
"""
|
| 52 |
+
def __init__(self, d_model=256, max_len=512):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.encoding = SinusoidalEncoding(d_model, max_len)
|
| 55 |
+
|
| 56 |
+
def forward(self, x, scale=1.0):
|
| 57 |
+
"""
|
| 58 |
+
Adds the positional encoding to the input tensor
|
| 59 |
+
|
| 60 |
+
The tensor is expected to be of the shape B, N, D
|
| 61 |
+
Properly masking the output tensor is up to the caller
|
| 62 |
+
"""
|
| 63 |
+
if len(x.shape) == 3:
|
| 64 |
+
timing = self.encoding(torch.arange(x.shape[1], device=x.device))
|
| 65 |
+
timing = timing.expand(x.shape[0], -1, -1)
|
| 66 |
+
elif len(x.shape) == 2:
|
| 67 |
+
timing = self.encoding(torch.arange(x.shape[0], device=x.device))
|
| 68 |
+
return x + timing * scale
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ConcatSinusoidalEncoding(nn.Module):
|
| 72 |
+
"""
|
| 73 |
+
Uses sine & cosine to represent position. Concats the position and returns a larger object
|
| 74 |
+
|
| 75 |
+
Default behavior is batch_first
|
| 76 |
+
"""
|
| 77 |
+
def __init__(self, d_model=256, max_len=512):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.encoding = SinusoidalEncoding(d_model, max_len)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
if len(x.shape) == 3:
|
| 83 |
+
timing = self.encoding(torch.arange(x.shape[1], device=x.device))
|
| 84 |
+
timing = timing.expand(x.shape[0], -1, -1)
|
| 85 |
+
else:
|
| 86 |
+
timing = self.encoding(torch.arange(x.shape[0], device=x.device))
|
| 87 |
+
|
| 88 |
+
out = torch.cat((x, timing), dim=-1)
|
| 89 |
+
return out
|
stanza/stanza/models/constituency/retagging.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Refactor a few functions specifically for retagging trees
|
| 3 |
+
|
| 4 |
+
Retagging is important because the gold tags will not be available at runtime
|
| 5 |
+
|
| 6 |
+
Note that the method which does the actual retagging is in utils.py
|
| 7 |
+
so as to avoid unnecessary circular imports
|
| 8 |
+
(eg, Pipeline imports constituency/trainer which imports this which imports Pipeline)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from stanza import Pipeline
|
| 15 |
+
|
| 16 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 17 |
+
from stanza.models.common.vocab import VOCAB_PREFIX
|
| 18 |
+
from stanza.resources.common import download_resources_json, load_resources_json, get_language_resources
|
| 19 |
+
|
| 20 |
+
tlogger = logging.getLogger('stanza.constituency.trainer')
|
| 21 |
+
|
| 22 |
+
# xpos tagger doesn't produce PP tag on the turin treebank,
|
| 23 |
+
# so instead we use upos to avoid unknown tag errors
|
| 24 |
+
RETAG_METHOD = {
|
| 25 |
+
"da": "upos", # the DDT has no xpos tags anyway
|
| 26 |
+
"de": "upos", # DE GSD is also missing a few punctuation tags
|
| 27 |
+
"es": "upos", # AnCora has half-finished xpos tags
|
| 28 |
+
"id": "upos", # GSD is missing a few punctuation tags - fixed in 2.12, though
|
| 29 |
+
"it": "upos",
|
| 30 |
+
"pt": "upos", # default PT model has no xpos either
|
| 31 |
+
"vi": "xpos", # the new version of UD can be merged with xpos from VLSP22
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def add_retag_args(parser):
|
| 35 |
+
"""
|
| 36 |
+
Arguments specifically for retagging treebanks
|
| 37 |
+
"""
|
| 38 |
+
parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
|
| 39 |
+
parser.add_argument('--retag_method', default=None, choices=['xpos', 'upos'], help='Which tags to use when retagging. Default depends on the language')
|
| 40 |
+
parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default. Can specify multiple taggers with ; in which case the majority vote wins')
|
| 41 |
+
parser.add_argument('--retag_pretrain_path', default=None, help='Use this for a pretrain path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom pretrain')
|
| 42 |
+
parser.add_argument('--retag_charlm_forward_file', default=None, help='Use this for a forward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm')
|
| 43 |
+
parser.add_argument('--retag_charlm_backward_file', default=None, help='Use this for a backward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm')
|
| 44 |
+
parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees")
|
| 45 |
+
|
| 46 |
+
def postprocess_args(args):
|
| 47 |
+
"""
|
| 48 |
+
After parsing args, unify some settings
|
| 49 |
+
"""
|
| 50 |
+
# use a language specific default for retag_method if we know the language
|
| 51 |
+
# otherwise, use xpos
|
| 52 |
+
if args['retag_method'] is None and 'lang' in args and args['lang'] in RETAG_METHOD:
|
| 53 |
+
args['retag_method'] = RETAG_METHOD[args['lang']]
|
| 54 |
+
if args['retag_method'] is None:
|
| 55 |
+
args['retag_method'] = 'xpos'
|
| 56 |
+
|
| 57 |
+
if args['retag_method'] == 'xpos':
|
| 58 |
+
args['retag_xpos'] = True
|
| 59 |
+
elif args['retag_method'] == 'upos':
|
| 60 |
+
args['retag_xpos'] = False
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError("Unknown retag method {}".format(xpos))
|
| 63 |
+
|
| 64 |
+
def build_retag_pipeline(args):
|
| 65 |
+
"""
|
| 66 |
+
Builds retag pipelines based on the arguments
|
| 67 |
+
|
| 68 |
+
May alter the arguments if the pipeline is incompatible, such as
|
| 69 |
+
taggers with no xpos
|
| 70 |
+
|
| 71 |
+
Will return a list of one or more retag pipelines.
|
| 72 |
+
Multiple tagger models can be specified by having them
|
| 73 |
+
semi-colon separated in retag_model_path.
|
| 74 |
+
"""
|
| 75 |
+
# some argument sets might not use 'mode'
|
| 76 |
+
if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer':
|
| 77 |
+
download_resources_json()
|
| 78 |
+
resources = load_resources_json()
|
| 79 |
+
|
| 80 |
+
if '_' in args['retag_package']:
|
| 81 |
+
lang, package = args['retag_package'].split('_', 1)
|
| 82 |
+
lang_resources = get_language_resources(resources, lang)
|
| 83 |
+
if lang_resources is None and 'lang' in args:
|
| 84 |
+
lang_resources = get_language_resources(resources, args['lang'])
|
| 85 |
+
if lang_resources is not None and 'pos' in lang_resources and args['retag_package'] in lang_resources['pos']:
|
| 86 |
+
lang = args['lang']
|
| 87 |
+
package = args['retag_package']
|
| 88 |
+
else:
|
| 89 |
+
if 'lang' not in args:
|
| 90 |
+
raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package'])
|
| 91 |
+
lang = args.get('lang', None)
|
| 92 |
+
package = args['retag_package']
|
| 93 |
+
foundation_cache = FoundationCache()
|
| 94 |
+
retag_args = {"lang": lang,
|
| 95 |
+
"processors": "tokenize, pos",
|
| 96 |
+
"tokenize_pretokenized": True,
|
| 97 |
+
"package": {"pos": package}}
|
| 98 |
+
if args['retag_pretrain_path'] is not None:
|
| 99 |
+
retag_args['pos_pretrain_path'] = args['retag_pretrain_path']
|
| 100 |
+
if args['retag_charlm_forward_file'] is not None:
|
| 101 |
+
retag_args['pos_forward_charlm_path'] = args['retag_charlm_forward_file']
|
| 102 |
+
if args['retag_charlm_backward_file'] is not None:
|
| 103 |
+
retag_args['pos_backward_charlm_path'] = args['retag_charlm_backward_file']
|
| 104 |
+
|
| 105 |
+
def build(retag_args, path):
|
| 106 |
+
retag_args = copy.deepcopy(retag_args)
|
| 107 |
+
# we just downloaded the resources a moment ago
|
| 108 |
+
# no need to repeatedly download
|
| 109 |
+
retag_args['download_method'] = 'reuse_resources'
|
| 110 |
+
if path is not None:
|
| 111 |
+
retag_args['allow_unknown_language'] = True
|
| 112 |
+
retag_args['pos_model_path'] = path
|
| 113 |
+
tlogger.debug('Creating retag pipeline using %s', path)
|
| 114 |
+
else:
|
| 115 |
+
tlogger.debug('Creating retag pipeline for %s package', package)
|
| 116 |
+
|
| 117 |
+
retag_pipeline = Pipeline(foundation_cache=foundation_cache, **retag_args)
|
| 118 |
+
if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):
|
| 119 |
+
tlogger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package)
|
| 120 |
+
args['retag_xpos'] = False
|
| 121 |
+
args['retag_method'] = 'upos'
|
| 122 |
+
return retag_pipeline
|
| 123 |
+
|
| 124 |
+
if args['retag_model_path'] is None:
|
| 125 |
+
return [build(retag_args, None)]
|
| 126 |
+
paths = args['retag_model_path'].split(";")
|
| 127 |
+
# can be length 1 if only one tagger to work with
|
| 128 |
+
return [build(retag_args, path) for path in paths]
|
| 129 |
+
|
| 130 |
+
return None
|
stanza/stanza/models/constituency/state.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
|
| 3 |
+
class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence',
|
| 4 |
+
'sentence_length', 'num_opens', 'word_position', 'score'])):
|
| 5 |
+
"""
|
| 6 |
+
Represents a partially completed transition parse
|
| 7 |
+
|
| 8 |
+
Includes stack/buffers for unused words, already executed transitions, and partially build constituents
|
| 9 |
+
At training time, also keeps track of the gold data we are reparsing
|
| 10 |
+
|
| 11 |
+
num_opens is useful for tracking
|
| 12 |
+
1) if the parser is in a stuck state where it is making infinite opens
|
| 13 |
+
2) if a close transition is impossible because there are no previous opens
|
| 14 |
+
|
| 15 |
+
sentence_length tracks how long the sentence is so we abort if we go infinite
|
| 16 |
+
|
| 17 |
+
non-stack information such as sentence_length and num_opens
|
| 18 |
+
will be copied from the original_state if possible, with the
|
| 19 |
+
exact arguments overriding the values in the original_state
|
| 20 |
+
|
| 21 |
+
gold_tree: the original tree, if made from a gold tree. might be None
|
| 22 |
+
gold_sequence: the original transition sequence, if available
|
| 23 |
+
Note that at runtime, gold values will not be available
|
| 24 |
+
|
| 25 |
+
word_position tracks where in the word queue we are. cheaper than
|
| 26 |
+
manipulating the list itself. this can be handled differently
|
| 27 |
+
from transitions and constituents as it is processed once
|
| 28 |
+
at the start of parsing
|
| 29 |
+
|
| 30 |
+
The word_queue should have both a start and an end word.
|
| 31 |
+
Those can be None in the case of the endpoints if they are unused.
|
| 32 |
+
"""
|
| 33 |
+
def empty_word_queue(self):
|
| 34 |
+
# the first element of each stack is a sentinel with no value
|
| 35 |
+
# and no parent
|
| 36 |
+
return self.word_position == self.sentence_length
|
| 37 |
+
|
| 38 |
+
def empty_transitions(self):
|
| 39 |
+
# the first element of each stack is a sentinel with no value
|
| 40 |
+
# and no parent
|
| 41 |
+
return self.transitions.parent is None
|
| 42 |
+
|
| 43 |
+
def has_one_constituent(self):
|
| 44 |
+
# a length of 1 represents no constituents
|
| 45 |
+
return self.constituents.length == 2
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def empty_constituents(self):
|
| 49 |
+
return self.constituents.parent is None
|
| 50 |
+
|
| 51 |
+
def num_constituents(self):
|
| 52 |
+
return self.constituents.length - 1
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def num_transitions(self):
|
| 56 |
+
# -1 for the sentinel value
|
| 57 |
+
return self.transitions.length - 1
|
| 58 |
+
|
| 59 |
+
def get_word(self, pos):
|
| 60 |
+
# +1 to handle the initial sentinel value
|
| 61 |
+
# (which you can actually get with pos=-1)
|
| 62 |
+
return self.word_queue[pos+1]
|
| 63 |
+
|
| 64 |
+
def finished(self, model):
|
| 65 |
+
return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.root_labels
|
| 66 |
+
|
| 67 |
+
def get_tree(self, model):
|
| 68 |
+
return model.get_top_constituent(self.constituents)
|
| 69 |
+
|
| 70 |
+
def all_transitions(self, model):
|
| 71 |
+
# TODO: rewrite this to be nicer / faster? or just refactor?
|
| 72 |
+
all_transitions = []
|
| 73 |
+
transitions = self.transitions
|
| 74 |
+
while transitions.parent is not None:
|
| 75 |
+
all_transitions.append(model.get_top_transition(transitions))
|
| 76 |
+
transitions = transitions.parent
|
| 77 |
+
return list(reversed(all_transitions))
|
| 78 |
+
|
| 79 |
+
def all_constituents(self, model):
|
| 80 |
+
# TODO: rewrite this to be nicer / faster?
|
| 81 |
+
all_constituents = []
|
| 82 |
+
constituents = self.constituents
|
| 83 |
+
while constituents.parent is not None:
|
| 84 |
+
all_constituents.append(model.get_top_constituent(constituents))
|
| 85 |
+
constituents = constituents.parent
|
| 86 |
+
return list(reversed(all_constituents))
|
| 87 |
+
|
| 88 |
+
def all_words(self, model):
|
| 89 |
+
return [model.get_word(x) for x in self.word_queue]
|
| 90 |
+
|
| 91 |
+
def to_string(self, model):
|
| 92 |
+
return "State(\n buffer:%s\n transitions:%s\n constituents:%s\n word_position:%d num_opens:%d)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)), self.word_position, self.num_opens)
|
| 93 |
+
|
| 94 |
+
def __str__(self):
|
| 95 |
+
return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.word_queue), str(self.transitions), str(self.constituents))
|
| 96 |
+
|
| 97 |
+
class MultiState(namedtuple('MultiState', ['states', 'gold_tree', 'gold_sequence', 'score'])):
|
| 98 |
+
def finished(self, ensemble):
|
| 99 |
+
return self.states[0].finished(ensemble.models[0])
|
| 100 |
+
|
| 101 |
+
def get_tree(self, ensemble):
|
| 102 |
+
return self.states[0].get_tree(ensemble.models[0])
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def empty_constituents(self):
|
| 106 |
+
return self.states[0].empty_constituents
|
| 107 |
+
|
| 108 |
+
def num_constituents(self):
|
| 109 |
+
return len(self.states[0].constituents) - 1
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def num_transitions(self):
|
| 113 |
+
# -1 for the sentinel value
|
| 114 |
+
return len(self.states[0].transitions) - 1
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def num_opens(self):
|
| 118 |
+
return self.states[0].num_opens
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def sentence_length(self):
|
| 122 |
+
return self.states[0].sentence_length
|
| 123 |
+
|
| 124 |
+
def empty_word_queue(self):
|
| 125 |
+
return self.states[0].empty_word_queue()
|
| 126 |
+
|
| 127 |
+
def empty_transitions(self):
|
| 128 |
+
return self.states[0].empty_transitions()
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def constituents(self):
|
| 132 |
+
# warning! if there is information in the constituents such as
|
| 133 |
+
# the embedding of the constituent, this will only contain the
|
| 134 |
+
# first such embedding
|
| 135 |
+
# the other models' constituent states won't be returned
|
| 136 |
+
return self.states[0].constituents
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def transitions(self):
|
| 140 |
+
# warning! if there is information in the transitions such as
|
| 141 |
+
# the embedding of the transition, this will only contain the
|
| 142 |
+
# first such embedding
|
| 143 |
+
# the other models' transition states won't be returned
|
| 144 |
+
return self.states[0].transitions
|
stanza/stanza/models/constituency/top_down_oracle.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from stanza.models.constituency.dynamic_oracle import advance_past_constituents, score_candidates, DynamicOracle, RepairEnum
|
| 5 |
+
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent
|
| 6 |
+
|
| 7 |
+
def find_constituent_end(gold_sequence, cur_index):
|
| 8 |
+
"""
|
| 9 |
+
Find the Close which ends the next constituent opened at or after cur_index
|
| 10 |
+
"""
|
| 11 |
+
count = 0
|
| 12 |
+
while cur_index < len(gold_sequence):
|
| 13 |
+
if isinstance(gold_sequence[cur_index], OpenConstituent):
|
| 14 |
+
count = count + 1
|
| 15 |
+
elif isinstance(gold_sequence[cur_index], CloseConstituent):
|
| 16 |
+
count = count - 1
|
| 17 |
+
if count == 0:
|
| 18 |
+
return cur_index
|
| 19 |
+
cur_index += 1
|
| 20 |
+
raise AssertionError("Open constituent not closed starting from index %d in sequence %s" % (cur_index, gold_sequence))
|
| 21 |
+
|
| 22 |
+
def fix_shift_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 23 |
+
"""
|
| 24 |
+
Predicted a close when we should have shifted
|
| 25 |
+
|
| 26 |
+
The fix here is to remove the corresponding close from later in
|
| 27 |
+
the transition sequence. The rest of the tree building is the same,
|
| 28 |
+
including doing the missing Shift immediately after
|
| 29 |
+
|
| 30 |
+
Anything else would make the situation of one precision, one
|
| 31 |
+
recall error worse
|
| 32 |
+
"""
|
| 33 |
+
if not isinstance(pred_transition, CloseConstituent):
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
if not isinstance(gold_transition, Shift):
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
close_index = advance_past_constituents(gold_sequence, gold_index)
|
| 40 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + gold_sequence[close_index+1:]
|
| 41 |
+
|
| 42 |
+
def fix_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 43 |
+
"""
|
| 44 |
+
Predicted a close when we should have opened a constituent
|
| 45 |
+
|
| 46 |
+
In this case, the previous constituent is now a precision and
|
| 47 |
+
recall error, BUT we can salvage the constituent we were about to
|
| 48 |
+
open by proceeding as if everything else is still the same.
|
| 49 |
+
|
| 50 |
+
The next thing the model should do is open the transition it forgot about
|
| 51 |
+
"""
|
| 52 |
+
if not isinstance(pred_transition, CloseConstituent):
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
close_index = advance_past_constituents(gold_sequence, gold_index)
|
| 59 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + gold_sequence[close_index+1:]
|
| 60 |
+
|
| 61 |
+
def fix_one_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 62 |
+
"""
|
| 63 |
+
Predicted a shift when we should have opened a constituent
|
| 64 |
+
|
| 65 |
+
This causes a single recall error if we just pretend that
|
| 66 |
+
constituent didn't exist
|
| 67 |
+
|
| 68 |
+
Keep the shift where it was, remove the next shift
|
| 69 |
+
Also, scroll ahead, find the corresponding close, cut it out
|
| 70 |
+
|
| 71 |
+
For the corresponding multiple opens, shift error, see fix_multiple_open_shift
|
| 72 |
+
"""
|
| 73 |
+
if not isinstance(pred_transition, Shift):
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
if not isinstance(gold_sequence[gold_index + 1], Shift):
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
shift_index = gold_index + 1
|
| 83 |
+
close_index = advance_past_constituents(gold_sequence, gold_index + 1)
|
| 84 |
+
if close_index is None:
|
| 85 |
+
return None
|
| 86 |
+
# gold_index is the skipped open constituent
|
| 87 |
+
# close_index was the corresponding close
|
| 88 |
+
# shift_index is the shift to remove
|
| 89 |
+
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:shift_index] + gold_sequence[shift_index+1:close_index] + gold_sequence[close_index+1:]
|
| 90 |
+
#print("Input sequence: %s\nIndex %d\nGold %s Pred %s\nUpdated sequence %s" % (gold_sequence, gold_index, gold_transition, pred_transition, updated_sequence))
|
| 91 |
+
return updated_sequence
|
| 92 |
+
|
| 93 |
+
def fix_multiple_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 94 |
+
"""
|
| 95 |
+
Predicted a shift when we should have opened multiple constituents instead
|
| 96 |
+
|
| 97 |
+
This causes a single recall error per constituent if we just
|
| 98 |
+
pretend those constituents don't exist
|
| 99 |
+
|
| 100 |
+
For each open constituent, we find the corresponding close,
|
| 101 |
+
then remove both the open & close
|
| 102 |
+
"""
|
| 103 |
+
if not isinstance(pred_transition, Shift):
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
shift_index = gold_index
|
| 110 |
+
while shift_index < len(gold_sequence) and isinstance(gold_sequence[shift_index], OpenConstituent):
|
| 111 |
+
shift_index += 1
|
| 112 |
+
if shift_index >= len(gold_sequence):
|
| 113 |
+
raise AssertionError("Found a sequence of OpenConstituent at the end of a TOP_DOWN sequence!")
|
| 114 |
+
if not isinstance(gold_sequence[shift_index], Shift):
|
| 115 |
+
raise AssertionError("Expected to find a Shift after a sequence of OpenConstituent. There should not be a %s" % gold_sequence[shift_index])
|
| 116 |
+
|
| 117 |
+
#print("Input sequence: %s\nIndex %d\nGold %s Pred %s" % (gold_sequence, gold_index, gold_transition, pred_transition))
|
| 118 |
+
updated_sequence = gold_sequence
|
| 119 |
+
while shift_index > gold_index:
|
| 120 |
+
close_index = advance_past_constituents(updated_sequence, shift_index)
|
| 121 |
+
if close_index is None:
|
| 122 |
+
raise AssertionError("Did not find a corresponding Close for this Open")
|
| 123 |
+
# cut out the corresponding open and close
|
| 124 |
+
updated_sequence = updated_sequence[:shift_index-1] + updated_sequence[shift_index:close_index] + updated_sequence[close_index+1:]
|
| 125 |
+
shift_index -= 1
|
| 126 |
+
#print(" %s" % updated_sequence)
|
| 127 |
+
|
| 128 |
+
#print("Final updated sequence: %s" % updated_sequence)
|
| 129 |
+
return updated_sequence
|
| 130 |
+
|
| 131 |
+
def fix_nested_open_constituent(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 132 |
+
"""
|
| 133 |
+
We were supposed to predict Open(X), then Open(Y), but predicted Open(Y) instead
|
| 134 |
+
|
| 135 |
+
We treat this as a single recall error.
|
| 136 |
+
|
| 137 |
+
We could even go crazy and turn it into a Unary,
|
| 138 |
+
such as Open(Y), Open(X), Open(Y)...
|
| 139 |
+
presumably that would be very confusing to the parser
|
| 140 |
+
not to mention ambiguous as to where to close the new constituent
|
| 141 |
+
"""
|
| 142 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
assert len(gold_sequence) > gold_index + 1
|
| 149 |
+
|
| 150 |
+
if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
# This replacement works if we skipped exactly one level
|
| 154 |
+
if gold_sequence[gold_index+1].label != pred_transition.label:
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
close_index = advance_past_constituents(gold_sequence, gold_index+1)
|
| 158 |
+
assert close_index is not None
|
| 159 |
+
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]
|
| 160 |
+
return updated_sequence
|
| 161 |
+
|
| 162 |
+
def fix_shift_open_immediate_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 163 |
+
"""
|
| 164 |
+
We were supposed to Shift, but instead we Opened
|
| 165 |
+
|
| 166 |
+
The biggest problem with this type of error is that the Close of
|
| 167 |
+
the Open is ambiguous. We could put it immediately before the
|
| 168 |
+
next Close, immediately after the Shift, or anywhere in between.
|
| 169 |
+
|
| 170 |
+
One unambiguous case would be if the proper sequence was Shift - Close.
|
| 171 |
+
Then it is unambiguous that the only possible repair is Open - Shift - Close - Close.
|
| 172 |
+
"""
|
| 173 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
if not isinstance(gold_transition, Shift):
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
assert len(gold_sequence) > gold_index + 1
|
| 180 |
+
if not isinstance(gold_sequence[gold_index+1], CloseConstituent):
|
| 181 |
+
# this is the ambiguous case
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]
|
| 185 |
+
|
| 186 |
+
def fix_shift_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 187 |
+
"""
|
| 188 |
+
We were supposed to Shift, but instead we Opened
|
| 189 |
+
|
| 190 |
+
The biggest problem with this type of error is that the Close of
|
| 191 |
+
the Open is ambiguous. We could put it immediately before the
|
| 192 |
+
next Close, immediately after the Shift, or anywhere in between.
|
| 193 |
+
|
| 194 |
+
In this fix, we are testing what happens if we treat this Open as a Unary transition.
|
| 195 |
+
"""
|
| 196 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
if not isinstance(gold_transition, Shift):
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
assert len(gold_sequence) > gold_index + 1
|
| 203 |
+
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
|
| 204 |
+
# this is the unambiguous case, which should already be handled
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]
|
| 208 |
+
|
| 209 |
+
def fix_shift_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 210 |
+
"""
|
| 211 |
+
We were supposed to Shift, but instead we Opened
|
| 212 |
+
|
| 213 |
+
The biggest problem with this type of error is that the Close of
|
| 214 |
+
the Open is ambiguous. We could put it immediately before the
|
| 215 |
+
next Close, immediately after the Shift, or anywhere in between.
|
| 216 |
+
|
| 217 |
+
In this fix, we put the corresponding Close for this Open at the end of the enclosing bracket.
|
| 218 |
+
"""
|
| 219 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
if not isinstance(gold_transition, Shift):
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
assert len(gold_sequence) > gold_index + 1
|
| 226 |
+
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
|
| 227 |
+
# this is the unambiguous case, which should already be handled
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
outer_close_index = advance_past_constituents(gold_sequence, gold_index)
|
| 231 |
+
|
| 232 |
+
return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:outer_close_index] + [CloseConstituent()] + gold_sequence[outer_close_index:]
|
| 233 |
+
|
| 234 |
+
def fix_shift_open_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 235 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
if not isinstance(gold_transition, Shift):
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
assert len(gold_sequence) > gold_index + 1
|
| 242 |
+
if isinstance(gold_sequence[gold_index+1], CloseConstituent):
|
| 243 |
+
# this is the unambiguous case, which should already be handled
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
# at this point: have Opened a constituent which we don't want
|
| 247 |
+
# need to figure out where to Close it
|
| 248 |
+
# could close it after the shift or after any given block
|
| 249 |
+
candidates = []
|
| 250 |
+
current_index = gold_index
|
| 251 |
+
while not isinstance(gold_sequence[current_index], CloseConstituent):
|
| 252 |
+
if isinstance(gold_sequence[current_index], Shift):
|
| 253 |
+
end_index = current_index
|
| 254 |
+
else:
|
| 255 |
+
end_index = find_constituent_end(gold_sequence, current_index)
|
| 256 |
+
candidates.append((gold_sequence[:gold_index], [pred_transition], gold_sequence[gold_index:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))
|
| 257 |
+
current_index = end_index + 1
|
| 258 |
+
|
| 259 |
+
scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=3)
|
| 260 |
+
if best_idx == len(candidates) - 1:
|
| 261 |
+
best_idx = -1
|
| 262 |
+
repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.name,
|
| 263 |
+
value="%d.%d" % (RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.value, best_idx),
|
| 264 |
+
is_correct=False)
|
| 265 |
+
return repair_type, best_candidate
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def fix_close_shift_ambiguous_immediate(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 269 |
+
"""
|
| 270 |
+
Instead of a Close, we predicted a Shift. This time, we immediately close no matter what comes after the next Shift.
|
| 271 |
+
|
| 272 |
+
An alternate strategy would be to Close at the closing of the outer constituent.
|
| 273 |
+
"""
|
| 274 |
+
if not isinstance(pred_transition, Shift):
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
num_closes = 0
|
| 281 |
+
while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
|
| 282 |
+
num_closes += 1
|
| 283 |
+
|
| 284 |
+
if not isinstance(gold_sequence[gold_index + num_closes], Shift):
|
| 285 |
+
# TODO: we should be able to handle this case too (an Open)
|
| 286 |
+
# however, it will be rare once the parser gets going and it
|
| 287 |
+
# would cause a lot of errors, anyway
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):
|
| 291 |
+
# this one should just have been satisfied in the non-ambiguous version
|
| 292 |
+
return None
|
| 293 |
+
|
| 294 |
+
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+1:]
|
| 295 |
+
return updated_sequence
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def fix_close_shift_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 299 |
+
"""
|
| 300 |
+
Instead of a Close, we predicted a Shift. This time, we close at the end of the outer bracket no matter what comes after the next Shift.
|
| 301 |
+
|
| 302 |
+
An alternate strategy would be to Close as soon as possible after the Shift.
|
| 303 |
+
"""
|
| 304 |
+
if not isinstance(pred_transition, Shift):
|
| 305 |
+
return None
|
| 306 |
+
|
| 307 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 308 |
+
return None
|
| 309 |
+
|
| 310 |
+
num_closes = 0
|
| 311 |
+
while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
|
| 312 |
+
num_closes += 1
|
| 313 |
+
|
| 314 |
+
if not isinstance(gold_sequence[gold_index + num_closes], Shift):
|
| 315 |
+
# TODO: we should be able to handle this case too (an Open)
|
| 316 |
+
# however, it will be rare once the parser gets going and it
|
| 317 |
+
# would cause a lot of errors, anyway
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):
|
| 321 |
+
# this one should just have been satisfied in the non-ambiguous version
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
# outer_close_index is now where the constituent which the broken constituent(s) reside inside gets closed
|
| 325 |
+
outer_close_index = advance_past_constituents(gold_sequence, gold_index + num_closes)
|
| 326 |
+
|
| 327 |
+
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+num_closes:outer_close_index] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[outer_close_index:]
|
| 328 |
+
return updated_sequence
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def fix_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, count_opens=False):
|
| 332 |
+
"""
|
| 333 |
+
We were supposed to Close, but instead did a Shift
|
| 334 |
+
|
| 335 |
+
In most cases, this will be ambiguous. There is now a constituent
|
| 336 |
+
which has been missed, no matter what we do, and we are on the
|
| 337 |
+
hook for eventually closing this constituent, creating a precision
|
| 338 |
+
error as well. The ambiguity arises because there will be
|
| 339 |
+
multiple places where the Close could occur if there are more
|
| 340 |
+
constituents created between now and when the outer constituent is
|
| 341 |
+
Closed.
|
| 342 |
+
|
| 343 |
+
The non-ambiguous case is if the proper sequence was
|
| 344 |
+
Close - Shift - Close
|
| 345 |
+
similar cases are also non-ambiguous, such as
|
| 346 |
+
Close - Close - Shift - Close
|
| 347 |
+
for that matter, so is the following, although the Opens will be lost
|
| 348 |
+
Close - Open - Shift - Close - Close
|
| 349 |
+
|
| 350 |
+
count_opens is an option to make it easy to count with or without
|
| 351 |
+
Open as different oracle fixes
|
| 352 |
+
"""
|
| 353 |
+
if not isinstance(pred_transition, Shift):
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 357 |
+
return None
|
| 358 |
+
|
| 359 |
+
num_closes = 0
|
| 360 |
+
while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
|
| 361 |
+
num_closes += 1
|
| 362 |
+
|
| 363 |
+
# We may allow unary transitions here
|
| 364 |
+
# the opens will be lost in the repaired sequence
|
| 365 |
+
num_opens = 0
|
| 366 |
+
if count_opens:
|
| 367 |
+
while isinstance(gold_sequence[gold_index + num_closes + num_opens], OpenConstituent):
|
| 368 |
+
num_opens += 1
|
| 369 |
+
|
| 370 |
+
if not isinstance(gold_sequence[gold_index + num_closes + num_opens], Shift):
|
| 371 |
+
if count_opens:
|
| 372 |
+
raise AssertionError("Should have found a Shift after a sequence of Opens or a Close with no Open. Started counting at %d in sequence %s" % (gold_index, gold_sequence))
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
if not isinstance(gold_sequence[gold_index + num_closes + num_opens + 1], CloseConstituent):
|
| 376 |
+
return None
|
| 377 |
+
for idx in range(num_opens):
|
| 378 |
+
if not isinstance(gold_sequence[gold_index + num_closes + num_opens + idx + 1], CloseConstituent):
|
| 379 |
+
return None
|
| 380 |
+
|
| 381 |
+
# Now we know it is Close x num_closes, Shift, Close
|
| 382 |
+
# Since we have erroneously predicted a Shift now, the best we can
|
| 383 |
+
# do is to follow that, then add num_closes Closes
|
| 384 |
+
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+num_opens*2+1:]
|
| 385 |
+
return updated_sequence
|
| 386 |
+
|
| 387 |
+
def fix_close_shift_with_opens(*args, **kwargs):
|
| 388 |
+
return fix_close_shift(*args, **kwargs, count_opens=True)
|
| 389 |
+
|
| 390 |
+
def fix_close_next_correct_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 391 |
+
"""
|
| 392 |
+
We were supposed to Close, but instead predicted Shift when the next transition is Shift
|
| 393 |
+
|
| 394 |
+
This differs from the previous Close-Shift in that this case does
|
| 395 |
+
not have an unambiguous place to put the Close. Instead, we let
|
| 396 |
+
the model predict where to put the Close
|
| 397 |
+
|
| 398 |
+
Note that this can also work for Close-Open with the next Open correct
|
| 399 |
+
|
| 400 |
+
Not covered (yet?) is multiple Close in a row
|
| 401 |
+
"""
|
| 402 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 403 |
+
return None
|
| 404 |
+
if not isinstance(pred_transition, (Shift, OpenConstituent)):
|
| 405 |
+
return None
|
| 406 |
+
if gold_sequence[gold_index+1] != pred_transition:
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
candidates = []
|
| 410 |
+
current_index = gold_index + 1
|
| 411 |
+
while not isinstance(gold_sequence[current_index], CloseConstituent):
|
| 412 |
+
if isinstance(gold_sequence[current_index], Shift):
|
| 413 |
+
end_index = current_index
|
| 414 |
+
else:
|
| 415 |
+
end_index = find_constituent_end(gold_sequence, current_index)
|
| 416 |
+
candidates.append((gold_sequence[:gold_index], gold_sequence[gold_index+1:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))
|
| 417 |
+
current_index = end_index + 1
|
| 418 |
+
|
| 419 |
+
scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=3)
|
| 420 |
+
if best_idx == len(candidates) - 1:
|
| 421 |
+
best_idx = -1
|
| 422 |
+
repair_type = RepairEnum(name=RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.name,
|
| 423 |
+
value="%d.%d" % (RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.value, best_idx),
|
| 424 |
+
is_correct=False)
|
| 425 |
+
return repair_type, best_candidate
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def fix_close_open_correct_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):
|
| 429 |
+
"""
|
| 430 |
+
We were supposed to Close, but instead did an Open
|
| 431 |
+
|
| 432 |
+
In general this is ambiguous (like close/shift), as we need to know when to close the incorrect constituent
|
| 433 |
+
|
| 434 |
+
A case that is not ambiguous is when exactly one constituent was
|
| 435 |
+
supposed to come after the Close and it matches the Open we just
|
| 436 |
+
created. In that case, we treat that constituent as if it were
|
| 437 |
+
part of the non-Closed constituent. For example,
|
| 438 |
+
"ate (NP spaghetti) (PP with a fork)" ->
|
| 439 |
+
"ate (NP spaghetti (PP with a fork))"
|
| 440 |
+
(delicious)
|
| 441 |
+
|
| 442 |
+
There is also an option to not check for the Close after the first
|
| 443 |
+
constituent, in which case any number of constituents could have
|
| 444 |
+
been predicted. This represents a solution of the ambiguous form
|
| 445 |
+
of the Close/Open transition where the Close could occur in
|
| 446 |
+
multiple places later in the sequence.
|
| 447 |
+
"""
|
| 448 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 449 |
+
return None
|
| 450 |
+
|
| 451 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 452 |
+
return None
|
| 453 |
+
|
| 454 |
+
if gold_sequence[gold_index+1] != pred_transition:
|
| 455 |
+
return None
|
| 456 |
+
|
| 457 |
+
close_index = find_constituent_end(gold_sequence, gold_index+1)
|
| 458 |
+
if check_close and not isinstance(gold_sequence[close_index+1], CloseConstituent):
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
# at this point, we know we can put the Close at the end of the
|
| 462 |
+
# Open which was accidentally added
|
| 463 |
+
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index+1] + [gold_transition] + gold_sequence[close_index+1:]
|
| 464 |
+
return updated_sequence
|
| 465 |
+
|
| 466 |
+
def fix_close_open_correct_open_ambiguous_immediate(*args, **kwargs):
|
| 467 |
+
return fix_close_open_correct_open(*args, **kwargs, check_close=False)
|
| 468 |
+
|
| 469 |
+
def fix_close_open_correct_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):
|
| 470 |
+
"""
|
| 471 |
+
We were supposed to Close, but instead did an Open in an ambiguous context. Here we resolve it later in the tree
|
| 472 |
+
"""
|
| 473 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 474 |
+
return None
|
| 475 |
+
|
| 476 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 477 |
+
return None
|
| 478 |
+
|
| 479 |
+
if gold_sequence[gold_index+1] != pred_transition:
|
| 480 |
+
return None
|
| 481 |
+
|
| 482 |
+
# this will be the index of the Close for the surrounding constituent
|
| 483 |
+
close_index = advance_past_constituents(gold_sequence, gold_index+1)
|
| 484 |
+
updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + [gold_transition] + gold_sequence[close_index:]
|
| 485 |
+
return updated_sequence
|
| 486 |
+
|
| 487 |
+
def fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 488 |
+
"""
|
| 489 |
+
If there is an Open/Open error which is not covered by the unambiguous single recall error, we try fixing it as a Unary
|
| 490 |
+
"""
|
| 491 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 492 |
+
return None
|
| 493 |
+
|
| 494 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 495 |
+
return None
|
| 496 |
+
|
| 497 |
+
if pred_transition == gold_transition:
|
| 498 |
+
return None
|
| 499 |
+
if gold_sequence[gold_index+1] == pred_transition:
|
| 500 |
+
# This case is covered by the nested open repair
|
| 501 |
+
return None
|
| 502 |
+
|
| 503 |
+
close_index = find_constituent_end(gold_sequence, gold_index)
|
| 504 |
+
assert close_index is not None
|
| 505 |
+
assert isinstance(gold_sequence[close_index], CloseConstituent)
|
| 506 |
+
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]
|
| 507 |
+
return updated_sequence
|
| 508 |
+
|
| 509 |
+
def fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 510 |
+
"""
|
| 511 |
+
If there is an Open/Open error which is not covered by the
|
| 512 |
+
unambiguous single recall error, we try fixing it by putting the
|
| 513 |
+
close at the end of the outer constituent
|
| 514 |
+
|
| 515 |
+
"""
|
| 516 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 517 |
+
return None
|
| 518 |
+
|
| 519 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 520 |
+
return None
|
| 521 |
+
|
| 522 |
+
if pred_transition == gold_transition:
|
| 523 |
+
return None
|
| 524 |
+
if gold_sequence[gold_index+1] == pred_transition:
|
| 525 |
+
# This case is covered by the nested open repair
|
| 526 |
+
return None
|
| 527 |
+
|
| 528 |
+
close_index = advance_past_constituents(gold_sequence, gold_index)
|
| 529 |
+
updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]
|
| 530 |
+
return updated_sequence
|
| 531 |
+
|
| 532 |
+
def fix_open_open_ambiguous_random(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 533 |
+
"""
|
| 534 |
+
If there is an Open/Open error which is not covered by the
|
| 535 |
+
unambiguous single recall error, we try fixing it by putting the
|
| 536 |
+
close at the end of the outer constituent
|
| 537 |
+
|
| 538 |
+
"""
|
| 539 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 540 |
+
return None
|
| 541 |
+
|
| 542 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
if pred_transition == gold_transition:
|
| 546 |
+
return None
|
| 547 |
+
if gold_sequence[gold_index+1] == pred_transition:
|
| 548 |
+
# This case is covered by the nested open repair
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
if random.random() < 0.5:
|
| 552 |
+
return fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)
|
| 553 |
+
else:
|
| 554 |
+
return fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 558 |
+
if not isinstance(gold_transition, Shift):
|
| 559 |
+
return None
|
| 560 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 561 |
+
return None
|
| 562 |
+
|
| 563 |
+
return RepairType.OTHER_SHIFT_OPEN, None
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 567 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 568 |
+
return None
|
| 569 |
+
if not isinstance(pred_transition, Shift):
|
| 570 |
+
return None
|
| 571 |
+
|
| 572 |
+
return RepairType.OTHER_CLOSE_SHIFT, None
|
| 573 |
+
|
| 574 |
+
def report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 575 |
+
if not isinstance(gold_transition, CloseConstituent):
|
| 576 |
+
return None
|
| 577 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 578 |
+
return None
|
| 579 |
+
|
| 580 |
+
return RepairType.OTHER_CLOSE_OPEN, None
|
| 581 |
+
|
| 582 |
+
def report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
|
| 583 |
+
if not isinstance(gold_transition, OpenConstituent):
|
| 584 |
+
return None
|
| 585 |
+
if not isinstance(pred_transition, OpenConstituent):
|
| 586 |
+
return None
|
| 587 |
+
|
| 588 |
+
return RepairType.OTHER_OPEN_OPEN, None
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class RepairType(Enum):
|
| 592 |
+
"""
|
| 593 |
+
Keep track of which repair is used, if any, on an incorrect transition
|
| 594 |
+
|
| 595 |
+
A test of the top-down oracle with no charlm or transformer
|
| 596 |
+
(eg, word vectors only) on EN PTB3 goes as follows.
|
| 597 |
+
3x training rounds, best training parameters as of Jan. 2024
|
| 598 |
+
unambiguous transitions only:
|
| 599 |
+
oracle scheme dev test
|
| 600 |
+
no oracle 0.9230 0.9194
|
| 601 |
+
+shift/close 0.9224 0.9180
|
| 602 |
+
+open/close 0.9225 0.9193
|
| 603 |
+
+open/shift (one) 0.9245 0.9207
|
| 604 |
+
+open/shift (mult) 0.9243 0.9211
|
| 605 |
+
+open/open nested 0.9258 0.9213
|
| 606 |
+
+shift/open 0.9266 0.9229
|
| 607 |
+
+close/shift (only) 0.9270 0.9230
|
| 608 |
+
+close/shift w/ opens 0.9262 0.9221
|
| 609 |
+
+close/open one con 0.9273 0.9230
|
| 610 |
+
|
| 611 |
+
Potential solutions for various ambiguous transitions:
|
| 612 |
+
|
| 613 |
+
close/open
|
| 614 |
+
can close immediately after the corresponding constituent or after any number of constituents
|
| 615 |
+
|
| 616 |
+
close/shift
|
| 617 |
+
can close immediately
|
| 618 |
+
can close anywhere up to the next close
|
| 619 |
+
any number of missed Opens are treated as recall errors
|
| 620 |
+
|
| 621 |
+
open/open
|
| 622 |
+
could treat as unary
|
| 623 |
+
could close at any number of positions after the next structures, up to the outer open's closing
|
| 624 |
+
|
| 625 |
+
shift/open ambiguity resolutions:
|
| 626 |
+
treat as unary
|
| 627 |
+
treat as wrapper around the next full constituent to build
|
| 628 |
+
treat as wrapper around everything to build until the next constituent
|
| 629 |
+
|
| 630 |
+
testing one at a time in addition to the full set of unambiguous corrections:
|
| 631 |
+
+close/open immediate 0.9259 0.9225
|
| 632 |
+
+close/open later 0.9258 0.9257
|
| 633 |
+
+close/shift immediate 0.9261 0.9219
|
| 634 |
+
+close/shift later 0.9270 0.9230
|
| 635 |
+
+open/open later 0.9269 0.9239
|
| 636 |
+
+open/open unary 0.9275 0.9246
|
| 637 |
+
+shift/open later 0.9263 0.9253
|
| 638 |
+
+shift/open unary 0.9264 0.9243
|
| 639 |
+
|
| 640 |
+
so there is some evidence that open/open or shift/open would be beneficial
|
| 641 |
+
|
| 642 |
+
Training by randomly choosing between the open/open, 50/50
|
| 643 |
+
+open/open random 0.9257 0.9235
|
| 644 |
+
so that didn't work great compared to the individual transitions
|
| 645 |
+
|
| 646 |
+
Testing deterministic resolutions of the ambiguous transitions
|
| 647 |
+
vs predicting the appropriate transition to use:
|
| 648 |
+
SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR,CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR,CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR
|
| 649 |
+
SHIFT_OPEN_AMBIGUOUS_PREDICTED,CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED
|
| 650 |
+
|
| 651 |
+
EN ambiguous (no charlm or transformer) 0.9268 0.9231
|
| 652 |
+
EN predicted 0.9270 0.9257
|
| 653 |
+
EN none of the above 0.9268 0.9229
|
| 654 |
+
|
| 655 |
+
ZH ambiguous 0.9137 0.9127
|
| 656 |
+
ZH predicted 0.9148 0.9141
|
| 657 |
+
ZH none of the above 0.9141 0.9143
|
| 658 |
+
|
| 659 |
+
DE ambiguous 0.9579 0.9408
|
| 660 |
+
DE predicted 0.9575 0.9406
|
| 661 |
+
DE none of the above 0.9581 0.9411
|
| 662 |
+
|
| 663 |
+
ID ambiguous 0.8889 0.8794
|
| 664 |
+
ID predicted 0.8911 0.8801
|
| 665 |
+
ID none of the above 0.8913 0.8822
|
| 666 |
+
|
| 667 |
+
IT ambiguous 0.8404 0.8380
|
| 668 |
+
IT predicted 0.8397 0.8398
|
| 669 |
+
IT none of the above 0.8400 0.8409
|
| 670 |
+
|
| 671 |
+
VI ambiguous 0.8290 0.7676
|
| 672 |
+
VI predicted 0.8287 0.7682
|
| 673 |
+
VI none of the above 0.8292 0.7691
|
| 674 |
+
"""
|
| 675 |
+
def __new__(cls, fn, correct=False, debug=False):
|
| 676 |
+
"""
|
| 677 |
+
Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
|
| 678 |
+
"""
|
| 679 |
+
value = len(cls.__members__)
|
| 680 |
+
obj = object.__new__(cls)
|
| 681 |
+
obj._value_ = value + 1
|
| 682 |
+
obj.fn = fn
|
| 683 |
+
obj.correct = correct
|
| 684 |
+
obj.debug = debug
|
| 685 |
+
return obj
|
| 686 |
+
|
| 687 |
+
@property
|
| 688 |
+
def is_correct(self):
|
| 689 |
+
return self.correct
|
| 690 |
+
|
| 691 |
+
# The parser chose to close a bracket instead of shift something
|
| 692 |
+
# into the bracket
|
| 693 |
+
# This causes both a precision and a recall error as there is now
|
| 694 |
+
# an incorrect bracket and a missing correct bracket
|
| 695 |
+
# Any bracket creation here would cause more wrong brackets, though
|
| 696 |
+
SHIFT_CLOSE_ERROR = (fix_shift_close,)
|
| 697 |
+
|
| 698 |
+
OPEN_CLOSE_ERROR = (fix_open_close,)
|
| 699 |
+
|
| 700 |
+
# open followed by shift was instead predicted to be shift
|
| 701 |
+
ONE_OPEN_SHIFT_ERROR = (fix_one_open_shift,)
|
| 702 |
+
|
| 703 |
+
# open followed by shift was instead predicted to be shift
|
| 704 |
+
MULTIPLE_OPEN_SHIFT_ERROR = (fix_multiple_open_shift,)
|
| 705 |
+
|
| 706 |
+
# should have done Open(X), Open(Y)
|
| 707 |
+
# instead just did Open(Y)
|
| 708 |
+
NESTED_OPEN_OPEN_ERROR = (fix_nested_open_constituent,)
|
| 709 |
+
|
| 710 |
+
SHIFT_OPEN_ERROR = (fix_shift_open_immediate_close,)
|
| 711 |
+
|
| 712 |
+
CLOSE_SHIFT_ERROR = (fix_close_shift,)
|
| 713 |
+
|
| 714 |
+
CLOSE_SHIFT_WITH_OPENS_ERROR = (fix_close_shift_with_opens,)
|
| 715 |
+
|
| 716 |
+
CLOSE_OPEN_ONE_CON_ERROR = (fix_close_open_correct_open,)
|
| 717 |
+
|
| 718 |
+
CORRECT = (None, True)
|
| 719 |
+
|
| 720 |
+
UNKNOWN = None
|
| 721 |
+
|
| 722 |
+
CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_open_correct_open_ambiguous_immediate,)
|
| 723 |
+
|
| 724 |
+
CLOSE_OPEN_AMBIGUOUS_LATER_ERROR = (fix_close_open_correct_open_ambiguous_later,)
|
| 725 |
+
|
| 726 |
+
CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_shift_ambiguous_immediate,)
|
| 727 |
+
|
| 728 |
+
CLOSE_SHIFT_AMBIGUOUS_LATER_ERROR = (fix_close_shift_ambiguous_later,)
|
| 729 |
+
|
| 730 |
+
# can potentially fix either close/shift or close/open
|
| 731 |
+
# as long as the gold transition after the close
|
| 732 |
+
# was the same as the transition we just predicted
|
| 733 |
+
CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED = (fix_close_next_correct_predicted,)
|
| 734 |
+
|
| 735 |
+
OPEN_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_open_open_ambiguous_unary,)
|
| 736 |
+
|
| 737 |
+
OPEN_OPEN_AMBIGUOUS_LATER_ERROR = (fix_open_open_ambiguous_later,)
|
| 738 |
+
|
| 739 |
+
OPEN_OPEN_AMBIGUOUS_RANDOM_ERROR = (fix_open_open_ambiguous_random,)
|
| 740 |
+
|
| 741 |
+
SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_shift_open_ambiguous_unary,)
|
| 742 |
+
|
| 743 |
+
SHIFT_OPEN_AMBIGUOUS_LATER_ERROR = (fix_shift_open_ambiguous_later,)
|
| 744 |
+
|
| 745 |
+
SHIFT_OPEN_AMBIGUOUS_PREDICTED = (fix_shift_open_ambiguous_predicted,)
|
| 746 |
+
|
| 747 |
+
OTHER_SHIFT_OPEN = (report_shift_open, False, True)
|
| 748 |
+
|
| 749 |
+
OTHER_CLOSE_SHIFT = (report_close_shift, False, True)
|
| 750 |
+
|
| 751 |
+
OTHER_CLOSE_OPEN = (report_close_open, False, True)
|
| 752 |
+
|
| 753 |
+
OTHER_OPEN_OPEN = (report_open_open, False, True)
|
| 754 |
+
|
| 755 |
+
class TopDownOracle(DynamicOracle):
|
| 756 |
+
def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
|
| 757 |
+
super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
|
stanza/stanza/models/constituency/trainer.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file includes a variety of methods needed to train new
|
| 3 |
+
constituency parsers. It also includes a method to load an
|
| 4 |
+
already-trained parser.
|
| 5 |
+
|
| 6 |
+
See the `train` method for the code block which starts from
|
| 7 |
+
raw treebank and returns a new parser.
|
| 8 |
+
`evaluate` reads a treebank and gives a score for those trees.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain, NoTransformerFoundationCache
|
| 18 |
+
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper, pop_peft_args
|
| 19 |
+
from stanza.models.constituency.base_trainer import BaseTrainer, ModelType
|
| 20 |
+
from stanza.models.constituency.lstm_model import LSTMModel, SentenceBoundary, StackHistory, ConstituencyComposition
|
| 21 |
+
from stanza.models.constituency.parse_transitions import Transition, TransitionScheme
|
| 22 |
+
from stanza.models.constituency.utils import build_optimizer, build_scheduler
|
| 23 |
+
# TODO: could put find_wordvec_pretrain, choose_charlm, etc in a more central place if it becomes widely used
|
| 24 |
+
from stanza.utils.training.common import find_wordvec_pretrain, choose_charlm, find_charlm_file
|
| 25 |
+
from stanza.resources.default_packages import default_charlms, default_pretrains
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger('stanza')
|
| 28 |
+
tlogger = logging.getLogger('stanza.constituency.trainer')
|
| 29 |
+
|
| 30 |
+
class Trainer(BaseTrainer):
|
| 31 |
+
"""
|
| 32 |
+
Stores a constituency model and its optimizer
|
| 33 |
+
|
| 34 |
+
Not inheriting from common/trainer.py because there's no concept of change_lr (yet?)
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
|
| 37 |
+
super().__init__(model, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)
|
| 38 |
+
|
| 39 |
+
def save(self, filename, save_optimizer=True):
|
| 40 |
+
"""
|
| 41 |
+
Save the model (and by default the optimizer) to the given path
|
| 42 |
+
"""
|
| 43 |
+
super().save(filename, save_optimizer)
|
| 44 |
+
|
| 45 |
+
def get_peft_params(self):
|
| 46 |
+
# Hide import so that peft dependency is optional
|
| 47 |
+
if self.model.args.get('use_peft', False):
|
| 48 |
+
from peft import get_peft_model_state_dict
|
| 49 |
+
return get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def model_type(self):
|
| 54 |
+
return ModelType.LSTM
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def find_and_load_pretrain(saved_args, foundation_cache):
|
| 58 |
+
if 'wordvec_pretrain_file' not in saved_args:
|
| 59 |
+
return None
|
| 60 |
+
if os.path.exists(saved_args['wordvec_pretrain_file']):
|
| 61 |
+
return load_pretrain(saved_args['wordvec_pretrain_file'], foundation_cache)
|
| 62 |
+
logger.info("Unable to find pretrain in %s Will try to load from the default resources instead", saved_args['wordvec_pretrain_file'])
|
| 63 |
+
language = saved_args['lang']
|
| 64 |
+
wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains)
|
| 65 |
+
return load_pretrain(wordvec_pretrain, foundation_cache)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def find_and_load_charlm(charlm_file, direction, saved_args, foundation_cache):
|
| 69 |
+
try:
|
| 70 |
+
return load_charlm(charlm_file, foundation_cache)
|
| 71 |
+
except FileNotFoundError as e:
|
| 72 |
+
logger.info("Unable to load charlm from %s Will try to load from the default resources instead", charlm_file)
|
| 73 |
+
language = saved_args['lang']
|
| 74 |
+
dataset = saved_args['shorthand'].split("_")[1]
|
| 75 |
+
charlm = choose_charlm(language, dataset, "default", default_charlms, {})
|
| 76 |
+
charlm_file = find_charlm_file(direction, language, charlm)
|
| 77 |
+
return load_charlm(charlm_file, foundation_cache)
|
| 78 |
+
|
| 79 |
+
def log_num_words_known(self, words):
|
| 80 |
+
tlogger.info("Number of words in the training set found in the embedding: %d out of %d", self.model.num_words_known(words), len(words))
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def load_optimizer(model, checkpoint, first_optimizer, filename):
|
| 84 |
+
optimizer = build_optimizer(model.args, model, first_optimizer)
|
| 85 |
+
if checkpoint.get('optimizer_state_dict', None) is not None:
|
| 86 |
+
try:
|
| 87 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 88 |
+
except ValueError as e:
|
| 89 |
+
raise ValueError("Failed to load optimizer from %s" % filename) from e
|
| 90 |
+
else:
|
| 91 |
+
logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
|
| 92 |
+
return optimizer
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def load_scheduler(model, optimizer, checkpoint, first_optimizer):
|
| 96 |
+
scheduler = build_scheduler(model.args, optimizer, first_optimizer=first_optimizer)
|
| 97 |
+
if 'scheduler_state_dict' in checkpoint:
|
| 98 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 99 |
+
return scheduler
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):
|
| 103 |
+
"""
|
| 104 |
+
Build a new model just from the saved params and some extra args
|
| 105 |
+
|
| 106 |
+
Refactoring allows other processors to include a constituency parser as a module
|
| 107 |
+
"""
|
| 108 |
+
saved_args = dict(params['config'])
|
| 109 |
+
if isinstance(saved_args['sentence_boundary_vectors'], str):
|
| 110 |
+
saved_args['sentence_boundary_vectors'] = SentenceBoundary[saved_args['sentence_boundary_vectors']]
|
| 111 |
+
if isinstance(saved_args['constituency_composition'], str):
|
| 112 |
+
saved_args['constituency_composition'] = ConstituencyComposition[saved_args['constituency_composition']]
|
| 113 |
+
if isinstance(saved_args['transition_stack'], str):
|
| 114 |
+
saved_args['transition_stack'] = StackHistory[saved_args['transition_stack']]
|
| 115 |
+
if isinstance(saved_args['constituent_stack'], str):
|
| 116 |
+
saved_args['constituent_stack'] = StackHistory[saved_args['constituent_stack']]
|
| 117 |
+
if isinstance(saved_args['transition_scheme'], str):
|
| 118 |
+
saved_args['transition_scheme'] = TransitionScheme[saved_args['transition_scheme']]
|
| 119 |
+
|
| 120 |
+
# some parameters which change the structure of a model have
|
| 121 |
+
# to be ignored, or the model will not function when it is
|
| 122 |
+
# reloaded from disk
|
| 123 |
+
if args is None: args = {}
|
| 124 |
+
update_args = copy.deepcopy(args)
|
| 125 |
+
pop_peft_args(update_args)
|
| 126 |
+
update_args.pop("bert_hidden_layers", None)
|
| 127 |
+
update_args.pop("bert_model", None)
|
| 128 |
+
update_args.pop("constituency_composition", None)
|
| 129 |
+
update_args.pop("constituent_stack", None)
|
| 130 |
+
update_args.pop("num_tree_lstm_layers", None)
|
| 131 |
+
update_args.pop("transition_scheme", None)
|
| 132 |
+
update_args.pop("transition_stack", None)
|
| 133 |
+
update_args.pop("maxout_k", None)
|
| 134 |
+
# if the pretrain or charlms are not specified, don't override the values in the model
|
| 135 |
+
# (if any), since the model won't even work without loading the same charlm
|
| 136 |
+
if 'wordvec_pretrain_file' in update_args and update_args['wordvec_pretrain_file'] is None:
|
| 137 |
+
update_args.pop('wordvec_pretrain_file')
|
| 138 |
+
if 'charlm_forward_file' in update_args and update_args['charlm_forward_file'] is None:
|
| 139 |
+
update_args.pop('charlm_forward_file')
|
| 140 |
+
if 'charlm_backward_file' in update_args and update_args['charlm_backward_file'] is None:
|
| 141 |
+
update_args.pop('charlm_backward_file')
|
| 142 |
+
# we don't pop bert_finetune, with the theory being that if
|
| 143 |
+
# the saved model has bert_finetune==True we can load the bert
|
| 144 |
+
# weights but then not further finetune if bert_finetune==False
|
| 145 |
+
saved_args.update(update_args)
|
| 146 |
+
|
| 147 |
+
# TODO: not needed if we rebuild the models
|
| 148 |
+
if saved_args.get("bert_finetune", None) is None:
|
| 149 |
+
saved_args["bert_finetune"] = False
|
| 150 |
+
if saved_args.get("stage1_bert_finetune", None) is None:
|
| 151 |
+
saved_args["stage1_bert_finetune"] = False
|
| 152 |
+
|
| 153 |
+
model_type = params['model_type']
|
| 154 |
+
if model_type == 'LSTM':
|
| 155 |
+
pt = Trainer.find_and_load_pretrain(saved_args, foundation_cache)
|
| 156 |
+
if saved_args.get('use_peft', False):
|
| 157 |
+
# if loading a peft model, we first load the base transformer
|
| 158 |
+
# then we load the weights using the saved weights in the file
|
| 159 |
+
if peft_name is None:
|
| 160 |
+
bert_model, bert_tokenizer, peft_name = load_bert_with_peft(saved_args.get('bert_model', None), "constituency", foundation_cache)
|
| 161 |
+
else:
|
| 162 |
+
bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)
|
| 163 |
+
bert_model = load_peft_wrapper(bert_model, peft_params, saved_args, logger, peft_name)
|
| 164 |
+
bert_saved = True
|
| 165 |
+
elif saved_args['bert_finetune'] or saved_args['stage1_bert_finetune'] or any(x.startswith("bert_model.") for x in params['model'].keys()):
|
| 166 |
+
# if bert_finetune is True, don't use the cached model!
|
| 167 |
+
# otherwise, other uses of the cached model will be ruined
|
| 168 |
+
bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None))
|
| 169 |
+
bert_saved = True
|
| 170 |
+
else:
|
| 171 |
+
bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)
|
| 172 |
+
bert_saved = False
|
| 173 |
+
forward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_forward_file"], "forward", saved_args, foundation_cache)
|
| 174 |
+
backward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_backward_file"], "backward", saved_args, foundation_cache)
|
| 175 |
+
|
| 176 |
+
# TODO: the isinstance will be unnecessary after 1.10.0
|
| 177 |
+
transitions = params['transitions']
|
| 178 |
+
if all(isinstance(x, str) for x in transitions):
|
| 179 |
+
transitions = [Transition.from_repr(x) for x in transitions]
|
| 180 |
+
|
| 181 |
+
model = LSTMModel(pretrain=pt,
|
| 182 |
+
forward_charlm=forward_charlm,
|
| 183 |
+
backward_charlm=backward_charlm,
|
| 184 |
+
bert_model=bert_model,
|
| 185 |
+
bert_tokenizer=bert_tokenizer,
|
| 186 |
+
force_bert_saved=bert_saved,
|
| 187 |
+
peft_name=peft_name,
|
| 188 |
+
transitions=transitions,
|
| 189 |
+
constituents=params['constituents'],
|
| 190 |
+
tags=params['tags'],
|
| 191 |
+
words=params['words'],
|
| 192 |
+
rare_words=set(params['rare_words']),
|
| 193 |
+
root_labels=params['root_labels'],
|
| 194 |
+
constituent_opens=params['constituent_opens'],
|
| 195 |
+
unary_limit=params['unary_limit'],
|
| 196 |
+
args=saved_args)
|
| 197 |
+
else:
|
| 198 |
+
raise ValueError("Unknown model type {}".format(model_type))
|
| 199 |
+
model.load_state_dict(params['model'], strict=False)
|
| 200 |
+
# model will stay on CPU if device==None
|
| 201 |
+
# can be moved elsewhere later, of course
|
| 202 |
+
model = model.to(args.get('device', None))
|
| 203 |
+
return model
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file):
|
| 207 |
+
# TODO: turn finetune, relearn_structure, multistage into an enum?
|
| 208 |
+
# finetune just means continue learning, so checkpoint is sufficient
|
| 209 |
+
# relearn_structure is essentially a one stage multistage
|
| 210 |
+
# multistage with a checkpoint will have the proper optimizer for that epoch
|
| 211 |
+
# and no special learning mode means we are training a new model and should continue
|
| 212 |
+
if args['checkpoint'] and args['checkpoint_save_name'] and os.path.exists(args['checkpoint_save_name']):
|
| 213 |
+
tlogger.info("Found checkpoint to continue training: %s", args['checkpoint_save_name'])
|
| 214 |
+
trainer = Trainer.load(args['checkpoint_save_name'], args, load_optimizer=True, foundation_cache=foundation_cache)
|
| 215 |
+
return trainer
|
| 216 |
+
|
| 217 |
+
# in the 'finetune' case, this will preload the models into foundation_cache,
|
| 218 |
+
# so the effort is not wasted
|
| 219 |
+
pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file'])
|
| 220 |
+
forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])
|
| 221 |
+
backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])
|
| 222 |
+
|
| 223 |
+
if args['finetune']:
|
| 224 |
+
tlogger.info("Loading model to finetune: %s", model_load_file)
|
| 225 |
+
trainer = Trainer.load(model_load_file, args, load_optimizer=True, foundation_cache=NoTransformerFoundationCache(foundation_cache))
|
| 226 |
+
# a new finetuning will start with a new epochs_trained count
|
| 227 |
+
trainer.epochs_trained = 0
|
| 228 |
+
return trainer
|
| 229 |
+
|
| 230 |
+
if args['relearn_structure']:
|
| 231 |
+
tlogger.info("Loading model to continue training with new structure from %s", model_load_file)
|
| 232 |
+
temp_args = dict(args)
|
| 233 |
+
# remove the pattn & lattn layers unless the saved model had them
|
| 234 |
+
temp_args.pop('pattn_num_layers', None)
|
| 235 |
+
temp_args.pop('lattn_d_proj', None)
|
| 236 |
+
trainer = Trainer.load(model_load_file, temp_args, load_optimizer=False, foundation_cache=NoTransformerFoundationCache(foundation_cache))
|
| 237 |
+
|
| 238 |
+
# using the model's current values works for if the new
|
| 239 |
+
# dataset is the same or smaller
|
| 240 |
+
# TODO: handle a larger dataset as well
|
| 241 |
+
model = LSTMModel(pt,
|
| 242 |
+
forward_charlm,
|
| 243 |
+
backward_charlm,
|
| 244 |
+
trainer.model.bert_model,
|
| 245 |
+
trainer.model.bert_tokenizer,
|
| 246 |
+
trainer.model.force_bert_saved,
|
| 247 |
+
trainer.model.peft_name,
|
| 248 |
+
trainer.model.transitions,
|
| 249 |
+
trainer.model.constituents,
|
| 250 |
+
trainer.model.tags,
|
| 251 |
+
trainer.model.delta_words,
|
| 252 |
+
trainer.model.rare_words,
|
| 253 |
+
trainer.model.root_labels,
|
| 254 |
+
trainer.model.constituent_opens,
|
| 255 |
+
trainer.model.unary_limit(),
|
| 256 |
+
args)
|
| 257 |
+
model = model.to(args['device'])
|
| 258 |
+
model.copy_with_new_structure(trainer.model)
|
| 259 |
+
optimizer = build_optimizer(args, model, False)
|
| 260 |
+
scheduler = build_scheduler(args, optimizer)
|
| 261 |
+
trainer = Trainer(model, optimizer, scheduler)
|
| 262 |
+
return trainer
|
| 263 |
+
|
| 264 |
+
if args['multistage']:
|
| 265 |
+
# run adadelta over the model for half the time with no pattn or lattn
|
| 266 |
+
# training then switches to a different optimizer for the rest
|
| 267 |
+
# this works surprisingly well
|
| 268 |
+
tlogger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2)
|
| 269 |
+
temp_args = dict(args)
|
| 270 |
+
# remove the attention layers for the temporary model
|
| 271 |
+
temp_args['pattn_num_layers'] = 0
|
| 272 |
+
temp_args['lattn_d_proj'] = 0
|
| 273 |
+
args = temp_args
|
| 274 |
+
|
| 275 |
+
peft_name = None
|
| 276 |
+
if args['use_peft']:
|
| 277 |
+
peft_name = "constituency"
|
| 278 |
+
bert_model, bert_tokenizer = load_bert(args['bert_model'])
|
| 279 |
+
bert_model = build_peft_wrapper(bert_model, temp_args, tlogger, adapter_name=peft_name)
|
| 280 |
+
elif args['bert_finetune'] or args['stage1_bert_finetune']:
|
| 281 |
+
bert_model, bert_tokenizer = load_bert(args['bert_model'])
|
| 282 |
+
else:
|
| 283 |
+
bert_model, bert_tokenizer = load_bert(args['bert_model'], foundation_cache)
|
| 284 |
+
model = LSTMModel(pt,
|
| 285 |
+
forward_charlm,
|
| 286 |
+
backward_charlm,
|
| 287 |
+
bert_model,
|
| 288 |
+
bert_tokenizer,
|
| 289 |
+
False,
|
| 290 |
+
peft_name,
|
| 291 |
+
train_transitions,
|
| 292 |
+
train_constituents,
|
| 293 |
+
tags,
|
| 294 |
+
words,
|
| 295 |
+
rare_words,
|
| 296 |
+
root_labels,
|
| 297 |
+
open_nodes,
|
| 298 |
+
unary_limit,
|
| 299 |
+
args)
|
| 300 |
+
model = model.to(args['device'])
|
| 301 |
+
|
| 302 |
+
optimizer = build_optimizer(args, model, build_simple_adadelta=args['multistage'])
|
| 303 |
+
scheduler = build_scheduler(args, optimizer, first_optimizer=args['multistage'])
|
| 304 |
+
|
| 305 |
+
trainer = Trainer(model, optimizer, scheduler, first_optimizer=args['multistage'])
|
| 306 |
+
return trainer
|
stanza/stanza/models/constituency/transformer_tree_stack.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on
|
| 3 |
+
|
| 4 |
+
Transition-based Parsing with Stack-Transformers
|
| 5 |
+
Ramon Fernandez Astudillo, Miguel Ballesteros, Tahira Naseem,
|
| 6 |
+
Austin Blodget, and Radu Florian
|
| 7 |
+
https://aclanthology.org/2020.findings-emnlp.89.pdf
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from stanza.models.constituency.positional_encoding import SinusoidalEncoding
|
| 16 |
+
from stanza.models.constituency.tree_stack import TreeStack
|
| 17 |
+
|
| 18 |
+
Node = namedtuple("Node", ['value', 'key_stack', 'value_stack', 'output'])
|
| 19 |
+
|
| 20 |
+
class TransformerTreeStack(nn.Module):
|
| 21 |
+
def __init__(self, input_size, output_size, input_dropout, length_limit=None, use_position=False, num_heads=1):
|
| 22 |
+
"""
|
| 23 |
+
Builds the internal matrices and start parameter
|
| 24 |
+
|
| 25 |
+
TODO: currently only one attention head, implement MHA
|
| 26 |
+
"""
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.input_size = input_size
|
| 30 |
+
self.output_size = output_size
|
| 31 |
+
self.inv_sqrt_output_size = 1 / output_size ** 0.5
|
| 32 |
+
self.num_heads = num_heads
|
| 33 |
+
|
| 34 |
+
self.w_query = nn.Linear(input_size, output_size)
|
| 35 |
+
self.w_key = nn.Linear(input_size, output_size)
|
| 36 |
+
self.w_value = nn.Linear(input_size, output_size)
|
| 37 |
+
|
| 38 |
+
self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
|
| 39 |
+
if isinstance(input_dropout, nn.Module):
|
| 40 |
+
self.input_dropout = input_dropout
|
| 41 |
+
else:
|
| 42 |
+
self.input_dropout = nn.Dropout(input_dropout)
|
| 43 |
+
|
| 44 |
+
if length_limit is not None and length_limit < 1:
|
| 45 |
+
raise ValueError("length_limit < 1 makes no sense")
|
| 46 |
+
self.length_limit = length_limit
|
| 47 |
+
|
| 48 |
+
self.use_position = use_position
|
| 49 |
+
if use_position:
|
| 50 |
+
self.position_encoding = SinusoidalEncoding(model_dim=self.input_size, max_len=512)
|
| 51 |
+
|
| 52 |
+
def attention(self, key, query, value, mask=None):
|
| 53 |
+
"""
|
| 54 |
+
Calculate attention for the given key, query value
|
| 55 |
+
|
| 56 |
+
Where B is the number of items stacked together, N is the length:
|
| 57 |
+
The key should be BxNxD
|
| 58 |
+
The query is BxD
|
| 59 |
+
The value is BxNxD
|
| 60 |
+
|
| 61 |
+
If mask is specified, it should be BxN of True/False values,
|
| 62 |
+
where True means that location is masked out
|
| 63 |
+
|
| 64 |
+
Reshapes and reorders are used to handle num_heads
|
| 65 |
+
|
| 66 |
+
Return will be softmax(query x key^T) * value
|
| 67 |
+
of size BxD
|
| 68 |
+
"""
|
| 69 |
+
B = key.shape[0]
|
| 70 |
+
N = key.shape[1]
|
| 71 |
+
D = key.shape[2]
|
| 72 |
+
|
| 73 |
+
H = self.num_heads
|
| 74 |
+
|
| 75 |
+
# query is now BxDx1
|
| 76 |
+
query = query.unsqueeze(2)
|
| 77 |
+
# BxHxD/Hx1
|
| 78 |
+
query = query.reshape((B, H, -1, 1))
|
| 79 |
+
|
| 80 |
+
# BxNxHxD/H
|
| 81 |
+
key = key.reshape((B, N, H, -1))
|
| 82 |
+
# BxHxNxD/H
|
| 83 |
+
key = key.transpose(1, 2)
|
| 84 |
+
|
| 85 |
+
# BxNxHxD/H
|
| 86 |
+
value = value.reshape((B, N, H, -1))
|
| 87 |
+
# BxHxNxD/H
|
| 88 |
+
value = value.transpose(1, 2)
|
| 89 |
+
|
| 90 |
+
# BxHxNxD/H x BxHxD/Hx1
|
| 91 |
+
# result shape: BxHxN
|
| 92 |
+
attn = torch.matmul(key, query).squeeze(3) * self.inv_sqrt_output_size
|
| 93 |
+
if mask is not None:
|
| 94 |
+
# mask goes from BxN -> Bx1xN
|
| 95 |
+
mask = mask.unsqueeze(1)
|
| 96 |
+
mask = mask.expand(-1, H, -1)
|
| 97 |
+
attn.masked_fill_(mask, float('-inf'))
|
| 98 |
+
# attn shape will now be BxHx1xN
|
| 99 |
+
attn = torch.softmax(attn, dim=2).unsqueeze(2)
|
| 100 |
+
# BxHx1xN x BxHxNxD/H -> BxHxD/H
|
| 101 |
+
output = torch.matmul(attn, value).squeeze(2)
|
| 102 |
+
output = output.reshape(B, -1)
|
| 103 |
+
return output
|
| 104 |
+
|
| 105 |
+
def initial_state(self, initial_value=None):
|
| 106 |
+
"""
|
| 107 |
+
Return an initial state based on a single layer of attention
|
| 108 |
+
|
| 109 |
+
Running attention might be overkill, but it is the simplest
|
| 110 |
+
way to put the Linears and start_embedding in the computation graph
|
| 111 |
+
"""
|
| 112 |
+
start = self.start_embedding
|
| 113 |
+
if self.use_position:
|
| 114 |
+
position = self.position_encoding([0]).squeeze(0)
|
| 115 |
+
start = start + position
|
| 116 |
+
|
| 117 |
+
# N=1
|
| 118 |
+
# shape: 1xD
|
| 119 |
+
key = self.w_key(start).unsqueeze(0)
|
| 120 |
+
|
| 121 |
+
# shape: D
|
| 122 |
+
query = self.w_query(start)
|
| 123 |
+
|
| 124 |
+
# shape: 1xD
|
| 125 |
+
value = self.w_value(start).unsqueeze(0)
|
| 126 |
+
|
| 127 |
+
# unsqueeze to make it look like we are part of a batch of size 1
|
| 128 |
+
output = self.attention(key.unsqueeze(0), query.unsqueeze(0), value.unsqueeze(0)).squeeze(0)
|
| 129 |
+
return TreeStack(value=Node(initial_value, key, value, output), parent=None, length=1)
|
| 130 |
+
|
| 131 |
+
def push_states(self, stacks, values, inputs):
|
| 132 |
+
"""
|
| 133 |
+
Push new inputs to the stacks and rerun attention on them
|
| 134 |
+
|
| 135 |
+
Where B is the number of items stacked together, I is input_size
|
| 136 |
+
stacks: B TreeStacks such as produced by initial_state and/or push_states
|
| 137 |
+
values: the new items to push on the stacks such as tree nodes or anything
|
| 138 |
+
inputs: BxI for the new input items
|
| 139 |
+
|
| 140 |
+
Runs attention starting from the existing keys & values
|
| 141 |
+
"""
|
| 142 |
+
device = self.w_key.weight.device
|
| 143 |
+
|
| 144 |
+
batch_len = len(stacks) # B
|
| 145 |
+
positions = [x.value.key_stack.shape[0] for x in stacks]
|
| 146 |
+
max_len = max(positions) # N
|
| 147 |
+
|
| 148 |
+
if self.use_position:
|
| 149 |
+
position_encodings = self.position_encoding(positions)
|
| 150 |
+
inputs = inputs + position_encodings
|
| 151 |
+
|
| 152 |
+
inputs = self.input_dropout(inputs)
|
| 153 |
+
if len(inputs.shape) == 3:
|
| 154 |
+
if inputs.shape[0] == 1:
|
| 155 |
+
inputs = inputs.squeeze(0)
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError("Expected the inputs to be of shape 1xBxI, got {}".format(inputs.shape))
|
| 158 |
+
|
| 159 |
+
new_keys = self.w_key(inputs)
|
| 160 |
+
key_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
|
| 161 |
+
key_stack[:, -1, :] = new_keys
|
| 162 |
+
for stack_idx, stack in enumerate(stacks):
|
| 163 |
+
key_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.key_stack
|
| 164 |
+
|
| 165 |
+
new_values = self.w_value(inputs)
|
| 166 |
+
value_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
|
| 167 |
+
value_stack[:, -1, :] = new_values
|
| 168 |
+
for stack_idx, stack in enumerate(stacks):
|
| 169 |
+
value_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.value_stack
|
| 170 |
+
|
| 171 |
+
query = self.w_query(inputs)
|
| 172 |
+
|
| 173 |
+
mask = torch.zeros(batch_len, max_len+1, device=device, dtype=torch.bool)
|
| 174 |
+
for stack_idx, stack in enumerate(stacks):
|
| 175 |
+
if len(stack) < max_len:
|
| 176 |
+
masked = max_len - positions[stack_idx]
|
| 177 |
+
mask[stack_idx, :masked] = True
|
| 178 |
+
|
| 179 |
+
batched_output = self.attention(key_stack, query, value_stack, mask)
|
| 180 |
+
|
| 181 |
+
new_stacks = []
|
| 182 |
+
for stack_idx, (stack, node_value, new_key, new_value, output) in enumerate(zip(stacks, values, key_stack, value_stack, batched_output)):
|
| 183 |
+
# max_len-len(stack) so that we ignore the padding at the start of shorter stacks
|
| 184 |
+
new_key_stack = new_key[max_len-positions[stack_idx]:, :]
|
| 185 |
+
new_value_stack = new_value[max_len-positions[stack_idx]:, :]
|
| 186 |
+
if self.length_limit is not None and new_key_stack.shape[0] > self.length_limit + 1:
|
| 187 |
+
new_key_stack = torch.cat([new_key_stack[:1, :], new_key_stack[2:, :]], axis=0)
|
| 188 |
+
new_value_stack = torch.cat([new_value_stack[:1, :], new_value_stack[2:, :]], axis=0)
|
| 189 |
+
new_stacks.append(stack.push(value=Node(node_value, new_key_stack, new_value_stack, output)))
|
| 190 |
+
return new_stacks
|
| 191 |
+
|
| 192 |
+
def output(self, stack):
|
| 193 |
+
"""
|
| 194 |
+
Return the last layer of the lstm_hx as the output from a stack
|
| 195 |
+
|
| 196 |
+
Refactored so that alternate structures have an easy way of getting the output
|
| 197 |
+
"""
|
| 198 |
+
return stack.value.output
|
stanza/stanza/models/constituency/transition_sequence.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build a transition sequence from parse trees.
|
| 3 |
+
|
| 4 |
+
Supports multiple transition schemes - TOP_DOWN and variants, IN_ORDER
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from stanza.models.common import utils
|
| 10 |
+
from stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme, Finalize
|
| 11 |
+
from stanza.models.constituency.tree_reader import read_trees
|
| 12 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 13 |
+
|
| 14 |
+
tqdm = get_tqdm()
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger('stanza.constituency.trainer')
|
| 17 |
+
|
| 18 |
+
def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
|
| 19 |
+
"""
|
| 20 |
+
For tree (X A B C D), yield Open(X) A B C D Close
|
| 21 |
+
|
| 22 |
+
The details are in how to treat unary transitions
|
| 23 |
+
Three possibilities handled by this method:
|
| 24 |
+
TOP_DOWN_UNARY: (Y (X ...)) -> Open(X) ... Close Unary(Y)
|
| 25 |
+
TOP_DOWN_COMPOUND: (Y (X ...)) -> Open(Y, X) ... Close
|
| 26 |
+
TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close
|
| 27 |
+
"""
|
| 28 |
+
if tree.is_preterminal():
|
| 29 |
+
yield Shift()
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
if tree.is_leaf():
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
if transition_scheme is TransitionScheme.TOP_DOWN_UNARY:
|
| 36 |
+
if len(tree.children) == 1:
|
| 37 |
+
labels = []
|
| 38 |
+
while not tree.is_preterminal() and len(tree.children) == 1:
|
| 39 |
+
labels.append(tree.label)
|
| 40 |
+
tree = tree.children[0]
|
| 41 |
+
for transition in yield_top_down_sequence(tree, transition_scheme):
|
| 42 |
+
yield transition
|
| 43 |
+
yield CompoundUnary(*labels)
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
|
| 47 |
+
labels = [tree.label]
|
| 48 |
+
while len(tree.children) == 1 and not tree.children[0].is_preterminal():
|
| 49 |
+
tree = tree.children[0]
|
| 50 |
+
labels.append(tree.label)
|
| 51 |
+
yield OpenConstituent(*labels)
|
| 52 |
+
else:
|
| 53 |
+
yield OpenConstituent(tree.label)
|
| 54 |
+
for child in tree.children:
|
| 55 |
+
for transition in yield_top_down_sequence(child, transition_scheme):
|
| 56 |
+
yield transition
|
| 57 |
+
yield CloseConstituent()
|
| 58 |
+
|
| 59 |
+
def yield_in_order_sequence(tree):
|
| 60 |
+
"""
|
| 61 |
+
For tree (X A B C D), yield A Open(X) B C D Close
|
| 62 |
+
"""
|
| 63 |
+
if tree.is_preterminal():
|
| 64 |
+
yield Shift()
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
if tree.is_leaf():
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
for transition in yield_in_order_sequence(tree.children[0]):
|
| 71 |
+
yield transition
|
| 72 |
+
|
| 73 |
+
yield OpenConstituent(tree.label)
|
| 74 |
+
|
| 75 |
+
for child in tree.children[1:]:
|
| 76 |
+
for transition in yield_in_order_sequence(child):
|
| 77 |
+
yield transition
|
| 78 |
+
|
| 79 |
+
yield CloseConstituent()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def yield_in_order_compound_sequence(tree, transition_scheme):
|
| 84 |
+
def helper(tree):
|
| 85 |
+
if tree.is_leaf():
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
labels = []
|
| 89 |
+
while len(tree.children) == 1 and not tree.is_preterminal():
|
| 90 |
+
labels.append(tree.label)
|
| 91 |
+
tree = tree.children[0]
|
| 92 |
+
|
| 93 |
+
if tree.is_preterminal():
|
| 94 |
+
yield Shift()
|
| 95 |
+
if len(labels) > 0:
|
| 96 |
+
yield CompoundUnary(*labels)
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
for transition in helper(tree.children[0]):
|
| 100 |
+
yield transition
|
| 101 |
+
|
| 102 |
+
if transition_scheme is TransitionScheme.IN_ORDER_UNARY:
|
| 103 |
+
yield OpenConstituent(tree.label)
|
| 104 |
+
else:
|
| 105 |
+
labels.append(tree.label)
|
| 106 |
+
yield OpenConstituent(*labels)
|
| 107 |
+
|
| 108 |
+
for child in tree.children[1:]:
|
| 109 |
+
for transition in helper(child):
|
| 110 |
+
yield transition
|
| 111 |
+
|
| 112 |
+
yield CloseConstituent()
|
| 113 |
+
|
| 114 |
+
if transition_scheme is TransitionScheme.IN_ORDER_UNARY and len(labels) > 0:
|
| 115 |
+
yield CompoundUnary(*labels)
|
| 116 |
+
|
| 117 |
+
if len(tree.children) == 0:
|
| 118 |
+
raise ValueError("Cannot build {} on an empty tree".format(transition_scheme))
|
| 119 |
+
if len(tree.children) != 1:
|
| 120 |
+
raise ValueError("Cannot build {} with a tree that has two top level nodes: {}".format(transition_scheme, tree))
|
| 121 |
+
|
| 122 |
+
for t in helper(tree.children[0]):
|
| 123 |
+
yield t
|
| 124 |
+
|
| 125 |
+
yield Finalize(tree.label)
|
| 126 |
+
|
| 127 |
+
def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
|
| 128 |
+
"""
|
| 129 |
+
Turn a single tree into a list of transitions based on the TransitionScheme
|
| 130 |
+
"""
|
| 131 |
+
if transition_scheme is TransitionScheme.IN_ORDER:
|
| 132 |
+
return list(yield_in_order_sequence(tree))
|
| 133 |
+
elif (transition_scheme is TransitionScheme.IN_ORDER_COMPOUND or
|
| 134 |
+
transition_scheme is TransitionScheme.IN_ORDER_UNARY):
|
| 135 |
+
return list(yield_in_order_compound_sequence(tree, transition_scheme))
|
| 136 |
+
else:
|
| 137 |
+
return list(yield_top_down_sequence(tree, transition_scheme))
|
| 138 |
+
|
| 139 |
+
def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, reverse=False):
|
| 140 |
+
"""
|
| 141 |
+
Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme
|
| 142 |
+
"""
|
| 143 |
+
if reverse:
|
| 144 |
+
return [build_sequence(tree.reverse(), transition_scheme) for tree in trees]
|
| 145 |
+
else:
|
| 146 |
+
return [build_sequence(tree, transition_scheme) for tree in trees]
|
| 147 |
+
|
| 148 |
+
def all_transitions(transition_lists):
|
| 149 |
+
"""
|
| 150 |
+
Given a list of transition lists, combine them all into a list of unique transitions.
|
| 151 |
+
"""
|
| 152 |
+
transitions = set()
|
| 153 |
+
for trans_list in transition_lists:
|
| 154 |
+
transitions.update(trans_list)
|
| 155 |
+
return sorted(transitions)
|
| 156 |
+
|
| 157 |
+
def convert_trees_to_sequences(trees, treebank_name, transition_scheme, reverse=False):
|
| 158 |
+
"""
|
| 159 |
+
Wrap both build_treebank and all_transitions, possibly with a tqdm
|
| 160 |
+
|
| 161 |
+
Converts trees to a list of sequences, then returns the list of known transitions
|
| 162 |
+
"""
|
| 163 |
+
if len(trees) == 0:
|
| 164 |
+
return [], []
|
| 165 |
+
logger.info("Building %s transition sequences", treebank_name)
|
| 166 |
+
if logger.getEffectiveLevel() <= logging.INFO:
|
| 167 |
+
trees = tqdm(trees)
|
| 168 |
+
sequences = build_treebank(trees, transition_scheme, reverse)
|
| 169 |
+
transitions = all_transitions(sequences)
|
| 170 |
+
return sequences, transitions
|
| 171 |
+
|
| 172 |
+
def main():
|
| 173 |
+
"""
|
| 174 |
+
Convert a sample tree and print its transitions
|
| 175 |
+
"""
|
| 176 |
+
text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 177 |
+
#text = "(WP Who)"
|
| 178 |
+
|
| 179 |
+
tree = read_trees(text)[0]
|
| 180 |
+
|
| 181 |
+
print(tree)
|
| 182 |
+
transitions = build_sequence(tree)
|
| 183 |
+
print(transitions)
|
| 184 |
+
|
| 185 |
+
if __name__ == '__main__':
|
| 186 |
+
main()
|
stanza/stanza/models/constituency/tree_embedding.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A module to use a Constituency Parser to make an embedding for a tree
|
| 3 |
+
|
| 4 |
+
The embedding can be produced just from the words and the top of the
|
| 5 |
+
tree, or it can be done with a form of attention over the nodes
|
| 6 |
+
|
| 7 |
+
Can be done over an existing parse tree or unparsed text
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from stanza.models.constituency.trainer import Trainer
|
| 15 |
+
|
| 16 |
+
class TreeEmbedding(nn.Module):
|
| 17 |
+
def __init__(self, constituency_parser, args):
|
| 18 |
+
super(TreeEmbedding, self).__init__()
|
| 19 |
+
|
| 20 |
+
self.config = {
|
| 21 |
+
"all_words": args["all_words"],
|
| 22 |
+
"backprop": args["backprop"],
|
| 23 |
+
#"batch_norm": args["batch_norm"],
|
| 24 |
+
"node_attn": args["node_attn"],
|
| 25 |
+
"top_layer": args["top_layer"],
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
self.constituency_parser = constituency_parser
|
| 29 |
+
|
| 30 |
+
# word_lstm: hidden_size * num_tree_lstm_layers * 2 (start & end)
|
| 31 |
+
# transition_stack: transition_hidden_size
|
| 32 |
+
# constituent_stack: hidden_size
|
| 33 |
+
self.hidden_size = self.constituency_parser.hidden_size + self.constituency_parser.transition_hidden_size
|
| 34 |
+
if self.config["all_words"]:
|
| 35 |
+
self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
|
| 36 |
+
else:
|
| 37 |
+
self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers * 2
|
| 38 |
+
|
| 39 |
+
if self.config["node_attn"]:
|
| 40 |
+
self.query = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)
|
| 41 |
+
self.key = nn.Linear(self.hidden_size, self.constituency_parser.hidden_size)
|
| 42 |
+
self.value = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)
|
| 43 |
+
|
| 44 |
+
# TODO: cat transition and constituent hx as well?
|
| 45 |
+
self.output_size = self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
|
| 46 |
+
else:
|
| 47 |
+
self.output_size = self.hidden_size
|
| 48 |
+
|
| 49 |
+
# TODO: maybe have batch_norm, maybe use Identity
|
| 50 |
+
#if self.config["batch_norm"]:
|
| 51 |
+
# self.input_norm = nn.BatchNorm1d(self.output_size)
|
| 52 |
+
|
| 53 |
+
def embed_trees(self, inputs):
|
| 54 |
+
if self.config["backprop"]:
|
| 55 |
+
states = self.constituency_parser.analyze_trees(inputs)
|
| 56 |
+
else:
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
states = self.constituency_parser.analyze_trees(inputs)
|
| 59 |
+
|
| 60 |
+
constituent_lists = [x.constituents for x in states]
|
| 61 |
+
states = [x.state for x in states]
|
| 62 |
+
|
| 63 |
+
word_begin_hx = torch.stack([state.word_queue[0].hx for state in states])
|
| 64 |
+
word_end_hx = torch.stack([state.word_queue[state.word_position].hx for state in states])
|
| 65 |
+
transition_hx = torch.stack([self.constituency_parser.transition_stack.output(state.transitions) for state in states])
|
| 66 |
+
# go down one layer to get the embedding off the top of the S, not the ROOT
|
| 67 |
+
# (in terms of the typical treebank)
|
| 68 |
+
# the idea being that the ROOT has no additional information
|
| 69 |
+
# and may even have 0s for the embedding in certain circumstances,
|
| 70 |
+
# such as after learning UNTIED_MAX long enough
|
| 71 |
+
if self.config["top_layer"]:
|
| 72 |
+
constituent_hx = torch.stack([self.constituency_parser.constituent_stack.output(state.constituents) for state in states])
|
| 73 |
+
else:
|
| 74 |
+
constituent_hx = torch.cat([constituents[-2].tree_hx for constituents in constituent_lists], dim=0)
|
| 75 |
+
|
| 76 |
+
if self.config["all_words"]:
|
| 77 |
+
# need B matrices of N x hidden_size
|
| 78 |
+
key = [torch.stack([torch.cat([word.hx, thx, chx]) for word in state.word_queue], dim=0)
|
| 79 |
+
for state, thx, chx in zip(states, transition_hx, constituent_hx)]
|
| 80 |
+
else:
|
| 81 |
+
key = torch.cat((word_begin_hx, word_end_hx, transition_hx, constituent_hx), dim=1).unsqueeze(1)
|
| 82 |
+
|
| 83 |
+
if not self.config["node_attn"]:
|
| 84 |
+
return key
|
| 85 |
+
key = [self.key(x) for x in key]
|
| 86 |
+
|
| 87 |
+
node_hx = [torch.stack([con.tree_hx for con in constituents], dim=0) for constituents in constituent_lists]
|
| 88 |
+
queries = [self.query(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
|
| 89 |
+
values = [self.value(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
|
| 90 |
+
# TODO: could pad to make faster here
|
| 91 |
+
attn = [torch.matmul(q, k.transpose(0, 1)) for q, k in zip(queries, key)]
|
| 92 |
+
attn = [torch.softmax(x, dim=0) for x in attn]
|
| 93 |
+
previous_layer = [torch.matmul(weight.transpose(0, 1), value) for weight, value in zip(attn, values)]
|
| 94 |
+
return previous_layer
|
| 95 |
+
|
| 96 |
+
def forward(self, inputs):
|
| 97 |
+
return embed_trees(self, inputs)
|
| 98 |
+
|
| 99 |
+
def get_norms(self):
|
| 100 |
+
lines = ["constituency_parser." + x for x in self.constituency_parser.get_norms()]
|
| 101 |
+
for name, param in self.named_parameters():
|
| 102 |
+
if param.requires_grad and not name.startswith('constituency_parser.'):
|
| 103 |
+
lines.append("%s %.6g" % (name, torch.norm(param).item()))
|
| 104 |
+
return lines
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_params(self, skip_modules=True):
|
| 108 |
+
model_state = self.state_dict()
|
| 109 |
+
# skip all of the constituency parameters here -
|
| 110 |
+
# we will add them by calling the model's get_params()
|
| 111 |
+
skipped = [k for k in model_state.keys() if k.startswith("constituency_parser.")]
|
| 112 |
+
for k in skipped:
|
| 113 |
+
del model_state[k]
|
| 114 |
+
|
| 115 |
+
parser = self.constituency_parser.get_params(skip_modules)
|
| 116 |
+
|
| 117 |
+
params = {
|
| 118 |
+
'model': model_state,
|
| 119 |
+
'constituency': parser,
|
| 120 |
+
'config': self.config,
|
| 121 |
+
}
|
| 122 |
+
return params
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def from_parser_file(args, foundation_cache=None):
|
| 126 |
+
constituency_parser = Trainer.load(args['model'], args, foundation_cache)
|
| 127 |
+
return TreeEmbedding(constituency_parser.model, args)
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def model_from_params(params, args, foundation_cache=None):
|
| 131 |
+
# TODO: integrate with peft
|
| 132 |
+
constituency_parser = Trainer.model_from_params(params['constituency'], None, args, foundation_cache)
|
| 133 |
+
model = TreeEmbedding(constituency_parser, params['config'])
|
| 134 |
+
model.load_state_dict(params['model'], strict=False)
|
| 135 |
+
return model
|
stanza/stanza/models/coref/config.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Describes Config, a simple namespace for config values.
|
| 2 |
+
|
| 3 |
+
For description of all config values, refer to config.toml.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Config: # pylint: disable=too-many-instance-attributes, too-few-public-methods
|
| 12 |
+
""" Contains values needed to set up the coreference model. """
|
| 13 |
+
section: str
|
| 14 |
+
|
| 15 |
+
# TODO: can either eliminate data_dir or use it for the train/dev/test data
|
| 16 |
+
data_dir: str
|
| 17 |
+
save_dir: str
|
| 18 |
+
save_name: str
|
| 19 |
+
|
| 20 |
+
train_data: str
|
| 21 |
+
dev_data: str
|
| 22 |
+
test_data: str
|
| 23 |
+
|
| 24 |
+
device: str
|
| 25 |
+
|
| 26 |
+
bert_model: str
|
| 27 |
+
bert_window_size: int
|
| 28 |
+
|
| 29 |
+
embedding_size: int
|
| 30 |
+
sp_embedding_size: int
|
| 31 |
+
a_scoring_batch_size: int
|
| 32 |
+
hidden_size: int
|
| 33 |
+
n_hidden_layers: int
|
| 34 |
+
|
| 35 |
+
max_span_len: int
|
| 36 |
+
|
| 37 |
+
rough_k: int
|
| 38 |
+
|
| 39 |
+
lora: bool
|
| 40 |
+
lora_alpha: int
|
| 41 |
+
lora_rank: int
|
| 42 |
+
lora_dropout: float
|
| 43 |
+
|
| 44 |
+
full_pairwise: bool
|
| 45 |
+
|
| 46 |
+
lora_target_modules: List[str]
|
| 47 |
+
lora_modules_to_save: List[str]
|
| 48 |
+
|
| 49 |
+
clusters_starts_are_singletons: bool
|
| 50 |
+
bert_finetune: bool
|
| 51 |
+
dropout_rate: float
|
| 52 |
+
learning_rate: float
|
| 53 |
+
bert_learning_rate: float
|
| 54 |
+
# we find that setting this to a small but non-zero number
|
| 55 |
+
# makes the model less likely to forget how to do anything
|
| 56 |
+
bert_finetune_begin_epoch: float
|
| 57 |
+
train_epochs: int
|
| 58 |
+
bce_loss_weight: float
|
| 59 |
+
|
| 60 |
+
tokenizer_kwargs: Dict[str, dict]
|
| 61 |
+
conll_log_dir: str
|
| 62 |
+
|
| 63 |
+
save_each_checkpoint: bool
|
| 64 |
+
log_norms: bool
|
| 65 |
+
singletons: bool
|
| 66 |
+
|
stanza/stanza/models/coref/coref_config.toml
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Before you start changing anything here, read the comments.
|
| 3 |
+
# All of them can be found below in the "DEFAULT" section
|
| 4 |
+
|
| 5 |
+
[DEFAULT]
|
| 6 |
+
|
| 7 |
+
# The directory that contains extracted files of everything you've downloaded.
|
| 8 |
+
data_dir = "data/coref"
|
| 9 |
+
|
| 10 |
+
# where to put checkpoints and final models
|
| 11 |
+
save_dir = "saved_models/coref"
|
| 12 |
+
save_name = "bert-large-cased"
|
| 13 |
+
|
| 14 |
+
# Train, dev and test jsonlines
|
| 15 |
+
# train_data = "data/coref/en_gum-ud.train.nosgl.json"
|
| 16 |
+
# dev_data = "data/coref/en_gum-ud.dev.nosgl.json"
|
| 17 |
+
# test_data = "data/coref/en_gum-ud.test.nosgl.json"
|
| 18 |
+
|
| 19 |
+
train_data = "data/coref/corefud_concat_v1_0_langid.train.json"
|
| 20 |
+
dev_data = "data/coref/corefud_concat_v1_0_langid.dev.json"
|
| 21 |
+
test_data = "data/coref/corefud_concat_v1_0_langid.dev.json"
|
| 22 |
+
|
| 23 |
+
#train_data = "data/coref/english_train_head.jsonlines"
|
| 24 |
+
#dev_data = "data/coref/english_development_head.jsonlines"
|
| 25 |
+
#test_data = "data/coref/english_test_head.jsonlines"
|
| 26 |
+
|
| 27 |
+
# do not use the full pairwise encoding scheme
|
| 28 |
+
full_pairwise = false
|
| 29 |
+
|
| 30 |
+
# The device where everything is to be placed. "cuda:N"/"cpu" are supported.
|
| 31 |
+
device = "cuda:0"
|
| 32 |
+
|
| 33 |
+
save_each_checkpoint = false
|
| 34 |
+
log_norms = false
|
| 35 |
+
|
| 36 |
+
# Bert settings ======================
|
| 37 |
+
|
| 38 |
+
# Base bert model architecture and tokenizer
|
| 39 |
+
bert_model = "bert-large-cased"
|
| 40 |
+
|
| 41 |
+
# Controls max length of sequences passed through bert to obtain its
|
| 42 |
+
# contextual embeddings
|
| 43 |
+
# Must be less than or equal to 512
|
| 44 |
+
bert_window_size = 512
|
| 45 |
+
|
| 46 |
+
# General model settings =============
|
| 47 |
+
|
| 48 |
+
# Controls the dimensionality of feature embeddings
|
| 49 |
+
embedding_size = 20
|
| 50 |
+
|
| 51 |
+
# Controls the dimensionality of distance embeddings used by SpanPredictor
|
| 52 |
+
sp_embedding_size = 64
|
| 53 |
+
|
| 54 |
+
# Controls the number of spans for which anaphoricity can be scores in one
|
| 55 |
+
# batch. Only affects final scoring; mention extraction and rough scoring
|
| 56 |
+
# are less memory intensive, so they are always done in just one batch.
|
| 57 |
+
a_scoring_batch_size = 128
|
| 58 |
+
|
| 59 |
+
# AnaphoricityScorer FFNN parameters
|
| 60 |
+
hidden_size = 1024
|
| 61 |
+
n_hidden_layers = 1
|
| 62 |
+
|
| 63 |
+
# Do you want to support singletons?
|
| 64 |
+
singletons = true
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Mention extraction settings ========
|
| 68 |
+
|
| 69 |
+
# Mention extractor will check spans up to max_span_len words
|
| 70 |
+
# The default value is chosen to be big enough to hold any dev data span
|
| 71 |
+
max_span_len = 64
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Pruning settings ===================
|
| 75 |
+
|
| 76 |
+
# Controls how many pairs should be preserved per mention
|
| 77 |
+
# after applying rough scoring.
|
| 78 |
+
rough_k = 50
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Lora settings ===================
|
| 82 |
+
|
| 83 |
+
# LoRA settings
|
| 84 |
+
lora = false
|
| 85 |
+
lora_alpha = 128
|
| 86 |
+
lora_dropout = 0.1
|
| 87 |
+
lora_rank = 64
|
| 88 |
+
lora_target_modules = []
|
| 89 |
+
lora_modules_to_save = []
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Training settings ==================
|
| 93 |
+
|
| 94 |
+
# Controls whether the first dummy node predicts cluster starts or singletons
|
| 95 |
+
clusters_starts_are_singletons = true
|
| 96 |
+
|
| 97 |
+
# Controls whether to fine-tune bert_model
|
| 98 |
+
bert_finetune = true
|
| 99 |
+
|
| 100 |
+
# Controls the dropout rate throughout all models
|
| 101 |
+
dropout_rate = 0.3
|
| 102 |
+
|
| 103 |
+
# Bert learning rate (only used if bert_finetune is set)
|
| 104 |
+
bert_learning_rate = 1e-6
|
| 105 |
+
bert_finetune_begin_epoch = 0.5
|
| 106 |
+
|
| 107 |
+
# Task learning rate
|
| 108 |
+
learning_rate = 3e-4
|
| 109 |
+
|
| 110 |
+
# For how many epochs the training is done
|
| 111 |
+
train_epochs = 32
|
| 112 |
+
|
| 113 |
+
# Controls the weight of binary cross entropy loss added to nlml loss
|
| 114 |
+
bce_loss_weight = 0.5
|
| 115 |
+
|
| 116 |
+
# The directory that will contain conll prediction files
|
| 117 |
+
conll_log_dir = "data/conll_logs"
|
| 118 |
+
|
| 119 |
+
# =============================================================================
|
| 120 |
+
# Extra keyword arguments to be passed to bert tokenizers of specified models
|
| 121 |
+
[DEFAULT.tokenizer_kwargs]
|
| 122 |
+
[DEFAULT.tokenizer_kwargs.roberta-large]
|
| 123 |
+
"add_prefix_space" = true
|
| 124 |
+
|
| 125 |
+
[DEFAULT.tokenizer_kwargs.xlm-roberta-large]
|
| 126 |
+
"add_prefix_space" = true
|
| 127 |
+
|
| 128 |
+
[DEFAULT.tokenizer_kwargs.spanbert-large-cased]
|
| 129 |
+
"do_lower_case" = false
|
| 130 |
+
|
| 131 |
+
[DEFAULT.tokenizer_kwargs.bert-large-cased]
|
| 132 |
+
"do_lower_case" = false
|
| 133 |
+
|
| 134 |
+
# =============================================================================
|
| 135 |
+
# The sections listed here do not need to make use of all config variables
|
| 136 |
+
# If a variable is omitted, its default value will be used instead
|
| 137 |
+
|
| 138 |
+
[roberta]
|
| 139 |
+
bert_model = "roberta-large"
|
| 140 |
+
|
| 141 |
+
[roberta_lora]
|
| 142 |
+
bert_model = "roberta-large"
|
| 143 |
+
bert_learning_rate = 0.00005
|
| 144 |
+
lora = true
|
| 145 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 146 |
+
lora_modules_to_save = [ "pooler" ]
|
| 147 |
+
|
| 148 |
+
[scandibert_lora]
|
| 149 |
+
bert_model = "vesteinn/ScandiBERT"
|
| 150 |
+
bert_learning_rate = 0.0002
|
| 151 |
+
lora = true
|
| 152 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 153 |
+
lora_modules_to_save = [ "pooler" ]
|
| 154 |
+
|
| 155 |
+
[xlm_roberta]
|
| 156 |
+
bert_model = "FacebookAI/xlm-roberta-large"
|
| 157 |
+
bert_learning_rate = 0.00001
|
| 158 |
+
bert_finetune = true
|
| 159 |
+
|
| 160 |
+
[xlm_roberta_lora]
|
| 161 |
+
bert_model = "FacebookAI/xlm-roberta-large"
|
| 162 |
+
bert_learning_rate = 0.000025
|
| 163 |
+
lora = true
|
| 164 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 165 |
+
lora_modules_to_save = [ "pooler" ]
|
| 166 |
+
|
| 167 |
+
[deeppavlov_slavic_bert_lora]
|
| 168 |
+
bert_model = "DeepPavlov/bert-base-bg-cs-pl-ru-cased"
|
| 169 |
+
bert_learning_rate = 0.000025
|
| 170 |
+
lora = true
|
| 171 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 172 |
+
lora_modules_to_save = [ "pooler" ]
|
| 173 |
+
|
| 174 |
+
[deberta_lora]
|
| 175 |
+
bert_model = "microsoft/deberta-v3-large"
|
| 176 |
+
bert_learning_rate = 0.00001
|
| 177 |
+
lora = true
|
| 178 |
+
lora_target_modules = [ "query_proj", "value_proj", "output.dense" ]
|
| 179 |
+
lora_modules_to_save = [ ]
|
| 180 |
+
|
| 181 |
+
[electra]
|
| 182 |
+
bert_model = "google/electra-large-discriminator"
|
| 183 |
+
bert_learning_rate = 0.00002
|
| 184 |
+
|
| 185 |
+
[electra_lora]
|
| 186 |
+
bert_model = "google/electra-large-discriminator"
|
| 187 |
+
bert_learning_rate = 0.000025
|
| 188 |
+
lora = true
|
| 189 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 190 |
+
lora_modules_to_save = [ ]
|
| 191 |
+
|
| 192 |
+
[hungarian_electra_lora]
|
| 193 |
+
# TODO: experiment with tokenizer options for this to see if that's
|
| 194 |
+
# why the results are so low using this transformer
|
| 195 |
+
bert_model = "NYTK/electra-small-discriminator-hungarian"
|
| 196 |
+
bert_learning_rate = 0.000025
|
| 197 |
+
lora = true
|
| 198 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 199 |
+
lora_modules_to_save = [ ]
|
| 200 |
+
|
| 201 |
+
[muril_large_cased_lora]
|
| 202 |
+
bert_model = "google/muril-large-cased"
|
| 203 |
+
bert_learning_rate = 0.000025
|
| 204 |
+
lora = true
|
| 205 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 206 |
+
lora_modules_to_save = [ "pooler" ]
|
| 207 |
+
|
| 208 |
+
[indic_bert_lora]
|
| 209 |
+
bert_model = "ai4bharat/indic-bert"
|
| 210 |
+
bert_learning_rate = 0.0005
|
| 211 |
+
lora = true
|
| 212 |
+
# indic-bert is an albert with repeating layers of different names
|
| 213 |
+
lora_target_modules = [ "query", "value", "dense", "ffn", "full_layer" ]
|
| 214 |
+
lora_modules_to_save = [ "pooler" ]
|
| 215 |
+
|
| 216 |
+
[bert_multilingual_cased_lora]
|
| 217 |
+
# LR sweep on a Hindi dataset
|
| 218 |
+
# 0.00001: 0.53238
|
| 219 |
+
# 0.00002: 0.54012
|
| 220 |
+
# 0.000025: 0.54206
|
| 221 |
+
# 0.00003: 0.54050
|
| 222 |
+
# 0.00004: 0.55081
|
| 223 |
+
# 0.00005: 0.55135
|
| 224 |
+
# 0.000075: 0.54482
|
| 225 |
+
# 0.0001: 0.53888
|
| 226 |
+
bert_model = "google-bert/bert-base-multilingual-cased"
|
| 227 |
+
bert_learning_rate = 0.00005
|
| 228 |
+
lora = true
|
| 229 |
+
lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
|
| 230 |
+
lora_modules_to_save = [ "pooler" ]
|
| 231 |
+
|
| 232 |
+
[t5_lora]
|
| 233 |
+
bert_model = "google-t5/t5-large"
|
| 234 |
+
bert_learning_rate = 0.000025
|
| 235 |
+
bert_window_size = 1024
|
| 236 |
+
lora = true
|
| 237 |
+
lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
|
| 238 |
+
lora_modules_to_save = [ ]
|
| 239 |
+
|
| 240 |
+
[mt5_lora]
|
| 241 |
+
bert_model = "google/mt5-base"
|
| 242 |
+
bert_learning_rate = 0.000025
|
| 243 |
+
lora_alpha = 64
|
| 244 |
+
lora_rank = 32
|
| 245 |
+
lora = true
|
| 246 |
+
lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
|
| 247 |
+
lora_modules_to_save = [ ]
|
| 248 |
+
|
| 249 |
+
[deepnarrow_t5_xl_lora]
|
| 250 |
+
bert_model = "google/t5-efficient-xl"
|
| 251 |
+
bert_learning_rate = 0.00025
|
| 252 |
+
lora = true
|
| 253 |
+
lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
|
| 254 |
+
lora_modules_to_save = [ ]
|
| 255 |
+
|
| 256 |
+
[roberta_no_finetune]
|
| 257 |
+
bert_model = "roberta-large"
|
| 258 |
+
bert_finetune = false
|
| 259 |
+
|
| 260 |
+
[roberta_no_bce]
|
| 261 |
+
bert_model = "roberta-large"
|
| 262 |
+
bce_loss_weight = 0.0
|
| 263 |
+
|
| 264 |
+
[spanbert]
|
| 265 |
+
bert_model = "SpanBERT/spanbert-large-cased"
|
| 266 |
+
|
| 267 |
+
[spanbert_no_bce]
|
| 268 |
+
bert_model = "SpanBERT/spanbert-large-cased"
|
| 269 |
+
bce_loss_weight = 0.0
|
| 270 |
+
|
| 271 |
+
[bert]
|
| 272 |
+
bert_model = "bert-large-cased"
|
| 273 |
+
|
| 274 |
+
[longformer]
|
| 275 |
+
bert_model = "allenai/longformer-large-4096"
|
| 276 |
+
bert_window_size = 2048
|
| 277 |
+
|
| 278 |
+
[debug]
|
| 279 |
+
bert_window_size = 384
|
| 280 |
+
bert_finetune = false
|
| 281 |
+
device = "cpu:0"
|
| 282 |
+
|
| 283 |
+
[debug_gpu]
|
| 284 |
+
bert_window_size = 384
|
| 285 |
+
bert_finetune = false
|
stanza/stanza/models/coref/dataset.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
|
| 5 |
+
from stanza.models.coref.tokenizer_customization import TOKENIZER_FILTERS, TOKENIZER_MAPS
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger('stanza')
|
| 8 |
+
|
| 9 |
+
class CorefDataset(Dataset):
|
| 10 |
+
|
| 11 |
+
def __init__(self, path, config, tokenizer):
|
| 12 |
+
self.config = config
|
| 13 |
+
self.tokenizer = tokenizer
|
| 14 |
+
|
| 15 |
+
# by default, this doesn't filter anything (see lambda _ True);
|
| 16 |
+
# however, there are some subword symbols which are standalone
|
| 17 |
+
# tokens which we don't want on models like Albert; hence we
|
| 18 |
+
# pass along a filter if needed.
|
| 19 |
+
self.__filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,
|
| 20 |
+
lambda _: True)
|
| 21 |
+
self.__token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
with open(path, encoding="utf-8") as fin:
|
| 25 |
+
data_f = json.load(fin)
|
| 26 |
+
except json.decoder.JSONDecodeError:
|
| 27 |
+
# read the old jsonlines format if necessary
|
| 28 |
+
with open(path, encoding="utf-8") as fin:
|
| 29 |
+
text = "[" + ",\n".join(fin) + "]"
|
| 30 |
+
data_f = json.loads(text)
|
| 31 |
+
logger.info("Processing %d docs from %s...", len(data_f), path)
|
| 32 |
+
self.__raw = data_f
|
| 33 |
+
self.__avg_span = sum(len(doc["head2span"]) for doc in self.__raw) / len(self.__raw)
|
| 34 |
+
self.__out = []
|
| 35 |
+
for doc in self.__raw:
|
| 36 |
+
doc["span_clusters"] = [[tuple(mention) for mention in cluster]
|
| 37 |
+
for cluster in doc["span_clusters"]]
|
| 38 |
+
word2subword = []
|
| 39 |
+
subwords = []
|
| 40 |
+
word_id = []
|
| 41 |
+
for i, word in enumerate(doc["cased_words"]):
|
| 42 |
+
tokenized_word = self.__token_map.get(word, self.tokenizer.tokenize(word))
|
| 43 |
+
tokenized_word = list(filter(self.__filter_func, tokenized_word))
|
| 44 |
+
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
|
| 45 |
+
subwords.extend(tokenized_word)
|
| 46 |
+
word_id.extend([i] * len(tokenized_word))
|
| 47 |
+
doc["word2subword"] = word2subword
|
| 48 |
+
doc["subwords"] = subwords
|
| 49 |
+
doc["word_id"] = word_id
|
| 50 |
+
self.__out.append(doc)
|
| 51 |
+
logger.info("Loaded %d docs from %s.", len(data_f), path)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def avg_span(self):
|
| 55 |
+
return self.__avg_span
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, x):
|
| 58 |
+
return self.__out[x]
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self.__out)
|
stanza/stanza/models/coref/pairwise_encoder.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Describes PairwiseEncodes, that transforms pairwise features, such as
|
| 2 |
+
distance between the mentions, same/different speaker into feature embeddings
|
| 3 |
+
"""
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from stanza.models.coref.config import Config
|
| 9 |
+
from stanza.models.coref.const import Doc
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PairwiseEncoder(torch.nn.Module):
|
| 13 |
+
""" A Pytorch module to obtain feature embeddings for pairwise features
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
encoder = PairwiseEncoder(config)
|
| 17 |
+
pairwise_features = encoder(pair_indices, doc)
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, config: Config):
|
| 20 |
+
super().__init__()
|
| 21 |
+
emb_size = config.embedding_size
|
| 22 |
+
|
| 23 |
+
self.genre2int = {g: gi for gi, g in enumerate(["bc", "bn", "mz", "nw",
|
| 24 |
+
"pt", "tc", "wb"])}
|
| 25 |
+
self.genre_emb = torch.nn.Embedding(len(self.genre2int), emb_size)
|
| 26 |
+
|
| 27 |
+
# each position corresponds to a bucket:
|
| 28 |
+
# [(0, 2), (2, 3), (3, 4), (4, 5), (5, 8),
|
| 29 |
+
# (8, 16), (16, 32), (32, 64), (64, float("inf"))]
|
| 30 |
+
self.distance_emb = torch.nn.Embedding(9, emb_size)
|
| 31 |
+
|
| 32 |
+
# two possibilities: same vs different speaker
|
| 33 |
+
self.speaker_emb = torch.nn.Embedding(2, emb_size)
|
| 34 |
+
|
| 35 |
+
self.dropout = torch.nn.Dropout(config.dropout_rate)
|
| 36 |
+
|
| 37 |
+
self.__full_pw = config.full_pairwise
|
| 38 |
+
|
| 39 |
+
if self.__full_pw:
|
| 40 |
+
self.shape = emb_size * 3 # genre, distance, speaker\
|
| 41 |
+
else:
|
| 42 |
+
self.shape = emb_size # distance only
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def device(self) -> torch.device:
|
| 46 |
+
""" A workaround to get current device (which is assumed to be the
|
| 47 |
+
device of the first parameter of one of the submodules) """
|
| 48 |
+
return next(self.genre_emb.parameters()).device
|
| 49 |
+
|
| 50 |
+
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
| 51 |
+
top_indices: torch.Tensor,
|
| 52 |
+
doc: Doc) -> torch.Tensor:
|
| 53 |
+
word_ids = torch.arange(0, len(doc["cased_words"]), device=self.device)
|
| 54 |
+
|
| 55 |
+
# bucketing the distance (see __init__())
|
| 56 |
+
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
| 57 |
+
).clamp_min_(min=1)
|
| 58 |
+
log_distance = distance.to(torch.float).log2().floor_()
|
| 59 |
+
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
| 60 |
+
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
| 61 |
+
distance = self.distance_emb(distance)
|
| 62 |
+
|
| 63 |
+
if not self.__full_pw:
|
| 64 |
+
return self.dropout(distance)
|
| 65 |
+
|
| 66 |
+
# calculate speaker embeddings
|
| 67 |
+
speaker_map = torch.tensor(self._speaker_map(doc), device=self.device)
|
| 68 |
+
same_speaker = (speaker_map[top_indices] == speaker_map.unsqueeze(1))
|
| 69 |
+
same_speaker = self.speaker_emb(same_speaker.to(torch.long))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# if there is no genre information, use "wb" as the genre (which is what the
|
| 73 |
+
# Pipeline does
|
| 74 |
+
genre = torch.tensor(self.genre2int.get(doc["document_id"][:2], self.genre2int["wb"]),
|
| 75 |
+
device=self.device).expand_as(top_indices)
|
| 76 |
+
genre = self.genre_emb(genre)
|
| 77 |
+
|
| 78 |
+
return self.dropout(torch.cat((same_speaker, distance, genre), dim=2))
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def _speaker_map(doc: Doc) -> List[int]:
|
| 82 |
+
"""
|
| 83 |
+
Returns a tensor where i-th element is the speaker id of i-th word.
|
| 84 |
+
"""
|
| 85 |
+
# if speaker is not found in the doc, simply return "speaker#1" for all the speakers
|
| 86 |
+
# and embed them using the same ID
|
| 87 |
+
|
| 88 |
+
# speaker string -> speaker id
|
| 89 |
+
str2int = {s: i for i, s in enumerate(set(doc.get("speaker", ["speaker#1"
|
| 90 |
+
for _ in range(len(doc["deprel"]))])))}
|
| 91 |
+
|
| 92 |
+
# word id -> speaker id
|
| 93 |
+
return [str2int[s] for s in doc.get("speaker", ["speaker#1"
|
| 94 |
+
for _ in range(len(doc["deprel"]))])]
|
stanza/stanza/models/coref/rough_scorer.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Describes RoughScorer, a simple bilinear module to calculate rough
|
| 2 |
+
anaphoricity scores.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from stanza.models.coref.config import Config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RoughScorer(torch.nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Is needed to give a roughly estimate of the anaphoricity of two candidates,
|
| 15 |
+
only top scoring candidates are considered on later steps to reduce
|
| 16 |
+
computational complexity.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, features: int, config: Config):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.dropout = torch.nn.Dropout(config.dropout_rate)
|
| 21 |
+
self.bilinear = torch.nn.Linear(features, features)
|
| 22 |
+
|
| 23 |
+
self.k = config.rough_k
|
| 24 |
+
|
| 25 |
+
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
| 26 |
+
mentions: torch.Tensor,
|
| 27 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 28 |
+
"""
|
| 29 |
+
Returns rough anaphoricity scores for candidates, which consist of
|
| 30 |
+
the bilinear output of the current model summed with mention scores.
|
| 31 |
+
"""
|
| 32 |
+
# [n_mentions, n_mentions]
|
| 33 |
+
pair_mask = torch.arange(mentions.shape[0])
|
| 34 |
+
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
| 35 |
+
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
| 36 |
+
pair_mask = pair_mask.to(mentions.device)
|
| 37 |
+
|
| 38 |
+
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
| 39 |
+
|
| 40 |
+
rough_scores = pair_mask + bilinear_scores
|
| 41 |
+
|
| 42 |
+
return self._prune(rough_scores)
|
| 43 |
+
|
| 44 |
+
def _prune(self,
|
| 45 |
+
rough_scores: torch.Tensor
|
| 46 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
"""
|
| 48 |
+
Selects top-k rough antecedent scores for each mention.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
rough_scores: tensor of shape [n_mentions, n_mentions], containing
|
| 52 |
+
rough antecedent scores of each mention-antecedent pair.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
FloatTensor of shape [n_mentions, k], top rough scores
|
| 56 |
+
LongTensor of shape [n_mentions, k], top indices
|
| 57 |
+
"""
|
| 58 |
+
top_scores, indices = torch.topk(rough_scores,
|
| 59 |
+
k=min(self.k, len(rough_scores)),
|
| 60 |
+
dim=1, sorted=False)
|
| 61 |
+
return top_scores, indices, rough_scores
|
stanza/stanza/models/coref/utils.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Contains functions not directly linked to coreference resolution """
|
| 2 |
+
|
| 3 |
+
from typing import List, Set
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from stanza.models.coref.const import EPSILON
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GraphNode:
|
| 11 |
+
def __init__(self, node_id: int):
|
| 12 |
+
self.id = node_id
|
| 13 |
+
self.links: Set[GraphNode] = set()
|
| 14 |
+
self.visited = False
|
| 15 |
+
|
| 16 |
+
def link(self, another: "GraphNode"):
|
| 17 |
+
self.links.add(another)
|
| 18 |
+
another.links.add(self)
|
| 19 |
+
|
| 20 |
+
def __repr__(self) -> str:
|
| 21 |
+
return str(self.id)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
| 25 |
+
""" Prepends zeros (or a very small value if eps is True)
|
| 26 |
+
to the first (not zeroth) dimension of tensor.
|
| 27 |
+
"""
|
| 28 |
+
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
|
| 29 |
+
shape: List[int] = list(tensor.shape)
|
| 30 |
+
shape[1] = 1
|
| 31 |
+
if not eps:
|
| 32 |
+
dummy = torch.zeros(shape, **kwargs) # type: ignore
|
| 33 |
+
else:
|
| 34 |
+
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
|
| 35 |
+
return torch.cat((dummy, tensor), dim=1)
|
stanza/stanza/models/depparse/model.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
|
| 9 |
+
|
| 10 |
+
from stanza.models.common.bert_embedding import extract_bert_embeddings
|
| 11 |
+
from stanza.models.common.biaffine import DeepBiaffineScorer
|
| 12 |
+
from stanza.models.common.foundation_cache import load_charlm
|
| 13 |
+
from stanza.models.common.hlstm import HighwayLSTM
|
| 14 |
+
from stanza.models.common.dropout import WordDropout
|
| 15 |
+
from stanza.models.common.utils import attach_bert_model
|
| 16 |
+
from stanza.models.common.vocab import CompositeVocab
|
| 17 |
+
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
|
| 18 |
+
from stanza.models.common import utils
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger('stanza')
|
| 21 |
+
|
| 22 |
+
class Parser(nn.Module):
|
| 23 |
+
def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.vocab = vocab
|
| 27 |
+
self.args = args
|
| 28 |
+
self.share_hid = share_hid
|
| 29 |
+
self.unsaved_modules = []
|
| 30 |
+
|
| 31 |
+
# input layers
|
| 32 |
+
input_size = 0
|
| 33 |
+
if self.args['word_emb_dim'] > 0:
|
| 34 |
+
# frequent word embeddings
|
| 35 |
+
self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)
|
| 36 |
+
self.lemma_emb = nn.Embedding(len(vocab['lemma']), self.args['word_emb_dim'], padding_idx=0)
|
| 37 |
+
input_size += self.args['word_emb_dim'] * 2
|
| 38 |
+
|
| 39 |
+
if self.args['tag_emb_dim'] > 0:
|
| 40 |
+
if self.args.get('use_upos', True):
|
| 41 |
+
self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
|
| 42 |
+
if self.args.get('use_xpos', True):
|
| 43 |
+
if not isinstance(vocab['xpos'], CompositeVocab):
|
| 44 |
+
self.xpos_emb = nn.Embedding(len(vocab['xpos']), self.args['tag_emb_dim'], padding_idx=0)
|
| 45 |
+
else:
|
| 46 |
+
self.xpos_emb = nn.ModuleList()
|
| 47 |
+
|
| 48 |
+
for l in vocab['xpos'].lens():
|
| 49 |
+
self.xpos_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))
|
| 50 |
+
if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
|
| 51 |
+
input_size += self.args['tag_emb_dim']
|
| 52 |
+
|
| 53 |
+
if self.args.get('use_ufeats', True):
|
| 54 |
+
self.ufeats_emb = nn.ModuleList()
|
| 55 |
+
|
| 56 |
+
for l in vocab['feats'].lens():
|
| 57 |
+
self.ufeats_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))
|
| 58 |
+
|
| 59 |
+
input_size += self.args['tag_emb_dim']
|
| 60 |
+
|
| 61 |
+
if self.args['char'] and self.args['char_emb_dim'] > 0:
|
| 62 |
+
if self.args.get('charlm', None):
|
| 63 |
+
if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
|
| 64 |
+
raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
|
| 65 |
+
if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
|
| 66 |
+
raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
|
| 67 |
+
logger.debug("Depparse model loading charmodels: %s and %s", args['charlm_forward_file'], args['charlm_backward_file'])
|
| 68 |
+
self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache))
|
| 69 |
+
self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache))
|
| 70 |
+
input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
|
| 71 |
+
else:
|
| 72 |
+
self.charmodel = CharacterModel(args, vocab)
|
| 73 |
+
self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
|
| 74 |
+
input_size += self.args['transformed_dim']
|
| 75 |
+
|
| 76 |
+
self.peft_name = peft_name
|
| 77 |
+
attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
|
| 78 |
+
if self.args.get('bert_model', None):
|
| 79 |
+
# TODO: refactor bert_hidden_layers between the different models
|
| 80 |
+
if args.get('bert_hidden_layers', False):
|
| 81 |
+
# The average will be offset by 1/N so that the default zeros
|
| 82 |
+
# represents an average of the N layers
|
| 83 |
+
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
|
| 84 |
+
nn.init.zeros_(self.bert_layer_mix.weight)
|
| 85 |
+
else:
|
| 86 |
+
# an average of layers 2, 3, 4 will be used
|
| 87 |
+
# (for historic reasons)
|
| 88 |
+
self.bert_layer_mix = None
|
| 89 |
+
input_size += self.bert_model.config.hidden_size
|
| 90 |
+
|
| 91 |
+
if self.args['pretrain']:
|
| 92 |
+
# pretrained embeddings, by default this won't be saved into model file
|
| 93 |
+
self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
|
| 94 |
+
self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
|
| 95 |
+
input_size += self.args['transformed_dim']
|
| 96 |
+
|
| 97 |
+
# recurrent layers
|
| 98 |
+
self.parserlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)
|
| 99 |
+
self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
|
| 100 |
+
self.parserlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
|
| 101 |
+
self.parserlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
|
| 102 |
+
|
| 103 |
+
# classifiers
|
| 104 |
+
self.unlabeled = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=args['dropout'])
|
| 105 |
+
self.deprel = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], len(vocab['deprel']), pairwise=True, dropout=args['dropout'])
|
| 106 |
+
if args['linearization']:
|
| 107 |
+
self.linearization = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=args['dropout'])
|
| 108 |
+
if args['distance']:
|
| 109 |
+
self.distance = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=args['dropout'])
|
| 110 |
+
|
| 111 |
+
# criterion
|
| 112 |
+
self.crit = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') # ignore padding
|
| 113 |
+
|
| 114 |
+
self.drop = nn.Dropout(args['dropout'])
|
| 115 |
+
self.worddrop = WordDropout(args['word_dropout'])
|
| 116 |
+
|
| 117 |
+
def add_unsaved_module(self, name, module):
|
| 118 |
+
self.unsaved_modules += [name]
|
| 119 |
+
setattr(self, name, module)
|
| 120 |
+
|
| 121 |
+
def log_norms(self):
|
| 122 |
+
utils.log_norms(self)
|
| 123 |
+
|
| 124 |
+
def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
|
| 125 |
+
def pack(x):
|
| 126 |
+
return pack_padded_sequence(x, sentlens, batch_first=True)
|
| 127 |
+
|
| 128 |
+
inputs = []
|
| 129 |
+
if self.args['pretrain']:
|
| 130 |
+
pretrained_emb = self.pretrained_emb(pretrained)
|
| 131 |
+
pretrained_emb = self.trans_pretrained(pretrained_emb)
|
| 132 |
+
pretrained_emb = pack(pretrained_emb)
|
| 133 |
+
inputs += [pretrained_emb]
|
| 134 |
+
|
| 135 |
+
#def pad(x):
|
| 136 |
+
# return pad_packed_sequence(PackedSequence(x, pretrained_emb.batch_sizes), batch_first=True)[0]
|
| 137 |
+
|
| 138 |
+
if self.args['word_emb_dim'] > 0:
|
| 139 |
+
word_emb = self.word_emb(word)
|
| 140 |
+
word_emb = pack(word_emb)
|
| 141 |
+
lemma_emb = self.lemma_emb(lemma)
|
| 142 |
+
lemma_emb = pack(lemma_emb)
|
| 143 |
+
inputs += [word_emb, lemma_emb]
|
| 144 |
+
|
| 145 |
+
if self.args['tag_emb_dim'] > 0:
|
| 146 |
+
if self.args.get('use_upos', True):
|
| 147 |
+
pos_emb = self.upos_emb(upos)
|
| 148 |
+
else:
|
| 149 |
+
pos_emb = 0
|
| 150 |
+
|
| 151 |
+
if self.args.get('use_xpos', True):
|
| 152 |
+
if isinstance(self.vocab['xpos'], CompositeVocab):
|
| 153 |
+
for i in range(len(self.vocab['xpos'])):
|
| 154 |
+
pos_emb += self.xpos_emb[i](xpos[:, :, i])
|
| 155 |
+
else:
|
| 156 |
+
pos_emb += self.xpos_emb(xpos)
|
| 157 |
+
|
| 158 |
+
if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
|
| 159 |
+
pos_emb = pack(pos_emb)
|
| 160 |
+
inputs += [pos_emb]
|
| 161 |
+
|
| 162 |
+
if self.args.get('use_ufeats', True):
|
| 163 |
+
feats_emb = 0
|
| 164 |
+
for i in range(len(self.vocab['feats'])):
|
| 165 |
+
feats_emb += self.ufeats_emb[i](ufeats[:, :, i])
|
| 166 |
+
feats_emb = pack(feats_emb)
|
| 167 |
+
|
| 168 |
+
inputs += [pos_emb]
|
| 169 |
+
|
| 170 |
+
if self.args['char'] and self.args['char_emb_dim'] > 0:
|
| 171 |
+
if self.args.get('charlm', None):
|
| 172 |
+
# \n is to add a somewhat neutral "word" for the ROOT
|
| 173 |
+
charlm_text = [["\n"] + x for x in text]
|
| 174 |
+
all_forward_chars = self.charmodel_forward.build_char_representation(charlm_text)
|
| 175 |
+
all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))
|
| 176 |
+
all_backward_chars = self.charmodel_backward.build_char_representation(charlm_text)
|
| 177 |
+
all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))
|
| 178 |
+
inputs += [all_forward_chars, all_backward_chars]
|
| 179 |
+
else:
|
| 180 |
+
char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
|
| 181 |
+
char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)
|
| 182 |
+
inputs += [char_reps]
|
| 183 |
+
|
| 184 |
+
if self.bert_model is not None:
|
| 185 |
+
device = next(self.parameters()).device
|
| 186 |
+
processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=True,
|
| 187 |
+
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
|
| 188 |
+
detach=not self.args.get('bert_finetune', False) or not self.training,
|
| 189 |
+
peft_name=self.peft_name)
|
| 190 |
+
if self.bert_layer_mix is not None:
|
| 191 |
+
# use a linear layer to weighted average the embedding dynamically
|
| 192 |
+
processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
|
| 193 |
+
|
| 194 |
+
# we are using the first endpoint from the transformer as the "word" for ROOT
|
| 195 |
+
processed_bert = [x[:-1, :] for x in processed_bert]
|
| 196 |
+
processed_bert = pad_sequence(processed_bert, batch_first=True)
|
| 197 |
+
inputs += [pack(processed_bert)]
|
| 198 |
+
|
| 199 |
+
lstm_inputs = torch.cat([x.data for x in inputs], 1)
|
| 200 |
+
|
| 201 |
+
lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
|
| 202 |
+
lstm_inputs = self.drop(lstm_inputs)
|
| 203 |
+
|
| 204 |
+
lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
|
| 205 |
+
|
| 206 |
+
lstm_outputs, _ = self.parserlstm(lstm_inputs, sentlens, hx=(self.parserlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.parserlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))
|
| 207 |
+
lstm_outputs, _ = pad_packed_sequence(lstm_outputs, batch_first=True)
|
| 208 |
+
|
| 209 |
+
unlabeled_scores = self.unlabeled(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
|
| 210 |
+
deprel_scores = self.deprel(self.drop(lstm_outputs), self.drop(lstm_outputs))
|
| 211 |
+
|
| 212 |
+
#goldmask = head.new_zeros(*head.size(), head.size(-1)+1, dtype=torch.uint8)
|
| 213 |
+
#goldmask.scatter_(2, head.unsqueeze(2), 1)
|
| 214 |
+
|
| 215 |
+
if self.args['linearization'] or self.args['distance']:
|
| 216 |
+
head_offset = torch.arange(word.size(1), device=head.device).view(1, 1, -1).expand(word.size(0), -1, -1) - torch.arange(word.size(1), device=head.device).view(1, -1, 1).expand(word.size(0), -1, -1)
|
| 217 |
+
|
| 218 |
+
if self.args['linearization']:
|
| 219 |
+
lin_scores = self.linearization(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
|
| 220 |
+
unlabeled_scores += F.logsigmoid(lin_scores * torch.sign(head_offset).float()).detach()
|
| 221 |
+
|
| 222 |
+
if self.args['distance']:
|
| 223 |
+
dist_scores = self.distance(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
|
| 224 |
+
dist_pred = 1 + F.softplus(dist_scores)
|
| 225 |
+
dist_target = torch.abs(head_offset)
|
| 226 |
+
dist_kld = -torch.log((dist_target.float() - dist_pred)**2/2 + 1)
|
| 227 |
+
unlabeled_scores += dist_kld.detach()
|
| 228 |
+
|
| 229 |
+
diag = torch.eye(head.size(-1)+1, dtype=torch.bool, device=head.device).unsqueeze(0)
|
| 230 |
+
unlabeled_scores.masked_fill_(diag, -float('inf'))
|
| 231 |
+
|
| 232 |
+
preds = []
|
| 233 |
+
|
| 234 |
+
if self.training:
|
| 235 |
+
unlabeled_scores = unlabeled_scores[:, 1:, :] # exclude attachment for the root symbol
|
| 236 |
+
unlabeled_scores = unlabeled_scores.masked_fill(word_mask.unsqueeze(1), -float('inf'))
|
| 237 |
+
unlabeled_target = head.masked_fill(word_mask[:, 1:], -1)
|
| 238 |
+
loss = self.crit(unlabeled_scores.contiguous().view(-1, unlabeled_scores.size(2)), unlabeled_target.view(-1))
|
| 239 |
+
|
| 240 |
+
deprel_scores = deprel_scores[:, 1:] # exclude attachment for the root symbol
|
| 241 |
+
#deprel_scores = deprel_scores.masked_select(goldmask.unsqueeze(3)).view(-1, len(self.vocab['deprel']))
|
| 242 |
+
deprel_scores = torch.gather(deprel_scores, 2, head.unsqueeze(2).unsqueeze(3).expand(-1, -1, -1, len(self.vocab['deprel']))).view(-1, len(self.vocab['deprel']))
|
| 243 |
+
deprel_target = deprel.masked_fill(word_mask[:, 1:], -1)
|
| 244 |
+
loss += self.crit(deprel_scores.contiguous(), deprel_target.view(-1))
|
| 245 |
+
|
| 246 |
+
if self.args['linearization']:
|
| 247 |
+
#lin_scores = lin_scores[:, 1:].masked_select(goldmask)
|
| 248 |
+
lin_scores = torch.gather(lin_scores[:, 1:], 2, head.unsqueeze(2)).view(-1)
|
| 249 |
+
lin_scores = torch.cat([-lin_scores.unsqueeze(1)/2, lin_scores.unsqueeze(1)/2], 1)
|
| 250 |
+
#lin_target = (head_offset[:, 1:] > 0).long().masked_select(goldmask)
|
| 251 |
+
lin_target = torch.gather((head_offset[:, 1:] > 0).long(), 2, head.unsqueeze(2))
|
| 252 |
+
loss += self.crit(lin_scores.contiguous(), lin_target.view(-1))
|
| 253 |
+
|
| 254 |
+
if self.args['distance']:
|
| 255 |
+
#dist_kld = dist_kld[:, 1:].masked_select(goldmask)
|
| 256 |
+
dist_kld = torch.gather(dist_kld[:, 1:], 2, head.unsqueeze(2))
|
| 257 |
+
loss -= dist_kld.sum()
|
| 258 |
+
|
| 259 |
+
loss /= wordchars.size(0) # number of words
|
| 260 |
+
else:
|
| 261 |
+
loss = 0
|
| 262 |
+
preds.append(F.log_softmax(unlabeled_scores, 2).detach().cpu().numpy())
|
| 263 |
+
preds.append(deprel_scores.max(3)[1].detach().cpu().numpy())
|
| 264 |
+
|
| 265 |
+
return loss, preds
|