bowphs commited on
Commit
fff2f9b
·
verified ·
1 Parent(s): ba68d3c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stanza/stanza/models/common/__init__.py +0 -0
  2. stanza/stanza/models/common/bert_embedding.py +509 -0
  3. stanza/stanza/models/common/biaffine.py +80 -0
  4. stanza/stanza/models/common/build_short_name_to_treebank.py +78 -0
  5. stanza/stanza/models/common/char_model.py +362 -0
  6. stanza/stanza/models/common/chuliu_edmonds.py +281 -0
  7. stanza/stanza/models/common/constant.py +550 -0
  8. stanza/stanza/models/common/count_ner_coverage.py +38 -0
  9. stanza/stanza/models/common/count_pretrain_coverage.py +41 -0
  10. stanza/stanza/models/common/crf.py +149 -0
  11. stanza/stanza/models/common/data.py +155 -0
  12. stanza/stanza/models/common/doc.py +1741 -0
  13. stanza/stanza/models/common/dropout.py +75 -0
  14. stanza/stanza/models/common/exceptions.py +15 -0
  15. stanza/stanza/models/common/foundation_cache.py +148 -0
  16. stanza/stanza/models/common/hlstm.py +124 -0
  17. stanza/stanza/models/common/large_margin_loss.py +68 -0
  18. stanza/stanza/models/common/loss.py +134 -0
  19. stanza/stanza/models/common/maxout_linear.py +42 -0
  20. stanza/stanza/models/common/packed_lstm.py +105 -0
  21. stanza/stanza/models/common/peft_config.py +119 -0
  22. stanza/stanza/models/common/seq2seq_constant.py +17 -0
  23. stanza/stanza/models/common/seq2seq_model.py +364 -0
  24. stanza/stanza/models/common/seq2seq_utils.py +121 -0
  25. stanza/stanza/models/common/short_name_to_treebank.py +619 -0
  26. stanza/stanza/models/common/trainer.py +20 -0
  27. stanza/stanza/models/common/utils.py +816 -0
  28. stanza/stanza/models/common/vocab.py +298 -0
  29. stanza/stanza/models/constituency/base_model.py +532 -0
  30. stanza/stanza/models/constituency/base_trainer.py +153 -0
  31. stanza/stanza/models/constituency/ensemble.py +486 -0
  32. stanza/stanza/models/constituency/in_order_compound_oracle.py +327 -0
  33. stanza/stanza/models/constituency/in_order_oracle.py +1029 -0
  34. stanza/stanza/models/constituency/lstm_model.py +1178 -0
  35. stanza/stanza/models/constituency/parse_tree.py +591 -0
  36. stanza/stanza/models/constituency/positional_encoding.py +89 -0
  37. stanza/stanza/models/constituency/retagging.py +130 -0
  38. stanza/stanza/models/constituency/state.py +144 -0
  39. stanza/stanza/models/constituency/top_down_oracle.py +757 -0
  40. stanza/stanza/models/constituency/trainer.py +306 -0
  41. stanza/stanza/models/constituency/transformer_tree_stack.py +198 -0
  42. stanza/stanza/models/constituency/transition_sequence.py +186 -0
  43. stanza/stanza/models/constituency/tree_embedding.py +135 -0
  44. stanza/stanza/models/coref/config.py +66 -0
  45. stanza/stanza/models/coref/coref_config.toml +285 -0
  46. stanza/stanza/models/coref/dataset.py +61 -0
  47. stanza/stanza/models/coref/pairwise_encoder.py +94 -0
  48. stanza/stanza/models/coref/rough_scorer.py +61 -0
  49. stanza/stanza/models/coref/utils.py +35 -0
  50. stanza/stanza/models/depparse/model.py +265 -0
stanza/stanza/models/common/__init__.py ADDED
File without changes
stanza/stanza/models/common/bert_embedding.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
8
+
9
+ logger = logging.getLogger('stanza')
10
+
11
+ BERT_ARGS = {
12
+ "vinai/phobert-base": { "use_fast": True },
13
+ "vinai/phobert-large": { "use_fast": True },
14
+ }
15
+
16
+ class TextTooLongError(ValueError):
17
+ """
18
+ A text was too long for the underlying model (possibly BERT)
19
+ """
20
+ def __init__(self, length, max_len, line_num, text):
21
+ super().__init__("Found a text of length %d (possibly after tokenizing). Maximum handled length is %d Error occurred at line %d" % (length, max_len, line_num))
22
+ self.line_num = line_num
23
+ self.text = text
24
+
25
+
26
+ def update_max_length(model_name, tokenizer):
27
+ if model_name in ('hf-internal-testing/tiny-bert',
28
+ 'google/muril-base-cased',
29
+ 'google/muril-large-cased',
30
+ 'airesearch/wangchanberta-base-att-spm-uncased',
31
+ 'camembert/camembert-large',
32
+ 'hfl/chinese-electra-180g-large-discriminator',
33
+ 'NYTK/electra-small-discriminator-hungarian'):
34
+ tokenizer.model_max_length = 512
35
+
36
+ def load_tokenizer(model_name, tokenizer_kwargs=None, local_files_only=False):
37
+ if model_name:
38
+ # note that use_fast is the default
39
+ try:
40
+ from transformers import AutoTokenizer
41
+ except ImportError:
42
+ raise ImportError("Please install transformers library for BERT support! Try `pip install transformers`.")
43
+ bert_args = BERT_ARGS.get(model_name, dict())
44
+ if not model_name.startswith("vinai/phobert"):
45
+ bert_args["add_prefix_space"] = True
46
+ if tokenizer_kwargs:
47
+ bert_args.update(tokenizer_kwargs)
48
+ bert_args['local_files_only'] = local_files_only
49
+ bert_tokenizer = AutoTokenizer.from_pretrained(model_name, **bert_args)
50
+ update_max_length(model_name, bert_tokenizer)
51
+ return bert_tokenizer
52
+ return None
53
+
54
+ def load_bert(model_name, tokenizer_kwargs=None, local_files_only=False):
55
+ if model_name:
56
+ # such as: "vinai/phobert-base"
57
+ try:
58
+ from transformers import AutoModel
59
+ except ImportError:
60
+ raise ImportError("Please install transformers library for BERT support! Try `pip install transformers`.")
61
+ bert_model = AutoModel.from_pretrained(model_name, local_files_only=local_files_only)
62
+ bert_tokenizer = load_tokenizer(model_name, tokenizer_kwargs=tokenizer_kwargs, local_files_only=local_files_only)
63
+ return bert_model, bert_tokenizer
64
+ return None, None
65
+
66
+ def tokenize_manual(model_name, sent, tokenizer):
67
+ """
68
+ Tokenize a sentence manually, using for checking long sentences and PHOBert.
69
+ """
70
+ #replace \xa0 or whatever the space character is by _ since PhoBERT expects _ between syllables
71
+ tokenized = [word.replace("\xa0","_").replace(" ", "_") for word in sent] if model_name.startswith("vinai/phobert") else [word.replace("\xa0"," ") for word in sent]
72
+
73
+ #concatenate to a sentence
74
+ sentence = ' '.join(tokenized)
75
+
76
+ #tokenize using AutoTokenizer PhoBERT
77
+ tokenized = tokenizer.tokenize(sentence)
78
+
79
+ #convert tokens to ids
80
+ sent_ids = tokenizer.convert_tokens_to_ids(tokenized)
81
+
82
+ #add start and end tokens to sent_ids
83
+ tokenized_sent = [tokenizer.bos_token_id] + sent_ids + [tokenizer.eos_token_id]
84
+
85
+ return tokenized, tokenized_sent
86
+
87
+ def filter_data(model_name, data, tokenizer = None, log_level=logging.DEBUG):
88
+ """
89
+ Filter out the (NER, POS) data that is too long for BERT model.
90
+ """
91
+ if tokenizer is None:
92
+ tokenizer = load_tokenizer(model_name)
93
+ filtered_data = []
94
+ #eliminate all the sentences that are too long for bert model
95
+ for sent in data:
96
+ sentence = [word if isinstance(word, str) else word[0] for word in sent]
97
+ _, tokenized_sent = tokenize_manual(model_name, sentence, tokenizer)
98
+
99
+ if len(tokenized_sent) > tokenizer.model_max_length - 2:
100
+ continue
101
+
102
+ filtered_data.append(sent)
103
+
104
+ logger.log(log_level, "Eliminated %d of %d datapoints because their length is over maximum size of BERT model.", (len(data)-len(filtered_data)), len(data))
105
+
106
+ return filtered_data
107
+
108
+ def needs_length_filter(model_name):
109
+ """
110
+ TODO: we were lazy and didn't implement any form of length fudging for models other than bert/roberta/electra
111
+ """
112
+ if 'bart' in model_name or 'xlnet' in model_name:
113
+ return True
114
+ if model_name.startswith("vinai/phobert"):
115
+ return True
116
+ return False
117
+
118
+ def cloned_feature(feature, num_layers, detach=True):
119
+ """
120
+ Clone & detach the feature, keeping the last N layers (or averaging -2,-3,-4 if not specified)
121
+
122
+ averaging 3 of the last 4 layers worked well for non-VI languages
123
+ """
124
+ # in most cases, need to call with features.hidden_states
125
+ # bartpho is different - it has features.decoder_hidden_states
126
+ # feature[2] is the same for bert, but it didn't work for
127
+ # older versions of transformers for xlnet
128
+ if num_layers is None:
129
+ feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4
130
+ else:
131
+ feature = torch.stack(feature[-num_layers:], axis=3)
132
+ if detach:
133
+ return feature.clone().detach()
134
+ else:
135
+ return feature
136
+
137
+ def extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
138
+ """
139
+ Handles vi-bart. May need testing before using on other bart
140
+
141
+ https://github.com/VinAIResearch/BARTpho
142
+ """
143
+ processed = [] # final product, returns the list of list of word representation
144
+
145
+ sentences = [" ".join([word.replace(" ", "_") for word in sentence]) for sentence in data]
146
+ tokenized = tokenizer(sentences, return_tensors='pt', padding=True, return_attention_mask=True)
147
+ input_ids = tokenized['input_ids'].to(device)
148
+ attention_mask = tokenized['attention_mask'].to(device)
149
+
150
+ for i in range(int(math.ceil(len(sentences)/128))):
151
+ start_sentence = i * 128
152
+ end_sentence = min(start_sentence + 128, len(sentences))
153
+ input_ids = input_ids[start_sentence:end_sentence]
154
+ attention_mask = attention_mask[start_sentence:end_sentence]
155
+
156
+ if detach:
157
+ with torch.no_grad():
158
+ features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
159
+ features = cloned_feature(features.decoder_hidden_states, num_layers, detach)
160
+ else:
161
+ features = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
162
+ features = cloned_feature(features.decoder_hidden_states, num_layers, detach)
163
+
164
+ for feature, sentence in zip(features, data):
165
+ # +2 for the endpoints
166
+ feature = feature[:len(sentence)+2]
167
+ if not keep_endpoints:
168
+ feature = feature[1:-1]
169
+ processed.append(feature)
170
+
171
+ return processed
172
+
173
+ def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
174
+ """
175
+ Extract transformer embeddings using a method specifically for phobert
176
+
177
+ Since phobert doesn't have the is_split_into_words / tokenized.word_ids(batch_index=0)
178
+ capability, we instead look for @@ to denote a continued token.
179
+ data: list of list of string (the text tokens)
180
+ """
181
+ processed = [] # final product, returns the list of list of word representation
182
+ tokenized_sents = [] # list of sentences, each is a torch tensor with start and end token
183
+ list_tokenized = [] # list of tokenized sentences from phobert
184
+ for idx, sent in enumerate(data):
185
+
186
+ tokenized, tokenized_sent = tokenize_manual(model_name, sent, tokenizer)
187
+
188
+ #add tokenized to list_tokenzied for later checking
189
+ list_tokenized.append(tokenized)
190
+
191
+ if len(tokenized_sent) > tokenizer.model_max_length:
192
+ logger.error("Invalid size, max size: %d, got %d %s", tokenizer.model_max_length, len(tokenized_sent), data[idx])
193
+ raise TextTooLongError(len(tokenized_sent), tokenizer.model_max_length, idx, " ".join(data[idx]))
194
+
195
+ #add to tokenized_sents
196
+ tokenized_sents.append(torch.tensor(tokenized_sent).detach())
197
+
198
+ processed_sent = []
199
+ processed.append(processed_sent)
200
+
201
+ # done loading bert emb
202
+
203
+ size = len(tokenized_sents)
204
+
205
+ #padding the inputs
206
+ tokenized_sents_padded = torch.nn.utils.rnn.pad_sequence(tokenized_sents,batch_first=True,padding_value=tokenizer.pad_token_id)
207
+
208
+ features = []
209
+
210
+ # Feed into PhoBERT 128 at a time in a batch fashion. In testing, the loop was
211
+ # run only 1 time as the batch size for the outer model was less than that
212
+ # (30 for conparser, for example)
213
+ for i in range(int(math.ceil(size/128))):
214
+ padded_input = tokenized_sents_padded[128*i:128*i+128]
215
+ start_sentence = i * 128
216
+ end_sentence = start_sentence + padded_input.shape[0]
217
+ attention_mask = torch.zeros(end_sentence - start_sentence, padded_input.shape[1], device=device)
218
+ for sent_idx, sent in enumerate(tokenized_sents[start_sentence:end_sentence]):
219
+ attention_mask[sent_idx, :len(sent)] = 1
220
+ if detach:
221
+ with torch.no_grad():
222
+ # TODO: is the clone().detach() necessary?
223
+ feature = model(padded_input.clone().detach().to(device), attention_mask=attention_mask, output_hidden_states=True)
224
+ features += cloned_feature(feature.hidden_states, num_layers, detach)
225
+ else:
226
+ feature = model(padded_input.to(device), attention_mask=attention_mask, output_hidden_states=True)
227
+ features += cloned_feature(feature.hidden_states, num_layers, detach)
228
+
229
+ assert len(features)==size
230
+ assert len(features)==len(processed)
231
+
232
+ #process the output
233
+ #only take the vector of the last word piece of a word/ you can do other methods such as first word piece or averaging.
234
+ # idx2+1 compensates for the start token at the start of a sentence
235
+ offsets = [[idx2+1 for idx2, _ in enumerate(list_tokenized[idx]) if (idx2 > 0 and not list_tokenized[idx][idx2-1].endswith("@@")) or (idx2==0)]
236
+ for idx, sent in enumerate(processed)]
237
+ if keep_endpoints:
238
+ # [0] and [-1] grab the start and end representations as well
239
+ offsets = [[0] + off + [-1] for off in offsets]
240
+ processed = [feature[offset] for feature, offset in zip(features, offsets)]
241
+
242
+ # This is a list of tensors
243
+ # Each tensor holds the representation of a sentence extracted from phobert
244
+ return processed
245
+
246
+ BAD_TOKENIZERS = ('bert-base-german-cased',
247
+ # the dbmdz tokenizers turn one or more types of characters into empty words
248
+ # for example, from PoSTWITA:
249
+ # ewww 󾓺 — in viaggio Roma
250
+ # the character which may not be rendering properly is 0xFE4FA
251
+ # https://github.com/dbmdz/berts/issues/48
252
+ 'dbmdz/bert-base-german-cased',
253
+ 'dbmdz/bert-base-italian-xxl-cased',
254
+ 'dbmdz/bert-base-italian-cased',
255
+ 'dbmdz/electra-base-italian-xxl-cased-discriminator',
256
+ # each of these (perhaps using similar tokenizers?)
257
+ # does not digest the script-flip-mark \u200f
258
+ 'avichr/heBERT',
259
+ 'onlplab/alephbert-base',
260
+ 'imvladikon/alephbertgimmel-base-512',
261
+ # these indonesian models fail on a sentence in the Indonesian GSD dataset:
262
+ # 'Tak', 'dapat', 'disangkal', 'jika', '\u200e', 'kemenangan', ...
263
+ # weirdly some other indonesian models (even by the same group) don't have that problem
264
+ 'cahya/bert-base-indonesian-1.5G',
265
+ 'indolem/indobert-base-uncased',
266
+ 'google/muril-base-cased',
267
+ 'l3cube-pune/marathi-roberta')
268
+
269
+ def fix_blank_tokens(tokenizer, data):
270
+ """Patch bert tokenizers with missing characters
271
+
272
+ There is an issue that some tokenizers (so far the German ones identified above)
273
+ tokenize soft hyphens or other unknown characters into nothing
274
+ If an entire word is tokenized as a soft hyphen, this means the tokenizer
275
+ simply vaporizes that word. The result is we're missing an embedding for
276
+ an entire word we wanted to use.
277
+
278
+ The solution we take here is to look for any words which get vaporized
279
+ in such a manner, eg `len(token) == 2`, and replace it with a regular "-"
280
+
281
+ Actually, recently we have found that even the Bert / Electra tokenizer
282
+ can do this in the case of "words" which are one special character long,
283
+ so the easiest thing to do is just always run this function
284
+ """
285
+ new_data = []
286
+ for sentence in data:
287
+ tokenized = tokenizer(sentence, is_split_into_words=False).input_ids
288
+ new_sentence = [word if len(token) > 2 else "-" for word, token in zip(sentence, tokenized)]
289
+ new_data.append(new_sentence)
290
+ return new_data
291
+
292
+ def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach=True):
293
+ # using attention masks makes contextual embeddings much more useful for downstream tasks
294
+ tokenized = tokenizer(data, is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=False)
295
+ #tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
296
+
297
+ list_offsets = [[None] * (len(sentence)+2) for sentence in data]
298
+ for idx in range(len(data)):
299
+ offsets = tokenized.word_ids(batch_index=idx)
300
+ list_offsets[idx][0] = 0
301
+ for pos, offset in enumerate(offsets):
302
+ if offset is None:
303
+ break
304
+ # this uses the last token piece for any offset by overwriting the previous value
305
+ # this will be one token earlier
306
+ # we will add a <pad> to the start of each sentence for the endpoints
307
+ list_offsets[idx][offset+1] = pos + 1
308
+ list_offsets[idx][-1] = list_offsets[idx][-2] + 1
309
+ if any(x is None for x in list_offsets[idx]):
310
+ raise ValueError("OOPS, hit None when preparing to use Bert\ndata[idx]: {}\noffsets: {}\nlist_offsets[idx]: {}".format(data[idx], offsets, list_offsets[idx], tokenized))
311
+
312
+ if len(offsets) > tokenizer.model_max_length - 2:
313
+ logger.error("Invalid size, max size: %d, got %d %s", tokenizer.model_max_length, len(offsets), data[idx])
314
+ raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))
315
+
316
+ features = []
317
+ for i in range(int(math.ceil(len(data)/128))):
318
+ # TODO: find a suitable representation for attention masks for xlnet
319
+ # xlnet base on WSJ:
320
+ # sep_token_id at beginning, cls_token_id at end: 0.9441
321
+ # bos_token_id at beginning, eos_token_id at end: 0.9463
322
+ # bos_token_id at beginning, sep_token_id at end: 0.9459
323
+ # bos_token_id at beginning, cls_token_id at end: 0.9457
324
+ # bos_token_id at beginning, sep/cls at end: 0.9454
325
+ # use the xlnet tokenization with words at end,
326
+ # begin token is last pad, end token is sep, no mask: 0.9463
327
+ # same, but with masks: 0.9440
328
+ input_ids = [[tokenizer.bos_token_id] + x[:-2] + [tokenizer.eos_token_id] for x in tokenized['input_ids'][128*i:128*i+128]]
329
+ max_len = max(len(x) for x in input_ids)
330
+ attention_mask = torch.zeros(len(input_ids), max_len, dtype=torch.long, device=device)
331
+ for idx, input_row in enumerate(input_ids):
332
+ attention_mask[idx, :len(input_row)] = 1
333
+ if len(input_row) < max_len:
334
+ input_row.extend([tokenizer.pad_token_id] * (max_len - len(input_row)))
335
+ if detach:
336
+ with torch.no_grad():
337
+ id_tensor = torch.tensor(input_ids, device=device)
338
+ feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
339
+ # feature[2] is the same for bert, but it didn't work for
340
+ # older versions of transformers for xlnet
341
+ # feature = feature[2]
342
+ features += cloned_feature(feature.hidden_states, num_layers, detach)
343
+ else:
344
+ id_tensor = torch.tensor(input_ids, device=device)
345
+ feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
346
+ # feature[2] is the same for bert, but it didn't work for
347
+ # older versions of transformers for xlnet
348
+ # feature = feature[2]
349
+ features += cloned_feature(feature.hidden_states, num_layers, detach)
350
+
351
+ processed = []
352
+ #process the output
353
+ if not keep_endpoints:
354
+ #remove the bos and eos tokens
355
+ list_offsets = [sent[1:-1] for sent in list_offsets]
356
+ for feature, offsets in zip(features, list_offsets):
357
+ new_sent = feature[offsets]
358
+ processed.append(new_sent)
359
+
360
+ return processed
361
+
362
+ def build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device):
363
+ """
364
+ Extract an embedding from the given transformer for a certain attention mask and tokens range
365
+
366
+ In the event that the tokens are longer than the max length
367
+ supported by the model, the range is split up into overlapping
368
+ sections and the overlapping pieces are connected. No idea if
369
+ this is actually any good, but at least it returns something
370
+ instead of horribly failing
371
+
372
+ TODO: at least two upgrades are very relevant
373
+ 1) cut off some overlap at the end as well
374
+ 2) use this on the phobert, bart, and xln versions as well
375
+ """
376
+ if attention_tensor.shape[1] <= tokenizer.model_max_length:
377
+ features = model(id_tensor, attention_mask=attention_tensor, output_hidden_states=True)
378
+ features = cloned_feature(features.hidden_states, num_layers, detach)
379
+ return features
380
+
381
+ slices = []
382
+ slice_len = max(tokenizer.model_max_length - 20, tokenizer.model_max_length // 2)
383
+ prefix_len = tokenizer.model_max_length - slice_len
384
+ if slice_len < 5:
385
+ raise RuntimeError("Really tiny tokenizer!")
386
+ remaining_attention = attention_tensor
387
+ remaining_ids = id_tensor
388
+ while True:
389
+ attention_slice = remaining_attention[:, :tokenizer.model_max_length]
390
+ id_slice = remaining_ids[:, :tokenizer.model_max_length]
391
+ features = model(id_slice, attention_mask=attention_slice, output_hidden_states=True)
392
+ features = cloned_feature(features.hidden_states, num_layers, detach)
393
+ if len(slices) > 0:
394
+ features = features[:, prefix_len:, :]
395
+ slices.append(features)
396
+ if remaining_attention.shape[1] <= tokenizer.model_max_length:
397
+ break
398
+ remaining_attention = remaining_attention[:, slice_len:]
399
+ remaining_ids = remaining_ids[:, slice_len:]
400
+ slices = torch.cat(slices, axis=1)
401
+ return slices
402
+
403
+
404
+ def convert_to_position_list(sentence, offsets):
405
+ """
406
+ Convert a transformers-tokenized sentence's offsets to a list of word to position
407
+ """
408
+ # +2 for the beginning and end
409
+ list_offsets = [None] * (len(sentence) + 2)
410
+ for pos, offset in enumerate(offsets):
411
+ if offset is None:
412
+ continue
413
+ # this uses the last token piece for any offset by overwriting the previous value
414
+ list_offsets[offset+1] = pos
415
+ list_offsets[0] = 0
416
+ for offset in list_offsets[-2::-1]:
417
+ # count backwards in case the last position was
418
+ # a word or character that got erased by the tokenizer
419
+ # this loop should eventually find something...
420
+ # after all, we just set the first one to be 0
421
+ if offset is not None:
422
+ list_offsets[-1] = offset + 1
423
+ break
424
+ return list_offsets
425
+
426
+ def extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach):
427
+ #add add_prefix_space = True for RoBerTa-- error if not
428
+ # using attention masks makes contextual embeddings much more useful for downstream tasks
429
+ tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
430
+ list_offsets = []
431
+ for idx in range(len(data)):
432
+ converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))
433
+ list_offsets.append(converted_offsets)
434
+
435
+ #if list_offsets[idx][-1] > tokenizer.model_max_length - 1:
436
+ # logger.error("Invalid size, max size: %d, got %d.\nTokens: %s\nTokenized: %s", tokenizer.model_max_length, len(offsets), data[idx][:1000], offsets[:1000])
437
+ # raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))
438
+
439
+ if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):
440
+ # at least one of the tokens in the data is composed entirely of characters the tokenizer doesn't know about
441
+ # one possible approach would be to retokenize only those sentences
442
+ # however, in that case the attention mask might be of a different length,
443
+ # as would the token ids, and it would be a pain to fix those
444
+ # easiest to just retokenize the whole thing, hopefully a rare event
445
+ data = fix_blank_tokens(tokenizer, data)
446
+
447
+ tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
448
+ list_offsets = []
449
+ for idx in range(len(data)):
450
+ converted_offsets = convert_to_position_list(data[idx], tokenized.word_ids(batch_index=idx))
451
+ list_offsets.append(converted_offsets)
452
+
453
+ if any(any(x is None for x in converted_offsets) for converted_offsets in list_offsets):
454
+ raise ValueError("OOPS, hit None when preparing to use Bert\ndata[idx]: {}\noffsets: {}\nlist_offsets[idx]: {}".format(data[idx], offsets, list_offsets[idx], tokenized))
455
+
456
+
457
+ features = []
458
+ for i in range(int(math.ceil(len(data)/128))):
459
+ attention_tensor = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
460
+ id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
461
+ if detach:
462
+ with torch.no_grad():
463
+ features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
464
+ else:
465
+ features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
466
+
467
+ processed = []
468
+ #process the output
469
+ if not keep_endpoints:
470
+ #remove the bos and eos tokens
471
+ list_offsets = [sent[1:-1] for sent in list_offsets]
472
+ for feature, offsets in zip(features, list_offsets):
473
+ new_sent = feature[offsets]
474
+ processed.append(new_sent)
475
+
476
+ return processed
477
+
478
+ def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers=None, detach=True, peft_name=None):
479
+ """
480
+ Extract transformer embeddings using a generic roberta extraction
481
+
482
+ data: list of list of string (the text tokens)
483
+ num_layers: how many to return. If None, the average of -2, -3, -4 is returned
484
+ """
485
+ # TODO: can maybe cache this value for a model and save some time
486
+ # TODO: too bad it isn't thread safe, but then again, who does?
487
+ if peft_name is None:
488
+ if model._hf_peft_config_loaded:
489
+ model.disable_adapters()
490
+ else:
491
+ model.enable_adapters()
492
+ model.set_adapter(peft_name)
493
+
494
+ if model_name.startswith("vinai/phobert"):
495
+ return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
496
+
497
+ if 'bart' in model_name:
498
+ # this should work with "vinai/bartpho-word"
499
+ # not sure this works with any other Bart
500
+ return extract_bart_word_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
501
+
502
+ if isinstance(data, tuple):
503
+ data = list(data)
504
+
505
+ if "xlnet" in model_name:
506
+ return extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
507
+
508
+ return extract_base_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers, detach)
509
+
stanza/stanza/models/common/biaffine.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class PairwiseBilinear(nn.Module):
6
+ ''' A bilinear module that deals with broadcasting for efficient memory usage.
7
+ Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
8
+ Output: tensor of size (N x L1 x L2 x O)'''
9
+ def __init__(self, input1_size, input2_size, output_size, bias=True):
10
+ super().__init__()
11
+
12
+ self.input1_size = input1_size
13
+ self.input2_size = input2_size
14
+ self.output_size = output_size
15
+
16
+ self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))
17
+ self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else 0
18
+
19
+ def forward(self, input1, input2):
20
+ input1_size = list(input1.size())
21
+ input2_size = list(input2.size())
22
+ output_size = [input1_size[0], input1_size[1], input2_size[1], self.output_size]
23
+
24
+ # ((N x L1) x D1) * (D1 x (D2 x O)) -> (N x L1) x (D2 x O)
25
+ intermediate = torch.mm(input1.view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size))
26
+ # (N x L2 x D2) -> (N x D2 x L2)
27
+ input2 = input2.transpose(1, 2)
28
+ # (N x (L1 x O) x D2) * (N x D2 x L2) -> (N x (L1 x O) x L2)
29
+ output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
30
+ # (N x (L1 x O) x L2) -> (N x L1 x L2 x O)
31
+ output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)
32
+
33
+ return output
34
+
35
+ class BiaffineScorer(nn.Module):
36
+ def __init__(self, input1_size, input2_size, output_size):
37
+ super().__init__()
38
+ self.W_bilin = nn.Bilinear(input1_size + 1, input2_size + 1, output_size)
39
+
40
+ self.W_bilin.weight.data.zero_()
41
+ self.W_bilin.bias.data.zero_()
42
+
43
+ def forward(self, input1, input2):
44
+ input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)
45
+ input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)
46
+ return self.W_bilin(input1, input2)
47
+
48
+ class PairwiseBiaffineScorer(nn.Module):
49
+ def __init__(self, input1_size, input2_size, output_size):
50
+ super().__init__()
51
+ self.W_bilin = PairwiseBilinear(input1_size + 1, input2_size + 1, output_size)
52
+
53
+ self.W_bilin.weight.data.zero_()
54
+ self.W_bilin.bias.data.zero_()
55
+
56
+ def forward(self, input1, input2):
57
+ input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size())-1)
58
+ input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size())-1)
59
+ return self.W_bilin(input1, input2)
60
+
61
+ class DeepBiaffineScorer(nn.Module):
62
+ def __init__(self, input1_size, input2_size, hidden_size, output_size, hidden_func=F.relu, dropout=0, pairwise=True):
63
+ super().__init__()
64
+ self.W1 = nn.Linear(input1_size, hidden_size)
65
+ self.W2 = nn.Linear(input2_size, hidden_size)
66
+ self.hidden_func = hidden_func
67
+ if pairwise:
68
+ self.scorer = PairwiseBiaffineScorer(hidden_size, hidden_size, output_size)
69
+ else:
70
+ self.scorer = BiaffineScorer(hidden_size, hidden_size, output_size)
71
+ self.dropout = nn.Dropout(dropout)
72
+
73
+ def forward(self, input1, input2):
74
+ return self.scorer(self.dropout(self.hidden_func(self.W1(input1))), self.dropout(self.hidden_func(self.W2(input2))))
75
+
76
+ if __name__ == "__main__":
77
+ x1 = torch.randn(3,4)
78
+ x2 = torch.randn(3,5)
79
+ scorer = DeepBiaffineScorer(4, 5, 6, 7)
80
+ print(scorer(x1, x2))
stanza/stanza/models/common/build_short_name_to_treebank.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ from stanza.models.common.constant import treebank_to_short_name, UnknownLanguageError, treebank_special_cases
5
+ from stanza.utils import default_paths
6
+
7
+ paths = default_paths.get_default_paths()
8
+ udbase = paths["UDBASE"]
9
+
10
+ directories = glob.glob(udbase + "/UD_*")
11
+ directories.sort()
12
+
13
+ output_name = os.path.join(os.path.split(__file__)[0], "short_name_to_treebank.py")
14
+ ud_names = [os.path.split(ud_path)[1] for ud_path in directories]
15
+ short_names = []
16
+
17
+ # check that all languages are known in the language map
18
+ # use that language map to come up with a shortname for these treebanks
19
+ for directory, ud_name in zip(directories, ud_names):
20
+ try:
21
+ short_names.append(treebank_to_short_name(ud_name))
22
+ except UnknownLanguageError as e:
23
+ raise UnknownLanguageError("Could not find language short name for dataset %s, path %s" % (ud_name, directory)) from e
24
+
25
+ for directory, ud_name in zip(directories, ud_names):
26
+ if ud_name.startswith("UD_Norwegian"):
27
+ if ud_name not in treebank_special_cases:
28
+ raise ValueError("Please figure out if dataset %s is NN or NB, then add to treebank_special_cases" % ud_name)
29
+ if ud_name.startswith("UD_Chinese"):
30
+ if ud_name not in treebank_special_cases:
31
+ raise ValueError("Please figure out if dataset %s is NN or NB, then add to treebank_special_cases" % ud_name)
32
+
33
+ max_len = max(len(x) for x in short_names) + 8
34
+ line_format = " %-" + str(max_len) + "s '%s',\n"
35
+
36
+
37
+ print("Writing to %s" % output_name)
38
+ with open(output_name, "w") as fout:
39
+ fout.write("# This module is autogenerated by build_short_name_to_treebank.py\n")
40
+ fout.write("# Please do not edit\n")
41
+ fout.write("\n")
42
+ fout.write("SHORT_NAMES = {\n")
43
+ for short_name, ud_name in zip(short_names, ud_names):
44
+ fout.write(line_format % ("'" + short_name + "':", ud_name))
45
+
46
+ if short_name.startswith("zh_"):
47
+ short_name = "zh-hans_" + short_name[3:]
48
+ fout.write(line_format % ("'" + short_name + "':", ud_name))
49
+ elif short_name.startswith("zh-hans_") or short_name.startswith("zh-hant_"):
50
+ short_name = "zh_" + short_name[8:]
51
+ fout.write(line_format % ("'" + short_name + "':", ud_name))
52
+ elif short_name == 'nb_bokmaal':
53
+ short_name = 'no_bokmaal'
54
+ fout.write(line_format % ("'" + short_name + "':", ud_name))
55
+
56
+ fout.write("}\n")
57
+
58
+ fout.write("""
59
+
60
+ def short_name_to_treebank(short_name):
61
+ return SHORT_NAMES[short_name]
62
+
63
+
64
+ """)
65
+
66
+ max_len = max(len(x) for x in ud_names) + 5
67
+ line_format = " %-" + str(max_len) + "s '%s',\n"
68
+ fout.write("CANONICAL_NAMES = {\n")
69
+ for ud_name in ud_names:
70
+ fout.write(line_format % ("'" + ud_name.lower() + "':", ud_name))
71
+ fout.write("}\n")
72
+ fout.write("""
73
+
74
+ def canonical_treebank_name(ud_name):
75
+ if ud_name in SHORT_NAMES:
76
+ return SHORT_NAMES[ud_name]
77
+ return CANONICAL_NAMES.get(ud_name.lower(), ud_name)
78
+ """)
stanza/stanza/models/common/char_model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on
3
+
4
+ @inproceedings{akbik-etal-2018-contextual,
5
+ title = "Contextual String Embeddings for Sequence Labeling",
6
+ author = "Akbik, Alan and
7
+ Blythe, Duncan and
8
+ Vollgraf, Roland",
9
+ booktitle = "Proceedings of the 27th International Conference on Computational Linguistics",
10
+ month = aug,
11
+ year = "2018",
12
+ address = "Santa Fe, New Mexico, USA",
13
+ publisher = "Association for Computational Linguistics",
14
+ url = "https://aclanthology.org/C18-1139",
15
+ pages = "1638--1649",
16
+ }
17
+ """
18
+
19
+ from collections import Counter
20
+ from operator import itemgetter
21
+ import os
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pack_padded_sequence, PackedSequence
26
+
27
+ from stanza.models.common.data import get_long_tensor
28
+ from stanza.models.common.packed_lstm import PackedLSTM
29
+ from stanza.models.common.utils import open_read_text, tensor_unsort, unsort
30
+ from stanza.models.common.dropout import SequenceUnitDropout
31
+ from stanza.models.common.vocab import UNK_ID, CharVocab
32
+
33
+ class CharacterModel(nn.Module):
34
+ def __init__(self, args, vocab, pad=False, bidirectional=False, attention=True):
35
+ super().__init__()
36
+ self.args = args
37
+ self.pad = pad
38
+ self.num_dir = 2 if bidirectional else 1
39
+ self.attn = attention
40
+
41
+ # char embeddings
42
+ self.char_emb = nn.Embedding(len(vocab['char']), self.args['char_emb_dim'], padding_idx=0)
43
+ if self.attn:
44
+ self.char_attn = nn.Linear(self.num_dir * self.args['char_hidden_dim'], 1, bias=False)
45
+ self.char_attn.weight.data.zero_()
46
+
47
+ # modules
48
+ self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \
49
+ dropout=0 if self.args['char_num_layers'] == 1 else args['dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=bidirectional)
50
+ self.charlstm_h_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
51
+ self.charlstm_c_init = nn.Parameter(torch.zeros(self.num_dir * self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
52
+
53
+ self.dropout = nn.Dropout(args['dropout'])
54
+
55
+ def forward(self, chars, chars_mask, word_orig_idx, sentlens, wordlens):
56
+ embs = self.dropout(self.char_emb(chars))
57
+ batch_size = embs.size(0)
58
+ embs = pack_padded_sequence(embs, wordlens, batch_first=True)
59
+ output = self.charlstm(embs, wordlens, hx=(\
60
+ self.charlstm_h_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(), \
61
+ self.charlstm_c_init.expand(self.num_dir * self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous()))
62
+
63
+ # apply attention, otherwise take final states
64
+ if self.attn:
65
+ char_reps = output[0]
66
+ weights = torch.sigmoid(self.char_attn(self.dropout(char_reps.data)))
67
+ char_reps = PackedSequence(char_reps.data * weights, char_reps.batch_sizes)
68
+ char_reps, _ = pad_packed_sequence(char_reps, batch_first=True)
69
+ res = char_reps.sum(1)
70
+ else:
71
+ h, c = output[1]
72
+ res = h[-2:].transpose(0,1).contiguous().view(batch_size, -1)
73
+
74
+ # recover character order and word separation
75
+ res = tensor_unsort(res, word_orig_idx)
76
+ res = pack_sequence(res.split(sentlens))
77
+ if self.pad:
78
+ res = pad_packed_sequence(res, batch_first=True)[0]
79
+
80
+ return res
81
+
82
+ def build_charlm_vocab(path, cutoff=0):
83
+ """
84
+ Build a vocab for a CharacterLanguageModel
85
+
86
+ Requires a large amount of memory, but only need to build once
87
+
88
+ here we need some trick to deal with excessively large files
89
+ for each file we accumulate the counter of characters, and
90
+ at the end we simply pass a list of chars to the vocab builder
91
+ """
92
+ counter = Counter()
93
+ if os.path.isdir(path):
94
+ filenames = sorted(os.listdir(path))
95
+ else:
96
+ filenames = [os.path.split(path)[1]]
97
+ path = os.path.split(path)[0]
98
+
99
+ for filename in filenames:
100
+ filename = os.path.join(path, filename)
101
+ with open_read_text(filename) as fin:
102
+ for line in fin:
103
+ counter.update(list(line))
104
+
105
+ if len(counter) == 0:
106
+ raise ValueError("Training data was empty!")
107
+ # remove infrequent characters from vocab
108
+ for k in list(counter.keys()):
109
+ if counter[k] < cutoff:
110
+ del counter[k]
111
+ # a singleton list of all characters
112
+ data = [sorted([x[0] for x in counter.most_common()])]
113
+ if len(data[0]) == 0:
114
+ raise ValueError("All characters in the training data were less frequent than --cutoff!")
115
+ vocab = CharVocab(data) # skip cutoff argument because this has been dealt with
116
+ return vocab
117
+
118
+ CHARLM_START = "\n"
119
+ CHARLM_END = " "
120
+
121
+ class CharacterLanguageModel(nn.Module):
122
+
123
+ def __init__(self, args, vocab, pad=False, is_forward_lm=True):
124
+ super().__init__()
125
+ self.args = args
126
+ self.vocab = vocab
127
+ self.is_forward_lm = is_forward_lm
128
+ self.pad = pad
129
+ self.finetune = True # always finetune unless otherwise specified
130
+
131
+ # char embeddings
132
+ self.char_emb = nn.Embedding(len(self.vocab['char']), self.args['char_emb_dim'], padding_idx=None) # we use space as padding, so padding_idx is not necessary
133
+
134
+ # modules
135
+ self.charlstm = PackedLSTM(self.args['char_emb_dim'], self.args['char_hidden_dim'], self.args['char_num_layers'], batch_first=True, \
136
+ dropout=0 if self.args['char_num_layers'] == 1 else args['char_dropout'], rec_dropout = self.args['char_rec_dropout'], bidirectional=False)
137
+ self.charlstm_h_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
138
+ self.charlstm_c_init = nn.Parameter(torch.zeros(self.args['char_num_layers'], 1, self.args['char_hidden_dim']))
139
+
140
+ # decoder
141
+ self.decoder = nn.Linear(self.args['char_hidden_dim'], len(self.vocab['char']))
142
+ self.dropout = nn.Dropout(args['char_dropout'])
143
+ self.char_dropout = SequenceUnitDropout(args.get('char_unit_dropout', 0), UNK_ID)
144
+
145
+ def forward(self, chars, charlens, hidden=None):
146
+ chars = self.char_dropout(chars)
147
+ embs = self.dropout(self.char_emb(chars))
148
+ batch_size = embs.size(0)
149
+ embs = pack_padded_sequence(embs, charlens, batch_first=True)
150
+ if hidden is None:
151
+ hidden = (self.charlstm_h_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous(),
152
+ self.charlstm_c_init.expand(self.args['char_num_layers'], batch_size, self.args['char_hidden_dim']).contiguous())
153
+ output, hidden = self.charlstm(embs, charlens, hx=hidden)
154
+ output = self.dropout(pad_packed_sequence(output, batch_first=True)[0])
155
+ decoded = self.decoder(output)
156
+ return output, hidden, decoded
157
+
158
+ def get_representation(self, chars, charoffsets, charlens, char_orig_idx):
159
+ with torch.no_grad():
160
+ output, _, _ = self.forward(chars, charlens)
161
+ res = [output[i, offsets] for i, offsets in enumerate(charoffsets)]
162
+ res = unsort(res, char_orig_idx)
163
+ res = pack_sequence(res)
164
+ if self.pad:
165
+ res = pad_packed_sequence(res, batch_first=True)[0]
166
+ return res
167
+
168
+ def per_char_representation(self, words):
169
+ device = next(self.parameters()).device
170
+ vocab = self.char_vocab()
171
+
172
+ all_data = [(vocab.map(word), len(word), idx) for idx, word in enumerate(words)]
173
+ all_data.sort(key=itemgetter(1), reverse=True)
174
+ chars = [x[0] for x in all_data]
175
+ char_lens = [x[1] for x in all_data]
176
+ char_tensor = get_long_tensor(chars, len(chars), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)
177
+ with torch.no_grad():
178
+ output, _, _ = self.forward(char_tensor, char_lens)
179
+ output = [x[:y, :] for x, y in zip(output, char_lens)]
180
+ output = unsort(output, [x[2] for x in all_data])
181
+ return output
182
+
183
+ def build_char_representation(self, sentences):
184
+ """
185
+ Return values from this charlm for a list of list of words
186
+
187
+ input: [[str]]
188
+ K sentences, each of length Ki (can be different for each sentence)
189
+ output: [tensor(Ki x dim)]
190
+ list of tensors, each one with shape Ki by the dim of the character model
191
+
192
+ Values are taken from the last character in a word for each word.
193
+ The words are effectively treated as if they are whitespace separated
194
+ (which may actually be somewhat inaccurate for languages such as Chinese or for MWT)
195
+ """
196
+ forward = self.is_forward_lm
197
+ vocab = self.char_vocab()
198
+ device = next(self.parameters()).device
199
+
200
+ all_data = []
201
+ for idx, words in enumerate(sentences):
202
+ if not forward:
203
+ words = [x[::-1] for x in reversed(words)]
204
+
205
+ chars = [CHARLM_START]
206
+ offsets = []
207
+ for w in words:
208
+ chars.extend(w)
209
+ chars.append(CHARLM_END)
210
+ offsets.append(len(chars) - 1)
211
+ if not forward:
212
+ offsets.reverse()
213
+ chars = vocab.map(chars)
214
+ all_data.append((chars, offsets, len(chars), len(all_data)))
215
+
216
+ all_data.sort(key=itemgetter(2), reverse=True)
217
+ chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data))
218
+ # TODO: can this be faster?
219
+ chars = get_long_tensor(chars, len(all_data), pad_id=vocab.unit2id(CHARLM_END)).to(device=device)
220
+
221
+ with torch.no_grad():
222
+ output, _, _ = self.forward(chars, char_lens)
223
+ res = [output[i, offsets] for i, offsets in enumerate(char_offsets)]
224
+ res = unsort(res, orig_idx)
225
+
226
+ return res
227
+
228
+ def hidden_dim(self):
229
+ return self.args['char_hidden_dim']
230
+
231
+ def char_vocab(self):
232
+ return self.vocab['char']
233
+
234
+ def train(self, mode=True):
235
+ """
236
+ Override the default train() function, so that when self.finetune == False, the training mode
237
+ won't be impacted by the parent models' status change.
238
+ """
239
+ if not mode: # eval() is always allowed, regardless of finetune status
240
+ super().train(mode)
241
+ else:
242
+ if self.finetune: # only set to training mode in finetune status
243
+ super().train(mode)
244
+
245
+ def full_state(self):
246
+ state = {
247
+ 'vocab': self.vocab['char'].state_dict(),
248
+ 'args': self.args,
249
+ 'state_dict': self.state_dict(),
250
+ 'pad': self.pad,
251
+ 'is_forward_lm': self.is_forward_lm
252
+ }
253
+ return state
254
+
255
+ def save(self, filename):
256
+ os.makedirs(os.path.split(filename)[0], exist_ok=True)
257
+ state = self.full_state()
258
+ torch.save(state, filename, _use_new_zipfile_serialization=False)
259
+
260
+ @classmethod
261
+ def from_full_state(cls, state, finetune=False):
262
+ vocab = {'char': CharVocab.load_state_dict(state['vocab'])}
263
+ model = cls(state['args'], vocab, state['pad'], state['is_forward_lm'])
264
+ model.load_state_dict(state['state_dict'])
265
+ model.eval()
266
+ model.finetune = finetune # set finetune status
267
+ return model
268
+
269
+ @classmethod
270
+ def load(cls, filename, finetune=False):
271
+ state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
272
+ # allow saving just the Model object,
273
+ # and allow for old charlms to still work
274
+ if 'state_dict' in state:
275
+ return cls.from_full_state(state, finetune)
276
+ return cls.from_full_state(state['model'], finetune)
277
+
278
+ class CharacterLanguageModelWordAdapter(nn.Module):
279
+ """
280
+ Adapts a character model to return embeddings for each character in a word
281
+ """
282
+ def __init__(self, charlms):
283
+ super().__init__()
284
+ self.charlms = charlms
285
+
286
+ def forward(self, words):
287
+ words = [CHARLM_START + x + CHARLM_END for x in words]
288
+ padded_reps = []
289
+ for charlm in self.charlms:
290
+ rep = charlm.per_char_representation(words)
291
+ padded_rep = torch.zeros(len(rep), max(x.shape[0] for x in rep), rep[0].shape[1], dtype=rep[0].dtype, device=rep[0].device)
292
+ for idx, row in enumerate(rep):
293
+ padded_rep[idx, :row.shape[0], :] = row
294
+ padded_reps.append(padded_rep)
295
+ padded_rep = torch.cat(padded_reps, dim=2)
296
+ return padded_rep
297
+
298
+ def hidden_dim(self):
299
+ return sum(charlm.hidden_dim() for charlm in self.charlms)
300
+
301
+ class CharacterLanguageModelTrainer():
302
+ def __init__(self, model, params, optimizer, criterion, scheduler, epoch=1, global_step=0):
303
+ self.model = model
304
+ self.params = params
305
+ self.optimizer = optimizer
306
+ self.criterion = criterion
307
+ self.scheduler = scheduler
308
+ self.epoch = epoch
309
+ self.global_step = global_step
310
+
311
+ def save(self, filename, full=True):
312
+ os.makedirs(os.path.split(filename)[0], exist_ok=True)
313
+ state = {
314
+ 'model': self.model.full_state(),
315
+ 'epoch': self.epoch,
316
+ 'global_step': self.global_step,
317
+ }
318
+ if full and self.optimizer is not None:
319
+ state['optimizer'] = self.optimizer.state_dict()
320
+ if full and self.criterion is not None:
321
+ state['criterion'] = self.criterion.state_dict()
322
+ if full and self.scheduler is not None:
323
+ state['scheduler'] = self.scheduler.state_dict()
324
+ torch.save(state, filename, _use_new_zipfile_serialization=False)
325
+
326
+ @classmethod
327
+ def from_new_model(cls, args, vocab):
328
+ model = CharacterLanguageModel(args, vocab, is_forward_lm=True if args['direction'] == 'forward' else False)
329
+ model = model.to(args['device'])
330
+ params = [param for param in model.parameters() if param.requires_grad]
331
+ optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
332
+ criterion = torch.nn.CrossEntropyLoss()
333
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=args['anneal'], patience=args['patience'])
334
+ return cls(model, params, optimizer, criterion, scheduler)
335
+
336
+
337
+ @classmethod
338
+ def load(cls, args, filename, finetune=False):
339
+ """
340
+ Load the model along with any other saved state for training
341
+
342
+ Note that you MUST set finetune=True if planning to continue training
343
+ Otherwise the only benefit you will get will be a warm GPU
344
+ """
345
+ state = torch.load(filename, lambda storage, loc: storage, weights_only=True)
346
+ model = CharacterLanguageModel.from_full_state(state['model'], finetune)
347
+ model = model.to(args['device'])
348
+
349
+ params = [param for param in model.parameters() if param.requires_grad]
350
+ optimizer = torch.optim.SGD(params, lr=args['lr0'], momentum=args['momentum'], weight_decay=args['weight_decay'])
351
+ if 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])
352
+
353
+ criterion = torch.nn.CrossEntropyLoss()
354
+ if 'criterion' in state: criterion.load_state_dict(state['criterion'])
355
+
356
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, factor=args['anneal'], patience=args['patience'])
357
+ if 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
358
+
359
+ epoch = state.get('epoch', 1)
360
+ global_step = state.get('global_step', 0)
361
+ return cls(model, params, optimizer, criterion, scheduler, epoch, global_step)
362
+
stanza/stanza/models/common/chuliu_edmonds.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from Tim's code here: https://github.com/tdozat/Parser-v3/blob/master/scripts/chuliu_edmonds.py
2
+
3
+ import numpy as np
4
+
5
+ def tarjan(tree):
6
+ """Finds the cycles in a dependency graph
7
+
8
+ The input should be a numpy array of integers,
9
+ where in the standard use case,
10
+ tree[i] is the head of node i.
11
+
12
+ tree[0] == 0 to represent the root
13
+
14
+ so for example, for the English sentence "This is a test",
15
+ the input is
16
+
17
+ [0 4 4 4 0]
18
+
19
+ "Arthritis makes my hip hurt"
20
+
21
+ [0 2 0 4 2 2]
22
+
23
+ The return is a list of cycles, where in cycle has True if the
24
+ node at that index is participating in the cycle.
25
+ So, for example, the previous examples both return empty lists,
26
+ whereas an input of
27
+ np.array([0, 3, 1, 2])
28
+ has an output of
29
+ [np.array([False, True, True, True])]
30
+ """
31
+ indices = -np.ones_like(tree)
32
+ lowlinks = -np.ones_like(tree)
33
+ onstack = np.zeros_like(tree, dtype=bool)
34
+ stack = list()
35
+ _index = [0]
36
+ cycles = []
37
+ #-------------------------------------------------------------
38
+ def maybe_pop_cycle(i):
39
+ if lowlinks[i] == indices[i]:
40
+ # There's a cycle!
41
+ cycle = np.zeros_like(indices, dtype=bool)
42
+ while stack[-1] != i:
43
+ j = stack.pop()
44
+ onstack[j] = False
45
+ cycle[j] = True
46
+ stack.pop()
47
+ onstack[i] = False
48
+ cycle[i] = True
49
+ if cycle.sum() > 1:
50
+ cycles.append(cycle)
51
+
52
+ def initialize_strong_connect(i):
53
+ _index[0] += 1
54
+ index = _index[-1]
55
+ indices[i] = lowlinks[i] = index - 1
56
+ stack.append(i)
57
+ onstack[i] = True
58
+
59
+ def strong_connect(i):
60
+ # this ridiculous atrocity is because somehow people keep
61
+ # coming up with graphs which overflow python's call stack
62
+ # so instead we make our own call stack and turn the recursion
63
+ # into a loop
64
+ # see for example
65
+ # https://github.com/stanfordnlp/stanza/issues/962
66
+ # https://github.com/spraakbanken/sparv-pipeline/issues/166
67
+ # in an ideal world this block of code would look like this
68
+ # initialize_strong_connect(i)
69
+ # dependents = iter(np.where(np.equal(tree, i))[0])
70
+ # for j in dependents:
71
+ # if indices[j] == -1:
72
+ # strong_connect(j)
73
+ # lowlinks[i] = min(lowlinks[i], lowlinks[j])
74
+ # elif onstack[j]:
75
+ # lowlinks[i] = min(lowlinks[i], indices[j])
76
+ #
77
+ # maybe_pop_cycle(i)
78
+ call_stack = [(i, None, None)]
79
+ while len(call_stack) > 0:
80
+ i, dependents_iterator, j = call_stack.pop()
81
+ if dependents_iterator is None: # first time getting here for this i
82
+ initialize_strong_connect(i)
83
+ dependents_iterator = iter(np.where(np.equal(tree, i))[0])
84
+ else: # been here before. j was the dependent we were just considering
85
+ lowlinks[i] = min(lowlinks[i], lowlinks[j])
86
+ for j in dependents_iterator:
87
+ if indices[j] == -1:
88
+ # have to remember where we were...
89
+ # put the current iterator & its state on the "call stack"
90
+ # we will come back to it later
91
+ call_stack.append((i, dependents_iterator, j))
92
+ # also, this is what we do next...
93
+ call_stack.append((j, None, None))
94
+ # this will break this iterator for now
95
+ # the next time through, we will continue progressing this iterator
96
+ break
97
+ elif onstack[j]:
98
+ lowlinks[i] = min(lowlinks[i], indices[j])
99
+ else:
100
+ # this is an intended use of for/else
101
+ # please stop filing git issues on obscure language features
102
+ # we finished iterating without a break
103
+ # and can finally resolve any possible cycles
104
+ maybe_pop_cycle(i)
105
+ # at this point, there are two cases:
106
+ #
107
+ # we iterated all the way through an iterator (the else in the for/else)
108
+ # and have resolved any possible cycles. can then proceed to the previous
109
+ # iterator we were considering (or finish, if there are no others)
110
+ # OR
111
+ # we have hit a break in the iteration over the dependents
112
+ # for a node
113
+ # and we need to dig deeper into the graph and resolve the dependent's dependents
114
+ # before we can continue the previous node
115
+ #
116
+ # either way, we check to see if there are unfinished subtrees
117
+ # when that is finally done, we can return
118
+
119
+ #-------------------------------------------------------------
120
+ for i in range(len(tree)):
121
+ if indices[i] == -1:
122
+ strong_connect(i)
123
+ return cycles
124
+
125
+ def process_cycle(tree, cycle, scores):
126
+ """
127
+ Build a subproblem with one cycle broken
128
+ """
129
+ # indices of cycle in original tree; (c) in t
130
+ cycle_locs = np.where(cycle)[0]
131
+ # heads of cycle in original tree; (c) in t
132
+ cycle_subtree = tree[cycle]
133
+ # scores of cycle in original tree; (c) in R
134
+ cycle_scores = scores[cycle, cycle_subtree]
135
+ # total score of cycle; () in R
136
+ cycle_score = cycle_scores.sum()
137
+
138
+ # locations of noncycle; (t) in [0,1]
139
+ noncycle = np.logical_not(cycle)
140
+ # indices of noncycle in original tree; (n) in t
141
+ noncycle_locs = np.where(noncycle)[0]
142
+ #print(cycle_locs, noncycle_locs)
143
+
144
+ # scores of cycle's potential heads; (c x n) - (c) + () -> (n x c) in R
145
+ metanode_head_scores = scores[cycle][:,noncycle] - cycle_scores[:,None] + cycle_score
146
+ # scores of cycle's potential dependents; (n x c) in R
147
+ metanode_dep_scores = scores[noncycle][:,cycle]
148
+ # best noncycle head for each cycle dependent; (n) in c
149
+ metanode_heads = np.argmax(metanode_head_scores, axis=0)
150
+ # best cycle head for each noncycle dependent; (n) in c
151
+ metanode_deps = np.argmax(metanode_dep_scores, axis=1)
152
+
153
+ # scores of noncycle graph; (n x n) in R
154
+ subscores = scores[noncycle][:,noncycle]
155
+ # pad to contracted graph; (n+1 x n+1) in R
156
+ subscores = np.pad(subscores, ( (0,1) , (0,1) ), 'constant')
157
+ # set the contracted graph scores of cycle's potential heads; (c x n)[:, (n) in n] in R -> (n) in R
158
+ subscores[-1, :-1] = metanode_head_scores[metanode_heads, np.arange(len(noncycle_locs))]
159
+ # set the contracted graph scores of cycle's potential dependents; (n x c)[(n) in n] in R-> (n) in R
160
+ subscores[:-1,-1] = metanode_dep_scores[np.arange(len(noncycle_locs)), metanode_deps]
161
+ return subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps
162
+
163
+
164
+ def expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps):
165
+ """
166
+ Given a partially solved tree with a cycle and a solved subproblem
167
+ for the cycle, build a larger solution without the cycle
168
+ """
169
+ # head of the cycle; () in n
170
+ #print(contracted_tree)
171
+ cycle_head = contracted_tree[-1]
172
+ # fixed tree: (n) in n+1
173
+ contracted_tree = contracted_tree[:-1]
174
+ # initialize new tree; (t) in 0
175
+ new_tree = -np.ones_like(tree)
176
+ #print(0, new_tree)
177
+ # fixed tree with no heads coming from the cycle: (n) in [0,1]
178
+ contracted_subtree = contracted_tree < len(contracted_tree)
179
+ # add the nodes to the new tree (t)[(n)[(n) in [0,1]] in t] in t = (n)[(n)[(n) in [0,1]] in n] in t
180
+ new_tree[noncycle_locs[contracted_subtree]] = noncycle_locs[contracted_tree[contracted_subtree]]
181
+ #print(1, new_tree)
182
+ # fixed tree with heads coming from the cycle: (n) in [0,1]
183
+ contracted_subtree = np.logical_not(contracted_subtree)
184
+ # add the nodes to the tree (t)[(n)[(n) in [0,1]] in t] in t = (c)[(n)[(n) in [0,1]] in c] in t
185
+ new_tree[noncycle_locs[contracted_subtree]] = cycle_locs[metanode_deps[contracted_subtree]]
186
+ #print(2, new_tree)
187
+ # add the old cycle to the tree; (t)[(c) in t] in t = (t)[(c) in t] in t
188
+ new_tree[cycle_locs] = tree[cycle_locs]
189
+ #print(3, new_tree)
190
+ # root of the cycle; (n)[() in n] in c = () in c
191
+ cycle_root = metanode_heads[cycle_head]
192
+ # add the root of the cycle to the new tree; (t)[(c)[() in c] in t] = (c)[() in c]
193
+ new_tree[cycle_locs[cycle_root]] = noncycle_locs[cycle_head]
194
+ #print(4, new_tree)
195
+ return new_tree
196
+
197
+ def prepare_scores(scores):
198
+ """
199
+ Alter the scores matrix to avoid self loops and handle the root
200
+ """
201
+ # prevent self-loops, set up the root location
202
+ np.fill_diagonal(scores, -float('inf')) # prevent self-loops
203
+ scores[0] = -float('inf')
204
+ scores[0,0] = 0
205
+
206
+ def chuliu_edmonds(scores):
207
+ subtree_stack = []
208
+
209
+ prepare_scores(scores)
210
+ tree = np.argmax(scores, axis=1)
211
+ cycles = tarjan(tree)
212
+
213
+ #print(scores)
214
+ #print(cycles)
215
+
216
+ # recursive implementation:
217
+ #if cycles:
218
+ # # t = len(tree); c = len(cycle); n = len(noncycle)
219
+ # # cycles.pop(): locations of cycle; (t) in [0,1]
220
+ # subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)
221
+ # # MST with contraction; (n+1) in n+1
222
+ # contracted_tree = chuliu_edmonds(subscores)
223
+ # tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)
224
+ # unfortunately, while the recursion is simpler to understand, it can get too deep for python's stack limit
225
+ # so instead we make our own recursion, with blackjack and (you know how it goes)
226
+
227
+ while cycles:
228
+ # t = len(tree); c = len(cycle); n = len(noncycle)
229
+ # cycles.pop(): locations of cycle; (t) in [0,1]
230
+ subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps = process_cycle(tree, cycles.pop(), scores)
231
+ subtree_stack.append((tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps))
232
+
233
+ scores = subscores
234
+ prepare_scores(scores)
235
+ tree = np.argmax(scores, axis=1)
236
+ cycles = tarjan(tree)
237
+
238
+ while len(subtree_stack) > 0:
239
+ contracted_tree = tree
240
+ (tree, cycles, scores, subscores, cycle_locs, noncycle_locs, metanode_heads, metanode_deps) = subtree_stack.pop()
241
+ tree = expand_contracted_tree(tree, contracted_tree, cycle_locs, noncycle_locs, metanode_heads, metanode_deps)
242
+
243
+ return tree
244
+
245
+ #===============================================================
246
+ def chuliu_edmonds_one_root(scores):
247
+ """"""
248
+
249
+ scores = scores.astype(np.float64)
250
+ tree = chuliu_edmonds(scores)
251
+ roots_to_try = np.where(np.equal(tree[1:], 0))[0]+1
252
+ if len(roots_to_try) == 1:
253
+ return tree
254
+
255
+ #-------------------------------------------------------------
256
+ def set_root(scores, root):
257
+ root_score = scores[root,0]
258
+ scores = np.array(scores)
259
+ scores[1:,0] = -float('inf')
260
+ scores[root] = -float('inf')
261
+ scores[root,0] = 0
262
+ return scores, root_score
263
+ #-------------------------------------------------------------
264
+
265
+ best_score, best_tree = -np.inf, None # This is what's causing it to crash
266
+ for root in roots_to_try:
267
+ _scores, root_score = set_root(scores, root)
268
+ _tree = chuliu_edmonds(_scores)
269
+ tree_probs = _scores[np.arange(len(_scores)), _tree]
270
+ tree_score = (tree_probs).sum()+(root_score) if (tree_probs > -np.inf).all() else -np.inf
271
+ if tree_score > best_score:
272
+ best_score = tree_score
273
+ best_tree = _tree
274
+ try:
275
+ assert best_tree is not None
276
+ except:
277
+ with open('debug.log', 'w') as f:
278
+ f.write('{}: {}, {}\n'.format(tree, scores, roots_to_try))
279
+ f.write('{}: {}, {}, {}\n'.format(_tree, _scores, tree_probs, tree_score))
280
+ raise
281
+ return best_tree
stanza/stanza/models/common/constant.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global constants.
3
+
4
+ These language codes mirror UD language codes when possible
5
+ """
6
+
7
+ import re
8
+
9
+ class UnknownLanguageError(ValueError):
10
+ pass
11
+
12
+ # tuples in a list so we can assert that the langcodes are all unique
13
+ # When applicable, we favor the UD decision over any other possible
14
+ # language code or language name
15
+ # ISO 639-1 is out of date, but many of the UD datasets are labeled
16
+ # using the two letter abbreviations, so we add those for non-UD
17
+ # languages in the hopes that we've guessed right if those languages
18
+ # are eventually processed
19
+ lcode2lang_raw = [
20
+ ("abq", "Abaza"),
21
+ ("ab", "Abkhazian"),
22
+ ("aa", "Afar"),
23
+ ("af", "Afrikaans"),
24
+ ("ak", "Akan"),
25
+ ("akk", "Akkadian"),
26
+ ("aqz", "Akuntsu"),
27
+ ("sq", "Albanian"),
28
+ ("am", "Amharic"),
29
+ ("grc", "Ancient_Greek"),
30
+ ("hbo", "Ancient_Hebrew"),
31
+ ("apu", "Apurina"),
32
+ ("ar", "Arabic"),
33
+ ("arz", "Egyptian_Arabic"),
34
+ ("an", "Aragonese"),
35
+ ("hy", "Armenian"),
36
+ ("as", "Assamese"),
37
+ ("aii", "Assyrian"),
38
+ ("ast", "Asturian"),
39
+ ("av", "Avaric"),
40
+ ("ae", "Avestan"),
41
+ ("ay", "Aymara"),
42
+ ("az", "Azerbaijani"),
43
+ ("bm", "Bambara"),
44
+ ("ba", "Bashkir"),
45
+ ("eu", "Basque"),
46
+ ("bar", "Bavarian"),
47
+ ("bej", "Beja"),
48
+ ("be", "Belarusian"),
49
+ ("bn", "Bengali"),
50
+ ("bho", "Bhojpuri"),
51
+ ("bpy", "Bishnupriya_Manipuri"),
52
+ ("bi", "Bislama"),
53
+ ("bor", "Bororo"),
54
+ ("bs", "Bosnian"),
55
+ ("br", "Breton"),
56
+ ("bg", "Bulgarian"),
57
+ ("bxr", "Buryat"),
58
+ ("yue", "Cantonese"),
59
+ ("cpg", "Cappadocian"),
60
+ ("ca", "Catalan"),
61
+ ("ceb", "Cebuano"),
62
+ ("km", "Central_Khmer"),
63
+ ("ch", "Chamorro"),
64
+ ("ce", "Chechen"),
65
+ ("ny", "Chichewa"),
66
+ ("ckt", "Chukchi"),
67
+ ("cv", "Chuvash"),
68
+ ("xcl", "Classical_Armenian"),
69
+ ("lzh", "Classical_Chinese"),
70
+ ("cop", "Coptic"),
71
+ ("kw", "Cornish"),
72
+ ("co", "Corsican"),
73
+ ("cr", "Cree"),
74
+ ("hr", "Croatian"),
75
+ ("cs", "Czech"),
76
+ ("da", "Danish"),
77
+ ("dar", "Dargwa"),
78
+ ("dv", "Dhivehi"),
79
+ ("nl", "Dutch"),
80
+ ("dz", "Dzongkha"),
81
+ ("egy", "Egyptian"),
82
+ ("en", "English"),
83
+ ("myv", "Erzya"),
84
+ ("eo", "Esperanto"),
85
+ ("et", "Estonian"),
86
+ ("ee", "Ewe"),
87
+ ("ext", "Extremaduran"),
88
+ ("fo", "Faroese"),
89
+ ("fj", "Fijian"),
90
+ ("fi", "Finnish"),
91
+ ("fon", "Fon"),
92
+ ("fr", "French"),
93
+ ("qfn", "Frisian_Dutch"),
94
+ ("ff", "Fulah"),
95
+ ("gl", "Galician"),
96
+ ("lg", "Ganda"),
97
+ ("ka", "Georgian"),
98
+ ("de", "German"),
99
+ ("aln", "Gheg"),
100
+ ("bbj", "Ghomálá'"),
101
+ ("got", "Gothic"),
102
+ ("el", "Greek"),
103
+ ("kl", "Greenlandic"),
104
+ ("gub", "Guajajara"),
105
+ ("gn", "Guarani"),
106
+ ("gu", "Gujarati"),
107
+ ("gwi", "Gwichin"),
108
+ ("ht", "Haitian"),
109
+ ("ha", "Hausa"),
110
+ ("he", "Hebrew"),
111
+ ("hz", "Herero"),
112
+ ("azz", "Highland_Puebla_Nahuatl"),
113
+ ("hil", "Hiligaynon"),
114
+ ("hi", "Hindi"),
115
+ ("qhe", "Hindi_English"),
116
+ ("ho", "Hiri_Motu"),
117
+ ("hit", "Hittite"),
118
+ ("hu", "Hungarian"),
119
+ ("is", "Icelandic"),
120
+ ("io", "Ido"),
121
+ ("ig", "Igbo"),
122
+ ("ilo", "Ilocano"),
123
+ ("arc", "Imperial_Aramaic"),
124
+ ("id", "Indonesian"),
125
+ ("iu", "Inuktitut"),
126
+ ("ik", "Inupiaq"),
127
+ ("ga", "Irish"),
128
+ ("it", "Italian"),
129
+ ("ja", "Japanese"),
130
+ ("jv", "Javanese"),
131
+ ("urb", "Kaapor"),
132
+ ("kab", "Kabyle"),
133
+ ("xnr", "Kangri"),
134
+ ("kn", "Kannada"),
135
+ ("kr", "Kanuri"),
136
+ ("pam", "Kapampangan"),
137
+ ("krl", "Karelian"),
138
+ ("arr", "Karo"),
139
+ ("ks", "Kashmiri"),
140
+ ("kk", "Kazakh"),
141
+ ("kfm", "Khunsari"),
142
+ ("quc", "Kiche"),
143
+ ("cgg", "Kiga"),
144
+ ("ki", "Kikuyu"),
145
+ ("rw", "Kinyarwanda"),
146
+ ("ky", "Kyrgyz"),
147
+ ("kv", "Komi"),
148
+ ("koi", "Komi_Permyak"),
149
+ ("kpv", "Komi_Zyrian"),
150
+ ("kg", "Kongo"),
151
+ ("ko", "Korean"),
152
+ ("ku", "Kurdish"),
153
+ ("kmr", "Kurmanji"),
154
+ ("kj", "Kwanyama"),
155
+ ("lad", "Ladino"),
156
+ ("lo", "Lao"),
157
+ ("ltg", "Latgalian"),
158
+ ("la", "Latin"),
159
+ ("lv", "Latvian"),
160
+ ("lij", "Ligurian"),
161
+ ("li", "Limburgish"),
162
+ ("ln", "Lingala"),
163
+ ("lt", "Lithuanian"),
164
+ ("liv", "Livonian"),
165
+ ("olo", "Livvi"),
166
+ ("nds", "Low_Saxon"),
167
+ ("lu", "Luba_Katanga"),
168
+ ("lb", "Luxembourgish"),
169
+ ("mk", "Macedonian"),
170
+ ("jaa", "Madi"),
171
+ ("mag", "Magahi"),
172
+ ("qaf", "Maghrebi_Arabic_French"),
173
+ ("mai", "Maithili"),
174
+ ("mpu", "Makurap"),
175
+ ("mg", "Malagasy"),
176
+ ("ms", "Malay"),
177
+ ("ml", "Malayalam"),
178
+ ("mt", "Maltese"),
179
+ ("mjl", "Mandyali"),
180
+ ("gv", "Manx"),
181
+ ("mi", "Maori"),
182
+ ("mr", "Marathi"),
183
+ ("mh", "Marshallese"),
184
+ ("mzn", "Mazandarani"),
185
+ ("gun", "Mbya_Guarani"),
186
+ ("enm", "Middle_English"),
187
+ ("frm", "Middle_French"),
188
+ ("min", "Minangkabau"),
189
+ ("xmf", "Mingrelian"),
190
+ ("mwl", "Mirandese"),
191
+ ("mdf", "Moksha"),
192
+ ("mn", "Mongolian"),
193
+ ("mos", "Mossi"),
194
+ ("myu", "Munduruku"),
195
+ ("my", "Myanmar"),
196
+ ("nqo", "N'Ko"),
197
+ ("nah", "Nahuatl"),
198
+ ("pcm", "Naija"),
199
+ ("na", "Nauru"),
200
+ ("nv", "Navajo"),
201
+ ("nyq", "Nayini"),
202
+ ("ng", "Ndonga"),
203
+ ("nap", "Neapolitan"),
204
+ ("ne", "Nepali"),
205
+ ("new", "Newar"),
206
+ ("yrl", "Nheengatu"),
207
+ ("nyn", "Nkore"),
208
+ ("frr", "North_Frisian"),
209
+ ("nd", "North_Ndebele"),
210
+ ("sme", "North_Sami"),
211
+ ("nso", "Northern_Sotho"),
212
+ ("gya", "Northwest_Gbaya"),
213
+ ("nb", "Norwegian_Bokmaal"),
214
+ ("nn", "Norwegian_Nynorsk"),
215
+ ("ii", "Nuosu"),
216
+ ("oc", "Occitan"),
217
+ ("or", "Odia"),
218
+ ("oj", "Ojibwa"),
219
+ ("cu", "Old_Church_Slavonic"),
220
+ ("orv", "Old_East_Slavic"),
221
+ ("ang", "Old_English"),
222
+ ("fro", "Old_French"),
223
+ ("sga", "Old_Irish"),
224
+ ("ojp", "Old_Japanese"),
225
+ ("otk", "Old_Turkish"),
226
+ ("om", "Oromo"),
227
+ ("os", "Ossetian"),
228
+ ("ota", "Ottoman_Turkish"),
229
+ ("pi", "Pali"),
230
+ ("ps", "Pashto"),
231
+ ("pad", "Paumari"),
232
+ ("fa", "Persian"),
233
+ ("pay", "Pesh"),
234
+ ("xpg", "Phrygian"),
235
+ ("pbv", "Pnar"),
236
+ ("pl", "Polish"),
237
+ ("qpm", "Pomak"),
238
+ ("pnt", "Pontic"),
239
+ ("pt", "Portuguese"),
240
+ ("pra", "Prakrit"),
241
+ ("pa", "Punjabi"),
242
+ ("qu", "Quechua"),
243
+ ("rhg", "Rohingya"),
244
+ ("ro", "Romanian"),
245
+ ("rm", "Romansh"),
246
+ ("rn", "Rundi"),
247
+ ("ru", "Russian"),
248
+ ("sm", "Samoan"),
249
+ ("sg", "Sango"),
250
+ ("sa", "Sanskrit"),
251
+ ("skr", "Saraiki"),
252
+ ("sc", "Sardinian"),
253
+ ("sco", "Scots"),
254
+ ("gd", "Scottish_Gaelic"),
255
+ ("sr", "Serbian"),
256
+ ("sn", "Shona"),
257
+ ("zh-hans", "Simplified_Chinese"),
258
+ ("sd", "Sindhi"),
259
+ ("si", "Sinhala"),
260
+ ("sms", "Skolt_Sami"),
261
+ ("sk", "Slovak"),
262
+ ("sl", "Slovenian"),
263
+ ("soj", "Soi"),
264
+ ("so", "Somali"),
265
+ ("ckb", "Sorani"),
266
+ ("ajp", "South_Levantine_Arabic"),
267
+ ("nr", "South_Ndebele"),
268
+ ("st", "Southern_Sotho"),
269
+ ("es", "Spanish"),
270
+ ("ssp", "Spanish_Sign_Language"),
271
+ ("su", "Sundanese"),
272
+ ("sw", "Swahili"),
273
+ ("ss", "Swati"),
274
+ ("sv", "Swedish"),
275
+ ("swl", "Swedish_Sign_Language"),
276
+ ("gsw", "Swiss_German"),
277
+ ("syr", "Syriac"),
278
+ ("tl", "Tagalog"),
279
+ ("ty", "Tahitian"),
280
+ ("tg", "Tajik"),
281
+ ("ta", "Tamil"),
282
+ ("tt", "Tatar"),
283
+ ("eme", "Teko"),
284
+ ("te", "Telugu"),
285
+ ("qte", "Telugu_English"),
286
+ ("th", "Thai"),
287
+ ("bo", "Tibetan"),
288
+ ("ti", "Tigrinya"),
289
+ ("to", "Tonga"),
290
+ ("zh-hant", "Traditional_Chinese"),
291
+ ("ts", "Tsonga"),
292
+ ("tn", "Tswana"),
293
+ ("tpn", "Tupinamba"),
294
+ ("tr", "Turkish"),
295
+ ("qtd", "Turkish_German"),
296
+ ("tk", "Turkmen"),
297
+ ("tw", "Twi"),
298
+ ("uk", "Ukrainian"),
299
+ ("xum", "Umbrian"),
300
+ ("hsb", "Upper_Sorbian"),
301
+ ("ur", "Urdu"),
302
+ ("ug", "Uyghur"),
303
+ ("uz", "Uzbek"),
304
+ ("ve", "Venda"),
305
+ ("vep", "Veps"),
306
+ ("vi", "Vietnamese"),
307
+ ("vo", "Volapük"),
308
+ ("wa", "Walloon"),
309
+ ("war", "Waray"),
310
+ ("wbp", "Warlpiri"),
311
+ ("cy", "Welsh"),
312
+ ("hyw", "Western_Armenian"),
313
+ ("fy", "Western_Frisian"),
314
+ ("nhi", "Western_Sierra_Puebla_Nahuatl"),
315
+ ("wo", "Wolof"),
316
+ ("xav", "Xavante"),
317
+ ("xh", "Xhosa"),
318
+ ("sjo", "Xibe"),
319
+ ("sah", "Yakut"),
320
+ ("yi", "Yiddish"),
321
+ ("yo", "Yoruba"),
322
+ ("ess", "Yupik"),
323
+ ("say", "Zaar"),
324
+ ("zza", "Zazaki"),
325
+ ("zea", "Zeelandic"),
326
+ ("za", "Zhuang"),
327
+ ("zu", "Zulu"),
328
+ ]
329
+
330
+ # build the dictionary, checking for duplicate language codes
331
+ lcode2lang = {}
332
+ for code, language in lcode2lang_raw:
333
+ assert code not in lcode2lang
334
+ lcode2lang[code] = language
335
+
336
+ # invert the dictionary, checking for possible duplicate language names
337
+ lang2lcode = {}
338
+ for code, language in lcode2lang_raw:
339
+ assert language not in lang2lcode
340
+ lang2lcode[language] = code
341
+
342
+ # check that nothing got clobbered
343
+ assert len(lcode2lang_raw) == len(lcode2lang)
344
+ assert len(lcode2lang_raw) == len(lang2lcode)
345
+
346
+ # some of the two letter langcodes get used elsewhere as three letters
347
+ # for example, Wolof is abbreviated "wo" in UD, but "wol" in Masakhane NER
348
+ two_to_three_letters_raw = (
349
+ ("bm", "bam"),
350
+ ("ee", "ewe"),
351
+ ("ha", "hau"),
352
+ ("ig", "ibo"),
353
+ ("rw", "kin"),
354
+ ("lg", "lug"),
355
+ ("ny", "nya"),
356
+ ("sn", "sna"),
357
+ ("sw", "swa"),
358
+ ("tn", "tsn"),
359
+ ("tw", "twi"),
360
+ ("wo", "wol"),
361
+ ("xh", "xho"),
362
+ ("yo", "yor"),
363
+ ("zu", "zul"),
364
+
365
+ # this is a weird case where a 2 letter code was available,
366
+ # but UD used the 3 letter code instead
367
+ ("se", "sme"),
368
+ )
369
+
370
+ for two, three in two_to_three_letters_raw:
371
+ if two in lcode2lang:
372
+ assert two in lcode2lang
373
+ assert three not in lcode2lang
374
+ assert three not in lang2lcode
375
+ lang2lcode[three] = two
376
+ lcode2lang[three] = lcode2lang[two]
377
+ elif three in lcode2lang:
378
+ assert three in lcode2lang
379
+ assert two not in lcode2lang
380
+ assert two not in lang2lcode
381
+ lang2lcode[two] = three
382
+ lcode2lang[two] = lcode2lang[three]
383
+ else:
384
+ raise AssertionError("Found a proposed alias %s -> %s when neither code was already known" % (two, three))
385
+
386
+ two_to_three_letters = {
387
+ two: three for two, three in two_to_three_letters_raw
388
+ }
389
+
390
+ three_to_two_letters = {
391
+ three: two for two, three in two_to_three_letters_raw
392
+ }
393
+
394
+ assert len(two_to_three_letters) == len(two_to_three_letters_raw)
395
+ assert len(three_to_two_letters) == len(two_to_three_letters_raw)
396
+
397
+ # additional useful code to language mapping
398
+ # added after dict invert to avoid conflict
399
+ lcode2lang['nb'] = 'Norwegian' # Norwegian Bokmall mapped to default norwegian
400
+ lcode2lang['no'] = 'Norwegian'
401
+ lcode2lang['zh'] = 'Simplified_Chinese'
402
+
403
+ extra_lang_to_lcodes = [
404
+ ("ab", "Abkhaz"),
405
+ ("gsw", "Alemannic"),
406
+ ("my", "Burmese"),
407
+ ("ckb", "Central_Kurdish"),
408
+ ("ny", "Chewa"),
409
+ ("zh", "Chinese"),
410
+ ("za", "Chuang"),
411
+ ("dv", "Divehi"),
412
+ ("eme", "Emerillon"),
413
+ ("lij", "Genoese"),
414
+ ("ga", "Gaelic"),
415
+ ("ne", "Gorkhali"),
416
+ ("ht", "Haitian_Creole"),
417
+ ("ilo", "Ilokano"),
418
+ ("nr", "isiNdebele"),
419
+ ("xh", "isiXhosa"),
420
+ ("zu", "isiZulu"),
421
+ ("jaa", "Jamamadí"),
422
+ ("kab", "Kabylian"),
423
+ ("kl", "Kalaallisut"),
424
+ ("km", "Khmer"),
425
+ ("ky", "Kirghiz"),
426
+ ("lb", "Letzeburgesch"),
427
+ ("lg", "Luganda"),
428
+ ("jaa", "Madí"),
429
+ ("dv", "Maldivian"),
430
+ ("mjl", "Mandeali"),
431
+ ("skr", "Multani"),
432
+ ("nb", "Norwegian"),
433
+ ("ny", "Nyanja"),
434
+ ("sga", "Old_Gaelic"),
435
+ ("or", "Oriya"),
436
+ ("arr", "Ramarama"),
437
+ ("sah", "Sakha"),
438
+ ("nso", "Sepedi"),
439
+ ("tn", "Setswana"),
440
+ ("ii", "Sichuan_Yi"),
441
+ ("si", "Sinhalese"),
442
+ ("ss", "Siswati"),
443
+ ("soj", "Sohi"),
444
+ ("st", "Sesotho"),
445
+ ("ve", "Tshivenda"),
446
+ ("ts", "Xitsonga"),
447
+ ("fy", "West_Frisian"),
448
+ ("zza", "Zaza"),
449
+ ]
450
+
451
+ for code, language in extra_lang_to_lcodes:
452
+ assert language not in lang2lcode
453
+ assert code in lcode2lang
454
+ lang2lcode[language] = code
455
+
456
+ # treebank names changed from Old Russian to Old East Slavic in 2.8
457
+ lang2lcode['Old_Russian'] = 'orv'
458
+
459
+ # build a lowercase map from language to langcode
460
+ langlower2lcode = {}
461
+ for k in lang2lcode:
462
+ langlower2lcode[k.lower()] = lang2lcode[k]
463
+
464
+ treebank_special_cases = {
465
+ "UD_Chinese-Beginner": "zh-hans_beginner",
466
+ "UD_Chinese-GSDSimp": "zh-hans_gsdsimp",
467
+ "UD_Chinese-GSD": "zh-hant_gsd",
468
+ "UD_Chinese-HK": "zh-hant_hk",
469
+ "UD_Chinese-CFL": "zh-hans_cfl",
470
+ "UD_Chinese-PatentChar": "zh-hans_patentchar",
471
+ "UD_Chinese-PUD": "zh-hant_pud",
472
+ "UD_Norwegian-Bokmaal": "nb_bokmaal",
473
+ "UD_Norwegian-Nynorsk": "nn_nynorsk",
474
+ "UD_Norwegian-NynorskLIA": "nn_nynorsklia",
475
+ }
476
+
477
+ SHORTNAME_RE = re.compile("^[a-z-]+_[a-z0-9-_]+$")
478
+
479
+ def langcode_to_lang(lcode):
480
+ if lcode in lcode2lang:
481
+ return lcode2lang[lcode]
482
+ elif lcode.lower() in lcode2lang:
483
+ return lcode2lang[lcode.lower()]
484
+ else:
485
+ return lcode
486
+
487
+ def pretty_langcode_to_lang(lcode):
488
+ lang = langcode_to_lang(lcode)
489
+ lang = lang.replace("_", " ")
490
+ if lang == 'Simplified Chinese':
491
+ lang = 'Chinese (Simplified)'
492
+ elif lang == 'Traditional Chinese':
493
+ lang = 'Chinese (Traditional)'
494
+ return lang
495
+
496
+ def lang_to_langcode(lang):
497
+ if lang in lang2lcode:
498
+ lcode = lang2lcode[lang]
499
+ elif lang.lower() in langlower2lcode:
500
+ lcode = langlower2lcode[lang.lower()]
501
+ elif lang in lcode2lang:
502
+ lcode = lang
503
+ elif lang.lower() in lcode2lang:
504
+ lcode = lang.lower()
505
+ else:
506
+ raise UnknownLanguageError("Unable to find language code for %s" % lang)
507
+ return lcode
508
+
509
+ RIGHT_TO_LEFT = set(["ar", "arc", "az", "ckb", "dv", "ff", "he", "ku", "mzn", "nqo", "ps", "fa", "rhg", "sd", "syr", "ur"])
510
+
511
+ def is_right_to_left(lang):
512
+ """
513
+ Covers all the RtL languages we support, as well as many we don't.
514
+
515
+ If a language is left out, please let us know!
516
+ """
517
+ lcode = lang_to_langcode(lang)
518
+ return lcode in RIGHT_TO_LEFT
519
+
520
+ def treebank_to_short_name(treebank):
521
+ """ Convert treebank name to short code. """
522
+ if treebank in treebank_special_cases:
523
+ return treebank_special_cases.get(treebank)
524
+ if SHORTNAME_RE.match(treebank):
525
+ lang, corpus = treebank.split("_", 1)
526
+ lang = lang_to_langcode(lang)
527
+ return lang + "_" + corpus
528
+
529
+ if treebank.startswith('UD_'):
530
+ treebank = treebank[3:]
531
+ # special case starting with zh in case the input is an already-converted ZH treebank
532
+ if treebank.startswith("zh-hans") or treebank.startswith("zh-hant"):
533
+ splits = (treebank[:len("zh-hans")], treebank[len("zh-hans")+1:])
534
+ else:
535
+ splits = treebank.split('-')
536
+ if len(splits) == 1:
537
+ splits = treebank.split("_", 1)
538
+ assert len(splits) == 2, "Unable to process %s" % treebank
539
+ lang, corpus = splits
540
+
541
+ lcode = lang_to_langcode(lang)
542
+
543
+ short = "{}_{}".format(lcode, corpus.lower())
544
+ return short
545
+
546
+ def treebank_to_langid(treebank):
547
+ """ Convert treebank name to langid """
548
+ short_name = treebank_to_short_name(treebank)
549
+ return short_name.split("_")[0]
550
+
stanza/stanza/models/common/count_ner_coverage.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stanza.models.common import pretrain
2
+ import argparse
3
+
4
+ def parse_args():
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument('ners', type=str, nargs='*', help='Which treebanks to run on')
7
+ parser.add_argument('--pretrain', type=str, default="/home/john/stanza_resources/hi/pretrain/hdtb.pt", help='Which pretrain to use')
8
+ parser.set_defaults(ners=["/home/john/stanza/data/ner/hi_fire2013.train.csv",
9
+ "/home/john/stanza/data/ner/hi_fire2013.dev.csv"])
10
+ args = parser.parse_args()
11
+ return args
12
+
13
+
14
+ def read_ner(filename):
15
+ words = []
16
+ for line in open(filename).readlines():
17
+ line = line.strip()
18
+ if not line:
19
+ continue
20
+ if line.split("\t")[1] == 'O':
21
+ continue
22
+ words.append(line.split("\t")[0])
23
+ return words
24
+
25
+ def count_coverage(pretrain, words):
26
+ count = 0
27
+ for w in words:
28
+ if w in pretrain.vocab:
29
+ count = count + 1
30
+ return count / len(words)
31
+
32
+ args = parse_args()
33
+ pt = pretrain.Pretrain(args.pretrain)
34
+ for dataset in args.ners:
35
+ words = read_ner(dataset)
36
+ print(dataset)
37
+ print(count_coverage(pt, words))
38
+ print()
stanza/stanza/models/common/count_pretrain_coverage.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple script to count the fraction of words in a UD dataset which are in a particular pretrain.
2
+
3
+ For example, this script shows that the word2vec Armenian vectors,
4
+ truncated at 250K words, have 75% coverage of the Western Armenian
5
+ dataset, whereas the vectors available here have 88% coverage:
6
+
7
+ https://github.com/ispras-texterra/word-embeddings-eval-hy
8
+ """
9
+
10
+ from stanza.models.common import pretrain
11
+ from stanza.utils.conll import CoNLL
12
+
13
+ import argparse
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('treebanks', type=str, nargs='*', help='Which treebanks to run on')
18
+ parser.add_argument('--pretrain', type=str, default="/home/john/extern_data/wordvec/glove/armenian.pt", help='Which pretrain to use')
19
+ parser.set_defaults(treebanks=["/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Western_Armenian-ArmTDP/hyw_armtdp-ud-train.conllu",
20
+ "/home/john/extern_data/ud2/ud-treebanks-v2.7/UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu"])
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+
25
+ args = parse_args()
26
+ pt = pretrain.Pretrain(args.pretrain)
27
+ pt.load()
28
+ print("Pretrain stats: {} vectors, {} dim".format(len(pt.vocab), pt.emb[0].shape[0]))
29
+
30
+ for treebank in args.treebanks:
31
+ print(treebank)
32
+ found = 0
33
+ total = 0
34
+ doc = CoNLL.conll2doc(treebank)
35
+ for sentence in doc.sentences:
36
+ for word in sentence.words:
37
+ total = total + 1
38
+ if word.text in pt.vocab:
39
+ found = found + 1
40
+
41
+ print (found / total)
stanza/stanza/models/common/crf.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CRF loss and viterbi decoding.
3
+ """
4
+
5
+ import math
6
+ from numbers import Number
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.init as init
11
+
12
+ class CRFLoss(nn.Module):
13
+ """
14
+ Calculate log-space crf loss, given unary potentials, a transition matrix
15
+ and gold tag sequences.
16
+ """
17
+ def __init__(self, num_tag, batch_average=True):
18
+ super().__init__()
19
+ self._transitions = nn.Parameter(torch.zeros(num_tag, num_tag))
20
+ self._batch_average = batch_average # if not batch average, average on all tokens
21
+
22
+ def forward(self, inputs, masks, tag_indices):
23
+ """
24
+ inputs: batch_size x seq_len x num_tags
25
+ masks: batch_size x seq_len
26
+ tag_indices: batch_size x seq_len
27
+
28
+ @return:
29
+ loss: CRF negative log likelihood on all instances.
30
+ transitions: the transition matrix
31
+ """
32
+ # TODO: handle <start> and <end> tags
33
+ input_bs, input_sl, input_nc = inputs.size()
34
+ unary_scores = self.crf_unary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)
35
+ binary_scores = self.crf_binary_score(inputs, masks, tag_indices, input_bs, input_sl, input_nc)
36
+ log_norm = self.crf_log_norm(inputs, masks, tag_indices)
37
+ log_likelihood = unary_scores + binary_scores - log_norm # batch_size
38
+ loss = torch.sum(-log_likelihood)
39
+ if self._batch_average:
40
+ loss = loss / input_bs
41
+ else:
42
+ total = masks.eq(0).sum()
43
+ loss = loss / (total + 1e-8)
44
+ return loss, self._transitions
45
+
46
+ def crf_unary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):
47
+ """
48
+ @return:
49
+ unary_scores: batch_size
50
+ """
51
+ flat_inputs = inputs.view(input_bs, -1)
52
+ flat_tag_indices = tag_indices + torch.arange(input_sl, device=tag_indices.device).long().unsqueeze(0) * input_nc
53
+ unary_scores = torch.gather(flat_inputs, 1, flat_tag_indices).view(input_bs, -1)
54
+ unary_scores.masked_fill_(masks, 0)
55
+ return unary_scores.sum(dim=1)
56
+
57
+ def crf_binary_score(self, inputs, masks, tag_indices, input_bs, input_sl, input_nc):
58
+ """
59
+ @return:
60
+ binary_scores: batch_size
61
+ """
62
+ # get number of transitions
63
+ nt = tag_indices.size(-1) - 1
64
+ start_indices = tag_indices[:, :nt]
65
+ end_indices = tag_indices[:, 1:]
66
+ # flat matrices
67
+ flat_transition_indices = start_indices * input_nc + end_indices
68
+ flat_transition_indices = flat_transition_indices.view(-1)
69
+ flat_transition_matrix = self._transitions.view(-1)
70
+ binary_scores = torch.gather(flat_transition_matrix, 0, flat_transition_indices)\
71
+ .view(input_bs, -1)
72
+ score_masks = masks[:, 1:]
73
+ binary_scores.masked_fill_(score_masks, 0)
74
+ return binary_scores.sum(dim=1)
75
+
76
+ def crf_log_norm(self, inputs, masks, tag_indices):
77
+ """
78
+ Calculate the CRF partition in log space for each instance, following:
79
+ http://www.cs.columbia.edu/~mcollins/fb.pdf
80
+ @return:
81
+ log_norm: batch_size
82
+ """
83
+ start_inputs = inputs[:,0,:] # bs x nc
84
+ rest_inputs = inputs[:,1:,:]
85
+ # TODO: technically we need to pay attention to the initial
86
+ # value being masked. Currently we do compensate for the
87
+ # entire row being masked at the end of the operation
88
+ rest_masks = masks[:,1:]
89
+ alphas = start_inputs # bs x nc
90
+ trans = self._transitions.unsqueeze(0) # 1 x nc x nc
91
+ # accumulate alphas in log space
92
+ for i in range(rest_inputs.size(1)):
93
+ transition_scores = alphas.unsqueeze(2) + trans # bs x nc x nc
94
+ new_alphas = rest_inputs[:,i,:] + log_sum_exp(transition_scores, dim=1)
95
+ m = rest_masks[:,i].unsqueeze(1).expand_as(new_alphas) # bs x nc, 1 for padding idx
96
+ # apply masks
97
+ new_alphas.masked_scatter_(m, alphas.masked_select(m))
98
+ alphas = new_alphas
99
+ log_norm = log_sum_exp(alphas, dim=1)
100
+
101
+ # if any row was entirely masked, we just turn its log denominator to 0
102
+ # eg, the empty summation for the denominator will be 1, and its log will be 0
103
+ all_masked = torch.all(masks, dim=1)
104
+ log_norm = log_norm * torch.logical_not(all_masked)
105
+ return log_norm
106
+
107
+ def viterbi_decode(scores, transition_params):
108
+ """
109
+ Decode a tag sequence with viterbi algorithm.
110
+ scores: seq_len x num_tags (numpy array)
111
+ transition_params: num_tags x num_tags (numpy array)
112
+ @return:
113
+ viterbi: a list of tag ids with highest score
114
+ viterbi_score: the highest score
115
+ """
116
+ trellis = np.zeros_like(scores)
117
+ backpointers = np.zeros_like(scores, dtype=np.int32)
118
+ trellis[0] = scores[0]
119
+
120
+ for t in range(1, scores.shape[0]):
121
+ v = np.expand_dims(trellis[t-1], 1) + transition_params
122
+ trellis[t] = scores[t] + np.max(v, 0)
123
+ backpointers[t] = np.argmax(v, 0)
124
+
125
+ viterbi = [np.argmax(trellis[-1])]
126
+ for bp in reversed(backpointers[1:]):
127
+ viterbi.append(bp[viterbi[-1]])
128
+ viterbi.reverse()
129
+ viterbi_score = np.max(trellis[-1])
130
+ return viterbi, viterbi_score
131
+
132
+ def log_sum_exp(value, dim=None, keepdim=False):
133
+ """Numerically stable implementation of the operation
134
+ value.exp().sum(dim, keepdim).log()
135
+ """
136
+ if dim is not None:
137
+ m, _ = torch.max(value, dim=dim, keepdim=True)
138
+ value0 = value - m
139
+ if keepdim is False:
140
+ m = m.squeeze(dim)
141
+ return m + torch.log(torch.sum(torch.exp(value0),
142
+ dim=dim, keepdim=keepdim))
143
+ else:
144
+ m = torch.max(value)
145
+ sum_exp = torch.sum(torch.exp(value - m))
146
+ if isinstance(sum_exp, Number):
147
+ return m + math.log(sum_exp)
148
+ else:
149
+ return m + torch.log(sum_exp)
stanza/stanza/models/common/data.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for data transformations.
3
+ """
4
+
5
+ import logging
6
+ import random
7
+
8
+ import torch
9
+
10
+ import stanza.models.common.seq2seq_constant as constant
11
+ from stanza.models.common.doc import HEAD, ID, UPOS
12
+
13
+ logger = logging.getLogger('stanza')
14
+
15
+ def map_to_ids(tokens, vocab):
16
+ ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]
17
+ return ids
18
+
19
+ def get_long_tensor(tokens_list, batch_size, pad_id=constant.PAD_ID):
20
+ """ Convert (list of )+ tokens to a padded LongTensor. """
21
+ sizes = []
22
+ x = tokens_list
23
+ while isinstance(x[0], list):
24
+ sizes.append(max(len(y) for y in x))
25
+ x = [z for y in x for z in y]
26
+ # TODO: pass in a device parameter and put it directly on the relevant device?
27
+ # that might be faster than creating it and then moving it
28
+ tokens = torch.LongTensor(batch_size, *sizes).fill_(pad_id)
29
+ for i, s in enumerate(tokens_list):
30
+ tokens[i, :len(s)] = torch.LongTensor(s)
31
+ return tokens
32
+
33
+ def get_float_tensor(features_list, batch_size):
34
+ if features_list is None or features_list[0] is None:
35
+ return None
36
+ seq_len = max(len(x) for x in features_list)
37
+ feature_len = len(features_list[0][0])
38
+ features = torch.FloatTensor(batch_size, seq_len, feature_len).zero_()
39
+ for i,f in enumerate(features_list):
40
+ features[i,:len(f),:] = torch.FloatTensor(f)
41
+ return features
42
+
43
+ def sort_all(batch, lens):
44
+ """ Sort all fields by descending order of lens, and return the original indices. """
45
+ if batch == [[]]:
46
+ return [[]], []
47
+ unsorted_all = [lens] + [range(len(lens))] + list(batch)
48
+ sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
49
+ return sorted_all[2:], sorted_all[1]
50
+
51
+ def get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate, desired_ratio=0.1, max_ratio=0.5):
52
+ """
53
+ Returns X so that if you randomly select X * N sentences, you get 10%
54
+
55
+ The ratio will be chosen in the assumption that the final dataset
56
+ is of size N rather than N + X * N.
57
+
58
+ should_augment_predicate: returns True if the sentence has some
59
+ feature which we may want to change occasionally. for example,
60
+ depparse sentences which end in punct
61
+ can_augment_predicate: in the depparse sentences example, it is
62
+ technically possible for the punct at the end to be the parent
63
+ of some other word in the sentence. in that case, the sentence
64
+ should not be chosen. should be at least as restrictive as
65
+ should_augment_predicate
66
+ """
67
+ n_data = len(train_data)
68
+ n_should_augment = sum(should_augment_predicate(sentence) for sentence in train_data)
69
+ n_can_augment = sum(can_augment_predicate(sentence) for sentence in train_data)
70
+ n_error = sum(can_augment_predicate(sentence) and not should_augment_predicate(sentence)
71
+ for sentence in train_data)
72
+ if n_error > 0:
73
+ raise AssertionError("can_augment_predicate allowed sentences not allowed by should_augment_predicate")
74
+
75
+ if n_can_augment == 0:
76
+ logger.warning("Found no sentences which matched can_augment_predicate {}".format(can_augment_predicate))
77
+ return 0.0
78
+ n_needed = n_data * desired_ratio - (n_data - n_should_augment)
79
+ # if we want 10%, for example, and more than 10% already matches, we can skip
80
+ if n_needed < 0:
81
+ return 0.0
82
+ ratio = n_needed / n_can_augment
83
+ if ratio > max_ratio:
84
+ return max_ratio
85
+ return ratio
86
+
87
+
88
+ def should_augment_nopunct_predicate(sentence):
89
+ last_word = sentence[-1]
90
+ return last_word.get(UPOS, None) == 'PUNCT'
91
+
92
+ def can_augment_nopunct_predicate(sentence):
93
+ """
94
+ Check that the sentence ends with PUNCT and also doesn't have any words which depend on the last word
95
+ """
96
+ last_word = sentence[-1]
97
+ if last_word.get(UPOS, None) != 'PUNCT':
98
+ return False
99
+ # don't cut off MWT
100
+ if len(last_word[ID]) > 1:
101
+ return False
102
+ if any(len(word[ID]) == 1 and word[HEAD] == last_word[ID][0] for word in sentence):
103
+ return False
104
+ return True
105
+
106
+ def augment_punct(train_data, augment_ratio,
107
+ should_augment_predicate=should_augment_nopunct_predicate,
108
+ can_augment_predicate=can_augment_nopunct_predicate,
109
+ keep_original_sentences=True):
110
+
111
+ """
112
+ Adds extra training data to compensate for some models having all sentences end with PUNCT
113
+
114
+ Some of the models (for example, UD_Hebrew-HTB) have the flaw that
115
+ all of the training sentences end with PUNCT. The model therefore
116
+ learns to finish every sentence with punctuation, even if it is
117
+ given a sentence with non-punct at the end.
118
+
119
+ One simple way to fix this is to train on some fraction of training data with punct.
120
+
121
+ Params:
122
+ train_data: list of list of dicts, eg a conll doc
123
+ augment_ratio: the fraction to augment. if None, a best guess is made to get to 10%
124
+
125
+ should_augment_predicate: a function which returns T/F if a sentence already ends with not PUNCT
126
+ can_augment_predicate: a function which returns T/F if it makes sense to remove the last PUNCT
127
+
128
+ TODO: do this dynamically, as part of the DataLoader or elsewhere?
129
+ One complication is the data comes back from the DataLoader as
130
+ tensors & indices, so it is much more complicated to manipulate
131
+ """
132
+ if len(train_data) == 0:
133
+ return []
134
+
135
+ if augment_ratio is None:
136
+ augment_ratio = get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate)
137
+
138
+ if augment_ratio <= 0:
139
+ if keep_original_sentences:
140
+ return list(train_data)
141
+ else:
142
+ return []
143
+
144
+ new_data = []
145
+ for sentence in train_data:
146
+ if can_augment_predicate(sentence):
147
+ if random.random() < augment_ratio and len(sentence) > 1:
148
+ # todo: could deep copy the words
149
+ # or not deep copy any of this
150
+ new_sentence = list(sentence[:-1])
151
+ new_data.append(new_sentence)
152
+ elif keep_original_sentences:
153
+ new_data.append(new_sentence)
154
+
155
+ return new_data
stanza/stanza/models/common/doc.py ADDED
@@ -0,0 +1,1741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic data structures
3
+ """
4
+
5
+ import io
6
+ from itertools import repeat
7
+ import re
8
+ import json
9
+ import pickle
10
+ import warnings
11
+
12
+ from enum import Enum
13
+
14
+ import networkx as nx
15
+
16
+ from stanza.models.common.stanza_object import StanzaObject
17
+ from stanza.models.common.utils import misc_to_space_after, space_after_to_misc, misc_to_space_before, space_before_to_misc
18
+ from stanza.models.ner.utils import decode_from_bioes
19
+ from stanza.models.constituency import tree_reader
20
+ from stanza.models.coref.coref_chain import CorefMention, CorefChain, CorefAttachment
21
+
22
+ class MWTProcessingType(Enum):
23
+ FLATTEN = 0 # flatten the current token into one ID instead of MWT
24
+ PROCESS = 1 # process the current token as an MWT and expand it as such
25
+ SKIP = 2 # do nothing on this token, simply increment IDs
26
+
27
+ multi_word_token_id = re.compile(r"([0-9]+)-([0-9]+)")
28
+ multi_word_token_misc = re.compile(r".*MWT=Yes.*")
29
+
30
+ MEXP = 'manual_expansion'
31
+ ID = 'id'
32
+ TEXT = 'text'
33
+ LEMMA = 'lemma'
34
+ UPOS = 'upos'
35
+ XPOS = 'xpos'
36
+ FEATS = 'feats'
37
+ HEAD = 'head'
38
+ DEPREL = 'deprel'
39
+ DEPS = 'deps'
40
+ MISC = 'misc'
41
+ NER = 'ner'
42
+ MULTI_NER = 'multi_ner' # will represent tags from multiple NER models
43
+ START_CHAR = 'start_char'
44
+ END_CHAR = 'end_char'
45
+ TYPE = 'type'
46
+ SENTIMENT = 'sentiment'
47
+ CONSTITUENCY = 'constituency'
48
+ COREF_CHAINS = 'coref_chains'
49
+
50
+ # field indices when converting the document to conll
51
+ FIELD_TO_IDX = {ID: 0, TEXT: 1, LEMMA: 2, UPOS: 3, XPOS: 4, FEATS: 5, HEAD: 6, DEPREL: 7, DEPS: 8, MISC: 9}
52
+ FIELD_NUM = len(FIELD_TO_IDX)
53
+
54
+ class DocJSONEncoder(json.JSONEncoder):
55
+ def default(self, obj):
56
+ if isinstance(obj, CorefMention):
57
+ return obj.__dict__
58
+ if isinstance(obj, CorefAttachment):
59
+ return obj.to_json()
60
+ return json.JSONEncoder.default(self, obj)
61
+
62
+ class Document(StanzaObject):
63
+ """ A document class that stores attributes of a document and carries a list of sentences.
64
+ """
65
+
66
+ def __init__(self, sentences, text=None, comments=None, empty_sentences=None):
67
+ """ Construct a document given a list of sentences in the form of lists of CoNLL-U dicts.
68
+
69
+ Args:
70
+ sentences: a list of sentences, which being a list of token entry, in the form of a CoNLL-U dict.
71
+ text: the raw text of the document.
72
+ comments: A list of list of strings to use as comments on the sentences, either None or the same length as sentences
73
+ """
74
+ self._sentences = []
75
+ self._lang = None
76
+ self._text = text
77
+ self._num_tokens = 0
78
+ self._num_words = 0
79
+
80
+ self._process_sentences(sentences, comments, empty_sentences)
81
+ self._ents = []
82
+ self._coref = []
83
+ if self._text is not None:
84
+ self.build_ents()
85
+ self.mark_whitespace()
86
+
87
+ def mark_whitespace(self):
88
+ for sentence in self._sentences:
89
+ # TODO: pairwise, once we move to minimum 3.10
90
+ for prev_token, next_token in zip(sentence.tokens[:-1], sentence.tokens[1:]):
91
+ whitespace = self._text[prev_token.end_char:next_token.start_char]
92
+ prev_token.spaces_after = whitespace
93
+ for prev_sentence, next_sentence in zip(self._sentences[:-1], self._sentences[1:]):
94
+ prev_token = prev_sentence.tokens[-1]
95
+ next_token = next_sentence.tokens[0]
96
+ whitespace = self._text[prev_token.end_char:next_token.start_char]
97
+ prev_token.spaces_after = whitespace
98
+ if len(self._sentences) > 0 and len(self._sentences[-1].tokens) > 0:
99
+ final_token = self._sentences[-1].tokens[-1]
100
+ whitespace = self._text[final_token.end_char:]
101
+ final_token.spaces_after = whitespace
102
+ if len(self._sentences) > 0 and len(self._sentences[0].tokens) > 0:
103
+ first_token = self._sentences[0].tokens[0]
104
+ whitespace = self._text[:first_token.start_char]
105
+ first_token.spaces_before = whitespace
106
+
107
+
108
+ @property
109
+ def lang(self):
110
+ """ Access the language of this document """
111
+ return self._lang
112
+
113
+ @lang.setter
114
+ def lang(self, value):
115
+ """ Set the language of this document """
116
+ self._lang = value
117
+
118
+ @property
119
+ def text(self):
120
+ """ Access the raw text for this document. """
121
+ return self._text
122
+
123
+ @text.setter
124
+ def text(self, value):
125
+ """ Set the raw text for this document. """
126
+ self._text = value
127
+
128
+ @property
129
+ def sentences(self):
130
+ """ Access the list of sentences for this document. """
131
+ return self._sentences
132
+
133
+ @sentences.setter
134
+ def sentences(self, value):
135
+ """ Set the list of tokens for this document. """
136
+ self._sentences = value
137
+
138
+ @property
139
+ def num_tokens(self):
140
+ """ Access the number of tokens for this document. """
141
+ return self._num_tokens
142
+
143
+ @num_tokens.setter
144
+ def num_tokens(self, value):
145
+ """ Set the number of tokens for this document. """
146
+ self._num_tokens = value
147
+
148
+ @property
149
+ def num_words(self):
150
+ """ Access the number of words for this document. """
151
+ return self._num_words
152
+
153
+ @num_words.setter
154
+ def num_words(self, value):
155
+ """ Set the number of words for this document. """
156
+ self._num_words = value
157
+
158
+ @property
159
+ def ents(self):
160
+ """ Access the list of entities in this document. """
161
+ return self._ents
162
+
163
+ @ents.setter
164
+ def ents(self, value):
165
+ """ Set the list of entities in this document. """
166
+ self._ents = value
167
+
168
+ @property
169
+ def entities(self):
170
+ """ Access the list of entities. This is just an alias of `ents`. """
171
+ return self._ents
172
+
173
+ @entities.setter
174
+ def entities(self, value):
175
+ """ Set the list of entities in this document. """
176
+ self._ents = value
177
+
178
+ def _process_sentences(self, sentences, comments=None, empty_sentences=None):
179
+ self.sentences = []
180
+ if empty_sentences is None:
181
+ empty_sentences = repeat([])
182
+ for sent_idx, (tokens, empty_words) in enumerate(zip(sentences, empty_sentences)):
183
+ try:
184
+ sentence = Sentence(tokens, doc=self, empty_words=empty_words)
185
+ except IndexError as e:
186
+ raise IndexError("Could not process document at sentence %d" % sent_idx) from e
187
+ except ValueError as e:
188
+ tokens = ["|%s|" % t for t in tokens]
189
+ tokens = ", ".join(tokens)
190
+ raise ValueError("Could not process document at sentence %d\n Raw tokens: %s" % (sent_idx, tokens)) from e
191
+ self.sentences.append(sentence)
192
+ begin_idx, end_idx = sentence.tokens[0].start_char, sentence.tokens[-1].end_char
193
+ if all((self.text is not None, begin_idx is not None, end_idx is not None)): sentence.text = self.text[begin_idx: end_idx]
194
+ sentence.index = sent_idx
195
+
196
+ self._count_words()
197
+
198
+ # Add a #text comment to each sentence in a doc if it doesn't already exist
199
+ if not comments:
200
+ comments = [[] for x in self.sentences]
201
+ else:
202
+ comments = [list(x) for x in comments]
203
+ for sentence, sentence_comments in zip(self.sentences, comments):
204
+ # the space after text can occur in treebanks such as the Naija-NSC treebank,
205
+ # which extensively uses `# text_en =` and `# text_ortho`
206
+ if sentence.text and not any(comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text=") for comment in sentence_comments):
207
+ # split/join to handle weird whitespace, especially newlines
208
+ sentence_comments.append("# text = " + ' '.join(sentence.text.split()))
209
+ elif not sentence.text:
210
+ for comment in sentence_comments:
211
+ if comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text="):
212
+ sentence.text = comment.split("=", 1)[-1].strip()
213
+ break
214
+
215
+ for comment in sentence_comments:
216
+ sentence.add_comment(comment)
217
+
218
+ # look for sent_id in the comments
219
+ # if it's there, overwrite the sent_idx id from above
220
+ for comment in sentence_comments:
221
+ if comment.startswith("# sent_id"):
222
+ sentence.sent_id = comment.split("=", 1)[-1].strip()
223
+ break
224
+ else:
225
+ # no sent_id found. add a comment with our enumerated id
226
+ # setting the sent_id on the sentence will automatically add the comment
227
+ sentence.sent_id = str(sentence.index)
228
+
229
+ def _count_words(self):
230
+ """
231
+ Count the number of tokens and words
232
+ """
233
+ self.num_tokens = sum([len(sentence.tokens) for sentence in self.sentences])
234
+ self.num_words = sum([len(sentence.words) for sentence in self.sentences])
235
+
236
+ def get(self, fields, as_sentences=False, from_token=False):
237
+ """ Get fields from a list of field names.
238
+ If only one field name (string or singleton list) is provided,
239
+ return a list of that field; if more than one, return a list of list.
240
+ Note that all returned fields are after multi-word expansion.
241
+
242
+ Args:
243
+ fields: name of the fields as a list or a single string
244
+ as_sentences: if True, return the fields as a list of sentences; otherwise as a whole list
245
+ from_token: if True, get the fields from Token; otherwise from Word
246
+
247
+ Returns:
248
+ All requested fields.
249
+ """
250
+ if isinstance(fields, str):
251
+ fields = [fields]
252
+ assert isinstance(fields, list), "Must provide field names as a list."
253
+ assert len(fields) >= 1, "Must have at least one field."
254
+
255
+ results = []
256
+ for sentence in self.sentences:
257
+ cursent = []
258
+ # decide word or token
259
+ if from_token:
260
+ units = sentence.tokens
261
+ else:
262
+ units = sentence.words
263
+ for unit in units:
264
+ if len(fields) == 1:
265
+ cursent += [getattr(unit, fields[0])]
266
+ else:
267
+ cursent += [[getattr(unit, field) for field in fields]]
268
+
269
+ # decide whether append the results as a sentence or a whole list
270
+ if as_sentences:
271
+ results.append(cursent)
272
+ else:
273
+ results += cursent
274
+ return results
275
+
276
+ def set(self, fields, contents, to_token=False, to_sentence=False):
277
+ """Set fields based on contents. If only one field (string or
278
+ singleton list) is provided, then a list of content will be
279
+ expected; otherwise a list of list of contents will be expected.
280
+
281
+ Args:
282
+ fields: name of the fields as a list or a single string
283
+ contents: field values to set; total length should be equal to number of words/tokens
284
+ to_token: if True, set field values to tokens; otherwise to words
285
+
286
+ """
287
+ if isinstance(fields, str):
288
+ fields = [fields]
289
+ assert isinstance(fields, (tuple, list)), "Must provide field names as a list."
290
+ assert isinstance(contents, (tuple, list)), "Must provide contents as a list (one item per line)."
291
+ assert len(fields) >= 1, "Must have at least one field."
292
+
293
+ assert not to_sentence or not to_token, "Both to_token and to_sentence set to True, which is very confusing"
294
+
295
+ if to_sentence:
296
+ assert len(self.sentences) == len(contents), \
297
+ "Contents must have the same length as the sentences"
298
+ for sentence, content in zip(self.sentences, contents):
299
+ if len(fields) == 1:
300
+ setattr(sentence, fields[0], content)
301
+ else:
302
+ for field, piece in zip(fields, content):
303
+ setattr(sentence, field, piece)
304
+ else:
305
+ assert (to_token and self.num_tokens == len(contents)) or self.num_words == len(contents), \
306
+ "Contents must have the same length as the original file."
307
+
308
+ cidx = 0
309
+ for sentence in self.sentences:
310
+ # decide word or token
311
+ if to_token:
312
+ units = sentence.tokens
313
+ else:
314
+ units = sentence.words
315
+ for unit in units:
316
+ if len(fields) == 1:
317
+ setattr(unit, fields[0], contents[cidx])
318
+ else:
319
+ for field, content in zip(fields, contents[cidx]):
320
+ setattr(unit, field, content)
321
+ cidx += 1
322
+
323
+ def set_mwt_expansions(self, expansions,
324
+ fake_dependencies=False,
325
+ process_manual_expanded=None):
326
+ """ Extend the multi-word tokens annotated by tokenizer. A list of list of expansions
327
+ will be expected for each multi-word token. Use `process_manual_expanded` to limit
328
+ processing for tokens marked manually expanded:
329
+
330
+ There are two types of MWT expansions: those with `misc`: `MWT=True`, and those with
331
+ `manual_expansion`: True. The latter of which means that it is an expansion which the
332
+ user manually specified through a postprocessor; the former means that it is a MWT
333
+ which the detector picked out, but needs to be automatically expanded.
334
+
335
+ process_manual_expanded = None - default; doesn't process manually expanded tokens
336
+ = True - process only manually expanded tokens (with `manual_expansion`: True)
337
+ = False - process only tokens explicitly tagged as MWT (`misc`: `MWT=True`)
338
+ """
339
+
340
+ idx_e = 0
341
+ for sentence in self.sentences:
342
+ idx_w = 0
343
+ for token in sentence.tokens:
344
+ idx_w += 1
345
+ is_multi = (len(token.id) > 1)
346
+ is_mwt = (multi_word_token_misc.match(token.misc) if token.misc is not None else None)
347
+ is_manual_expansion = token.manual_expansion
348
+
349
+ perform_mwt_processing = MWTProcessingType.FLATTEN
350
+
351
+ if (process_manual_expanded and is_manual_expansion):
352
+ perform_mwt_processing = MWTProcessingType.PROCESS
353
+ elif (process_manual_expanded==False and is_mwt):
354
+ perform_mwt_processing = MWTProcessingType.PROCESS
355
+ elif (process_manual_expanded==False and is_manual_expansion):
356
+ perform_mwt_processing = MWTProcessingType.SKIP
357
+ elif (process_manual_expanded==None and (is_mwt or is_multi)):
358
+ perform_mwt_processing = MWTProcessingType.PROCESS
359
+
360
+ if perform_mwt_processing == MWTProcessingType.FLATTEN:
361
+ for word in token.words:
362
+ token.id = (idx_w, )
363
+ # delete dependency information
364
+ word.deps = None
365
+ word.head, word.deprel = None, None
366
+ word.id = idx_w
367
+ elif perform_mwt_processing == MWTProcessingType.PROCESS:
368
+ expanded = [x for x in expansions[idx_e].split(' ') if len(x) > 0]
369
+ # in the event the MWT annotator only split the
370
+ # Token into a single Word, we preserve its text
371
+ # otherwise the Token's text is different from its
372
+ # only Word's text
373
+ if len(expanded) == 1:
374
+ expanded = [token.text]
375
+ idx_e += 1
376
+ idx_w_end = idx_w + len(expanded) - 1
377
+ if token.misc: # None can happen when using a prebuilt doc
378
+ token.misc = None if token.misc == 'MWT=Yes' else '|'.join([x for x in token.misc.split('|') if x != 'MWT=Yes'])
379
+ token.id = (idx_w, idx_w_end) if len(expanded) > 1 else (idx_w,)
380
+ token.words = []
381
+ for i, e_word in enumerate(expanded):
382
+ token.words.append(Word(sentence, {ID: idx_w + i, TEXT: e_word}))
383
+ idx_w = idx_w_end
384
+ elif perform_mwt_processing == MWTProcessingType.SKIP:
385
+ token.id = tuple(orig_id + idx_e for orig_id in token.id)
386
+ for i in token.words:
387
+ i.id += idx_e
388
+ idx_w = token.id[-1]
389
+ token.manual_expansion = None
390
+
391
+ # reprocess the words using the new tokens
392
+ sentence.words = []
393
+ for token in sentence.tokens:
394
+ token.sent = sentence
395
+ for word in token.words:
396
+ word.sent = sentence
397
+ word.parent = token
398
+ sentence.words.append(word)
399
+ if token.start_char is not None and token.end_char is not None and "".join(word.text for word in token.words) == token.text:
400
+ start_char = token.start_char
401
+ for word in token.words:
402
+ end_char = start_char + len(word.text)
403
+ word.start_char = start_char
404
+ word.end_char = end_char
405
+ start_char = end_char
406
+
407
+ if fake_dependencies:
408
+ sentence.build_fake_dependencies()
409
+ else:
410
+ sentence.rebuild_dependencies()
411
+
412
+ self._count_words() # update number of words & tokens
413
+ assert idx_e == len(expansions), "{} {}".format(idx_e, len(expansions))
414
+ return
415
+
416
+ def get_mwt_expansions(self, evaluation=False):
417
+ """ Get the multi-word tokens. For training, return a list of
418
+ (multi-word token, extended multi-word token); otherwise, return a list of
419
+ multi-word token only. By default doesn't skip already expanded tokens, but
420
+ `skip_already_expanded` will return only tokens marked as MWT.
421
+ """
422
+ expansions = []
423
+ for sentence in self.sentences:
424
+ for token in sentence.tokens:
425
+ is_multi = (len(token.id) > 1)
426
+ is_mwt = multi_word_token_misc.match(token.misc) if token.misc is not None else None
427
+ is_manual_expansion = token.manual_expansion
428
+ if (is_multi and not is_manual_expansion) or is_mwt:
429
+ src = token.text
430
+ dst = ' '.join([word.text for word in token.words])
431
+ expansions.append([src, dst])
432
+ if evaluation: expansions = [e[0] for e in expansions]
433
+ return expansions
434
+
435
+ def build_ents(self):
436
+ """ Build the list of entities by iterating over all words. Return all entities as a list. """
437
+ self.ents = []
438
+ for s in self.sentences:
439
+ s_ents = s.build_ents()
440
+ self.ents += s_ents
441
+ return self.ents
442
+
443
+ def sort_features(self):
444
+ """ Sort the features on all the words... useful for prototype treebanks, for example """
445
+ for sentence in self.sentences:
446
+ for word in sentence.words:
447
+ if not word.feats:
448
+ continue
449
+ pieces = word.feats.split("|")
450
+ pieces = sorted(pieces)
451
+ word.feats = "|".join(pieces)
452
+
453
+ def iter_words(self):
454
+ """ An iterator that returns all of the words in this Document. """
455
+ for s in self.sentences:
456
+ yield from s.words
457
+
458
+ def iter_tokens(self):
459
+ """ An iterator that returns all of the tokens in this Document. """
460
+ for s in self.sentences:
461
+ yield from s.tokens
462
+
463
+ def sentence_comments(self):
464
+ """ Returns a list of list of comments for the sentences """
465
+ return [[comment for comment in sentence.comments] for sentence in self.sentences]
466
+
467
+ @property
468
+ def coref(self):
469
+ """
470
+ Access the coref lists of the document
471
+ """
472
+ return self._coref
473
+
474
+ @coref.setter
475
+ def coref(self, chains):
476
+ """ Set the document's coref lists """
477
+ self._coref = chains
478
+ self._attach_coref_mentions(chains)
479
+
480
+ def _attach_coref_mentions(self, chains):
481
+ for sentence in self.sentences:
482
+ for word in sentence.words:
483
+ word.coref_chains = []
484
+
485
+ for chain in chains:
486
+ for mention_idx, mention in enumerate(chain.mentions):
487
+ sentence = self.sentences[mention.sentence]
488
+ for word_idx in range(mention.start_word, mention.end_word):
489
+ is_start = word_idx == mention.start_word
490
+ is_end = word_idx == mention.end_word - 1
491
+ is_representative = mention_idx == chain.representative_index
492
+ attachment = CorefAttachment(chain, is_start, is_end, is_representative)
493
+ sentence.words[word_idx].coref_chains.append(attachment)
494
+
495
+ def reindex_sentences(self, start_index):
496
+ for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences):
497
+ sentence.sent_id = str(sent_id)
498
+
499
+ def to_dict(self):
500
+ """ Dumps the whole document into a list of list of dictionary for each token in each sentence in the doc.
501
+ """
502
+ return [sentence.to_dict() for sentence in self.sentences]
503
+
504
+ def __repr__(self):
505
+ return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
506
+
507
+ def __format__(self, spec):
508
+ if spec == 'c':
509
+ return "\n\n".join("{:c}".format(s) for s in self.sentences)
510
+ elif spec == 'C':
511
+ return "\n\n".join("{:C}".format(s) for s in self.sentences)
512
+ else:
513
+ return str(self)
514
+
515
+ def to_serialized(self):
516
+ """ Dumps the whole document including text to a byte array containing a list of list of dictionaries for each token in each sentence in the doc.
517
+ """
518
+ return pickle.dumps((self.text, self.to_dict(), self.sentence_comments()))
519
+
520
+ @classmethod
521
+ def from_serialized(cls, serialized_string):
522
+ """ Create and initialize a new document from a serialized string generated by Document.to_serialized_string():
523
+ """
524
+ stuff = pickle.loads(serialized_string)
525
+ if not isinstance(stuff, tuple):
526
+ raise TypeError("Serialized data was not a tuple when building a Document")
527
+ if len(stuff) == 2:
528
+ text, sentences = pickle.loads(serialized_string)
529
+ doc = cls(sentences, text)
530
+ else:
531
+ text, sentences, comments = pickle.loads(serialized_string)
532
+ doc = cls(sentences, text, comments)
533
+ return doc
534
+
535
+
536
+ class Sentence(StanzaObject):
537
+ """ A sentence class that stores attributes of a sentence and carries a list of tokens.
538
+ """
539
+
540
+ def __init__(self, tokens, doc=None, empty_words=None):
541
+ """ Construct a sentence given a list of tokens in the form of CoNLL-U dicts.
542
+ """
543
+ self._tokens = []
544
+ self._words = []
545
+ self._dependencies = []
546
+ self._text = None
547
+ self._ents = []
548
+ self._doc = doc
549
+ self._constituency = None
550
+ self._sentiment = None
551
+ # comments are a list of comment lines occurring before the
552
+ # sentence in a CoNLL-U file. Can be empty
553
+ self._comments = []
554
+ self._doc_id = None
555
+
556
+ # enhanced_dependencies represents the DEPS column
557
+ # this is a networkx MultiDiGraph
558
+ # with edges from the parent to the dependent
559
+ # however, we set it to None until needed, as it is somewhat slow
560
+ self._enhanced_dependencies = None
561
+ self._process_tokens(tokens)
562
+
563
+ if empty_words is not None:
564
+ self._empty_words = [Word(self, entry) for entry in empty_words]
565
+ else:
566
+ self._empty_words = []
567
+
568
+ def _process_tokens(self, tokens):
569
+ st, en = -1, -1
570
+ self.tokens, self.words = [], []
571
+ for i, entry in enumerate(tokens):
572
+ if ID not in entry: # manually set a 1-based id for word if not exist
573
+ entry[ID] = (i+1, )
574
+ if isinstance(entry[ID], int):
575
+ entry[ID] = (entry[ID], )
576
+ if len(entry.get(ID)) > 1: # if this token is a multi-word token
577
+ st, en = entry[ID]
578
+ self.tokens.append(Token(self, entry))
579
+ else: # else this token is a word
580
+ new_word = Word(self, entry)
581
+ if len(self.words) > 0 and self.words[-1].id == new_word.id:
582
+ # this can happen in the following context:
583
+ # a document was created with MWT=Yes to mark that a token should be split
584
+ # and then there was an MWT "expansion" with a single word after that token
585
+ # we replace the Word in the Token assuming that the expansion token might
586
+ # have more information than the Token dict did
587
+ # note that a single word MWT like that can be detected with something like
588
+ # multi_word_token_misc.match(entry.get(MISC)) if entry.get(MISC, None)
589
+ self.words[-1] = new_word
590
+ self.tokens[-1].words[-1] = new_word
591
+ continue
592
+ self.words.append(new_word)
593
+ idx = entry.get(ID)[0]
594
+ if idx <= en:
595
+ self.tokens[-1].words.append(new_word)
596
+ else:
597
+ self.tokens.append(Token(self, entry, words=[new_word]))
598
+ new_word.parent = self.tokens[-1]
599
+
600
+ # put all of the whitespace annotations (if any) on the Tokens instead of the Words
601
+ for token in self.tokens:
602
+ token.consolidate_whitespace()
603
+ self.rebuild_dependencies()
604
+
605
+ def has_enhanced_dependencies(self):
606
+ """
607
+ Whether or not the enhanced dependencies are part of this sentence
608
+ """
609
+ return self._enhanced_dependencies is not None and len(self._enhanced_dependencies) > 0
610
+
611
+ @property
612
+ def index(self):
613
+ """
614
+ Access the index of this sentence within the doc.
615
+
616
+ If multiple docs were processed together,
617
+ the sentence index will continue counting across docs.
618
+ """
619
+ return self._index
620
+
621
+ @index.setter
622
+ def index(self, value):
623
+ """ Set the sentence's index value. """
624
+ self._index = value
625
+
626
+ @property
627
+ def id(self):
628
+ """
629
+ Access the index of this sentence within the doc.
630
+
631
+ If multiple docs were processed together,
632
+ the sentence index will continue counting across docs.
633
+ """
634
+ warnings.warn("Use of sentence.id is deprecated. Please use sentence.index instead", stacklevel=2)
635
+ return self._index
636
+
637
+ @id.setter
638
+ def id(self, value):
639
+ """ Set the sentence's index value. """
640
+ warnings.warn("Use of sentence.id is deprecated. Please use sentence.index instead", stacklevel=2)
641
+ self._index = value
642
+
643
+ @property
644
+ def sent_id(self):
645
+ """ conll-style sent_id Will be set from index if unknown """
646
+ return self._sent_id
647
+
648
+ @sent_id.setter
649
+ def sent_id(self, value):
650
+ """ Set the sentence's sent_id value. """
651
+ self._sent_id = value
652
+ sent_id_comment = "# sent_id = " + str(value)
653
+ for comment_idx, comment in enumerate(self._comments):
654
+ if comment.startswith("# sent_id = "):
655
+ self._comments[comment_idx] = sent_id_comment
656
+ break
657
+ else: # this is intended to be a for/else loop
658
+ self._comments.append(sent_id_comment)
659
+
660
+ @property
661
+ def doc_id(self):
662
+ """ conll-style doc_id Can be left blank if unknown """
663
+ return self._doc_id
664
+
665
+ @doc_id.setter
666
+ def doc_id(self, value):
667
+ """ Set the sentence's doc_id value. """
668
+ self._doc_id = value
669
+ doc_id_comment = "# doc_id = " + str(value)
670
+ for comment_idx, comment in enumerate(self._comments):
671
+ if comment.startswith("# doc_id = "):
672
+ self._comments[comment_idx] = doc_id_comment
673
+ break
674
+ else: # this is intended to be a for/else loop
675
+ self._comments.append(doc_id_comment)
676
+
677
+ @property
678
+ def doc(self):
679
+ """ Access the parent doc of this span. """
680
+ return self._doc
681
+
682
+ @doc.setter
683
+ def doc(self, value):
684
+ """ Set the parent doc of this span. """
685
+ self._doc = value
686
+
687
+ @property
688
+ def text(self):
689
+ """ Access the raw text for this sentence. """
690
+ return self._text
691
+
692
+ @text.setter
693
+ def text(self, value):
694
+ """ Set the raw text for this sentence. """
695
+ self._text = value
696
+
697
+ @property
698
+ def dependencies(self):
699
+ """ Access list of dependencies for this sentence. """
700
+ return self._dependencies
701
+
702
+ @dependencies.setter
703
+ def dependencies(self, value):
704
+ """ Set the list of dependencies for this sentence. """
705
+ self._dependencies = value
706
+
707
+ @property
708
+ def tokens(self):
709
+ """ Access the list of tokens for this sentence. """
710
+ return self._tokens
711
+
712
+ @tokens.setter
713
+ def tokens(self, value):
714
+ """ Set the list of tokens for this sentence. """
715
+ self._tokens = value
716
+
717
+ @property
718
+ def words(self):
719
+ """ Access the list of words for this sentence. """
720
+ return self._words
721
+
722
+ @words.setter
723
+ def words(self, value):
724
+ """ Set the list of words for this sentence. """
725
+ self._words = value
726
+
727
+ @property
728
+ def empty_words(self):
729
+ """ Access the list of words for this sentence. """
730
+ return self._empty_words
731
+
732
+ @empty_words.setter
733
+ def empty_words(self, value):
734
+ """ Set the list of words for this sentence. """
735
+ self._empty_words = value
736
+
737
+ @property
738
+ def ents(self):
739
+ """ Access the list of entities in this sentence. """
740
+ return self._ents
741
+
742
+ @ents.setter
743
+ def ents(self, value):
744
+ """ Set the list of entities in this sentence. """
745
+ self._ents = value
746
+
747
+ @property
748
+ def entities(self):
749
+ """ Access the list of entities. This is just an alias of `ents`. """
750
+ return self._ents
751
+
752
+ @entities.setter
753
+ def entities(self, value):
754
+ """ Set the list of entities in this sentence. """
755
+ self._ents = value
756
+
757
+ def build_ents(self):
758
+ """ Build the list of entities by iterating over all tokens. Return all entities as a list.
759
+
760
+ Note that unlike other attributes, since NER requires raw text, the actual tagging are always
761
+ performed at and attached to the `Token`s, instead of `Word`s.
762
+ """
763
+ self.ents = []
764
+ tags = [w.ner for w in self.tokens]
765
+ decoded = decode_from_bioes(tags)
766
+ for e in decoded:
767
+ ent_tokens = self.tokens[e['start']:e['end']+1]
768
+ self.ents.append(Span(tokens=ent_tokens, type=e['type'], doc=self.doc, sent=self))
769
+ return self.ents
770
+
771
+ @property
772
+ def sentiment(self):
773
+ """ Returns the sentiment value for this sentence """
774
+ return self._sentiment
775
+
776
+ @sentiment.setter
777
+ def sentiment(self, value):
778
+ """ Set the sentiment value """
779
+ self._sentiment = value
780
+ sentiment_comment = "# sentiment = " + str(value)
781
+ for comment_idx, comment in enumerate(self._comments):
782
+ if comment.startswith("# sentiment = "):
783
+ self._comments[comment_idx] = sentiment_comment
784
+ break
785
+ else: # this is intended to be a for/else loop
786
+ self._comments.append(sentiment_comment)
787
+
788
+ @property
789
+ def constituency(self):
790
+ """ Returns the constituency tree for this sentence """
791
+ return self._constituency
792
+
793
+ @constituency.setter
794
+ def constituency(self, value):
795
+ """
796
+ Set the constituency tree
797
+
798
+ This incidentally updates the #constituency comment if it already exists,
799
+ or otherwise creates a new comment # constituency = ...
800
+ """
801
+ self._constituency = value
802
+ constituency_comment = "# constituency = " + str(value)
803
+ constituency_comment = constituency_comment.replace("\n", "*NL*").replace("\r", "")
804
+ for comment_idx, comment in enumerate(self._comments):
805
+ if comment.startswith("# constituency = "):
806
+ self._comments[comment_idx] = constituency_comment
807
+ break
808
+ else: # this is intended to be a for/else loop
809
+ self._comments.append(constituency_comment)
810
+
811
+
812
+ @property
813
+ def comments(self):
814
+ """ Returns CoNLL-style comments for this sentence """
815
+ return self._comments
816
+
817
+ def add_comment(self, comment):
818
+ """ Adds a single comment to this sentence.
819
+
820
+ If the comment does not already have # at the start, it will be added.
821
+ """
822
+ if not comment.startswith("#"):
823
+ comment = "# " + comment
824
+ if comment.startswith("# constituency ="):
825
+ _, tree_text = comment.split("=", 1)
826
+ tree = tree_reader.read_trees(tree_text)
827
+ if len(tree) > 1:
828
+ raise ValueError("Multiple constituency trees for one sentence: %s" % tree_text)
829
+ self._constituency = tree[0]
830
+ self._comments = [x for x in self._comments if not x.startswith("# constituency =")]
831
+ elif comment.startswith("# sentiment ="):
832
+ _, sentiment = comment.split("=", 1)
833
+ sentiment = int(sentiment.strip())
834
+ self._sentiment = sentiment
835
+ self._comments = [x for x in self._comments if not x.startswith("# sentiment =")]
836
+ elif comment.startswith("# sent_id ="):
837
+ _, sent_id = comment.split("=", 1)
838
+ sent_id = sent_id.strip()
839
+ self._sent_id = sent_id
840
+ self._comments = [x for x in self._comments if not x.startswith("# sent_id =")]
841
+ elif comment.startswith("# doc_id ="):
842
+ _, doc_id = comment.split("=", 1)
843
+ doc_id = doc_id.strip()
844
+ self._doc_id = doc_id
845
+ self._comments = [x for x in self._comments if not x.startswith("# doc_id =")]
846
+ self._comments.append(comment)
847
+
848
+ def rebuild_dependencies(self):
849
+ # rebuild dependencies if there is dependency info
850
+ is_complete_dependencies = all(word.head is not None and word.deprel is not None for word in self.words)
851
+ is_complete_words = (len(self.words) >= len(self.tokens)) and (len(self.words) == self.words[-1].id)
852
+ if is_complete_dependencies and is_complete_words: self.build_dependencies()
853
+
854
+ def build_dependencies(self):
855
+ """ Build the dependency graph for this sentence. Each dependency graph entry is
856
+ a list of (head, deprel, word).
857
+ """
858
+ self.dependencies = []
859
+ for word in self.words:
860
+ if word.head == 0:
861
+ # make a word for the ROOT
862
+ word_entry = {ID: 0, TEXT: "ROOT"}
863
+ head = Word(self, word_entry)
864
+ else:
865
+ # id is index in words list + 1
866
+ try:
867
+ head = self.words[word.head - 1]
868
+ except IndexError as e:
869
+ raise IndexError("Word head {} is not a valid word index for word {}".format(word.head, word.id)) from e
870
+ if word.head != head.id:
871
+ raise ValueError("Dependency tree is incorrectly constructed")
872
+ self.dependencies.append((head, word.deprel, word))
873
+
874
+ def build_fake_dependencies(self):
875
+ self.dependencies = []
876
+ for word_idx, word in enumerate(self.words):
877
+ word.head = word_idx # note that this goes one previous to the index
878
+ word.deprel = "root" if word_idx == 0 else "dep"
879
+ word.deps = "%d:%s" % (word.head, word.deprel)
880
+ self.dependencies.append((word_idx, word.deprel, word))
881
+
882
+ def print_dependencies(self, file=None):
883
+ """ Print the dependencies for this sentence. """
884
+ for dep_edge in self.dependencies:
885
+ print((dep_edge[2].text, dep_edge[0].id, dep_edge[1]), file=file)
886
+
887
+ def dependencies_string(self):
888
+ """ Dump the dependencies for this sentence into string. """
889
+ dep_string = io.StringIO()
890
+ self.print_dependencies(file=dep_string)
891
+ return dep_string.getvalue().strip()
892
+
893
+ def print_tokens(self, file=None):
894
+ """ Print the tokens for this sentence. """
895
+ for tok in self.tokens:
896
+ print(tok.pretty_print(), file=file)
897
+
898
+ def tokens_string(self):
899
+ """ Dump the tokens for this sentence into string. """
900
+ toks_string = io.StringIO()
901
+ self.print_tokens(file=toks_string)
902
+ return toks_string.getvalue().strip()
903
+
904
+ def print_words(self, file=None):
905
+ """ Print the words for this sentence. """
906
+ for word in self.words:
907
+ print(word.pretty_print(), file=file)
908
+
909
+ def words_string(self):
910
+ """ Dump the words for this sentence into string. """
911
+ wrds_string = io.StringIO()
912
+ self.print_words(file=wrds_string)
913
+ return wrds_string.getvalue().strip()
914
+
915
+ def to_dict(self):
916
+ """ Dumps the sentence into a list of dictionary for each token in the sentence.
917
+ """
918
+ ret = []
919
+ empty_idx = 0
920
+ for token_idx, token in enumerate(self.tokens):
921
+ while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:
922
+ ret.append(self._empty_words[empty_idx].to_dict())
923
+ empty_idx += 1
924
+ ret += token.to_dict()
925
+ for empty_word in self._empty_words[empty_idx:]:
926
+ ret.append(empty_word.to_dict())
927
+ return ret
928
+
929
+ def __repr__(self):
930
+ return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
931
+
932
+ def __format__(self, spec):
933
+ if spec != 'c' and spec != 'C':
934
+ return str(self)
935
+
936
+ pieces = []
937
+ empty_idx = 0
938
+ for token_idx, token in enumerate(self.tokens):
939
+ while empty_idx < len(self._empty_words) and self._empty_words[empty_idx].id[0] < token.id[0]:
940
+ pieces.append(self._empty_words[empty_idx].to_conll_text())
941
+ empty_idx += 1
942
+ pieces.append(token.to_conll_text())
943
+ for empty_word in self._empty_words[empty_idx:]:
944
+ pieces.append(empty_word.to_conll_text())
945
+
946
+ if spec == 'c':
947
+ return "\n".join(pieces)
948
+ elif spec == 'C':
949
+ tokens = "\n".join(pieces)
950
+ if len(self.comments) > 0:
951
+ text = "\n".join(self.comments)
952
+ return text + "\n" + tokens
953
+ return tokens
954
+
955
+ def init_from_misc(unit):
956
+ """Create attributes by parsing from the `misc` field.
957
+
958
+ Also, remove start_char, end_char, and any other values we can set
959
+ from the misc field if applicable, so that we don't repeat ourselves
960
+ """
961
+ remaining_values = []
962
+ for item in unit._misc.split('|'):
963
+ key_value = item.split('=', 1)
964
+ if len(key_value) == 2:
965
+ # some key_value can not be split
966
+ key, value = key_value
967
+ # start & end char are kept as ints
968
+ if key in (START_CHAR, END_CHAR):
969
+ value = int(value)
970
+ # set attribute
971
+ attr = f'_{key}'
972
+ if hasattr(unit, attr):
973
+ setattr(unit, attr, value)
974
+ continue
975
+ elif key == NER:
976
+ # special case skipping NER for Words, since there is no Word NER field
977
+ continue
978
+ remaining_values.append(item)
979
+ unit._misc = "|".join(remaining_values)
980
+
981
+
982
+ def dict_to_conll_text(token_dict, id_connector="-"):
983
+ token_conll = ['_' for i in range(FIELD_NUM)]
984
+ misc = []
985
+ for key in token_dict:
986
+ if key == START_CHAR or key == END_CHAR:
987
+ misc.append("{}={}".format(key, token_dict[key]))
988
+ elif key == NER:
989
+ # TODO: potentially need to escape =|\ in the NER
990
+ misc.append("{}={}".format(key, token_dict[key]))
991
+ elif key == COREF_CHAINS:
992
+ chains = token_dict[key]
993
+ if len(chains) > 0:
994
+ misc_chains = []
995
+ for chain in chains:
996
+ if chain.is_start and chain.is_end:
997
+ coref_position = "unit-"
998
+ elif chain.is_start:
999
+ coref_position = "start-"
1000
+ elif chain.is_end:
1001
+ coref_position = "end-"
1002
+ else:
1003
+ coref_position = "middle-"
1004
+ is_representative = "repr-" if chain.is_representative else ""
1005
+ misc_chains.append("%s%sid%d" % (coref_position, is_representative, chain.chain.index))
1006
+ misc.append("{}={}".format(key, ",".join(misc_chains)))
1007
+ elif key == MISC:
1008
+ # avoid appending a blank misc entry.
1009
+ # otherwise the resulting misc field in the conll doc will wind up being blank text
1010
+ # TODO: potentially need to escape =|\ in the MISC as well
1011
+ if token_dict[key]:
1012
+ misc.append(token_dict[key])
1013
+ elif key == ID:
1014
+ token_conll[FIELD_TO_IDX[key]] = id_connector.join([str(x) for x in token_dict[key]]) if isinstance(token_dict[key], tuple) else str(token_dict[key])
1015
+ elif key in FIELD_TO_IDX:
1016
+ token_conll[FIELD_TO_IDX[key]] = str(token_dict[key])
1017
+ if misc:
1018
+ token_conll[FIELD_TO_IDX[MISC]] = "|".join(misc)
1019
+ else:
1020
+ token_conll[FIELD_TO_IDX[MISC]] = '_'
1021
+ # when a word (not mwt token) without head is found, we insert dummy head as required by the UD eval script
1022
+ if '-' not in token_conll[FIELD_TO_IDX[ID]] and '.' not in token_conll[FIELD_TO_IDX[ID]] and HEAD not in token_dict:
1023
+ token_conll[FIELD_TO_IDX[HEAD]] = str(int(token_dict[ID] if isinstance(token_dict[ID], int) else token_dict[ID][0]) - 1) # evaluation script requires head: int
1024
+ return "\t".join(token_conll)
1025
+
1026
+
1027
+ class Token(StanzaObject):
1028
+ """ A token class that stores attributes of a token and carries a list of words. A token corresponds to a unit in the raw
1029
+ text. In some languages such as English, a token has a one-to-one mapping to a word, while in other languages such as French,
1030
+ a (multi-word) token might be expanded into multiple words that carry syntactic annotations.
1031
+ """
1032
+
1033
+ def __init__(self, sentence, token_entry, words=None):
1034
+ """
1035
+ Construct a token given a dictionary format token entry. Optionally link itself to the corresponding words.
1036
+ The owning sentence must be passed in.
1037
+ """
1038
+ self._id = token_entry.get(ID)
1039
+ self._text = token_entry.get(TEXT)
1040
+ if not self._id:
1041
+ raise ValueError('id not included for the token')
1042
+ if not self._text:
1043
+ raise ValueError('text not included for the token')
1044
+ self._misc = token_entry.get(MISC, None)
1045
+ self._ner = token_entry.get(NER, None)
1046
+ self._multi_ner = token_entry.get(MULTI_NER, None)
1047
+ self._words = words if words is not None else []
1048
+ self._start_char = token_entry.get(START_CHAR, None)
1049
+ self._end_char = token_entry.get(END_CHAR, None)
1050
+ self._sent = sentence
1051
+ self._mexp = token_entry.get(MEXP, None)
1052
+ self._spaces_before = ""
1053
+ self._spaces_after = " "
1054
+
1055
+ if self._misc is not None:
1056
+ init_from_misc(self)
1057
+
1058
+ @property
1059
+ def id(self):
1060
+ """ Access the index of this token. """
1061
+ return self._id
1062
+
1063
+ @id.setter
1064
+ def id(self, value):
1065
+ """ Set the token's id value. """
1066
+ self._id = value
1067
+
1068
+ @property
1069
+ def manual_expansion(self):
1070
+ """ Access the whether this token was manually expanded. """
1071
+ return self._mexp
1072
+
1073
+ @manual_expansion.setter
1074
+ def manual_expansion(self, value):
1075
+ """ Set the whether this token was manually expanded. """
1076
+ self._mexp = value
1077
+
1078
+ @property
1079
+ def text(self):
1080
+ """ Access the text of this token. Example: 'The' """
1081
+ return self._text
1082
+
1083
+ @text.setter
1084
+ def text(self, value):
1085
+ """ Set the token's text value. Example: 'The' """
1086
+ self._text = value
1087
+
1088
+ @property
1089
+ def misc(self):
1090
+ """ Access the miscellaneousness of this token. """
1091
+ return self._misc
1092
+
1093
+ @misc.setter
1094
+ def misc(self, value):
1095
+ """ Set the token's miscellaneousness value. """
1096
+ self._misc = value if self._is_null(value) == False else None
1097
+
1098
+ def consolidate_whitespace(self):
1099
+ """
1100
+ Remove whitespace misc annotations from the Words and mark the whitespace on the Tokens
1101
+ """
1102
+ found_after = False
1103
+ found_before = False
1104
+ num_words = len(self.words)
1105
+ for word_idx, word in enumerate(self.words):
1106
+ misc = word.misc
1107
+ if not misc:
1108
+ continue
1109
+ pieces = misc.split("|")
1110
+ if word_idx == 0:
1111
+ if any(piece.startswith("SpacesBefore=") for piece in pieces):
1112
+ self.spaces_before = misc_to_space_before(misc)
1113
+ found_before = True
1114
+ else:
1115
+ if any(piece.startswith("SpacesBefore=") for piece in pieces):
1116
+ warnings.warn("Found a SpacesBefore MISC annotation on a Word that was not the first Word in a Token")
1117
+ if word_idx == num_words - 1:
1118
+ if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
1119
+ self.spaces_after = misc_to_space_after(misc)
1120
+ found_after = True
1121
+ else:
1122
+ if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
1123
+ unexpected_space_after = misc_to_space_after(misc)
1124
+ if unexpected_space_after == "":
1125
+ warnings.warn("Unexpected SpaceAfter=No annotation on a word in the middle of an MWT")
1126
+ else:
1127
+ warnings.warn("Unexpected SpacesAfter on a word in the middle on an MWT")
1128
+ pieces = [x for x in pieces if not x.startswith("SpacesAfter=") and not x.startswith("SpaceAfter=") and not x.startswith("SpacesBefore=")]
1129
+ word.misc = "|".join(pieces)
1130
+
1131
+ misc = self.misc
1132
+ if misc:
1133
+ pieces = misc.split("|")
1134
+ if any(piece.startswith("SpacesBefore=") for piece in pieces):
1135
+ spaces_before = misc_to_space_before(misc)
1136
+ if found_before:
1137
+ if spaces_before != self.spaces_before:
1138
+ warnings.warn("Found conflicting SpacesBefore on a token and its word!")
1139
+ else:
1140
+ self.spaces_before = spaces_before
1141
+ if any(piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter=") for piece in pieces):
1142
+ spaces_after = misc_to_space_after(misc)
1143
+ if found_after:
1144
+ if spaces_after != self.spaces_after:
1145
+ warnings.warn("Found conflicting SpaceAfter / SpacesAfter on a token and its word!")
1146
+ else:
1147
+ self.spaces_after = spaces_after
1148
+ pieces = [x for x in pieces if not x.startswith("SpacesAfter=") and not x.startswith("SpaceAfter=") and not x.startswith("SpacesBefore=")]
1149
+ self.misc = "|".join(pieces)
1150
+
1151
+ @property
1152
+ def spaces_before(self):
1153
+ """ SpacesBefore for the token. Translated from the MISC fields """
1154
+ return self._spaces_before
1155
+
1156
+ @spaces_before.setter
1157
+ def spaces_before(self, value):
1158
+ self._spaces_before = value
1159
+
1160
+ @property
1161
+ def spaces_after(self):
1162
+ """ SpaceAfter or SpacesAfter for the token. Translated from the MISC field """
1163
+ return self._spaces_after
1164
+
1165
+ @spaces_after.setter
1166
+ def spaces_after(self, value):
1167
+ self._spaces_after = value
1168
+
1169
+ @property
1170
+ def words(self):
1171
+ """ Access the list of syntactic words underlying this token. """
1172
+ return self._words
1173
+
1174
+ @words.setter
1175
+ def words(self, value):
1176
+ """ Set this token's list of underlying syntactic words. """
1177
+ self._words = value
1178
+ for w in self._words:
1179
+ w.parent = self
1180
+
1181
+ @property
1182
+ def start_char(self):
1183
+ """ Access the start character index for this token in the raw text. """
1184
+ return self._start_char
1185
+
1186
+ @property
1187
+ def end_char(self):
1188
+ """ Access the end character index for this token in the raw text. """
1189
+ return self._end_char
1190
+
1191
+ @property
1192
+ def ner(self):
1193
+ """ Access the NER tag of this token. Example: 'B-ORG'"""
1194
+ return self._ner
1195
+
1196
+ @ner.setter
1197
+ def ner(self, value):
1198
+ """ Set the token's NER tag. Example: 'B-ORG'"""
1199
+ self._ner = value if self._is_null(value) == False else None
1200
+
1201
+ @property
1202
+ def multi_ner(self):
1203
+ """ Access the MULTI_NER tag of this token. Example: '(B-ORG, B-DISEASE)'"""
1204
+ return self._multi_ner
1205
+
1206
+ @multi_ner.setter
1207
+ def multi_ner(self, value):
1208
+ """ Set the token's MULTI_NER tag. Example: '(B-ORG, B-DISEASE)'"""
1209
+ self._multi_ner = value if self._is_null(value) == False else None
1210
+
1211
+ @property
1212
+ def sent(self):
1213
+ """ Access the pointer to the sentence that this token belongs to. """
1214
+ return self._sent
1215
+
1216
+ @sent.setter
1217
+ def sent(self, value):
1218
+ """ Set the pointer to the sentence that this token belongs to. """
1219
+ self._sent = value
1220
+
1221
+ def __repr__(self):
1222
+ return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
1223
+
1224
+ def __format__(self, spec):
1225
+ if spec == 'C':
1226
+ return "\n".join(self.to_conll_text())
1227
+ elif spec == 'P':
1228
+ return self.pretty_print()
1229
+ else:
1230
+ return str(self)
1231
+
1232
+ def to_conll_text(self):
1233
+ return "\n".join(dict_to_conll_text(x) for x in self.to_dict())
1234
+
1235
+ def to_dict(self, fields=[ID, TEXT, MISC, START_CHAR, END_CHAR, NER, MULTI_NER, MEXP]):
1236
+ """ Dumps the token into a list of dictionary for this token with its extended words
1237
+ if the token is a multi-word token.
1238
+ """
1239
+ ret = []
1240
+ if len(self.id) > 1:
1241
+ token_dict = {}
1242
+ for field in fields:
1243
+ if getattr(self, field) is not None:
1244
+ token_dict[field] = getattr(self, field)
1245
+ if MISC in fields:
1246
+ spaces_after = self.spaces_after
1247
+ if spaces_after is not None and spaces_after != ' ':
1248
+ space_misc = space_after_to_misc(spaces_after)
1249
+ if token_dict.get(MISC):
1250
+ token_dict[MISC] = token_dict[MISC] + "|" + space_misc
1251
+ else:
1252
+ token_dict[MISC] = space_misc
1253
+
1254
+ spaces_before = self.spaces_before
1255
+ if spaces_before is not None and spaces_before != '':
1256
+ space_misc = space_before_to_misc(spaces_before)
1257
+ if token_dict.get(MISC):
1258
+ token_dict[MISC] = token_dict[MISC] + "|" + space_misc
1259
+ else:
1260
+ token_dict[MISC] = space_misc
1261
+
1262
+ ret.append(token_dict)
1263
+ for word in self.words:
1264
+ word_dict = word.to_dict()
1265
+ if len(self.id) == 1 and NER in fields and getattr(self, NER) is not None: # propagate NER label to Word if it is a single-word token
1266
+ word_dict[NER] = getattr(self, NER)
1267
+ if len(self.id) == 1 and MULTI_NER in fields and getattr(self, MULTI_NER) is not None: # propagate MULTI_NER label to Word if it is a single-word token
1268
+ word_dict[MULTI_NER] = getattr(self, MULTI_NER)
1269
+ if len(self.id) == 1 and MISC in fields:
1270
+ spaces_after = self.spaces_after
1271
+ if spaces_after is not None and spaces_after != ' ':
1272
+ space_misc = space_after_to_misc(spaces_after)
1273
+ if word_dict.get(MISC):
1274
+ word_dict[MISC] = word_dict[MISC] + "|" + space_misc
1275
+ else:
1276
+ word_dict[MISC] = space_misc
1277
+
1278
+ spaces_before = self.spaces_before
1279
+ if spaces_before is not None and spaces_before != '':
1280
+ space_misc = space_before_to_misc(spaces_before)
1281
+ if word_dict.get(MISC):
1282
+ word_dict[MISC] = word_dict[MISC] + "|" + space_misc
1283
+ else:
1284
+ word_dict[MISC] = space_misc
1285
+ ret.append(word_dict)
1286
+ return ret
1287
+
1288
+ def pretty_print(self):
1289
+ """ Print this token with its extended words in one line. """
1290
+ return f"<{self.__class__.__name__} id={'-'.join([str(x) for x in self.id])};words=[{', '.join([word.pretty_print() for word in self.words])}]>"
1291
+
1292
+ def _is_null(self, value):
1293
+ return (value is None) or (value == '_')
1294
+
1295
+ def is_mwt(self):
1296
+ return len(self.words) > 1
1297
+
1298
+ class Word(StanzaObject):
1299
+ """ A word class that stores attributes of a word.
1300
+ """
1301
+
1302
+ def __init__(self, sentence, word_entry):
1303
+ """ Construct a word given a dictionary format word entry.
1304
+ """
1305
+ self._id = word_entry.get(ID, None)
1306
+ if isinstance(self._id, tuple):
1307
+ if len(self._id) == 1:
1308
+ self._id = self._id[0]
1309
+ self._text = word_entry.get(TEXT, None)
1310
+
1311
+ assert self._id is not None and self._text is not None, 'id and text should be included for the word. {}'.format(word_entry)
1312
+
1313
+ self._lemma = word_entry.get(LEMMA, None)
1314
+ self._upos = word_entry.get(UPOS, None)
1315
+ self._xpos = word_entry.get(XPOS, None)
1316
+ self._feats = word_entry.get(FEATS, None)
1317
+ self._head = word_entry.get(HEAD, None)
1318
+ self._deprel = word_entry.get(DEPREL, None)
1319
+ self._misc = word_entry.get(MISC, None)
1320
+ self._start_char = word_entry.get(START_CHAR, None)
1321
+ self._end_char = word_entry.get(END_CHAR, None)
1322
+ self._parent = None
1323
+ self._sent = sentence
1324
+ self._mexp = word_entry.get(MEXP, None)
1325
+ self._coref_chains = None
1326
+
1327
+ if self._misc is not None:
1328
+ init_from_misc(self)
1329
+
1330
+ # use the setter, which will go up to the sentence and set the
1331
+ # dependencies on that graph
1332
+ self.deps = word_entry.get(DEPS, None)
1333
+
1334
+ @property
1335
+ def manual_expansion(self):
1336
+ """ Access the whether this token was manually expanded. """
1337
+ return self._mexp
1338
+
1339
+ @manual_expansion.setter
1340
+ def manual_expansion(self, value):
1341
+ """ Set the whether this token was manually expanded. """
1342
+ self._mexp = value
1343
+
1344
+ @property
1345
+ def id(self):
1346
+ """ Access the index of this word. """
1347
+ return self._id
1348
+
1349
+ @id.setter
1350
+ def id(self, value):
1351
+ """ Set the word's index value. """
1352
+ self._id = value
1353
+
1354
+ @property
1355
+ def text(self):
1356
+ """ Access the text of this word. Example: 'The'"""
1357
+ return self._text
1358
+
1359
+ @text.setter
1360
+ def text(self, value):
1361
+ """ Set the word's text value. Example: 'The'"""
1362
+ self._text = value
1363
+
1364
+ @property
1365
+ def lemma(self):
1366
+ """ Access the lemma of this word. """
1367
+ return self._lemma
1368
+
1369
+ @lemma.setter
1370
+ def lemma(self, value):
1371
+ """ Set the word's lemma value. """
1372
+ self._lemma = value if self._is_null(value) == False or self._text == '_' else None
1373
+
1374
+ @property
1375
+ def upos(self):
1376
+ """ Access the universal part-of-speech of this word. Example: 'NOUN'"""
1377
+ return self._upos
1378
+
1379
+ @upos.setter
1380
+ def upos(self, value):
1381
+ """ Set the word's universal part-of-speech value. Example: 'NOUN'"""
1382
+ self._upos = value if self._is_null(value) == False else None
1383
+
1384
+ @property
1385
+ def xpos(self):
1386
+ """ Access the treebank-specific part-of-speech of this word. Example: 'NNP'"""
1387
+ return self._xpos
1388
+
1389
+ @xpos.setter
1390
+ def xpos(self, value):
1391
+ """ Set the word's treebank-specific part-of-speech value. Example: 'NNP'"""
1392
+ self._xpos = value if self._is_null(value) == False else None
1393
+
1394
+ @property
1395
+ def feats(self):
1396
+ """ Access the morphological features of this word. Example: 'Gender=Fem'"""
1397
+ return self._feats
1398
+
1399
+ @feats.setter
1400
+ def feats(self, value):
1401
+ """ Set this word's morphological features. Example: 'Gender=Fem'"""
1402
+ self._feats = value if self._is_null(value) == False else None
1403
+
1404
+ @property
1405
+ def head(self):
1406
+ """ Access the id of the governor of this word. """
1407
+ return self._head
1408
+
1409
+ @head.setter
1410
+ def head(self, value):
1411
+ """ Set the word's governor id value. """
1412
+ self._head = int(value) if self._is_null(value) == False else None
1413
+
1414
+ @property
1415
+ def deprel(self):
1416
+ """ Access the dependency relation of this word. Example: 'nmod'"""
1417
+ return self._deprel
1418
+
1419
+ @deprel.setter
1420
+ def deprel(self, value):
1421
+ """ Set the word's dependency relation value. Example: 'nmod'"""
1422
+ self._deprel = value if self._is_null(value) == False else None
1423
+
1424
+ @property
1425
+ def deps(self):
1426
+ """ Access the dependencies of this word. """
1427
+ graph = self._sent._enhanced_dependencies
1428
+ if graph is None or not graph.has_node(self.id):
1429
+ return None
1430
+
1431
+ data = []
1432
+ predecessors = sorted(list(graph.predecessors(self.id)), key=lambda x: x if isinstance(x, tuple) else (x,))
1433
+ for parent in predecessors:
1434
+ deps = sorted(list(graph.get_edge_data(parent, self.id)))
1435
+ for dep in deps:
1436
+ if isinstance(parent, int):
1437
+ data.append("%d:%s" % (parent, dep))
1438
+ else:
1439
+ data.append("%d.%d:%s" % (parent[0], parent[1], dep))
1440
+ if not data:
1441
+ return None
1442
+
1443
+ return "|".join(data)
1444
+
1445
+ @deps.setter
1446
+ def deps(self, value):
1447
+ """ Set the word's dependencies value. """
1448
+ graph = self._sent._enhanced_dependencies
1449
+ # if we don't have a graph, and we aren't trying to set any actual
1450
+ # dependencies, we can save the time of doing anything else
1451
+ if graph is None and value is None:
1452
+ return
1453
+
1454
+ if graph is None:
1455
+ graph = nx.MultiDiGraph()
1456
+ self._sent._enhanced_dependencies = graph
1457
+ # need to make a new list: cannot iterate and delete at the same time
1458
+ if graph.has_node(self.id):
1459
+ in_edges = list(graph.in_edges(self.id))
1460
+ graph.remove_edges_from(in_edges)
1461
+
1462
+ if value is None:
1463
+ return
1464
+
1465
+ if isinstance(value, str):
1466
+ value = value.split("|")
1467
+ if all(isinstance(x, str) for x in value):
1468
+ value = [x.split(":", maxsplit=1) for x in value]
1469
+ for parent, dep in value:
1470
+ # we have to match the format of the IDs. since the IDs
1471
+ # of the words are int if they aren't empty words, we need
1472
+ # to convert single int IDs into int instead of tuple
1473
+ parent = tuple(map(int, parent.split(".", maxsplit=1)))
1474
+ if len(parent) == 1:
1475
+ parent = parent[0]
1476
+ graph.add_edge(parent, self.id, dep)
1477
+
1478
+ @property
1479
+ def misc(self):
1480
+ """ Access the miscellaneousness of this word. """
1481
+ return self._misc
1482
+
1483
+ @misc.setter
1484
+ def misc(self, value):
1485
+ """ Set the word's miscellaneousness value. """
1486
+ self._misc = value if self._is_null(value) == False else None
1487
+
1488
+ @property
1489
+ def start_char(self):
1490
+ """ Access the start character index for this token in the raw text. """
1491
+ return self._start_char
1492
+
1493
+ @start_char.setter
1494
+ def start_char(self, value):
1495
+ self._start_char = value
1496
+
1497
+ @property
1498
+ def end_char(self):
1499
+ """ Access the end character index for this token in the raw text. """
1500
+ return self._end_char
1501
+
1502
+ @end_char.setter
1503
+ def end_char(self, value):
1504
+ self._end_char = value
1505
+
1506
+ @property
1507
+ def parent(self):
1508
+ """ Access the parent token of this word. In the case of a multi-word token, a token can be the parent of
1509
+ multiple words. Note that this should return a reference to the parent token object.
1510
+ """
1511
+ return self._parent
1512
+
1513
+ @parent.setter
1514
+ def parent(self, value):
1515
+ """ Set this word's parent token. In the case of a multi-word token, a token can be the parent of
1516
+ multiple words. Note that value here should be a reference to the parent token object.
1517
+ """
1518
+ self._parent = value
1519
+
1520
+ @property
1521
+ def pos(self):
1522
+ """ Access the universal part-of-speech of this word. Example: 'NOUN'"""
1523
+ return self._upos
1524
+
1525
+ @pos.setter
1526
+ def pos(self, value):
1527
+ """ Set the word's universal part-of-speech value. Example: 'NOUN'"""
1528
+ self._upos = value if self._is_null(value) == False else None
1529
+
1530
+ @property
1531
+ def coref_chains(self):
1532
+ """
1533
+ coref_chains points to a list of CorefChain namedtuple, which has a list of mentions and a representative mention.
1534
+
1535
+ Useful for disambiguating words such as "him" (in languages where coref is available)
1536
+
1537
+ Theoretically it is possible for multiple corefs to occur at the same word. For example,
1538
+ "Chris Manning's NLP Group"
1539
+ could have "Chris Manning" and "Chris Manning's NLP Group" as overlapping entities
1540
+ """
1541
+ return self._coref_chains
1542
+
1543
+ @coref_chains.setter
1544
+ def coref_chains(self, chain):
1545
+ """ Set the backref for the coref chains """
1546
+ self._coref_chains = chain
1547
+
1548
+ @property
1549
+ def sent(self):
1550
+ """ Access the pointer to the sentence that this word belongs to. """
1551
+ return self._sent
1552
+
1553
+ @sent.setter
1554
+ def sent(self, value):
1555
+ """ Set the pointer to the sentence that this word belongs to. """
1556
+ self._sent = value
1557
+
1558
+ def __repr__(self):
1559
+ return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
1560
+
1561
+ def __format__(self, spec):
1562
+ if spec == 'C':
1563
+ return self.to_conll_text()
1564
+ elif spec == 'P':
1565
+ return self.pretty_print()
1566
+ else:
1567
+ return str(self)
1568
+
1569
+ def to_conll_text(self):
1570
+ """
1571
+ Turn a word into a conll representation (10 column tab separated)
1572
+ """
1573
+ token_dict = self.to_dict()
1574
+ return dict_to_conll_text(token_dict, '.')
1575
+
1576
+ def to_dict(self, fields=[ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, MEXP, COREF_CHAINS]):
1577
+ """ Dumps the word into a dictionary.
1578
+ """
1579
+ word_dict = {}
1580
+ for field in fields:
1581
+ if getattr(self, field) is not None:
1582
+ word_dict[field] = getattr(self, field)
1583
+ return word_dict
1584
+
1585
+ def pretty_print(self):
1586
+ """ Print the word in one line. """
1587
+ features = [ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL]
1588
+ feature_str = ";".join(["{}={}".format(k, getattr(self, k)) for k in features if getattr(self, k) is not None])
1589
+ return f"<{self.__class__.__name__} {feature_str}>"
1590
+
1591
+ def _is_null(self, value):
1592
+ return (value is None) or (value == '_')
1593
+
1594
+
1595
+ class Span(StanzaObject):
1596
+ """ A span class that stores attributes of a textual span. A span can be typed.
1597
+ A range of objects (e.g., entity mentions) can be represented as spans.
1598
+ """
1599
+
1600
+ def __init__(self, span_entry=None, tokens=None, type=None, doc=None, sent=None):
1601
+ """ Construct a span given a span entry or a list of tokens. A valid reference to a doc
1602
+ must be provided to construct a span (otherwise the text of the span cannot be initialized).
1603
+ """
1604
+ assert span_entry is not None or (tokens is not None and type is not None), \
1605
+ 'Either a span_entry or a token list needs to be provided to construct a span.'
1606
+ assert doc is not None, 'A parent doc must be provided to construct a span.'
1607
+ self._text, self._type, self._start_char, self._end_char = [None] * 4
1608
+ self._tokens = []
1609
+ self._words = []
1610
+ self._doc = doc
1611
+ self._sent = sent
1612
+
1613
+ if span_entry is not None:
1614
+ self.init_from_entry(span_entry)
1615
+
1616
+ if tokens is not None:
1617
+ self.init_from_tokens(tokens, type)
1618
+
1619
+ def init_from_entry(self, span_entry):
1620
+ self.text = span_entry.get(TEXT, None)
1621
+ self.type = span_entry.get(TYPE, None)
1622
+ self.start_char = span_entry.get(START_CHAR, None)
1623
+ self.end_char = span_entry.get(END_CHAR, None)
1624
+
1625
+ def init_from_tokens(self, tokens, type):
1626
+ assert isinstance(tokens, list), 'Tokens must be provided as a list to construct a span.'
1627
+ assert len(tokens) > 0, "Tokens of a span cannot be an empty list."
1628
+ self.tokens = tokens
1629
+ self.type = type
1630
+ # load start and end char offsets from tokens
1631
+ self.start_char = self.tokens[0].start_char
1632
+ self.end_char = self.tokens[-1].end_char
1633
+ if self.doc is not None and self.doc.text is not None:
1634
+ self.text = self.doc.text[self.start_char:self.end_char]
1635
+ elif tokens[0].sent is tokens[-1].sent:
1636
+ sentence = tokens[0].sent
1637
+ text_start = tokens[0].start_char - sentence.tokens[0].start_char
1638
+ text_end = tokens[-1].end_char - sentence.tokens[0].start_char
1639
+ self.text = sentence.text[text_start:text_end]
1640
+ else:
1641
+ # TODO: do any spans ever cross sentences?
1642
+ raise RuntimeError("Document text does not exist, and the span tested crosses two sentences, so it is impossible to extract the entity text!")
1643
+ # collect the words of the span following tokens
1644
+ self.words = [w for t in tokens for w in t.words]
1645
+ # set the sentence back-pointer to point to the sentence of the first token
1646
+ self.sent = tokens[0].sent
1647
+
1648
+ @property
1649
+ def doc(self):
1650
+ """ Access the parent doc of this span. """
1651
+ return self._doc
1652
+
1653
+ @doc.setter
1654
+ def doc(self, value):
1655
+ """ Set the parent doc of this span. """
1656
+ self._doc = value
1657
+
1658
+ @property
1659
+ def text(self):
1660
+ """ Access the text of this span. Example: 'Stanford University'"""
1661
+ return self._text
1662
+
1663
+ @text.setter
1664
+ def text(self, value):
1665
+ """ Set the span's text value. Example: 'Stanford University'"""
1666
+ self._text = value
1667
+
1668
+ @property
1669
+ def tokens(self):
1670
+ """ Access reference to a list of tokens that correspond to this span. """
1671
+ return self._tokens
1672
+
1673
+ @tokens.setter
1674
+ def tokens(self, value):
1675
+ """ Set the span's list of tokens. """
1676
+ self._tokens = value
1677
+
1678
+ @property
1679
+ def words(self):
1680
+ """ Access reference to a list of words that correspond to this span. """
1681
+ return self._words
1682
+
1683
+ @words.setter
1684
+ def words(self, value):
1685
+ """ Set the span's list of words. """
1686
+ self._words = value
1687
+
1688
+ @property
1689
+ def type(self):
1690
+ """ Access the type of this span. Example: 'PERSON'"""
1691
+ return self._type
1692
+
1693
+ @type.setter
1694
+ def type(self, value):
1695
+ """ Set the type of this span. """
1696
+ self._type = value
1697
+
1698
+ @property
1699
+ def start_char(self):
1700
+ """ Access the start character offset of this span. """
1701
+ return self._start_char
1702
+
1703
+ @start_char.setter
1704
+ def start_char(self, value):
1705
+ """ Set the start character offset of this span. """
1706
+ self._start_char = value
1707
+
1708
+ @property
1709
+ def end_char(self):
1710
+ """ Access the end character offset of this span. """
1711
+ return self._end_char
1712
+
1713
+ @end_char.setter
1714
+ def end_char(self, value):
1715
+ """ Set the end character offset of this span. """
1716
+ self._end_char = value
1717
+
1718
+ @property
1719
+ def sent(self):
1720
+ """ Access the pointer to the sentence that this span belongs to. """
1721
+ return self._sent
1722
+
1723
+ @sent.setter
1724
+ def sent(self, value):
1725
+ """ Set the pointer to the sentence that this span belongs to. """
1726
+ self._sent = value
1727
+
1728
+ def to_dict(self):
1729
+ """ Dumps the span into a dictionary. """
1730
+ attrs = ['text', 'type', 'start_char', 'end_char']
1731
+ span_dict = dict([(attr_name, getattr(self, attr_name)) for attr_name in attrs])
1732
+ return span_dict
1733
+
1734
+ def __repr__(self):
1735
+ return json.dumps(self.to_dict(), indent=2, ensure_ascii=False, cls=DocJSONEncoder)
1736
+
1737
+ def pretty_print(self):
1738
+ """ Print the span in one line. """
1739
+ span_dict = self.to_dict()
1740
+ feature_str = ";".join(["{}={}".format(k,v) for k,v in span_dict.items()])
1741
+ return f"<{self.__class__.__name__} {feature_str}>"
stanza/stanza/models/common/dropout.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class WordDropout(nn.Module):
5
+ """ A word dropout layer that's designed for embedded inputs (e.g., any inputs to an LSTM layer).
6
+ Given a batch of embedded inputs, this layer randomly set some of them to be a replacement state.
7
+ Note that this layer assumes the last dimension of the input to be the hidden dimension of a unit.
8
+ """
9
+ def __init__(self, dropprob):
10
+ super().__init__()
11
+ self.dropprob = dropprob
12
+
13
+ def forward(self, x, replacement=None):
14
+ if not self.training or self.dropprob == 0:
15
+ return x
16
+
17
+ masksize = [y for y in x.size()]
18
+ masksize[-1] = 1
19
+ dropmask = torch.rand(*masksize, device=x.device) < self.dropprob
20
+
21
+ res = x.masked_fill(dropmask, 0)
22
+ if replacement is not None:
23
+ res = res + dropmask.float() * replacement
24
+
25
+ return res
26
+
27
+ def extra_repr(self):
28
+ return 'p={}'.format(self.dropprob)
29
+
30
+ class LockedDropout(nn.Module):
31
+ """
32
+ A variant of dropout layer that consistently drops out the same parameters over time. Also known as the variational dropout.
33
+ This implementation was modified from the LockedDropout implementation in the flair library (https://github.com/zalandoresearch/flair).
34
+ """
35
+ def __init__(self, dropprob, batch_first=True):
36
+ super().__init__()
37
+ self.dropprob = dropprob
38
+ self.batch_first = batch_first
39
+
40
+ def forward(self, x):
41
+ if not self.training or self.dropprob == 0:
42
+ return x
43
+
44
+ if not self.batch_first:
45
+ m = x.new_empty(1, x.size(1), x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)
46
+ else:
47
+ m = x.new_empty(x.size(0), 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropprob)
48
+
49
+ mask = m.div(1 - self.dropprob).expand_as(x)
50
+ return mask * x
51
+
52
+ def extra_repr(self):
53
+ return 'p={}'.format(self.dropprob)
54
+
55
+ class SequenceUnitDropout(nn.Module):
56
+ """ A unit dropout layer that's designed for input of sequence units (e.g., word sequence, char sequence, etc.).
57
+ Given a sequence of unit indices, this layer randomly set some of them to be a replacement id (usually set to be <UNK>).
58
+ """
59
+ def __init__(self, dropprob, replacement_id):
60
+ super().__init__()
61
+ self.dropprob = dropprob
62
+ self.replacement_id = replacement_id
63
+
64
+ def forward(self, x):
65
+ """ :param: x must be a LongTensor of unit indices. """
66
+ if not self.training or self.dropprob == 0:
67
+ return x
68
+ masksize = [y for y in x.size()]
69
+ dropmask = torch.rand(*masksize, device=x.device) < self.dropprob
70
+ res = x.masked_fill(dropmask, self.replacement_id)
71
+ return res
72
+
73
+ def extra_repr(self):
74
+ return 'p={}, replacement_id={}'.format(self.dropprob, self.replacement_id)
75
+
stanza/stanza/models/common/exceptions.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A couple more specific FileNotFoundError exceptions
3
+
4
+ The idea being, the caller can catch it and report a more useful error resolution
5
+ """
6
+
7
+ import errno
8
+
9
+ class ForwardCharlmNotFoundError(FileNotFoundError):
10
+ def __init__(self, msg, filename):
11
+ super().__init__(errno.ENOENT, msg, filename)
12
+
13
+ class BackwardCharlmNotFoundError(FileNotFoundError):
14
+ def __init__(self, msg, filename):
15
+ super().__init__(errno.ENOENT, msg, filename)
stanza/stanza/models/common/foundation_cache.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Keeps BERT, charlm, word embedings in a cache to save memory
3
+ """
4
+
5
+ from collections import namedtuple
6
+ from copy import deepcopy
7
+ import logging
8
+ import threading
9
+
10
+ from stanza.models.common import bert_embedding
11
+ from stanza.models.common.char_model import CharacterLanguageModel
12
+ from stanza.models.common.pretrain import Pretrain
13
+
14
+ logger = logging.getLogger('stanza')
15
+
16
+ BertRecord = namedtuple('BertRecord', ['model', 'tokenizer', 'peft_ids'])
17
+
18
+ class FoundationCache:
19
+ def __init__(self, other=None, local_files_only=False):
20
+ if other is None:
21
+ self.bert = {}
22
+ self.charlms = {}
23
+ self.pretrains = {}
24
+ # future proof the module by using a lock for the glorious day
25
+ # when the GIL is finally gone
26
+ self.lock = threading.Lock()
27
+ else:
28
+ self.bert = other.bert
29
+ self.charlms = other.charlms
30
+ self.pretrains = other.pretrains
31
+ self.lock = other.lock
32
+ self.local_files_only=local_files_only
33
+
34
+ def load_bert(self, transformer_name, local_files_only=None):
35
+ m, t, _ = self.load_bert_with_peft(transformer_name, None, local_files_only=local_files_only)
36
+ return m, t
37
+
38
+ def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
39
+ """
40
+ Load a transformer only once
41
+
42
+ Uses a lock for thread safety
43
+ """
44
+ if transformer_name is None:
45
+ return None, None, None
46
+ with self.lock:
47
+ if transformer_name not in self.bert:
48
+ if local_files_only is None:
49
+ local_files_only = self.local_files_only
50
+ model, tokenizer = bert_embedding.load_bert(transformer_name, local_files_only=local_files_only)
51
+ self.bert[transformer_name] = BertRecord(model, tokenizer, {})
52
+ else:
53
+ logger.debug("Reusing bert %s", transformer_name)
54
+
55
+ bert_record = self.bert[transformer_name]
56
+ if not peft_name:
57
+ return bert_record.model, bert_record.tokenizer, None
58
+ if peft_name not in bert_record.peft_ids:
59
+ bert_record.peft_ids[peft_name] = 0
60
+ else:
61
+ bert_record.peft_ids[peft_name] = bert_record.peft_ids[peft_name] + 1
62
+ peft_name = "%s_%d" % (peft_name, bert_record.peft_ids[peft_name])
63
+ return bert_record.model, bert_record.tokenizer, peft_name
64
+
65
+ def load_charlm(self, filename):
66
+ if not filename:
67
+ return None
68
+
69
+ with self.lock:
70
+ if filename not in self.charlms:
71
+ logger.debug("Loading charlm from %s", filename)
72
+ self.charlms[filename] = CharacterLanguageModel.load(filename, finetune=False)
73
+ else:
74
+ logger.debug("Reusing charlm from %s", filename)
75
+
76
+ return self.charlms[filename]
77
+
78
+ def load_pretrain(self, filename):
79
+ """
80
+ Load a pretrained word embedding only once
81
+
82
+ Uses a lock for thread safety
83
+ """
84
+ if filename is None:
85
+ return None
86
+ with self.lock:
87
+ if filename not in self.pretrains:
88
+ logger.debug("Loading pretrain %s", filename)
89
+ self.pretrains[filename] = Pretrain(filename)
90
+ else:
91
+ logger.debug("Reusing pretrain %s", filename)
92
+
93
+ return self.pretrains[filename]
94
+
95
+ class NoTransformerFoundationCache(FoundationCache):
96
+ """
97
+ Uses the underlying FoundationCache, but hiding the transformer.
98
+
99
+ Useful for when loading a downstream model such as POS which has a
100
+ finetuned transformer, and we don't want the transformer reused
101
+ since it will then have the finetuned weights for other models
102
+ which don't want them
103
+ """
104
+ def load_bert(self, transformer_name, local_files_only=None):
105
+ return load_bert(transformer_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
106
+
107
+ def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
108
+ return load_bert_with_peft(transformer_name, peft_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
109
+
110
+ def load_bert(model_name, foundation_cache=None, local_files_only=None):
111
+ """
112
+ Load a bert, possibly using a foundation cache, ignoring the cache if None
113
+ """
114
+ if foundation_cache is None:
115
+ return bert_embedding.load_bert(model_name, local_files_only=local_files_only)
116
+ else:
117
+ return foundation_cache.load_bert(model_name, local_files_only=local_files_only)
118
+
119
+ def load_bert_with_peft(model_name, peft_name, foundation_cache=None, local_files_only=None):
120
+ if foundation_cache is None:
121
+ m, t = bert_embedding.load_bert(model_name, local_files_only=local_files_only)
122
+ return m, t, peft_name
123
+ return foundation_cache.load_bert_with_peft(model_name, peft_name, local_files_only=local_files_only)
124
+
125
+ def load_charlm(charlm_file, foundation_cache=None, finetune=False):
126
+ if not charlm_file:
127
+ return None
128
+
129
+ if finetune:
130
+ # can't use the cache in the case of a model which will be finetuned
131
+ # and the numbers will be different for other users of the model
132
+ return CharacterLanguageModel.load(charlm_file, finetune=True)
133
+
134
+ if foundation_cache is not None:
135
+ return foundation_cache.load_charlm(charlm_file)
136
+
137
+ logger.debug("Loading charlm from %s", charlm_file)
138
+ return CharacterLanguageModel.load(charlm_file, finetune=False)
139
+
140
+ def load_pretrain(filename, foundation_cache=None):
141
+ if not filename:
142
+ return None
143
+
144
+ if foundation_cache is not None:
145
+ return foundation_cache.load_pretrain(filename)
146
+
147
+ logger.debug("Loading pretrain from %s", filename)
148
+ return Pretrain(filename)
stanza/stanza/models/common/hlstm.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
5
+
6
+ from stanza.models.common.packed_lstm import PackedLSTM
7
+
8
+ class HLSTMCell(nn.modules.rnn.RNNCellBase):
9
+ """
10
+ A Highway LSTM Cell as proposed in Zhang et al. (2018) Highway Long Short-Term Memory RNNs for
11
+ Distant Speech Recognition.
12
+ """
13
+ def __init__(self, input_size, hidden_size, bias=True):
14
+ super(HLSTMCell, self).__init__()
15
+ self.input_size = input_size
16
+ self.hidden_size = hidden_size
17
+
18
+ # LSTM parameters
19
+ self.Wi = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
20
+ self.Wf = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
21
+ self.Wo = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
22
+ self.Wg = nn.Linear(input_size + hidden_size, hidden_size, bias=bias)
23
+
24
+ # highway gate parameters
25
+ self.gate = nn.Linear(input_size + 2 * hidden_size, hidden_size, bias=bias)
26
+
27
+ def forward(self, input, c_l_minus_one=None, hx=None):
28
+ self.check_forward_input(input)
29
+ if hx is None:
30
+ hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
31
+ hx = (hx, hx)
32
+ if c_l_minus_one is None:
33
+ c_l_minus_one = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
34
+
35
+ self.check_forward_hidden(input, hx[0], '[0]')
36
+ self.check_forward_hidden(input, hx[1], '[1]')
37
+ self.check_forward_hidden(input, c_l_minus_one, 'c_l_minus_one')
38
+
39
+ # vanilla LSTM computation
40
+ rec_input = torch.cat([input, hx[0]], 1)
41
+ i = F.sigmoid(self.Wi(rec_input))
42
+ f = F.sigmoid(self.Wf(rec_input))
43
+ o = F.sigmoid(self.Wo(rec_input))
44
+ g = F.tanh(self.Wg(rec_input))
45
+
46
+ # highway gates
47
+ gate = F.sigmoid(self.gate(torch.cat([c_l_minus_one, hx[1], input], 1)))
48
+
49
+ c = gate * c_l_minus_one + f * hx[1] + i * g
50
+ h = o * F.tanh(c)
51
+
52
+ return h, c
53
+
54
+ # Highway LSTM network, does NOT use the HLSTMCell above
55
+ class HighwayLSTM(nn.Module):
56
+ """
57
+ A Highway LSTM network, as used in the original Tensorflow version of the Dozat parser. Note that this
58
+ is independent from the HLSTMCell above.
59
+ """
60
+ def __init__(self, input_size, hidden_size,
61
+ num_layers=1, bias=True, batch_first=False,
62
+ dropout=0, bidirectional=False, rec_dropout=0, highway_func=None, pad=False):
63
+ super(HighwayLSTM, self).__init__()
64
+ self.input_size = input_size
65
+ self.hidden_size = hidden_size
66
+ self.num_layers = num_layers
67
+ self.bias = bias
68
+ self.batch_first = batch_first
69
+ self.dropout = dropout
70
+ self.dropout_state = {}
71
+ self.bidirectional = bidirectional
72
+ self.num_directions = 2 if bidirectional else 1
73
+ self.highway_func = highway_func
74
+ self.pad = pad
75
+
76
+ self.lstm = nn.ModuleList()
77
+ self.highway = nn.ModuleList()
78
+ self.gate = nn.ModuleList()
79
+ self.drop = nn.Dropout(dropout, inplace=True)
80
+
81
+ in_size = input_size
82
+ for l in range(num_layers):
83
+ self.lstm.append(PackedLSTM(in_size, hidden_size, num_layers=1, bias=bias,
84
+ batch_first=batch_first, dropout=0, bidirectional=bidirectional, rec_dropout=rec_dropout))
85
+ self.highway.append(nn.Linear(in_size, hidden_size * self.num_directions))
86
+ self.gate.append(nn.Linear(in_size, hidden_size * self.num_directions))
87
+ self.highway[-1].bias.data.zero_()
88
+ self.gate[-1].bias.data.zero_()
89
+ in_size = hidden_size * self.num_directions
90
+
91
+ def forward(self, input, seqlens, hx=None):
92
+ highway_func = (lambda x: x) if self.highway_func is None else self.highway_func
93
+
94
+ hs = []
95
+ cs = []
96
+
97
+ if not isinstance(input, PackedSequence):
98
+ input = pack_padded_sequence(input, seqlens, batch_first=self.batch_first)
99
+
100
+ for l in range(self.num_layers):
101
+ if l > 0:
102
+ input = PackedSequence(self.drop(input.data), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
103
+ layer_hx = (hx[0][l * self.num_directions:(l+1)*self.num_directions], hx[1][l * self.num_directions:(l+1)*self.num_directions]) if hx is not None else None
104
+ h, (ht, ct) = self.lstm[l](input, seqlens, layer_hx)
105
+
106
+ hs.append(ht)
107
+ cs.append(ct)
108
+
109
+ input = PackedSequence(h.data + torch.sigmoid(self.gate[l](input.data)) * highway_func(self.highway[l](input.data)), input.batch_sizes, input.sorted_indices, input.unsorted_indices)
110
+
111
+ if self.pad:
112
+ input = pad_packed_sequence(input, batch_first=self.batch_first)[0]
113
+ return input, (torch.cat(hs, 0), torch.cat(cs, 0))
114
+
115
+ if __name__ == "__main__":
116
+ T = 10
117
+ bidir = True
118
+ num_dir = 2 if bidir else 1
119
+ rnn = HighwayLSTM(10, 20, num_layers=2, bidirectional=True)
120
+ input = torch.randn(T, 3, 10)
121
+ hx = torch.randn(2 * num_dir, 3, 20)
122
+ cx = torch.randn(2 * num_dir, 3, 20)
123
+ output = rnn(input, (hx, cx))
124
+ print(output)
stanza/stanza/models/common/large_margin_loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LargeMarginInSoftmax, from the article
3
+
4
+ @inproceedings{kobayashi2019bmvc,
5
+ title={Large Margin In Softmax Cross-Entropy Loss},
6
+ author={Takumi Kobayashi},
7
+ booktitle={Proceedings of the British Machine Vision Conference (BMVC)},
8
+ year={2019}
9
+ }
10
+
11
+ implementation from
12
+
13
+ https://github.com/tk1980/LargeMarginInSoftmax
14
+
15
+ There is no license specifically chosen; they just ask people to cite the paper if the work is useful.
16
+ """
17
+
18
+
19
+ import math
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.init as init
24
+ import torch.nn.functional as F
25
+
26
+
27
+ class LargeMarginInSoftmaxLoss(nn.CrossEntropyLoss):
28
+ r"""
29
+ This combines the Softmax Cross-Entropy Loss (nn.CrossEntropyLoss) and the large-margin inducing
30
+ regularization proposed in
31
+ T. Kobayashi, "Large-Margin In Softmax Cross-Entropy Loss." In BMVC2019.
32
+
33
+ This loss function inherits the parameters from nn.CrossEntropyLoss except for `reg_lambda` and `deg_logit`.
34
+ Args:
35
+ reg_lambda (float, optional): a regularization parameter. (default: 0.3)
36
+ deg_logit (bool, optional): underestimate (degrade) the target logit by -1 or not. (default: False)
37
+ If True, it realizes the method that incorporates the modified loss into ours
38
+ as described in the above paper (Table 4).
39
+ """
40
+ def __init__(self, reg_lambda=0.3, deg_logit=None,
41
+ weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'):
42
+ super(LargeMarginInSoftmaxLoss, self).__init__(weight=weight, size_average=size_average,
43
+ ignore_index=ignore_index, reduce=reduce, reduction=reduction)
44
+ self.reg_lambda = reg_lambda
45
+ self.deg_logit = deg_logit
46
+
47
+ def forward(self, input, target):
48
+ N = input.size(0) # number of samples
49
+ C = input.size(1) # number of classes
50
+ Mask = torch.zeros_like(input, requires_grad=False)
51
+ Mask[range(N),target] = 1
52
+
53
+ if self.deg_logit is not None:
54
+ input = input - self.deg_logit * Mask
55
+
56
+ loss = F.cross_entropy(input, target, weight=self.weight,
57
+ ignore_index=self.ignore_index, reduction=self.reduction)
58
+
59
+ X = input - 1.e6 * Mask # [N x C], excluding the target class
60
+ reg = 0.5 * ((F.softmax(X, dim=1) - 1.0/(C-1)) * F.log_softmax(X, dim=1) * (1.0-Mask)).sum(dim=1)
61
+ if self.reduction == 'sum':
62
+ reg = reg.sum()
63
+ elif self.reduction == 'mean':
64
+ reg = reg.mean()
65
+ elif self.reduction == 'none':
66
+ reg = reg
67
+
68
+ return loss + self.reg_lambda * reg
stanza/stanza/models/common/loss.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Different loss functions.
3
+ """
4
+
5
+ import logging
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import stanza.models.common.seq2seq_constant as constant
11
+
12
+ logger = logging.getLogger('stanza')
13
+
14
+ def SequenceLoss(vocab_size):
15
+ weight = torch.ones(vocab_size)
16
+ weight[constant.PAD_ID] = 0
17
+ crit = nn.NLLLoss(weight)
18
+ return crit
19
+
20
+ def weighted_cross_entropy_loss(labels, log_dampened=False):
21
+ """
22
+ Either return a loss function which reweights all examples so the
23
+ classes have the same effective weight, or dampened reweighting
24
+ using log() so that the biggest class has some priority
25
+ """
26
+ if isinstance(labels, list):
27
+ all_labels = np.array(labels)
28
+ _, weights = np.unique(labels, return_counts=True)
29
+ weights = weights / float(np.sum(weights))
30
+ weights = np.sum(weights) / weights
31
+ if log_dampened:
32
+ weights = 1 + np.log(weights)
33
+ logger.debug("Reweighting cross entropy by {}".format(weights))
34
+ loss = nn.CrossEntropyLoss(
35
+ weight=torch.from_numpy(weights).type('torch.FloatTensor')
36
+ )
37
+ return loss
38
+
39
+ class FocalLoss(nn.Module):
40
+ """
41
+ Uses the model's assessment of how likely the correct answer is
42
+ to weight the loss for a each error
43
+
44
+ multi-category focal loss, in other words
45
+
46
+ from "Focal Loss for Dense Object Detection"
47
+
48
+ https://arxiv.org/abs/1708.02002
49
+ """
50
+ def __init__(self, reduction='mean', gamma=2.0):
51
+ super().__init__()
52
+ if reduction not in ('sum', 'none', 'mean'):
53
+ raise ValueError("Unknown reduction: %s" % reduction)
54
+
55
+ self.reduction = reduction
56
+ self.ce_loss = nn.CrossEntropyLoss(reduction='none')
57
+ self.gamma = gamma
58
+
59
+ def forward(self, inputs, targets):
60
+ """
61
+ Weight the loss using the models assessment of the correct answer
62
+
63
+ inputs: [N, C]
64
+ targets: [N]
65
+ """
66
+ if len(inputs.shape) == 2 and len(targets.shape) == 1:
67
+ if inputs.shape[0] != targets.shape[0]:
68
+ raise ValueError("Expected inputs N,C and targets N, but got {} and {}".format(inputs.shape, targets.shape))
69
+ elif len(inputs.shape) == 1 and len(targets.shape) == 0:
70
+ raise NotImplementedError("This would be a reasonable thing to implement, but we haven't done it yet")
71
+ else:
72
+ raise ValueError("Expected inputs N,C and targets N, but got {} and {}".format(inputs.shape, targets.shape))
73
+
74
+ raw_loss = self.ce_loss(inputs, targets)
75
+ assert len(raw_loss.shape) == 1 and raw_loss.shape[0] == inputs.shape[0]
76
+
77
+ # https://www.tutorialexample.com/implement-focal-loss-for-multi-label-classification-in-pytorch-pytorch-tutorial/
78
+ final_loss = raw_loss * ((1 - torch.exp(-raw_loss)) ** self.gamma)
79
+ assert len(final_loss.shape) == 1 and final_loss.shape[0] == inputs.shape[0]
80
+ if self.reduction == 'sum':
81
+ return final_loss.sum()
82
+ elif self.reduction == 'mean':
83
+ return final_loss.mean()
84
+ elif self.reduction == 'none':
85
+ return final_loss
86
+ raise AssertionError("unknown reduction! how did this happen??")
87
+
88
+ class MixLoss(nn.Module):
89
+ """
90
+ A mixture of SequenceLoss and CrossEntropyLoss.
91
+ Loss = SequenceLoss + alpha * CELoss
92
+ """
93
+ def __init__(self, vocab_size, alpha):
94
+ super().__init__()
95
+ self.seq_loss = SequenceLoss(vocab_size)
96
+ self.ce_loss = nn.CrossEntropyLoss()
97
+ assert alpha >= 0
98
+ self.alpha = alpha
99
+
100
+ def forward(self, seq_inputs, seq_targets, class_inputs, class_targets):
101
+ sl = self.seq_loss(seq_inputs, seq_targets)
102
+ cel = self.ce_loss(class_inputs, class_targets)
103
+ loss = sl + self.alpha * cel
104
+ return loss
105
+
106
+ class MaxEntropySequenceLoss(nn.Module):
107
+ """
108
+ A max entropy loss that encourage the model to have large entropy,
109
+ therefore giving more diverse outputs.
110
+
111
+ Loss = NLLLoss + alpha * EntropyLoss
112
+ """
113
+ def __init__(self, vocab_size, alpha):
114
+ super().__init__()
115
+ weight = torch.ones(vocab_size)
116
+ weight[constant.PAD_ID] = 0
117
+ self.nll = nn.NLLLoss(weight)
118
+ self.alpha = alpha
119
+
120
+ def forward(self, inputs, targets):
121
+ """
122
+ inputs: [N, C]
123
+ targets: [N]
124
+ """
125
+ assert inputs.size(0) == targets.size(0)
126
+ nll_loss = self.nll(inputs, targets)
127
+ # entropy loss
128
+ mask = targets.eq(constant.PAD_ID).unsqueeze(1).expand_as(inputs)
129
+ masked_inputs = inputs.clone().masked_fill_(mask, 0.0)
130
+ p = torch.exp(masked_inputs)
131
+ ent_loss = p.mul(masked_inputs).sum() / inputs.size(0) # average over minibatch
132
+ loss = nll_loss + self.alpha * ent_loss
133
+ return loss
134
+
stanza/stanza/models/common/maxout_linear.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A layer which implements maxout from the "Maxout Networks" paper
3
+
4
+ https://arxiv.org/pdf/1302.4389v4.pdf
5
+ Goodfellow, Warde-Farley, Mirza, Courville, Bengio
6
+
7
+ or a simpler explanation here:
8
+
9
+ https://stats.stackexchange.com/questions/129698/what-is-maxout-in-neural-network/298705#298705
10
+
11
+ The implementation here:
12
+ for k layers of maxout, in -> out channels, we make a single linear
13
+ map of size in -> out*k
14
+ then we reshape the end to be (..., k, out)
15
+ and return the max over the k layers
16
+ """
17
+
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ class MaxoutLinear(nn.Module):
23
+ def __init__(self, in_channels, out_channels, maxout_k):
24
+ super().__init__()
25
+
26
+ self.in_channels = in_channels
27
+ self.out_channels = out_channels
28
+ self.maxout_k = maxout_k
29
+
30
+ self.linear = nn.Linear(in_channels, out_channels * maxout_k)
31
+
32
+ def forward(self, inputs):
33
+ """
34
+ Use the oversized linear as the repeated linear, then take the max
35
+
36
+ One large linear map makes the implementation simpler and easier for pytorch to make parallel
37
+ """
38
+ outputs = self.linear(inputs)
39
+ outputs = outputs.view(*outputs.shape[:-1], self.maxout_k, self.out_channels)
40
+ outputs = torch.max(outputs, dim=-2)[0]
41
+ return outputs
42
+
stanza/stanza/models/common/packed_lstm.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, PackedSequence
5
+
6
+ class PackedLSTM(nn.Module):
7
+ def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):
8
+ super().__init__()
9
+
10
+ self.batch_first = batch_first
11
+ self.pad = pad
12
+ if rec_dropout == 0:
13
+ # use the fast, native LSTM implementation
14
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
15
+ else:
16
+ self.lstm = LSTMwRecDropout(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, rec_dropout=rec_dropout)
17
+
18
+ def forward(self, input, lengths, hx=None):
19
+ if not isinstance(input, PackedSequence):
20
+ input = pack_padded_sequence(input, lengths, batch_first=self.batch_first)
21
+
22
+ res = self.lstm(input, hx)
23
+ if self.pad:
24
+ res = (pad_packed_sequence(res[0], batch_first=self.batch_first)[0], res[1])
25
+ return res
26
+
27
+ class LSTMwRecDropout(nn.Module):
28
+ """ An LSTM implementation that supports recurrent dropout """
29
+ def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, pad=False, rec_dropout=0):
30
+ super().__init__()
31
+ self.batch_first = batch_first
32
+ self.pad = pad
33
+ self.num_layers = num_layers
34
+ self.hidden_size = hidden_size
35
+
36
+ self.dropout = dropout
37
+ self.drop = nn.Dropout(dropout, inplace=True)
38
+ self.rec_drop = nn.Dropout(rec_dropout, inplace=True)
39
+
40
+ self.num_directions = 2 if bidirectional else 1
41
+
42
+ self.cells = nn.ModuleList()
43
+ for l in range(num_layers):
44
+ in_size = input_size if l == 0 else self.num_directions * hidden_size
45
+ for d in range(self.num_directions):
46
+ self.cells.append(nn.LSTMCell(in_size, hidden_size, bias=bias))
47
+
48
+ def forward(self, input, hx=None):
49
+ def rnn_loop(x, batch_sizes, cell, inits, reverse=False):
50
+ # RNN loop for one layer in one direction with recurrent dropout
51
+ # Assumes input is PackedSequence, returns PackedSequence as well
52
+ batch_size = batch_sizes[0].item()
53
+ states = [list(init.split([1] * batch_size)) for init in inits]
54
+ h_drop_mask = x.new_ones(batch_size, self.hidden_size)
55
+ h_drop_mask = self.rec_drop(h_drop_mask)
56
+ resh = []
57
+
58
+ if not reverse:
59
+ st = 0
60
+ for bs in batch_sizes:
61
+ s1 = cell(x[st:st+bs], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
62
+ resh.append(s1[0])
63
+ for j in range(bs):
64
+ states[0][j] = s1[0][j].unsqueeze(0)
65
+ states[1][j] = s1[1][j].unsqueeze(0)
66
+ st += bs
67
+ else:
68
+ en = x.size(0)
69
+ for i in range(batch_sizes.size(0)-1, -1, -1):
70
+ bs = batch_sizes[i]
71
+ s1 = cell(x[en-bs:en], (torch.cat(states[0][:bs], 0) * h_drop_mask[:bs], torch.cat(states[1][:bs], 0)))
72
+ resh.append(s1[0])
73
+ for j in range(bs):
74
+ states[0][j] = s1[0][j].unsqueeze(0)
75
+ states[1][j] = s1[1][j].unsqueeze(0)
76
+ en -= bs
77
+ resh = list(reversed(resh))
78
+
79
+ return torch.cat(resh, 0), tuple(torch.cat(s, 0) for s in states)
80
+
81
+ all_states = [[], []]
82
+ inputdata, batch_sizes = input.data, input.batch_sizes
83
+ for l in range(self.num_layers):
84
+ new_input = []
85
+
86
+ if self.dropout > 0 and l > 0:
87
+ inputdata = self.drop(inputdata)
88
+ for d in range(self.num_directions):
89
+ idx = l * self.num_directions + d
90
+ cell = self.cells[idx]
91
+ out, states = rnn_loop(inputdata, batch_sizes, cell, (hx[i][idx] for i in range(2)) if hx is not None else (input.data.new_zeros(input.batch_sizes[0].item(), self.hidden_size, requires_grad=False) for _ in range(2)), reverse=(d == 1))
92
+
93
+ new_input.append(out)
94
+ all_states[0].append(states[0].unsqueeze(0))
95
+ all_states[1].append(states[1].unsqueeze(0))
96
+
97
+ if self.num_directions > 1:
98
+ # concatenate both directions
99
+ inputdata = torch.cat(new_input, 1)
100
+ else:
101
+ inputdata = new_input[0]
102
+
103
+ input = PackedSequence(inputdata, batch_sizes)
104
+
105
+ return input, tuple(torch.cat(x, 0) for x in all_states)
stanza/stanza/models/common/peft_config.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Set a few common flags for peft uage
3
+ """
4
+
5
+
6
+ TRANSFORMER_LORA_RANK = {}
7
+ DEFAULT_LORA_RANK = 64
8
+
9
+ TRANSFORMER_LORA_ALPHA = {}
10
+ DEFAULT_LORA_ALPHA = 128
11
+
12
+ TRANSFORMER_LORA_DROPOUT = {}
13
+ DEFAULT_LORA_DROPOUT = 0.1
14
+
15
+ TRANSFORMER_LORA_TARGETS = {}
16
+ DEFAULT_LORA_TARGETS = "query,value,output.dense,intermediate.dense"
17
+
18
+ TRANSFORMER_LORA_SAVE = {}
19
+ DEFAULT_LORA_SAVE = ""
20
+
21
+ def add_peft_args(parser):
22
+ """
23
+ Add common default flags to an argparse
24
+ """
25
+ parser.add_argument('--lora_rank', type=int, default=None, help="Rank of a LoRA approximation. Default will be %d or a model-specific parameter" % DEFAULT_LORA_RANK)
26
+ parser.add_argument('--lora_alpha', type=int, default=None, help="Alpha of a LoRA approximation. Default will be %d or a model-specific parameter" % DEFAULT_LORA_ALPHA)
27
+ parser.add_argument('--lora_dropout', type=float, default=None, help="Dropout for the LoRA approximation. Default will be %s or a model-specific parameter" % DEFAULT_LORA_DROPOUT)
28
+ parser.add_argument('--lora_target_modules', type=str, default=None, help="Comma separated list of LoRA targets. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_TARGETS)
29
+ parser.add_argument('--lora_modules_to_save', type=str, default=None, help="Comma separated list of modules to save (eg, fully tune) when using LoRA. Default will be '%s' or a model-specific parameter" % DEFAULT_LORA_SAVE)
30
+
31
+ parser.add_argument('--use_peft', default=False, action='store_true', help="Finetune Bert using peft")
32
+
33
+ def pop_peft_args(args):
34
+ """
35
+ Pop all of the peft-related arguments from a given dict
36
+
37
+ Useful for making sure a model loaded from disk is recreated with
38
+ the right shapes, for example
39
+ """
40
+ args.pop("lora_rank", None)
41
+ args.pop("lora_alpha", None)
42
+ args.pop("lora_dropout", None)
43
+ args.pop("lora_target_modules", None)
44
+ args.pop("lora_modules_to_save", None)
45
+
46
+ args.pop("use_peft", None)
47
+
48
+
49
+ def resolve_peft_args(args, logger, check_bert_finetune=True):
50
+ if not hasattr(args, 'bert_model'):
51
+ return
52
+
53
+ if args.lora_rank is None:
54
+ args.lora_rank = TRANSFORMER_LORA_RANK.get(args.bert_model, DEFAULT_LORA_RANK)
55
+
56
+ if args.lora_alpha is None:
57
+ args.lora_alpha = TRANSFORMER_LORA_ALPHA.get(args.bert_model, DEFAULT_LORA_ALPHA)
58
+
59
+ if args.lora_dropout is None:
60
+ args.lora_dropout = TRANSFORMER_LORA_DROPOUT.get(args.bert_model, DEFAULT_LORA_DROPOUT)
61
+
62
+ if args.lora_target_modules is None:
63
+ args.lora_target_modules = TRANSFORMER_LORA_TARGETS.get(args.bert_model, DEFAULT_LORA_TARGETS)
64
+ if not args.lora_target_modules.strip():
65
+ args.lora_target_modules = []
66
+ else:
67
+ args.lora_target_modules = args.lora_target_modules.split(",")
68
+
69
+ if args.lora_modules_to_save is None:
70
+ args.lora_modules_to_save = TRANSFORMER_LORA_SAVE.get(args.bert_model, DEFAULT_LORA_SAVE)
71
+ if not args.lora_modules_to_save.strip():
72
+ args.lora_modules_to_save = []
73
+ else:
74
+ args.lora_modules_to_save = args.lora_modules_to_save.split(",")
75
+
76
+ if check_bert_finetune and hasattr(args, 'bert_finetune'):
77
+ if args.use_peft and not args.bert_finetune:
78
+ logger.info("--use_peft set. setting --bert_finetune as well")
79
+ args.bert_finetune = True
80
+
81
+ def build_peft_config(args, logger):
82
+ # Hide import so that the peft dependency is optional
83
+ from peft import LoraConfig
84
+ logger.debug("Creating lora adapter with rank %d and alpha %d", args['lora_rank'], args['lora_alpha'])
85
+ peft_config = LoraConfig(inference_mode=False,
86
+ r=args['lora_rank'],
87
+ target_modules=args['lora_target_modules'],
88
+ lora_alpha=args['lora_alpha'],
89
+ lora_dropout=args['lora_dropout'],
90
+ modules_to_save=args['lora_modules_to_save'],
91
+ bias="none")
92
+ return peft_config
93
+
94
+ def build_peft_wrapper(bert_model, args, logger, adapter_name="default"):
95
+ # Hide import so that the peft dependency is optional
96
+ from peft import get_peft_model
97
+ peft_config = build_peft_config(args, logger)
98
+
99
+ pefted = get_peft_model(bert_model, peft_config, adapter_name=adapter_name)
100
+ # apparently get_peft_model doesn't actually mark that
101
+ # peft configs are loaded, making it impossible to turn off (or on)
102
+ # the peft adapter later
103
+ bert_model._hf_peft_config_loaded = True
104
+ pefted._hf_peft_config_loaded = True
105
+ pefted.set_adapter(adapter_name)
106
+ return pefted
107
+
108
+ def load_peft_wrapper(bert_model, lora_params, args, logger, adapter_name):
109
+ peft_config = build_peft_config(args, logger)
110
+
111
+ try:
112
+ bert_model.load_adapter(adapter_name=adapter_name, peft_config=peft_config, adapter_state_dict=lora_params)
113
+ except (ValueError, TypeError) as _:
114
+ from peft import set_peft_model_state_dict
115
+ # this can happen if the adapter already exists...
116
+ # in that case, try setting the adapter weights?
117
+ set_peft_model_state_dict(bert_model, lora_params, adapter_name=adapter_name)
118
+ bert_model.set_adapter(adapter_name)
119
+ return bert_model
stanza/stanza/models/common/seq2seq_constant.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for seq2seq models.
3
+ """
4
+
5
+ PAD = '<PAD>'
6
+ PAD_ID = 0
7
+ UNK = '<UNK>'
8
+ UNK_ID = 1
9
+ SOS = '<SOS>'
10
+ SOS_ID = 2
11
+ EOS = '<EOS>'
12
+ EOS_ID = 3
13
+
14
+ VOCAB_PREFIX = [PAD, UNK, SOS, EOS]
15
+
16
+ EMB_INIT_RANGE = 1.0
17
+ INFINITY_NUMBER = 1e12
stanza/stanza/models/common/seq2seq_model.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The full encoder-decoder model, built on top of the base seq2seq modules.
3
+ """
4
+
5
+ import logging
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+
11
+ import stanza.models.common.seq2seq_constant as constant
12
+ from stanza.models.common import utils
13
+ from stanza.models.common.seq2seq_modules import LSTMAttention
14
+ from stanza.models.common.beam import Beam
15
+ from stanza.models.common.seq2seq_constant import UNK_ID
16
+
17
+ logger = logging.getLogger('stanza')
18
+
19
+ class Seq2SeqModel(nn.Module):
20
+ """
21
+ A complete encoder-decoder model, with optional attention.
22
+
23
+ A parent class which makes use of the contextual_embedding (such as a charlm)
24
+ can make use of unsaved_modules when saving.
25
+ """
26
+ def __init__(self, args, emb_matrix=None, contextual_embedding=None):
27
+ super().__init__()
28
+
29
+ self.unsaved_modules = []
30
+
31
+ self.vocab_size = args['vocab_size']
32
+ self.emb_dim = args['emb_dim']
33
+ self.hidden_dim = args['hidden_dim']
34
+ self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1
35
+ self.emb_dropout = args.get('emb_dropout', 0.0)
36
+ self.dropout = args['dropout']
37
+ self.pad_token = constant.PAD_ID
38
+ self.max_dec_len = args['max_dec_len']
39
+ self.top = args.get('top', 1e10)
40
+ self.args = args
41
+ self.emb_matrix = emb_matrix
42
+ self.add_unsaved_module("contextual_embedding", contextual_embedding)
43
+
44
+ logger.debug("Building an attentional Seq2Seq model...")
45
+ logger.debug("Using a Bi-LSTM encoder")
46
+ self.num_directions = 2
47
+ self.enc_hidden_dim = self.hidden_dim // 2
48
+ self.dec_hidden_dim = self.hidden_dim
49
+
50
+ self.use_pos = args.get('pos', False)
51
+ self.pos_dim = args.get('pos_dim', 0)
52
+ self.pos_vocab_size = args.get('pos_vocab_size', 0)
53
+ self.pos_dropout = args.get('pos_dropout', 0)
54
+ self.edit = args.get('edit', False)
55
+ self.num_edit = args.get('num_edit', 0)
56
+ self.copy = args.get('copy', False)
57
+
58
+ self.emb_drop = nn.Dropout(self.emb_dropout)
59
+ self.drop = nn.Dropout(self.dropout)
60
+ self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
61
+ self.input_dim = self.emb_dim
62
+ if self.contextual_embedding is not None:
63
+ self.input_dim += self.contextual_embedding.hidden_dim()
64
+ self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
65
+ bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
66
+ self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \
67
+ batch_first=True, attn_type=self.args['attn_type'])
68
+ self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size)
69
+ if self.use_pos and self.pos_dim > 0:
70
+ logger.debug("Using POS in encoder")
71
+ self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token)
72
+ self.pos_drop = nn.Dropout(self.pos_dropout)
73
+ if self.edit:
74
+ edit_hidden = self.hidden_dim//2
75
+ self.edit_clf = nn.Sequential(
76
+ nn.Linear(self.hidden_dim, edit_hidden),
77
+ nn.ReLU(),
78
+ nn.Linear(edit_hidden, self.num_edit))
79
+
80
+ if self.copy:
81
+ self.copy_gate = nn.Linear(self.dec_hidden_dim, 1)
82
+
83
+ SOS_tensor = torch.LongTensor([constant.SOS_ID])
84
+ self.register_buffer('SOS_tensor', SOS_tensor)
85
+
86
+ self.init_weights()
87
+
88
+ def add_unsaved_module(self, name, module):
89
+ self.unsaved_modules += [name]
90
+ setattr(self, name, module)
91
+
92
+ def init_weights(self):
93
+ # initialize embeddings
94
+ init_range = constant.EMB_INIT_RANGE
95
+ if self.emb_matrix is not None:
96
+ if isinstance(self.emb_matrix, np.ndarray):
97
+ self.emb_matrix = torch.from_numpy(self.emb_matrix)
98
+ assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \
99
+ "Input embedding matrix must match size: {} x {}".format(self.vocab_size, self.emb_dim)
100
+ self.embedding.weight.data.copy_(self.emb_matrix)
101
+ else:
102
+ self.embedding.weight.data.uniform_(-init_range, init_range)
103
+ # decide finetuning
104
+ if self.top <= 0:
105
+ logger.debug("Do not finetune embedding layer.")
106
+ self.embedding.weight.requires_grad = False
107
+ elif self.top < self.vocab_size:
108
+ logger.debug("Finetune top {} embeddings.".format(self.top))
109
+ self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top))
110
+ else:
111
+ logger.debug("Finetune all embeddings.")
112
+ # initialize pos embeddings
113
+ if self.use_pos:
114
+ self.pos_embedding.weight.data.uniform_(-init_range, init_range)
115
+
116
+ def zero_state(self, inputs):
117
+ batch_size = inputs.size(0)
118
+ device = self.SOS_tensor.device
119
+ h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
120
+ c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
121
+ return h0, c0
122
+
123
+ def encode(self, enc_inputs, lens):
124
+ """ Encode source sequence. """
125
+ h0, c0 = self.zero_state(enc_inputs)
126
+
127
+ packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
128
+ packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0))
129
+ h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True)
130
+ hn = torch.cat((hn[-1], hn[-2]), 1)
131
+ cn = torch.cat((cn[-1], cn[-2]), 1)
132
+ return h_in, (hn, cn)
133
+
134
+ def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False):
135
+ """ Decode a step, based on context encoding and source context states."""
136
+ dec_hidden = (hn, cn)
137
+ decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy)
138
+ if self.copy:
139
+ h_out, dec_hidden, log_attn = decoder_output
140
+ else:
141
+ h_out, dec_hidden = decoder_output
142
+
143
+ h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1)
144
+ decoder_logits = self.dec2vocab(h_out_reshape)
145
+ decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1)
146
+ log_probs = self.get_log_prob(decoder_logits)
147
+
148
+ if self.copy:
149
+ copy_logit = self.copy_gate(h_out)
150
+ if self.use_pos:
151
+ # can't copy the UPOS
152
+ log_attn = log_attn[:, :, 1:]
153
+
154
+ # renormalize
155
+ log_attn = torch.log_softmax(log_attn, -1)
156
+ # calculate copy probability for each word in the vocab
157
+ log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn
158
+ # scatter logsumexp
159
+ mx = log_copy_prob.max(-1, keepdim=True)[0]
160
+ log_copy_prob = log_copy_prob - mx
161
+ # here we make space in the log probs for vocab items
162
+ # which might be copied from the encoder side, but which
163
+ # were not known at training time
164
+ # note that such an item cannot possibly be predicted by
165
+ # the model as a raw output token
166
+ # however, the copy gate might score high on copying a
167
+ # previously unknown vocab item
168
+ copy_prob = torch.exp(log_copy_prob)
169
+ copied_vocab_shape = list(log_probs.size())
170
+ if torch.max(src) >= copied_vocab_shape[-1]:
171
+ copied_vocab_shape[-1] = torch.max(src) + 1
172
+ copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape)
173
+ scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1))
174
+ # fill in the copy tensor with the copy probs of each character
175
+ # the rest of the copy tensor will be filled with -largenumber
176
+ copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob)
177
+ zero_mask = (copied_vocab_prob == 0)
178
+ log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx
179
+ log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12)
180
+
181
+ # combine with normal vocab probability
182
+ log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit))
183
+ if log_probs.shape[-1] < copied_vocab_shape[-1]:
184
+ # for previously unknown vocab items which are in the encoder,
185
+ # we reuse the UNK_ID prediction
186
+ # this gives a baseline number which we can combine with
187
+ # the copy gate prediction
188
+ # technically this makes log_probs no longer represent
189
+ # a probability distribution when looking at unknown vocab
190
+ # this is probably not a serious problem
191
+ # an example of this usage is in the Lemmatizer, such as a
192
+ # plural word in English with the character "ã" in it instead of "a"
193
+ # if "ã" is not known in the training data, the lemmatizer would
194
+ # ordinarily be unable to output it, and thus the seq2seq model
195
+ # would have no chance to depluralize "ãntennae" -> "ãntenna"
196
+ # however, if we temporarily add "ã" to the encoder vocab,
197
+ # then let the copy gate accept that letter, we find the Lemmatizer
198
+ # seq2seq model will want to copy that particular vocab item
199
+ # this allows the Lemmatizer to produce "ã" instead of requiring
200
+ # that it produces UNK, then going back to the input text to
201
+ # figure out which UNK it intended to produce
202
+ new_log_probs = log_probs.new_zeros(copied_vocab_shape)
203
+ new_log_probs[:, :, :log_probs.shape[-1]] = log_probs
204
+ new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2)
205
+ log_probs = new_log_probs
206
+ log_probs = log_probs + log_nocopy_prob
207
+ log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0)
208
+
209
+ if never_decode_unk:
210
+ log_probs[:, :, UNK_ID] = float("-inf")
211
+ return log_probs, dec_hidden
212
+
213
+ def embed(self, src, src_mask, pos, raw):
214
+ embed_src = src.clone()
215
+ embed_src[embed_src >= self.vocab_size] = UNK_ID
216
+ enc_inputs = self.emb_drop(self.embedding(embed_src))
217
+ batch_size = enc_inputs.size(0)
218
+ if self.use_pos:
219
+ assert pos is not None, "Missing POS input for seq2seq lemmatizer."
220
+ pos_inputs = self.pos_drop(self.pos_embedding(pos))
221
+ enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
222
+ pos_src_mask = src_mask.new_zeros([batch_size, 1])
223
+ src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
224
+ if raw is not None and self.contextual_embedding is not None:
225
+ raw_inputs = self.contextual_embedding(raw)
226
+ if self.use_pos:
227
+ raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))
228
+ raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)
229
+ enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)
230
+ src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
231
+ return enc_inputs, batch_size, src_lens, src_mask
232
+
233
+ def forward(self, src, src_mask, tgt_in, pos=None, raw=None):
234
+ # prepare for encoder/decoder
235
+ enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
236
+
237
+ # encode source
238
+ h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
239
+
240
+ if self.edit:
241
+ edit_logits = self.edit_clf(hn)
242
+ else:
243
+ edit_logits = None
244
+
245
+ dec_inputs = self.emb_drop(self.embedding(tgt_in))
246
+
247
+ log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src)
248
+ return log_probs, edit_logits
249
+
250
+ def get_log_prob(self, logits):
251
+ logits_reshape = logits.view(-1, self.vocab_size)
252
+ log_probs = F.log_softmax(logits_reshape, dim=1)
253
+ if logits.dim() == 2:
254
+ return log_probs
255
+ return log_probs.view(logits.size(0), logits.size(1), logits.size(2))
256
+
257
+ def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False):
258
+ """ Predict with greedy decoding. """
259
+ enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
260
+
261
+ # encode source
262
+ h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
263
+
264
+ if self.edit:
265
+ edit_logits = self.edit_clf(hn)
266
+ else:
267
+ edit_logits = None
268
+
269
+ # greedy decode by step
270
+ dec_inputs = self.embedding(self.SOS_tensor)
271
+ dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))
272
+
273
+ done = [False for _ in range(batch_size)]
274
+ total_done = 0
275
+ max_len = 0
276
+ output_seqs = [[] for _ in range(batch_size)]
277
+
278
+ while total_done < batch_size and max_len < self.max_dec_len:
279
+ log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
280
+ assert log_probs.size(1) == 1, "Output must have 1-step of output."
281
+ _, preds = log_probs.squeeze(1).max(1, keepdim=True)
282
+ # if a unlearned character is predicted via the copy mechanism,
283
+ # use the UNK embedding for it
284
+ dec_inputs = preds.clone()
285
+ dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
286
+ dec_inputs = self.embedding(dec_inputs) # update decoder inputs
287
+ max_len += 1
288
+ for i in range(batch_size):
289
+ if not done[i]:
290
+ token = preds.data[i][0].item()
291
+ if token == constant.EOS_ID:
292
+ done[i] = True
293
+ total_done += 1
294
+ else:
295
+ output_seqs[i].append(token)
296
+ return output_seqs, edit_logits
297
+
298
+ def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False):
299
+ """ Predict with beam search. """
300
+ if beam_size == 1:
301
+ return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk)
302
+
303
+ enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
304
+
305
+ # (1) encode source
306
+ h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
307
+
308
+ if self.edit:
309
+ edit_logits = self.edit_clf(hn)
310
+ else:
311
+ edit_logits = None
312
+
313
+ # (2) set up beam
314
+ with torch.no_grad():
315
+ h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search
316
+ src_mask = src_mask.repeat(beam_size, 1)
317
+ # repeat decoder hidden states
318
+ hn = hn.data.repeat(beam_size, 1)
319
+ cn = cn.data.repeat(beam_size, 1)
320
+ device = self.SOS_tensor.device
321
+ beam = [Beam(beam_size, device) for _ in range(batch_size)]
322
+
323
+ def update_state(states, idx, positions, beam_size):
324
+ """ Select the states according to back pointers. """
325
+ for e in states:
326
+ br, d = e.size()
327
+ s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx]
328
+ s.data.copy_(s.data.index_select(0, positions))
329
+
330
+ # (3) main loop
331
+ for i in range(self.max_dec_len):
332
+ dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1)
333
+ # if a unlearned character is predicted via the copy mechanism,
334
+ # use the UNK embedding for it
335
+ dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
336
+ dec_inputs = self.embedding(dec_inputs)
337
+ log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
338
+ log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V]
339
+
340
+ # advance each beam
341
+ done = []
342
+ for b in range(batch_size):
343
+ is_done = beam[b].advance(log_probs.data[b])
344
+ if is_done:
345
+ done += [b]
346
+ # update beam state
347
+ update_state((hn, cn), b, beam[b].get_current_origin(), beam_size)
348
+
349
+ if len(done) == batch_size:
350
+ break
351
+
352
+ # back trace and find hypothesis
353
+ all_hyp, all_scores = [], []
354
+ for b in range(batch_size):
355
+ scores, ks = beam[b].sort_best()
356
+ all_scores += [scores[0]]
357
+ k = ks[0]
358
+ hyp = beam[b].get_hyp(k)
359
+ hyp = utils.prune_hyp(hyp)
360
+ hyp = [i.item() for i in hyp]
361
+ all_hyp += [hyp]
362
+
363
+ return all_hyp, edit_logits
364
+
stanza/stanza/models/common/seq2seq_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils for seq2seq models.
3
+ """
4
+ from collections import Counter
5
+ import random
6
+ import json
7
+ import torch
8
+
9
+ import stanza.models.common.seq2seq_constant as constant
10
+
11
+ # torch utils
12
+ def get_optimizer(name, parameters, lr):
13
+ if name == 'sgd':
14
+ return torch.optim.SGD(parameters, lr=lr)
15
+ elif name == 'adagrad':
16
+ return torch.optim.Adagrad(parameters, lr=lr)
17
+ elif name == 'adam':
18
+ return torch.optim.Adam(parameters) # use default lr
19
+ elif name == 'adamax':
20
+ return torch.optim.Adamax(parameters) # use default lr
21
+ else:
22
+ raise Exception("Unsupported optimizer: {}".format(name))
23
+
24
+ def change_lr(optimizer, new_lr):
25
+ for param_group in optimizer.param_groups:
26
+ param_group['lr'] = new_lr
27
+
28
+ def flatten_indices(seq_lens, width):
29
+ flat = []
30
+ for i, l in enumerate(seq_lens):
31
+ for j in range(l):
32
+ flat.append(i * width + j)
33
+ return flat
34
+
35
+ def keep_partial_grad(grad, topk):
36
+ """
37
+ Keep only the topk rows of grads.
38
+ """
39
+ assert topk < grad.size(0)
40
+ grad.data[topk:].zero_()
41
+ return grad
42
+
43
+ # other utils
44
+ def save_config(config, path, verbose=True):
45
+ with open(path, 'w') as outfile:
46
+ json.dump(config, outfile, indent=2)
47
+ if verbose:
48
+ print("Config saved to file {}".format(path))
49
+ return config
50
+
51
+ def load_config(path, verbose=True):
52
+ with open(path) as f:
53
+ config = json.load(f)
54
+ if verbose:
55
+ print("Config loaded from file {}".format(path))
56
+ return config
57
+
58
+ def unmap_with_copy(indices, src_tokens, vocab):
59
+ """
60
+ Unmap a list of list of indices, by optionally copying from src_tokens.
61
+ """
62
+ result = []
63
+ for ind, tokens in zip(indices, src_tokens):
64
+ words = []
65
+ for idx in ind:
66
+ if idx >= 0:
67
+ words.append(vocab.id2word[idx])
68
+ else:
69
+ idx = -idx - 1 # flip and minus 1
70
+ words.append(tokens[idx])
71
+ result += [words]
72
+ return result
73
+
74
+ def prune_decoded_seqs(seqs):
75
+ """
76
+ Prune decoded sequences after EOS token.
77
+ """
78
+ out = []
79
+ for s in seqs:
80
+ if constant.EOS in s:
81
+ idx = s.index(constant.EOS_TOKEN)
82
+ out += [s[:idx]]
83
+ else:
84
+ out += [s]
85
+ return out
86
+
87
+ def prune_hyp(hyp):
88
+ """
89
+ Prune a decoded hypothesis
90
+ """
91
+ if constant.EOS_ID in hyp:
92
+ idx = hyp.index(constant.EOS_ID)
93
+ return hyp[:idx]
94
+ else:
95
+ return hyp
96
+
97
+ def prune(data_list, lens):
98
+ assert len(data_list) == len(lens)
99
+ nl = []
100
+ for d, l in zip(data_list, lens):
101
+ nl.append(d[:l])
102
+ return nl
103
+
104
+ def sort(packed, ref, reverse=True):
105
+ """
106
+ Sort a series of packed list, according to a ref list.
107
+ Also return the original index before the sort.
108
+ """
109
+ assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)
110
+ packed = [ref] + [range(len(ref))] + list(packed)
111
+ sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]
112
+ return tuple(sorted_packed[1:])
113
+
114
+ def unsort(sorted_list, oidx):
115
+ """
116
+ Unsort a sorted list, based on the original idx.
117
+ """
118
+ assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
119
+ _, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
120
+ return unsorted
121
+
stanza/stanza/models/common/short_name_to_treebank.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is autogenerated by build_short_name_to_treebank.py
2
+ # Please do not edit
3
+
4
+ SHORT_NAMES = {
5
+ 'abq_atb': 'UD_Abaza-ATB',
6
+ 'ab_abnc': 'UD_Abkhaz-AbNC',
7
+ 'af_afribooms': 'UD_Afrikaans-AfriBooms',
8
+ 'akk_pisandub': 'UD_Akkadian-PISANDUB',
9
+ 'akk_riao': 'UD_Akkadian-RIAO',
10
+ 'aqz_tudet': 'UD_Akuntsu-TuDeT',
11
+ 'sq_staf': 'UD_Albanian-STAF',
12
+ 'sq_tsa': 'UD_Albanian-TSA',
13
+ 'am_att': 'UD_Amharic-ATT',
14
+ 'grc_proiel': 'UD_Ancient_Greek-PROIEL',
15
+ 'grc_ptnk': 'UD_Ancient_Greek-PTNK',
16
+ 'grc_perseus': 'UD_Ancient_Greek-Perseus',
17
+ 'hbo_ptnk': 'UD_Ancient_Hebrew-PTNK',
18
+ 'apu_ufpa': 'UD_Apurina-UFPA',
19
+ 'ar_nyuad': 'UD_Arabic-NYUAD',
20
+ 'ar_padt': 'UD_Arabic-PADT',
21
+ 'ar_pud': 'UD_Arabic-PUD',
22
+ 'hy_armtdp': 'UD_Armenian-ArmTDP',
23
+ 'hy_bsut': 'UD_Armenian-BSUT',
24
+ 'aii_as': 'UD_Assyrian-AS',
25
+ 'az_tuecl': 'UD_Azerbaijani-TueCL',
26
+ 'bm_crb': 'UD_Bambara-CRB',
27
+ 'eu_bdt': 'UD_Basque-BDT',
28
+ 'bar_maibaam': 'UD_Bavarian-MaiBaam',
29
+ 'bej_autogramm': 'UD_Beja-Autogramm',
30
+ 'be_hse': 'UD_Belarusian-HSE',
31
+ 'bn_bru': 'UD_Bengali-BRU',
32
+ 'bho_bhtb': 'UD_Bhojpuri-BHTB',
33
+ 'bor_bdt': 'UD_Bororo-BDT',
34
+ 'br_keb': 'UD_Breton-KEB',
35
+ 'bg_btb': 'UD_Bulgarian-BTB',
36
+ 'bxr_bdt': 'UD_Buryat-BDT',
37
+ 'yue_hk': 'UD_Cantonese-HK',
38
+ 'cpg_amgic': 'UD_Cappadocian-AMGiC',
39
+ 'cpg_tuecl': 'UD_Cappadocian-TueCL',
40
+ 'ca_ancora': 'UD_Catalan-AnCora',
41
+ 'ceb_gja': 'UD_Cebuano-GJA',
42
+ 'zh-hans_beginner': 'UD_Chinese-Beginner',
43
+ 'zh_beginner': 'UD_Chinese-Beginner',
44
+ 'zh-hans_cfl': 'UD_Chinese-CFL',
45
+ 'zh_cfl': 'UD_Chinese-CFL',
46
+ 'zh-hant_gsd': 'UD_Chinese-GSD',
47
+ 'zh_gsd': 'UD_Chinese-GSD',
48
+ 'zh-hans_gsdsimp': 'UD_Chinese-GSDSimp',
49
+ 'zh_gsdsimp': 'UD_Chinese-GSDSimp',
50
+ 'zh-hant_hk': 'UD_Chinese-HK',
51
+ 'zh_hk': 'UD_Chinese-HK',
52
+ 'zh-hant_pud': 'UD_Chinese-PUD',
53
+ 'zh_pud': 'UD_Chinese-PUD',
54
+ 'zh-hans_patentchar': 'UD_Chinese-PatentChar',
55
+ 'zh_patentchar': 'UD_Chinese-PatentChar',
56
+ 'ckt_hse': 'UD_Chukchi-HSE',
57
+ 'xcl_caval': 'UD_Classical_Armenian-CAVaL',
58
+ 'lzh_kyoto': 'UD_Classical_Chinese-Kyoto',
59
+ 'lzh_tuecl': 'UD_Classical_Chinese-TueCL',
60
+ 'cop_scriptorium': 'UD_Coptic-Scriptorium',
61
+ 'hr_set': 'UD_Croatian-SET',
62
+ 'cs_cac': 'UD_Czech-CAC',
63
+ 'cs_cltt': 'UD_Czech-CLTT',
64
+ 'cs_fictree': 'UD_Czech-FicTree',
65
+ 'cs_pdt': 'UD_Czech-PDT',
66
+ 'cs_pud': 'UD_Czech-PUD',
67
+ 'cs_poetry': 'UD_Czech-Poetry',
68
+ 'da_ddt': 'UD_Danish-DDT',
69
+ 'nl_alpino': 'UD_Dutch-Alpino',
70
+ 'nl_lassysmall': 'UD_Dutch-LassySmall',
71
+ 'egy_ujaen': 'UD_Egyptian-UJaen',
72
+ 'en_atis': 'UD_English-Atis',
73
+ 'en_ctetex': 'UD_English-CTeTex',
74
+ 'en_eslspok': 'UD_English-ESLSpok',
75
+ 'en_ewt': 'UD_English-EWT',
76
+ 'en_gentle': 'UD_English-GENTLE',
77
+ 'en_gum': 'UD_English-GUM',
78
+ 'en_gumreddit': 'UD_English-GUMReddit',
79
+ 'en_lines': 'UD_English-LinES',
80
+ 'en_pud': 'UD_English-PUD',
81
+ 'en_partut': 'UD_English-ParTUT',
82
+ 'en_pronouns': 'UD_English-Pronouns',
83
+ 'myv_jr': 'UD_Erzya-JR',
84
+ 'et_edt': 'UD_Estonian-EDT',
85
+ 'et_ewt': 'UD_Estonian-EWT',
86
+ 'fo_farpahc': 'UD_Faroese-FarPaHC',
87
+ 'fo_oft': 'UD_Faroese-OFT',
88
+ 'fi_ftb': 'UD_Finnish-FTB',
89
+ 'fi_ood': 'UD_Finnish-OOD',
90
+ 'fi_pud': 'UD_Finnish-PUD',
91
+ 'fi_tdt': 'UD_Finnish-TDT',
92
+ 'fr_fqb': 'UD_French-FQB',
93
+ 'fr_gsd': 'UD_French-GSD',
94
+ 'fr_pud': 'UD_French-PUD',
95
+ 'fr_partut': 'UD_French-ParTUT',
96
+ 'fr_parisstories': 'UD_French-ParisStories',
97
+ 'fr_rhapsodie': 'UD_French-Rhapsodie',
98
+ 'fr_sequoia': 'UD_French-Sequoia',
99
+ 'qfn_fame': 'UD_Frisian_Dutch-Fame',
100
+ 'gl_ctg': 'UD_Galician-CTG',
101
+ 'gl_pud': 'UD_Galician-PUD',
102
+ 'gl_treegal': 'UD_Galician-TreeGal',
103
+ 'ka_glc': 'UD_Georgian-GLC',
104
+ 'de_gsd': 'UD_German-GSD',
105
+ 'de_hdt': 'UD_German-HDT',
106
+ 'de_lit': 'UD_German-LIT',
107
+ 'de_pud': 'UD_German-PUD',
108
+ 'aln_gps': 'UD_Gheg-GPS',
109
+ 'got_proiel': 'UD_Gothic-PROIEL',
110
+ 'el_gdt': 'UD_Greek-GDT',
111
+ 'el_gud': 'UD_Greek-GUD',
112
+ 'gub_tudet': 'UD_Guajajara-TuDeT',
113
+ 'gn_oldtudet': 'UD_Guarani-OldTuDeT',
114
+ 'gu_gujtb': 'UD_Gujarati-GujTB',
115
+ 'gwi_tuecl': 'UD_Gwichin-TueCL',
116
+ 'ht_autogramm': 'UD_Haitian_Creole-Autogramm',
117
+ 'ha_northernautogramm': 'UD_Hausa-NorthernAutogramm',
118
+ 'ha_southernautogramm': 'UD_Hausa-SouthernAutogramm',
119
+ 'he_htb': 'UD_Hebrew-HTB',
120
+ 'he_iahltknesset': 'UD_Hebrew-IAHLTknesset',
121
+ 'he_iahltwiki': 'UD_Hebrew-IAHLTwiki',
122
+ 'azz_itml': 'UD_Highland_Puebla_Nahuatl-ITML',
123
+ 'hi_hdtb': 'UD_Hindi-HDTB',
124
+ 'hi_pud': 'UD_Hindi-PUD',
125
+ 'hit_hittb': 'UD_Hittite-HitTB',
126
+ 'hu_szeged': 'UD_Hungarian-Szeged',
127
+ 'is_gc': 'UD_Icelandic-GC',
128
+ 'is_icepahc': 'UD_Icelandic-IcePaHC',
129
+ 'is_modern': 'UD_Icelandic-Modern',
130
+ 'is_pud': 'UD_Icelandic-PUD',
131
+ 'id_csui': 'UD_Indonesian-CSUI',
132
+ 'id_gsd': 'UD_Indonesian-GSD',
133
+ 'id_pud': 'UD_Indonesian-PUD',
134
+ 'ga_cadhan': 'UD_Irish-Cadhan',
135
+ 'ga_idt': 'UD_Irish-IDT',
136
+ 'ga_twittirish': 'UD_Irish-TwittIrish',
137
+ 'it_isdt': 'UD_Italian-ISDT',
138
+ 'it_markit': 'UD_Italian-MarkIT',
139
+ 'it_old': 'UD_Italian-Old',
140
+ 'it_pud': 'UD_Italian-PUD',
141
+ 'it_partut': 'UD_Italian-ParTUT',
142
+ 'it_parlamint': 'UD_Italian-ParlaMint',
143
+ 'it_postwita': 'UD_Italian-PoSTWITA',
144
+ 'it_twittiro': 'UD_Italian-TWITTIRO',
145
+ 'it_vit': 'UD_Italian-VIT',
146
+ 'it_valico': 'UD_Italian-Valico',
147
+ 'ja_bccwj': 'UD_Japanese-BCCWJ',
148
+ 'ja_bccwjluw': 'UD_Japanese-BCCWJLUW',
149
+ 'ja_gsd': 'UD_Japanese-GSD',
150
+ 'ja_gsdluw': 'UD_Japanese-GSDLUW',
151
+ 'ja_pud': 'UD_Japanese-PUD',
152
+ 'ja_pudluw': 'UD_Japanese-PUDLUW',
153
+ 'jv_csui': 'UD_Javanese-CSUI',
154
+ 'urb_tudet': 'UD_Kaapor-TuDeT',
155
+ 'xnr_kdtb': 'UD_Kangri-KDTB',
156
+ 'krl_kkpp': 'UD_Karelian-KKPP',
157
+ 'arr_tudet': 'UD_Karo-TuDeT',
158
+ 'kk_ktb': 'UD_Kazakh-KTB',
159
+ 'kfm_aha': 'UD_Khunsari-AHA',
160
+ 'quc_iu': 'UD_Kiche-IU',
161
+ 'koi_uh': 'UD_Komi_Permyak-UH',
162
+ 'kpv_ikdp': 'UD_Komi_Zyrian-IKDP',
163
+ 'kpv_lattice': 'UD_Komi_Zyrian-Lattice',
164
+ 'ko_gsd': 'UD_Korean-GSD',
165
+ 'ko_ksl': 'UD_Korean-KSL',
166
+ 'ko_kaist': 'UD_Korean-Kaist',
167
+ 'ko_pud': 'UD_Korean-PUD',
168
+ 'kmr_mg': 'UD_Kurmanji-MG',
169
+ 'ky_ktmu': 'UD_Kyrgyz-KTMU',
170
+ 'ky_tuecl': 'UD_Kyrgyz-TueCL',
171
+ 'ltg_cairo': 'UD_Latgalian-Cairo',
172
+ 'la_circse': 'UD_Latin-CIRCSE',
173
+ 'la_ittb': 'UD_Latin-ITTB',
174
+ 'la_llct': 'UD_Latin-LLCT',
175
+ 'la_proiel': 'UD_Latin-PROIEL',
176
+ 'la_perseus': 'UD_Latin-Perseus',
177
+ 'la_udante': 'UD_Latin-UDante',
178
+ 'lv_cairo': 'UD_Latvian-Cairo',
179
+ 'lv_lvtb': 'UD_Latvian-LVTB',
180
+ 'lij_glt': 'UD_Ligurian-GLT',
181
+ 'lt_alksnis': 'UD_Lithuanian-ALKSNIS',
182
+ 'lt_hse': 'UD_Lithuanian-HSE',
183
+ 'olo_kkpp': 'UD_Livvi-KKPP',
184
+ 'nds_lsdc': 'UD_Low_Saxon-LSDC',
185
+ 'lb_luxbank': 'UD_Luxembourgish-LuxBank',
186
+ 'mk_mtb': 'UD_Macedonian-MTB',
187
+ 'jaa_jarawara': 'UD_Madi-Jarawara',
188
+ 'qaf_arabizi': 'UD_Maghrebi_Arabic_French-Arabizi',
189
+ 'mpu_tudet': 'UD_Makurap-TuDeT',
190
+ 'ml_ufal': 'UD_Malayalam-UFAL',
191
+ 'mt_mudt': 'UD_Maltese-MUDT',
192
+ 'gv_cadhan': 'UD_Manx-Cadhan',
193
+ 'mr_ufal': 'UD_Marathi-UFAL',
194
+ 'gun_dooley': 'UD_Mbya_Guarani-Dooley',
195
+ 'gun_thomas': 'UD_Mbya_Guarani-Thomas',
196
+ 'frm_profiterole': 'UD_Middle_French-PROFITEROLE',
197
+ 'mdf_jr': 'UD_Moksha-JR',
198
+ 'myu_tudet': 'UD_Munduruku-TuDeT',
199
+ 'pcm_nsc': 'UD_Naija-NSC',
200
+ 'nyq_aha': 'UD_Nayini-AHA',
201
+ 'nap_rb': 'UD_Neapolitan-RB',
202
+ 'yrl_complin': 'UD_Nheengatu-CompLin',
203
+ 'sme_giella': 'UD_North_Sami-Giella',
204
+ 'gya_autogramm': 'UD_Northwest_Gbaya-Autogramm',
205
+ 'nb_bokmaal': 'UD_Norwegian-Bokmaal',
206
+ 'no_bokmaal': 'UD_Norwegian-Bokmaal',
207
+ 'nn_nynorsk': 'UD_Norwegian-Nynorsk',
208
+ 'cu_proiel': 'UD_Old_Church_Slavonic-PROIEL',
209
+ 'orv_birchbark': 'UD_Old_East_Slavic-Birchbark',
210
+ 'orv_rnc': 'UD_Old_East_Slavic-RNC',
211
+ 'orv_ruthenian': 'UD_Old_East_Slavic-Ruthenian',
212
+ 'orv_torot': 'UD_Old_East_Slavic-TOROT',
213
+ 'fro_profiterole': 'UD_Old_French-PROFITEROLE',
214
+ 'sga_dipsgg': 'UD_Old_Irish-DipSGG',
215
+ 'sga_dipwbg': 'UD_Old_Irish-DipWBG',
216
+ 'otk_clausal': 'UD_Old_Turkish-Clausal',
217
+ 'ota_boun': 'UD_Ottoman_Turkish-BOUN',
218
+ 'ota_dudu': 'UD_Ottoman_Turkish-DUDU',
219
+ 'ps_sikaram': 'UD_Pashto-Sikaram',
220
+ 'pad_tuecl': 'UD_Paumari-TueCL',
221
+ 'fa_perdt': 'UD_Persian-PerDT',
222
+ 'fa_seraji': 'UD_Persian-Seraji',
223
+ 'pay_chibergis': 'UD_Pesh-ChibErgIS',
224
+ 'xpg_kul': 'UD_Phrygian-KUL',
225
+ 'pl_lfg': 'UD_Polish-LFG',
226
+ 'pl_pdb': 'UD_Polish-PDB',
227
+ 'pl_pud': 'UD_Polish-PUD',
228
+ 'qpm_philotis': 'UD_Pomak-Philotis',
229
+ 'pt_bosque': 'UD_Portuguese-Bosque',
230
+ 'pt_cintil': 'UD_Portuguese-CINTIL',
231
+ 'pt_dantestocks': 'UD_Portuguese-DANTEStocks',
232
+ 'pt_gsd': 'UD_Portuguese-GSD',
233
+ 'pt_pud': 'UD_Portuguese-PUD',
234
+ 'pt_petrogold': 'UD_Portuguese-PetroGold',
235
+ 'pt_porttinari': 'UD_Portuguese-Porttinari',
236
+ 'ro_art': 'UD_Romanian-ArT',
237
+ 'ro_nonstandard': 'UD_Romanian-Nonstandard',
238
+ 'ro_rrt': 'UD_Romanian-RRT',
239
+ 'ro_simonero': 'UD_Romanian-SiMoNERo',
240
+ 'ro_tuecl': 'UD_Romanian-TueCL',
241
+ 'ru_gsd': 'UD_Russian-GSD',
242
+ 'ru_pud': 'UD_Russian-PUD',
243
+ 'ru_poetry': 'UD_Russian-Poetry',
244
+ 'ru_syntagrus': 'UD_Russian-SynTagRus',
245
+ 'ru_taiga': 'UD_Russian-Taiga',
246
+ 'sa_ufal': 'UD_Sanskrit-UFAL',
247
+ 'sa_vedic': 'UD_Sanskrit-Vedic',
248
+ 'gd_arcosg': 'UD_Scottish_Gaelic-ARCOSG',
249
+ 'sr_set': 'UD_Serbian-SET',
250
+ 'si_stb': 'UD_Sinhala-STB',
251
+ 'sms_giellagas': 'UD_Skolt_Sami-Giellagas',
252
+ 'sk_snk': 'UD_Slovak-SNK',
253
+ 'sl_ssj': 'UD_Slovenian-SSJ',
254
+ 'sl_sst': 'UD_Slovenian-SST',
255
+ 'soj_aha': 'UD_Soi-AHA',
256
+ 'ajp_madar': 'UD_South_Levantine_Arabic-MADAR',
257
+ 'es_ancora': 'UD_Spanish-AnCora',
258
+ 'es_coser': 'UD_Spanish-COSER',
259
+ 'es_gsd': 'UD_Spanish-GSD',
260
+ 'es_pud': 'UD_Spanish-PUD',
261
+ 'ssp_lse': 'UD_Spanish_Sign_Language-LSE',
262
+ 'sv_lines': 'UD_Swedish-LinES',
263
+ 'sv_pud': 'UD_Swedish-PUD',
264
+ 'sv_talbanken': 'UD_Swedish-Talbanken',
265
+ 'swl_sslc': 'UD_Swedish_Sign_Language-SSLC',
266
+ 'gsw_uzh': 'UD_Swiss_German-UZH',
267
+ 'tl_trg': 'UD_Tagalog-TRG',
268
+ 'tl_ugnayan': 'UD_Tagalog-Ugnayan',
269
+ 'ta_mwtt': 'UD_Tamil-MWTT',
270
+ 'ta_ttb': 'UD_Tamil-TTB',
271
+ 'tt_nmctt': 'UD_Tatar-NMCTT',
272
+ 'eme_tudet': 'UD_Teko-TuDeT',
273
+ 'te_mtg': 'UD_Telugu-MTG',
274
+ 'qte_tect': 'UD_Telugu_English-TECT',
275
+ 'th_pud': 'UD_Thai-PUD',
276
+ 'tn_popapolelo': 'UD_Tswana-Popapolelo',
277
+ 'tpn_tudet': 'UD_Tupinamba-TuDeT',
278
+ 'tr_atis': 'UD_Turkish-Atis',
279
+ 'tr_boun': 'UD_Turkish-BOUN',
280
+ 'tr_framenet': 'UD_Turkish-FrameNet',
281
+ 'tr_gb': 'UD_Turkish-GB',
282
+ 'tr_imst': 'UD_Turkish-IMST',
283
+ 'tr_kenet': 'UD_Turkish-Kenet',
284
+ 'tr_pud': 'UD_Turkish-PUD',
285
+ 'tr_penn': 'UD_Turkish-Penn',
286
+ 'tr_tourism': 'UD_Turkish-Tourism',
287
+ 'qtd_sagt': 'UD_Turkish_German-SAGT',
288
+ 'uk_iu': 'UD_Ukrainian-IU',
289
+ 'uk_parlamint': 'UD_Ukrainian-ParlaMint',
290
+ 'xum_ikuvina': 'UD_Umbrian-IKUVINA',
291
+ 'hsb_ufal': 'UD_Upper_Sorbian-UFAL',
292
+ 'ur_udtb': 'UD_Urdu-UDTB',
293
+ 'ug_udt': 'UD_Uyghur-UDT',
294
+ 'uz_ut': 'UD_Uzbek-UT',
295
+ 'vep_vwt': 'UD_Veps-VWT',
296
+ 'vi_tuecl': 'UD_Vietnamese-TueCL',
297
+ 'vi_vtb': 'UD_Vietnamese-VTB',
298
+ 'wbp_ufal': 'UD_Warlpiri-UFAL',
299
+ 'cy_ccg': 'UD_Welsh-CCG',
300
+ 'hyw_armtdp': 'UD_Western_Armenian-ArmTDP',
301
+ 'nhi_itml': 'UD_Western_Sierra_Puebla_Nahuatl-ITML',
302
+ 'wo_wtb': 'UD_Wolof-WTB',
303
+ 'xav_xdt': 'UD_Xavante-XDT',
304
+ 'sjo_xdt': 'UD_Xibe-XDT',
305
+ 'sah_yktdt': 'UD_Yakut-YKTDT',
306
+ 'yo_ytb': 'UD_Yoruba-YTB',
307
+ 'ess_sli': 'UD_Yupik-SLI',
308
+ 'say_autogramm': 'UD_Zaar-Autogramm',
309
+ }
310
+
311
+
312
+ def short_name_to_treebank(short_name):
313
+ return SHORT_NAMES[short_name]
314
+
315
+
316
+ CANONICAL_NAMES = {
317
+ 'ud_abaza-atb': 'UD_Abaza-ATB',
318
+ 'ud_abkhaz-abnc': 'UD_Abkhaz-AbNC',
319
+ 'ud_afrikaans-afribooms': 'UD_Afrikaans-AfriBooms',
320
+ 'ud_akkadian-pisandub': 'UD_Akkadian-PISANDUB',
321
+ 'ud_akkadian-riao': 'UD_Akkadian-RIAO',
322
+ 'ud_akuntsu-tudet': 'UD_Akuntsu-TuDeT',
323
+ 'ud_albanian-staf': 'UD_Albanian-STAF',
324
+ 'ud_albanian-tsa': 'UD_Albanian-TSA',
325
+ 'ud_amharic-att': 'UD_Amharic-ATT',
326
+ 'ud_ancient_greek-proiel': 'UD_Ancient_Greek-PROIEL',
327
+ 'ud_ancient_greek-ptnk': 'UD_Ancient_Greek-PTNK',
328
+ 'ud_ancient_greek-perseus': 'UD_Ancient_Greek-Perseus',
329
+ 'ud_ancient_hebrew-ptnk': 'UD_Ancient_Hebrew-PTNK',
330
+ 'ud_apurina-ufpa': 'UD_Apurina-UFPA',
331
+ 'ud_arabic-nyuad': 'UD_Arabic-NYUAD',
332
+ 'ud_arabic-padt': 'UD_Arabic-PADT',
333
+ 'ud_arabic-pud': 'UD_Arabic-PUD',
334
+ 'ud_armenian-armtdp': 'UD_Armenian-ArmTDP',
335
+ 'ud_armenian-bsut': 'UD_Armenian-BSUT',
336
+ 'ud_assyrian-as': 'UD_Assyrian-AS',
337
+ 'ud_azerbaijani-tuecl': 'UD_Azerbaijani-TueCL',
338
+ 'ud_bambara-crb': 'UD_Bambara-CRB',
339
+ 'ud_basque-bdt': 'UD_Basque-BDT',
340
+ 'ud_bavarian-maibaam': 'UD_Bavarian-MaiBaam',
341
+ 'ud_beja-autogramm': 'UD_Beja-Autogramm',
342
+ 'ud_belarusian-hse': 'UD_Belarusian-HSE',
343
+ 'ud_bengali-bru': 'UD_Bengali-BRU',
344
+ 'ud_bhojpuri-bhtb': 'UD_Bhojpuri-BHTB',
345
+ 'ud_bororo-bdt': 'UD_Bororo-BDT',
346
+ 'ud_breton-keb': 'UD_Breton-KEB',
347
+ 'ud_bulgarian-btb': 'UD_Bulgarian-BTB',
348
+ 'ud_buryat-bdt': 'UD_Buryat-BDT',
349
+ 'ud_cantonese-hk': 'UD_Cantonese-HK',
350
+ 'ud_cappadocian-amgic': 'UD_Cappadocian-AMGiC',
351
+ 'ud_cappadocian-tuecl': 'UD_Cappadocian-TueCL',
352
+ 'ud_catalan-ancora': 'UD_Catalan-AnCora',
353
+ 'ud_cebuano-gja': 'UD_Cebuano-GJA',
354
+ 'ud_chinese-beginner': 'UD_Chinese-Beginner',
355
+ 'ud_chinese-cfl': 'UD_Chinese-CFL',
356
+ 'ud_chinese-gsd': 'UD_Chinese-GSD',
357
+ 'ud_chinese-gsdsimp': 'UD_Chinese-GSDSimp',
358
+ 'ud_chinese-hk': 'UD_Chinese-HK',
359
+ 'ud_chinese-pud': 'UD_Chinese-PUD',
360
+ 'ud_chinese-patentchar': 'UD_Chinese-PatentChar',
361
+ 'ud_chukchi-hse': 'UD_Chukchi-HSE',
362
+ 'ud_classical_armenian-caval': 'UD_Classical_Armenian-CAVaL',
363
+ 'ud_classical_chinese-kyoto': 'UD_Classical_Chinese-Kyoto',
364
+ 'ud_classical_chinese-tuecl': 'UD_Classical_Chinese-TueCL',
365
+ 'ud_coptic-scriptorium': 'UD_Coptic-Scriptorium',
366
+ 'ud_croatian-set': 'UD_Croatian-SET',
367
+ 'ud_czech-cac': 'UD_Czech-CAC',
368
+ 'ud_czech-cltt': 'UD_Czech-CLTT',
369
+ 'ud_czech-fictree': 'UD_Czech-FicTree',
370
+ 'ud_czech-pdt': 'UD_Czech-PDT',
371
+ 'ud_czech-pud': 'UD_Czech-PUD',
372
+ 'ud_czech-poetry': 'UD_Czech-Poetry',
373
+ 'ud_danish-ddt': 'UD_Danish-DDT',
374
+ 'ud_dutch-alpino': 'UD_Dutch-Alpino',
375
+ 'ud_dutch-lassysmall': 'UD_Dutch-LassySmall',
376
+ 'ud_egyptian-ujaen': 'UD_Egyptian-UJaen',
377
+ 'ud_english-atis': 'UD_English-Atis',
378
+ 'ud_english-ctetex': 'UD_English-CTeTex',
379
+ 'ud_english-eslspok': 'UD_English-ESLSpok',
380
+ 'ud_english-ewt': 'UD_English-EWT',
381
+ 'ud_english-gentle': 'UD_English-GENTLE',
382
+ 'ud_english-gum': 'UD_English-GUM',
383
+ 'ud_english-gumreddit': 'UD_English-GUMReddit',
384
+ 'ud_english-lines': 'UD_English-LinES',
385
+ 'ud_english-pud': 'UD_English-PUD',
386
+ 'ud_english-partut': 'UD_English-ParTUT',
387
+ 'ud_english-pronouns': 'UD_English-Pronouns',
388
+ 'ud_erzya-jr': 'UD_Erzya-JR',
389
+ 'ud_estonian-edt': 'UD_Estonian-EDT',
390
+ 'ud_estonian-ewt': 'UD_Estonian-EWT',
391
+ 'ud_faroese-farpahc': 'UD_Faroese-FarPaHC',
392
+ 'ud_faroese-oft': 'UD_Faroese-OFT',
393
+ 'ud_finnish-ftb': 'UD_Finnish-FTB',
394
+ 'ud_finnish-ood': 'UD_Finnish-OOD',
395
+ 'ud_finnish-pud': 'UD_Finnish-PUD',
396
+ 'ud_finnish-tdt': 'UD_Finnish-TDT',
397
+ 'ud_french-fqb': 'UD_French-FQB',
398
+ 'ud_french-gsd': 'UD_French-GSD',
399
+ 'ud_french-pud': 'UD_French-PUD',
400
+ 'ud_french-partut': 'UD_French-ParTUT',
401
+ 'ud_french-parisstories': 'UD_French-ParisStories',
402
+ 'ud_french-rhapsodie': 'UD_French-Rhapsodie',
403
+ 'ud_french-sequoia': 'UD_French-Sequoia',
404
+ 'ud_frisian_dutch-fame': 'UD_Frisian_Dutch-Fame',
405
+ 'ud_galician-ctg': 'UD_Galician-CTG',
406
+ 'ud_galician-pud': 'UD_Galician-PUD',
407
+ 'ud_galician-treegal': 'UD_Galician-TreeGal',
408
+ 'ud_georgian-glc': 'UD_Georgian-GLC',
409
+ 'ud_german-gsd': 'UD_German-GSD',
410
+ 'ud_german-hdt': 'UD_German-HDT',
411
+ 'ud_german-lit': 'UD_German-LIT',
412
+ 'ud_german-pud': 'UD_German-PUD',
413
+ 'ud_gheg-gps': 'UD_Gheg-GPS',
414
+ 'ud_gothic-proiel': 'UD_Gothic-PROIEL',
415
+ 'ud_greek-gdt': 'UD_Greek-GDT',
416
+ 'ud_greek-gud': 'UD_Greek-GUD',
417
+ 'ud_guajajara-tudet': 'UD_Guajajara-TuDeT',
418
+ 'ud_guarani-oldtudet': 'UD_Guarani-OldTuDeT',
419
+ 'ud_gujarati-gujtb': 'UD_Gujarati-GujTB',
420
+ 'ud_gwichin-tuecl': 'UD_Gwichin-TueCL',
421
+ 'ud_haitian_creole-autogramm': 'UD_Haitian_Creole-Autogramm',
422
+ 'ud_hausa-northernautogramm': 'UD_Hausa-NorthernAutogramm',
423
+ 'ud_hausa-southernautogramm': 'UD_Hausa-SouthernAutogramm',
424
+ 'ud_hebrew-htb': 'UD_Hebrew-HTB',
425
+ 'ud_hebrew-iahltknesset': 'UD_Hebrew-IAHLTknesset',
426
+ 'ud_hebrew-iahltwiki': 'UD_Hebrew-IAHLTwiki',
427
+ 'ud_highland_puebla_nahuatl-itml': 'UD_Highland_Puebla_Nahuatl-ITML',
428
+ 'ud_hindi-hdtb': 'UD_Hindi-HDTB',
429
+ 'ud_hindi-pud': 'UD_Hindi-PUD',
430
+ 'ud_hittite-hittb': 'UD_Hittite-HitTB',
431
+ 'ud_hungarian-szeged': 'UD_Hungarian-Szeged',
432
+ 'ud_icelandic-gc': 'UD_Icelandic-GC',
433
+ 'ud_icelandic-icepahc': 'UD_Icelandic-IcePaHC',
434
+ 'ud_icelandic-modern': 'UD_Icelandic-Modern',
435
+ 'ud_icelandic-pud': 'UD_Icelandic-PUD',
436
+ 'ud_indonesian-csui': 'UD_Indonesian-CSUI',
437
+ 'ud_indonesian-gsd': 'UD_Indonesian-GSD',
438
+ 'ud_indonesian-pud': 'UD_Indonesian-PUD',
439
+ 'ud_irish-cadhan': 'UD_Irish-Cadhan',
440
+ 'ud_irish-idt': 'UD_Irish-IDT',
441
+ 'ud_irish-twittirish': 'UD_Irish-TwittIrish',
442
+ 'ud_italian-isdt': 'UD_Italian-ISDT',
443
+ 'ud_italian-markit': 'UD_Italian-MarkIT',
444
+ 'ud_italian-old': 'UD_Italian-Old',
445
+ 'ud_italian-pud': 'UD_Italian-PUD',
446
+ 'ud_italian-partut': 'UD_Italian-ParTUT',
447
+ 'ud_italian-parlamint': 'UD_Italian-ParlaMint',
448
+ 'ud_italian-postwita': 'UD_Italian-PoSTWITA',
449
+ 'ud_italian-twittiro': 'UD_Italian-TWITTIRO',
450
+ 'ud_italian-vit': 'UD_Italian-VIT',
451
+ 'ud_italian-valico': 'UD_Italian-Valico',
452
+ 'ud_japanese-bccwj': 'UD_Japanese-BCCWJ',
453
+ 'ud_japanese-bccwjluw': 'UD_Japanese-BCCWJLUW',
454
+ 'ud_japanese-gsd': 'UD_Japanese-GSD',
455
+ 'ud_japanese-gsdluw': 'UD_Japanese-GSDLUW',
456
+ 'ud_japanese-pud': 'UD_Japanese-PUD',
457
+ 'ud_japanese-pudluw': 'UD_Japanese-PUDLUW',
458
+ 'ud_javanese-csui': 'UD_Javanese-CSUI',
459
+ 'ud_kaapor-tudet': 'UD_Kaapor-TuDeT',
460
+ 'ud_kangri-kdtb': 'UD_Kangri-KDTB',
461
+ 'ud_karelian-kkpp': 'UD_Karelian-KKPP',
462
+ 'ud_karo-tudet': 'UD_Karo-TuDeT',
463
+ 'ud_kazakh-ktb': 'UD_Kazakh-KTB',
464
+ 'ud_khunsari-aha': 'UD_Khunsari-AHA',
465
+ 'ud_kiche-iu': 'UD_Kiche-IU',
466
+ 'ud_komi_permyak-uh': 'UD_Komi_Permyak-UH',
467
+ 'ud_komi_zyrian-ikdp': 'UD_Komi_Zyrian-IKDP',
468
+ 'ud_komi_zyrian-lattice': 'UD_Komi_Zyrian-Lattice',
469
+ 'ud_korean-gsd': 'UD_Korean-GSD',
470
+ 'ud_korean-ksl': 'UD_Korean-KSL',
471
+ 'ud_korean-kaist': 'UD_Korean-Kaist',
472
+ 'ud_korean-pud': 'UD_Korean-PUD',
473
+ 'ud_kurmanji-mg': 'UD_Kurmanji-MG',
474
+ 'ud_kyrgyz-ktmu': 'UD_Kyrgyz-KTMU',
475
+ 'ud_kyrgyz-tuecl': 'UD_Kyrgyz-TueCL',
476
+ 'ud_latgalian-cairo': 'UD_Latgalian-Cairo',
477
+ 'ud_latin-circse': 'UD_Latin-CIRCSE',
478
+ 'ud_latin-ittb': 'UD_Latin-ITTB',
479
+ 'ud_latin-llct': 'UD_Latin-LLCT',
480
+ 'ud_latin-proiel': 'UD_Latin-PROIEL',
481
+ 'ud_latin-perseus': 'UD_Latin-Perseus',
482
+ 'ud_latin-udante': 'UD_Latin-UDante',
483
+ 'ud_latvian-cairo': 'UD_Latvian-Cairo',
484
+ 'ud_latvian-lvtb': 'UD_Latvian-LVTB',
485
+ 'ud_ligurian-glt': 'UD_Ligurian-GLT',
486
+ 'ud_lithuanian-alksnis': 'UD_Lithuanian-ALKSNIS',
487
+ 'ud_lithuanian-hse': 'UD_Lithuanian-HSE',
488
+ 'ud_livvi-kkpp': 'UD_Livvi-KKPP',
489
+ 'ud_low_saxon-lsdc': 'UD_Low_Saxon-LSDC',
490
+ 'ud_luxembourgish-luxbank': 'UD_Luxembourgish-LuxBank',
491
+ 'ud_macedonian-mtb': 'UD_Macedonian-MTB',
492
+ 'ud_madi-jarawara': 'UD_Madi-Jarawara',
493
+ 'ud_maghrebi_arabic_french-arabizi': 'UD_Maghrebi_Arabic_French-Arabizi',
494
+ 'ud_makurap-tudet': 'UD_Makurap-TuDeT',
495
+ 'ud_malayalam-ufal': 'UD_Malayalam-UFAL',
496
+ 'ud_maltese-mudt': 'UD_Maltese-MUDT',
497
+ 'ud_manx-cadhan': 'UD_Manx-Cadhan',
498
+ 'ud_marathi-ufal': 'UD_Marathi-UFAL',
499
+ 'ud_mbya_guarani-dooley': 'UD_Mbya_Guarani-Dooley',
500
+ 'ud_mbya_guarani-thomas': 'UD_Mbya_Guarani-Thomas',
501
+ 'ud_middle_french-profiterole': 'UD_Middle_French-PROFITEROLE',
502
+ 'ud_moksha-jr': 'UD_Moksha-JR',
503
+ 'ud_munduruku-tudet': 'UD_Munduruku-TuDeT',
504
+ 'ud_naija-nsc': 'UD_Naija-NSC',
505
+ 'ud_nayini-aha': 'UD_Nayini-AHA',
506
+ 'ud_neapolitan-rb': 'UD_Neapolitan-RB',
507
+ 'ud_nheengatu-complin': 'UD_Nheengatu-CompLin',
508
+ 'ud_north_sami-giella': 'UD_North_Sami-Giella',
509
+ 'ud_northwest_gbaya-autogramm': 'UD_Northwest_Gbaya-Autogramm',
510
+ 'ud_norwegian-bokmaal': 'UD_Norwegian-Bokmaal',
511
+ 'ud_norwegian-nynorsk': 'UD_Norwegian-Nynorsk',
512
+ 'ud_old_church_slavonic-proiel': 'UD_Old_Church_Slavonic-PROIEL',
513
+ 'ud_old_east_slavic-birchbark': 'UD_Old_East_Slavic-Birchbark',
514
+ 'ud_old_east_slavic-rnc': 'UD_Old_East_Slavic-RNC',
515
+ 'ud_old_east_slavic-ruthenian': 'UD_Old_East_Slavic-Ruthenian',
516
+ 'ud_old_east_slavic-torot': 'UD_Old_East_Slavic-TOROT',
517
+ 'ud_old_french-profiterole': 'UD_Old_French-PROFITEROLE',
518
+ 'ud_old_irish-dipsgg': 'UD_Old_Irish-DipSGG',
519
+ 'ud_old_irish-dipwbg': 'UD_Old_Irish-DipWBG',
520
+ 'ud_old_turkish-clausal': 'UD_Old_Turkish-Clausal',
521
+ 'ud_ottoman_turkish-boun': 'UD_Ottoman_Turkish-BOUN',
522
+ 'ud_ottoman_turkish-dudu': 'UD_Ottoman_Turkish-DUDU',
523
+ 'ud_pashto-sikaram': 'UD_Pashto-Sikaram',
524
+ 'ud_paumari-tuecl': 'UD_Paumari-TueCL',
525
+ 'ud_persian-perdt': 'UD_Persian-PerDT',
526
+ 'ud_persian-seraji': 'UD_Persian-Seraji',
527
+ 'ud_pesh-chibergis': 'UD_Pesh-ChibErgIS',
528
+ 'ud_phrygian-kul': 'UD_Phrygian-KUL',
529
+ 'ud_polish-lfg': 'UD_Polish-LFG',
530
+ 'ud_polish-pdb': 'UD_Polish-PDB',
531
+ 'ud_polish-pud': 'UD_Polish-PUD',
532
+ 'ud_pomak-philotis': 'UD_Pomak-Philotis',
533
+ 'ud_portuguese-bosque': 'UD_Portuguese-Bosque',
534
+ 'ud_portuguese-cintil': 'UD_Portuguese-CINTIL',
535
+ 'ud_portuguese-dantestocks': 'UD_Portuguese-DANTEStocks',
536
+ 'ud_portuguese-gsd': 'UD_Portuguese-GSD',
537
+ 'ud_portuguese-pud': 'UD_Portuguese-PUD',
538
+ 'ud_portuguese-petrogold': 'UD_Portuguese-PetroGold',
539
+ 'ud_portuguese-porttinari': 'UD_Portuguese-Porttinari',
540
+ 'ud_romanian-art': 'UD_Romanian-ArT',
541
+ 'ud_romanian-nonstandard': 'UD_Romanian-Nonstandard',
542
+ 'ud_romanian-rrt': 'UD_Romanian-RRT',
543
+ 'ud_romanian-simonero': 'UD_Romanian-SiMoNERo',
544
+ 'ud_romanian-tuecl': 'UD_Romanian-TueCL',
545
+ 'ud_russian-gsd': 'UD_Russian-GSD',
546
+ 'ud_russian-pud': 'UD_Russian-PUD',
547
+ 'ud_russian-poetry': 'UD_Russian-Poetry',
548
+ 'ud_russian-syntagrus': 'UD_Russian-SynTagRus',
549
+ 'ud_russian-taiga': 'UD_Russian-Taiga',
550
+ 'ud_sanskrit-ufal': 'UD_Sanskrit-UFAL',
551
+ 'ud_sanskrit-vedic': 'UD_Sanskrit-Vedic',
552
+ 'ud_scottish_gaelic-arcosg': 'UD_Scottish_Gaelic-ARCOSG',
553
+ 'ud_serbian-set': 'UD_Serbian-SET',
554
+ 'ud_sinhala-stb': 'UD_Sinhala-STB',
555
+ 'ud_skolt_sami-giellagas': 'UD_Skolt_Sami-Giellagas',
556
+ 'ud_slovak-snk': 'UD_Slovak-SNK',
557
+ 'ud_slovenian-ssj': 'UD_Slovenian-SSJ',
558
+ 'ud_slovenian-sst': 'UD_Slovenian-SST',
559
+ 'ud_soi-aha': 'UD_Soi-AHA',
560
+ 'ud_south_levantine_arabic-madar': 'UD_South_Levantine_Arabic-MADAR',
561
+ 'ud_spanish-ancora': 'UD_Spanish-AnCora',
562
+ 'ud_spanish-coser': 'UD_Spanish-COSER',
563
+ 'ud_spanish-gsd': 'UD_Spanish-GSD',
564
+ 'ud_spanish-pud': 'UD_Spanish-PUD',
565
+ 'ud_spanish_sign_language-lse': 'UD_Spanish_Sign_Language-LSE',
566
+ 'ud_swedish-lines': 'UD_Swedish-LinES',
567
+ 'ud_swedish-pud': 'UD_Swedish-PUD',
568
+ 'ud_swedish-talbanken': 'UD_Swedish-Talbanken',
569
+ 'ud_swedish_sign_language-sslc': 'UD_Swedish_Sign_Language-SSLC',
570
+ 'ud_swiss_german-uzh': 'UD_Swiss_German-UZH',
571
+ 'ud_tagalog-trg': 'UD_Tagalog-TRG',
572
+ 'ud_tagalog-ugnayan': 'UD_Tagalog-Ugnayan',
573
+ 'ud_tamil-mwtt': 'UD_Tamil-MWTT',
574
+ 'ud_tamil-ttb': 'UD_Tamil-TTB',
575
+ 'ud_tatar-nmctt': 'UD_Tatar-NMCTT',
576
+ 'ud_teko-tudet': 'UD_Teko-TuDeT',
577
+ 'ud_telugu-mtg': 'UD_Telugu-MTG',
578
+ 'ud_telugu_english-tect': 'UD_Telugu_English-TECT',
579
+ 'ud_thai-pud': 'UD_Thai-PUD',
580
+ 'ud_tswana-popapolelo': 'UD_Tswana-Popapolelo',
581
+ 'ud_tupinamba-tudet': 'UD_Tupinamba-TuDeT',
582
+ 'ud_turkish-atis': 'UD_Turkish-Atis',
583
+ 'ud_turkish-boun': 'UD_Turkish-BOUN',
584
+ 'ud_turkish-framenet': 'UD_Turkish-FrameNet',
585
+ 'ud_turkish-gb': 'UD_Turkish-GB',
586
+ 'ud_turkish-imst': 'UD_Turkish-IMST',
587
+ 'ud_turkish-kenet': 'UD_Turkish-Kenet',
588
+ 'ud_turkish-pud': 'UD_Turkish-PUD',
589
+ 'ud_turkish-penn': 'UD_Turkish-Penn',
590
+ 'ud_turkish-tourism': 'UD_Turkish-Tourism',
591
+ 'ud_turkish_german-sagt': 'UD_Turkish_German-SAGT',
592
+ 'ud_ukrainian-iu': 'UD_Ukrainian-IU',
593
+ 'ud_ukrainian-parlamint': 'UD_Ukrainian-ParlaMint',
594
+ 'ud_umbrian-ikuvina': 'UD_Umbrian-IKUVINA',
595
+ 'ud_upper_sorbian-ufal': 'UD_Upper_Sorbian-UFAL',
596
+ 'ud_urdu-udtb': 'UD_Urdu-UDTB',
597
+ 'ud_uyghur-udt': 'UD_Uyghur-UDT',
598
+ 'ud_uzbek-ut': 'UD_Uzbek-UT',
599
+ 'ud_veps-vwt': 'UD_Veps-VWT',
600
+ 'ud_vietnamese-tuecl': 'UD_Vietnamese-TueCL',
601
+ 'ud_vietnamese-vtb': 'UD_Vietnamese-VTB',
602
+ 'ud_warlpiri-ufal': 'UD_Warlpiri-UFAL',
603
+ 'ud_welsh-ccg': 'UD_Welsh-CCG',
604
+ 'ud_western_armenian-armtdp': 'UD_Western_Armenian-ArmTDP',
605
+ 'ud_western_sierra_puebla_nahuatl-itml': 'UD_Western_Sierra_Puebla_Nahuatl-ITML',
606
+ 'ud_wolof-wtb': 'UD_Wolof-WTB',
607
+ 'ud_xavante-xdt': 'UD_Xavante-XDT',
608
+ 'ud_xibe-xdt': 'UD_Xibe-XDT',
609
+ 'ud_yakut-yktdt': 'UD_Yakut-YKTDT',
610
+ 'ud_yoruba-ytb': 'UD_Yoruba-YTB',
611
+ 'ud_yupik-sli': 'UD_Yupik-SLI',
612
+ 'ud_zaar-autogramm': 'UD_Zaar-Autogramm',
613
+ }
614
+
615
+
616
+ def canonical_treebank_name(ud_name):
617
+ if ud_name in SHORT_NAMES:
618
+ return SHORT_NAMES[ud_name]
619
+ return CANONICAL_NAMES.get(ud_name.lower(), ud_name)
stanza/stanza/models/common/trainer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Trainer:
4
+ def change_lr(self, new_lr):
5
+ for param_group in self.optimizer.param_groups:
6
+ param_group['lr'] = new_lr
7
+
8
+ def save(self, filename):
9
+ savedict = {
10
+ 'model': self.model.state_dict(),
11
+ 'optimizer': self.optimizer.state_dict()
12
+ }
13
+ torch.save(savedict, filename)
14
+
15
+ def load(self, filename):
16
+ savedict = torch.load(filename, lambda storage, loc: storage, weights_only=True)
17
+
18
+ self.model.load_state_dict(savedict['model'])
19
+ if self.args['mode'] == 'train':
20
+ self.optimizer.load_state_dict(savedict['optimizer'])
stanza/stanza/models/common/utils.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions.
3
+ """
4
+
5
+ import argparse
6
+ from collections import Counter
7
+ from contextlib import contextmanager
8
+ import gzip
9
+ import json
10
+ import logging
11
+ import lzma
12
+ import os
13
+ import random
14
+ import re
15
+ import sys
16
+ import unicodedata
17
+ import zipfile
18
+
19
+ import torch
20
+ import numpy as np
21
+
22
+ from stanza.models.common.constant import lcode2lang
23
+ import stanza.models.common.seq2seq_constant as constant
24
+ from stanza.resources.default_packages import TRANSFORMER_NICKNAMES
25
+ import stanza.utils.conll18_ud_eval as ud_eval
26
+ from stanza.utils.conll18_ud_eval import UDError
27
+
28
+ logger = logging.getLogger('stanza')
29
+
30
+ # filenames
31
+ def get_wordvec_file(wordvec_dir, shorthand, wordvec_type=None):
32
+ """ Lookup the name of the word vectors file, given a directory and the language shorthand.
33
+ """
34
+ lcode, tcode = shorthand.split('_', 1)
35
+ lang = lcode2lang[lcode]
36
+ # locate language folder
37
+ word2vec_dir = os.path.join(wordvec_dir, 'word2vec', lang)
38
+ fasttext_dir = os.path.join(wordvec_dir, 'fasttext', lang)
39
+ lang_dir = None
40
+ if wordvec_type is not None:
41
+ lang_dir = os.path.join(wordvec_dir, wordvec_type, lang)
42
+ if not os.path.exists(lang_dir):
43
+ raise FileNotFoundError("Word vector type {} was specified, but directory {} does not exist".format(wordvec_type, lang_dir))
44
+ elif os.path.exists(word2vec_dir): # first try word2vec
45
+ lang_dir = word2vec_dir
46
+ elif os.path.exists(fasttext_dir): # otherwise try fasttext
47
+ lang_dir = fasttext_dir
48
+ else:
49
+ raise FileNotFoundError("Cannot locate word vector directory for language: {} Looked in {} and {}".format(lang, word2vec_dir, fasttext_dir))
50
+ # look for wordvec filename in {lang_dir}
51
+ filename = os.path.join(lang_dir, '{}.vectors'.format(lcode))
52
+ if os.path.exists(filename + ".xz"):
53
+ filename = filename + ".xz"
54
+ elif os.path.exists(filename + ".txt"):
55
+ filename = filename + ".txt"
56
+ return filename
57
+
58
+ @contextmanager
59
+ def output_stream(filename=None):
60
+ """
61
+ Yields the given file if a file is given, or returns sys.stdout if filename is None
62
+
63
+ Opens the file in a context manager so it closes nicely
64
+ """
65
+ if filename is None:
66
+ yield sys.stdout
67
+ else:
68
+ with open(filename, "w", encoding="utf-8") as fout:
69
+ yield fout
70
+
71
+
72
+ @contextmanager
73
+ def open_read_text(filename, encoding="utf-8"):
74
+ """
75
+ Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular text otherwise.
76
+
77
+ Use as a context
78
+
79
+ eg:
80
+ with open_read_text(filename) as fin:
81
+ do stuff
82
+
83
+ File will be closed once the context exits
84
+ """
85
+ if filename.endswith(".xz"):
86
+ with lzma.open(filename, mode='rt', encoding=encoding) as fin:
87
+ yield fin
88
+ elif filename.endswith(".gz"):
89
+ with gzip.open(filename, mode='rt', encoding=encoding) as fin:
90
+ yield fin
91
+ else:
92
+ with open(filename, encoding=encoding) as fin:
93
+ yield fin
94
+
95
+ @contextmanager
96
+ def open_read_binary(filename):
97
+ """
98
+ Opens a file as an .xz file or .gz if it ends with .xz or .gz, or regular binary file otherwise.
99
+
100
+ If a .zip file is given, it can be read if there is a single file in there
101
+
102
+ Use as a context
103
+
104
+ eg:
105
+ with open_read_binary(filename) as fin:
106
+ do stuff
107
+
108
+ File will be closed once the context exits
109
+ """
110
+ if filename.endswith(".xz"):
111
+ with lzma.open(filename, mode='rb') as fin:
112
+ yield fin
113
+ elif filename.endswith(".gz"):
114
+ with gzip.open(filename, mode='rb') as fin:
115
+ yield fin
116
+ elif filename.endswith(".zip"):
117
+ with zipfile.ZipFile(filename) as zin:
118
+ input_names = zin.namelist()
119
+ if len(input_names) == 0:
120
+ raise ValueError("Empty zip archive")
121
+ if len(input_names) > 1:
122
+ raise ValueError("zip file %s has more than one file in it")
123
+ with zin.open(input_names[0]) as fin:
124
+ yield fin
125
+ else:
126
+ with open(filename, mode='rb') as fin:
127
+ yield fin
128
+
129
+ # training schedule
130
+ def get_adaptive_eval_interval(cur_dev_size, thres_dev_size, base_interval):
131
+ """ Adjust the evaluation interval adaptively.
132
+ If cur_dev_size <= thres_dev_size, return base_interval;
133
+ else, linearly increase the interval (round to integer times of base interval).
134
+ """
135
+ if cur_dev_size <= thres_dev_size:
136
+ return base_interval
137
+ else:
138
+ alpha = round(cur_dev_size / thres_dev_size)
139
+ return base_interval * alpha
140
+
141
+ # ud utils
142
+ def ud_scores(gold_conllu_file, system_conllu_file):
143
+ try:
144
+ gold_ud = ud_eval.load_conllu_file(gold_conllu_file)
145
+ except UDError as e:
146
+ raise UDError("Could not read %s" % gold_conllu_file) from e
147
+
148
+ try:
149
+ system_ud = ud_eval.load_conllu_file(system_conllu_file)
150
+ except UDError as e:
151
+ raise UDError("Could not read %s" % system_conllu_file) from e
152
+ evaluation = ud_eval.evaluate(gold_ud, system_ud)
153
+
154
+ return evaluation
155
+
156
+ def harmonic_mean(a, weights=None):
157
+ if any([x == 0 for x in a]):
158
+ return 0
159
+ else:
160
+ assert weights is None or len(weights) == len(a), 'Weights has length {} which is different from that of the array ({}).'.format(len(weights), len(a))
161
+ if weights is None:
162
+ return len(a) / sum([1/x for x in a])
163
+ else:
164
+ return sum(weights) / sum(w/x for x, w in zip(a, weights))
165
+
166
+ # torch utils
167
+ def dispatch_optimizer(name, parameters, opt_logger, lr=None, betas=None, eps=None, momentum=None, **extra_args):
168
+ extra_logging = ""
169
+ if len(extra_args) > 0:
170
+ extra_logging = ", " + ", ".join("%s=%s" % (x, y) for x, y in extra_args.items())
171
+
172
+ if name == 'amsgrad':
173
+ opt_logger.debug("Building Adam w/ amsgrad with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
174
+ return torch.optim.Adam(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
175
+ elif name == 'amsgradw':
176
+ opt_logger.debug("Building AdamW w/ amsgrad with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
177
+ return torch.optim.AdamW(parameters, amsgrad=True, lr=lr, betas=betas, eps=eps, **extra_args)
178
+ elif name == 'sgd':
179
+ opt_logger.debug("Building SGD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
180
+ return torch.optim.SGD(parameters, lr=lr, momentum=momentum, **extra_args)
181
+ elif name == 'adagrad':
182
+ opt_logger.debug("Building Adagrad with lr=%f%s", lr, extra_logging)
183
+ return torch.optim.Adagrad(parameters, lr=lr, **extra_args)
184
+ elif name == 'adam':
185
+ opt_logger.debug("Building Adam with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
186
+ return torch.optim.Adam(parameters, lr=lr, betas=betas, eps=eps, **extra_args)
187
+ elif name == 'adamw':
188
+ opt_logger.debug("Building AdamW with lr=%f, betas=%s, eps=%f%s", lr, betas, eps, extra_logging)
189
+ return torch.optim.AdamW(parameters, lr=lr, betas=betas, eps=eps, **extra_args)
190
+ elif name == 'adamax':
191
+ opt_logger.debug("Building Adamax%s", extra_logging)
192
+ return torch.optim.Adamax(parameters, **extra_args) # use default lr
193
+ elif name == 'adadelta':
194
+ opt_logger.debug("Building Adadelta with lr=%f%s", lr, extra_logging)
195
+ return torch.optim.Adadelta(parameters, lr=lr, **extra_args)
196
+ elif name == 'adabelief':
197
+ try:
198
+ from adabelief_pytorch import AdaBelief
199
+ except ModuleNotFoundError as e:
200
+ raise ModuleNotFoundError("Could not create adabelief optimizer. Perhaps the adabelief-pytorch package is not installed") from e
201
+ opt_logger.debug("Building AdaBelief with lr=%f, eps=%f%s", lr, eps, extra_logging)
202
+ # TODO: add weight_decouple and rectify as extra args?
203
+ return AdaBelief(parameters, lr=lr, eps=eps, weight_decouple=True, rectify=True, **extra_args)
204
+ elif name == 'madgrad':
205
+ try:
206
+ import madgrad
207
+ except ModuleNotFoundError as e:
208
+ raise ModuleNotFoundError("Could not create madgrad optimizer. Perhaps the madgrad package is not installed") from e
209
+ opt_logger.debug("Building MADGRAD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
210
+ return madgrad.MADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)
211
+ elif name == 'mirror_madgrad':
212
+ try:
213
+ import madgrad
214
+ except ModuleNotFoundError as e:
215
+ raise ModuleNotFoundError("Could not create mirror_madgrad optimizer. Perhaps the madgrad package is not installed") from e
216
+ opt_logger.debug("Building MirrorMADGRAD with lr=%f, momentum=%f%s", lr, momentum, extra_logging)
217
+ return madgrad.MirrorMADGRAD(parameters, lr=lr, momentum=momentum, **extra_args)
218
+ else:
219
+ raise ValueError("Unsupported optimizer: {}".format(name))
220
+
221
+
222
+ def get_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None, opt_logger=None):
223
+ opt_logger = opt_logger if opt_logger is not None else logger
224
+ base_parameters = [p for n, p in model.named_parameters()
225
+ if p.requires_grad and not n.startswith("bert_model.")
226
+ and not n.startswith("charmodel_forward.") and not n.startswith("charmodel_backward.")]
227
+ parameters = [{'param_group_name': 'base', 'params': base_parameters}]
228
+
229
+ charlm_parameters = [p for n, p in model.named_parameters()
230
+ if p.requires_grad and (n.startswith("charmodel_forward.") or n.startswith("charmodel_backward."))]
231
+ if len(charlm_parameters) > 0 and charlm_learning_rate > 0:
232
+ parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})
233
+
234
+ if not is_peft:
235
+ bert_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith("bert_model.")]
236
+
237
+ # bert_finetune_layers limits the bert finetuning to the *last* N layers of the model
238
+ if len(bert_parameters) > 0 and bert_finetune_layers is not None:
239
+ num_layers = model.bert_model.config.num_hidden_layers
240
+ start_layer = num_layers - bert_finetune_layers
241
+ bert_parameters = []
242
+ for layer_num in range(start_layer, num_layers):
243
+ bert_parameters.extend([param for name, param in model.named_parameters()
244
+ if param.requires_grad and name.startswith("bert_model.") and "layer.%d." % layer_num in name])
245
+
246
+ if len(bert_parameters) > 0 and bert_learning_rate > 0:
247
+ opt_logger.debug("Finetuning %d bert parameters with LR %s and WD %s", len(bert_parameters), lr * bert_learning_rate, bert_weight_decay)
248
+ parameters.append({'param_group_name': 'bert', 'params': bert_parameters, 'lr': lr * bert_learning_rate})
249
+ if bert_weight_decay is not None:
250
+ parameters[-1]['weight_decay'] = bert_weight_decay
251
+ else:
252
+ # some optimizers seem to train some even with a learning rate of 0...
253
+ if bert_learning_rate > 0:
254
+ # because PEFT handles what to hand to an optimizer, we don't want to touch that
255
+ parameters.append({'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate})
256
+ if bert_weight_decay is not None:
257
+ parameters[-1]['weight_decay'] = bert_weight_decay
258
+
259
+ extra_args = {}
260
+ if weight_decay is not None:
261
+ extra_args["weight_decay"] = weight_decay
262
+
263
+ return dispatch_optimizer(name, parameters, opt_logger=opt_logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
264
+
265
+ def get_split_optimizer(name, model, lr, betas=(0.9, 0.999), eps=1e-8, momentum=0, weight_decay=None, bert_learning_rate=0.0, bert_weight_decay=None, charlm_learning_rate=0.0, is_peft=False, bert_finetune_layers=None):
266
+ """Same as `get_optimizer`, but splits the optimizer for Bert into a seperate optimizer"""
267
+ base_parameters = [p for n, p in model.named_parameters()
268
+ if p.requires_grad and not n.startswith("bert_model.")
269
+ and not n.startswith("charmodel_forward.") and not n.startswith("charmodel_backward.")]
270
+ parameters = [{'param_group_name': 'base', 'params': base_parameters}]
271
+
272
+ charlm_parameters = [p for n, p in model.named_parameters()
273
+ if p.requires_grad and (n.startswith("charmodel_forward.") or n.startswith("charmodel_backward."))]
274
+ if len(charlm_parameters) > 0 and charlm_learning_rate > 0:
275
+ parameters.append({'param_group_name': 'charlm', 'params': charlm_parameters, 'lr': lr * charlm_learning_rate})
276
+
277
+ bert_parameters = None
278
+ if not is_peft:
279
+ trainable_parameters = [p for n, p in model.named_parameters() if p.requires_grad and n.startswith("bert_model.")]
280
+
281
+ # bert_finetune_layers limits the bert finetuning to the *last* N layers of the model
282
+ if len(trainable_parameters) > 0 and bert_finetune_layers is not None:
283
+ num_layers = model.bert_model.config.num_hidden_layers
284
+ start_layer = num_layers - bert_finetune_layers
285
+ trainable_parameters = []
286
+ for layer_num in range(start_layer, num_layers):
287
+ trainable_parameters.extend([param for name, param in model.named_parameters()
288
+ if param.requires_grad and name.startswith("bert_model.") and "layer.%d." % layer_num in name])
289
+
290
+ if len(trainable_parameters) > 0:
291
+ bert_parameters = [{'param_group_name': 'bert', 'params': trainable_parameters, 'lr': lr * bert_learning_rate}]
292
+ else:
293
+ # because PEFT handles what to hand to an optimizer, we don't want to touch that
294
+ bert_parameters = [{'param_group_name': 'bert', 'params': model.bert_model.parameters(), 'lr': lr * bert_learning_rate}]
295
+
296
+ extra_args = {}
297
+ if weight_decay is not None:
298
+ extra_args["weight_decay"] = weight_decay
299
+
300
+ optimizers = {
301
+ "general_optimizer": dispatch_optimizer(name, parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
302
+ }
303
+ if bert_parameters is not None and bert_learning_rate > 0.0:
304
+ if bert_weight_decay is not None:
305
+ extra_args['weight_decay'] = bert_weight_decay
306
+ optimizers["bert_optimizer"] = dispatch_optimizer(name, bert_parameters, opt_logger=logger, lr=lr, betas=betas, eps=eps, momentum=momentum, **extra_args)
307
+ return optimizers
308
+
309
+
310
+ def change_lr(optimizer, new_lr):
311
+ for param_group in optimizer.param_groups:
312
+ param_group['lr'] = new_lr
313
+
314
+ def flatten_indices(seq_lens, width):
315
+ flat = []
316
+ for i, l in enumerate(seq_lens):
317
+ for j in range(l):
318
+ flat.append(i * width + j)
319
+ return flat
320
+
321
+ def keep_partial_grad(grad, topk):
322
+ """
323
+ Keep only the topk rows of grads.
324
+ """
325
+ assert topk < grad.size(0)
326
+ grad.data[topk:].zero_()
327
+ return grad
328
+
329
+ # other utils
330
+ def ensure_dir(d, verbose=True):
331
+ if not os.path.exists(d):
332
+ if verbose:
333
+ logger.info("Directory {} does not exist; creating...".format(d))
334
+ # exist_ok: guard against race conditions
335
+ os.makedirs(d, exist_ok=True)
336
+
337
+ def save_config(config, path, verbose=True):
338
+ with open(path, 'w') as outfile:
339
+ json.dump(config, outfile, indent=2)
340
+ if verbose:
341
+ print("Config saved to file {}".format(path))
342
+ return config
343
+
344
+ def load_config(path, verbose=True):
345
+ with open(path) as f:
346
+ config = json.load(f)
347
+ if verbose:
348
+ print("Config loaded from file {}".format(path))
349
+ return config
350
+
351
+ def print_config(config):
352
+ info = "Running with the following configs:\n"
353
+ for k,v in config.items():
354
+ info += "\t{} : {}\n".format(k, str(v))
355
+ logger.info("\n" + info + "\n")
356
+
357
+ def normalize_text(text):
358
+ return unicodedata.normalize('NFD', text)
359
+
360
+ def unmap_with_copy(indices, src_tokens, vocab):
361
+ """
362
+ Unmap a list of list of indices, by optionally copying from src_tokens.
363
+ """
364
+ result = []
365
+ for ind, tokens in zip(indices, src_tokens):
366
+ words = []
367
+ for idx in ind:
368
+ if idx >= 0:
369
+ words.append(vocab.id2word[idx])
370
+ else:
371
+ idx = -idx - 1 # flip and minus 1
372
+ words.append(tokens[idx])
373
+ result += [words]
374
+ return result
375
+
376
+ def prune_decoded_seqs(seqs):
377
+ """
378
+ Prune decoded sequences after EOS token.
379
+ """
380
+ out = []
381
+ for s in seqs:
382
+ if constant.EOS in s:
383
+ idx = s.index(constant.EOS_TOKEN)
384
+ out += [s[:idx]]
385
+ else:
386
+ out += [s]
387
+ return out
388
+
389
+ def prune_hyp(hyp):
390
+ """
391
+ Prune a decoded hypothesis
392
+ """
393
+ if constant.EOS_ID in hyp:
394
+ idx = hyp.index(constant.EOS_ID)
395
+ return hyp[:idx]
396
+ else:
397
+ return hyp
398
+
399
+ def prune(data_list, lens):
400
+ assert len(data_list) == len(lens)
401
+ nl = []
402
+ for d, l in zip(data_list, lens):
403
+ nl.append(d[:l])
404
+ return nl
405
+
406
+ def sort(packed, ref, reverse=True):
407
+ """
408
+ Sort a series of packed list, according to a ref list.
409
+ Also return the original index before the sort.
410
+ """
411
+ assert (isinstance(packed, tuple) or isinstance(packed, list)) and isinstance(ref, list)
412
+ packed = [ref] + [range(len(ref))] + list(packed)
413
+ sorted_packed = [list(t) for t in zip(*sorted(zip(*packed), reverse=reverse))]
414
+ return tuple(sorted_packed[1:])
415
+
416
+ def unsort(sorted_list, oidx):
417
+ """
418
+ Unsort a sorted list, based on the original idx.
419
+ """
420
+ assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
421
+ if len(sorted_list) == 0:
422
+ return []
423
+ _, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
424
+ return unsorted
425
+
426
+ def sort_with_indices(data, key=None, reverse=False):
427
+ """
428
+ Sort data and return both the data and the original indices.
429
+
430
+ One useful application is to sort by length, which can be done with key=len
431
+ Returns the data as a sorted list, then the indices of the original list.
432
+ """
433
+ if not data:
434
+ return [], []
435
+ if key:
436
+ ordered = sorted(enumerate(data), key=lambda x: key(x[1]), reverse=reverse)
437
+ else:
438
+ ordered = sorted(enumerate(data), key=lambda x: x[1], reverse=reverse)
439
+
440
+ result = tuple(zip(*ordered))
441
+ return result[1], result[0]
442
+
443
+ def split_into_batches(data, batch_size):
444
+ """
445
+ Returns a list of intervals so that each interval is either <= batch_size or one element long.
446
+
447
+ Long elements are not dropped from the intervals.
448
+ data is a list of lists
449
+ batch_size is how long to make each batch
450
+ return value is a list of pairs, start_idx end_idx
451
+ """
452
+ intervals = []
453
+ interval_start = 0
454
+ interval_size = 0
455
+ for idx, line in enumerate(data):
456
+ if len(line) > batch_size:
457
+ # guess we'll just hope the model can handle a batch of this size after all
458
+ if interval_size > 0:
459
+ intervals.append((interval_start, idx))
460
+ intervals.append((idx, idx+1))
461
+ interval_start = idx+1
462
+ interval_size = 0
463
+ elif len(line) + interval_size > batch_size:
464
+ # this line puts us over batch_size
465
+ intervals.append((interval_start, idx))
466
+ interval_start = idx
467
+ interval_size = len(line)
468
+ else:
469
+ interval_size = interval_size + len(line)
470
+ if interval_size > 0:
471
+ # there's some leftover
472
+ intervals.append((interval_start, len(data)))
473
+ return intervals
474
+
475
+ def tensor_unsort(sorted_tensor, oidx):
476
+ """
477
+ Unsort a sorted tensor on its 0-th dimension, based on the original idx.
478
+ """
479
+ assert sorted_tensor.size(0) == len(oidx), "Number of list elements must match with original indices."
480
+ backidx = [x[0] for x in sorted(enumerate(oidx), key=lambda x: x[1])]
481
+ return sorted_tensor[backidx]
482
+
483
+
484
+ def set_random_seed(seed):
485
+ """
486
+ Set a random seed on all of the things which might need it.
487
+ torch, np, python random, and torch.cuda
488
+ """
489
+ if seed is None:
490
+ seed = random.randint(0, 1000000000)
491
+
492
+ torch.manual_seed(seed)
493
+ np.random.seed(seed)
494
+ random.seed(seed)
495
+ # some of these calls are probably redundant
496
+ torch.manual_seed(seed)
497
+ if torch.cuda.is_available():
498
+ torch.cuda.manual_seed(seed)
499
+ torch.cuda.manual_seed_all(seed)
500
+ return seed
501
+
502
+ def find_missing_tags(known_tags, test_tags):
503
+ if isinstance(known_tags, list) and isinstance(known_tags[0], list):
504
+ known_tags = set(x for y in known_tags for x in y)
505
+ if isinstance(test_tags, list) and isinstance(test_tags[0], list):
506
+ test_tags = sorted(set(x for y in test_tags for x in y))
507
+ missing_tags = sorted(x for x in test_tags if x not in known_tags)
508
+ return missing_tags
509
+
510
+ def warn_missing_tags(known_tags, test_tags, test_set_name):
511
+ """
512
+ Print a warning if any tags present in the second list are not in the first list.
513
+
514
+ Can also handle a list of lists.
515
+ """
516
+ missing_tags = find_missing_tags(known_tags, test_tags)
517
+ if len(missing_tags) > 0:
518
+ logger.warning("Found tags in {} missing from the expected tag set: {}".format(test_set_name, missing_tags))
519
+ return True
520
+ return False
521
+
522
+ def checkpoint_name(save_dir, save_name, checkpoint_name):
523
+ """
524
+ Will return a recommended checkpoint name for the given dir, save_name, optional checkpoint_name
525
+
526
+ For example, can pass in args['save_dir'], args['save_name'], args['checkpoint_save_name']
527
+ """
528
+ if checkpoint_name:
529
+ model_dir = os.path.split(checkpoint_name)[0]
530
+ if model_dir == save_dir:
531
+ return checkpoint_name
532
+ return os.path.join(save_dir, checkpoint_name)
533
+
534
+ model_dir = os.path.split(save_name)[0]
535
+ if model_dir != save_dir:
536
+ save_name = os.path.join(save_dir, save_name)
537
+ if save_name.endswith(".pt"):
538
+ return save_name[:-3] + "_checkpoint.pt"
539
+
540
+ return save_name + "_checkpoint"
541
+
542
+ def default_device():
543
+ """
544
+ Pick a default device based on what's available on this system
545
+ """
546
+ if torch.cuda.is_available():
547
+ return 'cuda'
548
+ return 'cpu'
549
+
550
+ def add_device_args(parser):
551
+ """
552
+ Add args which specify cpu, cuda, or arbitrary device
553
+ """
554
+ parser.add_argument('--device', type=str, default=default_device(), help='Which device to run on - use a torch device string name')
555
+ parser.add_argument('--cuda', dest='device', action='store_const', const='cuda', help='Run on CUDA')
556
+ parser.add_argument('--cpu', dest='device', action='store_const', const='cpu', help='Ignore CUDA and run on CPU')
557
+
558
+ def load_elmo(elmo_model):
559
+ # This import is here so that Elmo integration can be treated
560
+ # as an optional feature
561
+ import elmoformanylangs
562
+
563
+ logger.info("Loading elmo: %s" % elmo_model)
564
+ elmo_model = elmoformanylangs.Embedder(elmo_model)
565
+ return elmo_model
566
+
567
+ def log_training_args(args, args_logger, name="training"):
568
+ """
569
+ For record keeping purposes, log the arguments when training
570
+ """
571
+ if isinstance(args, argparse.Namespace):
572
+ args = vars(args)
573
+ keys = sorted(args.keys())
574
+ log_lines = ['%s: %s' % (k, args[k]) for k in keys]
575
+ args_logger.info('ARGS USED AT %s TIME:\n%s\n', name.upper(), '\n'.join(log_lines))
576
+
577
+ def embedding_name(args):
578
+ """
579
+ Return the generic name of the biggest embedding used by a model.
580
+
581
+ Used by POS and depparse, for example.
582
+
583
+ TODO: Probably will make the transformer names a bit more informative,
584
+ such as electra, roberta, etc. Maybe even phobert for VI, for example
585
+ """
586
+ embedding = "nocharlm"
587
+ if args['wordvec_pretrain_file'] is None and args['wordvec_file'] is None:
588
+ embedding = "nopretrain"
589
+ if args.get('charlm', True) and (args['charlm_forward_file'] or args['charlm_backward_file']):
590
+ embedding = "charlm"
591
+ if args['bert_model']:
592
+ if args['bert_model'] in TRANSFORMER_NICKNAMES:
593
+ embedding = TRANSFORMER_NICKNAMES[args['bert_model']]
594
+ else:
595
+ embedding = "transformer"
596
+
597
+ return embedding
598
+
599
+ def standard_model_file_name(args, model_type, **kwargs):
600
+ """
601
+ Returns a model file name based on some common args found in the various models.
602
+
603
+ The expectation is that the args will have something like
604
+
605
+ parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_parser.pt", help="File name to save the model")
606
+
607
+ Then the model shorthand, embedding type, and other args will be
608
+ turned into arguments in a format string
609
+ """
610
+ embedding = embedding_name(args)
611
+
612
+ finetune = ""
613
+ transformer_lr = ""
614
+ if args.get("bert_finetune", False):
615
+ finetune = "finetuned"
616
+ if "bert_learning_rate" in args:
617
+ transformer_lr = "{}".format(args["bert_learning_rate"])
618
+
619
+ use_peft = "nopeft"
620
+ if args.get("bert_finetune", False) and args.get("use_peft", False):
621
+ use_peft = "peft"
622
+
623
+ bert_finetuning = ""
624
+ if args.get("bert_finetune", False):
625
+ if args.get("use_peft", False):
626
+ bert_finetuning = "peft"
627
+ else:
628
+ bert_finetuning = "ft"
629
+
630
+ seed = args.get('seed', None)
631
+ if seed is None:
632
+ seed = ""
633
+ else:
634
+ seed = str(seed)
635
+
636
+ format_args = {
637
+ "batch_size": args['batch_size'],
638
+ "bert_finetuning": bert_finetuning,
639
+ "embedding": embedding,
640
+ "finetune": finetune,
641
+ "peft": use_peft,
642
+ "seed": seed,
643
+ "shorthand": args['shorthand'],
644
+ "transformer_lr": transformer_lr,
645
+ }
646
+ format_args.update(**kwargs)
647
+ model_file = args['save_name'].format(**format_args)
648
+ model_file = re.sub("_+", "_", model_file)
649
+
650
+ model_dir = os.path.split(model_file)[0]
651
+
652
+ if not os.path.exists(os.path.join(args['save_dir'], model_file)) and os.path.exists(model_file):
653
+ return model_file
654
+ return os.path.join(args['save_dir'], model_file)
655
+
656
+ def escape_misc_space(space):
657
+ spaces = []
658
+ for char in space:
659
+ if char == ' ':
660
+ spaces.append('\\s')
661
+ elif char == '\t':
662
+ spaces.append('\\t')
663
+ elif char == '\r':
664
+ spaces.append('\\r')
665
+ elif char == '\n':
666
+ spaces.append('\\n')
667
+ elif char == '|':
668
+ spaces.append('\\p')
669
+ elif char == '\\':
670
+ spaces.append('\\\\')
671
+ elif char == ' ':
672
+ spaces.append('\\u00A0')
673
+ else:
674
+ spaces.append(char)
675
+ escaped_space = "".join(spaces)
676
+ return escaped_space
677
+
678
+ def unescape_misc_space(misc_space):
679
+ spaces = []
680
+ pos = 0
681
+ while pos < len(misc_space):
682
+ if misc_space[pos:pos+2] == '\\s':
683
+ spaces.append(' ')
684
+ pos += 2
685
+ elif misc_space[pos:pos+2] == '\\t':
686
+ spaces.append('\t')
687
+ pos += 2
688
+ elif misc_space[pos:pos+2] == '\\r':
689
+ spaces.append('\r')
690
+ pos += 2
691
+ elif misc_space[pos:pos+2] == '\\n':
692
+ spaces.append('\n')
693
+ pos += 2
694
+ elif misc_space[pos:pos+2] == '\\p':
695
+ spaces.append('|')
696
+ pos += 2
697
+ elif misc_space[pos:pos+2] == '\\\\':
698
+ spaces.append('\\')
699
+ pos += 2
700
+ elif misc_space[pos:pos+6] == '\\u00A0':
701
+ spaces.append(' ')
702
+ pos += 6
703
+ else:
704
+ spaces.append(misc_space[pos])
705
+ pos += 1
706
+ unescaped_space = "".join(spaces)
707
+ return unescaped_space
708
+
709
+ def space_before_to_misc(space):
710
+ """
711
+ Convert whitespace to SpacesBefore specifically for the start of a document.
712
+
713
+ In general, UD datasets do not have both SpacesAfter on a token and SpacesBefore on the next token.
714
+
715
+ The space(s) are only marked on one of the tokens.
716
+
717
+ Only at the very beginning of a document is it necessary to mark what spaces occurred before the actual text,
718
+ and the default assumption is that there is no space if there is no SpacesBefore annotation.
719
+ """
720
+ if not space:
721
+ return ""
722
+ escaped_space = escape_misc_space(space)
723
+ return "SpacesBefore=%s" % escaped_space
724
+
725
+ def space_after_to_misc(space):
726
+ """
727
+ Convert whitespace back to the escaped format - either SpaceAfter=No or SpacesAfter=...
728
+ """
729
+ if not space:
730
+ return "SpaceAfter=No"
731
+ if space == " ":
732
+ return ""
733
+ escaped_space = escape_misc_space(space)
734
+ return "SpacesAfter=%s" % escaped_space
735
+
736
+ def misc_to_space_before(misc):
737
+ """
738
+ Find any SpacesBefore annotation in the MISC column and turn it into a space value
739
+ """
740
+ if not misc:
741
+ return ""
742
+ pieces = misc.split("|")
743
+ for piece in pieces:
744
+ if not piece.lower().startswith("spacesbefore="):
745
+ continue
746
+ misc_space = piece.split("=", maxsplit=1)[1]
747
+ return unescape_misc_space(misc_space)
748
+ return ""
749
+
750
+ def misc_to_space_after(misc):
751
+ """
752
+ Convert either SpaceAfter=No or the SpacesAfter annotation
753
+
754
+ see https://universaldependencies.org/misc.html#spacesafter
755
+
756
+ We compensate for some treebanks using SpaceAfter=\n instead of SpacesAfter=\n
757
+ On the way back, though, those annotations will be turned into SpacesAfter
758
+ """
759
+ if not misc:
760
+ return " "
761
+ pieces = misc.split("|")
762
+ if any(piece.lower() == "spaceafter=no" for piece in pieces):
763
+ return ""
764
+ if "SpaceAfter=Yes" in pieces:
765
+ # as of UD 2.11, the Cantonese treebank had this as a misc feature
766
+ return " "
767
+ if "SpaceAfter=No~" in pieces:
768
+ # as of UD 2.11, a weird typo in the Russian Taiga dataset
769
+ return ""
770
+ for piece in pieces:
771
+ if piece.startswith("SpaceAfter=") or piece.startswith("SpacesAfter="):
772
+ misc_space = piece.split("=", maxsplit=1)[1]
773
+ return unescape_misc_space(misc_space)
774
+ return " "
775
+
776
+ def log_norms(model):
777
+ lines = ["NORMS FOR MODEL PARAMTERS"]
778
+ pieces = []
779
+ for name, param in model.named_parameters():
780
+ if param.requires_grad:
781
+ pieces.append((name, "%.6g" % torch.norm(param).item(), "%d" % param.numel()))
782
+ name_len = max(len(x[0]) for x in pieces)
783
+ norm_len = max(len(x[1]) for x in pieces)
784
+ line_format = " %-" + str(name_len) + "s %" + str(norm_len) + "s %s"
785
+ for line in pieces:
786
+ lines.append(line_format % line)
787
+ logger.info("\n".join(lines))
788
+
789
+ def attach_bert_model(model, bert_model, bert_tokenizer, use_peft, force_bert_saved):
790
+ if use_peft:
791
+ # we use a peft-specific pathway for saving peft weights
792
+ model.add_unsaved_module('bert_model', bert_model)
793
+ model.bert_model.train()
794
+ elif force_bert_saved:
795
+ model.bert_model = bert_model
796
+ elif bert_model is not None:
797
+ model.add_unsaved_module('bert_model', bert_model)
798
+ for _, parameter in bert_model.named_parameters():
799
+ parameter.requires_grad = False
800
+ else:
801
+ model.bert_model = None
802
+ model.add_unsaved_module('bert_tokenizer', bert_tokenizer)
803
+
804
+ def build_save_each_filename(base_filename):
805
+ """
806
+ If the given name doesn't have %d in it, add %4d at the end of the filename
807
+
808
+ This way, there's something to count how many models have been saved
809
+ """
810
+ try:
811
+ base_filename % 1
812
+ except TypeError:
813
+ # so models.pt -> models_0001.pt, etc
814
+ pieces = os.path.splitext(model_save_each_file)
815
+ base_filename = pieces[0] + "_%04d" + pieces[1]
816
+ return base_filename
stanza/stanza/models/common/vocab.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from collections import Counter, OrderedDict
3
+ from collections.abc import Iterable
4
+ import os
5
+ import pickle
6
+
7
+ PAD = '<PAD>'
8
+ PAD_ID = 0
9
+ UNK = '<UNK>'
10
+ UNK_ID = 1
11
+ EMPTY = '<EMPTY>'
12
+ EMPTY_ID = 2
13
+ ROOT = '<ROOT>'
14
+ ROOT_ID = 3
15
+ VOCAB_PREFIX = [PAD, UNK, EMPTY, ROOT]
16
+ VOCAB_PREFIX_SIZE = len(VOCAB_PREFIX)
17
+
18
+ class BaseVocab:
19
+ """ A base class for common vocabulary operations. Each subclass should at least
20
+ implement its own build_vocab() function."""
21
+ def __init__(self, data=None, lang="", idx=0, cutoff=0, lower=False):
22
+ self.data = data
23
+ self.lang = lang
24
+ self.idx = idx
25
+ self.cutoff = cutoff
26
+ self.lower = lower
27
+ if data is not None:
28
+ self.build_vocab()
29
+ self.state_attrs = ['lang', 'idx', 'cutoff', 'lower', '_unit2id', '_id2unit']
30
+
31
+ def build_vocab(self):
32
+ raise NotImplementedError("This BaseVocab does not have build_vocab implemented. This method should create _id2unit and _unit2id")
33
+
34
+ def state_dict(self):
35
+ """ Returns a dictionary containing all states that are necessary to recover
36
+ this vocab. Useful for serialization."""
37
+ state = OrderedDict()
38
+ for attr in self.state_attrs:
39
+ if hasattr(self, attr):
40
+ state[attr] = getattr(self, attr)
41
+ return state
42
+
43
+ @classmethod
44
+ def load_state_dict(cls, state_dict):
45
+ """ Returns a new Vocab instance constructed from a state dict. """
46
+ new = cls()
47
+ for attr, value in state_dict.items():
48
+ setattr(new, attr, value)
49
+ return new
50
+
51
+ def normalize_unit(self, unit):
52
+ # be sure to look in subclasses for other normalization being done
53
+ # especially PretrainWordVocab
54
+ if unit is None:
55
+ return unit
56
+ if self.lower:
57
+ return unit.lower()
58
+ return unit
59
+
60
+ def unit2id(self, unit):
61
+ unit = self.normalize_unit(unit)
62
+ if unit in self._unit2id:
63
+ return self._unit2id[unit]
64
+ else:
65
+ return self._unit2id[UNK]
66
+
67
+ def id2unit(self, id):
68
+ return self._id2unit[id]
69
+
70
+ def map(self, units):
71
+ return [self.unit2id(x) for x in units]
72
+
73
+ def unmap(self, ids):
74
+ return [self.id2unit(x) for x in ids]
75
+
76
+ def __str__(self):
77
+ lang_str = "(%s)" % self.lang if self.lang else ""
78
+ name = str(type(self)) + lang_str
79
+ return "<%s: %s>" % (name, self._id2unit)
80
+
81
+ def __len__(self):
82
+ return len(self._id2unit)
83
+
84
+ def __getitem__(self, key):
85
+ if isinstance(key, str):
86
+ return self.unit2id(key)
87
+ elif isinstance(key, int) or isinstance(key, list):
88
+ return self.id2unit(key)
89
+ else:
90
+ raise TypeError("Vocab key must be one of str, list, or int")
91
+
92
+ def __contains__(self, key):
93
+ return self.normalize_unit(key) in self._unit2id
94
+
95
+ @property
96
+ def size(self):
97
+ return len(self)
98
+
99
+ class DeltaVocab(BaseVocab):
100
+ """
101
+ A vocab that starts off with a BaseVocab, then possibly adds more tokens based on the text in the given data
102
+
103
+ Currently meant only for characters, such as built by MWT or Lemma
104
+
105
+ Expected data format is either a list of strings, or a list of list of strings
106
+ """
107
+ def __init__(self, data, orig_vocab):
108
+ self.orig_vocab = orig_vocab
109
+ super().__init__(data=data, lang=orig_vocab.lang, idx=orig_vocab.idx, cutoff=orig_vocab.cutoff, lower=orig_vocab.lower)
110
+
111
+ def build_vocab(self):
112
+ if all(isinstance(word, str) for word in self.data):
113
+ allchars = "".join(self.data)
114
+ else:
115
+ allchars = "".join([word for sentence in self.data for word in sentence])
116
+
117
+ unk = [c for c in allchars if c not in self.orig_vocab._unit2id]
118
+ if len(unk) > 0:
119
+ unk = sorted(set(unk))
120
+ self._id2unit = self.orig_vocab._id2unit + unk
121
+ self._unit2id = dict(self.orig_vocab._unit2id)
122
+ for c in unk:
123
+ self._unit2id[c] = len(self._unit2id)
124
+ else:
125
+ self._id2unit = self.orig_vocab._id2unit
126
+ self._unit2id = self.orig_vocab._unit2id
127
+
128
+ class CompositeVocab(BaseVocab):
129
+ ''' Vocabulary class that handles parsing and printing composite values such as
130
+ compositional XPOS and universal morphological features (UFeats).
131
+
132
+ Two key options are `keyed` and `sep`. `sep` specifies the separator used between
133
+ different parts of the composite values, which is `|` for UFeats, for example.
134
+ If `keyed` is `True`, then the incoming value is treated similarly to UFeats, where
135
+ each part is a key/value pair separated by an equal sign (`=`). There are no inherit
136
+ order to the keys, and we sort them alphabetically for serialization and deserialization.
137
+ Whenever a part is absent, its internal value is a special `<EMPTY>` symbol that will
138
+ be treated accordingly when generating the output. If `keyed` is `False`, then the parts
139
+ are treated as positioned values, and `<EMPTY>` is used to pad parts at the end when the
140
+ incoming value is not long enough.'''
141
+
142
+ def __init__(self, data=None, lang="", idx=0, sep="", keyed=False):
143
+ self.sep = sep
144
+ self.keyed = keyed
145
+ super().__init__(data, lang, idx=idx)
146
+ self.state_attrs += ['sep', 'keyed']
147
+
148
+ def unit2parts(self, unit):
149
+ # unpack parts of a unit
150
+ if not self.sep:
151
+ parts = [x for x in unit]
152
+ else:
153
+ parts = unit.split(self.sep)
154
+ if self.keyed:
155
+ if len(parts) == 1 and parts[0] == '_':
156
+ return dict()
157
+ parts = [x.split('=') for x in parts]
158
+ if any(len(x) != 2 for x in parts):
159
+ raise ValueError('Received "%s" for a dictionary which is supposed to be keyed, eg the entries should all be of the form key=value and separated by %s' % (unit, self.sep))
160
+
161
+ # Just treat multi-valued properties values as one possible value
162
+ parts = dict(parts)
163
+ elif unit == '_':
164
+ parts = []
165
+ return parts
166
+
167
+ def unit2id(self, unit):
168
+ parts = self.unit2parts(unit)
169
+ if self.keyed:
170
+ # treat multi-valued properties as singletons
171
+ return [self._unit2id[k].get(parts[k], UNK_ID) if k in parts else EMPTY_ID for k in self._unit2id]
172
+ else:
173
+ return [self._unit2id[i].get(parts[i], UNK_ID) if i < len(parts) else EMPTY_ID for i in range(len(self._unit2id))]
174
+
175
+ def id2unit(self, id):
176
+ # special case: allow single ids for vocabs with length 1
177
+ if len(self._id2unit) == 1 and not isinstance(id, Iterable):
178
+ id = (id,)
179
+ items = []
180
+ for v, k in zip(id, self._id2unit.keys()):
181
+ if v == EMPTY_ID: continue
182
+ if self.keyed:
183
+ items.append("{}={}".format(k, self._id2unit[k][v]))
184
+ else:
185
+ items.append(self._id2unit[k][v])
186
+ if self.sep is not None:
187
+ res = self.sep.join(items)
188
+ if res == "":
189
+ res = "_"
190
+ return res
191
+ else:
192
+ return items
193
+
194
+ def build_vocab(self):
195
+ allunits = [w[self.idx] for sent in self.data for w in sent]
196
+ if self.keyed:
197
+ self._id2unit = dict()
198
+
199
+ for u in allunits:
200
+ parts = self.unit2parts(u)
201
+ for key in parts:
202
+ if key not in self._id2unit:
203
+ self._id2unit[key] = copy(VOCAB_PREFIX)
204
+
205
+ # treat multi-valued properties as singletons
206
+ if parts[key] not in self._id2unit[key]:
207
+ self._id2unit[key].append(parts[key])
208
+
209
+ # special handle for the case where upos/xpos/ufeats are always empty
210
+ if len(self._id2unit) == 0:
211
+ self._id2unit['_'] = copy(VOCAB_PREFIX) # use an arbitrary key
212
+
213
+ else:
214
+ self._id2unit = dict()
215
+
216
+ allparts = [self.unit2parts(u) for u in allunits]
217
+ maxlen = max([len(p) for p in allparts])
218
+
219
+ for parts in allparts:
220
+ for i, p in enumerate(parts):
221
+ if i not in self._id2unit:
222
+ self._id2unit[i] = copy(VOCAB_PREFIX)
223
+ if i < len(parts) and p not in self._id2unit[i]:
224
+ self._id2unit[i].append(p)
225
+
226
+ # special handle for the case where upos/xpos/ufeats are always empty
227
+ if len(self._id2unit) == 0:
228
+ self._id2unit[0] = copy(VOCAB_PREFIX) # use an arbitrary key
229
+
230
+ self._id2unit = OrderedDict([(k, self._id2unit[k]) for k in sorted(self._id2unit.keys())])
231
+ self._unit2id = {k: {w:i for i, w in enumerate(self._id2unit[k])} for k in self._id2unit}
232
+
233
+ def lens(self):
234
+ return [len(self._unit2id[k]) for k in self._unit2id]
235
+
236
+ def items(self, idx):
237
+ return self._id2unit[idx]
238
+
239
+ def __str__(self):
240
+ pieces = ["[" + ",".join(x) + "]" for _, x in self._id2unit.items()]
241
+ rep = "<{}:\n {}>".format(type(self), "\n ".join(pieces))
242
+ return rep
243
+
244
+ class BaseMultiVocab:
245
+ """ A convenient vocab container that can store multiple BaseVocab instances, and support
246
+ safe serialization of all instances via state dicts. Each subclass of this base class
247
+ should implement the load_state_dict() function to specify how a saved state dict
248
+ should be loaded back."""
249
+ def __init__(self, vocab_dict=None):
250
+ self._vocabs = OrderedDict()
251
+ if vocab_dict is None:
252
+ return
253
+ # check all values provided must be a subclass of the Vocab base class
254
+ assert all([isinstance(v, BaseVocab) for v in vocab_dict.values()])
255
+ for k, v in vocab_dict.items():
256
+ self._vocabs[k] = v
257
+
258
+ def __setitem__(self, key, item):
259
+ self._vocabs[key] = item
260
+
261
+ def __getitem__(self, key):
262
+ return self._vocabs[key]
263
+
264
+ def __str__(self):
265
+ return "<{}: [{}]>".format(type(self), ", ".join(self._vocabs.keys()))
266
+
267
+ def __contains__(self, key):
268
+ return key in self._vocabs
269
+
270
+ def keys(self):
271
+ return self._vocabs.keys()
272
+
273
+ def state_dict(self):
274
+ """ Build a state dict by iteratively calling state_dict() of all vocabs. """
275
+ state = OrderedDict()
276
+ for k, v in self._vocabs.items():
277
+ state[k] = v.state_dict()
278
+ return state
279
+
280
+ @classmethod
281
+ def load_state_dict(cls, state_dict):
282
+ """ Construct a MultiVocab by reading from a state dict."""
283
+ raise NotImplementedError
284
+
285
+
286
+
287
+ class CharVocab(BaseVocab):
288
+ def build_vocab(self):
289
+ if isinstance(self.data[0][0], (list, tuple)): # general data from DataLoader
290
+ counter = Counter([c for sent in self.data for w in sent for c in w[self.idx]])
291
+ for k in list(counter.keys()):
292
+ if counter[k] < self.cutoff:
293
+ del counter[k]
294
+ else: # special data from Char LM
295
+ counter = Counter([c for sent in self.data for c in sent])
296
+ self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: (counter[k], k), reverse=True))
297
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
298
+
stanza/stanza/models/constituency/base_model.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The BaseModel is passed to the transitions so that the transitions
3
+ can operate on a parsing state without knowing the exact
4
+ representation used in the model.
5
+
6
+ For example, a SimpleModel simply looks at the top of the various stacks in the state.
7
+
8
+ A model with LSTM representations for the different transitions may
9
+ attach the hidden and output states of the LSTM to the word /
10
+ constituent / transition stacks.
11
+
12
+ Reminder: the parsing state is a list of words to parse, the
13
+ transitions used to build a (possibly incomplete) parse, and the
14
+ constituent(s) built so far by those transitions. Each of these
15
+ components are represented using stacks to improve the efficiency
16
+ of operations such as "combine the most recent 4 constituents"
17
+ or "turn the next input word into a constituent"
18
+ """
19
+
20
+ from abc import ABC, abstractmethod
21
+ from collections import defaultdict
22
+ import logging
23
+
24
+ import torch
25
+
26
+ from stanza.models.common import utils
27
+ from stanza.models.constituency import transition_sequence
28
+ from stanza.models.constituency.parse_transitions import TransitionScheme, CloseConstituent
29
+ from stanza.models.constituency.parse_tree import Tree
30
+ from stanza.models.constituency.state import State
31
+ from stanza.models.constituency.tree_stack import TreeStack
32
+ from stanza.server.parser_eval import ParseResult, ScoredTree
33
+
34
+ # default unary limit. some treebanks may have longer chains (CTB, for example)
35
+ UNARY_LIMIT = 4
36
+
37
+ logger = logging.getLogger('stanza.constituency.trainer')
38
+
39
+ class BaseModel(ABC):
40
+ """
41
+ This base class defines abstract methods for manipulating a State.
42
+
43
+ Applying transitions may change important metadata about a State
44
+ such as the vectors associated with LSTM hidden states, for example.
45
+
46
+ The constructor forwards all unused arguments to other classes in the
47
+ constructor sequence, so put this before other classes such as nn.Module
48
+ """
49
+ def __init__(self, transition_scheme, unary_limit, reverse_sentence, root_labels, *args, **kwargs):
50
+ super().__init__(*args, **kwargs) # forwards all unused arguments
51
+
52
+ self._transition_scheme = transition_scheme
53
+ self._unary_limit = unary_limit
54
+ self._reverse_sentence = reverse_sentence
55
+ self._root_labels = sorted(list(root_labels))
56
+
57
+ self._is_top_down = (self._transition_scheme is TransitionScheme.TOP_DOWN or
58
+ self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY or
59
+ self._transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND)
60
+
61
+ @abstractmethod
62
+ def initial_word_queues(self, tagged_word_lists):
63
+ """
64
+ For each list of tagged words, builds a TreeStack of word nodes
65
+
66
+ The word lists should be backwards so that the first word is the last word put on the stack (LIFO)
67
+ """
68
+
69
+ @abstractmethod
70
+ def initial_transitions(self):
71
+ """
72
+ Builds an initial transition stack with whatever values need to go into first position
73
+ """
74
+
75
+ @abstractmethod
76
+ def initial_constituents(self):
77
+ """
78
+ Builds an initial constituent stack with whatever values need to go into first position
79
+ """
80
+
81
+ @abstractmethod
82
+ def get_word(self, word_node):
83
+ """
84
+ Get the word corresponding to this position in the word queue
85
+ """
86
+
87
+ @abstractmethod
88
+ def transform_word_to_constituent(self, state):
89
+ """
90
+ Transform the top node of word_queue to something that can push on the constituent stack
91
+ """
92
+
93
+ @abstractmethod
94
+ def dummy_constituent(self, dummy):
95
+ """
96
+ When using a dummy node as a sentinel, transform it to something usable by this model
97
+ """
98
+
99
+ @abstractmethod
100
+ def build_constituents(self, labels, children_lists):
101
+ """
102
+ Build multiple constituents at once. This gives the opportunity for batching operations
103
+ """
104
+
105
+ @abstractmethod
106
+ def push_constituents(self, constituent_stacks, constituents):
107
+ """
108
+ Add a multiple constituents to multiple constituent_stacks
109
+
110
+ Useful to factor this out in case batching will help
111
+ """
112
+
113
+ @abstractmethod
114
+ def get_top_constituent(self, constituents):
115
+ """
116
+ Get the first constituent from the constituent stack
117
+
118
+ For example, a model might want to remove embeddings and LSTM state vectors
119
+ """
120
+
121
+ @abstractmethod
122
+ def push_transitions(self, transition_stacks, transitions):
123
+ """
124
+ Add a multiple transitions to multiple transition_stacks
125
+
126
+ Useful to factor this out in case batching will help
127
+ """
128
+
129
+ @abstractmethod
130
+ def get_top_transition(self, transitions):
131
+ """
132
+ Get the first transition from the transition stack
133
+
134
+ For example, a model might want to remove transition embeddings before returning the transition
135
+ """
136
+
137
+ @property
138
+ def root_labels(self):
139
+ """
140
+ Return ROOT labels for this model. Probably ROOT, TOP, or both
141
+
142
+ (Danish uses 's', though)
143
+ """
144
+ return self._root_labels
145
+
146
+ def unary_limit(self):
147
+ """
148
+ Limit on the number of consecutive unary transitions
149
+ """
150
+ return self._unary_limit
151
+
152
+
153
+ def transition_scheme(self):
154
+ """
155
+ Transition scheme used - see parse_transitions
156
+ """
157
+ return self._transition_scheme
158
+
159
+ def has_unary_transitions(self):
160
+ """
161
+ Whether or not this model uses unary transitions, based on transition_scheme
162
+ """
163
+ return self._transition_scheme is TransitionScheme.TOP_DOWN_UNARY
164
+
165
+ @property
166
+ def is_top_down(self):
167
+ """
168
+ Whether or not this model is TOP_DOWN
169
+ """
170
+ return self._is_top_down
171
+
172
+ @property
173
+ def reverse_sentence(self):
174
+ """
175
+ Whether or not this model is built to parse backwards
176
+ """
177
+ return self._reverse_sentence
178
+
179
+ def predict(self, states, is_legal=True):
180
+ raise NotImplementedError("LSTMModel can predict, but SimpleModel cannot")
181
+
182
+ def weighted_choice(self, states):
183
+ raise NotImplementedError("LSTMModel can weighted_choice, but SimpleModel cannot")
184
+
185
+ def predict_gold(self, states, is_legal=True):
186
+ """
187
+ For each State, return the next item in the gold_sequence
188
+ """
189
+ transitions = [y.gold_sequence[y.num_transitions] for y in states]
190
+ if is_legal:
191
+ for trans, state in zip(transitions, states):
192
+ if not trans.is_legal(state, self):
193
+ raise RuntimeError("Transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(state.num_transitions, trans, state.gold_tree, state.gold_sequence))
194
+ return None, transitions, None
195
+
196
+ def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):
197
+ """
198
+ what is passed in should be a list of list of preterminals
199
+ """
200
+ word_queues = self.initial_word_queues(preterminal_lists)
201
+ # this is the bottom of the TreeStack and will be the same for each State
202
+ transitions = self.initial_transitions()
203
+ constituents = self.initial_constituents()
204
+ states = [State(sentence_length=len(wq)-2, # -2 because it starts and ends with a sentinel
205
+ num_opens=0,
206
+ word_queue=wq,
207
+ gold_tree=None,
208
+ gold_sequence=None,
209
+ transitions=transitions,
210
+ constituents=constituents,
211
+ word_position=0,
212
+ score=0.0)
213
+ for idx, wq in enumerate(word_queues)]
214
+ if gold_trees:
215
+ states = [state._replace(gold_tree=gold_tree) for gold_tree, state in zip(gold_trees, states)]
216
+ if gold_sequences:
217
+ states = [state._replace(gold_sequence=gold_sequence) for gold_sequence, state in zip(gold_sequences, states)]
218
+ return states
219
+
220
+ def initial_state_from_words(self, word_lists):
221
+ preterminal_lists = [[Tree(tag, Tree(word)) for word, tag in words]
222
+ for words in word_lists]
223
+ return self.initial_state_from_preterminals(preterminal_lists, gold_trees=None, gold_sequences=None)
224
+
225
+ def initial_state_from_gold_trees(self, trees, gold_sequences=None):
226
+ preterminal_lists = [[Tree(pt.label, Tree(pt.children[0].label))
227
+ for pt in tree.yield_preterminals()]
228
+ for tree in trees]
229
+ return self.initial_state_from_preterminals(preterminal_lists, gold_trees=trees, gold_sequences=gold_sequences)
230
+
231
+ def build_batch_from_trees(self, batch_size, data_iterator):
232
+ """
233
+ Read from the data_iterator batch_size trees and turn them into new parsing states
234
+ """
235
+ state_batch = []
236
+ for _ in range(batch_size):
237
+ gold_tree = next(data_iterator, None)
238
+ if gold_tree is None:
239
+ break
240
+ state_batch.append(gold_tree)
241
+
242
+ if len(state_batch) > 0:
243
+ state_batch = self.initial_state_from_gold_trees(state_batch)
244
+ return state_batch
245
+
246
+ def build_batch_from_trees_with_gold_sequence(self, batch_size, data_iterator):
247
+ """
248
+ Same as build_batch_from_trees, but use the model parameters to turn the trees into gold sequences and include the sequence
249
+ """
250
+ state_batch = self.build_batch_from_trees(batch_size, data_iterator)
251
+ if len(state_batch) == 0:
252
+ return state_batch
253
+
254
+ gold_sequences = transition_sequence.build_treebank([state.gold_tree for state in state_batch], self.transition_scheme(), self.reverse_sentence)
255
+ state_batch = [state._replace(gold_sequence=sequence) for state, sequence in zip(state_batch, gold_sequences)]
256
+ return state_batch
257
+
258
+ def build_batch_from_tagged_words(self, batch_size, data_iterator):
259
+ """
260
+ Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
261
+
262
+ Expects a list of list of (word, tag)
263
+ """
264
+ state_batch = []
265
+ for _ in range(batch_size):
266
+ sentence = next(data_iterator, None)
267
+ if sentence is None:
268
+ break
269
+ state_batch.append(sentence)
270
+
271
+ if len(state_batch) > 0:
272
+ state_batch = self.initial_state_from_words(state_batch)
273
+ return state_batch
274
+
275
+
276
+ def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
277
+ """
278
+ Repeat transitions to build a list of trees from the input batches.
279
+
280
+ The data_iterator should be anything which returns the data for a parse task via next()
281
+ build_batch_fn is a function that turns that data into State objects
282
+ This will be called to generate batches of size batch_size until the data is exhausted
283
+
284
+ The return is a list of tuples: (gold_tree, [(predicted, score) ...])
285
+ gold_tree will be left blank if the data did not include gold trees
286
+ if keep_scores is true, the score will be the sum of the values
287
+ returned by the model for each transition
288
+
289
+ transition_choice: which method of the model to use for choosing the next transition
290
+ predict for predicting the transition based on the model
291
+ predict_gold to just extract the gold transition from the sequence
292
+ """
293
+ treebank = []
294
+ treebank_indices = []
295
+ state_batch = build_batch_fn(batch_size, data_iterator)
296
+ # used to track which indices we are currently parsing
297
+ # since the parses get finished at different times, this will let us unsort after
298
+ batch_indices = list(range(len(state_batch)))
299
+ horizon_iterator = iter([])
300
+
301
+ if keep_constituents:
302
+ constituents = defaultdict(list)
303
+
304
+ while len(state_batch) > 0:
305
+ pred_scores, transitions, scores = transition_choice(state_batch)
306
+ if keep_scores and scores is not None:
307
+ state_batch = [state._replace(score=state.score + score) for state, score in zip(state_batch, scores)]
308
+ state_batch = self.bulk_apply(state_batch, transitions)
309
+
310
+ if keep_constituents:
311
+ for t_idx, transition in enumerate(transitions):
312
+ if isinstance(transition, CloseConstituent):
313
+ # constituents is a TreeStack with information on how to build the next state of the LSTM or attn
314
+ # constituents.value is the TreeStack node
315
+ # constituents.value.value is the Constituent itself (with the tree and the embedding)
316
+ constituents[batch_indices[t_idx]].append(state_batch[t_idx].constituents.value.value)
317
+
318
+ remove = set()
319
+ for idx, state in enumerate(state_batch):
320
+ if state.finished(self):
321
+ predicted_tree = state.get_tree(self)
322
+ if self.reverse_sentence:
323
+ predicted_tree = predicted_tree.reverse()
324
+ gold_tree = state.gold_tree
325
+ treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, state.score)], state if keep_state else None, constituents[batch_indices[idx]] if keep_constituents else None))
326
+ treebank_indices.append(batch_indices[idx])
327
+ remove.add(idx)
328
+
329
+ if len(remove) > 0:
330
+ state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]
331
+ batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]
332
+
333
+ for _ in range(batch_size - len(state_batch)):
334
+ horizon_state = next(horizon_iterator, None)
335
+ if not horizon_state:
336
+ horizon_batch = build_batch_fn(batch_size, data_iterator)
337
+ if len(horizon_batch) == 0:
338
+ break
339
+ horizon_iterator = iter(horizon_batch)
340
+ horizon_state = next(horizon_iterator, None)
341
+
342
+ state_batch.append(horizon_state)
343
+ batch_indices.append(len(treebank) + len(state_batch))
344
+
345
+ treebank = utils.unsort(treebank, treebank_indices)
346
+ return treebank
347
+
348
+ def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
349
+ """
350
+ Given an iterator over the data and a method for building batches, returns a list of parse trees.
351
+
352
+ no_grad() is so that gradients aren't kept, which makes the model
353
+ run faster and use less memory at inference time
354
+ """
355
+ with torch.no_grad():
356
+ return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)
357
+
358
+ def analyze_trees(self, trees, batch_size=None, keep_state=True, keep_constituents=True, keep_scores=True):
359
+ """
360
+ Return a ParseResult for each tree in the trees list
361
+
362
+ The transitions run will be the transitions represented by the tree
363
+ The output layers will be available in result.state for each result
364
+
365
+ keep_state=True as a default here as a method which keeps the grad
366
+ is likely to want to keep the resulting state as well
367
+ """
368
+ if batch_size is None:
369
+ # TODO: refactor?
370
+ batch_size = self.args['eval_batch_size']
371
+ tree_iterator = iter(trees)
372
+ treebank = self.parse_sentences(tree_iterator, self.build_batch_from_trees_with_gold_sequence, batch_size, self.predict_gold, keep_state, keep_constituents, keep_scores=keep_scores)
373
+ return treebank
374
+
375
+ def parse_tagged_words(self, words, batch_size):
376
+ """
377
+ This parses tagged words and returns a list of trees.
378
+
379
+ `parse_tagged_words` is useful at Pipeline time -
380
+ it takes words & tags and processes that into trees.
381
+
382
+ The tagged words should be represented:
383
+ one list per sentence
384
+ each sentence is a list of (word, tag)
385
+ The return value is a list of ParseTree objects
386
+ """
387
+ logger.debug("Processing %d sentences", len(words))
388
+ self.eval()
389
+
390
+ sentence_iterator = iter(words)
391
+ treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)
392
+
393
+ results = [t.predictions[0].tree for t in treebank]
394
+ return results
395
+
396
+ def bulk_apply(self, state_batch, transitions, fail=False):
397
+ """
398
+ Apply the given list of Transitions to the given list of States, using the model as a reference
399
+
400
+ model: SimpleModel, LSTMModel, or any other form of model
401
+ state_batch: list of States
402
+ transitions: list of transitions, one per state
403
+ fail: throw an exception on a failed transition, as opposed to skipping the tree
404
+ """
405
+ remove = set()
406
+
407
+ word_positions = []
408
+ constituents = []
409
+ new_constituents = []
410
+ callbacks = defaultdict(list)
411
+
412
+ for idx, (tree, transition) in enumerate(zip(state_batch, transitions)):
413
+ if not transition:
414
+ error = "Got stuck and couldn't find a legal transition on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(self))
415
+ if fail:
416
+ raise ValueError(error)
417
+ else:
418
+ logger.error(error)
419
+ remove.add(idx)
420
+ continue
421
+
422
+ if tree.num_transitions >= len(tree.word_queue) * 20:
423
+ # too many transitions
424
+ # x20 is somewhat empirically chosen based on certain
425
+ # treebanks having deep unary structures, especially early
426
+ # on when the model is fumbling around
427
+ if tree.gold_tree:
428
+ error = "Went infinite on the following gold tree:\n{}\n\nFinal state:\n{}".format(tree.gold_tree, tree.to_string(self))
429
+ else:
430
+ error = "Went infinite!:\nFinal state:\n{}".format(tree.to_string(self))
431
+ if fail:
432
+ raise ValueError(error)
433
+ else:
434
+ logger.error(error)
435
+ remove.add(idx)
436
+ continue
437
+
438
+ wq, c, nc, callback = transition.update_state(tree, self)
439
+
440
+ word_positions.append(wq)
441
+ constituents.append(c)
442
+ new_constituents.append(nc)
443
+ if callback:
444
+ # not `idx` in case something was removed
445
+ callbacks[callback].append(len(new_constituents)-1)
446
+
447
+ for key, idxs in callbacks.items():
448
+ data = [new_constituents[x] for x in idxs]
449
+ callback_constituents = key.build_constituents(self, data)
450
+ for idx, constituent in zip(idxs, callback_constituents):
451
+ new_constituents[idx] = constituent
452
+
453
+ if len(remove) > 0:
454
+ state_batch = [tree for idx, tree in enumerate(state_batch) if idx not in remove]
455
+ transitions = [trans for idx, trans in enumerate(transitions) if idx not in remove]
456
+
457
+ if len(state_batch) == 0:
458
+ return state_batch
459
+
460
+ new_transitions = self.push_transitions([tree.transitions for tree in state_batch], transitions)
461
+ new_constituents = self.push_constituents(constituents, new_constituents)
462
+
463
+ state_batch = [state._replace(num_opens=state.num_opens + transition.delta_opens(),
464
+ word_position=word_position,
465
+ transitions=transition_stack,
466
+ constituents=constituents)
467
+ for (state, transition, word_position, transition_stack, constituents)
468
+ in zip(state_batch, transitions, word_positions, new_transitions, new_constituents)]
469
+
470
+ return state_batch
471
+
472
+ class SimpleModel(BaseModel):
473
+ """
474
+ This model allows pushing and popping with no extra data
475
+
476
+ This class is primarily used for testing various operations which
477
+ don't need the NN's weights
478
+
479
+ Also, for rebuilding trees from transitions when verifying the
480
+ transitions in situations where the NN state is not relevant,
481
+ as this class will be faster than using the NN
482
+ """
483
+ def __init__(self, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, unary_limit=UNARY_LIMIT, reverse_sentence=False, root_labels=("ROOT",)):
484
+ super().__init__(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse_sentence, root_labels=root_labels)
485
+
486
+ def initial_word_queues(self, tagged_word_lists):
487
+ word_queues = []
488
+ for tagged_words in tagged_word_lists:
489
+ word_queue = [None]
490
+ word_queue += [tag_node for tag_node in tagged_words]
491
+ word_queue.append(None)
492
+ if self.reverse_sentence:
493
+ word_queue.reverse()
494
+ word_queues.append(word_queue)
495
+ return word_queues
496
+
497
+ def initial_transitions(self):
498
+ return TreeStack(value=None, parent=None, length=1)
499
+
500
+ def initial_constituents(self):
501
+ return TreeStack(value=None, parent=None, length=1)
502
+
503
+ def get_word(self, word_node):
504
+ return word_node
505
+
506
+ def transform_word_to_constituent(self, state):
507
+ return state.get_word(state.word_position)
508
+
509
+ def dummy_constituent(self, dummy):
510
+ return dummy
511
+
512
+ def build_constituents(self, labels, children_lists):
513
+ constituents = []
514
+ for label, children in zip(labels, children_lists):
515
+ if isinstance(label, str):
516
+ label = (label,)
517
+ for value in reversed(label):
518
+ children = Tree(label=value, children=children)
519
+ constituents.append(children)
520
+ return constituents
521
+
522
+ def push_constituents(self, constituent_stacks, constituents):
523
+ return [stack.push(constituent) for stack, constituent in zip(constituent_stacks, constituents)]
524
+
525
+ def get_top_constituent(self, constituents):
526
+ return constituents.value
527
+
528
+ def push_transitions(self, transition_stacks, transitions):
529
+ return [stack.push(transition) for stack, transition in zip(transition_stacks, transitions)]
530
+
531
+ def get_top_transition(self, transitions):
532
+ return transitions.value
stanza/stanza/models/constituency/base_trainer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+
7
+ from pickle import UnpicklingError
8
+ import warnings
9
+
10
+ logger = logging.getLogger('stanza')
11
+
12
+ class ModelType(Enum):
13
+ LSTM = 1
14
+ ENSEMBLE = 2
15
+
16
+ class BaseTrainer:
17
+ def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
18
+ self.model = model
19
+ self.optimizer = optimizer
20
+ self.scheduler = scheduler
21
+ # keeping track of the epochs trained will be useful
22
+ # for adjusting the learning scheme
23
+ self.epochs_trained = epochs_trained
24
+ self.batches_trained = batches_trained
25
+ self.best_f1 = best_f1
26
+ self.best_epoch = best_epoch
27
+ self.first_optimizer = first_optimizer
28
+
29
+ def save(self, filename, save_optimizer=True):
30
+ params = self.model.get_params()
31
+ checkpoint = {
32
+ 'params': params,
33
+ 'epochs_trained': self.epochs_trained,
34
+ 'batches_trained': self.batches_trained,
35
+ 'best_f1': self.best_f1,
36
+ 'best_epoch': self.best_epoch,
37
+ 'model_type': self.model_type.name,
38
+ 'first_optimizer': self.first_optimizer,
39
+ }
40
+ checkpoint["bert_lora"] = self.get_peft_params()
41
+ if save_optimizer and self.optimizer is not None:
42
+ checkpoint['optimizer_state_dict'] = self.optimizer.state_dict()
43
+ checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
44
+ torch.save(checkpoint, filename, _use_new_zipfile_serialization=False)
45
+ logger.info("Model saved to %s", filename)
46
+
47
+ def log_norms(self):
48
+ self.model.log_norms()
49
+
50
+ def log_shapes(self):
51
+ self.model.log_shapes()
52
+
53
+ @property
54
+ def transitions(self):
55
+ return self.model.transitions
56
+
57
+ @property
58
+ def root_labels(self):
59
+ return self.model.root_labels
60
+
61
+ @property
62
+ def device(self):
63
+ return next(self.model.parameters()).device
64
+
65
+ def train(self):
66
+ return self.model.train()
67
+
68
+ def eval(self):
69
+ return self.model.eval()
70
+
71
+ # TODO: make ABC with methods such as model_from_params?
72
+ # TODO: if we save the type in the checkpoint, use that here to figure out which to load
73
+ @staticmethod
74
+ def load(filename, args=None, load_optimizer=False, foundation_cache=None, peft_name=None):
75
+ """
76
+ Load back a model and possibly its optimizer.
77
+ """
78
+ # hide the import here to avoid circular imports
79
+ from stanza.models.constituency.ensemble import EnsembleTrainer
80
+ from stanza.models.constituency.trainer import Trainer
81
+
82
+ if not os.path.exists(filename):
83
+ if args.get('save_dir', None) is None:
84
+ raise FileNotFoundError("Cannot find model in {} and args['save_dir'] is None".format(filename))
85
+ elif os.path.exists(os.path.join(args['save_dir'], filename)):
86
+ filename = os.path.join(args['save_dir'], filename)
87
+ else:
88
+ raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args['save_dir'], filename)))
89
+ try:
90
+ # TODO: currently cannot switch this to weights_only=True
91
+ # without in some way changing the model to save enums in
92
+ # a safe manner, probably by converting to int
93
+ try:
94
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
95
+ except UnpicklingError as e:
96
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
97
+ warnings.warn("The saved constituency parser has an old format using Enum, set, unsanitized Transitions, etc. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the constituency parser using this version ASAP.")
98
+ except BaseException:
99
+ logger.exception("Cannot load model from %s", filename)
100
+ raise
101
+ logger.debug("Loaded model from %s", filename)
102
+
103
+ params = checkpoint['params']
104
+
105
+ if 'model_type' not in checkpoint:
106
+ # old models will have this trait
107
+ # TODO: can remove this after 1.10
108
+ checkpoint['model_type'] = ModelType.LSTM
109
+ if isinstance(checkpoint['model_type'], str):
110
+ checkpoint['model_type'] = ModelType[checkpoint['model_type']]
111
+ if checkpoint['model_type'] == ModelType.LSTM:
112
+ clazz = Trainer
113
+ elif checkpoint['model_type'] == ModelType.ENSEMBLE:
114
+ clazz = EnsembleTrainer
115
+ else:
116
+ raise ValueError("Unexpected model type: %s" % checkpoint['model_type'])
117
+ model = clazz.model_from_params(params, checkpoint.get('bert_lora', None), args, foundation_cache, peft_name)
118
+
119
+ epochs_trained = checkpoint['epochs_trained']
120
+ batches_trained = checkpoint.get('batches_trained', 0)
121
+ best_f1 = checkpoint['best_f1']
122
+ best_epoch = checkpoint['best_epoch']
123
+
124
+ if 'first_optimizer' not in checkpoint:
125
+ # this will only apply to old (LSTM) Trainers
126
+ # EnsembleTrainers will always have this value saved
127
+ # so here we can compensate by looking at the old training statistics...
128
+ # we use params['config'] here instead of model.args
129
+ # because the args might have a different training
130
+ # mechanism, but in order to reload the optimizer, we need
131
+ # to match the optimizer we build with the one that was
132
+ # used at training time
133
+ build_simple_adadelta = params['config']['multistage'] and epochs_trained < params['config']['epochs'] // 2
134
+ checkpoint['first_optimizer'] = build_simple_adadelta
135
+ first_optimizer = checkpoint['first_optimizer']
136
+
137
+ if load_optimizer:
138
+ optimizer = clazz.load_optimizer(model, checkpoint, first_optimizer, filename)
139
+ scheduler = clazz.load_scheduler(model, optimizer, checkpoint, first_optimizer)
140
+ else:
141
+ optimizer = None
142
+ scheduler = None
143
+
144
+ if checkpoint['model_type'] == ModelType.LSTM:
145
+ logger.debug("-- MODEL CONFIG --")
146
+ for k in model.args.keys():
147
+ logger.debug(" --%s: %s", k, model.args[k])
148
+ return Trainer(model=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)
149
+ elif checkpoint['model_type'] == ModelType.ENSEMBLE:
150
+ return EnsembleTrainer(ensemble=model, optimizer=optimizer, scheduler=scheduler, epochs_trained=epochs_trained, batches_trained=batches_trained, best_f1=best_f1, best_epoch=best_epoch, first_optimizer=first_optimizer)
151
+ else:
152
+ raise ValueError("Unexpected model type: %s" % checkpoint['model_type'])
153
+
stanza/stanza/models/constituency/ensemble.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prototype of ensembling N models together on the same dataset
3
+
4
+ The main inference method is to run the normal transition sequence,
5
+ but sum the scores for the N models and use that to choose the highest
6
+ scoring transition
7
+
8
+ Example of how to run it to build a silver dataset
9
+ (or just parse a text file in general):
10
+
11
+ # first, use this tool to build a saved ensemble
12
+ python3 stanza/models/constituency/ensemble.py
13
+ saved_models/constituency/wsj_inorder_?.pt
14
+ --save_name saved_models/constituency/en_ensemble.pt
15
+
16
+ # then use the ensemble directly as a model in constituency_parser.py
17
+ python3 stanza/models/constituency_parser.py
18
+ --save_name saved_models/constituency/en_ensemble.pt
19
+ --mode parse_text
20
+ --tokenized_file /nlp/scr/horatio/en_silver/en_split_100
21
+ --predict_file /nlp/scr/horatio/en_silver/en_split_100.inorder.mrg
22
+ --retag_package en_combined_bert
23
+ --lang en
24
+
25
+ then, ideally, run a second time with a set of topdown models,
26
+ then take the trees which match from the files
27
+ """
28
+
29
+
30
+ import argparse
31
+ import copy
32
+ import logging
33
+ import os
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+
38
+ from stanza.models.common import utils
39
+ from stanza.models.common.foundation_cache import FoundationCache
40
+ from stanza.models.constituency.base_trainer import BaseTrainer, ModelType
41
+ from stanza.models.constituency.state import MultiState
42
+ from stanza.models.constituency.trainer import Trainer
43
+ from stanza.models.constituency.utils import build_optimizer, build_scheduler
44
+ from stanza.server.parser_eval import ParseResult, ScoredTree
45
+
46
+ logger = logging.getLogger('stanza.constituency.trainer')
47
+
48
+ class Ensemble(nn.Module):
49
+ def __init__(self, args, filenames=None, models=None, foundation_cache=None):
50
+ """
51
+ Loads each model in filenames
52
+
53
+ If foundation_cache is None, we build one on our own,
54
+ as the expectation is the models will reuse modules
55
+ such as pretrain, charlm, bert
56
+ """
57
+ super().__init__()
58
+
59
+ self.args = args
60
+ if filenames:
61
+ if models:
62
+ raise ValueError("both filenames and models set when making the Ensemble")
63
+
64
+ if foundation_cache is None:
65
+ foundation_cache = FoundationCache()
66
+
67
+ if isinstance(filenames, str):
68
+ filenames = [filenames]
69
+ logger.info("Models used for ensemble:\n %s", "\n ".join(filenames))
70
+ models = [Trainer.load(filename, args, load_optimizer=False, foundation_cache=foundation_cache).model for filename in filenames]
71
+ elif not models:
72
+ raise ValueError("filenames and models both not set!")
73
+
74
+ self.models = nn.ModuleList(models)
75
+
76
+ for model_idx, model in enumerate(self.models):
77
+ if self.models[0].transition_scheme() != model.transition_scheme():
78
+ raise ValueError("Models {} and {} are incompatible. {} vs {}".format(filenames[0], filenames[model_idx], self.models[0].transition_scheme(), model.transition_scheme()))
79
+ if self.models[0].transitions != model.transitions:
80
+ raise ValueError(f"Models {filenames[0]} and {filenames[model_idx]} are incompatible: different transitions\n{filenames[0]}:\n{self.models[0].transitions}\n{filenames[model_idx]}:\n{model.transitions}")
81
+ if self.models[0].constituents != model.constituents:
82
+ raise ValueError("Models %s and %s are incompatible: different constituents" % (filenames[0], filenames[model_idx]))
83
+ if self.models[0].root_labels != model.root_labels:
84
+ raise ValueError("Models %s and %s are incompatible: different root_labels" % (filenames[0], filenames[model_idx]))
85
+ if self.models[0].uses_xpos() != model.uses_xpos():
86
+ raise ValueError("Models %s and %s are incompatible: different uses_xpos" % (filenames[0], filenames[model_idx]))
87
+ if self.models[0].reverse_sentence != model.reverse_sentence:
88
+ raise ValueError("Models %s and %s are incompatible: different reverse_sentence" % (filenames[0], filenames[model_idx]))
89
+
90
+ self._reverse_sentence = self.models[0].reverse_sentence
91
+
92
+ # submodels are not trained (so far)
93
+ self.detach_submodels()
94
+
95
+ logger.debug("Number of models in the Ensemble: %d", len(self.models))
96
+ self.register_parameter('weighted_sum', torch.nn.Parameter(torch.zeros(len(self.models), len(self.transitions), requires_grad=True)))
97
+
98
+ def detach_submodels(self):
99
+ # submodels are not trained (so far)
100
+ for model in self.models:
101
+ for _, parameter in model.named_parameters():
102
+ parameter.requires_grad = False
103
+
104
+ def train(self, mode=True):
105
+ super().train(mode)
106
+ if mode:
107
+ # peft has a weird interaction where it turns requires_grad back on
108
+ # even if it was previously off
109
+ self.detach_submodels()
110
+
111
+ @property
112
+ def transitions(self):
113
+ return self.models[0].transitions
114
+
115
+ @property
116
+ def root_labels(self):
117
+ return self.models[0].root_labels
118
+
119
+ @property
120
+ def device(self):
121
+ return next(self.parameters()).device
122
+
123
+ def unary_limit(self):
124
+ """
125
+ Limit on the number of consecutive unary transitions
126
+ """
127
+ return min(m.unary_limit() for m in self.models)
128
+
129
+ def transition_scheme(self):
130
+ return self.models[0].transition_scheme()
131
+
132
+ def has_unary_transitions(self):
133
+ return self.models[0].has_unary_transitions()
134
+
135
+ @property
136
+ def is_top_down(self):
137
+ return self.models[0].is_top_down
138
+
139
+ @property
140
+ def reverse_sentence(self):
141
+ return self._reverse_sentence
142
+
143
+ @property
144
+ def retag_method(self):
145
+ # TODO: make the method an enum
146
+ return self.models[0].args['retag_method']
147
+
148
+ def uses_xpos(self):
149
+ return self.models[0].uses_xpos()
150
+
151
+ def get_top_constituent(self, constituents):
152
+ return self.models[0].get_top_constituent(constituents)
153
+
154
+ def get_top_transition(self, transitions):
155
+ return self.models[0].get_top_transition(transitions)
156
+
157
+ def log_norms(self):
158
+ lines = ["NORMS FOR MODEL PARAMETERS"]
159
+ for name, param in self.named_parameters():
160
+ if param.requires_grad and not name.startswith("models."):
161
+ zeros = torch.sum(param.abs() < 0.000001).item()
162
+ norm = "%.6g" % torch.norm(param).item()
163
+ lines.append("%s %s %d %d" % (name, norm, zeros, param.nelement()))
164
+ for model_idx, model in enumerate(self.models):
165
+ sublines = model.get_norms()
166
+ if len(sublines) > 0:
167
+ lines.append(" ---- MODEL %d ----" % model_idx)
168
+ lines.extend(sublines)
169
+ logger.info("\n".join(lines))
170
+
171
+ def log_shapes(self):
172
+ lines = ["NORMS FOR MODEL PARAMETERS"]
173
+ for name, param in self.named_parameters():
174
+ if param.requires_grad:
175
+ lines.append("{} {}".format(name, param.shape))
176
+ logger.info("\n".join(lines))
177
+
178
+ def get_params(self):
179
+ model_state = self.state_dict()
180
+ # don't save the children in the base params
181
+ model_state = {k: v for k, v in model_state.items() if not k.startswith("models.")}
182
+ return {
183
+ "base_params": model_state,
184
+ "children_params": [x.get_params() for x in self.models]
185
+ }
186
+
187
+ def initial_state_from_preterminals(self, preterminal_lists, gold_trees, gold_sequences):
188
+ state_batch = [model.initial_state_from_preterminals(preterminal_lists, gold_trees, gold_sequences) for model in self.models]
189
+ state_batch = list(zip(*state_batch))
190
+ state_batch = [MultiState(states, gold_tree, gold_sequence, 0.0)
191
+ for states, gold_tree, gold_sequence in zip(state_batch, gold_trees, gold_sequences)]
192
+ return state_batch
193
+
194
+ def build_batch_from_tagged_words(self, batch_size, data_iterator):
195
+ """
196
+ Read from the data_iterator batch_size tagged sentences and turn them into new parsing states
197
+
198
+ Expects a list of list of (word, tag)
199
+ """
200
+ state_batch = []
201
+ for _ in range(batch_size):
202
+ sentence = next(data_iterator, None)
203
+ if sentence is None:
204
+ break
205
+ state_batch.append(sentence)
206
+
207
+ if len(state_batch) > 0:
208
+ state_batch = [model.initial_state_from_words(state_batch) for model in self.models]
209
+ state_batch = list(zip(*state_batch))
210
+ state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]
211
+ return state_batch
212
+
213
+ def build_batch_from_trees(self, batch_size, data_iterator):
214
+ """
215
+ Read from the data_iterator batch_size trees and turn them into N lists of parsing states
216
+ """
217
+ state_batch = []
218
+ for _ in range(batch_size):
219
+ gold_tree = next(data_iterator, None)
220
+ if gold_tree is None:
221
+ break
222
+ state_batch.append(gold_tree)
223
+
224
+ if len(state_batch) > 0:
225
+ state_batch = [model.initial_state_from_gold_trees(state_batch) for model in self.models]
226
+ state_batch = list(zip(*state_batch))
227
+ state_batch = [MultiState(states, None, None, 0.0) for states in state_batch]
228
+ return state_batch
229
+
230
+ def predict(self, states, is_legal=True):
231
+ states = list(zip(*[x.states for x in states]))
232
+ predictions = [model.forward(state_batch) for model, state_batch in zip(self.models, states)]
233
+
234
+ # batch X num transitions X num models
235
+ predictions = torch.stack(predictions, dim=2)
236
+
237
+ flat_predictions = torch.einsum("BTM,MT->BT", predictions, self.weighted_sum)
238
+ predictions = torch.sum(predictions, dim=2) + flat_predictions
239
+
240
+ model = self.models[0]
241
+
242
+ # TODO: possibly refactor with lstm_model.predict
243
+ pred_max = torch.argmax(predictions, dim=1)
244
+ scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)
245
+ pred_max = pred_max.detach().cpu()
246
+
247
+ pred_trans = [model.transitions[pred_max[idx]] for idx in range(len(states[0]))]
248
+ if is_legal:
249
+ for idx, (state, trans) in enumerate(zip(states[0], pred_trans)):
250
+ if not trans.is_legal(state, model):
251
+ _, indices = predictions[idx, :].sort(descending=True)
252
+ for index in indices:
253
+ if model.transitions[index].is_legal(state, model):
254
+ pred_trans[idx] = model.transitions[index]
255
+ scores[idx] = predictions[idx, index]
256
+ break
257
+ else: # yeah, else on a for loop, deal with it
258
+ pred_trans[idx] = None
259
+ scores[idx] = None
260
+
261
+ return predictions, pred_trans, scores.squeeze(1)
262
+
263
+ def bulk_apply(self, state_batch, transitions, fail=False):
264
+ new_states = []
265
+
266
+ states = list(zip(*[x.states for x in state_batch]))
267
+ states = [x.bulk_apply(y, transitions, fail=fail) for x, y in zip(self.models, states)]
268
+ states = list(zip(*states))
269
+ state_batch = [x._replace(states=y) for x, y in zip(state_batch, states)]
270
+ return state_batch
271
+
272
+ def parse_tagged_words(self, words, batch_size):
273
+ """
274
+ This parses tagged words and returns a list of trees.
275
+
276
+ `parse_tagged_words` is useful at Pipeline time -
277
+ it takes words & tags and processes that into trees.
278
+
279
+ The tagged words should be represented:
280
+ one list per sentence
281
+ each sentence is a list of (word, tag)
282
+ The return value is a list of ParseTree objects
283
+
284
+ TODO: this really ought to be refactored with base_model
285
+ """
286
+ logger.debug("Processing %d sentences", len(words))
287
+ self.eval()
288
+
289
+ sentence_iterator = iter(words)
290
+ treebank = self.parse_sentences_no_grad(sentence_iterator, self.build_batch_from_tagged_words, batch_size, self.predict, keep_state=False, keep_constituents=False)
291
+
292
+ results = [t.predictions[0].tree for t in treebank]
293
+ return results
294
+
295
+ def parse_sentences(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
296
+ """
297
+ Repeat transitions to build a list of trees from the input batches.
298
+
299
+ The data_iterator should be anything which returns the data for a parse task via next()
300
+ build_batch_fn is a function that turns that data into State objects
301
+ This will be called to generate batches of size batch_size until the data is exhausted
302
+
303
+ The return is a list of tuples: (gold_tree, [(predicted, score) ...])
304
+ gold_tree will be left blank if the data did not include gold trees
305
+ currently score is always 1.0, but the interface may be expanded
306
+ to get a score from the result of the parsing
307
+
308
+ transition_choice: which method of the model to use for
309
+ choosing the next transition
310
+
311
+ TODO: refactor with base_model
312
+ """
313
+ treebank = []
314
+ treebank_indices = []
315
+ # this will produce tuples of states
316
+ # batch size lists of num models tuples
317
+ state_batch = build_batch_fn(batch_size, data_iterator)
318
+ batch_indices = list(range(len(state_batch)))
319
+ horizon_iterator = iter([])
320
+
321
+ if keep_constituents:
322
+ constituents = defaultdict(list)
323
+
324
+ while len(state_batch) > 0:
325
+ pred_scores, transitions, scores = transition_choice(state_batch)
326
+ # num models lists of batch size states
327
+ state_batch = self.bulk_apply(state_batch, transitions)
328
+
329
+ remove = set()
330
+ for idx, states in enumerate(state_batch):
331
+ if states.finished(self):
332
+ predicted_tree = states.get_tree(self)
333
+ if self.reverse_sentence:
334
+ predicted_tree = predicted_tree.reverse()
335
+ gold_tree = states.gold_tree
336
+ # TODO: could easily store the score here
337
+ # not sure what it means to store the state,
338
+ # since each model is tracking its own state
339
+ treebank.append(ParseResult(gold_tree, [ScoredTree(predicted_tree, None)], None, None))
340
+ treebank_indices.append(batch_indices[idx])
341
+ remove.add(idx)
342
+
343
+ if len(remove) > 0:
344
+ state_batch = [state for idx, state in enumerate(state_batch) if idx not in remove]
345
+ batch_indices = [batch_idx for idx, batch_idx in enumerate(batch_indices) if idx not in remove]
346
+
347
+ for _ in range(batch_size - len(state_batch)):
348
+ horizon_state = next(horizon_iterator, None)
349
+ if not horizon_state:
350
+ horizon_batch = build_batch_fn(batch_size, data_iterator)
351
+ if len(horizon_batch) == 0:
352
+ break
353
+ horizon_iterator = iter(horizon_batch)
354
+ horizon_state = next(horizon_iterator, None)
355
+
356
+ state_batch.append(horizon_state)
357
+ batch_indices.append(len(treebank) + len(state_batch))
358
+
359
+ treebank = utils.unsort(treebank, treebank_indices)
360
+ return treebank
361
+
362
+ def parse_sentences_no_grad(self, data_iterator, build_batch_fn, batch_size, transition_choice, keep_state=False, keep_constituents=False, keep_scores=False):
363
+ with torch.no_grad():
364
+ return self.parse_sentences(data_iterator, build_batch_fn, batch_size, transition_choice, keep_state, keep_constituents, keep_scores)
365
+
366
+ class EnsembleTrainer(BaseTrainer):
367
+ """
368
+ Stores a list of constituency models, useful for combining their results into one stronger model
369
+ """
370
+ def __init__(self, ensemble, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
371
+ super().__init__(ensemble, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)
372
+
373
+ @staticmethod
374
+ def from_files(args, filenames, foundation_cache=None):
375
+ ensemble = Ensemble(args, filenames, foundation_cache=foundation_cache)
376
+ ensemble = ensemble.to(args.get('device', None))
377
+ return EnsembleTrainer(ensemble)
378
+
379
+ def get_peft_params(self):
380
+ params = []
381
+ for model in self.model.models:
382
+ if model.args.get('use_peft', False):
383
+ from peft import get_peft_model_state_dict
384
+ params.append(get_peft_model_state_dict(model.bert_model, adapter_name=model.peft_name))
385
+ else:
386
+ params.append(None)
387
+
388
+ return params
389
+
390
+ @property
391
+ def model_type(self):
392
+ return ModelType.ENSEMBLE
393
+
394
+ def log_num_words_known(self, words):
395
+ nwk = [m.num_words_known(words) for m in self.model.models]
396
+ if all(x == nwk[0] for x in nwk):
397
+ logger.info("Number of words in the training set known to each sub-model: %d out of %d", nwk[0], len(words))
398
+ else:
399
+ logger.info("Number of words in the training set known to the sub-models:\n %s" % "\n ".join(["%d/%d" % (x, len(words)) for x in nwk]))
400
+
401
+ @staticmethod
402
+ def build_optimizer(args, model, first_optimizer):
403
+ def fake_named_parameters():
404
+ for n, p in model.named_parameters():
405
+ if not n.startswith("models."):
406
+ yield n, p
407
+
408
+ # TODO: there has to be a cleaner way to do this, like maybe a "keep" callback
409
+ # TODO: if we finetune the underlying models, we will want a series of optimizers
410
+ # so that they can have a different learning rate from the ensemble's fields
411
+ fake_model = copy.copy(model)
412
+ fake_model.named_parameters = fake_named_parameters
413
+ optimizer = build_optimizer(args, fake_model, first_optimizer)
414
+ return optimizer
415
+
416
+ @staticmethod
417
+ def load_optimizer(model, checkpoint, first_optimizer, filename):
418
+ optimizer = EnsembleTrainer.build_optimizer(model.models[0].args, model, first_optimizer)
419
+ if checkpoint.get('optimizer_state_dict', None) is not None:
420
+ try:
421
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
422
+ except ValueError as e:
423
+ raise ValueError("Failed to load optimizer from %s" % filename) from e
424
+ else:
425
+ logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
426
+ return optimizer
427
+
428
+ @staticmethod
429
+ def load_scheduler(model, optimizer, checkpoint, first_optimizer):
430
+ scheduler = build_scheduler(model.models[0].args, optimizer, first_optimizer=first_optimizer)
431
+ if 'scheduler_state_dict' in checkpoint:
432
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
433
+ return scheduler
434
+
435
+ @staticmethod
436
+ def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):
437
+ # TODO: no need for the if/else once the models are rebuilt
438
+ children_params = params["children_params"] if isinstance(params, dict) else params
439
+ base_params = params["base_params"] if isinstance(params, dict) else {}
440
+
441
+ # TODO: fill in peft_name
442
+ if peft_params is None:
443
+ peft_params = [None] * len(children_params)
444
+ if peft_name is None:
445
+ peft_name = [None] * len(children_params)
446
+
447
+ if len(children_params) != len(peft_params):
448
+ raise ValueError("Model file had params length %d and peft params length %d" % (len(params), len(peft_params)))
449
+ if len(children_params) != len(peft_name):
450
+ raise ValueError("Model file had params length %d and peft name length %d" % (len(params), len(peft_name)))
451
+
452
+ models = [Trainer.model_from_params(model_param, peft_param, args, foundation_cache, peft_name=pname)
453
+ for model_param, peft_param, pname in zip(children_params, peft_params, peft_name)]
454
+ ensemble = Ensemble(args, models=models)
455
+ ensemble.load_state_dict(base_params, strict=False)
456
+ ensemble = ensemble.to(args.get('device', None))
457
+ return ensemble
458
+
459
+ def parse_args(args=None):
460
+ parser = argparse.ArgumentParser()
461
+
462
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
463
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
464
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
465
+
466
+ utils.add_device_args(parser)
467
+
468
+ parser.add_argument('--lang', default='en', help='Language to use')
469
+
470
+ parser.add_argument('models', type=str, nargs='+', default=None, help="Which model(s) to load")
471
+
472
+ parser.add_argument('--save_name', type=str, default=None, required=True, help='Where to save the combined ensemble')
473
+
474
+ args = vars(parser.parse_args())
475
+
476
+ return args
477
+
478
+ def main(args=None):
479
+ args = parse_args(args)
480
+ foundation_cache = FoundationCache()
481
+
482
+ ensemble = EnsembleTrainer.from_files(args, args['models'], foundation_cache)
483
+ ensemble.save(args['save_name'], save_optimizer=False)
484
+
485
+ if __name__ == "__main__":
486
+ main()
stanza/stanza/models/constituency/in_order_compound_oracle.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, DynamicOracle
4
+ from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, CompoundUnary, Finalize
5
+
6
+ def fix_missing_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
7
+ """
8
+ A CompoundUnary transition was missed after a Shift, but the sequence was continued correctly otherwise
9
+ """
10
+ if not isinstance(gold_transition, CompoundUnary):
11
+ return None
12
+
13
+ if pred_transition != gold_sequence[gold_index + 1]:
14
+ return None
15
+ if isinstance(pred_transition, Finalize):
16
+ # this can happen if the entire tree is a single word
17
+ # but it can't be fixed if it means the parser missed the ROOT transition
18
+ return None
19
+
20
+ return gold_sequence[:gold_index] + gold_sequence[gold_index+1:]
21
+
22
+ def fix_wrong_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
23
+ if not isinstance(gold_transition, CompoundUnary):
24
+ return None
25
+
26
+ if not isinstance(pred_transition, CompoundUnary):
27
+ return None
28
+
29
+ assert gold_transition != pred_transition
30
+
31
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]
32
+
33
+ def fix_spurious_unary_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
34
+ if isinstance(gold_transition, CompoundUnary):
35
+ return None
36
+
37
+ if not isinstance(pred_transition, CompoundUnary):
38
+ return None
39
+
40
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:]
41
+
42
+ def fix_open_shift_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
43
+ """
44
+ Fix a missed Open constituent where we predicted a Shift and the next transition was a Shift
45
+
46
+ In fact, the subsequent transition MUST be a Shift with this transition scheme
47
+ """
48
+ if not isinstance(gold_transition, OpenConstituent):
49
+ return None
50
+
51
+ if not isinstance(pred_transition, Shift):
52
+ return None
53
+
54
+ #if not isinstance(gold_sequence[gold_index+1], Shift):
55
+ # return None
56
+ assert isinstance(gold_sequence[gold_index+1], Shift)
57
+
58
+ # close_index represents the Close for the missing Open
59
+ close_index = advance_past_constituents(gold_sequence, gold_index+1)
60
+ assert close_index is not None
61
+ return gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]
62
+
63
+ def fix_open_open_two_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
64
+ if gold_transition == pred_transition:
65
+ return None
66
+
67
+ if not isinstance(gold_transition, OpenConstituent):
68
+ return None
69
+ if not isinstance(pred_transition, OpenConstituent):
70
+ return None
71
+
72
+ block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
73
+ if isinstance(gold_sequence[block_end], Shift):
74
+ # this is a multiple subtrees version of this error
75
+ # we are only skipping the two subtrees errors for now
76
+ return None
77
+
78
+ # no fix is possible, so we just return here
79
+ return RepairType.OPEN_OPEN_TWO_SUBTREES_ERROR, None
80
+
81
+ def fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three):
82
+ if gold_transition == pred_transition:
83
+ return None
84
+
85
+ if not isinstance(gold_transition, OpenConstituent):
86
+ return None
87
+ if not isinstance(pred_transition, OpenConstituent):
88
+ return None
89
+
90
+ block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
91
+ if not isinstance(gold_sequence[block_end], Shift):
92
+ # this is a multiple subtrees version of this error
93
+ # we are only skipping the two subtrees errors for now
94
+ return None
95
+
96
+ next_block_end = find_in_order_constituent_end(gold_sequence, block_end+1)
97
+ if exactly_three and isinstance(gold_sequence[next_block_end], Shift):
98
+ # for exactly three subtrees,
99
+ # we can put back the missing open transition
100
+ # and now we have no recall error, only precision error
101
+ # for more than three, we separate that out as an ambiguous choice
102
+ return None
103
+ elif not exactly_three and isinstance(gold_sequence[next_block_end], CloseConstituent):
104
+ # this is ambiguous, but we can still try this fix
105
+ return None
106
+
107
+ # at this point, we build a new sequence with the origin constituent inserted
108
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]
109
+
110
+
111
+ def fix_open_open_three_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
112
+ return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=True)
113
+
114
+ def fix_open_open_many_subtrees_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
115
+ return fix_open_open_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, exactly_three=False)
116
+
117
+ def fix_open_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
118
+ """
119
+ Find the closed bracket, reopen it
120
+
121
+ The Open we just missed must be forgotten - it cannot be reopened
122
+ """
123
+ if not isinstance(gold_transition, OpenConstituent):
124
+ return None
125
+
126
+ if not isinstance(pred_transition, CloseConstituent):
127
+ return None
128
+
129
+ # find the appropriate Open so we can reopen it
130
+ open_idx = find_previous_open(gold_sequence, gold_index)
131
+ # actually, if the Close is legal, this can't happen
132
+ # but it might happen in a unit test which doesn't check legality
133
+ if open_idx is None:
134
+ return None
135
+
136
+ # also, since we are punting on the missed Open, we need to skip
137
+ # the Close which would have closed it
138
+ close_idx = advance_past_constituents(gold_sequence, gold_index+1)
139
+
140
+ return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index+1:close_idx] + gold_sequence[close_idx+1:]
141
+
142
+ def fix_shift_close_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
143
+ """
144
+ Find the closed bracket, reopen it
145
+ """
146
+ if not isinstance(gold_transition, Shift):
147
+ return None
148
+
149
+ if not isinstance(pred_transition, CloseConstituent):
150
+ return None
151
+
152
+ # don't do this at the start or immediately after opening
153
+ if gold_index == 0 or isinstance(gold_sequence[gold_index - 1], OpenConstituent):
154
+ return None
155
+
156
+ open_idx = find_previous_open(gold_sequence, gold_index)
157
+ assert open_idx is not None
158
+
159
+ return gold_sequence[:gold_index] + [pred_transition, gold_sequence[open_idx]] + gold_sequence[gold_index:]
160
+
161
+ def fix_shift_open_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
162
+ if not isinstance(gold_transition, Shift):
163
+ return None
164
+
165
+ if not isinstance(pred_transition, OpenConstituent):
166
+ return None
167
+
168
+ bracket_end = find_in_order_constituent_end(gold_sequence, gold_index)
169
+ assert bracket_end is not None
170
+ if isinstance(gold_sequence[bracket_end], Shift):
171
+ # this is an ambiguous error
172
+ # multiple possible places to end the wrong constituent
173
+ return None
174
+ assert isinstance(gold_sequence[bracket_end], CloseConstituent)
175
+
176
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]
177
+
178
+ def fix_close_shift_unambiguous_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
179
+ if not isinstance(gold_transition, CloseConstituent):
180
+ return None
181
+
182
+ if not isinstance(pred_transition, Shift):
183
+ return None
184
+ if not isinstance(gold_sequence[gold_index+1], Shift):
185
+ return None
186
+
187
+ bracket_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
188
+ assert bracket_end is not None
189
+ if isinstance(gold_sequence[bracket_end], Shift):
190
+ # this is an ambiguous error
191
+ # multiple possible places to end the wrong constituent
192
+ return None
193
+ assert isinstance(gold_sequence[bracket_end], CloseConstituent)
194
+
195
+ return gold_sequence[:gold_index] + gold_sequence[gold_index+1:bracket_end] + [CloseConstituent()] + gold_sequence[bracket_end:]
196
+
197
+ class RepairType(Enum):
198
+ """
199
+ Keep track of which repair is used, if any, on an incorrect transition
200
+
201
+ Effects of different repair types:
202
+ no oracle: 0.9251 0.9226
203
+ +missing_unary: 0.9246 0.9214
204
+ +wrong_unary: 0.9236 0.9213
205
+ +spurious_unary: 0.9247 0.9229
206
+ +open_shift_error: 0.9258 0.9226
207
+ +open_open_two_subtrees: 0.9256 0.9215 # nothing changes with this one...
208
+ +open_open_three_subtrees: 0.9256 0.9226
209
+ +open_open_many_subtrees: 0.9257 0.9234
210
+ +shift_close: 0.9267 0.9250
211
+ +shift_open: 0.9273 0.9247
212
+ +close_shift: 0.9266 0.9229
213
+ +open_close: 0.9267 0.9256
214
+ """
215
+ def __new__(cls, fn, correct=False, debug=False):
216
+ """
217
+ Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
218
+ """
219
+ value = len(cls.__members__)
220
+ obj = object.__new__(cls)
221
+ obj._value_ = value + 1
222
+ obj.fn = fn
223
+ obj.correct = correct
224
+ obj.debug = debug
225
+ return obj
226
+
227
+ @property
228
+ def is_correct(self):
229
+ return self.correct
230
+
231
+ # The correct sequence went Shift - Unary - Stuff
232
+ # but the CompoundUnary was missed and Stuff predicted
233
+ # so now we just proceed as if nothing happened
234
+ # note that CompoundUnary happens immediately after a Shift
235
+ # complicated nodes are created with single Open transitions
236
+ MISSING_UNARY_ERROR = (fix_missing_unary_error,)
237
+
238
+ # Predicted a wrong CompoundUnary. No way to fix this, so just keep going
239
+ WRONG_UNARY_ERROR = (fix_wrong_unary_error,)
240
+
241
+ # The correct sequence went Shift - Stuff
242
+ # but instead we predicted a CompoundUnary
243
+ # again, we just keep going
244
+ SPURIOUS_UNARY_ERROR = (fix_spurious_unary_error,)
245
+
246
+ # Were supposed to open a new constituent,
247
+ # but instead shifted an item onto the stack
248
+ #
249
+ # The missed Open cannot be recovered
250
+ #
251
+ # One could ask, is it possible to open a bigger constituent later,
252
+ # but if the constituent patterns go
253
+ # X (good open) Y (missed open) Z
254
+ # when we eventually close Y and Z, because of the missed Open,
255
+ # it is guaranteed to capture X as well
256
+ # since it will grab constituents until one left of the previous Open before Y
257
+ #
258
+ # Therefore, in this case, we must simply forget about this Open (recall error)
259
+ OPEN_SHIFT_ERROR = (fix_open_shift_error,)
260
+
261
+ # With this transition scheme, it is not possible to fix the following pattern:
262
+ # T1 O_x T2 C -> T1 O_y T2 C
263
+ # seeing as how there are no unary transitions
264
+ # so whatever precision & recall errors are caused by substituting O_x -> O_y
265
+ # (which could include multiple transitions)
266
+ # those errors are unfixable in any way
267
+ OPEN_OPEN_TWO_SUBTREES_ERROR = (fix_open_open_two_subtrees_error,)
268
+
269
+ # With this transition scheme, a three subtree branch with a wrong Open
270
+ # has a non-ambiguous fix
271
+ # T1 O_x T2 T3 C -> T1 O_y T2 T3 C
272
+ # this can become
273
+ # T1 O_y T2 C O_x T3 C
274
+ # now there are precision errors from the incorrectly added transition(s),
275
+ # but the correctly replaced transitions are unambiguous
276
+ OPEN_OPEN_THREE_SUBTREES_ERROR = (fix_open_open_three_subtrees_error,)
277
+
278
+ # We were supposed to shift a new item onto the stack,
279
+ # but instead we closed the previous constituent
280
+ # This causes a precision error, but we can avoid the recall error
281
+ # by immediately reopening the closed constituent.
282
+ SHIFT_CLOSE_ERROR = (fix_shift_close_error,)
283
+
284
+ # We opened a new constituent instead of shifting
285
+ # In the event that the next constituent ends with a close,
286
+ # rather than building another new constituent,
287
+ # then there is no ambiguity
288
+ SHIFT_OPEN_UNAMBIGUOUS_ERROR = (fix_shift_open_unambiguous_error,)
289
+
290
+ # Suppose we were supposed to Close, then Shift
291
+ # but instead we just did a Shift
292
+ # Similar to shift_open_unambiguous, we now have an opened
293
+ # constituent which shouldn't be there
294
+ # We can scroll past the next constituent created to see
295
+ # if the outer constituents close at that point
296
+ # If so, we can close this constituent as well in an unambiguous manner
297
+ # TODO: analyze the case where we were supposed to Close, Open
298
+ # but instead did a Shift
299
+ CLOSE_SHIFT_UNAMBIGUOUS_ERROR = (fix_close_shift_unambiguous_error,)
300
+
301
+ # Supposed to open a new constituent,
302
+ # instead closed an existing constituent
303
+ #
304
+ # X (good open) Y (open -> close) Z
305
+ #
306
+ # the constituent that should contain Y, Z is unfortunately lost
307
+ # since now the stack has
308
+ #
309
+ # XY ...
310
+ #
311
+ # furthermore, there is now a precision error for the extra XY
312
+ # constituent that should not exist
313
+ # however, what we can do to minimize further errors is
314
+ # to at least reopen the label between X and Y
315
+ OPEN_CLOSE_ERROR = (fix_open_close_error,)
316
+
317
+ # this is ambiguous, but we can still try the same fix as three_subtrees (see above)
318
+ OPEN_OPEN_MANY_SUBTREES_ERROR = (fix_open_open_many_subtrees_error,)
319
+
320
+ CORRECT = (None, True)
321
+
322
+ UNKNOWN = None
323
+
324
+
325
+ class InOrderCompoundOracle(DynamicOracle):
326
+ def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
327
+ super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
stanza/stanza/models/constituency/in_order_oracle.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from stanza.models.constituency.dynamic_oracle import advance_past_constituents, find_in_order_constituent_end, find_previous_open, score_candidates, DynamicOracle, RepairEnum
4
+ from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent
5
+
6
+ def fix_wrong_open_root_error(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
7
+ """
8
+ If there is an open/open error specifically at the ROOT, close the wrong open and try again
9
+ """
10
+ if gold_transition == pred_transition:
11
+ return None
12
+
13
+ if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent) and gold_transition.top_label in root_labels:
14
+ return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]
15
+
16
+ return None
17
+
18
+ def fix_wrong_open_unary_chain(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
19
+ """
20
+ Fix a wrong open/open in a unary chain by removing the skipped unary transitions
21
+
22
+ Only applies is the wrong pred transition is a transition found higher up in the unary chain
23
+ """
24
+ # useful to have this check here in case the call is made independently in a unit test
25
+ if gold_transition == pred_transition:
26
+ return None
27
+
28
+ if isinstance(gold_transition, OpenConstituent) and isinstance(pred_transition, OpenConstituent):
29
+ cur_index = gold_index + 1 # This is now a Close if we are in this particular context
30
+ while cur_index + 1 < len(gold_sequence) and isinstance(gold_sequence[cur_index], CloseConstituent) and isinstance(gold_sequence[cur_index+1], OpenConstituent):
31
+ cur_index = cur_index + 1 # advance to the next Open
32
+ if gold_sequence[cur_index] == pred_transition:
33
+ return gold_sequence[:gold_index] + gold_sequence[cur_index:]
34
+ cur_index = cur_index + 1 # advance to the next Close
35
+
36
+ return None
37
+
38
+ def fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two):
39
+ if gold_transition == pred_transition:
40
+ return None
41
+
42
+ if not isinstance(gold_transition, OpenConstituent):
43
+ return None
44
+ if not isinstance(pred_transition, OpenConstituent):
45
+ return None
46
+
47
+ if isinstance(gold_sequence[gold_index+1], CloseConstituent):
48
+ # if Close, the gold was a unary
49
+ return None
50
+ assert not isinstance(gold_sequence[gold_index+1], OpenConstituent)
51
+ assert isinstance(gold_sequence[gold_index+1], Shift)
52
+
53
+ block_end = find_in_order_constituent_end(gold_sequence, gold_index+1)
54
+ assert block_end is not None
55
+
56
+ if more_than_two and isinstance(gold_sequence[block_end], CloseConstituent):
57
+ return None
58
+ if not more_than_two and isinstance(gold_sequence[block_end], Shift):
59
+ return None
60
+
61
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:block_end] + [CloseConstituent(), gold_transition] + gold_sequence[block_end:]
62
+
63
+ def fix_wrong_open_two_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
64
+ return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=False)
65
+
66
+ def fix_wrong_open_multiple_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
67
+ return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=True)
68
+
69
+ def advance_past_unaries(gold_sequence, cur_index):
70
+ while cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index], OpenConstituent) and isinstance(gold_sequence[cur_index+1], CloseConstituent):
71
+ cur_index += 2
72
+ return cur_index
73
+
74
+ def fix_wrong_open_stuff_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
75
+ """
76
+ Fix a wrong open/open when there is an intervening constituent and then the guessed NT
77
+
78
+ This happens when the correct pattern is
79
+ stuff_1 NT_X stuff_2 close NT_Y ...
80
+ and instead of guessing the gold transition NT_X,
81
+ the prediction was NT_Y
82
+ """
83
+ if gold_transition == pred_transition:
84
+ return None
85
+
86
+ if not isinstance(gold_transition, OpenConstituent):
87
+ return None
88
+ if not isinstance(pred_transition, OpenConstituent):
89
+ return None
90
+ # TODO: Here we could advance past unary transitions while
91
+ # watching for hitting pred_transition. However, that is an open
92
+ # question... is it better to try to keep such an Open as part of
93
+ # the sequence, or is it better to skip them and attach the inner
94
+ # nodes to the upper level
95
+ stuff_start = gold_index + 1
96
+ if not isinstance(gold_sequence[stuff_start], Shift):
97
+ return None
98
+ stuff_end = advance_past_constituents(gold_sequence, stuff_start)
99
+ if stuff_end is None:
100
+ return None
101
+ # at this point, stuff_end points to the Close which occurred after stuff_2
102
+ # also, stuff_start points to the first transition which makes stuff_2, the Shift
103
+ cur_index = stuff_end + 1
104
+ while isinstance(gold_sequence[cur_index], OpenConstituent):
105
+ if gold_sequence[cur_index] == pred_transition:
106
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index+1:]
107
+ # this was an OpenConstituent, but not the OpenConstituent we guessed
108
+ # maybe there's a unary transition which lets us try again
109
+ if cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index + 1], CloseConstituent):
110
+ cur_index = cur_index + 2
111
+ else:
112
+ break
113
+
114
+ # oh well, none of this worked
115
+ return None
116
+
117
+ def fix_wrong_open_general(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
118
+ """
119
+ Fix a general wrong open/open transition by accepting the open and continuing
120
+
121
+ A couple other open/open patterns have already been carved out
122
+
123
+ TODO: negative checks for the previous patterns, in case we turn those off
124
+ """
125
+ if gold_transition == pred_transition:
126
+ return None
127
+
128
+ if not isinstance(gold_transition, OpenConstituent):
129
+ return None
130
+ if not isinstance(pred_transition, OpenConstituent):
131
+ return None
132
+ # If the top is a ROOT, then replacing it with a non-ROOT creates an illegal
133
+ # transition sequence. The ROOT case was already handled elsewhere anyway
134
+ if gold_transition.top_label in root_labels:
135
+ return None
136
+
137
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:]
138
+
139
+ def fix_missed_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
140
+ """
141
+ Fix a missed unary which is followed by an otherwise correct transition
142
+
143
+ (also handles multiple missed unary transitions)
144
+ """
145
+ if gold_transition == pred_transition:
146
+ return None
147
+
148
+ cur_index = gold_index
149
+ cur_index = advance_past_unaries(gold_sequence, cur_index)
150
+ if gold_sequence[cur_index] == pred_transition:
151
+ return gold_sequence[:gold_index] + gold_sequence[cur_index:]
152
+ return None
153
+
154
+ def fix_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
155
+ """
156
+ Fix an Open replaced with a Shift
157
+
158
+ Suppose we were supposed to guess NT_X and instead did S
159
+
160
+ We derive the repair as follows.
161
+
162
+ For simplicity, assume the open is not a unary for now
163
+
164
+ Since we know an Open was legal, there must be stuff
165
+ stuff NT_X
166
+ Shift is also legal, so there must be other stuff and a previous Open
167
+ stuff_1 NT_Y stuff_2 NT_X
168
+ After the NT_X which we missed, there was a bunch of stuff and a close for NT_X
169
+ stuff_1 NT_Y stuff_2 NT_X stuff_3 C
170
+ There could be more stuff here which can be saved...
171
+ stuff_1 NT_Y stuff_2 NT_X stuff_3 C stuff_4 C
172
+ stuff_1 NT_Y stuff_2 NT_X stuff_3 C C
173
+ """
174
+ if not isinstance(gold_transition, OpenConstituent):
175
+ return None
176
+ if not isinstance(pred_transition, Shift):
177
+ return None
178
+
179
+ cur_index = gold_index
180
+ cur_index = advance_past_unaries(gold_sequence, cur_index)
181
+ if not isinstance(gold_sequence[cur_index], OpenConstituent):
182
+ return None
183
+ if gold_sequence[cur_index].top_label in root_labels:
184
+ return None
185
+ # cur_index now points to the NT_X we missed (not counting unaries)
186
+
187
+ stuff_start = cur_index + 1
188
+ # can't be a Close, since we just went past an Open and checked for unaries
189
+ # can't be an Open, since two Open in a row is illegal
190
+ assert isinstance(gold_sequence[stuff_start], Shift)
191
+ stuff_end = advance_past_constituents(gold_sequence, stuff_start)
192
+ # stuff_end is now the Close which ends NT_X
193
+ cur_index = stuff_end + 1
194
+ if cur_index >= len(gold_sequence):
195
+ return None
196
+ if isinstance(gold_sequence[cur_index], OpenConstituent):
197
+ cur_index = advance_past_unaries(gold_sequence, cur_index)
198
+ if cur_index >= len(gold_sequence):
199
+ return None
200
+ if isinstance(gold_sequence[cur_index], OpenConstituent):
201
+ # an Open here signifies that there was a bracket containing X underneath Y
202
+ # TODO: perhaps try to salvage something out of that situation?
203
+ return None
204
+ # the repair starts with the sequence up through the error,
205
+ # then stuff_3, which includes the error
206
+ # skip the Close for the missed NT_X
207
+ # then finish the sequence with any potential stuff_4, the next Close, and everything else
208
+ repair = gold_sequence[:gold_index] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]
209
+ return repair
210
+
211
+ def fix_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
212
+ """
213
+ Fix an Open replaced with a Close
214
+
215
+ Call the Open NT_X
216
+ Open legal, so there must be stuff:
217
+ stuff NT_X
218
+ Close legal, so there must be something to close:
219
+ stuff_1 NT_Y stuff_2 NT_X
220
+
221
+ The incorrect close makes the following brackets:
222
+ (Y stuff_1 stuff_2)
223
+ We were supposed to build
224
+ (Y stuff_1 (X stuff_2 ...) (possibly more stuff))
225
+ The simplest fix here is to reopen Y at this point.
226
+
227
+ One issue might be if there is another bracket which encloses X underneath Y
228
+ So, for example, the tree was supposed to be
229
+ (Y stuff_1 (Z (X stuff_2 stuff_3) stuff_4))
230
+ The pattern for this case is
231
+ stuff_1 NT_Y stuff_2 NY_X stuff_3 close NT_Z stuff_4 close close
232
+ """
233
+ if not isinstance(gold_transition, OpenConstituent):
234
+ return None
235
+ if not isinstance(pred_transition, CloseConstituent):
236
+ return None
237
+
238
+ cur_index = advance_past_unaries(gold_sequence, gold_index)
239
+ if cur_index >= len(gold_sequence):
240
+ return None
241
+ if not isinstance(gold_sequence[cur_index], OpenConstituent):
242
+ return None
243
+ if gold_sequence[cur_index].top_label in root_labels:
244
+ return None
245
+
246
+ prev_open_index = find_previous_open(gold_sequence, gold_index)
247
+ if prev_open_index is None:
248
+ return None
249
+ prev_open = gold_sequence[prev_open_index]
250
+ # prev_open is now NT_Y from above
251
+
252
+ stuff_start = cur_index + 1
253
+ assert isinstance(gold_sequence[stuff_start], Shift)
254
+ stuff_end = advance_past_constituents(gold_sequence, stuff_start)
255
+ # stuff_end is now the Close which ends NT_X
256
+ # stuff_start:stuff_end is the stuff_3 block above
257
+ cur_index = stuff_end + 1
258
+ if cur_index >= len(gold_sequence):
259
+ return None
260
+ # if there are unary transitions here, we want to skip those.
261
+ # those are unary transitions on X and cannot be recovered, since X is gone
262
+ cur_index = advance_past_unaries(gold_sequence, cur_index)
263
+ # now there is a certain failure case which has to be accounted for.
264
+
265
+ # specifically, if there is a new non-terminal which opens
266
+ # immediately after X closes, it is encompassing X in a way that
267
+ # cannot be recovered now that part of X is stuck under Y.
268
+ # The two choices at this point would be to eliminate the new
269
+ # transition or just reject the tree from the repair
270
+ # For now, we reject the tree
271
+ if isinstance(gold_sequence[cur_index], OpenConstituent):
272
+ return None
273
+
274
+ repair = gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[stuff_start:stuff_end] + gold_sequence[cur_index:]
275
+ return repair
276
+
277
+ def fix_shift_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
278
+ """
279
+ This fixes Shift replaced with a Close transition.
280
+
281
+ This error occurs in the following pattern:
282
+ stuff_1 NT_X stuff... shift
283
+ Instead of shift, you close the NT_X
284
+ The easiest fix here is to just restore the NT_X.
285
+ """
286
+
287
+ if not isinstance(pred_transition, CloseConstituent):
288
+ return None
289
+
290
+ # this fix can also be applied if there were unaries on the
291
+ # previous constituent. we just skip those until the Shift
292
+ cur_index = gold_index
293
+ if isinstance(gold_transition, OpenConstituent):
294
+ cur_index = advance_past_unaries(gold_sequence, cur_index)
295
+ if not isinstance(gold_sequence[cur_index], Shift):
296
+ return None
297
+
298
+ prev_open_index = find_previous_open(gold_sequence, gold_index)
299
+ if prev_open_index is None:
300
+ return None
301
+ prev_open = gold_sequence[prev_open_index]
302
+ # prev_open is now NT_X from above
303
+
304
+ return gold_sequence[:gold_index] + [pred_transition, prev_open] + gold_sequence[cur_index:]
305
+
306
+ def fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):
307
+ if not isinstance(gold_transition, CloseConstituent):
308
+ return None
309
+ if not isinstance(pred_transition, Shift):
310
+ return None
311
+
312
+ if len(gold_sequence) < gold_index + 3:
313
+ return None
314
+ if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
315
+ return None
316
+
317
+ open_index = advance_past_unaries(gold_sequence, gold_index+1)
318
+ if not isinstance(gold_sequence[open_index], OpenConstituent):
319
+ return None
320
+ if not isinstance(gold_sequence[open_index+1], Shift):
321
+ return None
322
+
323
+ # check that the next operation was to open a *different* constituent
324
+ # from the one we just closed
325
+ prev_open_index = find_previous_open(gold_sequence, gold_index)
326
+ if prev_open_index is None:
327
+ return None
328
+ prev_open = gold_sequence[prev_open_index]
329
+ if gold_sequence[open_index] == prev_open:
330
+ return None
331
+
332
+ # check that the following stuff is a single bracket, not multiple brackets
333
+ end_index = find_in_order_constituent_end(gold_sequence, open_index+1)
334
+ if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):
335
+ return None
336
+ elif not ambiguous and isinstance(gold_sequence[end_index], Shift):
337
+ return None
338
+
339
+ # if closing at the end of the next blocks,
340
+ # instead of closing after the first block ends,
341
+ # we go to the end of the last block
342
+ if late:
343
+ end_index = advance_past_constituents(gold_sequence, open_index+1)
344
+
345
+ return gold_sequence[:gold_index] + gold_sequence[open_index+1:end_index] + gold_sequence[gold_index:open_index+1] + gold_sequence[end_index:]
346
+
347
+ def fix_close_open_shift_unambiguous_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
348
+ return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)
349
+
350
+ def fix_close_open_shift_ambiguous_bracket_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
351
+ return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)
352
+
353
+ def fix_close_open_shift_ambiguous_bracket_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
354
+ return fix_close_shift_open_bracket(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)
355
+
356
+ def fix_close_open_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
357
+ if not isinstance(gold_transition, CloseConstituent):
358
+ return None
359
+ if not isinstance(pred_transition, Shift):
360
+ return None
361
+
362
+ if len(gold_sequence) < gold_index + 3:
363
+ return None
364
+ if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
365
+ return None
366
+
367
+ open_index = advance_past_unaries(gold_sequence, gold_index+1)
368
+ if not isinstance(gold_sequence[open_index], OpenConstituent):
369
+ return None
370
+ if not isinstance(gold_sequence[open_index+1], Shift):
371
+ return None
372
+
373
+ # check that the next operation was to open a *different* constituent
374
+ # from the one we just closed
375
+ prev_open_index = find_previous_open(gold_sequence, gold_index)
376
+ if prev_open_index is None:
377
+ return None
378
+ prev_open = gold_sequence[prev_open_index]
379
+ if gold_sequence[open_index] == prev_open:
380
+ return None
381
+
382
+ # alright, at long last we have:
383
+ # a close that was missed
384
+ # a non-nested open that was missed
385
+ end_index = find_in_order_constituent_end(gold_sequence, open_index+1)
386
+
387
+ candidates = []
388
+ candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))
389
+ while isinstance(gold_sequence[end_index], Shift):
390
+ end_index = find_in_order_constituent_end(gold_sequence, end_index+1)
391
+ candidates.append((gold_sequence[:gold_index], gold_sequence[open_index+1:end_index], gold_sequence[gold_index:open_index+1], gold_sequence[end_index:]))
392
+
393
+ scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
394
+ if len(candidates) == 1:
395
+ return RepairType.CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET, best_candidate
396
+
397
+ if best_idx == len(candidates) - 1:
398
+ best_idx = -1
399
+ repair_type = RepairEnum(name=RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.name,
400
+ value="%d.%d" % (RepairType.CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),
401
+ is_correct=False)
402
+ return repair_type, best_candidate
403
+
404
+ def fix_close_open_shift_nested(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
405
+ """
406
+ Fix a Close X..Open X..Shift pattern where both the Close and Open were skipped.
407
+
408
+ Here the pattern we are trying to fix is
409
+ stuff_A open_X stuff_B *close* open_X shift...
410
+ replaced with
411
+ stuff_A open_X stuff_B shift...
412
+ the missed close & open means a missed recall error for (X A B)
413
+ whereas the previous open_X can still get the outer bracket
414
+ """
415
+ if not isinstance(gold_transition, CloseConstituent):
416
+ return None
417
+ if not isinstance(pred_transition, Shift):
418
+ return None
419
+
420
+ if len(gold_sequence) < gold_index + 3:
421
+ return None
422
+ if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
423
+ return None
424
+
425
+ # handle the sequence:
426
+ # stuff_A open_X stuff_B close open_Y close open_X shift
427
+ open_index = advance_past_unaries(gold_sequence, gold_index+1)
428
+ if not isinstance(gold_sequence[open_index], OpenConstituent):
429
+ return None
430
+ if not isinstance(gold_sequence[open_index+1], Shift):
431
+ return None
432
+
433
+ # check that the next operation was to open the same constituent
434
+ # we just closed
435
+ prev_open_index = find_previous_open(gold_sequence, gold_index)
436
+ if prev_open_index is None:
437
+ return None
438
+ prev_open = gold_sequence[prev_open_index]
439
+ if gold_sequence[open_index] != prev_open:
440
+ return None
441
+
442
+ return gold_sequence[:gold_index] + gold_sequence[open_index+1:]
443
+
444
+ def fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous, late):
445
+ """
446
+ Repair Close/Shift -> Shift by moving the Close to after the next block is created
447
+ """
448
+ if not isinstance(gold_transition, CloseConstituent):
449
+ return None
450
+ if not isinstance(pred_transition, Shift):
451
+ return None
452
+ if len(gold_sequence) < gold_index + 2:
453
+ return None
454
+ start_index = gold_index + 1
455
+ start_index = advance_past_unaries(gold_sequence, start_index)
456
+ if len(gold_sequence) < start_index + 2:
457
+ return None
458
+ if not isinstance(gold_sequence[start_index], Shift):
459
+ return None
460
+
461
+ end_index = find_in_order_constituent_end(gold_sequence, start_index)
462
+ if end_index is None:
463
+ return None
464
+ # if this *isn't* a close, we don't allow it in the unambiguous case
465
+ # that case seems to be ambiguous...
466
+ # stuff_1 close stuff_2 stuff_3
467
+ # if you would normally start building stuff_3,
468
+ # it is not clear if you want to close at the end of
469
+ # stuff_2 or build stuff_3 instead.
470
+ if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):
471
+ return None
472
+ elif not ambiguous and isinstance(gold_sequence[end_index], Shift):
473
+ return None
474
+
475
+ # close at the end of the brackets, rather than once the first bracket is finished
476
+ if late:
477
+ end_index = advance_past_constituents(gold_sequence, start_index)
478
+
479
+ return gold_sequence[:gold_index] + gold_sequence[start_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
480
+
481
+ def fix_close_shift_shift_unambiguous(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
482
+ return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=False, late=False)
483
+
484
+ def fix_close_shift_shift_ambiguous_early(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
485
+ return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=False)
486
+
487
+ def fix_close_shift_shift_ambiguous_late(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
488
+ return fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, ambiguous=True, late=True)
489
+
490
+ def fix_close_shift_shift_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
491
+ if not isinstance(gold_transition, CloseConstituent):
492
+ return None
493
+ if not isinstance(pred_transition, Shift):
494
+ return None
495
+ if len(gold_sequence) < gold_index + 2:
496
+ return None
497
+ start_index = gold_index + 1
498
+ start_index = advance_past_unaries(gold_sequence, start_index)
499
+ if len(gold_sequence) < start_index + 2:
500
+ return None
501
+ if not isinstance(gold_sequence[start_index], Shift):
502
+ return None
503
+
504
+ # now we know that the gold pattern was
505
+ # Close (unaries) Shift
506
+ # and instead the model predicted Shift
507
+ candidates = []
508
+ current_index = start_index
509
+ while isinstance(gold_sequence[current_index], Shift):
510
+ current_index = find_in_order_constituent_end(gold_sequence, current_index)
511
+ assert current_index is not None
512
+ candidates.append((gold_sequence[:gold_index], gold_sequence[start_index:current_index], [CloseConstituent()], gold_sequence[current_index:]))
513
+ scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
514
+ if len(candidates) == 1:
515
+ return RepairType.CLOSE_SHIFT_SHIFT, best_candidate
516
+ if best_idx == len(candidates) - 1:
517
+ best_idx = -1
518
+ repair_type = RepairEnum(name=RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.name,
519
+ value="%d.%d" % (RepairType.CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED.value, best_idx),
520
+ is_correct=False)
521
+ #print(best_idx, len(candidates), repair_type)
522
+ return repair_type, best_candidate
523
+
524
+ def ambiguous_shift_open_unary_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
525
+ if not isinstance(gold_transition, Shift):
526
+ return None
527
+ if not isinstance(pred_transition, OpenConstituent):
528
+ return None
529
+
530
+ return gold_sequence[:gold_index] + [pred_transition, CloseConstituent()] + gold_sequence[gold_index:]
531
+
532
+ def ambiguous_shift_open_early_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
533
+ if not isinstance(gold_transition, Shift):
534
+ return None
535
+ if not isinstance(pred_transition, OpenConstituent):
536
+ return None
537
+
538
+ # Find when the current block ends,
539
+ # either via a Shift or a Close
540
+ end_index = find_in_order_constituent_end(gold_sequence, gold_index)
541
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
542
+
543
+ def ambiguous_shift_open_late_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
544
+ if not isinstance(gold_transition, Shift):
545
+ return None
546
+ if not isinstance(pred_transition, OpenConstituent):
547
+ return None
548
+
549
+ end_index = advance_past_constituents(gold_sequence, gold_index)
550
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:end_index] + [CloseConstituent()] + gold_sequence[end_index:]
551
+
552
+ def ambiguous_shift_open_predicted_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
553
+ if not isinstance(gold_transition, Shift):
554
+ return None
555
+ if not isinstance(pred_transition, OpenConstituent):
556
+ return None
557
+
558
+ unary_candidate = (gold_sequence[:gold_index], [pred_transition], [CloseConstituent()], gold_sequence[gold_index:])
559
+
560
+ early_index = find_in_order_constituent_end(gold_sequence, gold_index)
561
+ early_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:early_index], [CloseConstituent()], gold_sequence[early_index:])
562
+
563
+ late_index = advance_past_constituents(gold_sequence, gold_index)
564
+ if early_index == late_index:
565
+ candidates = [unary_candidate, early_candidate]
566
+ scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
567
+ if best_idx == 0:
568
+ return_label = "U"
569
+ else:
570
+ return_label = "S"
571
+ else:
572
+ late_candidate = (gold_sequence[:gold_index], [pred_transition] + gold_sequence[gold_index:late_index], [CloseConstituent()], gold_sequence[late_index:])
573
+ candidates = [unary_candidate, early_candidate, late_candidate]
574
+ scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=2)
575
+ if best_idx == 0:
576
+ return_label = "U"
577
+ elif best_idx == 1:
578
+ return_label = "E"
579
+ else:
580
+ return_label = "L"
581
+ repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_PREDICTED_CLOSE.name,
582
+ value="%d.%s" % (RepairType.SHIFT_OPEN_PREDICTED_CLOSE.value, return_label),
583
+ is_correct=False)
584
+ return repair_type, best_candidate
585
+
586
+
587
+ def report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
588
+ if not isinstance(gold_transition, CloseConstituent):
589
+ return None
590
+ if not isinstance(pred_transition, Shift):
591
+ return None
592
+
593
+ return RepairType.OTHER_CLOSE_SHIFT, None
594
+
595
+ def report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
596
+ if not isinstance(gold_transition, CloseConstituent):
597
+ return None
598
+ if not isinstance(pred_transition, OpenConstituent):
599
+ return None
600
+
601
+ return RepairType.OTHER_CLOSE_OPEN, None
602
+
603
+ def report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
604
+ if not isinstance(gold_transition, OpenConstituent):
605
+ return None
606
+ if not isinstance(pred_transition, OpenConstituent):
607
+ return None
608
+
609
+ return RepairType.OTHER_OPEN_OPEN, None
610
+
611
+ def report_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
612
+ if not isinstance(gold_transition, OpenConstituent):
613
+ return None
614
+ if not isinstance(pred_transition, Shift):
615
+ return None
616
+
617
+ return RepairType.OTHER_OPEN_SHIFT, None
618
+
619
+ def report_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
620
+ if not isinstance(gold_transition, OpenConstituent):
621
+ return None
622
+ if not isinstance(pred_transition, CloseConstituent):
623
+ return None
624
+
625
+ return RepairType.OTHER_OPEN_CLOSE, None
626
+
627
+ def report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
628
+ if not isinstance(gold_transition, Shift):
629
+ return None
630
+ if not isinstance(pred_transition, OpenConstituent):
631
+ return None
632
+
633
+ return RepairType.OTHER_SHIFT_OPEN, None
634
+
635
+ class RepairType(Enum):
636
+ """
637
+ Keep track of which repair is used, if any, on an incorrect transition
638
+
639
+ Statistics on English w/ no charlm, no transformer,
640
+ eg word vectors only, best model as of January 2024
641
+
642
+ unambiguous transitions only:
643
+ oracle scheme dev test
644
+ no oracle 0.9245 0.9226
645
+ +wrong_open_root 0.9244 0.9224
646
+ +wrong_unary_chain 0.9243 0.9237
647
+ +wrong_open_unary 0.9249 0.9223
648
+ +wrong_open_general 0.9251 0.9215
649
+ +missed_unary 0.9248 0.9215
650
+ +open_shift 0.9243 0.9216
651
+ +open_close 0.9254 0.9217
652
+ +shift_close 0.9261 0.9238
653
+ +close_shift_nested 0.9253 0.9250
654
+
655
+ Redoing the wrong_open_general, which seemed to hurt test scores:
656
+ wrong_open_two_subtrees - L4 0.9244 0.9220
657
+ every else w/o ambiguous open/open fix 0.9259 0.9241
658
+ everything w/ open_two_subtrees 0.9261 0.9246
659
+ w/ ambiguous open_three_subtrees 0.9264 0.9243
660
+
661
+ Testing three different possible repairs for shift-open:
662
+ w/ ambiguous open_three_subtrees 0.9264 0.9243
663
+ immediate close (unary) 0.9267 0.9246
664
+ close after first bracket 0.9265 0.9256
665
+ close after last bracket 0.9264 0.9240
666
+
667
+ Testing three possible repairs for close-open-shift/shift
668
+ w/ ambiguous open_three_subtrees 0.9264 0.9243
669
+ unambiguous c-o-s/shift 0.9265 0.9246
670
+ ambiguous c-o-s/shift closed early 0.9262 0.9246
671
+ ambiguous c-o-s/shift closed late 0.9259 0.9245
672
+
673
+ Testing three possible repairs for close-shift/shift
674
+ w/ ambiguous open_three_subtrees 0.9264 0.9243
675
+ unambiguous c-s/shift 0.9253 0.9239
676
+ ambiguous c-s/shift closed early 0.9259 0.9235
677
+ ambiguous c-s/shift closed late 0.9252 0.9241
678
+ ambiguous c-s/shift predicted 0.9264 0.9243
679
+
680
+ --------------------------------------------------------
681
+
682
+ Running ID experiments to verify some of the above findings
683
+ no charlm or bert, only 200 epochs
684
+
685
+ Comparing wrong_open fixes
686
+ w/ ambiguous open_two_subtrees 0.8448 0.8335
687
+ w/ ambiguous open_three_subtrees 0.8424 0.8336
688
+
689
+ Testing three possible repairs for close-shift/shift
690
+ unambiguous c-s/shift 0.8448 0.8360
691
+ ambiguous c-s/shift closed early 0.8425 0.8352
692
+ ambiguous c-s/shift closed late 0.8452 0.8334
693
+
694
+ --------------------------------------------------------
695
+
696
+ Running ID experiments to verify some of the above findings
697
+ bert + peft, only 200 epochs
698
+
699
+ Comparing wrong_open fixes
700
+ w/o ambiguous open/open fix 0.8923 0.8834
701
+ w/ ambiguous open_two_subtrees 0.8908 0.8828
702
+ w/ ambiguous open_three_subtrees 0.8901 0.8801
703
+
704
+ Testing three possible repairs for close-shift/shift
705
+ unambiguous c-s/shift 0.8921 0.8825
706
+ ambiguous c-s/shift closed early 0.8924 0.8841
707
+ ambiguous c-s/shift closed late 0.8921 0.8806
708
+ ambiguous c-s/shift predicted 0.8923 0.8835
709
+
710
+ --------------------------------------------------------
711
+
712
+ Running DE experiments to verify some of the above findings
713
+ bert + peft, only 200 epochs
714
+
715
+ Comparing wrong_open fixes
716
+ w/o ambiguous open/open fix 0.9576 0.9402
717
+ w/ ambiguous open_two_subtrees 0.9570 0.9410
718
+ w/ ambiguous open_three_subtrees 0.9569 0.9412
719
+
720
+ Testing three possible repairs for close-shift/shift
721
+ unambiguous c-s/shift 0.9566 0.9408
722
+ ambiguous c-s/shift closed early 0.9564 0.9394
723
+ ambiguous c-s/shift closed late 0.9572 0.9408
724
+ ambiguous c-s/shift predicted 0.9571 0.9404
725
+
726
+ --------------------------------------------------------
727
+
728
+ Running IT experiments to verify some of the above findings
729
+ bert + peft, only 200 epochs
730
+
731
+ Comparing wrong_open fixes
732
+ w/o ambiguous open/open fix 0.8380 0.8361
733
+ w/ ambiguous open_two_subtrees 0.8377 0.8351
734
+ w/ ambiguous open_three_subtrees 0.8381 0.8368
735
+
736
+ Testing three possible repairs for close-shift/shift
737
+ unambiguous c-s/shift 0.8376 0.8392
738
+ ambiguous c-s/shift closed early 0.8363 0.8359
739
+ ambiguous c-s/shift closed late 0.8365 0.8383
740
+ ambiguous c-s/shift predicted 0.8379 0.8371
741
+
742
+ --------------------------------------------------------
743
+
744
+ Running ZH experiments to verify some of the above findings
745
+ bert + peft, only 200 epochs
746
+
747
+ Comparing wrong_open fixes
748
+ w/o ambiguous open/open fix 0.9160 0.9143
749
+ w/ ambiguous open_two_subtrees 0.9145 0.9144
750
+ w/ ambiguous open_three_subtrees 0.9146 0.9142
751
+
752
+ Testing three possible repairs for close-shift/shift
753
+ unambiguous c-s/shift 0.9155 0.9146
754
+ ambiguous c-s/shift closed early 0.9145 0.9153
755
+ ambiguous c-s/shift closed late 0.9138 0.9140
756
+ ambiguous c-s/shift predicted 0.9154 0.9144
757
+
758
+ --------------------------------------------------------
759
+
760
+ Running VI experiments to verify some of the above findings
761
+ bert + peft, only 200 epochs
762
+
763
+ Comparing wrong_open fixes
764
+ w/o ambiguous open/open fix 0.8282 0.7668
765
+ w/ ambiguous open_two_subtrees 0.8272 0.7670
766
+ w/ ambiguous open_three_subtrees 0.8282 0.7668
767
+
768
+ Testing three possible repairs for close-shift/shift
769
+ unambiguous c-s/shift 0.8285 0.7683
770
+ ambiguous c-s/shift closed early 0.8276 0.7678
771
+ ambiguous c-s/shift closed late 0.8278 0.7668
772
+ ambiguous c-s/shift predicted 0.8270 0.7668
773
+
774
+ --------------------------------------------------------
775
+
776
+ Testing a combination of ambiguous vs predicted transitions
777
+
778
+ ambiguous
779
+ EN: (no CSS_U) 0.9258 0.9252
780
+ ZH: (no CSS_U) 0.9153 0.9145
781
+
782
+ predicted
783
+ EN: (no CSS_U) 0.9264 0.9241
784
+ ZH: (no CSS_U) 0.9145 0.9141
785
+ """
786
+ def __new__(cls, fn, correct=False, debug=False):
787
+ """
788
+ Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
789
+
790
+ correct: this represents a correct transition
791
+
792
+ debug: always run this, as it just counts statistics
793
+ """
794
+ value = len(cls.__members__)
795
+ obj = object.__new__(cls)
796
+ obj._value_ = value + 1
797
+ obj.fn = fn
798
+ obj.correct = correct
799
+ obj.debug = debug
800
+ return obj
801
+
802
+ @property
803
+ def is_correct(self):
804
+ return self.correct
805
+
806
+ # The first section is a sequence of repairs when the parser
807
+ # should have chosen NTx but instead chose NTy
808
+
809
+ # Blocks of transitions which can be abstracted away to be
810
+ # anything will be represented as S1, S2, etc... S for stuff
811
+
812
+ # We carve out an exception for a wrong open at the root
813
+ # The only possble transtions at this point are to close
814
+ # the error and try again with the root
815
+ WRONG_OPEN_ROOT_ERROR = (fix_wrong_open_root_error,)
816
+
817
+ # The simplest form of such an error is when there is a sequence
818
+ # of unary transitions and the parser chose a wrong parent.
819
+ # Remember that a unary transition is represented by a pair
820
+ # of transitions, NTx, Close.
821
+ # In this case, the correct sequence was
822
+ # S1 NTx Close NTy Close NTz ...
823
+ # but the parser chose NTy, NTz, etc
824
+ # The repair in this case is to simply discard the unchosen
825
+ # unary transitions and continue
826
+ WRONG_OPEN_UNARY_CHAIN = (fix_wrong_open_unary_chain,)
827
+
828
+ # Similar to the UNARY_CHAIN error, but in this case there is a
829
+ # bunch of stuff (one or more constituents built) between the
830
+ # missed open transition and the close transition
831
+ WRONG_OPEN_STUFF_UNARY = (fix_wrong_open_stuff_unary,)
832
+
833
+ # If the correct sequence is
834
+ # T1 O_x T2 C
835
+ # and instead we predicted
836
+ # T1 O_y ...
837
+ # this can be fixed with a unary transition after
838
+ # T1 O_y T2 C O_x C
839
+ # note that this is technically ambiguous
840
+ # could have done
841
+ # T1 O_x C O_y T2 C
842
+ # but doing this should be easier for the parser to detect (untested)
843
+ # also this way the same code paths can be used for two subtrees
844
+ # and for multiple subtrees
845
+ WRONG_OPEN_TWO_SUBTREES = (fix_wrong_open_two_subtrees,)
846
+
847
+ # If the gold transition is an Open because it is part of
848
+ # a unary transition, and the following transition is a
849
+ # correct Shift or Close, we can just skip past the unary.
850
+ MISSED_UNARY = (fix_missed_unary,)
851
+
852
+ # Open -> Shift errors which don't just represent a unary
853
+ # generally represent a missing bracket which cannot be
854
+ # recovered using the in-order mechanism. Dropping the
855
+ # missing transition is generally the only fix.
856
+ # (This means removing the corresponding Close)
857
+ # One could theoretically create a new transition which
858
+ # grabs two constituents, though
859
+ OPEN_SHIFT = (fix_open_shift,)
860
+
861
+ # Open -> Close is a rather drastic break in the
862
+ # potential structure of the tree. We can no longer
863
+ # recover the missed Open, and we might not be able
864
+ # to recover other following missed Opens as well.
865
+ # In most cases, the only thing to do is reopen the
866
+ # incorrectly closed outer bracket and keep going.
867
+ OPEN_CLOSE = (fix_open_close,)
868
+
869
+ # Similar to the Open -> Close error, but at least
870
+ # in this case we are just introducing one wrong bracket
871
+ # rather than also breaking some existing brackets.
872
+ # The fix here is to reopen the closed bracket.
873
+ SHIFT_CLOSE = (fix_shift_close,)
874
+
875
+ # Specifically fixes an error where bracket X is
876
+ # closed and then immediately opened to build a
877
+ # new X bracket. In this case, the simplest fix
878
+ # will be to skip both the close and the new open
879
+ # and continue from there.
880
+ CLOSE_OPEN_SHIFT_NESTED = (fix_close_open_shift_nested,)
881
+
882
+ # Fix an error where the correct sequence was to Close X, Open Y,
883
+ # then continue building,
884
+ # but instead the model did a Shift in place of C_X O_Y
885
+ # The damage here is a recall error for the missed X and
886
+ # a precision error for the incorrectly opened X
887
+ # However, the Y can actually be recovered - whenever we finally
888
+ # close X, we can then open Y
889
+ # One form of that is unambiguous, that of
890
+ # T_A O_X T_B C O_Y T_C C
891
+ # with only one subtree after the O_Y
892
+ # In that case, the Close that would have closed Y
893
+ # is the only place for the missing close of X
894
+ # So we can produce the following:
895
+ # T_A O_X T_B T_C C O_Y C
896
+ CLOSE_OPEN_SHIFT_UNAMBIGUOUS_BRACKET = (fix_close_open_shift_unambiguous_bracket,)
897
+
898
+ # Similarly to WRONG_OPEN_TWO_SUBTREES, if the correct sequence is
899
+ # T1 O_x T2 T3 C
900
+ # and instead we predicted
901
+ # T1 O_y ...
902
+ # this can be fixed by closing O_y in any number of places
903
+ # T1 O_y T2 C O_x T3 C
904
+ # T1 O_y T2 C T3 O_x C
905
+ # Either solution is a single precision error,
906
+ # but keeps the O_x subtree correct
907
+ # This is an ambiguous transition - we can experiment with different fixes
908
+ WRONG_OPEN_MULTIPLE_SUBTREES = (fix_wrong_open_multiple_subtrees,)
909
+
910
+ CORRECT = (None, True)
911
+
912
+ UNKNOWN = None
913
+
914
+ # If the model is supposed to build a block after a Close
915
+ # operation, attach that block to the piece to the left
916
+ # a couple different variations on this were tried
917
+ # we tried attaching all constituents to the
918
+ # bracket which should have been closed
919
+ # we tried attaching exactly one constituent
920
+ # and we tried attaching only if there was
921
+ # exactly one following constituent
922
+ # none of these improved f1. for example, on the VI dataset, we
923
+ # lost 0.15 F1 with the exactly one following constituent version
924
+ # it might be worthwhile double checking some of the other
925
+ # versions to make sure those also fail, though
926
+ CLOSE_SHIFT_SHIFT = (fix_close_shift_shift_unambiguous,)
927
+
928
+ # In the ambiguous close-shift/shift case, this closes the surrounding bracket
929
+ # (which should have already been closed)
930
+ # as soon as the next constituent is built
931
+ # this turns
932
+ # (A (B s1 s2) s3 s4)
933
+ # into
934
+ # (A (B s1 s2 s3) s4)
935
+ CLOSE_SHIFT_SHIFT_AMBIGUOUS_EARLY = (fix_close_shift_shift_ambiguous_early,)
936
+
937
+ # In the ambiguous close-shift/shift case, this closes the surrounding bracket
938
+ # (which should have already been closed)
939
+ # when the rest of the constituents in this bracket are built
940
+ # this turns
941
+ # (A (B s1 s2) s3 s4)
942
+ # into
943
+ # (A (B s1 s2 s3 s4))
944
+ CLOSE_SHIFT_SHIFT_AMBIGUOUS_LATE = (fix_close_shift_shift_ambiguous_late,)
945
+
946
+ # For the close-shift/shift errors which are ambiguous,
947
+ # this uses the model's predictions to guess which block
948
+ # to put the close after
949
+ CLOSE_SHIFT_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_shift_shift_ambiguous_predicted,)
950
+
951
+ # If a sequence should have gone Close - Open - Shift,
952
+ # and instead we went Shift,
953
+ # we need to close the previous bracket
954
+ # If it is ambiguous
955
+ # such as Close - Open - Shift - Shift
956
+ # close the bracket ASAP
957
+ # eg, Shift - Close - Open - Shift
958
+ CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_EARLY = (fix_close_open_shift_ambiguous_bracket_early,)
959
+
960
+ # for Close - Open - Shift - Shift
961
+ # close the bracket as late as possible
962
+ # eg, Shift - Shift - Close - Open
963
+ CLOSE_OPEN_SHIFT_AMBIGUOUS_BRACKET_LATE = (fix_close_open_shift_ambiguous_bracket_late,)
964
+
965
+ # If the sequence should have gone
966
+ # Close - Open - Shift
967
+ # and instead we predicted a Shift
968
+ # in a context where closing the bracket would be ambiguous
969
+ # we use the model to predict where the close should actually happen
970
+ CLOSE_OPEN_SHIFT_AMBIGUOUS_PREDICTED = (fix_close_open_shift_ambiguous_predicted,)
971
+
972
+ # This particular repair effectively turns the shift -> ambiguous open
973
+ # into a unary transition
974
+ SHIFT_OPEN_UNARY_CLOSE = (ambiguous_shift_open_unary_close,)
975
+
976
+ # Fix the shift -> ambiguous open by closing after the first constituent
977
+ # This is an ambiguous solution because it could also be closed either
978
+ # as a unary transition or with a close at the end of the outer bracket
979
+ SHIFT_OPEN_EARLY_CLOSE = (ambiguous_shift_open_early_close,)
980
+
981
+ # Fix the shift -> ambiguous open by closing after all constituents
982
+ # This is an ambiguous solution because it could also be closed either
983
+ # as a unary transition or with a close at the end of the first constituent
984
+ SHIFT_OPEN_LATE_CLOSE = (ambiguous_shift_open_late_close,)
985
+
986
+ # Use the model to predict when to close!
987
+ # The different options for where to put the Close are put into the model,
988
+ # and the highest scoring close is used
989
+ SHIFT_OPEN_PREDICTED_CLOSE = (ambiguous_shift_open_predicted_close,)
990
+
991
+ OTHER_CLOSE_SHIFT = (report_close_shift, False, True)
992
+
993
+ OTHER_CLOSE_OPEN = (report_close_open, False, True)
994
+
995
+ OTHER_OPEN_OPEN = (report_open_open, False, True)
996
+
997
+ OTHER_OPEN_CLOSE = (report_open_close, False, True)
998
+
999
+ OTHER_OPEN_SHIFT = (report_open_shift, False, True)
1000
+
1001
+ OTHER_SHIFT_OPEN = (report_shift_open, False, True)
1002
+
1003
+ # any other open transition we get wrong, which hasn't already
1004
+ # been carved out as an exception above, we just accept the
1005
+ # incorrect Open and keep going
1006
+ #
1007
+ # TODO: check if there is a way to improve this
1008
+ # it appears to hurt scores simply by existing
1009
+ # explanation: this is wrong logic
1010
+ # Suppose the correct sequence had been
1011
+ # T1 open(NP) T2 T3 close
1012
+ # Instead we had done
1013
+ # T1 open(VP) T2 T3 close
1014
+ # We can recover the missing NP!
1015
+ # T1 open(VP) T2 close open(NP) T3 close
1016
+ # Can also recover it as
1017
+ # T1 open(VP) T2 T3 close open(NP) close
1018
+ # So this is actually an ambiguous transition
1019
+ # except in the case of
1020
+ # T1 open(...) close
1021
+ # In this case, a unary transition can fix make it so we only have
1022
+ # a precision error, not also a recall error
1023
+ # Currently, the approach is to put this after the default fixes
1024
+ # and use the two & more-than-two versions of the fix above
1025
+ WRONG_OPEN_GENERAL = (fix_wrong_open_general,)
1026
+
1027
+ class InOrderOracle(DynamicOracle):
1028
+ def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
1029
+ super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
stanza/stanza/models/constituency/lstm_model.py ADDED
@@ -0,0 +1,1178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A version of the BaseModel which uses LSTMs to predict the correct next transition
3
+ based on the current known state.
4
+
5
+ The primary purpose of this class is to implement the prediction of the next
6
+ transition, which is done by concatenating the output of an LSTM operated over
7
+ previous transitions, the words, and the partially built constituents.
8
+
9
+ A complete processing of a sentence is as follows:
10
+ 1) Run the input words through an encoder.
11
+ The encoder includes some or all of the following:
12
+ pretrained word embedding
13
+ finetuned word embedding for training set words - "delta_embedding"
14
+ POS tag embedding
15
+ pretrained charlm representation
16
+ BERT or similar large language model representation
17
+ attention transformer over the previous inputs
18
+ labeled attention transformer over the first attention layer
19
+ The encoded input is then put through a bi-lstm, giving a word representation
20
+ 2) Transitions are put in an embedding, and transitions already used are tracked
21
+ in an LSTM
22
+ 3) Constituents already built are also processed in an LSTM
23
+ 4) Every transition is chosen by taking the output of the current word position,
24
+ the transition LSTM, and the constituent LSTM, and classifying the next
25
+ transition
26
+ 5) Transitions are repeated (with constraints) until the sentence is completed
27
+ """
28
+
29
+ from collections import namedtuple
30
+ import copy
31
+ from enum import Enum
32
+ import logging
33
+ import math
34
+ import random
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+ from torch.nn.utils.rnn import pack_padded_sequence
39
+
40
+ from stanza.models.common.bert_embedding import extract_bert_embeddings
41
+ from stanza.models.common.maxout_linear import MaxoutLinear
42
+ from stanza.models.common.utils import attach_bert_model, unsort
43
+ from stanza.models.common.vocab import PAD_ID, UNK_ID
44
+ from stanza.models.constituency.base_model import BaseModel
45
+ from stanza.models.constituency.label_attention import LabelAttentionModule
46
+ from stanza.models.constituency.lstm_tree_stack import LSTMTreeStack
47
+ from stanza.models.constituency.parse_transitions import TransitionScheme
48
+ from stanza.models.constituency.parse_tree import Tree
49
+ from stanza.models.constituency.partitioned_transformer import PartitionedTransformerModule
50
+ from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding
51
+ from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack
52
+ from stanza.models.constituency.tree_stack import TreeStack
53
+ from stanza.models.constituency.utils import build_nonlinearity, initialize_linear
54
+
55
+ logger = logging.getLogger('stanza')
56
+ tlogger = logging.getLogger('stanza.constituency.trainer')
57
+
58
+ WordNode = namedtuple("WordNode", ['value', 'hx'])
59
+
60
+ # lstm_hx & lstm_cx are the hidden & cell states of the LSTM going across constituents
61
+ # tree_hx and tree_cx are the states of the lstm going up the constituents in the case of the tree_lstm combination method
62
+ Constituent = namedtuple("Constituent", ['value', 'tree_hx', 'tree_cx'])
63
+
64
+ # The sentence boundary vectors are marginally useful at best.
65
+ # However, they make it much easier to use non-bert layers as input to
66
+ # attention layers, as the attention layers work better when they have
67
+ # an index 0 to attend to.
68
+ class SentenceBoundary(Enum):
69
+ NONE = 1
70
+ WORDS = 2
71
+ EVERYTHING = 3
72
+
73
+ class StackHistory(Enum):
74
+ LSTM = 1
75
+ ATTN = 2
76
+
77
+ # How to compose constituent children into new constituents
78
+ # MAX is simply take the max value of the children
79
+ # this is surprisingly effective
80
+ # for example, a Turkish dataset went from 81-81.5 dev, 75->75.5 test
81
+ # BILSTM is the method described in the papers of making an lstm
82
+ # out of the constituents
83
+ # BILSTM_MAX is the same as BILSTM, but instead of using a Linear
84
+ # to reduce the outputs of the lstm, we first take the max
85
+ # and then use a linear to reduce the max
86
+ # BIGRAM combines pairs of children and then takes the max over those
87
+ # ATTN means to put an attention layer over the children nodes
88
+ # we then take the max of the children with their attention
89
+ #
90
+ # Experiments show that MAX is noticeably better than the other options
91
+ # On ja_alt, here are a few results after 200 iterations,
92
+ # averaged over 5 iterations:
93
+ # MAX: 0.8985
94
+ # BILSTM: 0.8964
95
+ # BILSTM_MAX: 0.8973
96
+ # BIGRAM: 0.8982
97
+ #
98
+ # The MAX method has a linear transform after the max.
99
+ # Removing that transform makes the score go down to 0.8982
100
+ #
101
+ # We tried a few varieties of BILSTM_MAX
102
+ # In particular:
103
+ # max over LSTM, combining forward & backward using the max: 0.8970
104
+ # max over forward & backward separately, then reduce: 0.8970
105
+ # max over forward & backward only over 1:-1
106
+ # (eg, leave out the node embedding): 0.8969
107
+ # same as previous, but split the reduce into 2 pieces: 0.8973
108
+ # max over forward & backward separately, then reduce as
109
+ # 1/2(F + B) + W(F,B)
110
+ # the idea being that this way F and B are guaranteed
111
+ # to be represented: 0.8971
112
+ #
113
+ # BIGRAM is an attempt to mix information from nodes
114
+ # when building constituents, but it didn't help
115
+ # The first example, just taking pairs and learning
116
+ # a transform, went to NaN. Likely the transform
117
+ # expanded the embedding too much. Switching it to
118
+ # scale the matrix by 0.5 didn't go to Nan, but only
119
+ # resulted in 0.8982
120
+ #
121
+ # A couple varieties of ATTN:
122
+ # first an input linear, then attn, then an output linear
123
+ # the upside of this would be making the dimension of the attn
124
+ # independent from the rest of the model
125
+ # however, this caused an expansion in the magnitude of the vectors,
126
+ # resulting in NaN for deep enough trees
127
+ # adding layernorm or tanh to balance this out resulted in
128
+ # disappointing performance
129
+ # tanh: 0.8972
130
+ # another alternative not tested yet: lower initialization weights
131
+ # and enforce that the norms of the matrices are low enough that
132
+ # exponential explosion up the layers of the tree doesn't happen
133
+ # just an attention layer means hidden_size % reduce_heads == 0
134
+ # that is simple enough to enforce by slightly changing hidden_size
135
+ # if needed
136
+ # appending the embedding for the open state to the start of the
137
+ # sequence of children and taking only the content nodes
138
+ # was very disappointing: 0.8967
139
+ # taking the entire sequence of children including the open state
140
+ # embedding resulted in 0.8973
141
+ # long story short, this looks like an idea that should work, but it
142
+ # doesn't help. suggestions welcome for improving these results
143
+ #
144
+ # The current TREE_LSTM_CX mechanism uses a word's embedding
145
+ # as the hx and a trained embedding over tags as the cx 0.8996
146
+ # This worked slightly better than 0s for cx (TREE_LSTM) 0.8992
147
+ # A variant of TREE_LSTM which didn't work out:
148
+ # nodes are combined with an LSTM
149
+ # hx & cx are embeddings of the node type (eg S, NP, etc)
150
+ # input is the max over children: 0.8977
151
+ # Another variant which didn't work: use the word embedding
152
+ # as input to the same LSTM to get hx & cx 0.8985
153
+ # Note that although the scores for TREE_LSTM_CX are slightly higher
154
+ # than MAX for the JA dataset, the benefit was not as clear for EN,
155
+ # so we left the default at MAX.
156
+ # For example, on English WSJ, before switching to Bert POS and
157
+ # a learned Bert mixing layer, a comparison of 5x models trained
158
+ # for 400 iterations got dev scores of:
159
+ # TREE_LSTM_CX 0.9589
160
+ # MAX 0.9593
161
+ #
162
+ # UNTIED_MAX has a different reduce_linear for each type of
163
+ # constituent in the model. Similar to the different linear
164
+ # maps used in the CVG paper from Socher, Bauer, Manning, Ng
165
+ # This is implemented as a large CxHxH parameter,
166
+ # with num_constituent layers of hidden-hidden transform,
167
+ # along with a CxH bias parameter.
168
+ # Essentially C Linears stacked on top of each other,
169
+ # but in a parameter so that indexing can be done quickly.
170
+ # Unfortunately this does not beat out MAX with one combined linear.
171
+ # On an experiment on WSJ with all the best settings as of early
172
+ # October 2022, such as a Bert model POS tagger:
173
+ # MAX 0.9597
174
+ # UNTIED_MAX 0.9592
175
+ # Furthermore, starting from a finished MAX model and restarting
176
+ # by splitting the MAX layer into multiple pieces did not improve.
177
+ #
178
+ # KEY has a single Key which is used for a facsimile of ATTN
179
+ # each incoming subtree has its values weighted by a Query
180
+ # then the Key is used to calculate a softmax
181
+ # finally, a Value is used to scale the subtrees
182
+ # reduce_heads is used to determine the number of heads
183
+ # There is an option to use or not use position information
184
+ # using a sinusoidal position embedding
185
+ # UNTIED_KEY is the same, but has a different key
186
+ # for each possible constituent
187
+ # On a VI dataset:
188
+ # MAX 0.82064
189
+ # KEY (pos, 8) 0.81739
190
+ # UNTIED_KEY (pos, 8) 0.82046
191
+ # UNTIED_KEY (pos, 4) 0.81742
192
+ # Attempted to add a linear to mix the attn heads together,
193
+ # but that was awful: 0.81567
194
+ # Adding two position vectors, one in each direction, did not help:
195
+ # UNTIED_KEY (2x pos, 8) 0.8188
196
+ # To redo that experiment, double the width of reduce_query and
197
+ # reduce_value, then call reduce_position on nhx, flip it,
198
+ # and call reduce_position again
199
+ # Evidently the experiments to try should be:
200
+ # no pos at all
201
+ # more heads
202
+ class ConstituencyComposition(Enum):
203
+ BILSTM = 1
204
+ MAX = 2
205
+ TREE_LSTM = 3
206
+ BILSTM_MAX = 4
207
+ BIGRAM = 5
208
+ ATTN = 6
209
+ TREE_LSTM_CX = 7
210
+ UNTIED_MAX = 8
211
+ KEY = 9
212
+ UNTIED_KEY = 10
213
+
214
+ class LSTMModel(BaseModel, nn.Module):
215
+ def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_tokenizer, force_bert_saved, peft_name, transitions, constituents, tags, words, rare_words, root_labels, constituent_opens, unary_limit, args):
216
+ """
217
+ pretrain: a Pretrain object
218
+ transitions: a list of all possible transitions which will be
219
+ used to build trees
220
+ constituents: a list of all possible constituents in the treebank
221
+ tags: a list of all possible tags in the treebank
222
+ words: a list of all known words, used for a delta word embedding.
223
+ note that there will be an attempt made to learn UNK words as well,
224
+ and tags by themselves may help UNK words
225
+ rare_words: a list of rare words, used to occasionally replace with UNK
226
+ root_labels: probably ROOT, although apparently some treebanks like TOP or even s
227
+ constituent_opens: a list of all possible open nodes which will go on the stack
228
+ - this might be different from constituents if there are nodes
229
+ which represent multiple constituents at once
230
+ args: hidden_size, transition_hidden_size, etc as gotten from
231
+ constituency_parser.py
232
+
233
+ Note that it might look like a hassle to pass all of this in
234
+ when it can be collected directly from the trees themselves.
235
+ However, that would only work at train time. At eval or
236
+ pipeline time we will load the lists from the saved model.
237
+ """
238
+ super().__init__(transition_scheme=args['transition_scheme'], unary_limit=unary_limit, reverse_sentence=args.get('reversed', False), root_labels=root_labels)
239
+
240
+ self.args = args
241
+ self.unsaved_modules = []
242
+
243
+ emb_matrix = pretrain.emb
244
+ self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
245
+
246
+ # replacing NBSP picks up a whole bunch of words for VI
247
+ self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }
248
+ # precompute tensors for the word indices
249
+ # the tensors should be put on the GPU if needed by calling to(device)
250
+ self.register_buffer('vocab_tensors', torch.tensor(range(len(pretrain.vocab)), requires_grad=False))
251
+ self.vocab_size = emb_matrix.shape[0]
252
+ self.embedding_dim = emb_matrix.shape[1]
253
+
254
+ self.constituents = sorted(list(constituents))
255
+
256
+ self.hidden_size = self.args['hidden_size']
257
+ self.constituency_composition = self.args.get("constituency_composition", ConstituencyComposition.BILSTM)
258
+ if self.constituency_composition in (ConstituencyComposition.ATTN, ConstituencyComposition.KEY, ConstituencyComposition.UNTIED_KEY):
259
+ self.reduce_heads = self.args['reduce_heads']
260
+ if self.hidden_size % self.reduce_heads != 0:
261
+ self.hidden_size = self.hidden_size + self.reduce_heads - (self.hidden_size % self.reduce_heads)
262
+
263
+ if args['constituent_stack'] == StackHistory.ATTN:
264
+ self.reduce_heads = self.args['reduce_heads']
265
+ if self.hidden_size % args['constituent_heads'] != 0:
266
+ # TODO: technically we should either use the LCM of this and reduce_heads, or just have two separate fields
267
+ self.hidden_size = self.hidden_size + args['constituent_heads'] - (hidden_size % args['constituent_heads'])
268
+ if self.constituency_composition == ConstituencyComposition.ATTN and self.hidden_size % self.reduce_heads != 0:
269
+ raise ValueError("--reduce_heads and --constituent_heads not compatible!")
270
+
271
+ self.transition_hidden_size = self.args['transition_hidden_size']
272
+ if args['transition_stack'] == StackHistory.ATTN:
273
+ if self.transition_hidden_size % args['transition_heads'] > 0:
274
+ logger.warning("transition_hidden_size %d %% transition_heads %d != 0. reconfiguring", transition_hidden_size, args['transition_heads'])
275
+ self.transition_hidden_size = self.transition_hidden_size + args['transition_heads'] - (self.transition_hidden_size % args['transition_heads'])
276
+
277
+ self.tag_embedding_dim = self.args['tag_embedding_dim']
278
+ self.transition_embedding_dim = self.args['transition_embedding_dim']
279
+ self.delta_embedding_dim = self.args['delta_embedding_dim']
280
+
281
+ self.word_input_size = self.embedding_dim + self.tag_embedding_dim + self.delta_embedding_dim
282
+
283
+ if forward_charlm is not None:
284
+ self.add_unsaved_module('forward_charlm', forward_charlm)
285
+ self.word_input_size += self.forward_charlm.hidden_dim()
286
+ if not forward_charlm.is_forward_lm:
287
+ raise ValueError("Got a backward charlm as a forward charlm!")
288
+ else:
289
+ self.forward_charlm = None
290
+ if backward_charlm is not None:
291
+ self.add_unsaved_module('backward_charlm', backward_charlm)
292
+ self.word_input_size += self.backward_charlm.hidden_dim()
293
+ if backward_charlm.is_forward_lm:
294
+ raise ValueError("Got a forward charlm as a backward charlm!")
295
+ else:
296
+ self.backward_charlm = None
297
+
298
+ self.delta_words = sorted(set(words))
299
+ self.delta_word_map = { word: i+2 for i, word in enumerate(self.delta_words) }
300
+ assert PAD_ID == 0
301
+ assert UNK_ID == 1
302
+ # initialization is chosen based on the observed values of the norms
303
+ # after several long training cycles
304
+ # (this is true for other embeddings and embedding-like vectors as well)
305
+ # the experiments show this slightly helps were done with
306
+ # Adadelta and the correct initialization may be slightly
307
+ # different for a different optimizer.
308
+ # in fact, it is likely a scheme other than normal_ would
309
+ # be better - the optimizer tends to learn the weights
310
+ # rather close to 0 before learning in the direction it
311
+ # actually wants to go
312
+ self.delta_embedding = nn.Embedding(num_embeddings = len(self.delta_words)+2,
313
+ embedding_dim = self.delta_embedding_dim,
314
+ padding_idx = 0)
315
+ nn.init.normal_(self.delta_embedding.weight, std=0.05)
316
+ self.register_buffer('delta_tensors', torch.tensor(range(len(self.delta_words) + 2), requires_grad=False))
317
+
318
+ self.rare_words = set(rare_words)
319
+
320
+ self.tags = sorted(list(tags))
321
+ if self.tag_embedding_dim > 0:
322
+ self.tag_map = { t: i+2 for i, t in enumerate(self.tags) }
323
+ self.tag_embedding = nn.Embedding(num_embeddings = len(tags)+2,
324
+ embedding_dim = self.tag_embedding_dim,
325
+ padding_idx = 0)
326
+ nn.init.normal_(self.tag_embedding.weight, std=0.25)
327
+ self.register_buffer('tag_tensors', torch.tensor(range(len(self.tags) + 2), requires_grad=False))
328
+
329
+ self.num_lstm_layers = self.args['num_lstm_layers']
330
+ self.num_tree_lstm_layers = self.args['num_tree_lstm_layers']
331
+ self.lstm_layer_dropout = self.args['lstm_layer_dropout']
332
+
333
+ self.word_dropout = nn.Dropout(self.args['word_dropout'])
334
+ self.predict_dropout = nn.Dropout(self.args['predict_dropout'])
335
+ self.lstm_input_dropout = nn.Dropout(self.args['lstm_input_dropout'])
336
+
337
+ # also register a buffer of zeros so that we can always get zeros on the appropriate device
338
+ self.register_buffer('word_zeros', torch.zeros(self.hidden_size * self.num_tree_lstm_layers))
339
+ self.register_buffer('constituent_zeros', torch.zeros(self.num_lstm_layers, 1, self.hidden_size))
340
+
341
+ # possibly add a couple vectors for bookends of the sentence
342
+ # We put the word_start and word_end here, AFTER counting the
343
+ # charlm dimension, but BEFORE counting the bert dimension,
344
+ # as we want word_start and word_end to not have dimensions
345
+ # for the bert embedding. The bert model will add its own
346
+ # start and end representation.
347
+ self.sentence_boundary_vectors = self.args['sentence_boundary_vectors']
348
+ if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
349
+ self.register_parameter('word_start_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))
350
+ self.register_parameter('word_end_embedding', torch.nn.Parameter(0.2 * torch.randn(self.word_input_size, requires_grad=True)))
351
+
352
+ # we set up the bert AFTER building word_start and word_end
353
+ # so that we can use the charlm endpoint values rather than
354
+ # try to train our own
355
+ self.force_bert_saved = force_bert_saved or self.args['bert_finetune'] or self.args['stage1_bert_finetune']
356
+ attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), self.force_bert_saved)
357
+ self.peft_name = peft_name
358
+
359
+ if bert_model is not None:
360
+ if bert_tokenizer is None:
361
+ raise ValueError("Cannot have a bert model without a tokenizer")
362
+ self.bert_dim = self.bert_model.config.hidden_size
363
+ if args['bert_hidden_layers']:
364
+ # The average will be offset by 1/N so that the default zeros
365
+ # represents an average of the N layers
366
+ if args['bert_hidden_layers'] > bert_model.config.num_hidden_layers:
367
+ # limit ourselves to the number of layers actually available
368
+ # note that we can +1 because of the initial embedding layer
369
+ args['bert_hidden_layers'] = bert_model.config.num_hidden_layers + 1
370
+ self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
371
+ nn.init.zeros_(self.bert_layer_mix.weight)
372
+ else:
373
+ # an average of layers 2, 3, 4 will be used
374
+ # (for historic reasons)
375
+ self.bert_layer_mix = None
376
+ self.word_input_size = self.word_input_size + self.bert_dim
377
+
378
+ self.partitioned_transformer_module = None
379
+ self.pattn_d_model = 0
380
+ if LSTMModel.uses_pattn(self.args):
381
+ # Initializations of parameters for the Partitioned Attention
382
+ # round off the size of the model so that it divides in half evenly
383
+ self.pattn_d_model = self.args['pattn_d_model'] // 2 * 2
384
+
385
+ # Initializations for the Partitioned Attention
386
+ # experiments suggest having a bias does not help here
387
+ self.partitioned_transformer_module = PartitionedTransformerModule(
388
+ self.args['pattn_num_layers'],
389
+ d_model=self.pattn_d_model,
390
+ n_head=self.args['pattn_num_heads'],
391
+ d_qkv=self.args['pattn_d_kv'],
392
+ d_ff=self.args['pattn_d_ff'],
393
+ ff_dropout=self.args['pattn_relu_dropout'],
394
+ residual_dropout=self.args['pattn_residual_dropout'],
395
+ attention_dropout=self.args['pattn_attention_dropout'],
396
+ word_input_size=self.word_input_size,
397
+ bias=self.args['pattn_bias'],
398
+ morpho_emb_dropout=self.args['pattn_morpho_emb_dropout'],
399
+ timing=self.args['pattn_timing'],
400
+ encoder_max_len=self.args['pattn_encoder_max_len']
401
+ )
402
+ self.word_input_size += self.pattn_d_model
403
+
404
+ self.label_attention_module = None
405
+ if LSTMModel.uses_lattn(self.args):
406
+ if self.partitioned_transformer_module is None:
407
+ logger.error("Not using Labeled Attention, as the Partitioned Attention module is not used")
408
+ else:
409
+ # TODO: think of a couple ways to use alternate inputs
410
+ # for example, could pass in the word inputs with a positional embedding
411
+ # that would also allow it to work in the case of no partitioned module
412
+ if self.args['lattn_combined_input']:
413
+ self.lattn_d_input = self.word_input_size
414
+ else:
415
+ self.lattn_d_input = self.pattn_d_model
416
+ self.label_attention_module = LabelAttentionModule(self.lattn_d_input,
417
+ self.args['lattn_d_input_proj'],
418
+ self.args['lattn_d_kv'],
419
+ self.args['lattn_d_kv'],
420
+ self.args['lattn_d_l'],
421
+ self.args['lattn_d_proj'],
422
+ self.args['lattn_combine_as_self'],
423
+ self.args['lattn_resdrop'],
424
+ self.args['lattn_q_as_matrix'],
425
+ self.args['lattn_residual_dropout'],
426
+ self.args['lattn_attention_dropout'],
427
+ self.pattn_d_model // 2,
428
+ self.args['lattn_d_ff'],
429
+ self.args['lattn_relu_dropout'],
430
+ self.args['lattn_partitioned'])
431
+ self.word_input_size = self.word_input_size + self.args['lattn_d_proj']*self.args['lattn_d_l']
432
+
433
+ self.word_lstm = nn.LSTM(input_size=self.word_input_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
434
+
435
+ # after putting the word_delta_tag input through the word_lstm, we get back
436
+ # hidden_size * 2 output with the front and back lstms concatenated.
437
+ # this transforms it into hidden_size with the values mixed together
438
+ self.word_to_constituent = nn.Linear(self.hidden_size * 2, self.hidden_size * self.num_tree_lstm_layers)
439
+ initialize_linear(self.word_to_constituent, self.args['nonlinearity'], self.hidden_size * 2)
440
+
441
+ self.transitions = sorted(list(transitions))
442
+ self.transition_map = { t: i for i, t in enumerate(self.transitions) }
443
+ # precompute tensors for the transitions
444
+ self.register_buffer('transition_tensors', torch.tensor(range(len(transitions)), requires_grad=False))
445
+ self.transition_embedding = nn.Embedding(num_embeddings = len(transitions),
446
+ embedding_dim = self.transition_embedding_dim)
447
+ nn.init.normal_(self.transition_embedding.weight, std=0.25)
448
+ if args['transition_stack'] == StackHistory.LSTM:
449
+ self.transition_stack = LSTMTreeStack(input_size=self.transition_embedding_dim,
450
+ hidden_size=self.transition_hidden_size,
451
+ num_lstm_layers=self.num_lstm_layers,
452
+ dropout=self.lstm_layer_dropout,
453
+ uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,
454
+ input_dropout=self.lstm_input_dropout)
455
+ elif args['transition_stack'] == StackHistory.ATTN:
456
+ self.transition_stack = TransformerTreeStack(input_size=self.transition_embedding_dim,
457
+ output_size=self.transition_hidden_size,
458
+ input_dropout=self.lstm_input_dropout,
459
+ use_position=True,
460
+ num_heads=args['transition_heads'])
461
+ else:
462
+ raise ValueError("Unhandled transition_stack StackHistory: {}".format(args['transition_stack']))
463
+
464
+ self.constituent_opens = sorted(list(constituent_opens))
465
+ # an embedding for the spot on the constituent LSTM taken up by the Open transitions
466
+ # the pattern when condensing constituents is embedding - con1 - con2 - con3 - embedding
467
+ # TODO: try the two ends have different embeddings?
468
+ self.constituent_open_map = { x: i for (i, x) in enumerate(self.constituent_opens) }
469
+ self.constituent_open_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),
470
+ embedding_dim = self.hidden_size)
471
+ nn.init.normal_(self.constituent_open_embedding.weight, std=0.2)
472
+
473
+ # input_size is hidden_size - could introduce a new constituent_size instead if we liked
474
+ if args['constituent_stack'] == StackHistory.LSTM:
475
+ self.constituent_stack = LSTMTreeStack(input_size=self.hidden_size,
476
+ hidden_size=self.hidden_size,
477
+ num_lstm_layers=self.num_lstm_layers,
478
+ dropout=self.lstm_layer_dropout,
479
+ uses_boundary_vector=self.sentence_boundary_vectors is SentenceBoundary.EVERYTHING,
480
+ input_dropout=self.lstm_input_dropout)
481
+ elif args['constituent_stack'] == StackHistory.ATTN:
482
+ self.constituent_stack = TransformerTreeStack(input_size=self.hidden_size,
483
+ output_size=self.hidden_size,
484
+ input_dropout=self.lstm_input_dropout,
485
+ use_position=True,
486
+ num_heads=args['constituent_heads'])
487
+ else:
488
+ raise ValueError("Unhandled constituent_stack StackHistory: {}".format(args['transition_stack']))
489
+
490
+
491
+ if args['combined_dummy_embedding']:
492
+ self.dummy_embedding = self.constituent_open_embedding
493
+ else:
494
+ self.dummy_embedding = nn.Embedding(num_embeddings = len(self.constituent_open_map),
495
+ embedding_dim = self.hidden_size)
496
+ nn.init.normal_(self.dummy_embedding.weight, std=0.2)
497
+ self.register_buffer('constituent_open_tensors', torch.tensor(range(len(constituent_opens)), requires_grad=False))
498
+
499
+ # TODO: refactor
500
+ if (self.constituency_composition == ConstituencyComposition.BILSTM or
501
+ self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
502
+ # forward and backward pieces for crunching several
503
+ # constituents into one, combined into a bi-lstm
504
+ # TODO: make the hidden size here an option?
505
+ self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_lstm_layers, bidirectional=True, dropout=self.lstm_layer_dropout)
506
+ # affine transformation from bi-lstm reduce to a new hidden layer
507
+ if self.constituency_composition == ConstituencyComposition.BILSTM:
508
+ self.reduce_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)
509
+ initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size * 2)
510
+ else:
511
+ self.reduce_forward = nn.Linear(self.hidden_size, self.hidden_size)
512
+ self.reduce_backward = nn.Linear(self.hidden_size, self.hidden_size)
513
+ initialize_linear(self.reduce_forward, self.args['nonlinearity'], self.hidden_size)
514
+ initialize_linear(self.reduce_backward, self.args['nonlinearity'], self.hidden_size)
515
+ elif self.constituency_composition == ConstituencyComposition.MAX:
516
+ # transformation to turn several constituents into one new constituent
517
+ self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)
518
+ initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)
519
+ elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
520
+ # transformation to turn several constituents into one new constituent
521
+ self.register_parameter('reduce_linear_weight', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, self.hidden_size, requires_grad=True)))
522
+ self.register_parameter('reduce_linear_bias', torch.nn.Parameter(torch.randn(len(constituent_opens), self.hidden_size, requires_grad=True)))
523
+ for layer_idx in range(len(constituent_opens)):
524
+ nn.init.kaiming_normal_(self.reduce_linear_weight[layer_idx], nonlinearity=self.args['nonlinearity'])
525
+ nn.init.uniform_(self.reduce_linear_bias, 0, 1 / (self.hidden_size * 2) ** 0.5)
526
+ elif self.constituency_composition == ConstituencyComposition.BIGRAM:
527
+ self.reduce_linear = nn.Linear(self.hidden_size, self.hidden_size)
528
+ self.reduce_bigram = nn.Linear(self.hidden_size * 2, self.hidden_size)
529
+ initialize_linear(self.reduce_linear, self.args['nonlinearity'], self.hidden_size)
530
+ initialize_linear(self.reduce_bigram, self.args['nonlinearity'], self.hidden_size)
531
+ elif self.constituency_composition == ConstituencyComposition.ATTN:
532
+ self.reduce_attn = nn.MultiheadAttention(self.hidden_size, self.reduce_heads)
533
+ elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
534
+ if self.args['reduce_position']:
535
+ # unsaved module so that if it grows, we don't save
536
+ # the larger version unnecessarily
537
+ # under any normal circumstances, the growth will
538
+ # happen early in training when the model is not
539
+ # behaving well, then will not be needed once the
540
+ # model learns not to make super degenerate
541
+ # constituents
542
+ self.add_unsaved_module("reduce_position", ConcatSinusoidalEncoding(self.args['reduce_position'], 50))
543
+ else:
544
+ self.add_unsaved_module("reduce_position", nn.Identity())
545
+ self.reduce_query = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size, bias=False)
546
+ self.reduce_value = nn.Linear(self.hidden_size + self.args['reduce_position'], self.hidden_size)
547
+ if self.constituency_composition == ConstituencyComposition.KEY:
548
+ self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
549
+ else:
550
+ self.register_parameter('reduce_key', torch.nn.Parameter(torch.randn(len(constituent_opens), self.reduce_heads, self.hidden_size // self.reduce_heads, 1, requires_grad=True)))
551
+ elif self.constituency_composition == ConstituencyComposition.TREE_LSTM:
552
+ self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
553
+ elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:
554
+ self.constituent_reduce_embedding = nn.Embedding(num_embeddings = len(tags)+2,
555
+ embedding_dim = self.num_tree_lstm_layers * self.hidden_size)
556
+ self.constituent_reduce_lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=self.num_tree_lstm_layers, dropout=self.lstm_layer_dropout)
557
+ else:
558
+ raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
559
+
560
+ self.nonlinearity = build_nonlinearity(self.args['nonlinearity'])
561
+
562
+ # matrix for predicting the next transition using word/constituent/transition queues
563
+ # word size + constituency size + transition size
564
+ # TODO: .get() is only necessary until all models rebuilt with this param
565
+ self.maxout_k = self.args.get('maxout_k', 0)
566
+ self.output_layers = self.build_output_layers(self.args['num_output_layers'], len(transitions), self.maxout_k)
567
+
568
+ @staticmethod
569
+ def uses_lattn(args):
570
+ return args.get('use_lattn', True) and args.get('lattn_d_proj', 0) > 0 and args.get('lattn_d_l', 0) > 0
571
+
572
+ @staticmethod
573
+ def uses_pattn(args):
574
+ return args['pattn_num_heads'] > 0 and args['pattn_num_layers'] > 0
575
+
576
+ def copy_with_new_structure(self, other):
577
+ """
578
+ Copy parameters from the other model to this model
579
+
580
+ word_lstm can change size if the other model didn't use pattn / lattn and this one does.
581
+ In that case, the new values are initialized to 0.
582
+ This will rebuild the model in such a way that the outputs will be
583
+ exactly the same as the previous model.
584
+ """
585
+ if self.constituency_composition != other.constituency_composition and self.constituency_composition != ConstituencyComposition.UNTIED_MAX:
586
+ raise ValueError("Models are incompatible: self.constituency_composition == {}, other.constituency_composition == {}".format(self.constituency_composition, other.constituency_composition))
587
+ for name, other_parameter in other.named_parameters():
588
+ # this allows other.constituency_composition == UNTIED_MAX to fall through
589
+ if name.startswith('reduce_linear.') and self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
590
+ if name == 'reduce_linear.weight':
591
+ my_parameter = self.reduce_linear_weight
592
+ elif name == 'reduce_linear.bias':
593
+ my_parameter = self.reduce_linear_bias
594
+ else:
595
+ raise ValueError("Unexpected other parameter name {}".format(name))
596
+ for idx in range(len(self.constituent_opens)):
597
+ my_parameter[idx].data.copy_(other_parameter.data)
598
+ elif name.startswith('word_lstm.weight_ih_l0'):
599
+ # bottom layer shape may have changed from adding a new pattn / lattn block
600
+ my_parameter = self.get_parameter(name)
601
+ # -1 so that it can be converted easier to a different parameter
602
+ copy_size = min(other_parameter.data.shape[-1], my_parameter.data.shape[-1])
603
+ #new_values = my_parameter.data.clone().detach()
604
+ new_values = torch.zeros_like(my_parameter.data)
605
+ new_values[..., :copy_size] = other_parameter.data[..., :copy_size]
606
+ my_parameter.data.copy_(new_values)
607
+ else:
608
+ try:
609
+ self.get_parameter(name).data.copy_(other_parameter.data)
610
+ except AttributeError as e:
611
+ raise AttributeError("Could not process %s" % name) from e
612
+
613
+ def build_output_layers(self, num_output_layers, final_layer_size, maxout_k):
614
+ """
615
+ Build a ModuleList of Linear transformations for the given num_output_layers
616
+
617
+ The final layer size can be specified.
618
+ Initial layer size is the combination of word, constituent, and transition vectors
619
+ Middle layer sizes are self.hidden_size
620
+ """
621
+ middle_layers = num_output_layers - 1
622
+ # word_lstm: hidden_size * num_tree_lstm_layers
623
+ # transition_stack: transition_hidden_size
624
+ # constituent_stack: hidden_size
625
+ predict_input_size = [self.hidden_size + self.hidden_size * self.num_tree_lstm_layers + self.transition_hidden_size] + [self.hidden_size] * middle_layers
626
+ predict_output_size = [self.hidden_size] * middle_layers + [final_layer_size]
627
+ if not maxout_k:
628
+ output_layers = nn.ModuleList([nn.Linear(input_size, output_size)
629
+ for input_size, output_size in zip(predict_input_size, predict_output_size)])
630
+ for output_layer, input_size in zip(output_layers, predict_input_size):
631
+ initialize_linear(output_layer, self.args['nonlinearity'], input_size)
632
+ else:
633
+ output_layers = nn.ModuleList([MaxoutLinear(input_size, output_size, maxout_k)
634
+ for input_size, output_size in zip(predict_input_size, predict_output_size)])
635
+ return output_layers
636
+
637
+ def num_words_known(self, words):
638
+ return sum(word in self.vocab_map or word.lower() in self.vocab_map for word in words)
639
+
640
+ @property
641
+ def retag_method(self):
642
+ # TODO: make the method an enum
643
+ return self.args['retag_method']
644
+
645
+ def uses_xpos(self):
646
+ return self.args['retag_package'] is not None and self.args['retag_method'] == 'xpos'
647
+
648
+ def add_unsaved_module(self, name, module):
649
+ """
650
+ Adds a module which will not be saved to disk
651
+
652
+ Best used for large models such as pretrained word embeddings
653
+ """
654
+ self.unsaved_modules += [name]
655
+ setattr(self, name, module)
656
+ if module is not None and name in ('forward_charlm', 'backward_charlm'):
657
+ for _, parameter in module.named_parameters():
658
+ parameter.requires_grad = False
659
+
660
+ def is_unsaved_module(self, name):
661
+ return name.split('.')[0] in self.unsaved_modules
662
+
663
+ def get_norms(self):
664
+ lines = []
665
+ skip = set()
666
+ if self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
667
+ skip = {'reduce_linear_weight', 'reduce_linear_bias'}
668
+ lines.append("reduce_linear:")
669
+ for c_idx, c_open in enumerate(self.constituent_opens):
670
+ lines.append(" %s weight %.6g bias %.6g" % (c_open, torch.norm(self.reduce_linear_weight[c_idx]).item(), torch.norm(self.reduce_linear_bias[c_idx]).item()))
671
+ active_params = [(name, param) for name, param in self.named_parameters() if param.requires_grad and name not in skip]
672
+ if len(active_params) == 0:
673
+ return lines
674
+ print(len(active_params))
675
+
676
+ max_name_len = max(len(name) for name, param in active_params)
677
+ max_norm_len = max(len("%.6g" % torch.norm(param).item()) for name, param in active_params)
678
+ format_string = "%-" + str(max_name_len) + "s norm %" + str(max_norm_len) + "s zeros %d / %d"
679
+ for name, param in active_params:
680
+ zeros = torch.sum(param.abs() < 0.000001).item()
681
+ norm = "%.6g" % torch.norm(param).item()
682
+ lines.append(format_string % (name, norm, zeros, param.nelement()))
683
+ return lines
684
+
685
+ def log_norms(self):
686
+ lines = ["NORMS FOR MODEL PARAMETERS"]
687
+ lines.extend(self.get_norms())
688
+ logger.info("\n".join(lines))
689
+
690
+ def log_shapes(self):
691
+ lines = ["NORMS FOR MODEL PARAMETERS"]
692
+ for name, param in self.named_parameters():
693
+ if param.requires_grad:
694
+ lines.append("{} {}".format(name, param.shape))
695
+ logger.info("\n".join(lines))
696
+
697
+ def initial_word_queues(self, tagged_word_lists):
698
+ """
699
+ Produce initial word queues out of the model's LSTMs for use in the tagged word lists.
700
+
701
+ Operates in a batched fashion to reduce the runtime for the LSTM operations
702
+ """
703
+ device = next(self.parameters()).device
704
+
705
+ vocab_map = self.vocab_map
706
+ def map_word(word):
707
+ idx = vocab_map.get(word, None)
708
+ if idx is not None:
709
+ return idx
710
+ return vocab_map.get(word.lower(), UNK_ID)
711
+
712
+ all_word_inputs = []
713
+ all_word_labels = [[word.children[0].label for word in tagged_words]
714
+ for tagged_words in tagged_word_lists]
715
+
716
+ for sentence_idx, tagged_words in enumerate(tagged_word_lists):
717
+ word_labels = all_word_labels[sentence_idx]
718
+ word_idx = torch.stack([self.vocab_tensors[map_word(word.children[0].label)] for word in tagged_words])
719
+ word_input = self.embedding(word_idx)
720
+
721
+ # this occasionally learns UNK at train time
722
+ if self.training:
723
+ delta_labels = [None if word in self.rare_words and random.random() < self.args['rare_word_unknown_frequency'] else word
724
+ for word in word_labels]
725
+ else:
726
+ delta_labels = word_labels
727
+ delta_idx = torch.stack([self.delta_tensors[self.delta_word_map.get(word, UNK_ID)] for word in delta_labels])
728
+
729
+ delta_input = self.delta_embedding(delta_idx)
730
+ word_inputs = [word_input, delta_input]
731
+
732
+ if self.tag_embedding_dim > 0:
733
+ if self.training:
734
+ tag_labels = [None if random.random() < self.args['tag_unknown_frequency'] else word.label for word in tagged_words]
735
+ else:
736
+ tag_labels = [word.label for word in tagged_words]
737
+ tag_idx = torch.stack([self.tag_tensors[self.tag_map.get(tag, UNK_ID)] for tag in tag_labels])
738
+ tag_input = self.tag_embedding(tag_idx)
739
+ word_inputs.append(tag_input)
740
+
741
+ all_word_inputs.append(word_inputs)
742
+
743
+ if self.forward_charlm is not None:
744
+ all_forward_chars = self.forward_charlm.build_char_representation(all_word_labels)
745
+ for word_inputs, forward_chars in zip(all_word_inputs, all_forward_chars):
746
+ word_inputs.append(forward_chars)
747
+ if self.backward_charlm is not None:
748
+ all_backward_chars = self.backward_charlm.build_char_representation(all_word_labels)
749
+ for word_inputs, backward_chars in zip(all_word_inputs, all_backward_chars):
750
+ word_inputs.append(backward_chars)
751
+
752
+ all_word_inputs = [torch.cat(word_inputs, dim=1) for word_inputs in all_word_inputs]
753
+ if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
754
+ word_start = self.word_start_embedding.unsqueeze(0)
755
+ word_end = self.word_end_embedding.unsqueeze(0)
756
+ all_word_inputs = [torch.cat([word_start, word_inputs, word_end], dim=0) for word_inputs in all_word_inputs]
757
+
758
+ if self.bert_model is not None:
759
+ # BERT embedding extraction
760
+ # result will be len+2 for each sentence
761
+ # we will take 1:-1 if we don't care about the endpoints
762
+ bert_embeddings = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, all_word_labels, device,
763
+ keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE,
764
+ num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
765
+ detach=not self.args['bert_finetune'] and not self.args['stage1_bert_finetune'],
766
+ peft_name=self.peft_name)
767
+ if self.bert_layer_mix is not None:
768
+ # add the average so that the default behavior is to
769
+ # take an average of the N layers, and anything else
770
+ # other than that needs to be learned
771
+ bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]
772
+
773
+ all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)]
774
+
775
+ # Extract partitioned representation
776
+ if self.partitioned_transformer_module is not None:
777
+ partitioned_embeddings = self.partitioned_transformer_module(None, all_word_inputs)
778
+ all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, partitioned_embeddings)]
779
+
780
+ # Extract Labeled Representation
781
+ if self.label_attention_module is not None:
782
+ if self.args['lattn_combined_input']:
783
+ labeled_representations = self.label_attention_module(all_word_inputs, tagged_word_lists)
784
+ else:
785
+ labeled_representations = self.label_attention_module(partitioned_embeddings, tagged_word_lists)
786
+ all_word_inputs = [torch.cat((x, y[:x.shape[0], :]), axis=1) for x, y in zip(all_word_inputs, labeled_representations)]
787
+
788
+ all_word_inputs = [self.word_dropout(word_inputs) for word_inputs in all_word_inputs]
789
+ packed_word_input = torch.nn.utils.rnn.pack_sequence(all_word_inputs, enforce_sorted=False)
790
+ word_output, _ = self.word_lstm(packed_word_input)
791
+ # would like to do word_to_constituent here, but it seems PackedSequence doesn't support Linear
792
+ # word_output will now be sentence x batch x 2*hidden_size
793
+ word_output, word_output_lens = torch.nn.utils.rnn.pad_packed_sequence(word_output)
794
+ # now sentence x batch x hidden_size
795
+
796
+ word_queues = []
797
+ for sentence_idx, tagged_words in enumerate(tagged_word_lists):
798
+ if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
799
+ sentence_output = word_output[:len(tagged_words)+2, sentence_idx, :]
800
+ else:
801
+ sentence_output = word_output[:len(tagged_words), sentence_idx, :]
802
+ sentence_output = self.word_to_constituent(sentence_output)
803
+ sentence_output = self.nonlinearity(sentence_output)
804
+ # TODO: this makes it so constituents downstream are
805
+ # build with the outputs of the LSTM, not the word
806
+ # embeddings themselves. It is possible we want to
807
+ # transform the word_input to hidden_size in some way
808
+ # and use that instead
809
+ if self.sentence_boundary_vectors is not SentenceBoundary.NONE:
810
+ word_queue = [WordNode(None, sentence_output[0, :])]
811
+ word_queue += [WordNode(tag_node, sentence_output[idx+1, :])
812
+ for idx, tag_node in enumerate(tagged_words)]
813
+ word_queue.append(WordNode(None, sentence_output[len(tagged_words)+1, :]))
814
+ else:
815
+ word_queue = [WordNode(None, self.word_zeros)]
816
+ word_queue += [WordNode(tag_node, sentence_output[idx, :])
817
+ for idx, tag_node in enumerate(tagged_words)]
818
+ word_queue.append(WordNode(None, self.word_zeros))
819
+
820
+ if self.reverse_sentence:
821
+ word_queue = list(reversed(word_queue))
822
+ word_queues.append(word_queue)
823
+
824
+ return word_queues
825
+
826
+ def initial_transitions(self):
827
+ """
828
+ Return an initial TreeStack with no transitions
829
+ """
830
+ return self.transition_stack.initial_state()
831
+
832
+ def initial_constituents(self):
833
+ """
834
+ Return an initial TreeStack with no constituents
835
+ """
836
+ return self.constituent_stack.initial_state(Constituent(None, self.constituent_zeros, self.constituent_zeros))
837
+
838
+ def get_word(self, word_node):
839
+ return word_node.value
840
+
841
+ def transform_word_to_constituent(self, state):
842
+ word_node = state.get_word(state.word_position)
843
+ word = word_node.value
844
+ if self.constituency_composition == ConstituencyComposition.TREE_LSTM:
845
+ return Constituent(word, word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size), self.word_zeros.view(self.num_tree_lstm_layers, self.hidden_size))
846
+ elif self.constituency_composition == ConstituencyComposition.TREE_LSTM_CX:
847
+ # the UNK tag will be trained thanks to occasionally dropping out tags
848
+ tag = word.label
849
+ tree_hx = word_node.hx.view(self.num_tree_lstm_layers, self.hidden_size)
850
+ tag_tensor = self.tag_tensors[self.tag_map.get(tag, UNK_ID)]
851
+ tree_cx = self.constituent_reduce_embedding(tag_tensor)
852
+ tree_cx = tree_cx.view(self.num_tree_lstm_layers, self.hidden_size)
853
+ return Constituent(word, tree_hx, tree_cx * tree_hx)
854
+ else:
855
+ return Constituent(word, word_node.hx[:self.hidden_size].unsqueeze(0), None)
856
+
857
+ def dummy_constituent(self, dummy):
858
+ label = dummy.label
859
+ open_index = self.constituent_open_tensors[self.constituent_open_map[label]]
860
+ hx = self.dummy_embedding(open_index)
861
+ # the cx doesn't matter: the dummy will be discarded when building a new constituent
862
+ return Constituent(dummy, hx.unsqueeze(0), None)
863
+
864
+ def build_constituents(self, labels, children_lists):
865
+ """
866
+ Build new constituents with the given label from the list of children
867
+
868
+ labels is a list of labels for each of the new nodes to construct
869
+ children_lists is a list of children that go under each of the new nodes
870
+ lists of each are used so that we can stack operations
871
+ """
872
+ # at the end of each of these operations, we expect lstm_hx.shape
873
+ # is (L, N, hidden_size) for N lists of children
874
+ if (self.constituency_composition == ConstituencyComposition.BILSTM or
875
+ self.constituency_composition == ConstituencyComposition.BILSTM_MAX):
876
+ node_hx = [[child.value.tree_hx.squeeze(0) for child in children] for children in children_lists]
877
+ label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]
878
+
879
+ max_length = max(len(children) for children in children_lists)
880
+ zeros = torch.zeros(self.hidden_size, device=label_hx[0].device)
881
+ # weirdly, this is faster than using pack_sequence
882
+ unpacked_hx = [[lhx] + nhx + [lhx] + [zeros] * (max_length - len(nhx)) for lhx, nhx in zip(label_hx, node_hx)]
883
+ unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in unpacked_hx]
884
+ packed_hx = torch.stack(unpacked_hx, axis=1)
885
+ packed_hx = torch.nn.utils.rnn.pack_padded_sequence(packed_hx, [len(x)+2 for x in children_lists], enforce_sorted=False)
886
+ lstm_output = self.constituent_reduce_lstm(packed_hx)
887
+ # take just the output of the final layer
888
+ # result of lstm is ouput, (hx, cx)
889
+ # so [1][0] gets hx
890
+ # [1][0][-1] is the final output
891
+ # will be shape len(children_lists) * 2, hidden_size for bidirectional
892
+ # where forward outputs are -2 and backwards are -1
893
+ if self.constituency_composition == ConstituencyComposition.BILSTM:
894
+ lstm_output = lstm_output[1][0]
895
+ forward_hx = lstm_output[-2, :, :]
896
+ backward_hx = lstm_output[-1, :, :]
897
+ hx = self.reduce_linear(torch.cat((forward_hx, backward_hx), axis=1))
898
+ else:
899
+ lstm_output, lstm_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_output[0])
900
+ lstm_output = [lstm_output[1:length-1, x, :] for x, length in zip(range(len(lstm_lengths)), lstm_lengths)]
901
+ lstm_output = torch.stack([torch.max(x, 0).values for x in lstm_output], axis=0)
902
+ hx = self.reduce_forward(lstm_output[:, :self.hidden_size]) + self.reduce_backward(lstm_output[:, self.hidden_size:])
903
+ lstm_hx = self.nonlinearity(hx).unsqueeze(0)
904
+ lstm_cx = None
905
+ elif self.constituency_composition == ConstituencyComposition.MAX:
906
+ node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
907
+ unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]
908
+ packed_hx = torch.stack(unpacked_hx, axis=1)
909
+ hx = self.reduce_linear(packed_hx)
910
+ lstm_hx = self.nonlinearity(hx)
911
+ lstm_cx = None
912
+ elif self.constituency_composition == ConstituencyComposition.UNTIED_MAX:
913
+ node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
914
+ unpacked_hx = [self.lstm_input_dropout(torch.max(torch.stack(nhx), 0).values) for nhx in node_hx]
915
+ # shape == len(labels),1,hidden_size after the stack
916
+ #packed_hx = torch.stack(unpacked_hx, axis=0)
917
+ label_indices = [self.constituent_open_map[label] for label in labels]
918
+ # we would like to stack the reduce_linear_weight calculations as follows:
919
+ #reduce_weight = self.reduce_linear_weight[label_indices]
920
+ #reduce_bias = self.reduce_linear_bias[label_indices]
921
+ # this would allow for faster vectorized operations.
922
+ # however, this runs out of memory on larger training examples,
923
+ # presumably because there are too many stacks in a row and each one
924
+ # has its own gradient kept for the entire calculation
925
+ # fortunately, this operation is not a huge part of the expense
926
+ hx = [torch.matmul(self.reduce_linear_weight[label_idx], hx_layer.squeeze(0)) + self.reduce_linear_bias[label_idx]
927
+ for label_idx, hx_layer in zip(label_indices, unpacked_hx)]
928
+ hx = torch.stack(hx, axis=0)
929
+ hx = hx.unsqueeze(0)
930
+ lstm_hx = self.nonlinearity(hx)
931
+ lstm_cx = None
932
+ elif self.constituency_composition == ConstituencyComposition.BIGRAM:
933
+ node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
934
+ unpacked_hx = []
935
+ for nhx in node_hx:
936
+ # tanh or otherwise limit the size of the output?
937
+ stacked_nhx = self.lstm_input_dropout(torch.cat(nhx, axis=0))
938
+ if stacked_nhx.shape[0] > 1:
939
+ bigram_hx = torch.cat((stacked_nhx[:-1, :], stacked_nhx[1:, :]), axis=1)
940
+ bigram_hx = self.reduce_bigram(bigram_hx) / 2
941
+ stacked_nhx = torch.cat((stacked_nhx, bigram_hx), axis=0)
942
+ unpacked_hx.append(torch.max(stacked_nhx, 0).values)
943
+ packed_hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)
944
+ hx = self.reduce_linear(packed_hx)
945
+ lstm_hx = self.nonlinearity(hx)
946
+ lstm_cx = None
947
+ elif self.constituency_composition == ConstituencyComposition.ATTN:
948
+ node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
949
+ label_hx = [self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]]) for label in labels]
950
+ unpacked_hx = [torch.stack(nhx) for nhx in node_hx]
951
+ unpacked_hx = [torch.cat((lhx.unsqueeze(0).unsqueeze(0), nhx), axis=0) for lhx, nhx in zip(label_hx, unpacked_hx)]
952
+ unpacked_hx = [self.reduce_attn(nhx, nhx, nhx)[0].squeeze(1) for nhx in unpacked_hx]
953
+ unpacked_hx = [self.lstm_input_dropout(torch.max(nhx, 0).values) for nhx in unpacked_hx]
954
+ hx = torch.stack(unpacked_hx, axis=0)
955
+ lstm_hx = self.nonlinearity(hx).unsqueeze(0)
956
+ lstm_cx = None
957
+ elif self.constituency_composition == ConstituencyComposition.KEY or self.constituency_composition == ConstituencyComposition.UNTIED_KEY:
958
+ node_hx = [torch.stack([child.value.tree_hx for child in children]) for children in children_lists]
959
+ # add a position vector to each node_hx
960
+ node_hx = [self.reduce_position(x.reshape(x.shape[0], -1)) for x in node_hx]
961
+ query_hx = [self.reduce_query(nhx) for nhx in node_hx]
962
+ # reshape query for MHA
963
+ query_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in query_hx]
964
+ if self.constituency_composition == ConstituencyComposition.KEY:
965
+ queries = [torch.matmul(nhx, self.reduce_key) for nhx in query_hx]
966
+ else:
967
+ label_indices = [self.constituent_open_map[label] for label in labels]
968
+ queries = [torch.matmul(nhx, self.reduce_key[label_idx]) for nhx, label_idx in zip(query_hx, label_indices)]
969
+ # softmax each head
970
+ weights = [torch.nn.functional.softmax(nhx, dim=1).transpose(1, 2) for nhx in queries]
971
+ value_hx = [self.reduce_value(nhx) for nhx in node_hx]
972
+ value_hx = [nhx.reshape(nhx.shape[0], self.reduce_heads, -1).transpose(0, 1) for nhx in value_hx]
973
+ # use the softmaxes to add up the heads
974
+ unpacked_hx = [torch.matmul(weight, nhx).squeeze(1) for weight, nhx in zip(weights, value_hx)]
975
+ unpacked_hx = [nhx.reshape(-1) for nhx in unpacked_hx]
976
+ hx = torch.stack(unpacked_hx, axis=0).unsqueeze(0)
977
+ lstm_hx = self.nonlinearity(hx)
978
+ lstm_cx = None
979
+ elif self.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
980
+ label_hx = [self.lstm_input_dropout(self.constituent_open_embedding(self.constituent_open_tensors[self.constituent_open_map[label]])) for label in labels]
981
+ label_hx = torch.stack(label_hx).unsqueeze(0)
982
+
983
+ max_length = max(len(children) for children in children_lists)
984
+
985
+ # stacking will let us do elementwise multiplication faster, hopefully
986
+ node_hx = [[child.value.tree_hx for child in children] for children in children_lists]
987
+ unpacked_hx = [self.lstm_input_dropout(torch.stack(nhx)) for nhx in node_hx]
988
+ unpacked_hx = [nhx.max(dim=0) for nhx in unpacked_hx]
989
+ packed_hx = torch.stack([nhx.values for nhx in unpacked_hx], axis=1)
990
+ #packed_hx = packed_hx.max(dim=0).values
991
+
992
+ node_cx = [torch.stack([child.value.tree_cx for child in children]) for children in children_lists]
993
+ node_cx_indices = [uhx.indices.unsqueeze(0) for uhx in unpacked_hx]
994
+ unpacked_cx = [ncx.gather(0, nci).squeeze(0) for ncx, nci in zip(node_cx, node_cx_indices)]
995
+ packed_cx = torch.stack(unpacked_cx, axis=1)
996
+
997
+ _, (lstm_hx, lstm_cx) = self.constituent_reduce_lstm(label_hx, (packed_hx, packed_cx))
998
+ else:
999
+ raise ValueError("Unhandled ConstituencyComposition: {}".format(self.constituency_composition))
1000
+
1001
+ constituents = []
1002
+ for idx, (label, children) in enumerate(zip(labels, children_lists)):
1003
+ children = [child.value.value for child in children]
1004
+ if isinstance(label, str):
1005
+ node = Tree(label=label, children=children)
1006
+ else:
1007
+ for value in reversed(label):
1008
+ node = Tree(label=value, children=children)
1009
+ children = node
1010
+ constituents.append(Constituent(node, lstm_hx[:, idx, :], lstm_cx[:, idx, :] if lstm_cx is not None else None))
1011
+ return constituents
1012
+
1013
+ def push_constituents(self, constituent_stacks, constituents):
1014
+ # Another possibility here would be to use output[0, i, :]
1015
+ # from the constituency lstm for the value of the new node.
1016
+ # This might theoretically make the new constituent include
1017
+ # information from neighboring constituents. However, this
1018
+ # lowers the scores of various models.
1019
+ # For example, an experiment on ja_alt built this way,
1020
+ # averaged over 5 trials, had the following loss in accuracy:
1021
+ # 150 epochs: 0.8971 to 0.8953
1022
+ # 200 epochs: 0.8985 to 0.8964
1023
+ current_nodes = [stack.value for stack in constituent_stacks]
1024
+
1025
+ constituent_input = torch.stack([x.tree_hx[-1:] for x in constituents], axis=1)
1026
+ #constituent_input = constituent_input.unsqueeze(0)
1027
+ # the constituents are already Constituent(tree, tree_hx, tree_cx)
1028
+ return self.constituent_stack.push_states(constituent_stacks, constituents, constituent_input)
1029
+
1030
+ def get_top_constituent(self, constituents):
1031
+ """
1032
+ Extract only the top constituent from a state's constituent
1033
+ sequence, even though it has multiple addition pieces of
1034
+ information
1035
+ """
1036
+ # TreeStack value -> LSTMTreeStack value -> Constituent value -> constituent
1037
+ return constituents.value.value.value
1038
+
1039
+ def push_transitions(self, transition_stacks, transitions):
1040
+ """
1041
+ Push all of the given transitions on to the stack as a batch operations.
1042
+
1043
+ Significantly faster than doing one transition at a time.
1044
+ """
1045
+ transition_idx = torch.stack([self.transition_tensors[self.transition_map[transition]] for transition in transitions])
1046
+ transition_input = self.transition_embedding(transition_idx).unsqueeze(0)
1047
+ return self.transition_stack.push_states(transition_stacks, transitions, transition_input)
1048
+
1049
+ def get_top_transition(self, transitions):
1050
+ """
1051
+ Extract only the top transition from a state's transition
1052
+ sequence, even though it has multiple addition pieces of
1053
+ information
1054
+ """
1055
+ # TreeStack value -> LSTMTreeStack value -> transition
1056
+ return transitions.value.value
1057
+
1058
+ def forward(self, states):
1059
+ """
1060
+ Return logits for a prediction of what transition to make next
1061
+
1062
+ We've basically done all the work analyzing the state as
1063
+ part of applying the transitions, so this method is very simple
1064
+
1065
+ return shape: (num_states, num_transitions)
1066
+ """
1067
+ word_hx = torch.stack([state.get_word(state.word_position).hx for state in states])
1068
+ transition_hx = torch.stack([self.transition_stack.output(state.transitions) for state in states])
1069
+ # this .output() is the output of the constituent stack, not the
1070
+ # constituent itself
1071
+ # this way, we can, as an option, NOT include the constituents to the left
1072
+ # when building the current vector for a constituent
1073
+ # and the vector used for inference will still incorporate the entire LSTM
1074
+ constituent_hx = torch.stack([self.constituent_stack.output(state.constituents) for state in states])
1075
+
1076
+ hx = torch.cat((word_hx, transition_hx, constituent_hx), axis=1)
1077
+ for idx, output_layer in enumerate(self.output_layers):
1078
+ hx = self.predict_dropout(hx)
1079
+ if not self.maxout_k and idx < len(self.output_layers) - 1:
1080
+ hx = self.nonlinearity(hx)
1081
+ hx = output_layer(hx)
1082
+ return hx
1083
+
1084
+ def predict(self, states, is_legal=True):
1085
+ """
1086
+ Generate and return predictions, along with the transitions those predictions represent
1087
+
1088
+ If is_legal is set to True, will only return legal transitions.
1089
+ This means returning None if there are no legal transitions.
1090
+ Hopefully the constraints prevent that from happening
1091
+ """
1092
+ predictions = self.forward(states)
1093
+ pred_max = torch.argmax(predictions, dim=1)
1094
+ scores = torch.take_along_dim(predictions, pred_max.unsqueeze(1), dim=1)
1095
+ pred_max = pred_max.detach().cpu()
1096
+
1097
+ pred_trans = [self.transitions[pred_max[idx]] for idx in range(len(states))]
1098
+ if is_legal:
1099
+ for idx, (state, trans) in enumerate(zip(states, pred_trans)):
1100
+ if not trans.is_legal(state, self):
1101
+ _, indices = predictions[idx, :].sort(descending=True)
1102
+ for index in indices:
1103
+ if self.transitions[index].is_legal(state, self):
1104
+ pred_trans[idx] = self.transitions[index]
1105
+ scores[idx] = predictions[idx, index]
1106
+ break
1107
+ else: # yeah, else on a for loop, deal with it
1108
+ pred_trans[idx] = None
1109
+ scores[idx] = None
1110
+
1111
+ return predictions, pred_trans, scores.squeeze(1)
1112
+
1113
+ def weighted_choice(self, states):
1114
+ """
1115
+ Generate and return predictions, and randomly choose a prediction weighted by the scores
1116
+
1117
+ TODO: pass in a temperature
1118
+ """
1119
+ predictions = self.forward(states)
1120
+ pred_trans = []
1121
+ all_scores = []
1122
+ for state, prediction in zip(states, predictions):
1123
+ legal_idx = [idx for idx in range(prediction.shape[0]) if self.transitions[idx].is_legal(state, self)]
1124
+ if len(legal_idx) == 0:
1125
+ pred_trans.append(None)
1126
+ continue
1127
+ scores = prediction[legal_idx]
1128
+ scores = torch.softmax(scores, dim=0)
1129
+ idx = torch.multinomial(scores, 1)
1130
+ idx = legal_idx[idx]
1131
+ pred_trans.append(self.transitions[idx])
1132
+ all_scores.append(prediction[idx])
1133
+ all_scores = torch.stack(all_scores)
1134
+ return predictions, pred_trans, all_scores
1135
+
1136
+ def predict_gold(self, states):
1137
+ """
1138
+ For each State, return the next item in the gold_sequence
1139
+ """
1140
+ predictions = self.forward(states)
1141
+ transitions = [y.gold_sequence[y.num_transitions] for y in states]
1142
+ indices = torch.tensor([self.transition_map[t] for t in transitions], device=predictions.device)
1143
+ scores = torch.take_along_dim(predictions, indices.unsqueeze(1), dim=1)
1144
+ return predictions, transitions, scores.squeeze(1)
1145
+
1146
+ def get_params(self, skip_modules=True):
1147
+ """
1148
+ Get a dictionary for saving the model
1149
+ """
1150
+ model_state = self.state_dict()
1151
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
1152
+ if skip_modules:
1153
+ skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]
1154
+ for k in skipped:
1155
+ del model_state[k]
1156
+ config = copy.deepcopy(self.args)
1157
+ config['sentence_boundary_vectors'] = config['sentence_boundary_vectors'].name
1158
+ config['constituency_composition'] = config['constituency_composition'].name
1159
+ config['transition_stack'] = config['transition_stack'].name
1160
+ config['constituent_stack'] = config['constituent_stack'].name
1161
+ config['transition_scheme'] = config['transition_scheme'].name
1162
+ assert isinstance(self.rare_words, set)
1163
+ params = {
1164
+ 'model': model_state,
1165
+ 'model_type': "LSTM",
1166
+ 'config': config,
1167
+ 'transitions': [repr(x) for x in self.transitions],
1168
+ 'constituents': self.constituents,
1169
+ 'tags': self.tags,
1170
+ 'words': self.delta_words,
1171
+ 'rare_words': list(self.rare_words),
1172
+ 'root_labels': self.root_labels,
1173
+ 'constituent_opens': self.constituent_opens,
1174
+ 'unary_limit': self.unary_limit(),
1175
+ }
1176
+
1177
+ return params
1178
+
stanza/stanza/models/constituency/parse_tree.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tree datastructure
3
+ """
4
+
5
+ from collections import deque, Counter
6
+ import copy
7
+ from enum import Enum
8
+ from io import StringIO
9
+ import itertools
10
+ import re
11
+ import warnings
12
+
13
+ from stanza.models.common.stanza_object import StanzaObject
14
+
15
+ # useful more for the "is" functionality than the time savings
16
+ CLOSE_PAREN = ')'
17
+ SPACE_SEPARATOR = ' '
18
+ OPEN_PAREN = '('
19
+
20
+ EMPTY_CHILDREN = ()
21
+
22
+ # used to split off the functional tags from various treebanks
23
+ # for example, the Icelandic treebank (which we don't currently
24
+ # incorporate) uses * to distinguish 'ADJP', 'ADJP*OC' but we treat
25
+ # those as the same
26
+ CONSTITUENT_SPLIT = re.compile("[-=#*]")
27
+
28
+ # These words occur in the VLSP dataset.
29
+ # The documentation claims there might be *O*, although those don't
30
+ # seem to exist in practice
31
+ WORDS_TO_PRUNE = ('*E*', '*T*', '*O*')
32
+
33
+ class TreePrintMethod(Enum):
34
+ """
35
+ Describes a few options for printing trees.
36
+
37
+ This probably doesn't need to be used directly. See __format__
38
+ """
39
+ ONE_LINE = 1 # (ROOT (S ... ))
40
+ LABELED_PARENS = 2 # (_ROOT (_S ... )_S )_ROOT
41
+ PRETTY = 3 # multiple lines
42
+ VLSP = 4 # <s> (S ... ) </s>
43
+ LATEX_TREE = 5 # \Tree [.S [.NP ... ] ]
44
+
45
+
46
+ class Tree(StanzaObject):
47
+ """
48
+ A data structure to represent a parse tree
49
+ """
50
+ def __init__(self, label=None, children=None):
51
+ if children is None:
52
+ self.children = EMPTY_CHILDREN
53
+ elif isinstance(children, Tree):
54
+ self.children = (children,)
55
+ else:
56
+ self.children = tuple(children)
57
+
58
+ self.label = label
59
+
60
+ def is_leaf(self):
61
+ return len(self.children) == 0
62
+
63
+ def is_preterminal(self):
64
+ return len(self.children) == 1 and len(self.children[0].children) == 0
65
+
66
+ def yield_preterminals(self):
67
+ """
68
+ Yield the preterminals one at a time in order
69
+ """
70
+ if self.is_preterminal():
71
+ yield self
72
+ return
73
+
74
+ if self.is_leaf():
75
+ raise ValueError("Attempted to iterate preterminals on non-internal node")
76
+
77
+ iterator = iter(self.children)
78
+ node = next(iterator, None)
79
+ while node is not None:
80
+ if node.is_preterminal():
81
+ yield node
82
+ else:
83
+ iterator = itertools.chain(node.children, iterator)
84
+ node = next(iterator, None)
85
+
86
+ def leaf_labels(self):
87
+ """
88
+ Get the labels of the leaves
89
+ """
90
+ if self.is_leaf():
91
+ return [self.label]
92
+
93
+ words = [x.children[0].label for x in self.yield_preterminals()]
94
+ return words
95
+
96
+ def __len__(self):
97
+ return len(self.leaf_labels())
98
+
99
+ def all_leaves_are_preterminals(self):
100
+ """
101
+ Returns True if all leaves are under preterminals, False otherwise
102
+ """
103
+ if self.is_leaf():
104
+ return False
105
+
106
+ if self.is_preterminal():
107
+ return True
108
+
109
+ return all(t.all_leaves_are_preterminals() for t in self.children)
110
+
111
+ def pretty_print(self, normalize=None):
112
+ """
113
+ Print with newlines & indentation on each line
114
+
115
+ Preterminals and nodes with all preterminal children go on their own line
116
+
117
+ You can pass in your own normalize() function. If you do,
118
+ make sure the function updates the parens to be something
119
+ other than () or the brackets will be broken
120
+ """
121
+ if normalize is None:
122
+ normalize = lambda x: x.replace("(", "-LRB-").replace(")", "-RRB-")
123
+
124
+ indent = 0
125
+ with StringIO() as buf:
126
+ stack = deque()
127
+ stack.append(self)
128
+ while len(stack) > 0:
129
+ node = stack.pop()
130
+
131
+ if node is CLOSE_PAREN:
132
+ # if we're trying to pretty print trees, pop all off close parens
133
+ # then write a newline
134
+ while node is CLOSE_PAREN:
135
+ indent -= 1
136
+ buf.write(CLOSE_PAREN)
137
+ if len(stack) == 0:
138
+ node = None
139
+ break
140
+ node = stack.pop()
141
+ buf.write("\n")
142
+ if node is None:
143
+ break
144
+ stack.append(node)
145
+ elif node.is_preterminal():
146
+ buf.write(" " * indent)
147
+ buf.write("%s%s %s%s" % (OPEN_PAREN, normalize(node.label), normalize(node.children[0].label), CLOSE_PAREN))
148
+ if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:
149
+ buf.write("\n")
150
+ elif all(x.is_preterminal() for x in node.children):
151
+ buf.write(" " * indent)
152
+ buf.write("%s%s" % (OPEN_PAREN, normalize(node.label)))
153
+ for child in node.children:
154
+ buf.write(" %s%s %s%s" % (OPEN_PAREN, normalize(child.label), normalize(child.children[0].label), CLOSE_PAREN))
155
+ buf.write(CLOSE_PAREN)
156
+ if len(stack) == 0 or stack[-1] is not CLOSE_PAREN:
157
+ buf.write("\n")
158
+ else:
159
+ buf.write(" " * indent)
160
+ buf.write("%s%s\n" % (OPEN_PAREN, normalize(node.label)))
161
+ stack.append(CLOSE_PAREN)
162
+ for child in reversed(node.children):
163
+ stack.append(child)
164
+ indent += 1
165
+
166
+ buf.seek(0)
167
+ return buf.read()
168
+
169
+ def __format__(self, spec):
170
+ """
171
+ Turn the tree into a string representing the tree
172
+
173
+ Note that this is not a recursive traversal
174
+ Otherwise, a tree too deep might blow up the call stack
175
+
176
+ There is a type specific format:
177
+ O -> one line PTB format, which is the default anyway
178
+ L -> open and close brackets are labeled, spaces in the tokens are replaced with _
179
+ P -> pretty print over multiple lines
180
+ V -> surround lines with <s>...</s>, don't print ROOT, and turn () into L/RBKT
181
+ ? -> spaces in the tokens are replaced with ? for any value of ? other than OLP
182
+ warning: this may be removed in the future
183
+ ?{OLPV} -> specific format AND a custom space replacement
184
+ Vi -> add an ID to the <s> in the V format. Also works with ?Vi
185
+ """
186
+ space_replacement = " "
187
+ print_format = TreePrintMethod.ONE_LINE
188
+ if spec == 'L':
189
+ print_format = TreePrintMethod.LABELED_PARENS
190
+ space_replacement = "_"
191
+ elif spec and spec[-1] == 'L':
192
+ print_format = TreePrintMethod.LABELED_PARENS
193
+ space_replacement = spec[0]
194
+ elif spec == 'O':
195
+ print_format = TreePrintMethod.ONE_LINE
196
+ elif spec and spec[-1] == 'O':
197
+ print_format = TreePrintMethod.ONE_LINE
198
+ space_replacement = spec[0]
199
+ elif spec == 'P':
200
+ print_format = TreePrintMethod.PRETTY
201
+ elif spec and spec[-1] == 'P':
202
+ print_format = TreePrintMethod.PRETTY
203
+ space_replacement = spec[0]
204
+ elif spec and spec[0] == 'V':
205
+ print_format = TreePrintMethod.VLSP
206
+ use_tree_id = spec[-1] == 'i'
207
+ elif spec and len(spec) > 1 and spec[1] == 'V':
208
+ print_format = TreePrintMethod.VLSP
209
+ space_replacement = spec[0]
210
+ use_tree_id = spec[-1] == 'i'
211
+ elif spec == 'T':
212
+ print_format = TreePrintMethod.LATEX_TREE
213
+ elif spec and len(spec) > 1 and spec[1] == 'T':
214
+ print_format = TreePrintMethod.LATEX_TREE
215
+ space_replacement = spec[0]
216
+ elif spec:
217
+ space_replacement = spec[0]
218
+ warnings.warn("Use of a custom replacement without a format specifier is deprecated. Please use {}O instead".format(space_replacement), stacklevel=2)
219
+
220
+ LRB = "LBKT" if print_format == TreePrintMethod.VLSP else "-LRB-"
221
+ RRB = "RBKT" if print_format == TreePrintMethod.VLSP else "-RRB-"
222
+ def normalize(text):
223
+ return text.replace(" ", space_replacement).replace("(", LRB).replace(")", RRB)
224
+
225
+ if print_format is TreePrintMethod.PRETTY:
226
+ return self.pretty_print(normalize)
227
+
228
+ with StringIO() as buf:
229
+ stack = deque()
230
+ if print_format == TreePrintMethod.VLSP:
231
+ if use_tree_id:
232
+ buf.write("<s id={}>\n".format(self.tree_id))
233
+ else:
234
+ buf.write("<s>\n")
235
+ if len(self.children) == 0:
236
+ raise ValueError("Cannot print an empty tree with V format")
237
+ elif len(self.children) > 1:
238
+ raise ValueError("Cannot print a tree with %d branches with V format" % len(self.children))
239
+ stack.append(self.children[0])
240
+ elif print_format == TreePrintMethod.LATEX_TREE:
241
+ buf.write("\\Tree ")
242
+ if len(self.children) == 0:
243
+ raise ValueError("Cannot print an empty tree with T format")
244
+ elif len(self.children) == 1 and len(self.children[0].children) == 0:
245
+ buf.write("[.? ")
246
+ buf.write(normalize(self.children[0].label))
247
+ buf.write(" ]")
248
+ elif self.label == 'ROOT':
249
+ stack.append(self.children[0])
250
+ else:
251
+ stack.append(self)
252
+ else:
253
+ stack.append(self)
254
+ while len(stack) > 0:
255
+ node = stack.pop()
256
+
257
+ if isinstance(node, str):
258
+ buf.write(node)
259
+ continue
260
+ if len(node.children) == 0:
261
+ if node.label is not None:
262
+ buf.write(normalize(node.label))
263
+ continue
264
+
265
+ if print_format is TreePrintMethod.LATEX_TREE:
266
+ if node.is_preterminal():
267
+ buf.write(normalize(node.children[0].label))
268
+ continue
269
+ buf.write("[.%s" % normalize(node.label))
270
+ stack.append(" ]")
271
+ elif print_format is TreePrintMethod.ONE_LINE or print_format is TreePrintMethod.VLSP:
272
+ buf.write(OPEN_PAREN)
273
+ if node.label is not None:
274
+ buf.write(normalize(node.label))
275
+ stack.append(CLOSE_PAREN)
276
+ elif print_format is TreePrintMethod.LABELED_PARENS:
277
+ buf.write("%s_%s" % (OPEN_PAREN, normalize(node.label)))
278
+ stack.append(CLOSE_PAREN + "_" + normalize(node.label))
279
+ stack.append(SPACE_SEPARATOR)
280
+
281
+ for child in reversed(node.children):
282
+ stack.append(child)
283
+ stack.append(SPACE_SEPARATOR)
284
+ if print_format == TreePrintMethod.VLSP:
285
+ buf.write("\n</s>")
286
+ buf.seek(0)
287
+ return buf.read()
288
+
289
+ def __repr__(self):
290
+ return "{}".format(self)
291
+
292
+ def __eq__(self, other):
293
+ if self is other:
294
+ return True
295
+ if not isinstance(other, Tree):
296
+ return False
297
+ if self.label != other.label:
298
+ return False
299
+ if len(self.children) != len(other.children):
300
+ return False
301
+ if any(c1 != c2 for c1, c2 in zip(self.children, other.children)):
302
+ return False
303
+ return True
304
+
305
+ def depth(self):
306
+ if not self.children:
307
+ return 0
308
+ return 1 + max(x.depth() for x in self.children)
309
+
310
+ def visit_preorder(self, internal=None, preterminal=None, leaf=None):
311
+ """
312
+ Visit the tree in a preorder order
313
+
314
+ Applies the given functions to each node.
315
+ internal: if not None, applies this function to each non-leaf, non-preterminal node
316
+ preterminal: if not None, applies this functiion to each preterminal
317
+ leaf: if not None, applies this function to each leaf
318
+
319
+ The functions should *not* destructively alter the trees.
320
+ There is no attempt to interpret the results of calling these functions.
321
+ Rather, you can use visit_preorder to collect stats on trees, etc.
322
+ """
323
+ if self.is_leaf():
324
+ if leaf:
325
+ leaf(self)
326
+ elif self.is_preterminal():
327
+ if preterminal:
328
+ preterminal(self)
329
+ else:
330
+ if internal:
331
+ internal(self)
332
+ for child in self.children:
333
+ child.visit_preorder(internal, preterminal, leaf)
334
+
335
+ @staticmethod
336
+ def get_unique_constituent_labels(trees):
337
+ """
338
+ Walks over all of the trees and gets all of the unique constituent names from the trees
339
+ """
340
+ if isinstance(trees, Tree):
341
+ trees = [trees]
342
+ constituents = Tree.get_constituent_counts(trees)
343
+ return sorted(set(constituents.keys()))
344
+
345
+ @staticmethod
346
+ def get_constituent_counts(trees):
347
+ """
348
+ Walks over all of the trees and gets the count of the unique constituent names from the trees
349
+ """
350
+ if isinstance(trees, Tree):
351
+ trees = [trees]
352
+
353
+ constituents = Counter()
354
+ for tree in trees:
355
+ tree.visit_preorder(internal = lambda x: constituents.update([x.label]))
356
+ return constituents
357
+
358
+ @staticmethod
359
+ def get_unique_tags(trees):
360
+ """
361
+ Walks over all of the trees and gets all of the unique tags from the trees
362
+ """
363
+ if isinstance(trees, Tree):
364
+ trees = [trees]
365
+
366
+ tags = set()
367
+ for tree in trees:
368
+ tree.visit_preorder(preterminal = lambda x: tags.add(x.label))
369
+ return sorted(tags)
370
+
371
+ @staticmethod
372
+ def get_unique_words(trees):
373
+ """
374
+ Walks over all of the trees and gets all of the unique words from the trees
375
+ """
376
+ if isinstance(trees, Tree):
377
+ trees = [trees]
378
+
379
+ words = set()
380
+ for tree in trees:
381
+ tree.visit_preorder(leaf = lambda x: words.add(x.label))
382
+ return sorted(words)
383
+
384
+ @staticmethod
385
+ def get_common_words(trees, num_words):
386
+ """
387
+ Walks over all of the trees and gets the most frequently occurring words.
388
+ """
389
+ if num_words == 0:
390
+ return set()
391
+
392
+ if isinstance(trees, Tree):
393
+ trees = [trees]
394
+
395
+ words = Counter()
396
+ for tree in trees:
397
+ tree.visit_preorder(leaf = lambda x: words.update([x.label]))
398
+ return sorted(x[0] for x in words.most_common()[:num_words])
399
+
400
+ @staticmethod
401
+ def get_rare_words(trees, threshold=0.05):
402
+ """
403
+ Walks over all of the trees and gets the least frequently occurring words.
404
+
405
+ threshold: choose the bottom X percent
406
+ """
407
+ if isinstance(trees, Tree):
408
+ trees = [trees]
409
+
410
+ words = Counter()
411
+ for tree in trees:
412
+ tree.visit_preorder(leaf = lambda x: words.update([x.label]))
413
+ threshold = max(int(len(words) * threshold), 1)
414
+ return sorted(x[0] for x in words.most_common()[:-threshold-1:-1])
415
+
416
+ @staticmethod
417
+ def get_root_labels(trees):
418
+ return sorted(set(x.label for x in trees))
419
+
420
+ @staticmethod
421
+ def get_compound_constituents(trees, separate_root=False):
422
+ constituents = set()
423
+ stack = deque()
424
+ for tree in trees:
425
+ if separate_root:
426
+ constituents.add((tree.label,))
427
+ for child in tree.children:
428
+ stack.append(child)
429
+ else:
430
+ stack.append(tree)
431
+ while len(stack) > 0:
432
+ node = stack.pop()
433
+ if node.is_leaf() or node.is_preterminal():
434
+ continue
435
+ labels = [node.label]
436
+ while len(node.children) == 1 and not node.children[0].is_preterminal():
437
+ node = node.children[0]
438
+ labels.append(node.label)
439
+ constituents.add(tuple(labels))
440
+ for child in node.children:
441
+ stack.append(child)
442
+ return sorted(constituents)
443
+
444
+ # TODO: test different pattern
445
+ def simplify_labels(self, pattern=CONSTITUENT_SPLIT):
446
+ """
447
+ Return a copy of the tree with the -=# removed
448
+
449
+ Leaves the text of the leaves alone.
450
+ """
451
+ new_label = self.label
452
+ # check len(new_label) just in case it's a tag of - or =
453
+ if new_label and not self.is_leaf() and len(new_label) > 1 and new_label not in ('-LRB-', '-RRB-'):
454
+ new_label = pattern.split(new_label)[0]
455
+ new_children = [child.simplify_labels(pattern) for child in self.children]
456
+ return Tree(new_label, new_children)
457
+
458
+ def reverse(self):
459
+ """
460
+ Flip a tree backwards
461
+
462
+ The intent is to train a parser backwards to see if the
463
+ forward and backwards parsers can augment each other
464
+ """
465
+ if self.is_leaf():
466
+ return Tree(self.label)
467
+
468
+ new_children = [child.reverse() for child in reversed(self.children)]
469
+ return Tree(self.label, new_children)
470
+
471
+ def remap_constituent_labels(self, label_map):
472
+ """
473
+ Copies the tree with some labels replaced.
474
+
475
+ Labels in the map are replaced with the mapped value.
476
+ Labels not in the map are unchanged.
477
+ """
478
+ if self.is_leaf():
479
+ return Tree(self.label)
480
+ if self.is_preterminal():
481
+ return Tree(self.label, Tree(self.children[0].label))
482
+ new_label = label_map.get(self.label, self.label)
483
+ return Tree(new_label, [child.remap_constituent_labels(label_map) for child in self.children])
484
+
485
+ def remap_words(self, word_map):
486
+ """
487
+ Copies the tree with some labels replaced.
488
+
489
+ Labels in the map are replaced with the mapped value.
490
+ Labels not in the map are unchanged.
491
+ """
492
+ if self.is_leaf():
493
+ new_label = word_map.get(self.label, self.label)
494
+ return Tree(new_label)
495
+ if self.is_preterminal():
496
+ return Tree(self.label, self.children[0].remap_words(word_map))
497
+ return Tree(self.label, [child.remap_words(word_map) for child in self.children])
498
+
499
+ def replace_words(self, words):
500
+ """
501
+ Replace all leaf words with the words in the given list (or iterable)
502
+
503
+ Returns a new tree
504
+ """
505
+ word_iterator = iter(words)
506
+ def recursive_replace_words(subtree):
507
+ if subtree.is_leaf():
508
+ word = next(word_iterator, None)
509
+ if word is None:
510
+ raise ValueError("Not enough words to replace all leaves")
511
+ return Tree(word)
512
+ return Tree(subtree.label, [recursive_replace_words(x) for x in subtree.children])
513
+
514
+ new_tree = recursive_replace_words(self)
515
+ if any(True for _ in word_iterator):
516
+ raise ValueError("Too many words for the given tree")
517
+ return new_tree
518
+
519
+
520
+ def replace_tags(self, tags):
521
+ if self.is_leaf():
522
+ raise ValueError("Must call replace_tags with non-leaf")
523
+
524
+ if isinstance(tags, Tree):
525
+ tag_iterator = (x.label for x in tags.yield_preterminals())
526
+ else:
527
+ tag_iterator = iter(tags)
528
+
529
+ new_tree = copy.deepcopy(self)
530
+ queue = deque()
531
+ queue.append(new_tree)
532
+ while len(queue) > 0:
533
+ next_node = queue.pop()
534
+ if next_node.is_preterminal():
535
+ try:
536
+ label = next(tag_iterator)
537
+ except StopIteration:
538
+ raise ValueError("Not enough tags in sentence for given tree")
539
+ next_node.label = label
540
+ elif next_node.is_leaf():
541
+ raise ValueError("Got a badly structured tree: {}".format(self))
542
+ else:
543
+ queue.extend(reversed(next_node.children))
544
+
545
+ if any(True for _ in tag_iterator):
546
+ raise ValueError("Too many tags for the given tree")
547
+
548
+ return new_tree
549
+
550
+
551
+ def prune_none(self):
552
+ """
553
+ Return a copy of the tree, eliminating all nodes which are in one of two categories:
554
+ they are a preterminal -NONE-, such as appears in PTB
555
+ *E* shows up in a VLSP dataset
556
+ they have been pruned to 0 children by the recursive call
557
+ """
558
+ if self.is_leaf():
559
+ return Tree(self.label)
560
+ if self.is_preterminal():
561
+ if self.label == '-NONE-' or self.children[0].label in WORDS_TO_PRUNE:
562
+ return None
563
+ return Tree(self.label, Tree(self.children[0].label))
564
+ # must be internal node
565
+ new_children = [child.prune_none() for child in self.children]
566
+ new_children = [child for child in new_children if child is not None]
567
+ if len(new_children) == 0:
568
+ return None
569
+ return Tree(self.label, new_children)
570
+
571
+ def count_unary_depth(self):
572
+ if self.is_preterminal() or self.is_leaf():
573
+ return 0
574
+ if len(self.children) == 1:
575
+ t = self
576
+ score = 0
577
+ while not t.is_preterminal() and not t.is_leaf() and len(t.children) == 1:
578
+ score = score + 1
579
+ t = t.children[0]
580
+ child_score = max(tc.count_unary_depth() for tc in t.children)
581
+ score = max(score, child_score)
582
+ return score
583
+ score = max(t.count_unary_depth() for t in self.children)
584
+ return score
585
+
586
+ @staticmethod
587
+ def write_treebank(trees, out_file, fmt="{}"):
588
+ with open(out_file, "w", encoding="utf-8") as fout:
589
+ for tree in trees:
590
+ fout.write(fmt.format(tree))
591
+ fout.write("\n")
stanza/stanza/models/constituency/positional_encoding.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on
3
+ https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model
4
+ """
5
+
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ class SinusoidalEncoding(nn.Module):
12
+ """
13
+ Uses sine & cosine to represent position
14
+ """
15
+ def __init__(self, model_dim, max_len):
16
+ super().__init__()
17
+ self.register_buffer('pe', self.build_position(model_dim, max_len))
18
+
19
+ @staticmethod
20
+ def build_position(model_dim, max_len, device=None):
21
+ position = torch.arange(max_len).unsqueeze(1)
22
+ div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
23
+ pe = torch.zeros(max_len, model_dim)
24
+ pe[:, 0::2] = torch.sin(position * div_term)
25
+ pe[:, 1::2] = torch.cos(position * div_term)
26
+ if device is not None:
27
+ pe = pe.to(device=device)
28
+ return pe
29
+
30
+ def forward(self, x):
31
+ if max(x) >= self.pe.shape[0]:
32
+ # try to drop the reference first before creating a new encoding
33
+ # the goal being to save memory if we are close to the memory limit
34
+ device = self.pe.device
35
+ shape = self.pe.shape[1]
36
+ self.register_buffer('pe', None)
37
+ # TODO: this may result in very poor performance
38
+ # in the event of a model that increases size one at a time
39
+ self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device))
40
+ return self.pe[x]
41
+
42
+ def max_len(self):
43
+ return self.pe.shape[0]
44
+
45
+
46
+ class AddSinusoidalEncoding(nn.Module):
47
+ """
48
+ Uses sine & cosine to represent position. Adds the position to the given matrix
49
+
50
+ Default behavior is batch_first
51
+ """
52
+ def __init__(self, d_model=256, max_len=512):
53
+ super().__init__()
54
+ self.encoding = SinusoidalEncoding(d_model, max_len)
55
+
56
+ def forward(self, x, scale=1.0):
57
+ """
58
+ Adds the positional encoding to the input tensor
59
+
60
+ The tensor is expected to be of the shape B, N, D
61
+ Properly masking the output tensor is up to the caller
62
+ """
63
+ if len(x.shape) == 3:
64
+ timing = self.encoding(torch.arange(x.shape[1], device=x.device))
65
+ timing = timing.expand(x.shape[0], -1, -1)
66
+ elif len(x.shape) == 2:
67
+ timing = self.encoding(torch.arange(x.shape[0], device=x.device))
68
+ return x + timing * scale
69
+
70
+
71
+ class ConcatSinusoidalEncoding(nn.Module):
72
+ """
73
+ Uses sine & cosine to represent position. Concats the position and returns a larger object
74
+
75
+ Default behavior is batch_first
76
+ """
77
+ def __init__(self, d_model=256, max_len=512):
78
+ super().__init__()
79
+ self.encoding = SinusoidalEncoding(d_model, max_len)
80
+
81
+ def forward(self, x):
82
+ if len(x.shape) == 3:
83
+ timing = self.encoding(torch.arange(x.shape[1], device=x.device))
84
+ timing = timing.expand(x.shape[0], -1, -1)
85
+ else:
86
+ timing = self.encoding(torch.arange(x.shape[0], device=x.device))
87
+
88
+ out = torch.cat((x, timing), dim=-1)
89
+ return out
stanza/stanza/models/constituency/retagging.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refactor a few functions specifically for retagging trees
3
+
4
+ Retagging is important because the gold tags will not be available at runtime
5
+
6
+ Note that the method which does the actual retagging is in utils.py
7
+ so as to avoid unnecessary circular imports
8
+ (eg, Pipeline imports constituency/trainer which imports this which imports Pipeline)
9
+ """
10
+
11
+ import copy
12
+ import logging
13
+
14
+ from stanza import Pipeline
15
+
16
+ from stanza.models.common.foundation_cache import FoundationCache
17
+ from stanza.models.common.vocab import VOCAB_PREFIX
18
+ from stanza.resources.common import download_resources_json, load_resources_json, get_language_resources
19
+
20
+ tlogger = logging.getLogger('stanza.constituency.trainer')
21
+
22
+ # xpos tagger doesn't produce PP tag on the turin treebank,
23
+ # so instead we use upos to avoid unknown tag errors
24
+ RETAG_METHOD = {
25
+ "da": "upos", # the DDT has no xpos tags anyway
26
+ "de": "upos", # DE GSD is also missing a few punctuation tags
27
+ "es": "upos", # AnCora has half-finished xpos tags
28
+ "id": "upos", # GSD is missing a few punctuation tags - fixed in 2.12, though
29
+ "it": "upos",
30
+ "pt": "upos", # default PT model has no xpos either
31
+ "vi": "xpos", # the new version of UD can be merged with xpos from VLSP22
32
+ }
33
+
34
+ def add_retag_args(parser):
35
+ """
36
+ Arguments specifically for retagging treebanks
37
+ """
38
+ parser.add_argument('--retag_package', default="default", help='Which tagger shortname to use when retagging trees. None for no retagging. Retagging is recommended, as gold tags will not be available at pipeline time')
39
+ parser.add_argument('--retag_method', default=None, choices=['xpos', 'upos'], help='Which tags to use when retagging. Default depends on the language')
40
+ parser.add_argument('--retag_model_path', default=None, help='Path to a retag POS model to use. Will use a downloaded Stanza model by default. Can specify multiple taggers with ; in which case the majority vote wins')
41
+ parser.add_argument('--retag_pretrain_path', default=None, help='Use this for a pretrain path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom pretrain')
42
+ parser.add_argument('--retag_charlm_forward_file', default=None, help='Use this for a forward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm')
43
+ parser.add_argument('--retag_charlm_backward_file', default=None, help='Use this for a backward charlm path for the retagging pipeline. Generally not needed unless using a custom POS model with a custom charlm')
44
+ parser.add_argument('--no_retag', dest='retag_package', action="store_const", const=None, help="Don't retag the trees")
45
+
46
+ def postprocess_args(args):
47
+ """
48
+ After parsing args, unify some settings
49
+ """
50
+ # use a language specific default for retag_method if we know the language
51
+ # otherwise, use xpos
52
+ if args['retag_method'] is None and 'lang' in args and args['lang'] in RETAG_METHOD:
53
+ args['retag_method'] = RETAG_METHOD[args['lang']]
54
+ if args['retag_method'] is None:
55
+ args['retag_method'] = 'xpos'
56
+
57
+ if args['retag_method'] == 'xpos':
58
+ args['retag_xpos'] = True
59
+ elif args['retag_method'] == 'upos':
60
+ args['retag_xpos'] = False
61
+ else:
62
+ raise ValueError("Unknown retag method {}".format(xpos))
63
+
64
+ def build_retag_pipeline(args):
65
+ """
66
+ Builds retag pipelines based on the arguments
67
+
68
+ May alter the arguments if the pipeline is incompatible, such as
69
+ taggers with no xpos
70
+
71
+ Will return a list of one or more retag pipelines.
72
+ Multiple tagger models can be specified by having them
73
+ semi-colon separated in retag_model_path.
74
+ """
75
+ # some argument sets might not use 'mode'
76
+ if args['retag_package'] is not None and args.get('mode', None) != 'remove_optimizer':
77
+ download_resources_json()
78
+ resources = load_resources_json()
79
+
80
+ if '_' in args['retag_package']:
81
+ lang, package = args['retag_package'].split('_', 1)
82
+ lang_resources = get_language_resources(resources, lang)
83
+ if lang_resources is None and 'lang' in args:
84
+ lang_resources = get_language_resources(resources, args['lang'])
85
+ if lang_resources is not None and 'pos' in lang_resources and args['retag_package'] in lang_resources['pos']:
86
+ lang = args['lang']
87
+ package = args['retag_package']
88
+ else:
89
+ if 'lang' not in args:
90
+ raise ValueError("Retag package %s does not specify the language, and it is not clear from the arguments" % args['retag_package'])
91
+ lang = args.get('lang', None)
92
+ package = args['retag_package']
93
+ foundation_cache = FoundationCache()
94
+ retag_args = {"lang": lang,
95
+ "processors": "tokenize, pos",
96
+ "tokenize_pretokenized": True,
97
+ "package": {"pos": package}}
98
+ if args['retag_pretrain_path'] is not None:
99
+ retag_args['pos_pretrain_path'] = args['retag_pretrain_path']
100
+ if args['retag_charlm_forward_file'] is not None:
101
+ retag_args['pos_forward_charlm_path'] = args['retag_charlm_forward_file']
102
+ if args['retag_charlm_backward_file'] is not None:
103
+ retag_args['pos_backward_charlm_path'] = args['retag_charlm_backward_file']
104
+
105
+ def build(retag_args, path):
106
+ retag_args = copy.deepcopy(retag_args)
107
+ # we just downloaded the resources a moment ago
108
+ # no need to repeatedly download
109
+ retag_args['download_method'] = 'reuse_resources'
110
+ if path is not None:
111
+ retag_args['allow_unknown_language'] = True
112
+ retag_args['pos_model_path'] = path
113
+ tlogger.debug('Creating retag pipeline using %s', path)
114
+ else:
115
+ tlogger.debug('Creating retag pipeline for %s package', package)
116
+
117
+ retag_pipeline = Pipeline(foundation_cache=foundation_cache, **retag_args)
118
+ if args['retag_xpos'] and len(retag_pipeline.processors['pos'].vocab['xpos']) == len(VOCAB_PREFIX):
119
+ tlogger.warning("XPOS for the %s tagger is empty. Switching to UPOS", package)
120
+ args['retag_xpos'] = False
121
+ args['retag_method'] = 'upos'
122
+ return retag_pipeline
123
+
124
+ if args['retag_model_path'] is None:
125
+ return [build(retag_args, None)]
126
+ paths = args['retag_model_path'].split(";")
127
+ # can be length 1 if only one tagger to work with
128
+ return [build(retag_args, path) for path in paths]
129
+
130
+ return None
stanza/stanza/models/constituency/state.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ class State(namedtuple('State', ['word_queue', 'transitions', 'constituents', 'gold_tree', 'gold_sequence',
4
+ 'sentence_length', 'num_opens', 'word_position', 'score'])):
5
+ """
6
+ Represents a partially completed transition parse
7
+
8
+ Includes stack/buffers for unused words, already executed transitions, and partially build constituents
9
+ At training time, also keeps track of the gold data we are reparsing
10
+
11
+ num_opens is useful for tracking
12
+ 1) if the parser is in a stuck state where it is making infinite opens
13
+ 2) if a close transition is impossible because there are no previous opens
14
+
15
+ sentence_length tracks how long the sentence is so we abort if we go infinite
16
+
17
+ non-stack information such as sentence_length and num_opens
18
+ will be copied from the original_state if possible, with the
19
+ exact arguments overriding the values in the original_state
20
+
21
+ gold_tree: the original tree, if made from a gold tree. might be None
22
+ gold_sequence: the original transition sequence, if available
23
+ Note that at runtime, gold values will not be available
24
+
25
+ word_position tracks where in the word queue we are. cheaper than
26
+ manipulating the list itself. this can be handled differently
27
+ from transitions and constituents as it is processed once
28
+ at the start of parsing
29
+
30
+ The word_queue should have both a start and an end word.
31
+ Those can be None in the case of the endpoints if they are unused.
32
+ """
33
+ def empty_word_queue(self):
34
+ # the first element of each stack is a sentinel with no value
35
+ # and no parent
36
+ return self.word_position == self.sentence_length
37
+
38
+ def empty_transitions(self):
39
+ # the first element of each stack is a sentinel with no value
40
+ # and no parent
41
+ return self.transitions.parent is None
42
+
43
+ def has_one_constituent(self):
44
+ # a length of 1 represents no constituents
45
+ return self.constituents.length == 2
46
+
47
+ @property
48
+ def empty_constituents(self):
49
+ return self.constituents.parent is None
50
+
51
+ def num_constituents(self):
52
+ return self.constituents.length - 1
53
+
54
+ @property
55
+ def num_transitions(self):
56
+ # -1 for the sentinel value
57
+ return self.transitions.length - 1
58
+
59
+ def get_word(self, pos):
60
+ # +1 to handle the initial sentinel value
61
+ # (which you can actually get with pos=-1)
62
+ return self.word_queue[pos+1]
63
+
64
+ def finished(self, model):
65
+ return self.empty_word_queue() and self.has_one_constituent() and model.get_top_constituent(self.constituents).label in model.root_labels
66
+
67
+ def get_tree(self, model):
68
+ return model.get_top_constituent(self.constituents)
69
+
70
+ def all_transitions(self, model):
71
+ # TODO: rewrite this to be nicer / faster? or just refactor?
72
+ all_transitions = []
73
+ transitions = self.transitions
74
+ while transitions.parent is not None:
75
+ all_transitions.append(model.get_top_transition(transitions))
76
+ transitions = transitions.parent
77
+ return list(reversed(all_transitions))
78
+
79
+ def all_constituents(self, model):
80
+ # TODO: rewrite this to be nicer / faster?
81
+ all_constituents = []
82
+ constituents = self.constituents
83
+ while constituents.parent is not None:
84
+ all_constituents.append(model.get_top_constituent(constituents))
85
+ constituents = constituents.parent
86
+ return list(reversed(all_constituents))
87
+
88
+ def all_words(self, model):
89
+ return [model.get_word(x) for x in self.word_queue]
90
+
91
+ def to_string(self, model):
92
+ return "State(\n buffer:%s\n transitions:%s\n constituents:%s\n word_position:%d num_opens:%d)" % (str(self.all_words(model)), str(self.all_transitions(model)), str(self.all_constituents(model)), self.word_position, self.num_opens)
93
+
94
+ def __str__(self):
95
+ return "State(\n buffer:%s\n transitions:%s\n constituents:%s)" % (str(self.word_queue), str(self.transitions), str(self.constituents))
96
+
97
+ class MultiState(namedtuple('MultiState', ['states', 'gold_tree', 'gold_sequence', 'score'])):
98
+ def finished(self, ensemble):
99
+ return self.states[0].finished(ensemble.models[0])
100
+
101
+ def get_tree(self, ensemble):
102
+ return self.states[0].get_tree(ensemble.models[0])
103
+
104
+ @property
105
+ def empty_constituents(self):
106
+ return self.states[0].empty_constituents
107
+
108
+ def num_constituents(self):
109
+ return len(self.states[0].constituents) - 1
110
+
111
+ @property
112
+ def num_transitions(self):
113
+ # -1 for the sentinel value
114
+ return len(self.states[0].transitions) - 1
115
+
116
+ @property
117
+ def num_opens(self):
118
+ return self.states[0].num_opens
119
+
120
+ @property
121
+ def sentence_length(self):
122
+ return self.states[0].sentence_length
123
+
124
+ def empty_word_queue(self):
125
+ return self.states[0].empty_word_queue()
126
+
127
+ def empty_transitions(self):
128
+ return self.states[0].empty_transitions()
129
+
130
+ @property
131
+ def constituents(self):
132
+ # warning! if there is information in the constituents such as
133
+ # the embedding of the constituent, this will only contain the
134
+ # first such embedding
135
+ # the other models' constituent states won't be returned
136
+ return self.states[0].constituents
137
+
138
+ @property
139
+ def transitions(self):
140
+ # warning! if there is information in the transitions such as
141
+ # the embedding of the transition, this will only contain the
142
+ # first such embedding
143
+ # the other models' transition states won't be returned
144
+ return self.states[0].transitions
stanza/stanza/models/constituency/top_down_oracle.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import random
3
+
4
+ from stanza.models.constituency.dynamic_oracle import advance_past_constituents, score_candidates, DynamicOracle, RepairEnum
5
+ from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent
6
+
7
+ def find_constituent_end(gold_sequence, cur_index):
8
+ """
9
+ Find the Close which ends the next constituent opened at or after cur_index
10
+ """
11
+ count = 0
12
+ while cur_index < len(gold_sequence):
13
+ if isinstance(gold_sequence[cur_index], OpenConstituent):
14
+ count = count + 1
15
+ elif isinstance(gold_sequence[cur_index], CloseConstituent):
16
+ count = count - 1
17
+ if count == 0:
18
+ return cur_index
19
+ cur_index += 1
20
+ raise AssertionError("Open constituent not closed starting from index %d in sequence %s" % (cur_index, gold_sequence))
21
+
22
+ def fix_shift_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
23
+ """
24
+ Predicted a close when we should have shifted
25
+
26
+ The fix here is to remove the corresponding close from later in
27
+ the transition sequence. The rest of the tree building is the same,
28
+ including doing the missing Shift immediately after
29
+
30
+ Anything else would make the situation of one precision, one
31
+ recall error worse
32
+ """
33
+ if not isinstance(pred_transition, CloseConstituent):
34
+ return None
35
+
36
+ if not isinstance(gold_transition, Shift):
37
+ return None
38
+
39
+ close_index = advance_past_constituents(gold_sequence, gold_index)
40
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + gold_sequence[close_index+1:]
41
+
42
+ def fix_open_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
43
+ """
44
+ Predicted a close when we should have opened a constituent
45
+
46
+ In this case, the previous constituent is now a precision and
47
+ recall error, BUT we can salvage the constituent we were about to
48
+ open by proceeding as if everything else is still the same.
49
+
50
+ The next thing the model should do is open the transition it forgot about
51
+ """
52
+ if not isinstance(pred_transition, CloseConstituent):
53
+ return None
54
+
55
+ if not isinstance(gold_transition, OpenConstituent):
56
+ return None
57
+
58
+ close_index = advance_past_constituents(gold_sequence, gold_index)
59
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + gold_sequence[close_index+1:]
60
+
61
+ def fix_one_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
62
+ """
63
+ Predicted a shift when we should have opened a constituent
64
+
65
+ This causes a single recall error if we just pretend that
66
+ constituent didn't exist
67
+
68
+ Keep the shift where it was, remove the next shift
69
+ Also, scroll ahead, find the corresponding close, cut it out
70
+
71
+ For the corresponding multiple opens, shift error, see fix_multiple_open_shift
72
+ """
73
+ if not isinstance(pred_transition, Shift):
74
+ return None
75
+
76
+ if not isinstance(gold_transition, OpenConstituent):
77
+ return None
78
+
79
+ if not isinstance(gold_sequence[gold_index + 1], Shift):
80
+ return None
81
+
82
+ shift_index = gold_index + 1
83
+ close_index = advance_past_constituents(gold_sequence, gold_index + 1)
84
+ if close_index is None:
85
+ return None
86
+ # gold_index is the skipped open constituent
87
+ # close_index was the corresponding close
88
+ # shift_index is the shift to remove
89
+ updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index+1:shift_index] + gold_sequence[shift_index+1:close_index] + gold_sequence[close_index+1:]
90
+ #print("Input sequence: %s\nIndex %d\nGold %s Pred %s\nUpdated sequence %s" % (gold_sequence, gold_index, gold_transition, pred_transition, updated_sequence))
91
+ return updated_sequence
92
+
93
+ def fix_multiple_open_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
94
+ """
95
+ Predicted a shift when we should have opened multiple constituents instead
96
+
97
+ This causes a single recall error per constituent if we just
98
+ pretend those constituents don't exist
99
+
100
+ For each open constituent, we find the corresponding close,
101
+ then remove both the open & close
102
+ """
103
+ if not isinstance(pred_transition, Shift):
104
+ return None
105
+
106
+ if not isinstance(gold_transition, OpenConstituent):
107
+ return None
108
+
109
+ shift_index = gold_index
110
+ while shift_index < len(gold_sequence) and isinstance(gold_sequence[shift_index], OpenConstituent):
111
+ shift_index += 1
112
+ if shift_index >= len(gold_sequence):
113
+ raise AssertionError("Found a sequence of OpenConstituent at the end of a TOP_DOWN sequence!")
114
+ if not isinstance(gold_sequence[shift_index], Shift):
115
+ raise AssertionError("Expected to find a Shift after a sequence of OpenConstituent. There should not be a %s" % gold_sequence[shift_index])
116
+
117
+ #print("Input sequence: %s\nIndex %d\nGold %s Pred %s" % (gold_sequence, gold_index, gold_transition, pred_transition))
118
+ updated_sequence = gold_sequence
119
+ while shift_index > gold_index:
120
+ close_index = advance_past_constituents(updated_sequence, shift_index)
121
+ if close_index is None:
122
+ raise AssertionError("Did not find a corresponding Close for this Open")
123
+ # cut out the corresponding open and close
124
+ updated_sequence = updated_sequence[:shift_index-1] + updated_sequence[shift_index:close_index] + updated_sequence[close_index+1:]
125
+ shift_index -= 1
126
+ #print(" %s" % updated_sequence)
127
+
128
+ #print("Final updated sequence: %s" % updated_sequence)
129
+ return updated_sequence
130
+
131
+ def fix_nested_open_constituent(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
132
+ """
133
+ We were supposed to predict Open(X), then Open(Y), but predicted Open(Y) instead
134
+
135
+ We treat this as a single recall error.
136
+
137
+ We could even go crazy and turn it into a Unary,
138
+ such as Open(Y), Open(X), Open(Y)...
139
+ presumably that would be very confusing to the parser
140
+ not to mention ambiguous as to where to close the new constituent
141
+ """
142
+ if not isinstance(pred_transition, OpenConstituent):
143
+ return None
144
+
145
+ if not isinstance(gold_transition, OpenConstituent):
146
+ return None
147
+
148
+ assert len(gold_sequence) > gold_index + 1
149
+
150
+ if not isinstance(gold_sequence[gold_index+1], OpenConstituent):
151
+ return None
152
+
153
+ # This replacement works if we skipped exactly one level
154
+ if gold_sequence[gold_index+1].label != pred_transition.label:
155
+ return None
156
+
157
+ close_index = advance_past_constituents(gold_sequence, gold_index+1)
158
+ assert close_index is not None
159
+ updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + gold_sequence[close_index+1:]
160
+ return updated_sequence
161
+
162
+ def fix_shift_open_immediate_close(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
163
+ """
164
+ We were supposed to Shift, but instead we Opened
165
+
166
+ The biggest problem with this type of error is that the Close of
167
+ the Open is ambiguous. We could put it immediately before the
168
+ next Close, immediately after the Shift, or anywhere in between.
169
+
170
+ One unambiguous case would be if the proper sequence was Shift - Close.
171
+ Then it is unambiguous that the only possible repair is Open - Shift - Close - Close.
172
+ """
173
+ if not isinstance(pred_transition, OpenConstituent):
174
+ return None
175
+
176
+ if not isinstance(gold_transition, Shift):
177
+ return None
178
+
179
+ assert len(gold_sequence) > gold_index + 1
180
+ if not isinstance(gold_sequence[gold_index+1], CloseConstituent):
181
+ # this is the ambiguous case
182
+ return None
183
+
184
+ return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]
185
+
186
+ def fix_shift_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
187
+ """
188
+ We were supposed to Shift, but instead we Opened
189
+
190
+ The biggest problem with this type of error is that the Close of
191
+ the Open is ambiguous. We could put it immediately before the
192
+ next Close, immediately after the Shift, or anywhere in between.
193
+
194
+ In this fix, we are testing what happens if we treat this Open as a Unary transition.
195
+ """
196
+ if not isinstance(pred_transition, OpenConstituent):
197
+ return None
198
+
199
+ if not isinstance(gold_transition, Shift):
200
+ return None
201
+
202
+ assert len(gold_sequence) > gold_index + 1
203
+ if isinstance(gold_sequence[gold_index+1], CloseConstituent):
204
+ # this is the unambiguous case, which should already be handled
205
+ return None
206
+
207
+ return gold_sequence[:gold_index] + [pred_transition, gold_transition, CloseConstituent()] + gold_sequence[gold_index+1:]
208
+
209
+ def fix_shift_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
210
+ """
211
+ We were supposed to Shift, but instead we Opened
212
+
213
+ The biggest problem with this type of error is that the Close of
214
+ the Open is ambiguous. We could put it immediately before the
215
+ next Close, immediately after the Shift, or anywhere in between.
216
+
217
+ In this fix, we put the corresponding Close for this Open at the end of the enclosing bracket.
218
+ """
219
+ if not isinstance(pred_transition, OpenConstituent):
220
+ return None
221
+
222
+ if not isinstance(gold_transition, Shift):
223
+ return None
224
+
225
+ assert len(gold_sequence) > gold_index + 1
226
+ if isinstance(gold_sequence[gold_index+1], CloseConstituent):
227
+ # this is the unambiguous case, which should already be handled
228
+ return None
229
+
230
+ outer_close_index = advance_past_constituents(gold_sequence, gold_index)
231
+
232
+ return gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:outer_close_index] + [CloseConstituent()] + gold_sequence[outer_close_index:]
233
+
234
+ def fix_shift_open_ambiguous_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
235
+ if not isinstance(pred_transition, OpenConstituent):
236
+ return None
237
+
238
+ if not isinstance(gold_transition, Shift):
239
+ return None
240
+
241
+ assert len(gold_sequence) > gold_index + 1
242
+ if isinstance(gold_sequence[gold_index+1], CloseConstituent):
243
+ # this is the unambiguous case, which should already be handled
244
+ return None
245
+
246
+ # at this point: have Opened a constituent which we don't want
247
+ # need to figure out where to Close it
248
+ # could close it after the shift or after any given block
249
+ candidates = []
250
+ current_index = gold_index
251
+ while not isinstance(gold_sequence[current_index], CloseConstituent):
252
+ if isinstance(gold_sequence[current_index], Shift):
253
+ end_index = current_index
254
+ else:
255
+ end_index = find_constituent_end(gold_sequence, current_index)
256
+ candidates.append((gold_sequence[:gold_index], [pred_transition], gold_sequence[gold_index:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))
257
+ current_index = end_index + 1
258
+
259
+ scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=3)
260
+ if best_idx == len(candidates) - 1:
261
+ best_idx = -1
262
+ repair_type = RepairEnum(name=RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.name,
263
+ value="%d.%d" % (RepairType.SHIFT_OPEN_AMBIGUOUS_PREDICTED.value, best_idx),
264
+ is_correct=False)
265
+ return repair_type, best_candidate
266
+
267
+
268
+ def fix_close_shift_ambiguous_immediate(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
269
+ """
270
+ Instead of a Close, we predicted a Shift. This time, we immediately close no matter what comes after the next Shift.
271
+
272
+ An alternate strategy would be to Close at the closing of the outer constituent.
273
+ """
274
+ if not isinstance(pred_transition, Shift):
275
+ return None
276
+
277
+ if not isinstance(gold_transition, CloseConstituent):
278
+ return None
279
+
280
+ num_closes = 0
281
+ while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
282
+ num_closes += 1
283
+
284
+ if not isinstance(gold_sequence[gold_index + num_closes], Shift):
285
+ # TODO: we should be able to handle this case too (an Open)
286
+ # however, it will be rare once the parser gets going and it
287
+ # would cause a lot of errors, anyway
288
+ return None
289
+
290
+ if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):
291
+ # this one should just have been satisfied in the non-ambiguous version
292
+ return None
293
+
294
+ updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+1:]
295
+ return updated_sequence
296
+
297
+
298
+ def fix_close_shift_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
299
+ """
300
+ Instead of a Close, we predicted a Shift. This time, we close at the end of the outer bracket no matter what comes after the next Shift.
301
+
302
+ An alternate strategy would be to Close as soon as possible after the Shift.
303
+ """
304
+ if not isinstance(pred_transition, Shift):
305
+ return None
306
+
307
+ if not isinstance(gold_transition, CloseConstituent):
308
+ return None
309
+
310
+ num_closes = 0
311
+ while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
312
+ num_closes += 1
313
+
314
+ if not isinstance(gold_sequence[gold_index + num_closes], Shift):
315
+ # TODO: we should be able to handle this case too (an Open)
316
+ # however, it will be rare once the parser gets going and it
317
+ # would cause a lot of errors, anyway
318
+ return None
319
+
320
+ if isinstance(gold_sequence[gold_index + num_closes + 1], CloseConstituent):
321
+ # this one should just have been satisfied in the non-ambiguous version
322
+ return None
323
+
324
+ # outer_close_index is now where the constituent which the broken constituent(s) reside inside gets closed
325
+ outer_close_index = advance_past_constituents(gold_sequence, gold_index + num_closes)
326
+
327
+ updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+num_closes:outer_close_index] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[outer_close_index:]
328
+ return updated_sequence
329
+
330
+
331
+ def fix_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, count_opens=False):
332
+ """
333
+ We were supposed to Close, but instead did a Shift
334
+
335
+ In most cases, this will be ambiguous. There is now a constituent
336
+ which has been missed, no matter what we do, and we are on the
337
+ hook for eventually closing this constituent, creating a precision
338
+ error as well. The ambiguity arises because there will be
339
+ multiple places where the Close could occur if there are more
340
+ constituents created between now and when the outer constituent is
341
+ Closed.
342
+
343
+ The non-ambiguous case is if the proper sequence was
344
+ Close - Shift - Close
345
+ similar cases are also non-ambiguous, such as
346
+ Close - Close - Shift - Close
347
+ for that matter, so is the following, although the Opens will be lost
348
+ Close - Open - Shift - Close - Close
349
+
350
+ count_opens is an option to make it easy to count with or without
351
+ Open as different oracle fixes
352
+ """
353
+ if not isinstance(pred_transition, Shift):
354
+ return None
355
+
356
+ if not isinstance(gold_transition, CloseConstituent):
357
+ return None
358
+
359
+ num_closes = 0
360
+ while isinstance(gold_sequence[gold_index + num_closes], CloseConstituent):
361
+ num_closes += 1
362
+
363
+ # We may allow unary transitions here
364
+ # the opens will be lost in the repaired sequence
365
+ num_opens = 0
366
+ if count_opens:
367
+ while isinstance(gold_sequence[gold_index + num_closes + num_opens], OpenConstituent):
368
+ num_opens += 1
369
+
370
+ if not isinstance(gold_sequence[gold_index + num_closes + num_opens], Shift):
371
+ if count_opens:
372
+ raise AssertionError("Should have found a Shift after a sequence of Opens or a Close with no Open. Started counting at %d in sequence %s" % (gold_index, gold_sequence))
373
+ return None
374
+
375
+ if not isinstance(gold_sequence[gold_index + num_closes + num_opens + 1], CloseConstituent):
376
+ return None
377
+ for idx in range(num_opens):
378
+ if not isinstance(gold_sequence[gold_index + num_closes + num_opens + idx + 1], CloseConstituent):
379
+ return None
380
+
381
+ # Now we know it is Close x num_closes, Shift, Close
382
+ # Since we have erroneously predicted a Shift now, the best we can
383
+ # do is to follow that, then add num_closes Closes
384
+ updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:gold_index+num_closes] + gold_sequence[gold_index+num_closes+num_opens*2+1:]
385
+ return updated_sequence
386
+
387
+ def fix_close_shift_with_opens(*args, **kwargs):
388
+ return fix_close_shift(*args, **kwargs, count_opens=True)
389
+
390
+ def fix_close_next_correct_predicted(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
391
+ """
392
+ We were supposed to Close, but instead predicted Shift when the next transition is Shift
393
+
394
+ This differs from the previous Close-Shift in that this case does
395
+ not have an unambiguous place to put the Close. Instead, we let
396
+ the model predict where to put the Close
397
+
398
+ Note that this can also work for Close-Open with the next Open correct
399
+
400
+ Not covered (yet?) is multiple Close in a row
401
+ """
402
+ if not isinstance(gold_transition, CloseConstituent):
403
+ return None
404
+ if not isinstance(pred_transition, (Shift, OpenConstituent)):
405
+ return None
406
+ if gold_sequence[gold_index+1] != pred_transition:
407
+ return None
408
+
409
+ candidates = []
410
+ current_index = gold_index + 1
411
+ while not isinstance(gold_sequence[current_index], CloseConstituent):
412
+ if isinstance(gold_sequence[current_index], Shift):
413
+ end_index = current_index
414
+ else:
415
+ end_index = find_constituent_end(gold_sequence, current_index)
416
+ candidates.append((gold_sequence[:gold_index], gold_sequence[gold_index+1:end_index+1], [CloseConstituent()], gold_sequence[end_index+1:]))
417
+ current_index = end_index + 1
418
+
419
+ scores, best_idx, best_candidate = score_candidates(model, state, candidates, candidate_idx=3)
420
+ if best_idx == len(candidates) - 1:
421
+ best_idx = -1
422
+ repair_type = RepairEnum(name=RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.name,
423
+ value="%d.%d" % (RepairType.CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED.value, best_idx),
424
+ is_correct=False)
425
+ return repair_type, best_candidate
426
+
427
+
428
+ def fix_close_open_correct_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):
429
+ """
430
+ We were supposed to Close, but instead did an Open
431
+
432
+ In general this is ambiguous (like close/shift), as we need to know when to close the incorrect constituent
433
+
434
+ A case that is not ambiguous is when exactly one constituent was
435
+ supposed to come after the Close and it matches the Open we just
436
+ created. In that case, we treat that constituent as if it were
437
+ part of the non-Closed constituent. For example,
438
+ "ate (NP spaghetti) (PP with a fork)" ->
439
+ "ate (NP spaghetti (PP with a fork))"
440
+ (delicious)
441
+
442
+ There is also an option to not check for the Close after the first
443
+ constituent, in which case any number of constituents could have
444
+ been predicted. This represents a solution of the ambiguous form
445
+ of the Close/Open transition where the Close could occur in
446
+ multiple places later in the sequence.
447
+ """
448
+ if not isinstance(pred_transition, OpenConstituent):
449
+ return None
450
+
451
+ if not isinstance(gold_transition, CloseConstituent):
452
+ return None
453
+
454
+ if gold_sequence[gold_index+1] != pred_transition:
455
+ return None
456
+
457
+ close_index = find_constituent_end(gold_sequence, gold_index+1)
458
+ if check_close and not isinstance(gold_sequence[close_index+1], CloseConstituent):
459
+ return None
460
+
461
+ # at this point, we know we can put the Close at the end of the
462
+ # Open which was accidentally added
463
+ updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index+1] + [gold_transition] + gold_sequence[close_index+1:]
464
+ return updated_sequence
465
+
466
+ def fix_close_open_correct_open_ambiguous_immediate(*args, **kwargs):
467
+ return fix_close_open_correct_open(*args, **kwargs, check_close=False)
468
+
469
+ def fix_close_open_correct_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state, check_close=True):
470
+ """
471
+ We were supposed to Close, but instead did an Open in an ambiguous context. Here we resolve it later in the tree
472
+ """
473
+ if not isinstance(pred_transition, OpenConstituent):
474
+ return None
475
+
476
+ if not isinstance(gold_transition, CloseConstituent):
477
+ return None
478
+
479
+ if gold_sequence[gold_index+1] != pred_transition:
480
+ return None
481
+
482
+ # this will be the index of the Close for the surrounding constituent
483
+ close_index = advance_past_constituents(gold_sequence, gold_index+1)
484
+ updated_sequence = gold_sequence[:gold_index] + gold_sequence[gold_index+1:close_index] + [gold_transition] + gold_sequence[close_index:]
485
+ return updated_sequence
486
+
487
+ def fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
488
+ """
489
+ If there is an Open/Open error which is not covered by the unambiguous single recall error, we try fixing it as a Unary
490
+ """
491
+ if not isinstance(pred_transition, OpenConstituent):
492
+ return None
493
+
494
+ if not isinstance(gold_transition, OpenConstituent):
495
+ return None
496
+
497
+ if pred_transition == gold_transition:
498
+ return None
499
+ if gold_sequence[gold_index+1] == pred_transition:
500
+ # This case is covered by the nested open repair
501
+ return None
502
+
503
+ close_index = find_constituent_end(gold_sequence, gold_index)
504
+ assert close_index is not None
505
+ assert isinstance(gold_sequence[close_index], CloseConstituent)
506
+ updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]
507
+ return updated_sequence
508
+
509
+ def fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
510
+ """
511
+ If there is an Open/Open error which is not covered by the
512
+ unambiguous single recall error, we try fixing it by putting the
513
+ close at the end of the outer constituent
514
+
515
+ """
516
+ if not isinstance(pred_transition, OpenConstituent):
517
+ return None
518
+
519
+ if not isinstance(gold_transition, OpenConstituent):
520
+ return None
521
+
522
+ if pred_transition == gold_transition:
523
+ return None
524
+ if gold_sequence[gold_index+1] == pred_transition:
525
+ # This case is covered by the nested open repair
526
+ return None
527
+
528
+ close_index = advance_past_constituents(gold_sequence, gold_index)
529
+ updated_sequence = gold_sequence[:gold_index] + [pred_transition] + gold_sequence[gold_index:close_index] + [CloseConstituent()] + gold_sequence[close_index:]
530
+ return updated_sequence
531
+
532
+ def fix_open_open_ambiguous_random(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
533
+ """
534
+ If there is an Open/Open error which is not covered by the
535
+ unambiguous single recall error, we try fixing it by putting the
536
+ close at the end of the outer constituent
537
+
538
+ """
539
+ if not isinstance(pred_transition, OpenConstituent):
540
+ return None
541
+
542
+ if not isinstance(gold_transition, OpenConstituent):
543
+ return None
544
+
545
+ if pred_transition == gold_transition:
546
+ return None
547
+ if gold_sequence[gold_index+1] == pred_transition:
548
+ # This case is covered by the nested open repair
549
+ return None
550
+
551
+ if random.random() < 0.5:
552
+ return fix_open_open_ambiguous_later(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)
553
+ else:
554
+ return fix_open_open_ambiguous_unary(gold_transition, pred_transition, gold_sequence, gold_index, root_labels)
555
+
556
+
557
+ def report_shift_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
558
+ if not isinstance(gold_transition, Shift):
559
+ return None
560
+ if not isinstance(pred_transition, OpenConstituent):
561
+ return None
562
+
563
+ return RepairType.OTHER_SHIFT_OPEN, None
564
+
565
+
566
+ def report_close_shift(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
567
+ if not isinstance(gold_transition, CloseConstituent):
568
+ return None
569
+ if not isinstance(pred_transition, Shift):
570
+ return None
571
+
572
+ return RepairType.OTHER_CLOSE_SHIFT, None
573
+
574
+ def report_close_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
575
+ if not isinstance(gold_transition, CloseConstituent):
576
+ return None
577
+ if not isinstance(pred_transition, OpenConstituent):
578
+ return None
579
+
580
+ return RepairType.OTHER_CLOSE_OPEN, None
581
+
582
+ def report_open_open(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, model, state):
583
+ if not isinstance(gold_transition, OpenConstituent):
584
+ return None
585
+ if not isinstance(pred_transition, OpenConstituent):
586
+ return None
587
+
588
+ return RepairType.OTHER_OPEN_OPEN, None
589
+
590
+
591
+ class RepairType(Enum):
592
+ """
593
+ Keep track of which repair is used, if any, on an incorrect transition
594
+
595
+ A test of the top-down oracle with no charlm or transformer
596
+ (eg, word vectors only) on EN PTB3 goes as follows.
597
+ 3x training rounds, best training parameters as of Jan. 2024
598
+ unambiguous transitions only:
599
+ oracle scheme dev test
600
+ no oracle 0.9230 0.9194
601
+ +shift/close 0.9224 0.9180
602
+ +open/close 0.9225 0.9193
603
+ +open/shift (one) 0.9245 0.9207
604
+ +open/shift (mult) 0.9243 0.9211
605
+ +open/open nested 0.9258 0.9213
606
+ +shift/open 0.9266 0.9229
607
+ +close/shift (only) 0.9270 0.9230
608
+ +close/shift w/ opens 0.9262 0.9221
609
+ +close/open one con 0.9273 0.9230
610
+
611
+ Potential solutions for various ambiguous transitions:
612
+
613
+ close/open
614
+ can close immediately after the corresponding constituent or after any number of constituents
615
+
616
+ close/shift
617
+ can close immediately
618
+ can close anywhere up to the next close
619
+ any number of missed Opens are treated as recall errors
620
+
621
+ open/open
622
+ could treat as unary
623
+ could close at any number of positions after the next structures, up to the outer open's closing
624
+
625
+ shift/open ambiguity resolutions:
626
+ treat as unary
627
+ treat as wrapper around the next full constituent to build
628
+ treat as wrapper around everything to build until the next constituent
629
+
630
+ testing one at a time in addition to the full set of unambiguous corrections:
631
+ +close/open immediate 0.9259 0.9225
632
+ +close/open later 0.9258 0.9257
633
+ +close/shift immediate 0.9261 0.9219
634
+ +close/shift later 0.9270 0.9230
635
+ +open/open later 0.9269 0.9239
636
+ +open/open unary 0.9275 0.9246
637
+ +shift/open later 0.9263 0.9253
638
+ +shift/open unary 0.9264 0.9243
639
+
640
+ so there is some evidence that open/open or shift/open would be beneficial
641
+
642
+ Training by randomly choosing between the open/open, 50/50
643
+ +open/open random 0.9257 0.9235
644
+ so that didn't work great compared to the individual transitions
645
+
646
+ Testing deterministic resolutions of the ambiguous transitions
647
+ vs predicting the appropriate transition to use:
648
+ SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR,CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR,CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR
649
+ SHIFT_OPEN_AMBIGUOUS_PREDICTED,CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED
650
+
651
+ EN ambiguous (no charlm or transformer) 0.9268 0.9231
652
+ EN predicted 0.9270 0.9257
653
+ EN none of the above 0.9268 0.9229
654
+
655
+ ZH ambiguous 0.9137 0.9127
656
+ ZH predicted 0.9148 0.9141
657
+ ZH none of the above 0.9141 0.9143
658
+
659
+ DE ambiguous 0.9579 0.9408
660
+ DE predicted 0.9575 0.9406
661
+ DE none of the above 0.9581 0.9411
662
+
663
+ ID ambiguous 0.8889 0.8794
664
+ ID predicted 0.8911 0.8801
665
+ ID none of the above 0.8913 0.8822
666
+
667
+ IT ambiguous 0.8404 0.8380
668
+ IT predicted 0.8397 0.8398
669
+ IT none of the above 0.8400 0.8409
670
+
671
+ VI ambiguous 0.8290 0.7676
672
+ VI predicted 0.8287 0.7682
673
+ VI none of the above 0.8292 0.7691
674
+ """
675
+ def __new__(cls, fn, correct=False, debug=False):
676
+ """
677
+ Enumerate values as normal, but also keep a pointer to a function which repairs that kind of error
678
+ """
679
+ value = len(cls.__members__)
680
+ obj = object.__new__(cls)
681
+ obj._value_ = value + 1
682
+ obj.fn = fn
683
+ obj.correct = correct
684
+ obj.debug = debug
685
+ return obj
686
+
687
+ @property
688
+ def is_correct(self):
689
+ return self.correct
690
+
691
+ # The parser chose to close a bracket instead of shift something
692
+ # into the bracket
693
+ # This causes both a precision and a recall error as there is now
694
+ # an incorrect bracket and a missing correct bracket
695
+ # Any bracket creation here would cause more wrong brackets, though
696
+ SHIFT_CLOSE_ERROR = (fix_shift_close,)
697
+
698
+ OPEN_CLOSE_ERROR = (fix_open_close,)
699
+
700
+ # open followed by shift was instead predicted to be shift
701
+ ONE_OPEN_SHIFT_ERROR = (fix_one_open_shift,)
702
+
703
+ # open followed by shift was instead predicted to be shift
704
+ MULTIPLE_OPEN_SHIFT_ERROR = (fix_multiple_open_shift,)
705
+
706
+ # should have done Open(X), Open(Y)
707
+ # instead just did Open(Y)
708
+ NESTED_OPEN_OPEN_ERROR = (fix_nested_open_constituent,)
709
+
710
+ SHIFT_OPEN_ERROR = (fix_shift_open_immediate_close,)
711
+
712
+ CLOSE_SHIFT_ERROR = (fix_close_shift,)
713
+
714
+ CLOSE_SHIFT_WITH_OPENS_ERROR = (fix_close_shift_with_opens,)
715
+
716
+ CLOSE_OPEN_ONE_CON_ERROR = (fix_close_open_correct_open,)
717
+
718
+ CORRECT = (None, True)
719
+
720
+ UNKNOWN = None
721
+
722
+ CLOSE_OPEN_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_open_correct_open_ambiguous_immediate,)
723
+
724
+ CLOSE_OPEN_AMBIGUOUS_LATER_ERROR = (fix_close_open_correct_open_ambiguous_later,)
725
+
726
+ CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR = (fix_close_shift_ambiguous_immediate,)
727
+
728
+ CLOSE_SHIFT_AMBIGUOUS_LATER_ERROR = (fix_close_shift_ambiguous_later,)
729
+
730
+ # can potentially fix either close/shift or close/open
731
+ # as long as the gold transition after the close
732
+ # was the same as the transition we just predicted
733
+ CLOSE_NEXT_CORRECT_AMBIGUOUS_PREDICTED = (fix_close_next_correct_predicted,)
734
+
735
+ OPEN_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_open_open_ambiguous_unary,)
736
+
737
+ OPEN_OPEN_AMBIGUOUS_LATER_ERROR = (fix_open_open_ambiguous_later,)
738
+
739
+ OPEN_OPEN_AMBIGUOUS_RANDOM_ERROR = (fix_open_open_ambiguous_random,)
740
+
741
+ SHIFT_OPEN_AMBIGUOUS_UNARY_ERROR = (fix_shift_open_ambiguous_unary,)
742
+
743
+ SHIFT_OPEN_AMBIGUOUS_LATER_ERROR = (fix_shift_open_ambiguous_later,)
744
+
745
+ SHIFT_OPEN_AMBIGUOUS_PREDICTED = (fix_shift_open_ambiguous_predicted,)
746
+
747
+ OTHER_SHIFT_OPEN = (report_shift_open, False, True)
748
+
749
+ OTHER_CLOSE_SHIFT = (report_close_shift, False, True)
750
+
751
+ OTHER_CLOSE_OPEN = (report_close_open, False, True)
752
+
753
+ OTHER_OPEN_OPEN = (report_open_open, False, True)
754
+
755
+ class TopDownOracle(DynamicOracle):
756
+ def __init__(self, root_labels, oracle_level, additional_oracle_levels, deactivated_oracle_levels):
757
+ super().__init__(root_labels, oracle_level, RepairType, additional_oracle_levels, deactivated_oracle_levels)
stanza/stanza/models/constituency/trainer.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file includes a variety of methods needed to train new
3
+ constituency parsers. It also includes a method to load an
4
+ already-trained parser.
5
+
6
+ See the `train` method for the code block which starts from
7
+ raw treebank and returns a new parser.
8
+ `evaluate` reads a treebank and gives a score for those trees.
9
+ """
10
+
11
+ import copy
12
+ import logging
13
+ import os
14
+
15
+ import torch
16
+
17
+ from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain, NoTransformerFoundationCache
18
+ from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper, pop_peft_args
19
+ from stanza.models.constituency.base_trainer import BaseTrainer, ModelType
20
+ from stanza.models.constituency.lstm_model import LSTMModel, SentenceBoundary, StackHistory, ConstituencyComposition
21
+ from stanza.models.constituency.parse_transitions import Transition, TransitionScheme
22
+ from stanza.models.constituency.utils import build_optimizer, build_scheduler
23
+ # TODO: could put find_wordvec_pretrain, choose_charlm, etc in a more central place if it becomes widely used
24
+ from stanza.utils.training.common import find_wordvec_pretrain, choose_charlm, find_charlm_file
25
+ from stanza.resources.default_packages import default_charlms, default_pretrains
26
+
27
+ logger = logging.getLogger('stanza')
28
+ tlogger = logging.getLogger('stanza.constituency.trainer')
29
+
30
+ class Trainer(BaseTrainer):
31
+ """
32
+ Stores a constituency model and its optimizer
33
+
34
+ Not inheriting from common/trainer.py because there's no concept of change_lr (yet?)
35
+ """
36
+ def __init__(self, model, optimizer=None, scheduler=None, epochs_trained=0, batches_trained=0, best_f1=0.0, best_epoch=0, first_optimizer=False):
37
+ super().__init__(model, optimizer, scheduler, epochs_trained, batches_trained, best_f1, best_epoch, first_optimizer)
38
+
39
+ def save(self, filename, save_optimizer=True):
40
+ """
41
+ Save the model (and by default the optimizer) to the given path
42
+ """
43
+ super().save(filename, save_optimizer)
44
+
45
+ def get_peft_params(self):
46
+ # Hide import so that peft dependency is optional
47
+ if self.model.args.get('use_peft', False):
48
+ from peft import get_peft_model_state_dict
49
+ return get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
50
+ return None
51
+
52
+ @property
53
+ def model_type(self):
54
+ return ModelType.LSTM
55
+
56
+ @staticmethod
57
+ def find_and_load_pretrain(saved_args, foundation_cache):
58
+ if 'wordvec_pretrain_file' not in saved_args:
59
+ return None
60
+ if os.path.exists(saved_args['wordvec_pretrain_file']):
61
+ return load_pretrain(saved_args['wordvec_pretrain_file'], foundation_cache)
62
+ logger.info("Unable to find pretrain in %s Will try to load from the default resources instead", saved_args['wordvec_pretrain_file'])
63
+ language = saved_args['lang']
64
+ wordvec_pretrain = find_wordvec_pretrain(language, default_pretrains)
65
+ return load_pretrain(wordvec_pretrain, foundation_cache)
66
+
67
+ @staticmethod
68
+ def find_and_load_charlm(charlm_file, direction, saved_args, foundation_cache):
69
+ try:
70
+ return load_charlm(charlm_file, foundation_cache)
71
+ except FileNotFoundError as e:
72
+ logger.info("Unable to load charlm from %s Will try to load from the default resources instead", charlm_file)
73
+ language = saved_args['lang']
74
+ dataset = saved_args['shorthand'].split("_")[1]
75
+ charlm = choose_charlm(language, dataset, "default", default_charlms, {})
76
+ charlm_file = find_charlm_file(direction, language, charlm)
77
+ return load_charlm(charlm_file, foundation_cache)
78
+
79
+ def log_num_words_known(self, words):
80
+ tlogger.info("Number of words in the training set found in the embedding: %d out of %d", self.model.num_words_known(words), len(words))
81
+
82
+ @staticmethod
83
+ def load_optimizer(model, checkpoint, first_optimizer, filename):
84
+ optimizer = build_optimizer(model.args, model, first_optimizer)
85
+ if checkpoint.get('optimizer_state_dict', None) is not None:
86
+ try:
87
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
88
+ except ValueError as e:
89
+ raise ValueError("Failed to load optimizer from %s" % filename) from e
90
+ else:
91
+ logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
92
+ return optimizer
93
+
94
+ @staticmethod
95
+ def load_scheduler(model, optimizer, checkpoint, first_optimizer):
96
+ scheduler = build_scheduler(model.args, optimizer, first_optimizer=first_optimizer)
97
+ if 'scheduler_state_dict' in checkpoint:
98
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
99
+ return scheduler
100
+
101
+ @staticmethod
102
+ def model_from_params(params, peft_params, args, foundation_cache=None, peft_name=None):
103
+ """
104
+ Build a new model just from the saved params and some extra args
105
+
106
+ Refactoring allows other processors to include a constituency parser as a module
107
+ """
108
+ saved_args = dict(params['config'])
109
+ if isinstance(saved_args['sentence_boundary_vectors'], str):
110
+ saved_args['sentence_boundary_vectors'] = SentenceBoundary[saved_args['sentence_boundary_vectors']]
111
+ if isinstance(saved_args['constituency_composition'], str):
112
+ saved_args['constituency_composition'] = ConstituencyComposition[saved_args['constituency_composition']]
113
+ if isinstance(saved_args['transition_stack'], str):
114
+ saved_args['transition_stack'] = StackHistory[saved_args['transition_stack']]
115
+ if isinstance(saved_args['constituent_stack'], str):
116
+ saved_args['constituent_stack'] = StackHistory[saved_args['constituent_stack']]
117
+ if isinstance(saved_args['transition_scheme'], str):
118
+ saved_args['transition_scheme'] = TransitionScheme[saved_args['transition_scheme']]
119
+
120
+ # some parameters which change the structure of a model have
121
+ # to be ignored, or the model will not function when it is
122
+ # reloaded from disk
123
+ if args is None: args = {}
124
+ update_args = copy.deepcopy(args)
125
+ pop_peft_args(update_args)
126
+ update_args.pop("bert_hidden_layers", None)
127
+ update_args.pop("bert_model", None)
128
+ update_args.pop("constituency_composition", None)
129
+ update_args.pop("constituent_stack", None)
130
+ update_args.pop("num_tree_lstm_layers", None)
131
+ update_args.pop("transition_scheme", None)
132
+ update_args.pop("transition_stack", None)
133
+ update_args.pop("maxout_k", None)
134
+ # if the pretrain or charlms are not specified, don't override the values in the model
135
+ # (if any), since the model won't even work without loading the same charlm
136
+ if 'wordvec_pretrain_file' in update_args and update_args['wordvec_pretrain_file'] is None:
137
+ update_args.pop('wordvec_pretrain_file')
138
+ if 'charlm_forward_file' in update_args and update_args['charlm_forward_file'] is None:
139
+ update_args.pop('charlm_forward_file')
140
+ if 'charlm_backward_file' in update_args and update_args['charlm_backward_file'] is None:
141
+ update_args.pop('charlm_backward_file')
142
+ # we don't pop bert_finetune, with the theory being that if
143
+ # the saved model has bert_finetune==True we can load the bert
144
+ # weights but then not further finetune if bert_finetune==False
145
+ saved_args.update(update_args)
146
+
147
+ # TODO: not needed if we rebuild the models
148
+ if saved_args.get("bert_finetune", None) is None:
149
+ saved_args["bert_finetune"] = False
150
+ if saved_args.get("stage1_bert_finetune", None) is None:
151
+ saved_args["stage1_bert_finetune"] = False
152
+
153
+ model_type = params['model_type']
154
+ if model_type == 'LSTM':
155
+ pt = Trainer.find_and_load_pretrain(saved_args, foundation_cache)
156
+ if saved_args.get('use_peft', False):
157
+ # if loading a peft model, we first load the base transformer
158
+ # then we load the weights using the saved weights in the file
159
+ if peft_name is None:
160
+ bert_model, bert_tokenizer, peft_name = load_bert_with_peft(saved_args.get('bert_model', None), "constituency", foundation_cache)
161
+ else:
162
+ bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)
163
+ bert_model = load_peft_wrapper(bert_model, peft_params, saved_args, logger, peft_name)
164
+ bert_saved = True
165
+ elif saved_args['bert_finetune'] or saved_args['stage1_bert_finetune'] or any(x.startswith("bert_model.") for x in params['model'].keys()):
166
+ # if bert_finetune is True, don't use the cached model!
167
+ # otherwise, other uses of the cached model will be ruined
168
+ bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None))
169
+ bert_saved = True
170
+ else:
171
+ bert_model, bert_tokenizer = load_bert(saved_args.get('bert_model', None), foundation_cache)
172
+ bert_saved = False
173
+ forward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_forward_file"], "forward", saved_args, foundation_cache)
174
+ backward_charlm = Trainer.find_and_load_charlm(saved_args["charlm_backward_file"], "backward", saved_args, foundation_cache)
175
+
176
+ # TODO: the isinstance will be unnecessary after 1.10.0
177
+ transitions = params['transitions']
178
+ if all(isinstance(x, str) for x in transitions):
179
+ transitions = [Transition.from_repr(x) for x in transitions]
180
+
181
+ model = LSTMModel(pretrain=pt,
182
+ forward_charlm=forward_charlm,
183
+ backward_charlm=backward_charlm,
184
+ bert_model=bert_model,
185
+ bert_tokenizer=bert_tokenizer,
186
+ force_bert_saved=bert_saved,
187
+ peft_name=peft_name,
188
+ transitions=transitions,
189
+ constituents=params['constituents'],
190
+ tags=params['tags'],
191
+ words=params['words'],
192
+ rare_words=set(params['rare_words']),
193
+ root_labels=params['root_labels'],
194
+ constituent_opens=params['constituent_opens'],
195
+ unary_limit=params['unary_limit'],
196
+ args=saved_args)
197
+ else:
198
+ raise ValueError("Unknown model type {}".format(model_type))
199
+ model.load_state_dict(params['model'], strict=False)
200
+ # model will stay on CPU if device==None
201
+ # can be moved elsewhere later, of course
202
+ model = model.to(args.get('device', None))
203
+ return model
204
+
205
+ @staticmethod
206
+ def build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file):
207
+ # TODO: turn finetune, relearn_structure, multistage into an enum?
208
+ # finetune just means continue learning, so checkpoint is sufficient
209
+ # relearn_structure is essentially a one stage multistage
210
+ # multistage with a checkpoint will have the proper optimizer for that epoch
211
+ # and no special learning mode means we are training a new model and should continue
212
+ if args['checkpoint'] and args['checkpoint_save_name'] and os.path.exists(args['checkpoint_save_name']):
213
+ tlogger.info("Found checkpoint to continue training: %s", args['checkpoint_save_name'])
214
+ trainer = Trainer.load(args['checkpoint_save_name'], args, load_optimizer=True, foundation_cache=foundation_cache)
215
+ return trainer
216
+
217
+ # in the 'finetune' case, this will preload the models into foundation_cache,
218
+ # so the effort is not wasted
219
+ pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file'])
220
+ forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])
221
+ backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])
222
+
223
+ if args['finetune']:
224
+ tlogger.info("Loading model to finetune: %s", model_load_file)
225
+ trainer = Trainer.load(model_load_file, args, load_optimizer=True, foundation_cache=NoTransformerFoundationCache(foundation_cache))
226
+ # a new finetuning will start with a new epochs_trained count
227
+ trainer.epochs_trained = 0
228
+ return trainer
229
+
230
+ if args['relearn_structure']:
231
+ tlogger.info("Loading model to continue training with new structure from %s", model_load_file)
232
+ temp_args = dict(args)
233
+ # remove the pattn & lattn layers unless the saved model had them
234
+ temp_args.pop('pattn_num_layers', None)
235
+ temp_args.pop('lattn_d_proj', None)
236
+ trainer = Trainer.load(model_load_file, temp_args, load_optimizer=False, foundation_cache=NoTransformerFoundationCache(foundation_cache))
237
+
238
+ # using the model's current values works for if the new
239
+ # dataset is the same or smaller
240
+ # TODO: handle a larger dataset as well
241
+ model = LSTMModel(pt,
242
+ forward_charlm,
243
+ backward_charlm,
244
+ trainer.model.bert_model,
245
+ trainer.model.bert_tokenizer,
246
+ trainer.model.force_bert_saved,
247
+ trainer.model.peft_name,
248
+ trainer.model.transitions,
249
+ trainer.model.constituents,
250
+ trainer.model.tags,
251
+ trainer.model.delta_words,
252
+ trainer.model.rare_words,
253
+ trainer.model.root_labels,
254
+ trainer.model.constituent_opens,
255
+ trainer.model.unary_limit(),
256
+ args)
257
+ model = model.to(args['device'])
258
+ model.copy_with_new_structure(trainer.model)
259
+ optimizer = build_optimizer(args, model, False)
260
+ scheduler = build_scheduler(args, optimizer)
261
+ trainer = Trainer(model, optimizer, scheduler)
262
+ return trainer
263
+
264
+ if args['multistage']:
265
+ # run adadelta over the model for half the time with no pattn or lattn
266
+ # training then switches to a different optimizer for the rest
267
+ # this works surprisingly well
268
+ tlogger.info("Warming up model for %d iterations using AdaDelta to train the embeddings", args['epochs'] // 2)
269
+ temp_args = dict(args)
270
+ # remove the attention layers for the temporary model
271
+ temp_args['pattn_num_layers'] = 0
272
+ temp_args['lattn_d_proj'] = 0
273
+ args = temp_args
274
+
275
+ peft_name = None
276
+ if args['use_peft']:
277
+ peft_name = "constituency"
278
+ bert_model, bert_tokenizer = load_bert(args['bert_model'])
279
+ bert_model = build_peft_wrapper(bert_model, temp_args, tlogger, adapter_name=peft_name)
280
+ elif args['bert_finetune'] or args['stage1_bert_finetune']:
281
+ bert_model, bert_tokenizer = load_bert(args['bert_model'])
282
+ else:
283
+ bert_model, bert_tokenizer = load_bert(args['bert_model'], foundation_cache)
284
+ model = LSTMModel(pt,
285
+ forward_charlm,
286
+ backward_charlm,
287
+ bert_model,
288
+ bert_tokenizer,
289
+ False,
290
+ peft_name,
291
+ train_transitions,
292
+ train_constituents,
293
+ tags,
294
+ words,
295
+ rare_words,
296
+ root_labels,
297
+ open_nodes,
298
+ unary_limit,
299
+ args)
300
+ model = model.to(args['device'])
301
+
302
+ optimizer = build_optimizer(args, model, build_simple_adadelta=args['multistage'])
303
+ scheduler = build_scheduler(args, optimizer, first_optimizer=args['multistage'])
304
+
305
+ trainer = Trainer(model, optimizer, scheduler, first_optimizer=args['multistage'])
306
+ return trainer
stanza/stanza/models/constituency/transformer_tree_stack.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on
3
+
4
+ Transition-based Parsing with Stack-Transformers
5
+ Ramon Fernandez Astudillo, Miguel Ballesteros, Tahira Naseem,
6
+ Austin Blodget, and Radu Florian
7
+ https://aclanthology.org/2020.findings-emnlp.89.pdf
8
+ """
9
+
10
+ from collections import namedtuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from stanza.models.constituency.positional_encoding import SinusoidalEncoding
16
+ from stanza.models.constituency.tree_stack import TreeStack
17
+
18
+ Node = namedtuple("Node", ['value', 'key_stack', 'value_stack', 'output'])
19
+
20
+ class TransformerTreeStack(nn.Module):
21
+ def __init__(self, input_size, output_size, input_dropout, length_limit=None, use_position=False, num_heads=1):
22
+ """
23
+ Builds the internal matrices and start parameter
24
+
25
+ TODO: currently only one attention head, implement MHA
26
+ """
27
+ super().__init__()
28
+
29
+ self.input_size = input_size
30
+ self.output_size = output_size
31
+ self.inv_sqrt_output_size = 1 / output_size ** 0.5
32
+ self.num_heads = num_heads
33
+
34
+ self.w_query = nn.Linear(input_size, output_size)
35
+ self.w_key = nn.Linear(input_size, output_size)
36
+ self.w_value = nn.Linear(input_size, output_size)
37
+
38
+ self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
39
+ if isinstance(input_dropout, nn.Module):
40
+ self.input_dropout = input_dropout
41
+ else:
42
+ self.input_dropout = nn.Dropout(input_dropout)
43
+
44
+ if length_limit is not None and length_limit < 1:
45
+ raise ValueError("length_limit < 1 makes no sense")
46
+ self.length_limit = length_limit
47
+
48
+ self.use_position = use_position
49
+ if use_position:
50
+ self.position_encoding = SinusoidalEncoding(model_dim=self.input_size, max_len=512)
51
+
52
+ def attention(self, key, query, value, mask=None):
53
+ """
54
+ Calculate attention for the given key, query value
55
+
56
+ Where B is the number of items stacked together, N is the length:
57
+ The key should be BxNxD
58
+ The query is BxD
59
+ The value is BxNxD
60
+
61
+ If mask is specified, it should be BxN of True/False values,
62
+ where True means that location is masked out
63
+
64
+ Reshapes and reorders are used to handle num_heads
65
+
66
+ Return will be softmax(query x key^T) * value
67
+ of size BxD
68
+ """
69
+ B = key.shape[0]
70
+ N = key.shape[1]
71
+ D = key.shape[2]
72
+
73
+ H = self.num_heads
74
+
75
+ # query is now BxDx1
76
+ query = query.unsqueeze(2)
77
+ # BxHxD/Hx1
78
+ query = query.reshape((B, H, -1, 1))
79
+
80
+ # BxNxHxD/H
81
+ key = key.reshape((B, N, H, -1))
82
+ # BxHxNxD/H
83
+ key = key.transpose(1, 2)
84
+
85
+ # BxNxHxD/H
86
+ value = value.reshape((B, N, H, -1))
87
+ # BxHxNxD/H
88
+ value = value.transpose(1, 2)
89
+
90
+ # BxHxNxD/H x BxHxD/Hx1
91
+ # result shape: BxHxN
92
+ attn = torch.matmul(key, query).squeeze(3) * self.inv_sqrt_output_size
93
+ if mask is not None:
94
+ # mask goes from BxN -> Bx1xN
95
+ mask = mask.unsqueeze(1)
96
+ mask = mask.expand(-1, H, -1)
97
+ attn.masked_fill_(mask, float('-inf'))
98
+ # attn shape will now be BxHx1xN
99
+ attn = torch.softmax(attn, dim=2).unsqueeze(2)
100
+ # BxHx1xN x BxHxNxD/H -> BxHxD/H
101
+ output = torch.matmul(attn, value).squeeze(2)
102
+ output = output.reshape(B, -1)
103
+ return output
104
+
105
+ def initial_state(self, initial_value=None):
106
+ """
107
+ Return an initial state based on a single layer of attention
108
+
109
+ Running attention might be overkill, but it is the simplest
110
+ way to put the Linears and start_embedding in the computation graph
111
+ """
112
+ start = self.start_embedding
113
+ if self.use_position:
114
+ position = self.position_encoding([0]).squeeze(0)
115
+ start = start + position
116
+
117
+ # N=1
118
+ # shape: 1xD
119
+ key = self.w_key(start).unsqueeze(0)
120
+
121
+ # shape: D
122
+ query = self.w_query(start)
123
+
124
+ # shape: 1xD
125
+ value = self.w_value(start).unsqueeze(0)
126
+
127
+ # unsqueeze to make it look like we are part of a batch of size 1
128
+ output = self.attention(key.unsqueeze(0), query.unsqueeze(0), value.unsqueeze(0)).squeeze(0)
129
+ return TreeStack(value=Node(initial_value, key, value, output), parent=None, length=1)
130
+
131
+ def push_states(self, stacks, values, inputs):
132
+ """
133
+ Push new inputs to the stacks and rerun attention on them
134
+
135
+ Where B is the number of items stacked together, I is input_size
136
+ stacks: B TreeStacks such as produced by initial_state and/or push_states
137
+ values: the new items to push on the stacks such as tree nodes or anything
138
+ inputs: BxI for the new input items
139
+
140
+ Runs attention starting from the existing keys & values
141
+ """
142
+ device = self.w_key.weight.device
143
+
144
+ batch_len = len(stacks) # B
145
+ positions = [x.value.key_stack.shape[0] for x in stacks]
146
+ max_len = max(positions) # N
147
+
148
+ if self.use_position:
149
+ position_encodings = self.position_encoding(positions)
150
+ inputs = inputs + position_encodings
151
+
152
+ inputs = self.input_dropout(inputs)
153
+ if len(inputs.shape) == 3:
154
+ if inputs.shape[0] == 1:
155
+ inputs = inputs.squeeze(0)
156
+ else:
157
+ raise ValueError("Expected the inputs to be of shape 1xBxI, got {}".format(inputs.shape))
158
+
159
+ new_keys = self.w_key(inputs)
160
+ key_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
161
+ key_stack[:, -1, :] = new_keys
162
+ for stack_idx, stack in enumerate(stacks):
163
+ key_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.key_stack
164
+
165
+ new_values = self.w_value(inputs)
166
+ value_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
167
+ value_stack[:, -1, :] = new_values
168
+ for stack_idx, stack in enumerate(stacks):
169
+ value_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.value_stack
170
+
171
+ query = self.w_query(inputs)
172
+
173
+ mask = torch.zeros(batch_len, max_len+1, device=device, dtype=torch.bool)
174
+ for stack_idx, stack in enumerate(stacks):
175
+ if len(stack) < max_len:
176
+ masked = max_len - positions[stack_idx]
177
+ mask[stack_idx, :masked] = True
178
+
179
+ batched_output = self.attention(key_stack, query, value_stack, mask)
180
+
181
+ new_stacks = []
182
+ for stack_idx, (stack, node_value, new_key, new_value, output) in enumerate(zip(stacks, values, key_stack, value_stack, batched_output)):
183
+ # max_len-len(stack) so that we ignore the padding at the start of shorter stacks
184
+ new_key_stack = new_key[max_len-positions[stack_idx]:, :]
185
+ new_value_stack = new_value[max_len-positions[stack_idx]:, :]
186
+ if self.length_limit is not None and new_key_stack.shape[0] > self.length_limit + 1:
187
+ new_key_stack = torch.cat([new_key_stack[:1, :], new_key_stack[2:, :]], axis=0)
188
+ new_value_stack = torch.cat([new_value_stack[:1, :], new_value_stack[2:, :]], axis=0)
189
+ new_stacks.append(stack.push(value=Node(node_value, new_key_stack, new_value_stack, output)))
190
+ return new_stacks
191
+
192
+ def output(self, stack):
193
+ """
194
+ Return the last layer of the lstm_hx as the output from a stack
195
+
196
+ Refactored so that alternate structures have an easy way of getting the output
197
+ """
198
+ return stack.value.output
stanza/stanza/models/constituency/transition_sequence.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build a transition sequence from parse trees.
3
+
4
+ Supports multiple transition schemes - TOP_DOWN and variants, IN_ORDER
5
+ """
6
+
7
+ import logging
8
+
9
+ from stanza.models.common import utils
10
+ from stanza.models.constituency.parse_transitions import Shift, CompoundUnary, OpenConstituent, CloseConstituent, TransitionScheme, Finalize
11
+ from stanza.models.constituency.tree_reader import read_trees
12
+ from stanza.utils.get_tqdm import get_tqdm
13
+
14
+ tqdm = get_tqdm()
15
+
16
+ logger = logging.getLogger('stanza.constituency.trainer')
17
+
18
+ def yield_top_down_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
19
+ """
20
+ For tree (X A B C D), yield Open(X) A B C D Close
21
+
22
+ The details are in how to treat unary transitions
23
+ Three possibilities handled by this method:
24
+ TOP_DOWN_UNARY: (Y (X ...)) -> Open(X) ... Close Unary(Y)
25
+ TOP_DOWN_COMPOUND: (Y (X ...)) -> Open(Y, X) ... Close
26
+ TOP_DOWN: (Y (X ...)) -> Open(Y) Open(X) ... Close Close
27
+ """
28
+ if tree.is_preterminal():
29
+ yield Shift()
30
+ return
31
+
32
+ if tree.is_leaf():
33
+ return
34
+
35
+ if transition_scheme is TransitionScheme.TOP_DOWN_UNARY:
36
+ if len(tree.children) == 1:
37
+ labels = []
38
+ while not tree.is_preterminal() and len(tree.children) == 1:
39
+ labels.append(tree.label)
40
+ tree = tree.children[0]
41
+ for transition in yield_top_down_sequence(tree, transition_scheme):
42
+ yield transition
43
+ yield CompoundUnary(*labels)
44
+ return
45
+
46
+ if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
47
+ labels = [tree.label]
48
+ while len(tree.children) == 1 and not tree.children[0].is_preterminal():
49
+ tree = tree.children[0]
50
+ labels.append(tree.label)
51
+ yield OpenConstituent(*labels)
52
+ else:
53
+ yield OpenConstituent(tree.label)
54
+ for child in tree.children:
55
+ for transition in yield_top_down_sequence(child, transition_scheme):
56
+ yield transition
57
+ yield CloseConstituent()
58
+
59
+ def yield_in_order_sequence(tree):
60
+ """
61
+ For tree (X A B C D), yield A Open(X) B C D Close
62
+ """
63
+ if tree.is_preterminal():
64
+ yield Shift()
65
+ return
66
+
67
+ if tree.is_leaf():
68
+ return
69
+
70
+ for transition in yield_in_order_sequence(tree.children[0]):
71
+ yield transition
72
+
73
+ yield OpenConstituent(tree.label)
74
+
75
+ for child in tree.children[1:]:
76
+ for transition in yield_in_order_sequence(child):
77
+ yield transition
78
+
79
+ yield CloseConstituent()
80
+
81
+
82
+
83
+ def yield_in_order_compound_sequence(tree, transition_scheme):
84
+ def helper(tree):
85
+ if tree.is_leaf():
86
+ return
87
+
88
+ labels = []
89
+ while len(tree.children) == 1 and not tree.is_preterminal():
90
+ labels.append(tree.label)
91
+ tree = tree.children[0]
92
+
93
+ if tree.is_preterminal():
94
+ yield Shift()
95
+ if len(labels) > 0:
96
+ yield CompoundUnary(*labels)
97
+ return
98
+
99
+ for transition in helper(tree.children[0]):
100
+ yield transition
101
+
102
+ if transition_scheme is TransitionScheme.IN_ORDER_UNARY:
103
+ yield OpenConstituent(tree.label)
104
+ else:
105
+ labels.append(tree.label)
106
+ yield OpenConstituent(*labels)
107
+
108
+ for child in tree.children[1:]:
109
+ for transition in helper(child):
110
+ yield transition
111
+
112
+ yield CloseConstituent()
113
+
114
+ if transition_scheme is TransitionScheme.IN_ORDER_UNARY and len(labels) > 0:
115
+ yield CompoundUnary(*labels)
116
+
117
+ if len(tree.children) == 0:
118
+ raise ValueError("Cannot build {} on an empty tree".format(transition_scheme))
119
+ if len(tree.children) != 1:
120
+ raise ValueError("Cannot build {} with a tree that has two top level nodes: {}".format(transition_scheme, tree))
121
+
122
+ for t in helper(tree.children[0]):
123
+ yield t
124
+
125
+ yield Finalize(tree.label)
126
+
127
+ def build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN_UNARY):
128
+ """
129
+ Turn a single tree into a list of transitions based on the TransitionScheme
130
+ """
131
+ if transition_scheme is TransitionScheme.IN_ORDER:
132
+ return list(yield_in_order_sequence(tree))
133
+ elif (transition_scheme is TransitionScheme.IN_ORDER_COMPOUND or
134
+ transition_scheme is TransitionScheme.IN_ORDER_UNARY):
135
+ return list(yield_in_order_compound_sequence(tree, transition_scheme))
136
+ else:
137
+ return list(yield_top_down_sequence(tree, transition_scheme))
138
+
139
+ def build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN_UNARY, reverse=False):
140
+ """
141
+ Turn each of the trees in the treebank into a list of transitions based on the TransitionScheme
142
+ """
143
+ if reverse:
144
+ return [build_sequence(tree.reverse(), transition_scheme) for tree in trees]
145
+ else:
146
+ return [build_sequence(tree, transition_scheme) for tree in trees]
147
+
148
+ def all_transitions(transition_lists):
149
+ """
150
+ Given a list of transition lists, combine them all into a list of unique transitions.
151
+ """
152
+ transitions = set()
153
+ for trans_list in transition_lists:
154
+ transitions.update(trans_list)
155
+ return sorted(transitions)
156
+
157
+ def convert_trees_to_sequences(trees, treebank_name, transition_scheme, reverse=False):
158
+ """
159
+ Wrap both build_treebank and all_transitions, possibly with a tqdm
160
+
161
+ Converts trees to a list of sequences, then returns the list of known transitions
162
+ """
163
+ if len(trees) == 0:
164
+ return [], []
165
+ logger.info("Building %s transition sequences", treebank_name)
166
+ if logger.getEffectiveLevel() <= logging.INFO:
167
+ trees = tqdm(trees)
168
+ sequences = build_treebank(trees, transition_scheme, reverse)
169
+ transitions = all_transitions(sequences)
170
+ return sequences, transitions
171
+
172
+ def main():
173
+ """
174
+ Convert a sample tree and print its transitions
175
+ """
176
+ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
177
+ #text = "(WP Who)"
178
+
179
+ tree = read_trees(text)[0]
180
+
181
+ print(tree)
182
+ transitions = build_sequence(tree)
183
+ print(transitions)
184
+
185
+ if __name__ == '__main__':
186
+ main()
stanza/stanza/models/constituency/tree_embedding.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A module to use a Constituency Parser to make an embedding for a tree
3
+
4
+ The embedding can be produced just from the words and the top of the
5
+ tree, or it can be done with a form of attention over the nodes
6
+
7
+ Can be done over an existing parse tree or unparsed text
8
+ """
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from stanza.models.constituency.trainer import Trainer
15
+
16
+ class TreeEmbedding(nn.Module):
17
+ def __init__(self, constituency_parser, args):
18
+ super(TreeEmbedding, self).__init__()
19
+
20
+ self.config = {
21
+ "all_words": args["all_words"],
22
+ "backprop": args["backprop"],
23
+ #"batch_norm": args["batch_norm"],
24
+ "node_attn": args["node_attn"],
25
+ "top_layer": args["top_layer"],
26
+ }
27
+
28
+ self.constituency_parser = constituency_parser
29
+
30
+ # word_lstm: hidden_size * num_tree_lstm_layers * 2 (start & end)
31
+ # transition_stack: transition_hidden_size
32
+ # constituent_stack: hidden_size
33
+ self.hidden_size = self.constituency_parser.hidden_size + self.constituency_parser.transition_hidden_size
34
+ if self.config["all_words"]:
35
+ self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
36
+ else:
37
+ self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers * 2
38
+
39
+ if self.config["node_attn"]:
40
+ self.query = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)
41
+ self.key = nn.Linear(self.hidden_size, self.constituency_parser.hidden_size)
42
+ self.value = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)
43
+
44
+ # TODO: cat transition and constituent hx as well?
45
+ self.output_size = self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
46
+ else:
47
+ self.output_size = self.hidden_size
48
+
49
+ # TODO: maybe have batch_norm, maybe use Identity
50
+ #if self.config["batch_norm"]:
51
+ # self.input_norm = nn.BatchNorm1d(self.output_size)
52
+
53
+ def embed_trees(self, inputs):
54
+ if self.config["backprop"]:
55
+ states = self.constituency_parser.analyze_trees(inputs)
56
+ else:
57
+ with torch.no_grad():
58
+ states = self.constituency_parser.analyze_trees(inputs)
59
+
60
+ constituent_lists = [x.constituents for x in states]
61
+ states = [x.state for x in states]
62
+
63
+ word_begin_hx = torch.stack([state.word_queue[0].hx for state in states])
64
+ word_end_hx = torch.stack([state.word_queue[state.word_position].hx for state in states])
65
+ transition_hx = torch.stack([self.constituency_parser.transition_stack.output(state.transitions) for state in states])
66
+ # go down one layer to get the embedding off the top of the S, not the ROOT
67
+ # (in terms of the typical treebank)
68
+ # the idea being that the ROOT has no additional information
69
+ # and may even have 0s for the embedding in certain circumstances,
70
+ # such as after learning UNTIED_MAX long enough
71
+ if self.config["top_layer"]:
72
+ constituent_hx = torch.stack([self.constituency_parser.constituent_stack.output(state.constituents) for state in states])
73
+ else:
74
+ constituent_hx = torch.cat([constituents[-2].tree_hx for constituents in constituent_lists], dim=0)
75
+
76
+ if self.config["all_words"]:
77
+ # need B matrices of N x hidden_size
78
+ key = [torch.stack([torch.cat([word.hx, thx, chx]) for word in state.word_queue], dim=0)
79
+ for state, thx, chx in zip(states, transition_hx, constituent_hx)]
80
+ else:
81
+ key = torch.cat((word_begin_hx, word_end_hx, transition_hx, constituent_hx), dim=1).unsqueeze(1)
82
+
83
+ if not self.config["node_attn"]:
84
+ return key
85
+ key = [self.key(x) for x in key]
86
+
87
+ node_hx = [torch.stack([con.tree_hx for con in constituents], dim=0) for constituents in constituent_lists]
88
+ queries = [self.query(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
89
+ values = [self.value(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
90
+ # TODO: could pad to make faster here
91
+ attn = [torch.matmul(q, k.transpose(0, 1)) for q, k in zip(queries, key)]
92
+ attn = [torch.softmax(x, dim=0) for x in attn]
93
+ previous_layer = [torch.matmul(weight.transpose(0, 1), value) for weight, value in zip(attn, values)]
94
+ return previous_layer
95
+
96
+ def forward(self, inputs):
97
+ return embed_trees(self, inputs)
98
+
99
+ def get_norms(self):
100
+ lines = ["constituency_parser." + x for x in self.constituency_parser.get_norms()]
101
+ for name, param in self.named_parameters():
102
+ if param.requires_grad and not name.startswith('constituency_parser.'):
103
+ lines.append("%s %.6g" % (name, torch.norm(param).item()))
104
+ return lines
105
+
106
+
107
+ def get_params(self, skip_modules=True):
108
+ model_state = self.state_dict()
109
+ # skip all of the constituency parameters here -
110
+ # we will add them by calling the model's get_params()
111
+ skipped = [k for k in model_state.keys() if k.startswith("constituency_parser.")]
112
+ for k in skipped:
113
+ del model_state[k]
114
+
115
+ parser = self.constituency_parser.get_params(skip_modules)
116
+
117
+ params = {
118
+ 'model': model_state,
119
+ 'constituency': parser,
120
+ 'config': self.config,
121
+ }
122
+ return params
123
+
124
+ @staticmethod
125
+ def from_parser_file(args, foundation_cache=None):
126
+ constituency_parser = Trainer.load(args['model'], args, foundation_cache)
127
+ return TreeEmbedding(constituency_parser.model, args)
128
+
129
+ @staticmethod
130
+ def model_from_params(params, args, foundation_cache=None):
131
+ # TODO: integrate with peft
132
+ constituency_parser = Trainer.model_from_params(params['constituency'], None, args, foundation_cache)
133
+ model = TreeEmbedding(constituency_parser, params['config'])
134
+ model.load_state_dict(params['model'], strict=False)
135
+ return model
stanza/stanza/models/coref/config.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes Config, a simple namespace for config values.
2
+
3
+ For description of all config values, refer to config.toml.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List
8
+
9
+
10
+ @dataclass
11
+ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-methods
12
+ """ Contains values needed to set up the coreference model. """
13
+ section: str
14
+
15
+ # TODO: can either eliminate data_dir or use it for the train/dev/test data
16
+ data_dir: str
17
+ save_dir: str
18
+ save_name: str
19
+
20
+ train_data: str
21
+ dev_data: str
22
+ test_data: str
23
+
24
+ device: str
25
+
26
+ bert_model: str
27
+ bert_window_size: int
28
+
29
+ embedding_size: int
30
+ sp_embedding_size: int
31
+ a_scoring_batch_size: int
32
+ hidden_size: int
33
+ n_hidden_layers: int
34
+
35
+ max_span_len: int
36
+
37
+ rough_k: int
38
+
39
+ lora: bool
40
+ lora_alpha: int
41
+ lora_rank: int
42
+ lora_dropout: float
43
+
44
+ full_pairwise: bool
45
+
46
+ lora_target_modules: List[str]
47
+ lora_modules_to_save: List[str]
48
+
49
+ clusters_starts_are_singletons: bool
50
+ bert_finetune: bool
51
+ dropout_rate: float
52
+ learning_rate: float
53
+ bert_learning_rate: float
54
+ # we find that setting this to a small but non-zero number
55
+ # makes the model less likely to forget how to do anything
56
+ bert_finetune_begin_epoch: float
57
+ train_epochs: int
58
+ bce_loss_weight: float
59
+
60
+ tokenizer_kwargs: Dict[str, dict]
61
+ conll_log_dir: str
62
+
63
+ save_each_checkpoint: bool
64
+ log_norms: bool
65
+ singletons: bool
66
+
stanza/stanza/models/coref/coref_config.toml ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Before you start changing anything here, read the comments.
3
+ # All of them can be found below in the "DEFAULT" section
4
+
5
+ [DEFAULT]
6
+
7
+ # The directory that contains extracted files of everything you've downloaded.
8
+ data_dir = "data/coref"
9
+
10
+ # where to put checkpoints and final models
11
+ save_dir = "saved_models/coref"
12
+ save_name = "bert-large-cased"
13
+
14
+ # Train, dev and test jsonlines
15
+ # train_data = "data/coref/en_gum-ud.train.nosgl.json"
16
+ # dev_data = "data/coref/en_gum-ud.dev.nosgl.json"
17
+ # test_data = "data/coref/en_gum-ud.test.nosgl.json"
18
+
19
+ train_data = "data/coref/corefud_concat_v1_0_langid.train.json"
20
+ dev_data = "data/coref/corefud_concat_v1_0_langid.dev.json"
21
+ test_data = "data/coref/corefud_concat_v1_0_langid.dev.json"
22
+
23
+ #train_data = "data/coref/english_train_head.jsonlines"
24
+ #dev_data = "data/coref/english_development_head.jsonlines"
25
+ #test_data = "data/coref/english_test_head.jsonlines"
26
+
27
+ # do not use the full pairwise encoding scheme
28
+ full_pairwise = false
29
+
30
+ # The device where everything is to be placed. "cuda:N"/"cpu" are supported.
31
+ device = "cuda:0"
32
+
33
+ save_each_checkpoint = false
34
+ log_norms = false
35
+
36
+ # Bert settings ======================
37
+
38
+ # Base bert model architecture and tokenizer
39
+ bert_model = "bert-large-cased"
40
+
41
+ # Controls max length of sequences passed through bert to obtain its
42
+ # contextual embeddings
43
+ # Must be less than or equal to 512
44
+ bert_window_size = 512
45
+
46
+ # General model settings =============
47
+
48
+ # Controls the dimensionality of feature embeddings
49
+ embedding_size = 20
50
+
51
+ # Controls the dimensionality of distance embeddings used by SpanPredictor
52
+ sp_embedding_size = 64
53
+
54
+ # Controls the number of spans for which anaphoricity can be scores in one
55
+ # batch. Only affects final scoring; mention extraction and rough scoring
56
+ # are less memory intensive, so they are always done in just one batch.
57
+ a_scoring_batch_size = 128
58
+
59
+ # AnaphoricityScorer FFNN parameters
60
+ hidden_size = 1024
61
+ n_hidden_layers = 1
62
+
63
+ # Do you want to support singletons?
64
+ singletons = true
65
+
66
+
67
+ # Mention extraction settings ========
68
+
69
+ # Mention extractor will check spans up to max_span_len words
70
+ # The default value is chosen to be big enough to hold any dev data span
71
+ max_span_len = 64
72
+
73
+
74
+ # Pruning settings ===================
75
+
76
+ # Controls how many pairs should be preserved per mention
77
+ # after applying rough scoring.
78
+ rough_k = 50
79
+
80
+
81
+ # Lora settings ===================
82
+
83
+ # LoRA settings
84
+ lora = false
85
+ lora_alpha = 128
86
+ lora_dropout = 0.1
87
+ lora_rank = 64
88
+ lora_target_modules = []
89
+ lora_modules_to_save = []
90
+
91
+
92
+ # Training settings ==================
93
+
94
+ # Controls whether the first dummy node predicts cluster starts or singletons
95
+ clusters_starts_are_singletons = true
96
+
97
+ # Controls whether to fine-tune bert_model
98
+ bert_finetune = true
99
+
100
+ # Controls the dropout rate throughout all models
101
+ dropout_rate = 0.3
102
+
103
+ # Bert learning rate (only used if bert_finetune is set)
104
+ bert_learning_rate = 1e-6
105
+ bert_finetune_begin_epoch = 0.5
106
+
107
+ # Task learning rate
108
+ learning_rate = 3e-4
109
+
110
+ # For how many epochs the training is done
111
+ train_epochs = 32
112
+
113
+ # Controls the weight of binary cross entropy loss added to nlml loss
114
+ bce_loss_weight = 0.5
115
+
116
+ # The directory that will contain conll prediction files
117
+ conll_log_dir = "data/conll_logs"
118
+
119
+ # =============================================================================
120
+ # Extra keyword arguments to be passed to bert tokenizers of specified models
121
+ [DEFAULT.tokenizer_kwargs]
122
+ [DEFAULT.tokenizer_kwargs.roberta-large]
123
+ "add_prefix_space" = true
124
+
125
+ [DEFAULT.tokenizer_kwargs.xlm-roberta-large]
126
+ "add_prefix_space" = true
127
+
128
+ [DEFAULT.tokenizer_kwargs.spanbert-large-cased]
129
+ "do_lower_case" = false
130
+
131
+ [DEFAULT.tokenizer_kwargs.bert-large-cased]
132
+ "do_lower_case" = false
133
+
134
+ # =============================================================================
135
+ # The sections listed here do not need to make use of all config variables
136
+ # If a variable is omitted, its default value will be used instead
137
+
138
+ [roberta]
139
+ bert_model = "roberta-large"
140
+
141
+ [roberta_lora]
142
+ bert_model = "roberta-large"
143
+ bert_learning_rate = 0.00005
144
+ lora = true
145
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
146
+ lora_modules_to_save = [ "pooler" ]
147
+
148
+ [scandibert_lora]
149
+ bert_model = "vesteinn/ScandiBERT"
150
+ bert_learning_rate = 0.0002
151
+ lora = true
152
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
153
+ lora_modules_to_save = [ "pooler" ]
154
+
155
+ [xlm_roberta]
156
+ bert_model = "FacebookAI/xlm-roberta-large"
157
+ bert_learning_rate = 0.00001
158
+ bert_finetune = true
159
+
160
+ [xlm_roberta_lora]
161
+ bert_model = "FacebookAI/xlm-roberta-large"
162
+ bert_learning_rate = 0.000025
163
+ lora = true
164
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
165
+ lora_modules_to_save = [ "pooler" ]
166
+
167
+ [deeppavlov_slavic_bert_lora]
168
+ bert_model = "DeepPavlov/bert-base-bg-cs-pl-ru-cased"
169
+ bert_learning_rate = 0.000025
170
+ lora = true
171
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
172
+ lora_modules_to_save = [ "pooler" ]
173
+
174
+ [deberta_lora]
175
+ bert_model = "microsoft/deberta-v3-large"
176
+ bert_learning_rate = 0.00001
177
+ lora = true
178
+ lora_target_modules = [ "query_proj", "value_proj", "output.dense" ]
179
+ lora_modules_to_save = [ ]
180
+
181
+ [electra]
182
+ bert_model = "google/electra-large-discriminator"
183
+ bert_learning_rate = 0.00002
184
+
185
+ [electra_lora]
186
+ bert_model = "google/electra-large-discriminator"
187
+ bert_learning_rate = 0.000025
188
+ lora = true
189
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
190
+ lora_modules_to_save = [ ]
191
+
192
+ [hungarian_electra_lora]
193
+ # TODO: experiment with tokenizer options for this to see if that's
194
+ # why the results are so low using this transformer
195
+ bert_model = "NYTK/electra-small-discriminator-hungarian"
196
+ bert_learning_rate = 0.000025
197
+ lora = true
198
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
199
+ lora_modules_to_save = [ ]
200
+
201
+ [muril_large_cased_lora]
202
+ bert_model = "google/muril-large-cased"
203
+ bert_learning_rate = 0.000025
204
+ lora = true
205
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
206
+ lora_modules_to_save = [ "pooler" ]
207
+
208
+ [indic_bert_lora]
209
+ bert_model = "ai4bharat/indic-bert"
210
+ bert_learning_rate = 0.0005
211
+ lora = true
212
+ # indic-bert is an albert with repeating layers of different names
213
+ lora_target_modules = [ "query", "value", "dense", "ffn", "full_layer" ]
214
+ lora_modules_to_save = [ "pooler" ]
215
+
216
+ [bert_multilingual_cased_lora]
217
+ # LR sweep on a Hindi dataset
218
+ # 0.00001: 0.53238
219
+ # 0.00002: 0.54012
220
+ # 0.000025: 0.54206
221
+ # 0.00003: 0.54050
222
+ # 0.00004: 0.55081
223
+ # 0.00005: 0.55135
224
+ # 0.000075: 0.54482
225
+ # 0.0001: 0.53888
226
+ bert_model = "google-bert/bert-base-multilingual-cased"
227
+ bert_learning_rate = 0.00005
228
+ lora = true
229
+ lora_target_modules = [ "query", "value", "output.dense", "intermediate.dense" ]
230
+ lora_modules_to_save = [ "pooler" ]
231
+
232
+ [t5_lora]
233
+ bert_model = "google-t5/t5-large"
234
+ bert_learning_rate = 0.000025
235
+ bert_window_size = 1024
236
+ lora = true
237
+ lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
238
+ lora_modules_to_save = [ ]
239
+
240
+ [mt5_lora]
241
+ bert_model = "google/mt5-base"
242
+ bert_learning_rate = 0.000025
243
+ lora_alpha = 64
244
+ lora_rank = 32
245
+ lora = true
246
+ lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
247
+ lora_modules_to_save = [ ]
248
+
249
+ [deepnarrow_t5_xl_lora]
250
+ bert_model = "google/t5-efficient-xl"
251
+ bert_learning_rate = 0.00025
252
+ lora = true
253
+ lora_target_modules = [ "q", "v", "o", "wi", "wo" ]
254
+ lora_modules_to_save = [ ]
255
+
256
+ [roberta_no_finetune]
257
+ bert_model = "roberta-large"
258
+ bert_finetune = false
259
+
260
+ [roberta_no_bce]
261
+ bert_model = "roberta-large"
262
+ bce_loss_weight = 0.0
263
+
264
+ [spanbert]
265
+ bert_model = "SpanBERT/spanbert-large-cased"
266
+
267
+ [spanbert_no_bce]
268
+ bert_model = "SpanBERT/spanbert-large-cased"
269
+ bce_loss_weight = 0.0
270
+
271
+ [bert]
272
+ bert_model = "bert-large-cased"
273
+
274
+ [longformer]
275
+ bert_model = "allenai/longformer-large-4096"
276
+ bert_window_size = 2048
277
+
278
+ [debug]
279
+ bert_window_size = 384
280
+ bert_finetune = false
281
+ device = "cpu:0"
282
+
283
+ [debug_gpu]
284
+ bert_window_size = 384
285
+ bert_finetune = false
stanza/stanza/models/coref/dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from torch.utils.data import Dataset
4
+
5
+ from stanza.models.coref.tokenizer_customization import TOKENIZER_FILTERS, TOKENIZER_MAPS
6
+
7
+ logger = logging.getLogger('stanza')
8
+
9
+ class CorefDataset(Dataset):
10
+
11
+ def __init__(self, path, config, tokenizer):
12
+ self.config = config
13
+ self.tokenizer = tokenizer
14
+
15
+ # by default, this doesn't filter anything (see lambda _ True);
16
+ # however, there are some subword symbols which are standalone
17
+ # tokens which we don't want on models like Albert; hence we
18
+ # pass along a filter if needed.
19
+ self.__filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,
20
+ lambda _: True)
21
+ self.__token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})
22
+
23
+ try:
24
+ with open(path, encoding="utf-8") as fin:
25
+ data_f = json.load(fin)
26
+ except json.decoder.JSONDecodeError:
27
+ # read the old jsonlines format if necessary
28
+ with open(path, encoding="utf-8") as fin:
29
+ text = "[" + ",\n".join(fin) + "]"
30
+ data_f = json.loads(text)
31
+ logger.info("Processing %d docs from %s...", len(data_f), path)
32
+ self.__raw = data_f
33
+ self.__avg_span = sum(len(doc["head2span"]) for doc in self.__raw) / len(self.__raw)
34
+ self.__out = []
35
+ for doc in self.__raw:
36
+ doc["span_clusters"] = [[tuple(mention) for mention in cluster]
37
+ for cluster in doc["span_clusters"]]
38
+ word2subword = []
39
+ subwords = []
40
+ word_id = []
41
+ for i, word in enumerate(doc["cased_words"]):
42
+ tokenized_word = self.__token_map.get(word, self.tokenizer.tokenize(word))
43
+ tokenized_word = list(filter(self.__filter_func, tokenized_word))
44
+ word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
45
+ subwords.extend(tokenized_word)
46
+ word_id.extend([i] * len(tokenized_word))
47
+ doc["word2subword"] = word2subword
48
+ doc["subwords"] = subwords
49
+ doc["word_id"] = word_id
50
+ self.__out.append(doc)
51
+ logger.info("Loaded %d docs from %s.", len(data_f), path)
52
+
53
+ @property
54
+ def avg_span(self):
55
+ return self.__avg_span
56
+
57
+ def __getitem__(self, x):
58
+ return self.__out[x]
59
+
60
+ def __len__(self):
61
+ return len(self.__out)
stanza/stanza/models/coref/pairwise_encoder.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes PairwiseEncodes, that transforms pairwise features, such as
2
+ distance between the mentions, same/different speaker into feature embeddings
3
+ """
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from stanza.models.coref.config import Config
9
+ from stanza.models.coref.const import Doc
10
+
11
+
12
+ class PairwiseEncoder(torch.nn.Module):
13
+ """ A Pytorch module to obtain feature embeddings for pairwise features
14
+
15
+ Usage:
16
+ encoder = PairwiseEncoder(config)
17
+ pairwise_features = encoder(pair_indices, doc)
18
+ """
19
+ def __init__(self, config: Config):
20
+ super().__init__()
21
+ emb_size = config.embedding_size
22
+
23
+ self.genre2int = {g: gi for gi, g in enumerate(["bc", "bn", "mz", "nw",
24
+ "pt", "tc", "wb"])}
25
+ self.genre_emb = torch.nn.Embedding(len(self.genre2int), emb_size)
26
+
27
+ # each position corresponds to a bucket:
28
+ # [(0, 2), (2, 3), (3, 4), (4, 5), (5, 8),
29
+ # (8, 16), (16, 32), (32, 64), (64, float("inf"))]
30
+ self.distance_emb = torch.nn.Embedding(9, emb_size)
31
+
32
+ # two possibilities: same vs different speaker
33
+ self.speaker_emb = torch.nn.Embedding(2, emb_size)
34
+
35
+ self.dropout = torch.nn.Dropout(config.dropout_rate)
36
+
37
+ self.__full_pw = config.full_pairwise
38
+
39
+ if self.__full_pw:
40
+ self.shape = emb_size * 3 # genre, distance, speaker\
41
+ else:
42
+ self.shape = emb_size # distance only
43
+
44
+ @property
45
+ def device(self) -> torch.device:
46
+ """ A workaround to get current device (which is assumed to be the
47
+ device of the first parameter of one of the submodules) """
48
+ return next(self.genre_emb.parameters()).device
49
+
50
+ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
51
+ top_indices: torch.Tensor,
52
+ doc: Doc) -> torch.Tensor:
53
+ word_ids = torch.arange(0, len(doc["cased_words"]), device=self.device)
54
+
55
+ # bucketing the distance (see __init__())
56
+ distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
57
+ ).clamp_min_(min=1)
58
+ log_distance = distance.to(torch.float).log2().floor_()
59
+ log_distance = log_distance.clamp_max_(max=6).to(torch.long)
60
+ distance = torch.where(distance < 5, distance - 1, log_distance + 2)
61
+ distance = self.distance_emb(distance)
62
+
63
+ if not self.__full_pw:
64
+ return self.dropout(distance)
65
+
66
+ # calculate speaker embeddings
67
+ speaker_map = torch.tensor(self._speaker_map(doc), device=self.device)
68
+ same_speaker = (speaker_map[top_indices] == speaker_map.unsqueeze(1))
69
+ same_speaker = self.speaker_emb(same_speaker.to(torch.long))
70
+
71
+
72
+ # if there is no genre information, use "wb" as the genre (which is what the
73
+ # Pipeline does
74
+ genre = torch.tensor(self.genre2int.get(doc["document_id"][:2], self.genre2int["wb"]),
75
+ device=self.device).expand_as(top_indices)
76
+ genre = self.genre_emb(genre)
77
+
78
+ return self.dropout(torch.cat((same_speaker, distance, genre), dim=2))
79
+
80
+ @staticmethod
81
+ def _speaker_map(doc: Doc) -> List[int]:
82
+ """
83
+ Returns a tensor where i-th element is the speaker id of i-th word.
84
+ """
85
+ # if speaker is not found in the doc, simply return "speaker#1" for all the speakers
86
+ # and embed them using the same ID
87
+
88
+ # speaker string -> speaker id
89
+ str2int = {s: i for i, s in enumerate(set(doc.get("speaker", ["speaker#1"
90
+ for _ in range(len(doc["deprel"]))])))}
91
+
92
+ # word id -> speaker id
93
+ return [str2int[s] for s in doc.get("speaker", ["speaker#1"
94
+ for _ in range(len(doc["deprel"]))])]
stanza/stanza/models/coref/rough_scorer.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes RoughScorer, a simple bilinear module to calculate rough
2
+ anaphoricity scores.
3
+ """
4
+
5
+ from typing import Tuple
6
+
7
+ import torch
8
+
9
+ from stanza.models.coref.config import Config
10
+
11
+
12
+ class RoughScorer(torch.nn.Module):
13
+ """
14
+ Is needed to give a roughly estimate of the anaphoricity of two candidates,
15
+ only top scoring candidates are considered on later steps to reduce
16
+ computational complexity.
17
+ """
18
+ def __init__(self, features: int, config: Config):
19
+ super().__init__()
20
+ self.dropout = torch.nn.Dropout(config.dropout_rate)
21
+ self.bilinear = torch.nn.Linear(features, features)
22
+
23
+ self.k = config.rough_k
24
+
25
+ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
26
+ mentions: torch.Tensor,
27
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
28
+ """
29
+ Returns rough anaphoricity scores for candidates, which consist of
30
+ the bilinear output of the current model summed with mention scores.
31
+ """
32
+ # [n_mentions, n_mentions]
33
+ pair_mask = torch.arange(mentions.shape[0])
34
+ pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
35
+ pair_mask = torch.log((pair_mask > 0).to(torch.float))
36
+ pair_mask = pair_mask.to(mentions.device)
37
+
38
+ bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
39
+
40
+ rough_scores = pair_mask + bilinear_scores
41
+
42
+ return self._prune(rough_scores)
43
+
44
+ def _prune(self,
45
+ rough_scores: torch.Tensor
46
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ """
48
+ Selects top-k rough antecedent scores for each mention.
49
+
50
+ Args:
51
+ rough_scores: tensor of shape [n_mentions, n_mentions], containing
52
+ rough antecedent scores of each mention-antecedent pair.
53
+
54
+ Returns:
55
+ FloatTensor of shape [n_mentions, k], top rough scores
56
+ LongTensor of shape [n_mentions, k], top indices
57
+ """
58
+ top_scores, indices = torch.topk(rough_scores,
59
+ k=min(self.k, len(rough_scores)),
60
+ dim=1, sorted=False)
61
+ return top_scores, indices, rough_scores
stanza/stanza/models/coref/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Contains functions not directly linked to coreference resolution """
2
+
3
+ from typing import List, Set
4
+
5
+ import torch
6
+
7
+ from stanza.models.coref.const import EPSILON
8
+
9
+
10
+ class GraphNode:
11
+ def __init__(self, node_id: int):
12
+ self.id = node_id
13
+ self.links: Set[GraphNode] = set()
14
+ self.visited = False
15
+
16
+ def link(self, another: "GraphNode"):
17
+ self.links.add(another)
18
+ another.links.add(self)
19
+
20
+ def __repr__(self) -> str:
21
+ return str(self.id)
22
+
23
+
24
+ def add_dummy(tensor: torch.Tensor, eps: bool = False):
25
+ """ Prepends zeros (or a very small value if eps is True)
26
+ to the first (not zeroth) dimension of tensor.
27
+ """
28
+ kwargs = dict(device=tensor.device, dtype=tensor.dtype)
29
+ shape: List[int] = list(tensor.shape)
30
+ shape[1] = 1
31
+ if not eps:
32
+ dummy = torch.zeros(shape, **kwargs) # type: ignore
33
+ else:
34
+ dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
35
+ return torch.cat((dummy, tensor), dim=1)
stanza/stanza/models/depparse/model.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
9
+
10
+ from stanza.models.common.bert_embedding import extract_bert_embeddings
11
+ from stanza.models.common.biaffine import DeepBiaffineScorer
12
+ from stanza.models.common.foundation_cache import load_charlm
13
+ from stanza.models.common.hlstm import HighwayLSTM
14
+ from stanza.models.common.dropout import WordDropout
15
+ from stanza.models.common.utils import attach_bert_model
16
+ from stanza.models.common.vocab import CompositeVocab
17
+ from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
18
+ from stanza.models.common import utils
19
+
20
+ logger = logging.getLogger('stanza')
21
+
22
+ class Parser(nn.Module):
23
+ def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
24
+ super().__init__()
25
+
26
+ self.vocab = vocab
27
+ self.args = args
28
+ self.share_hid = share_hid
29
+ self.unsaved_modules = []
30
+
31
+ # input layers
32
+ input_size = 0
33
+ if self.args['word_emb_dim'] > 0:
34
+ # frequent word embeddings
35
+ self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)
36
+ self.lemma_emb = nn.Embedding(len(vocab['lemma']), self.args['word_emb_dim'], padding_idx=0)
37
+ input_size += self.args['word_emb_dim'] * 2
38
+
39
+ if self.args['tag_emb_dim'] > 0:
40
+ if self.args.get('use_upos', True):
41
+ self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
42
+ if self.args.get('use_xpos', True):
43
+ if not isinstance(vocab['xpos'], CompositeVocab):
44
+ self.xpos_emb = nn.Embedding(len(vocab['xpos']), self.args['tag_emb_dim'], padding_idx=0)
45
+ else:
46
+ self.xpos_emb = nn.ModuleList()
47
+
48
+ for l in vocab['xpos'].lens():
49
+ self.xpos_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))
50
+ if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
51
+ input_size += self.args['tag_emb_dim']
52
+
53
+ if self.args.get('use_ufeats', True):
54
+ self.ufeats_emb = nn.ModuleList()
55
+
56
+ for l in vocab['feats'].lens():
57
+ self.ufeats_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))
58
+
59
+ input_size += self.args['tag_emb_dim']
60
+
61
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
62
+ if self.args.get('charlm', None):
63
+ if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
64
+ raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
65
+ if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
66
+ raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
67
+ logger.debug("Depparse model loading charmodels: %s and %s", args['charlm_forward_file'], args['charlm_backward_file'])
68
+ self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache))
69
+ self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache))
70
+ input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
71
+ else:
72
+ self.charmodel = CharacterModel(args, vocab)
73
+ self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
74
+ input_size += self.args['transformed_dim']
75
+
76
+ self.peft_name = peft_name
77
+ attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
78
+ if self.args.get('bert_model', None):
79
+ # TODO: refactor bert_hidden_layers between the different models
80
+ if args.get('bert_hidden_layers', False):
81
+ # The average will be offset by 1/N so that the default zeros
82
+ # represents an average of the N layers
83
+ self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
84
+ nn.init.zeros_(self.bert_layer_mix.weight)
85
+ else:
86
+ # an average of layers 2, 3, 4 will be used
87
+ # (for historic reasons)
88
+ self.bert_layer_mix = None
89
+ input_size += self.bert_model.config.hidden_size
90
+
91
+ if self.args['pretrain']:
92
+ # pretrained embeddings, by default this won't be saved into model file
93
+ self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
94
+ self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
95
+ input_size += self.args['transformed_dim']
96
+
97
+ # recurrent layers
98
+ self.parserlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)
99
+ self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
100
+ self.parserlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
101
+ self.parserlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
102
+
103
+ # classifiers
104
+ self.unlabeled = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=args['dropout'])
105
+ self.deprel = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], len(vocab['deprel']), pairwise=True, dropout=args['dropout'])
106
+ if args['linearization']:
107
+ self.linearization = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=args['dropout'])
108
+ if args['distance']:
109
+ self.distance = DeepBiaffineScorer(2 * self.args['hidden_dim'], 2 * self.args['hidden_dim'], self.args['deep_biaff_hidden_dim'], 1, pairwise=True, dropout=args['dropout'])
110
+
111
+ # criterion
112
+ self.crit = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') # ignore padding
113
+
114
+ self.drop = nn.Dropout(args['dropout'])
115
+ self.worddrop = WordDropout(args['word_dropout'])
116
+
117
+ def add_unsaved_module(self, name, module):
118
+ self.unsaved_modules += [name]
119
+ setattr(self, name, module)
120
+
121
+ def log_norms(self):
122
+ utils.log_norms(self)
123
+
124
+ def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text):
125
+ def pack(x):
126
+ return pack_padded_sequence(x, sentlens, batch_first=True)
127
+
128
+ inputs = []
129
+ if self.args['pretrain']:
130
+ pretrained_emb = self.pretrained_emb(pretrained)
131
+ pretrained_emb = self.trans_pretrained(pretrained_emb)
132
+ pretrained_emb = pack(pretrained_emb)
133
+ inputs += [pretrained_emb]
134
+
135
+ #def pad(x):
136
+ # return pad_packed_sequence(PackedSequence(x, pretrained_emb.batch_sizes), batch_first=True)[0]
137
+
138
+ if self.args['word_emb_dim'] > 0:
139
+ word_emb = self.word_emb(word)
140
+ word_emb = pack(word_emb)
141
+ lemma_emb = self.lemma_emb(lemma)
142
+ lemma_emb = pack(lemma_emb)
143
+ inputs += [word_emb, lemma_emb]
144
+
145
+ if self.args['tag_emb_dim'] > 0:
146
+ if self.args.get('use_upos', True):
147
+ pos_emb = self.upos_emb(upos)
148
+ else:
149
+ pos_emb = 0
150
+
151
+ if self.args.get('use_xpos', True):
152
+ if isinstance(self.vocab['xpos'], CompositeVocab):
153
+ for i in range(len(self.vocab['xpos'])):
154
+ pos_emb += self.xpos_emb[i](xpos[:, :, i])
155
+ else:
156
+ pos_emb += self.xpos_emb(xpos)
157
+
158
+ if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
159
+ pos_emb = pack(pos_emb)
160
+ inputs += [pos_emb]
161
+
162
+ if self.args.get('use_ufeats', True):
163
+ feats_emb = 0
164
+ for i in range(len(self.vocab['feats'])):
165
+ feats_emb += self.ufeats_emb[i](ufeats[:, :, i])
166
+ feats_emb = pack(feats_emb)
167
+
168
+ inputs += [pos_emb]
169
+
170
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
171
+ if self.args.get('charlm', None):
172
+ # \n is to add a somewhat neutral "word" for the ROOT
173
+ charlm_text = [["\n"] + x for x in text]
174
+ all_forward_chars = self.charmodel_forward.build_char_representation(charlm_text)
175
+ all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))
176
+ all_backward_chars = self.charmodel_backward.build_char_representation(charlm_text)
177
+ all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))
178
+ inputs += [all_forward_chars, all_backward_chars]
179
+ else:
180
+ char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
181
+ char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)
182
+ inputs += [char_reps]
183
+
184
+ if self.bert_model is not None:
185
+ device = next(self.parameters()).device
186
+ processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=True,
187
+ num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
188
+ detach=not self.args.get('bert_finetune', False) or not self.training,
189
+ peft_name=self.peft_name)
190
+ if self.bert_layer_mix is not None:
191
+ # use a linear layer to weighted average the embedding dynamically
192
+ processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
193
+
194
+ # we are using the first endpoint from the transformer as the "word" for ROOT
195
+ processed_bert = [x[:-1, :] for x in processed_bert]
196
+ processed_bert = pad_sequence(processed_bert, batch_first=True)
197
+ inputs += [pack(processed_bert)]
198
+
199
+ lstm_inputs = torch.cat([x.data for x in inputs], 1)
200
+
201
+ lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
202
+ lstm_inputs = self.drop(lstm_inputs)
203
+
204
+ lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
205
+
206
+ lstm_outputs, _ = self.parserlstm(lstm_inputs, sentlens, hx=(self.parserlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.parserlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))
207
+ lstm_outputs, _ = pad_packed_sequence(lstm_outputs, batch_first=True)
208
+
209
+ unlabeled_scores = self.unlabeled(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
210
+ deprel_scores = self.deprel(self.drop(lstm_outputs), self.drop(lstm_outputs))
211
+
212
+ #goldmask = head.new_zeros(*head.size(), head.size(-1)+1, dtype=torch.uint8)
213
+ #goldmask.scatter_(2, head.unsqueeze(2), 1)
214
+
215
+ if self.args['linearization'] or self.args['distance']:
216
+ head_offset = torch.arange(word.size(1), device=head.device).view(1, 1, -1).expand(word.size(0), -1, -1) - torch.arange(word.size(1), device=head.device).view(1, -1, 1).expand(word.size(0), -1, -1)
217
+
218
+ if self.args['linearization']:
219
+ lin_scores = self.linearization(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
220
+ unlabeled_scores += F.logsigmoid(lin_scores * torch.sign(head_offset).float()).detach()
221
+
222
+ if self.args['distance']:
223
+ dist_scores = self.distance(self.drop(lstm_outputs), self.drop(lstm_outputs)).squeeze(3)
224
+ dist_pred = 1 + F.softplus(dist_scores)
225
+ dist_target = torch.abs(head_offset)
226
+ dist_kld = -torch.log((dist_target.float() - dist_pred)**2/2 + 1)
227
+ unlabeled_scores += dist_kld.detach()
228
+
229
+ diag = torch.eye(head.size(-1)+1, dtype=torch.bool, device=head.device).unsqueeze(0)
230
+ unlabeled_scores.masked_fill_(diag, -float('inf'))
231
+
232
+ preds = []
233
+
234
+ if self.training:
235
+ unlabeled_scores = unlabeled_scores[:, 1:, :] # exclude attachment for the root symbol
236
+ unlabeled_scores = unlabeled_scores.masked_fill(word_mask.unsqueeze(1), -float('inf'))
237
+ unlabeled_target = head.masked_fill(word_mask[:, 1:], -1)
238
+ loss = self.crit(unlabeled_scores.contiguous().view(-1, unlabeled_scores.size(2)), unlabeled_target.view(-1))
239
+
240
+ deprel_scores = deprel_scores[:, 1:] # exclude attachment for the root symbol
241
+ #deprel_scores = deprel_scores.masked_select(goldmask.unsqueeze(3)).view(-1, len(self.vocab['deprel']))
242
+ deprel_scores = torch.gather(deprel_scores, 2, head.unsqueeze(2).unsqueeze(3).expand(-1, -1, -1, len(self.vocab['deprel']))).view(-1, len(self.vocab['deprel']))
243
+ deprel_target = deprel.masked_fill(word_mask[:, 1:], -1)
244
+ loss += self.crit(deprel_scores.contiguous(), deprel_target.view(-1))
245
+
246
+ if self.args['linearization']:
247
+ #lin_scores = lin_scores[:, 1:].masked_select(goldmask)
248
+ lin_scores = torch.gather(lin_scores[:, 1:], 2, head.unsqueeze(2)).view(-1)
249
+ lin_scores = torch.cat([-lin_scores.unsqueeze(1)/2, lin_scores.unsqueeze(1)/2], 1)
250
+ #lin_target = (head_offset[:, 1:] > 0).long().masked_select(goldmask)
251
+ lin_target = torch.gather((head_offset[:, 1:] > 0).long(), 2, head.unsqueeze(2))
252
+ loss += self.crit(lin_scores.contiguous(), lin_target.view(-1))
253
+
254
+ if self.args['distance']:
255
+ #dist_kld = dist_kld[:, 1:].masked_select(goldmask)
256
+ dist_kld = torch.gather(dist_kld[:, 1:], 2, head.unsqueeze(2))
257
+ loss -= dist_kld.sum()
258
+
259
+ loss /= wordchars.size(0) # number of words
260
+ else:
261
+ loss = 0
262
+ preds.append(F.log_softmax(unlabeled_scores, 2).detach().cpu().numpy())
263
+ preds.append(deprel_scores.max(3)[1].detach().cpu().numpy())
264
+
265
+ return loss, preds