bowphs commited on
Commit
ba68d3c
·
verified ·
1 Parent(s): 9634055

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stanza/stanza/models/classifiers/base_classifier.py +65 -0
  2. stanza/stanza/models/classifiers/cnn_classifier.py +547 -0
  3. stanza/stanza/models/classifiers/iterate_test.py +64 -0
  4. stanza/stanza/models/classifiers/trainer.py +304 -0
  5. stanza/stanza/models/constituency/__init__.py +0 -0
  6. stanza/stanza/models/constituency/evaluate_treebanks.py +36 -0
  7. stanza/stanza/models/constituency/label_attention.py +726 -0
  8. stanza/stanza/models/constituency/lstm_tree_stack.py +91 -0
  9. stanza/stanza/models/constituency/score_converted_dependencies.py +65 -0
  10. stanza/stanza/models/constituency/text_processing.py +166 -0
  11. stanza/stanza/models/constituency/tree_reader.py +274 -0
  12. stanza/stanza/models/constituency/tree_stack.py +57 -0
  13. stanza/stanza/models/constituency/utils.py +375 -0
  14. stanza/stanza/models/coref/predict.py +55 -0
  15. stanza/stanza/models/coref/span_predictor.py +146 -0
  16. stanza/stanza/models/coref/tokenizer_customization.py +18 -0
  17. stanza/stanza/models/coref/word_encoder.py +108 -0
  18. stanza/stanza/models/depparse/data.py +233 -0
  19. stanza/stanza/models/lemma/attach_lemma_classifier.py +25 -0
  20. stanza/stanza/models/lemma/scorer.py +13 -0
  21. stanza/stanza/models/lemma/vocab.py +18 -0
  22. stanza/stanza/models/lemma_classifier/base_trainer.py +114 -0
  23. stanza/stanza/models/lemma_classifier/constants.py +14 -0
  24. stanza/stanza/models/lemma_classifier/evaluate_many.py +68 -0
  25. stanza/stanza/models/lemma_classifier/evaluate_models.py +228 -0
  26. stanza/stanza/models/lemma_classifier/prepare_dataset.py +125 -0
  27. stanza/stanza/models/lemma_classifier/train_lstm_model.py +147 -0
  28. stanza/stanza/models/lemma_classifier/train_many.py +155 -0
  29. stanza/stanza/models/lemma_classifier/train_transformer_model.py +130 -0
  30. stanza/stanza/models/lemma_classifier/transformer_model.py +89 -0
  31. stanza/stanza/models/lemma_classifier/utils.py +173 -0
  32. stanza/stanza/models/mwt/character_classifier.py +65 -0
  33. stanza/stanza/models/mwt/trainer.py +218 -0
  34. stanza/stanza/models/mwt/vocab.py +19 -0
  35. stanza/stanza/models/ner/vocab.py +56 -0
  36. stanza/stanza/models/pos/__init__.py +0 -0
  37. stanza/stanza/models/pos/build_xpos_vocab_factory.py +144 -0
  38. stanza/stanza/models/pos/data.py +387 -0
  39. stanza/stanza/models/pos/model.py +256 -0
  40. stanza/stanza/models/pos/trainer.py +179 -0
  41. stanza/stanza/models/pos/xpos_vocab_factory.py +200 -0
  42. stanza/stanza/models/pos/xpos_vocab_utils.py +48 -0
  43. stanza/stanza/models/tokenization/__init__.py +0 -0
  44. stanza/stanza/models/tokenization/data.py +432 -0
  45. stanza/stanza/models/tokenization/model.py +101 -0
  46. stanza/stanza/models/tokenization/tokenize_files.py +83 -0
  47. stanza/stanza/models/tokenization/trainer.py +102 -0
  48. stanza/stanza/utils/datasets/constituency/convert_ctb.py +224 -0
  49. stanza/stanza/utils/datasets/constituency/extract_silver_dataset.py +47 -0
  50. stanza/stanza/utils/datasets/coref/balance_languages.py +60 -0
stanza/stanza/models/classifiers/base_classifier.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import logging
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort
9
+
10
+ """
11
+ A base classifier type
12
+
13
+ Currently, has the ability to process text or other inputs in a manner
14
+ suitable for the particular model type.
15
+ In other words, the CNNClassifier processes lists of words,
16
+ and the ConstituencyClassifier processes trees
17
+ """
18
+
19
+ logger = logging.getLogger('stanza')
20
+
21
+ class BaseClassifier(ABC, nn.Module):
22
+ @abstractmethod
23
+ def extract_sentences(self, doc):
24
+ """
25
+ Extract the sentences or the relevant information in the sentences from a document
26
+ """
27
+
28
+ def preprocess_sentences(self, sentences):
29
+ """
30
+ By default, don't do anything
31
+ """
32
+ return sentences
33
+
34
+ def label_sentences(self, sentences, batch_size=None):
35
+ """
36
+ Given a list of sentences, return the model's results on that text.
37
+ """
38
+ self.eval()
39
+
40
+ sentences = self.preprocess_sentences(sentences)
41
+
42
+ if batch_size is None:
43
+ intervals = [(0, len(sentences))]
44
+ orig_idx = None
45
+ else:
46
+ sentences, orig_idx = sort_with_indices(sentences, key=len, reverse=True)
47
+ intervals = split_into_batches(sentences, batch_size)
48
+ labels = []
49
+ for interval in intervals:
50
+ if interval[1] - interval[0] == 0:
51
+ # this can happen for empty text
52
+ continue
53
+ output = self(sentences[interval[0]:interval[1]])
54
+ predicted = torch.argmax(output, dim=1)
55
+ labels.extend(predicted.tolist())
56
+
57
+ if orig_idx:
58
+ sentences = unsort(sentences, orig_idx)
59
+ labels = unsort(labels, orig_idx)
60
+
61
+ logger.debug("Found labels")
62
+ for (label, sentence) in zip(labels, sentences):
63
+ logger.debug((label, sentence))
64
+
65
+ return labels
stanza/stanza/models/classifiers/cnn_classifier.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import math
4
+ import os
5
+ import random
6
+ import re
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ import stanza.models.classifiers.data as data
14
+ from stanza.models.classifiers.base_classifier import BaseClassifier
15
+ from stanza.models.classifiers.config import CNNConfig
16
+ from stanza.models.classifiers.data import SentimentDatum
17
+ from stanza.models.classifiers.utils import ExtraVectors, ModelType, build_output_layers
18
+ from stanza.models.common.bert_embedding import extract_bert_embeddings
19
+ from stanza.models.common.data import get_long_tensor, sort_all
20
+ from stanza.models.common.utils import attach_bert_model
21
+ from stanza.models.common.vocab import PAD_ID, UNK_ID
22
+
23
+ """
24
+ The CNN classifier is based on Yoon Kim's work:
25
+
26
+ https://arxiv.org/abs/1408.5882
27
+
28
+ Also included are maxpool 2d, conv 2d, and a bilstm, as in
29
+
30
+ Text Classification Improved by Integrating Bidirectional LSTM
31
+ with Two-dimensional Max Pooling
32
+ https://aclanthology.org/C16-1329.pdf
33
+
34
+ The architecture is simple:
35
+
36
+ - Embedding at the bottom layer
37
+ - separate learnable entry for UNK, since many of the embeddings we have use 0 for UNK
38
+ - maybe a bilstm layer, as per a command line flag
39
+ - Some number of conv2d layers over the embedding
40
+ - Maxpool layers over small windows, window size being a parameter
41
+ - FC layer to the classification layer
42
+
43
+ One experiment which was run and found to be a bit of a negative was
44
+ putting a layer on top of the pretrain. You would think that might
45
+ help, but dev performance went down for each variation of
46
+ - trans(emb)
47
+ - relu(trans(emb))
48
+ - dropout(trans(emb))
49
+ - dropout(relu(trans(emb)))
50
+ """
51
+
52
+ logger = logging.getLogger('stanza')
53
+ tlogger = logging.getLogger('stanza.classifiers.trainer')
54
+
55
+ class CNNClassifier(BaseClassifier):
56
+ def __init__(self, pretrain, extra_vocab, labels,
57
+ charmodel_forward, charmodel_backward, elmo_model, bert_model, bert_tokenizer, force_bert_saved, peft_name,
58
+ args):
59
+ """
60
+ pretrain is a pretrained word embedding. should have .emb and .vocab
61
+
62
+ extra_vocab is a collection of words in the training data to
63
+ be used for the delta word embedding, if used. can be set to
64
+ None if delta word embedding is not used.
65
+
66
+ labels is the list of labels we expect in the training data.
67
+ Used to derive the number of classes. Saving it in the model
68
+ will let us check that test data has the same labels
69
+
70
+ args is either the complete arguments when training, or the
71
+ subset of arguments stored in the model save file
72
+ """
73
+ super(CNNClassifier, self).__init__()
74
+ self.labels = labels
75
+ bert_finetune = args.bert_finetune
76
+ use_peft = args.use_peft
77
+ force_bert_saved = force_bert_saved or bert_finetune
78
+ logger.debug("bert_finetune %s / force_bert_saved %s", bert_finetune, force_bert_saved)
79
+
80
+ # this may change when loaded in a new Pipeline, so it's not part of the config
81
+ self.peft_name = peft_name
82
+
83
+ # we build a separate config out of the args so that we can easily save it in torch
84
+ self.config = CNNConfig(filter_channels = args.filter_channels,
85
+ filter_sizes = args.filter_sizes,
86
+ fc_shapes = args.fc_shapes,
87
+ dropout = args.dropout,
88
+ num_classes = len(labels),
89
+ wordvec_type = args.wordvec_type,
90
+ extra_wordvec_method = args.extra_wordvec_method,
91
+ extra_wordvec_dim = args.extra_wordvec_dim,
92
+ extra_wordvec_max_norm = args.extra_wordvec_max_norm,
93
+ char_lowercase = args.char_lowercase,
94
+ charlm_projection = args.charlm_projection,
95
+ has_charlm_forward = charmodel_forward is not None,
96
+ has_charlm_backward = charmodel_backward is not None,
97
+ use_elmo = args.use_elmo,
98
+ elmo_projection = args.elmo_projection,
99
+ bert_model = args.bert_model,
100
+ bert_finetune = bert_finetune,
101
+ bert_hidden_layers = args.bert_hidden_layers,
102
+ force_bert_saved = force_bert_saved,
103
+
104
+ use_peft = use_peft,
105
+ lora_rank = args.lora_rank,
106
+ lora_alpha = args.lora_alpha,
107
+ lora_dropout = args.lora_dropout,
108
+ lora_modules_to_save = args.lora_modules_to_save,
109
+ lora_target_modules = args.lora_target_modules,
110
+
111
+ bilstm = args.bilstm,
112
+ bilstm_hidden_dim = args.bilstm_hidden_dim,
113
+ maxpool_width = args.maxpool_width,
114
+ model_type = ModelType.CNN)
115
+
116
+ self.char_lowercase = args.char_lowercase
117
+
118
+ self.unsaved_modules = []
119
+
120
+ emb_matrix = pretrain.emb
121
+ self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
122
+ self.add_unsaved_module('elmo_model', elmo_model)
123
+ self.vocab_size = emb_matrix.shape[0]
124
+ self.embedding_dim = emb_matrix.shape[1]
125
+
126
+ self.add_unsaved_module('forward_charlm', charmodel_forward)
127
+ if charmodel_forward is not None:
128
+ tlogger.debug("Got forward char model of dimension {}".format(charmodel_forward.hidden_dim()))
129
+ if not charmodel_forward.is_forward_lm:
130
+ raise ValueError("Got a backward charlm as a forward charlm!")
131
+ self.add_unsaved_module('backward_charlm', charmodel_backward)
132
+ if charmodel_backward is not None:
133
+ tlogger.debug("Got backward char model of dimension {}".format(charmodel_backward.hidden_dim()))
134
+ if charmodel_backward.is_forward_lm:
135
+ raise ValueError("Got a forward charlm as a backward charlm!")
136
+
137
+ attach_bert_model(self, bert_model, bert_tokenizer, self.config.use_peft, force_bert_saved)
138
+
139
+ # The Pretrain has PAD and UNK already (indices 0 and 1), but we
140
+ # possibly want to train UNK while freezing the rest of the embedding
141
+ # note that the /10.0 operation has to be inside nn.Parameter unless
142
+ # you want to spend a long time debugging this
143
+ self.unk = nn.Parameter(torch.randn(self.embedding_dim) / np.sqrt(self.embedding_dim) / 10.0)
144
+
145
+ # replacing NBSP picks up a whole bunch of words for VI
146
+ self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pretrain.vocab) }
147
+
148
+ if self.config.extra_wordvec_method is not ExtraVectors.NONE:
149
+ if not extra_vocab:
150
+ raise ValueError("Should have had extra_vocab set for extra_wordvec_method {}".format(self.config.extra_wordvec_method))
151
+ if not args.extra_wordvec_dim:
152
+ self.config.extra_wordvec_dim = self.embedding_dim
153
+ if self.config.extra_wordvec_method is ExtraVectors.SUM:
154
+ if self.config.extra_wordvec_dim != self.embedding_dim:
155
+ raise ValueError("extra_wordvec_dim must equal embedding_dim for {}".format(self.config.extra_wordvec_method))
156
+
157
+ self.extra_vocab = list(extra_vocab)
158
+ self.extra_vocab_map = { word: i for i, word in enumerate(self.extra_vocab) }
159
+ # TODO: possibly add regularization specifically on the extra embedding?
160
+ # note: it looks like a bug that this doesn't add UNK or PAD, but actually
161
+ # those are expected to already be the first two entries
162
+ self.extra_embedding = nn.Embedding(num_embeddings = len(extra_vocab),
163
+ embedding_dim = self.config.extra_wordvec_dim,
164
+ max_norm = self.config.extra_wordvec_max_norm,
165
+ padding_idx = 0)
166
+ tlogger.debug("Extra embedding size: {}".format(self.extra_embedding.weight.shape))
167
+ else:
168
+ self.extra_vocab = None
169
+ self.extra_vocab_map = None
170
+ self.config.extra_wordvec_dim = 0
171
+ self.extra_embedding = None
172
+
173
+ # Pytorch is "aware" of the existence of the nn.Modules inside
174
+ # an nn.ModuleList in terms of parameters() etc
175
+ if self.config.extra_wordvec_method is ExtraVectors.NONE:
176
+ total_embedding_dim = self.embedding_dim
177
+ elif self.config.extra_wordvec_method is ExtraVectors.SUM:
178
+ total_embedding_dim = self.embedding_dim
179
+ elif self.config.extra_wordvec_method is ExtraVectors.CONCAT:
180
+ total_embedding_dim = self.embedding_dim + self.config.extra_wordvec_dim
181
+ else:
182
+ raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
183
+
184
+ if charmodel_forward is not None:
185
+ if args.charlm_projection:
186
+ self.charmodel_forward_projection = nn.Linear(charmodel_forward.hidden_dim(), args.charlm_projection)
187
+ total_embedding_dim += args.charlm_projection
188
+ else:
189
+ self.charmodel_forward_projection = None
190
+ total_embedding_dim += charmodel_forward.hidden_dim()
191
+
192
+ if charmodel_backward is not None:
193
+ if args.charlm_projection:
194
+ self.charmodel_backward_projection = nn.Linear(charmodel_backward.hidden_dim(), args.charlm_projection)
195
+ total_embedding_dim += args.charlm_projection
196
+ else:
197
+ self.charmodel_backward_projection = None
198
+ total_embedding_dim += charmodel_backward.hidden_dim()
199
+
200
+ if self.config.use_elmo:
201
+ if elmo_model is None:
202
+ raise ValueError("Model requires elmo, but elmo_model not passed in")
203
+ elmo_dim = elmo_model.sents2elmo([["Test"]])[0].shape[1]
204
+
205
+ # this mapping will combine 3 layers of elmo to 1 layer of features
206
+ self.elmo_combine_layers = nn.Linear(in_features=3, out_features=1, bias=False)
207
+ if self.config.elmo_projection:
208
+ self.elmo_projection = nn.Linear(in_features=elmo_dim, out_features=self.config.elmo_projection)
209
+ total_embedding_dim = total_embedding_dim + self.config.elmo_projection
210
+ else:
211
+ total_embedding_dim = total_embedding_dim + elmo_dim
212
+
213
+ if bert_model is not None:
214
+ if self.config.bert_hidden_layers:
215
+ # The average will be offset by 1/N so that the default zeros
216
+ # repressents an average of the N layers
217
+ if self.config.bert_hidden_layers > bert_model.config.num_hidden_layers:
218
+ # limit ourselves to the number of layers actually available
219
+ # note that we can +1 because of the initial embedding layer
220
+ self.config.bert_hidden_layers = bert_model.config.num_hidden_layers + 1
221
+ self.bert_layer_mix = nn.Linear(self.config.bert_hidden_layers, 1, bias=False)
222
+ nn.init.zeros_(self.bert_layer_mix.weight)
223
+ else:
224
+ # an average of layers 2, 3, 4 will be used
225
+ # (for historic reasons)
226
+ self.bert_layer_mix = None
227
+
228
+ if bert_tokenizer is None:
229
+ raise ValueError("Cannot have a bert model without a tokenizer")
230
+ self.bert_dim = self.bert_model.config.hidden_size
231
+ total_embedding_dim += self.bert_dim
232
+
233
+ if self.config.bilstm:
234
+ conv_input_dim = self.config.bilstm_hidden_dim * 2
235
+ self.bilstm = nn.LSTM(batch_first=True,
236
+ input_size=total_embedding_dim,
237
+ hidden_size=self.config.bilstm_hidden_dim,
238
+ num_layers=2,
239
+ bidirectional=True,
240
+ dropout=0.2)
241
+ else:
242
+ conv_input_dim = total_embedding_dim
243
+ self.bilstm = None
244
+
245
+ self.fc_input_size = 0
246
+ self.conv_layers = nn.ModuleList()
247
+ self.max_window = 0
248
+ for filter_idx, filter_size in enumerate(self.config.filter_sizes):
249
+ if isinstance(filter_size, int):
250
+ self.max_window = max(self.max_window, filter_size)
251
+ if isinstance(self.config.filter_channels, int):
252
+ filter_channels = self.config.filter_channels
253
+ else:
254
+ filter_channels = self.config.filter_channels[filter_idx]
255
+ fc_delta = filter_channels // self.config.maxpool_width
256
+ tlogger.debug("Adding full width filter %d. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
257
+ self.fc_input_size += fc_delta
258
+ self.conv_layers.append(nn.Conv2d(in_channels=1,
259
+ out_channels=filter_channels,
260
+ kernel_size=(filter_size, conv_input_dim)))
261
+ elif isinstance(filter_size, tuple) and len(filter_size) == 2:
262
+ filter_height, filter_width = filter_size
263
+ self.max_window = max(self.max_window, filter_width)
264
+ if isinstance(self.config.filter_channels, int):
265
+ filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width))
266
+ else:
267
+ filter_channels = self.config.filter_channels[filter_idx]
268
+ fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width
269
+ tlogger.debug("Adding filter %s. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
270
+ self.fc_input_size += fc_delta
271
+ self.conv_layers.append(nn.Conv2d(in_channels=1,
272
+ out_channels=filter_channels,
273
+ stride=(1, filter_width),
274
+ kernel_size=(filter_height, filter_width)))
275
+ else:
276
+ raise ValueError("Expected int or 2d tuple for conv size")
277
+
278
+ tlogger.debug("Input dim to FC layers: %d", self.fc_input_size)
279
+ self.fc_layers = build_output_layers(self.fc_input_size, self.config.fc_shapes, self.config.num_classes)
280
+
281
+ self.dropout = nn.Dropout(self.config.dropout)
282
+
283
+ def add_unsaved_module(self, name, module):
284
+ self.unsaved_modules += [name]
285
+ setattr(self, name, module)
286
+
287
+ if module is not None and (name in ('forward_charlm', 'backward_charlm') or
288
+ (name == 'bert_model' and not self.config.use_peft)):
289
+ # if we are using peft, we should not save the transformer directly
290
+ # instead, the peft parameters only will be saved later
291
+ for _, parameter in module.named_parameters():
292
+ parameter.requires_grad = False
293
+
294
+ def is_unsaved_module(self, name):
295
+ return name.split('.')[0] in self.unsaved_modules
296
+
297
+ def log_configuration(self):
298
+ """
299
+ Log some essential information about the model configuration to the training logger
300
+ """
301
+ tlogger.info("Filter sizes: %s" % str(self.config.filter_sizes))
302
+ tlogger.info("Filter channels: %s" % str(self.config.filter_channels))
303
+ tlogger.info("Intermediate layers: %s" % str(self.config.fc_shapes))
304
+
305
+ def log_norms(self):
306
+ lines = ["NORMS FOR MODEL PARAMTERS"]
307
+ for name, param in self.named_parameters():
308
+ if param.requires_grad and name.split(".")[0] not in ('forward_charlm', 'backward_charlm'):
309
+ lines.append("%s %.6g" % (name, torch.norm(param).item()))
310
+ logger.info("\n".join(lines))
311
+
312
+ def build_char_reps(self, inputs, max_phrase_len, charlm, projection, begin_paddings, device):
313
+ char_reps = charlm.build_char_representation(inputs)
314
+ if projection is not None:
315
+ char_reps = [projection(x) for x in char_reps]
316
+ char_inputs = torch.zeros((len(inputs), max_phrase_len, char_reps[0].shape[-1]), device=device)
317
+ for idx, rep in enumerate(char_reps):
318
+ start = begin_paddings[idx]
319
+ end = start + rep.shape[0]
320
+ char_inputs[idx, start:end, :] = rep
321
+ return char_inputs
322
+
323
+ def extract_bert_embeddings(self, inputs, max_phrase_len, begin_paddings, device):
324
+ bert_embeddings = extract_bert_embeddings(self.config.bert_model, self.bert_tokenizer, self.bert_model, inputs, device,
325
+ keep_endpoints=False,
326
+ num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
327
+ detach=not self.config.bert_finetune,
328
+ peft_name=self.peft_name)
329
+ if self.bert_layer_mix is not None:
330
+ # add the average so that the default behavior is to
331
+ # take an average of the N layers, and anything else
332
+ # other than that needs to be learned
333
+ bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]
334
+ bert_inputs = torch.zeros((len(inputs), max_phrase_len, bert_embeddings[0].shape[-1]), device=device)
335
+ for idx, rep in enumerate(bert_embeddings):
336
+ start = begin_paddings[idx]
337
+ end = start + rep.shape[0]
338
+ bert_inputs[idx, start:end, :] = rep
339
+ return bert_inputs
340
+
341
+ def forward(self, inputs):
342
+ # assume all pieces are on the same device
343
+ device = next(self.parameters()).device
344
+
345
+ vocab_map = self.vocab_map
346
+ def map_word(word):
347
+ idx = vocab_map.get(word, None)
348
+ if idx is not None:
349
+ return idx
350
+ if word[-1] == "'":
351
+ idx = vocab_map.get(word[:-1], None)
352
+ if idx is not None:
353
+ return idx
354
+ return vocab_map.get(word.lower(), UNK_ID)
355
+
356
+ inputs = [x.text if isinstance(x, SentimentDatum) else x for x in inputs]
357
+ # we will pad each phrase so either it matches the longest
358
+ # conv or the longest phrase in the input, whichever is longer
359
+ max_phrase_len = max(len(x) for x in inputs)
360
+ if self.max_window > max_phrase_len:
361
+ max_phrase_len = self.max_window
362
+
363
+ batch_indices = []
364
+ batch_unknowns = []
365
+ extra_batch_indices = []
366
+ begin_paddings = []
367
+ end_paddings = []
368
+
369
+ elmo_batch_words = []
370
+
371
+ for phrase in inputs:
372
+ # we use random at training time to try to learn different
373
+ # positions of padding. at test time, though, we want to
374
+ # have consistent results, so we set that to 0 begin_pad
375
+ if self.training:
376
+ begin_pad_width = random.randint(0, max_phrase_len - len(phrase))
377
+ else:
378
+ begin_pad_width = 0
379
+ end_pad_width = max_phrase_len - begin_pad_width - len(phrase)
380
+
381
+ begin_paddings.append(begin_pad_width)
382
+ end_paddings.append(end_pad_width)
383
+
384
+ # the initial lists are the length of the begin padding
385
+ sentence_indices = [PAD_ID] * begin_pad_width
386
+ sentence_indices.extend([map_word(x) for x in phrase])
387
+ sentence_indices.extend([PAD_ID] * end_pad_width)
388
+
389
+ # the "unknowns" will be the locations of the unknown words.
390
+ # these locations will get the specially trained unknown vector
391
+ # TODO: split UNK based on part of speech? might be an interesting experiment
392
+ sentence_unknowns = [idx for idx, word in enumerate(sentence_indices) if word == UNK_ID]
393
+
394
+ batch_indices.append(sentence_indices)
395
+ batch_unknowns.append(sentence_unknowns)
396
+
397
+ if self.extra_vocab:
398
+ extra_sentence_indices = [PAD_ID] * begin_pad_width
399
+ for word in phrase:
400
+ if word in self.extra_vocab_map:
401
+ # the extra vocab is initialized from the
402
+ # words in the training set, which means there
403
+ # would be no unknown words. to occasionally
404
+ # train the extra vocab's unknown words, we
405
+ # replace 1% of the words with UNK
406
+ # we don't do that for the original embedding
407
+ # on the assumption that there may be some
408
+ # unknown words in the training set anyway
409
+ # TODO: maybe train unk for the original embedding?
410
+ if self.training and random.random() < 0.01:
411
+ extra_sentence_indices.append(UNK_ID)
412
+ else:
413
+ extra_sentence_indices.append(self.extra_vocab_map[word])
414
+ else:
415
+ extra_sentence_indices.append(UNK_ID)
416
+ extra_sentence_indices.extend([PAD_ID] * end_pad_width)
417
+ extra_batch_indices.append(extra_sentence_indices)
418
+
419
+ if self.config.use_elmo:
420
+ elmo_phrase_words = [""] * begin_pad_width
421
+ for word in phrase:
422
+ elmo_phrase_words.append(word)
423
+ elmo_phrase_words.extend([""] * end_pad_width)
424
+ elmo_batch_words.append(elmo_phrase_words)
425
+
426
+ # creating a single large list with all the indices lets us
427
+ # create a single tensor, which is much faster than creating
428
+ # many tiny tensors
429
+ # we can convert this to the input to the CNN
430
+ # it is padded at one or both ends so that it is now num_phrases x max_len x emb_size
431
+ # there are two ways in which this padding is suboptimal
432
+ # the first is that for short sentences, smaller windows will
433
+ # be padded to the point that some windows are entirely pad
434
+ # the second is that a sentence S will have more or less padding
435
+ # depending on what other sentences are in its batch
436
+ # we assume these effects are pretty minimal
437
+ batch_indices = torch.tensor(batch_indices, requires_grad=False, device=device)
438
+ input_vectors = self.embedding(batch_indices)
439
+ # we use the random unk so that we are not necessarily
440
+ # learning to match 0s for unk
441
+ for phrase_num, sentence_unknowns in enumerate(batch_unknowns):
442
+ input_vectors[phrase_num][sentence_unknowns] = self.unk
443
+
444
+ if self.extra_vocab:
445
+ extra_batch_indices = torch.tensor(extra_batch_indices, requires_grad=False, device=device)
446
+ extra_input_vectors = self.extra_embedding(extra_batch_indices)
447
+ if self.config.extra_wordvec_method is ExtraVectors.CONCAT:
448
+ all_inputs = [input_vectors, extra_input_vectors]
449
+ elif self.config.extra_wordvec_method is ExtraVectors.SUM:
450
+ all_inputs = [input_vectors + extra_input_vectors]
451
+ else:
452
+ raise ValueError("unable to handle {}".format(self.config.extra_wordvec_method))
453
+ else:
454
+ all_inputs = [input_vectors]
455
+
456
+ if self.forward_charlm is not None:
457
+ char_reps_forward = self.build_char_reps(inputs, max_phrase_len, self.forward_charlm, self.charmodel_forward_projection, begin_paddings, device)
458
+ all_inputs.append(char_reps_forward)
459
+
460
+ if self.backward_charlm is not None:
461
+ char_reps_backward = self.build_char_reps(inputs, max_phrase_len, self.backward_charlm, self.charmodel_backward_projection, begin_paddings, device)
462
+ all_inputs.append(char_reps_backward)
463
+
464
+ if self.config.use_elmo:
465
+ # this will be N arrays of 3xMx1024 where M is the number of words
466
+ # and N is the number of sentences (and 1024 is actually the number of weights)
467
+ elmo_arrays = self.elmo_model.sents2elmo(elmo_batch_words, output_layer=-2)
468
+ elmo_tensors = [torch.tensor(x).to(device=device) for x in elmo_arrays]
469
+ # elmo_tensor will now be Nx3xMx1024
470
+ elmo_tensor = torch.stack(elmo_tensors)
471
+ # Nx1024xMx3
472
+ elmo_tensor = torch.transpose(elmo_tensor, 1, 3)
473
+ # NxMx1024x3
474
+ elmo_tensor = torch.transpose(elmo_tensor, 1, 2)
475
+ # NxMx1024x1
476
+ elmo_tensor = self.elmo_combine_layers(elmo_tensor)
477
+ # NxMx1024
478
+ elmo_tensor = elmo_tensor.squeeze(3)
479
+ if self.config.elmo_projection:
480
+ elmo_tensor = self.elmo_projection(elmo_tensor)
481
+ all_inputs.append(elmo_tensor)
482
+
483
+ if self.bert_model is not None:
484
+ bert_embeddings = self.extract_bert_embeddings(inputs, max_phrase_len, begin_paddings, device)
485
+ all_inputs.append(bert_embeddings)
486
+
487
+ # still works even if there's just one item
488
+ input_vectors = torch.cat(all_inputs, dim=2)
489
+
490
+ if self.config.bilstm:
491
+ input_vectors, _ = self.bilstm(self.dropout(input_vectors))
492
+
493
+ # reshape to fit the input tensors
494
+ x = input_vectors.unsqueeze(1)
495
+
496
+ conv_outs = []
497
+ for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes):
498
+ if isinstance(filter_size, int):
499
+ conv_out = self.dropout(F.relu(conv(x).squeeze(3)))
500
+ conv_outs.append(conv_out)
501
+ else:
502
+ conv_out = conv(x).transpose(2, 3).flatten(1, 2)
503
+ conv_out = self.dropout(F.relu(conv_out))
504
+ conv_outs.append(conv_out)
505
+ pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs]
506
+ pooled = torch.cat(pool_outs, dim=1)
507
+
508
+ previous_layer = pooled
509
+ for fc in self.fc_layers[:-1]:
510
+ previous_layer = self.dropout(F.relu(fc(previous_layer)))
511
+ out = self.fc_layers[-1](previous_layer)
512
+ # note that we return the raw logits rather than use a softmax
513
+ # https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/4
514
+ return out
515
+
516
+ def get_params(self, skip_modules=True):
517
+ model_state = self.state_dict()
518
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
519
+ if skip_modules:
520
+ skipped = [k for k in model_state.keys() if self.is_unsaved_module(k)]
521
+ for k in skipped:
522
+ del model_state[k]
523
+
524
+ config = dataclasses.asdict(self.config)
525
+ config['wordvec_type'] = config['wordvec_type'].name
526
+ config['extra_wordvec_method'] = config['extra_wordvec_method'].name
527
+ config['model_type'] = config['model_type'].name
528
+
529
+ params = {
530
+ 'model': model_state,
531
+ 'config': config,
532
+ 'labels': self.labels,
533
+ 'extra_vocab': self.extra_vocab,
534
+ }
535
+ if self.config.use_peft:
536
+ # Hide import so that peft dependency is optional
537
+ from peft import get_peft_model_state_dict
538
+ params["bert_lora"] = get_peft_model_state_dict(self.bert_model, adapter_name=self.peft_name)
539
+ return params
540
+
541
+ def preprocess_data(self, sentences):
542
+ sentences = [data.update_text(s, self.config.wordvec_type) for s in sentences]
543
+ return sentences
544
+
545
+ def extract_sentences(self, doc):
546
+ # TODO: tokens or words better here?
547
+ return [[token.text for token in sentence.tokens] for sentence in doc.sentences]
stanza/stanza/models/classifiers/iterate_test.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Iterate test."""
2
+ import argparse
3
+ import glob
4
+ import logging
5
+
6
+ import stanza.models.classifier as classifier
7
+ import stanza.models.classifiers.cnn_classifier as cnn_classifier
8
+ from stanza.models.common import utils
9
+
10
+ from stanza.utils.confusion import format_confusion, confusion_to_accuracy
11
+
12
+ """
13
+ A script for running the same test file on several different classifiers.
14
+
15
+ For each one, it will output the accuracy and, if possible, the confusion matrix.
16
+
17
+ Includes the arguments for pretrain, which allows for passing in a
18
+ different directory for the pretrain file.
19
+
20
+ Example command line:
21
+ python3 -m stanza.models.classifiers.iterate_test --test_file extern_data/sentiment/sst-processed/threeclass/test-threeclass-roots.txt --glob "saved_models/classifier/FC41_3class_en_ewt_FS*ACC66*"
22
+ """
23
+
24
+ logger = logging.getLogger('stanza')
25
+
26
+
27
+ def parse_args():
28
+ """Add and parse arguments."""
29
+ parser = classifier.build_parser()
30
+
31
+ parser.add_argument('--glob', type=str, default='saved_models/classifier/*classifier*pt', help='Model file(s) to test.')
32
+
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+ args = parse_args()
37
+ seed = utils.set_random_seed(args.seed)
38
+
39
+ model_files = []
40
+ for glob_piece in args.glob.split():
41
+ model_files.extend(glob.glob(glob_piece))
42
+ model_files = sorted(set(model_files))
43
+
44
+ test_set = data.read_dataset(args.test_file, args.wordvec_type, min_len=None)
45
+ logger.info("Using test set: %s" % args.test_file)
46
+
47
+ device = None
48
+ for load_name in model_files:
49
+ args.load_name = load_name
50
+ model = classifier.load_model(args)
51
+
52
+ logger.info("Testing %s" % load_name)
53
+ model = cnn_classifier.load(load_name, pretrain)
54
+ if device is None:
55
+ device = next(model.parameters()).device
56
+ logger.info("Current device: %s" % device)
57
+
58
+ labels = model.labels
59
+ classifier.check_labels(labels, test_set)
60
+
61
+ confusion = classifier.confusion_dataset(model, test_set, device=device)
62
+ correct, total = confusion_to_accuracy(confusion)
63
+ logger.info(" Results: %d correct of %d examples. Accuracy: %f" % (correct, total, correct / total))
64
+ logger.info("Confusion matrix:\n{}".format(format_confusion(confusion, model.labels)))
stanza/stanza/models/classifiers/trainer.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Organizes the model itself and its optimizer in one place
3
+
4
+ Saving the optimizer allows for easy restarting of training
5
+ """
6
+
7
+ import logging
8
+ import os
9
+ import torch
10
+ import torch.optim as optim
11
+ from types import SimpleNamespace
12
+
13
+ import stanza.models.classifiers.data as data
14
+ import stanza.models.classifiers.cnn_classifier as cnn_classifier
15
+ import stanza.models.classifiers.constituency_classifier as constituency_classifier
16
+ from stanza.models.classifiers.config import CNNConfig, ConstituencyConfig
17
+ from stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors
18
+ from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain
19
+ from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
20
+ from stanza.models.common.pretrain import Pretrain
21
+ from stanza.models.common.utils import get_split_optimizer
22
+ from stanza.models.constituency.tree_embedding import TreeEmbedding
23
+
24
+ from pickle import UnpicklingError
25
+ import warnings
26
+
27
+ logger = logging.getLogger('stanza')
28
+
29
+ class Trainer:
30
+ """
31
+ Stores a constituency model and its optimizer
32
+ """
33
+
34
+ def __init__(self, model, optimizer=None, epochs_trained=0, global_step=0, best_score=None):
35
+ self.model = model
36
+ self.optimizer = optimizer
37
+ # we keep track of position in the learning so that we can
38
+ # checkpoint & restart if needed without restarting the epoch count
39
+ self.epochs_trained = epochs_trained
40
+ self.global_step = global_step
41
+ # save the best dev score so that when reloading a checkpoint
42
+ # of a model, we know how far we got
43
+ self.best_score = best_score
44
+
45
+ def save(self, filename, epochs_trained=None, skip_modules=True, save_optimizer=True):
46
+ """
47
+ save the current model, optimizer, and other state to filename
48
+
49
+ epochs_trained can be passed as a parameter to handle saving at the end of an epoch
50
+ """
51
+ if epochs_trained is None:
52
+ epochs_trained = self.epochs_trained
53
+ save_dir = os.path.split(filename)[0]
54
+ os.makedirs(save_dir, exist_ok=True)
55
+ model_params = self.model.get_params(skip_modules)
56
+ params = {
57
+ 'params': model_params,
58
+ 'epochs_trained': epochs_trained,
59
+ 'global_step': self.global_step,
60
+ 'best_score': self.best_score,
61
+ }
62
+ if save_optimizer and self.optimizer is not None:
63
+ params['optimizer_state_dict'] = {opt_name: opt.state_dict() for opt_name, opt in self.optimizer.items()}
64
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
65
+ logger.info("Model saved to {}".format(filename))
66
+
67
+ @staticmethod
68
+ def load(filename, args, foundation_cache=None, load_optimizer=False):
69
+ if not os.path.exists(filename):
70
+ if args.save_dir is None:
71
+ raise FileNotFoundError("Cannot find model in {} and args.save_dir is None".format(filename))
72
+ elif os.path.exists(os.path.join(args.save_dir, filename)):
73
+ filename = os.path.join(args.save_dir, filename)
74
+ else:
75
+ raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
76
+ try:
77
+ # TODO: can remove the try/except once the new version is out
78
+ #checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
79
+ try:
80
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
81
+ except UnpicklingError as e:
82
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
83
+ warnings.warn("The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained classifier using this version ASAP.")
84
+ except BaseException:
85
+ logger.exception("Cannot load model from {}".format(filename))
86
+ raise
87
+ logger.debug("Loaded model {}".format(filename))
88
+
89
+ epochs_trained = checkpoint.get('epochs_trained', 0)
90
+ global_step = checkpoint.get('global_step', 0)
91
+ best_score = checkpoint.get('best_score', None)
92
+
93
+ # TODO: can remove this block once all models are retrained
94
+ if 'params' not in checkpoint:
95
+ model_params = {
96
+ 'model': checkpoint['model'],
97
+ 'config': checkpoint['config'],
98
+ 'labels': checkpoint['labels'],
99
+ 'extra_vocab': checkpoint['extra_vocab'],
100
+ }
101
+ else:
102
+ model_params = checkpoint['params']
103
+ # TODO: this can be removed once v1.10.0 is out
104
+ if isinstance(model_params['config'], SimpleNamespace):
105
+ model_params['config'] = vars(model_params['config'])
106
+ # TODO: these isinstance can go away after 1.10.0
107
+ model_type = model_params['config']['model_type']
108
+ if isinstance(model_type, str):
109
+ model_type = ModelType[model_type]
110
+ model_params['config']['model_type'] = model_type
111
+
112
+ if model_type == ModelType.CNN:
113
+ # TODO: these updates are only necessary during the
114
+ # transition to the @dataclass version of the config
115
+ # Once those are all saved, it is no longer necessary
116
+ # to patch existing models (since they will all be patched)
117
+ if 'has_charlm_forward' not in model_params['config']:
118
+ model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None
119
+ if 'has_charlm_backward' not in model_params['config']:
120
+ model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None
121
+ for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft',
122
+ 'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']:
123
+ model_params['config'][argname] = model_params['config'].get(argname, None)
124
+ # TODO: these isinstance can go away after 1.10.0
125
+ if isinstance(model_params['config']['wordvec_type'], str):
126
+ model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']]
127
+ if isinstance(model_params['config']['extra_wordvec_method'], str):
128
+ model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']]
129
+ model_params['config'] = CNNConfig(**model_params['config'])
130
+
131
+ pretrain = Trainer.load_pretrain(args, foundation_cache)
132
+ elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
133
+
134
+ if model_params['config'].has_charlm_forward:
135
+ charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)
136
+ else:
137
+ charmodel_forward = None
138
+ if model_params['config'].has_charlm_backward:
139
+ charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)
140
+ else:
141
+ charmodel_backward = None
142
+
143
+ bert_model = model_params['config'].bert_model
144
+ # TODO: can get rid of the getattr after rebuilding all models
145
+ use_peft = getattr(model_params['config'], 'use_peft', False)
146
+ force_bert_saved = getattr(model_params['config'], 'force_bert_saved', False)
147
+ peft_name = None
148
+ if use_peft:
149
+ # if loading a peft model, we first load the base transformer
150
+ # the CNNClassifier code wraps the transformer in peft
151
+ # after creating the CNNClassifier with the peft wrapper,
152
+ # we *then* load the weights
153
+ bert_model, bert_tokenizer, peft_name = load_bert_with_peft(bert_model, "classifier", foundation_cache)
154
+ bert_model = load_peft_wrapper(bert_model, model_params['bert_lora'], vars(model_params['config']), logger, peft_name)
155
+ elif force_bert_saved:
156
+ bert_model, bert_tokenizer = load_bert(bert_model)
157
+ else:
158
+ bert_model, bert_tokenizer = load_bert(bert_model, foundation_cache)
159
+ model = cnn_classifier.CNNClassifier(pretrain=pretrain,
160
+ extra_vocab=model_params['extra_vocab'],
161
+ labels=model_params['labels'],
162
+ charmodel_forward=charmodel_forward,
163
+ charmodel_backward=charmodel_backward,
164
+ elmo_model=elmo_model,
165
+ bert_model=bert_model,
166
+ bert_tokenizer=bert_tokenizer,
167
+ force_bert_saved=force_bert_saved,
168
+ peft_name=peft_name,
169
+ args=model_params['config'])
170
+ elif model_type == ModelType.CONSTITUENCY:
171
+ # the constituency version doesn't have a peft feature yet
172
+ use_peft = False
173
+ pretrain_args = {
174
+ 'wordvec_pretrain_file': args.wordvec_pretrain_file,
175
+ 'charlm_forward_file': args.charlm_forward_file,
176
+ 'charlm_backward_file': args.charlm_backward_file,
177
+ }
178
+ # TODO: integrate with peft for the constituency version
179
+ tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache)
180
+ model_params['config'] = ConstituencyConfig(**model_params['config'])
181
+ model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
182
+ labels=model_params['labels'],
183
+ args=model_params['config'])
184
+ else:
185
+ raise ValueError("Unknown model type {}".format(model_type))
186
+ model.load_state_dict(model_params['model'], strict=False)
187
+ model = model.to(args.device)
188
+
189
+ logger.debug("-- MODEL CONFIG --")
190
+ for k in model.config.__dict__:
191
+ logger.debug(" --{}: {}".format(k, model.config.__dict__[k]))
192
+
193
+ logger.debug("-- MODEL LABELS --")
194
+ logger.debug(" {}".format(" ".join(model.labels)))
195
+
196
+ optimizer = None
197
+ if load_optimizer:
198
+ optimizer = Trainer.build_optimizer(model, args)
199
+ if checkpoint.get('optimizer_state_dict', None) is not None:
200
+ for opt_name, opt_state_dict in checkpoint['optimizer_state_dict'].items():
201
+ optimizer[opt_name].load_state_dict(opt_state_dict)
202
+ else:
203
+ logger.info("Attempted to load optimizer to resume training, but optimizer not saved. Creating new optimizer")
204
+
205
+ trainer = Trainer(model, optimizer, epochs_trained, global_step, best_score)
206
+
207
+ return trainer
208
+
209
+
210
+ def load_pretrain(args, foundation_cache):
211
+ if args.wordvec_pretrain_file:
212
+ pretrain_file = args.wordvec_pretrain_file
213
+ elif args.wordvec_type:
214
+ pretrain_file = '{}/{}.{}.pretrain.pt'.format(args.save_dir, args.shorthand, args.wordvec_type.name.lower())
215
+ else:
216
+ raise RuntimeError("TODO: need to get the wv type back from get_wordvec_file")
217
+
218
+ logger.debug("Looking for pretrained vectors in {}".format(pretrain_file))
219
+ if os.path.exists(pretrain_file):
220
+ return load_pretrain(pretrain_file, foundation_cache)
221
+ elif args.wordvec_raw_file:
222
+ vec_file = args.wordvec_raw_file
223
+ logger.debug("Pretrain not found. Looking in {}".format(vec_file))
224
+ else:
225
+ vec_file = utils.get_wordvec_file(args.wordvec_dir, args.shorthand, args.wordvec_type.name.lower())
226
+ logger.debug("Pretrain not found. Looking in {}".format(vec_file))
227
+ pretrain = Pretrain(pretrain_file, vec_file, args.pretrain_max_vocab)
228
+ logger.debug("Embedding shape: %s" % str(pretrain.emb.shape))
229
+ return pretrain
230
+
231
+
232
+ @staticmethod
233
+ def build_new_model(args, train_set):
234
+ """
235
+ Load pretrained pieces and then build a new model
236
+ """
237
+ if train_set is None:
238
+ raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")
239
+
240
+ labels = data.dataset_labels(train_set)
241
+
242
+ if args.model_type == ModelType.CNN:
243
+ pretrain = Trainer.load_pretrain(args, foundation_cache=None)
244
+ elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
245
+ charmodel_forward = load_charlm(args.charlm_forward_file)
246
+ charmodel_backward = load_charlm(args.charlm_backward_file)
247
+ peft_name = None
248
+ bert_model, bert_tokenizer = load_bert(args.bert_model)
249
+
250
+ use_peft = getattr(args, "use_peft", False)
251
+ if use_peft:
252
+ peft_name = "sentiment"
253
+ bert_model = build_peft_wrapper(bert_model, vars(args), logger, adapter_name=peft_name)
254
+
255
+ extra_vocab = data.dataset_vocab(train_set)
256
+ force_bert_saved = args.bert_finetune
257
+ model = cnn_classifier.CNNClassifier(pretrain=pretrain,
258
+ extra_vocab=extra_vocab,
259
+ labels=labels,
260
+ charmodel_forward=charmodel_forward,
261
+ charmodel_backward=charmodel_backward,
262
+ elmo_model=elmo_model,
263
+ bert_model=bert_model,
264
+ bert_tokenizer=bert_tokenizer,
265
+ force_bert_saved=force_bert_saved,
266
+ peft_name=peft_name,
267
+ args=args)
268
+ model = model.to(args.device)
269
+ elif args.model_type == ModelType.CONSTITUENCY:
270
+ # this passes flags such as "constituency_backprop" from
271
+ # the classifier to the TreeEmbedding as the "backprop" flag
272
+ parser_args = { x[len("constituency_"):]: y for x, y in vars(args).items() if x.startswith("constituency_") }
273
+ parser_args.update({
274
+ "wordvec_pretrain_file": args.wordvec_pretrain_file,
275
+ "charlm_forward_file": args.charlm_forward_file,
276
+ "charlm_backward_file": args.charlm_backward_file,
277
+ "bert_model": args.bert_model,
278
+ # we found that finetuning from the classifier output
279
+ # all the way to the bert layers caused the bert model
280
+ # to go astray
281
+ # could make this an option... but it is much less accurate
282
+ # with the Bert finetuning
283
+ # noting that the constituency parser itself works better
284
+ # after finetuning, of course
285
+ "bert_finetune": False,
286
+ "stage1_bert_finetune": False,
287
+ })
288
+ logger.info("Building constituency classifier using %s as the base model" % args.constituency_model)
289
+ tree_embedding = TreeEmbedding.from_parser_file(parser_args)
290
+ model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
291
+ labels=labels,
292
+ args=args)
293
+ model = model.to(args.device)
294
+ else:
295
+ raise ValueError("Unhandled model type {}".format(args.model_type))
296
+
297
+ optimizer = Trainer.build_optimizer(model, args)
298
+
299
+ return Trainer(model, optimizer)
300
+
301
+
302
+ @staticmethod
303
+ def build_optimizer(model, args):
304
+ return get_split_optimizer(args.optim.lower(), model, args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, bert_learning_rate=args.bert_learning_rate, bert_weight_decay=args.weight_decay * args.bert_weight_decay, is_peft=args.use_peft)
stanza/stanza/models/constituency/__init__.py ADDED
File without changes
stanza/stanza/models/constituency/evaluate_treebanks.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Read multiple treebanks, score the results.
3
+
4
+ Reports the k-best score if multiple predicted treebanks are given.
5
+ """
6
+
7
+ import argparse
8
+
9
+ from stanza.models.constituency import tree_reader
10
+ from stanza.server.parser_eval import EvaluateParser, ParseResult
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(description='Get scores for one or more treebanks against the gold')
15
+ parser.add_argument('gold', type=str, help='Which file to load as the gold trees')
16
+ parser.add_argument('pred', type=str, nargs='+', help='Which file(s) are the predictions. If more than one is given, the evaluation will be "k-best" with the first prediction treated as the canonical')
17
+ args = parser.parse_args()
18
+
19
+ print("Loading gold treebank: " + args.gold)
20
+ gold = tree_reader.read_treebank(args.gold)
21
+ print("Loading predicted treebanks: " + args.pred)
22
+ pred = [tree_reader.read_treebank(x) for x in args.pred]
23
+
24
+ full_results = [ParseResult(parses[0], [*parses[1:]])
25
+ for parses in zip(gold, *pred)]
26
+
27
+ if len(pred) <= 1:
28
+ kbest = None
29
+ else:
30
+ kbest = len(pred)
31
+
32
+ with EvaluateParser(kbest=kbest) as evaluator:
33
+ response = evaluator.process(full_results)
34
+
35
+ if __name__ == '__main__':
36
+ main()
stanza/stanza/models/constituency/label_attention.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import functools
3
+ import sys
4
+ import torch
5
+ from torch.autograd import Variable
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+
9
+ # publicly available versions alternate between torch.uint8 and torch.bool,
10
+ # but that is for older versions of torch anyway
11
+ DTYPE = torch.bool
12
+
13
+ class BatchIndices:
14
+ """
15
+ Batch indices container class (used to implement packed batches)
16
+ """
17
+ def __init__(self, batch_idxs_np, device):
18
+ self.batch_idxs_np = batch_idxs_np
19
+ self.batch_idxs_torch = torch.as_tensor(batch_idxs_np, dtype=torch.long, device=device)
20
+
21
+ self.batch_size = int(1 + np.max(batch_idxs_np))
22
+
23
+ batch_idxs_np_extra = np.concatenate([[-1], batch_idxs_np, [-1]])
24
+ self.boundaries_np = np.nonzero(batch_idxs_np_extra[1:] != batch_idxs_np_extra[:-1])[0]
25
+
26
+ #print(f"boundaries_np: {self.boundaries_np}")
27
+ #print(f"boundaries_np[1:]: {self.boundaries_np[1:]}")
28
+ #print(f"boundaries_np[:-1]: {self.boundaries_np[:-1]}")
29
+ self.seq_lens_np = self.boundaries_np[1:] - self.boundaries_np[:-1]
30
+ #print(f"seq_lens_np: {self.seq_lens_np}")
31
+ #print(f"batch_size: {self.batch_size}")
32
+ assert len(self.seq_lens_np) == self.batch_size
33
+ self.max_len = int(np.max(self.boundaries_np[1:] - self.boundaries_np[:-1]))
34
+
35
+
36
+ class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
37
+ @classmethod
38
+ def forward(cls, ctx, input, batch_idxs, p=0.5, train=False, inplace=False):
39
+ if p < 0 or p > 1:
40
+ raise ValueError("dropout probability has to be between 0 and 1, "
41
+ "but got {}".format(p))
42
+
43
+ ctx.p = p
44
+ ctx.train = train
45
+ ctx.inplace = inplace
46
+
47
+ if ctx.inplace:
48
+ ctx.mark_dirty(input)
49
+ output = input
50
+ else:
51
+ output = input.clone()
52
+
53
+ if ctx.p > 0 and ctx.train:
54
+ ctx.noise = input.new().resize_(batch_idxs.batch_size, input.size(1))
55
+ if ctx.p == 1:
56
+ ctx.noise.fill_(0)
57
+ else:
58
+ ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
59
+ ctx.noise = ctx.noise[batch_idxs.batch_idxs_torch, :]
60
+ output.mul_(ctx.noise)
61
+
62
+ return output
63
+
64
+ @staticmethod
65
+ def backward(ctx, grad_output):
66
+ if ctx.p > 0 and ctx.train:
67
+ return grad_output.mul(ctx.noise), None, None, None, None
68
+ else:
69
+ return grad_output, None, None, None, None
70
+
71
+ #
72
+ class FeatureDropout(nn.Module):
73
+ """
74
+ Feature-level dropout: takes an input of size len x num_features and drops
75
+ each feature with probabibility p. A feature is dropped across the full
76
+ portion of the input that corresponds to a single batch element.
77
+ """
78
+ def __init__(self, p=0.5, inplace=False):
79
+ super().__init__()
80
+ if p < 0 or p > 1:
81
+ raise ValueError("dropout probability has to be between 0 and 1, "
82
+ "but got {}".format(p))
83
+ self.p = p
84
+ self.inplace = inplace
85
+
86
+ def forward(self, input, batch_idxs):
87
+ return FeatureDropoutFunction.apply(input, batch_idxs, self.p, self.training, self.inplace)
88
+
89
+
90
+
91
+ class LayerNormalization(nn.Module):
92
+ def __init__(self, d_hid, eps=1e-3, affine=True):
93
+ super(LayerNormalization, self).__init__()
94
+
95
+ self.eps = eps
96
+ self.affine = affine
97
+ if self.affine:
98
+ self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
99
+ self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
100
+
101
+ def forward(self, z):
102
+ if z.size(-1) == 1:
103
+ return z
104
+
105
+ mu = torch.mean(z, keepdim=True, dim=-1)
106
+ sigma = torch.std(z, keepdim=True, dim=-1)
107
+ ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
108
+ if self.affine:
109
+ ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
110
+
111
+ return ln_out
112
+
113
+
114
+
115
+ class ScaledDotProductAttention(nn.Module):
116
+ def __init__(self, d_model, attention_dropout=0.1):
117
+ super(ScaledDotProductAttention, self).__init__()
118
+ self.temper = d_model ** 0.5
119
+ self.dropout = nn.Dropout(attention_dropout)
120
+ self.softmax = nn.Softmax(dim=-1)
121
+
122
+ def forward(self, q, k, v, attn_mask=None):
123
+ # q: [batch, slot, feat] or (batch * d_l) x max_len x d_k
124
+ # k: [batch, slot, feat] or (batch * d_l) x max_len x d_k
125
+ # v: [batch, slot, feat] or (batch * d_l) x max_len x d_v
126
+ # q in LAL is (batch * d_l) x 1 x d_k
127
+
128
+ attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # (batch * d_l) x max_len x max_len
129
+ # in LAL, gives: (batch * d_l) x 1 x max_len
130
+ # attention weights from each word to each word, for each label
131
+ # in best model (repeated q): attention weights from label (as vector weights) to each word
132
+
133
+ if attn_mask is not None:
134
+ assert attn_mask.size() == attn.size(), \
135
+ 'Attention mask shape {} mismatch ' \
136
+ 'with Attention logit tensor shape ' \
137
+ '{}.'.format(attn_mask.size(), attn.size())
138
+
139
+ attn.data.masked_fill_(attn_mask, -float('inf'))
140
+
141
+ attn = self.softmax(attn)
142
+ # Note that this makes the distribution not sum to 1. At some point it
143
+ # may be worth researching whether this is the right way to apply
144
+ # dropout to the attention.
145
+ # Note that the t2t code also applies dropout in this manner
146
+ attn = self.dropout(attn)
147
+ output = torch.bmm(attn, v) # (batch * d_l) x max_len x d_v
148
+ # in LAL, gives: (batch * d_l) x 1 x d_v
149
+
150
+ return output, attn
151
+
152
+
153
+ class MultiHeadAttention(nn.Module):
154
+ """
155
+ Multi-head attention module
156
+ """
157
+
158
+ def __init__(self, n_head, d_model, d_k, d_v, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
159
+ super(MultiHeadAttention, self).__init__()
160
+
161
+ self.n_head = n_head
162
+ self.d_k = d_k
163
+ self.d_v = d_v
164
+
165
+ if not d_positional:
166
+ self.partitioned = False
167
+ else:
168
+ self.partitioned = True
169
+
170
+ if self.partitioned:
171
+ self.d_content = d_model - d_positional
172
+ self.d_positional = d_positional
173
+
174
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
175
+ self.w_ks1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
176
+ self.w_vs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_v // 2))
177
+
178
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
179
+ self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
180
+ self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_v // 2))
181
+
182
+ init.xavier_normal_(self.w_qs1)
183
+ init.xavier_normal_(self.w_ks1)
184
+ init.xavier_normal_(self.w_vs1)
185
+
186
+ init.xavier_normal_(self.w_qs2)
187
+ init.xavier_normal_(self.w_ks2)
188
+ init.xavier_normal_(self.w_vs2)
189
+ else:
190
+ self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
191
+ self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
192
+ self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))
193
+
194
+ init.xavier_normal_(self.w_qs)
195
+ init.xavier_normal_(self.w_ks)
196
+ init.xavier_normal_(self.w_vs)
197
+
198
+ self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
199
+ self.layer_norm = LayerNormalization(d_model)
200
+
201
+ if not self.partitioned:
202
+ # The lack of a bias term here is consistent with the t2t code, though
203
+ # in my experiments I have never observed this making a difference.
204
+ self.proj = nn.Linear(n_head*d_v, d_model, bias=False)
205
+ else:
206
+ self.proj1 = nn.Linear(n_head*(d_v//2), self.d_content, bias=False)
207
+ self.proj2 = nn.Linear(n_head*(d_v//2), self.d_positional, bias=False)
208
+
209
+ self.residual_dropout = FeatureDropout(residual_dropout)
210
+
211
+ def split_qkv_packed(self, inp, qk_inp=None):
212
+ v_inp_repeated = inp.repeat(self.n_head, 1).view(self.n_head, -1, inp.size(-1)) # n_head x len_inp x d_model
213
+ if qk_inp is None:
214
+ qk_inp_repeated = v_inp_repeated
215
+ else:
216
+ qk_inp_repeated = qk_inp.repeat(self.n_head, 1).view(self.n_head, -1, qk_inp.size(-1))
217
+
218
+ if not self.partitioned:
219
+ q_s = torch.bmm(qk_inp_repeated, self.w_qs) # n_head x len_inp x d_k
220
+ k_s = torch.bmm(qk_inp_repeated, self.w_ks) # n_head x len_inp x d_k
221
+ v_s = torch.bmm(v_inp_repeated, self.w_vs) # n_head x len_inp x d_v
222
+ else:
223
+ q_s = torch.cat([
224
+ torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_qs1),
225
+ torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_qs2),
226
+ ], -1)
227
+ k_s = torch.cat([
228
+ torch.bmm(qk_inp_repeated[:,:,:self.d_content], self.w_ks1),
229
+ torch.bmm(qk_inp_repeated[:,:,self.d_content:], self.w_ks2),
230
+ ], -1)
231
+ v_s = torch.cat([
232
+ torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
233
+ torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
234
+ ], -1)
235
+ return q_s, k_s, v_s
236
+
237
+ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
238
+ # Input is padded representation: n_head x len_inp x d
239
+ # Output is packed representation: (n_head * mb_size) x len_padded x d
240
+ # (along with masks for the attention and output)
241
+ n_head = self.n_head
242
+ d_k, d_v = self.d_k, self.d_v
243
+
244
+ len_padded = batch_idxs.max_len
245
+ mb_size = batch_idxs.batch_size
246
+ q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
247
+ k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
248
+ v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
249
+ invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
250
+
251
+ for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
252
+ q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
253
+ k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
254
+ v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
255
+ invalid_mask[i, :end-start].fill_(False)
256
+
257
+ return(
258
+ q_padded.view(-1, len_padded, d_k),
259
+ k_padded.view(-1, len_padded, d_k),
260
+ v_padded.view(-1, len_padded, d_v),
261
+ invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1),
262
+ (~invalid_mask).repeat(n_head, 1),
263
+ )
264
+
265
+ def combine_v(self, outputs):
266
+ # Combine attention information from the different heads
267
+ n_head = self.n_head
268
+ outputs = outputs.view(n_head, -1, self.d_v) # n_head x len_inp x d_kv
269
+
270
+ if not self.partitioned:
271
+ # Switch from n_head x len_inp x d_v to len_inp x (n_head * d_v)
272
+ outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, n_head * self.d_v)
273
+
274
+ # Project back to residual size
275
+ outputs = self.proj(outputs)
276
+ else:
277
+ d_v1 = self.d_v // 2
278
+ outputs1 = outputs[:,:,:d_v1]
279
+ outputs2 = outputs[:,:,d_v1:]
280
+ outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, n_head * d_v1)
281
+ outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, n_head * d_v1)
282
+ outputs = torch.cat([
283
+ self.proj1(outputs1),
284
+ self.proj2(outputs2),
285
+ ], -1)
286
+
287
+ return outputs
288
+
289
+ def forward(self, inp, batch_idxs, qk_inp=None):
290
+ residual = inp
291
+
292
+ # While still using a packed representation, project to obtain the
293
+ # query/key/value for each head
294
+ q_s, k_s, v_s = self.split_qkv_packed(inp, qk_inp=qk_inp)
295
+ # n_head x len_inp x d_kv
296
+
297
+ # Switch to padded representation, perform attention, then switch back
298
+ q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
299
+ # (n_head * batch) x len_padded x d_kv
300
+
301
+ outputs_padded, attns_padded = self.attention(
302
+ q_padded, k_padded, v_padded,
303
+ attn_mask=attn_mask,
304
+ )
305
+ outputs = outputs_padded[output_mask]
306
+ # (n_head * len_inp) x d_kv
307
+ outputs = self.combine_v(outputs)
308
+ # len_inp x d_model
309
+
310
+ outputs = self.residual_dropout(outputs, batch_idxs)
311
+
312
+ return self.layer_norm(outputs + residual), attns_padded
313
+
314
+ #
315
+ class PositionwiseFeedForward(nn.Module):
316
+ """
317
+ A position-wise feed forward module.
318
+
319
+ Projects to a higher-dimensional space before applying ReLU, then projects
320
+ back.
321
+ """
322
+
323
+ def __init__(self, d_hid, d_ff, relu_dropout=0.1, residual_dropout=0.1):
324
+ super(PositionwiseFeedForward, self).__init__()
325
+ self.w_1 = nn.Linear(d_hid, d_ff)
326
+ self.w_2 = nn.Linear(d_ff, d_hid)
327
+
328
+ self.layer_norm = LayerNormalization(d_hid)
329
+ self.relu_dropout = FeatureDropout(relu_dropout)
330
+ self.residual_dropout = FeatureDropout(residual_dropout)
331
+ self.relu = nn.ReLU()
332
+
333
+
334
+ def forward(self, x, batch_idxs):
335
+ residual = x
336
+
337
+ output = self.w_1(x)
338
+ output = self.relu_dropout(self.relu(output), batch_idxs)
339
+ output = self.w_2(output)
340
+
341
+ output = self.residual_dropout(output, batch_idxs)
342
+ return self.layer_norm(output + residual)
343
+
344
+ #
345
+ class PartitionedPositionwiseFeedForward(nn.Module):
346
+ def __init__(self, d_hid, d_ff, d_positional, relu_dropout=0.1, residual_dropout=0.1):
347
+ super().__init__()
348
+ self.d_content = d_hid - d_positional
349
+ self.w_1c = nn.Linear(self.d_content, d_ff//2)
350
+ self.w_1p = nn.Linear(d_positional, d_ff//2)
351
+ self.w_2c = nn.Linear(d_ff//2, self.d_content)
352
+ self.w_2p = nn.Linear(d_ff//2, d_positional)
353
+ self.layer_norm = LayerNormalization(d_hid)
354
+ self.relu_dropout = FeatureDropout(relu_dropout)
355
+ self.residual_dropout = FeatureDropout(residual_dropout)
356
+ self.relu = nn.ReLU()
357
+
358
+ def forward(self, x, batch_idxs):
359
+ residual = x
360
+ xc = x[:, :self.d_content]
361
+ xp = x[:, self.d_content:]
362
+
363
+ outputc = self.w_1c(xc)
364
+ outputc = self.relu_dropout(self.relu(outputc), batch_idxs)
365
+ outputc = self.w_2c(outputc)
366
+
367
+ outputp = self.w_1p(xp)
368
+ outputp = self.relu_dropout(self.relu(outputp), batch_idxs)
369
+ outputp = self.w_2p(outputp)
370
+
371
+ output = torch.cat([outputc, outputp], -1)
372
+
373
+ output = self.residual_dropout(output, batch_idxs)
374
+ return self.layer_norm(output + residual)
375
+
376
+ class LabelAttention(nn.Module):
377
+ """
378
+ Single-head Attention layer for label-specific representations
379
+ """
380
+
381
+ def __init__(self, d_model, d_k, d_v, d_l, d_proj, combine_as_self, use_resdrop=True, q_as_matrix=False, residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
382
+ super(LabelAttention, self).__init__()
383
+ self.d_k = d_k
384
+ self.d_v = d_v
385
+ self.d_l = d_l # Number of Labels
386
+ self.d_model = d_model # Model Dimensionality
387
+ self.d_proj = d_proj # Projection dimension of each label output
388
+ self.use_resdrop = use_resdrop # Using Residual Dropout?
389
+ self.q_as_matrix = q_as_matrix # Using a Matrix of Q to be multiplied with input instead of learned q vectors
390
+ self.combine_as_self = combine_as_self # Using the Combination Method of Self-Attention
391
+
392
+ if not d_positional:
393
+ self.partitioned = False
394
+ else:
395
+ self.partitioned = True
396
+
397
+ if self.partitioned:
398
+ if d_model <= d_positional:
399
+ raise ValueError("Unable to build LabelAttention. d_model %d <= d_positional %d" % (d_model, d_positional))
400
+ self.d_content = d_model - d_positional
401
+ self.d_positional = d_positional
402
+
403
+ if self.q_as_matrix:
404
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
405
+ else:
406
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
407
+ self.w_ks1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
408
+ self.w_vs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_v // 2), requires_grad=True)
409
+
410
+ if self.q_as_matrix:
411
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
412
+ else:
413
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
414
+ self.w_ks2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
415
+ self.w_vs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_v // 2), requires_grad=True)
416
+
417
+ init.xavier_normal_(self.w_qs1)
418
+ init.xavier_normal_(self.w_ks1)
419
+ init.xavier_normal_(self.w_vs1)
420
+
421
+ init.xavier_normal_(self.w_qs2)
422
+ init.xavier_normal_(self.w_ks2)
423
+ init.xavier_normal_(self.w_vs2)
424
+ else:
425
+ if self.q_as_matrix:
426
+ self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
427
+ else:
428
+ self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_k), requires_grad=True)
429
+ self.w_ks = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
430
+ self.w_vs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_v), requires_grad=True)
431
+
432
+ init.xavier_normal_(self.w_qs)
433
+ init.xavier_normal_(self.w_ks)
434
+ init.xavier_normal_(self.w_vs)
435
+
436
+ self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
437
+ if self.combine_as_self:
438
+ self.layer_norm = LayerNormalization(d_model)
439
+ else:
440
+ self.layer_norm = LayerNormalization(self.d_proj)
441
+
442
+ if not self.partitioned:
443
+ # The lack of a bias term here is consistent with the t2t code, though
444
+ # in my experiments I have never observed this making a difference.
445
+ if self.combine_as_self:
446
+ self.proj = nn.Linear(self.d_l * d_v, d_model, bias=False)
447
+ else:
448
+ self.proj = nn.Linear(d_v, d_model, bias=False) # input dimension does not match, should be d_l * d_v
449
+ else:
450
+ if self.combine_as_self:
451
+ self.proj1 = nn.Linear(self.d_l*(d_v//2), self.d_content, bias=False)
452
+ self.proj2 = nn.Linear(self.d_l*(d_v//2), self.d_positional, bias=False)
453
+ else:
454
+ self.proj1 = nn.Linear(d_v//2, self.d_content, bias=False)
455
+ self.proj2 = nn.Linear(d_v//2, self.d_positional, bias=False)
456
+ if not self.combine_as_self:
457
+ self.reduce_proj = nn.Linear(d_model, self.d_proj, bias=False)
458
+
459
+ self.residual_dropout = FeatureDropout(residual_dropout)
460
+
461
+ def split_qkv_packed(self, inp, k_inp=None):
462
+ len_inp = inp.size(0)
463
+ v_inp_repeated = inp.repeat(self.d_l, 1).view(self.d_l, -1, inp.size(-1)) # d_l x len_inp x d_model
464
+ if k_inp is None:
465
+ k_inp_repeated = v_inp_repeated
466
+ else:
467
+ k_inp_repeated = k_inp.repeat(self.d_l, 1).view(self.d_l, -1, k_inp.size(-1)) # d_l x len_inp x d_model
468
+
469
+ if not self.partitioned:
470
+ if self.q_as_matrix:
471
+ q_s = torch.bmm(k_inp_repeated, self.w_qs) # d_l x len_inp x d_k
472
+ else:
473
+ q_s = self.w_qs.unsqueeze(1) # d_l x 1 x d_k
474
+ k_s = torch.bmm(k_inp_repeated, self.w_ks) # d_l x len_inp x d_k
475
+ v_s = torch.bmm(v_inp_repeated, self.w_vs) # d_l x len_inp x d_v
476
+ else:
477
+ if self.q_as_matrix:
478
+ q_s = torch.cat([
479
+ torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_qs1),
480
+ torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_qs2),
481
+ ], -1)
482
+ else:
483
+ q_s = torch.cat([
484
+ self.w_qs1.unsqueeze(1),
485
+ self.w_qs2.unsqueeze(1),
486
+ ], -1)
487
+ k_s = torch.cat([
488
+ torch.bmm(k_inp_repeated[:,:,:self.d_content], self.w_ks1),
489
+ torch.bmm(k_inp_repeated[:,:,self.d_content:], self.w_ks2),
490
+ ], -1)
491
+ v_s = torch.cat([
492
+ torch.bmm(v_inp_repeated[:,:,:self.d_content], self.w_vs1),
493
+ torch.bmm(v_inp_repeated[:,:,self.d_content:], self.w_vs2),
494
+ ], -1)
495
+ return q_s, k_s, v_s
496
+
497
+ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs):
498
+ # Input is padded representation: n_head x len_inp x d
499
+ # Output is packed representation: (n_head * mb_size) x len_padded x d
500
+ # (along with masks for the attention and output)
501
+ n_head = self.d_l
502
+ d_k, d_v = self.d_k, self.d_v
503
+
504
+ len_padded = batch_idxs.max_len
505
+ mb_size = batch_idxs.batch_size
506
+ if self.q_as_matrix:
507
+ q_padded = q_s.new_zeros((n_head, mb_size, len_padded, d_k))
508
+ else:
509
+ q_padded = q_s.repeat(mb_size, 1, 1) # (d_l * mb_size) x 1 x d_k
510
+ k_padded = k_s.new_zeros((n_head, mb_size, len_padded, d_k))
511
+ v_padded = v_s.new_zeros((n_head, mb_size, len_padded, d_v))
512
+ invalid_mask = q_s.new_ones((mb_size, len_padded), dtype=DTYPE)
513
+
514
+ for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
515
+ if self.q_as_matrix:
516
+ q_padded[:,i,:end-start,:] = q_s[:,start:end,:]
517
+ k_padded[:,i,:end-start,:] = k_s[:,start:end,:]
518
+ v_padded[:,i,:end-start,:] = v_s[:,start:end,:]
519
+ invalid_mask[i, :end-start].fill_(False)
520
+
521
+ if self.q_as_matrix:
522
+ q_padded = q_padded.view(-1, len_padded, d_k)
523
+ attn_mask = invalid_mask.unsqueeze(1).expand(mb_size, len_padded, len_padded).repeat(n_head, 1, 1)
524
+ else:
525
+ attn_mask = invalid_mask.unsqueeze(1).repeat(n_head, 1, 1)
526
+
527
+ output_mask = (~invalid_mask).repeat(n_head, 1)
528
+
529
+ return(
530
+ q_padded,
531
+ k_padded.view(-1, len_padded, d_k),
532
+ v_padded.view(-1, len_padded, d_v),
533
+ attn_mask,
534
+ output_mask,
535
+ )
536
+
537
+ def combine_v(self, outputs):
538
+ # Combine attention information from the different labels
539
+ d_l = self.d_l
540
+ outputs = outputs.view(d_l, -1, self.d_v) # d_l x len_inp x d_v
541
+
542
+ if not self.partitioned:
543
+ # Switch from d_l x len_inp x d_v to len_inp x d_l x d_v
544
+ if self.combine_as_self:
545
+ outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, d_l * self.d_v)
546
+ else:
547
+ outputs = torch.transpose(outputs, 0, 1)#.contiguous() #.view(-1, d_l * self.d_v)
548
+ # Project back to residual size
549
+ outputs = self.proj(outputs) # Becomes len_inp x d_l x d_model
550
+ else:
551
+ d_v1 = self.d_v // 2
552
+ outputs1 = outputs[:,:,:d_v1]
553
+ outputs2 = outputs[:,:,d_v1:]
554
+ if self.combine_as_self:
555
+ outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, d_l * d_v1)
556
+ outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, d_l * d_v1)
557
+ else:
558
+ outputs1 = torch.transpose(outputs1, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
559
+ outputs2 = torch.transpose(outputs2, 0, 1)#.contiguous() #.view(-1, d_l * d_v1)
560
+ outputs = torch.cat([
561
+ self.proj1(outputs1),
562
+ self.proj2(outputs2),
563
+ ], -1)#.contiguous()
564
+
565
+ return outputs
566
+
567
+ def forward(self, inp, batch_idxs, k_inp=None):
568
+ residual = inp # len_inp x d_model
569
+ #print()
570
+ #print(f"inp.shape: {inp.shape}")
571
+ len_inp = inp.size(0)
572
+ #print(f"len_inp: {len_inp}")
573
+
574
+ # While still using a packed representation, project to obtain the
575
+ # query/key/value for each head
576
+ q_s, k_s, v_s = self.split_qkv_packed(inp, k_inp=k_inp)
577
+ # d_l x len_inp x d_k
578
+ # q_s is d_l x 1 x d_k
579
+
580
+ # Switch to padded representation, perform attention, then switch back
581
+ q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs)
582
+ # q_padded, k_padded, v_padded: (d_l * batch_size) x max_len x d_kv
583
+ # q_s is (d_l * batch_size) x 1 x d_kv
584
+
585
+ outputs_padded, attns_padded = self.attention(
586
+ q_padded, k_padded, v_padded,
587
+ attn_mask=attn_mask,
588
+ )
589
+ # outputs_padded: (d_l * batch_size) x max_len x d_kv
590
+ # in LAL: (d_l * batch_size) x 1 x d_kv
591
+ # on the best model, this is one value vector per label that is repeated max_len times
592
+ if not self.q_as_matrix:
593
+ outputs_padded = outputs_padded.repeat(1,output_mask.size(-1),1)
594
+ outputs = outputs_padded[output_mask]
595
+ # outputs: (d_l * len_inp) x d_kv or LAL: (d_l * len_inp) x d_kv
596
+ # output_mask: (d_l * batch_size) x max_len
597
+ outputs = self.combine_v(outputs)
598
+ #print(f"outputs shape: {outputs.shape}")
599
+ # outputs: len_inp x d_l x d_model, whereas a normal self-attention layer gets len_inp x d_model
600
+ if self.use_resdrop:
601
+ if self.combine_as_self:
602
+ outputs = self.residual_dropout(outputs, batch_idxs)
603
+ else:
604
+ outputs = torch.cat([self.residual_dropout(outputs[:,i,:], batch_idxs).unsqueeze(1) for i in range(self.d_l)], 1)
605
+ if self.combine_as_self:
606
+ outputs = self.layer_norm(outputs + inp)
607
+ else:
608
+ for l in range(self.d_l):
609
+ outputs[:, l, :] = outputs[:, l, :] + inp
610
+
611
+ outputs = self.reduce_proj(outputs) # len_inp x d_l x d_proj
612
+ outputs = self.layer_norm(outputs) # len_inp x d_l x d_proj
613
+ outputs = outputs.view(len_inp, -1).contiguous() # len_inp x (d_l * d_proj)
614
+
615
+ return outputs, attns_padded
616
+
617
+
618
+ #
619
+ class LabelAttentionModule(nn.Module):
620
+ """
621
+ Label Attention Module for label-specific representations
622
+ The module can be used right after the Partitioned Attention, or it can be experimented with for the transition stack
623
+ """
624
+ #
625
+ def __init__(self,
626
+ d_model,
627
+ d_input_proj,
628
+ d_k,
629
+ d_v,
630
+ d_l,
631
+ d_proj,
632
+ combine_as_self,
633
+ use_resdrop=True,
634
+ q_as_matrix=False,
635
+ residual_dropout=0.1,
636
+ attention_dropout=0.1,
637
+ d_positional=None,
638
+ d_ff=2048,
639
+ relu_dropout=0.2,
640
+ lattn_partitioned=True):
641
+ super().__init__()
642
+ self.ff_dim = d_proj * d_l
643
+
644
+ if not lattn_partitioned:
645
+ self.d_positional = 0
646
+ else:
647
+ self.d_positional = d_positional if d_positional else 0
648
+
649
+ if d_input_proj:
650
+ if d_input_proj <= self.d_positional:
651
+ raise ValueError("Illegal argument for d_input_proj: d_input_proj %d is smaller than d_positional %d" % (d_input_proj, self.d_positional))
652
+ self.input_projection = nn.Linear(d_model - self.d_positional, d_input_proj - self.d_positional, bias=False)
653
+ d_input = d_input_proj
654
+ else:
655
+ self.input_projection = None
656
+ d_input = d_model
657
+
658
+ self.label_attention = LabelAttention(d_input,
659
+ d_k,
660
+ d_v,
661
+ d_l,
662
+ d_proj,
663
+ combine_as_self,
664
+ use_resdrop,
665
+ q_as_matrix,
666
+ residual_dropout,
667
+ attention_dropout,
668
+ self.d_positional)
669
+
670
+ if not lattn_partitioned:
671
+ self.lal_ff = PositionwiseFeedForward(self.ff_dim,
672
+ d_ff,
673
+ relu_dropout,
674
+ residual_dropout)
675
+ else:
676
+ self.lal_ff = PartitionedPositionwiseFeedForward(self.ff_dim,
677
+ d_ff,
678
+ self.d_positional,
679
+ relu_dropout,
680
+ residual_dropout)
681
+
682
+ def forward(self, word_embeddings, tagged_word_lists):
683
+ if self.input_projection:
684
+ if self.d_positional > 0:
685
+ word_embeddings = [torch.cat((self.input_projection(sentence[:, :-self.d_positional]),
686
+ sentence[:, -self.d_positional:]), dim=1)
687
+ for sentence in word_embeddings]
688
+ else:
689
+ word_embeddings = [self.input_projection(sentence) for sentence in word_embeddings]
690
+ # Extract Labeled Representation
691
+ packed_len = sum(sentence.shape[0] for sentence in word_embeddings)
692
+ batch_idxs = np.zeros(packed_len, dtype=int)
693
+
694
+ batch_size = len(word_embeddings)
695
+ i = 0
696
+
697
+ sentence_lengths = [0] * batch_size
698
+ for sentence_idx, sentence in enumerate(word_embeddings):
699
+ sentence_lengths[sentence_idx] = len(sentence)
700
+ for word in sentence:
701
+ batch_idxs[i] = sentence_idx
702
+ i += 1
703
+
704
+ batch_indices = batch_idxs
705
+ batch_idxs = BatchIndices(batch_idxs, word_embeddings[0].device)
706
+
707
+ new_embeds = []
708
+ for sentence_idx, batch in enumerate(word_embeddings):
709
+ for word_idx, embed in enumerate(batch):
710
+ if word_idx < sentence_lengths[sentence_idx]:
711
+ new_embeds.append(embed)
712
+
713
+ new_word_embeddings = torch.stack(new_embeds)
714
+
715
+ labeled_representations, _ = self.label_attention(new_word_embeddings, batch_idxs)
716
+ labeled_representations = self.lal_ff(labeled_representations, batch_idxs)
717
+ final_labeled_representations = [[] for i in range(batch_size)]
718
+
719
+ for idx, embed in enumerate(labeled_representations):
720
+ final_labeled_representations[batch_indices[idx]].append(embed)
721
+
722
+ for idx, representation in enumerate(final_labeled_representations):
723
+ final_labeled_representations[idx] = torch.stack(representation)
724
+
725
+ return final_labeled_representations
726
+
stanza/stanza/models/constituency/lstm_tree_stack.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Keeps an LSTM in TreeStack form.
3
+
4
+ The TreeStack nodes keep the hx and cx for the LSTM, along with a
5
+ "value" which represents whatever the user needs to store.
6
+
7
+ The TreeStacks can be ppped to get back to the previous LSTM state.
8
+
9
+ The module itself implements three methods: initial_state, push_states, output
10
+ """
11
+
12
+ from collections import namedtuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from stanza.models.constituency.tree_stack import TreeStack
18
+
19
+ Node = namedtuple("Node", ['value', 'lstm_hx', 'lstm_cx'])
20
+
21
+ class LSTMTreeStack(nn.Module):
22
+ def __init__(self, input_size, hidden_size, num_lstm_layers, dropout, uses_boundary_vector, input_dropout):
23
+ """
24
+ Prepare LSTM and parameters
25
+
26
+ input_size: dimension of the inputs to the LSTM
27
+ hidden_size: LSTM internal & output dimension
28
+ num_lstm_layers: how many layers of LSTM to use
29
+ dropout: value of the LSTM dropout
30
+ uses_boundary_vector: if set, learn a start_embedding parameter. otherwise, use zeros
31
+ input_dropout: an nn.Module to dropout inputs. TODO: allow a float parameter as well
32
+ """
33
+ super().__init__()
34
+
35
+ self.uses_boundary_vector = uses_boundary_vector
36
+
37
+ # The start embedding needs to be input_size as we put it through the LSTM
38
+ if uses_boundary_vector:
39
+ self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
40
+ else:
41
+ self.register_buffer('input_zeros', torch.zeros(num_lstm_layers, 1, input_size))
42
+ self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size))
43
+
44
+ self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout)
45
+ self.input_dropout = input_dropout
46
+
47
+
48
+ def initial_state(self, initial_value=None):
49
+ """
50
+ Return an initial state, either based on zeros or based on the initial embedding and LSTM
51
+
52
+ Note that LSTM start operation is already batched, in a sense
53
+ The subsequent batch built this way will be used for batch_size trees
54
+
55
+ Returns a stack with None value, hx & cx either based on the
56
+ start_embedding or zeros, and no parent.
57
+ """
58
+ if self.uses_boundary_vector:
59
+ start = self.start_embedding.unsqueeze(0).unsqueeze(0)
60
+ output, (hx, cx) = self.lstm(start)
61
+ start = output[0, 0, :]
62
+ else:
63
+ start = self.input_zeros
64
+ hx = self.hidden_zeros
65
+ cx = self.hidden_zeros
66
+ return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1)
67
+
68
+ def push_states(self, stacks, values, inputs):
69
+ """
70
+ Starting from a list of current stacks, put the inputs through the LSTM and build new stack nodes.
71
+
72
+ B = stacks.len() = values.len()
73
+
74
+ inputs must be of shape 1 x B x input_size
75
+ """
76
+ inputs = self.input_dropout(inputs)
77
+
78
+ hx = torch.cat([t.value.lstm_hx for t in stacks], axis=1)
79
+ cx = torch.cat([t.value.lstm_cx for t in stacks], axis=1)
80
+ output, (hx, cx) = self.lstm(inputs, (hx, cx))
81
+ new_stacks = [stack.push(Node(transition, hx[:, i:i+1, :], cx[:, i:i+1, :]))
82
+ for i, (stack, transition) in enumerate(zip(stacks, values))]
83
+ return new_stacks
84
+
85
+ def output(self, stack):
86
+ """
87
+ Return the last layer of the lstm_hx as the output from a stack
88
+
89
+ Refactored so that alternate structures have an easy way of getting the output
90
+ """
91
+ return stack.value.lstm_hx[-1, 0, :]
stanza/stanza/models/constituency/score_converted_dependencies.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script which processes a dependency file by using the constituency parser, then converting with the CoreNLP converter
3
+
4
+ Currently this does not have the constituency parser as an option,
5
+ although that is easy to add.
6
+
7
+ Only English is supported, as only English is available in the CoreNLP converter
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import tempfile
13
+
14
+ import stanza
15
+ from stanza.models.constituency import retagging
16
+ from stanza.models.depparse import scorer
17
+ from stanza.utils.conll import CoNLL
18
+
19
+ def score_converted_dependencies(args):
20
+ if args['lang'] != 'en':
21
+ raise ValueError("Converting and scoring dependencies is currently only supported for English")
22
+
23
+ constituency_package = args['constituency_package']
24
+ pipeline_args = {'lang': args['lang'],
25
+ 'tokenize_pretokenized': True,
26
+ 'package': {'pos': args['retag_package'], 'depparse': 'converter', 'constituency': constituency_package},
27
+ 'processors': 'tokenize, pos, constituency, depparse'}
28
+ pipeline = stanza.Pipeline(**pipeline_args)
29
+
30
+ input_doc = CoNLL.conll2doc(args['eval_file'])
31
+ output_doc = pipeline(input_doc)
32
+ print("Processed %d sentences" % len(output_doc.sentences))
33
+ # reload - the pipeline clobbered the gold values
34
+ input_doc = CoNLL.conll2doc(args['eval_file'])
35
+
36
+ scorer.score_named_dependencies(output_doc, input_doc)
37
+ with tempfile.TemporaryDirectory() as tempdir:
38
+ output_path = os.path.join(tempdir, "converted.conll")
39
+
40
+ CoNLL.write_doc2conll(output_doc, output_path)
41
+
42
+ _, _, score = scorer.score(output_path, args['eval_file'])
43
+
44
+ print("Parser score:")
45
+ print("{} {:.2f}".format(constituency_package, score*100))
46
+
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser()
50
+
51
+ parser.add_argument('--lang', default='en', type=str, help='Language')
52
+ parser.add_argument('--eval_file', default="extern_data/ud2/ud-treebanks-v2.13/UD_English-EWT/en_ewt-ud-test.conllu", help='Input file for data loader.')
53
+ parser.add_argument('--constituency_package', default="ptb3-revised_electra-large", help='Which constituency parser to use for converting')
54
+
55
+ retagging.add_retag_args(parser)
56
+ args = parser.parse_args()
57
+
58
+ args = vars(args)
59
+ retagging.postprocess_args(args)
60
+
61
+ score_converted_dependencies(args)
62
+
63
+ if __name__ == '__main__':
64
+ main()
65
+
stanza/stanza/models/constituency/text_processing.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import logging
4
+
5
+ from stanza.models.common import utils
6
+ from stanza.models.constituency.utils import retag_tags
7
+ from stanza.models.constituency.trainer import Trainer
8
+ from stanza.models.constituency.tree_reader import read_trees
9
+ from stanza.utils.get_tqdm import get_tqdm
10
+
11
+ logger = logging.getLogger('stanza')
12
+ tqdm = get_tqdm()
13
+
14
+ def read_tokenized_file(tokenized_file):
15
+ """
16
+ Read sentences from a tokenized file, potentially replacing _ with space for languages such as VI
17
+ """
18
+ with open(tokenized_file, encoding='utf-8') as fin:
19
+ lines = fin.readlines()
20
+ lines = [x.strip() for x in lines]
21
+ lines = [x for x in lines if x]
22
+ docs = [[word if all(x == '_' for x in word) else word.replace("_", " ") for word in sentence.split()] for sentence in lines]
23
+ ids = [None] * len(docs)
24
+ return docs, ids
25
+
26
+ def read_xml_tree_file(tree_file):
27
+ """
28
+ Read sentences from a file of the format unique to VLSP test sets
29
+
30
+ in particular, it should be multiple blocks of
31
+
32
+ <s id=1>
33
+ (tree ...)
34
+ </s>
35
+ """
36
+ with open(tree_file, encoding='utf-8') as fin:
37
+ lines = fin.readlines()
38
+ lines = [x.strip() for x in lines]
39
+ lines = [x for x in lines if x]
40
+ docs = []
41
+ ids = []
42
+ tree_id = None
43
+ tree_text = []
44
+ for line in lines:
45
+ if line.startswith("<s"):
46
+ tree_id = line.split("=")
47
+ if len(tree_id) > 1:
48
+ tree_id = tree_id[1]
49
+ if tree_id.endswith(">"):
50
+ tree_id = tree_id[:-1]
51
+ tree_id = int(tree_id)
52
+ else:
53
+ tree_id = None
54
+ elif line.startswith("</s"):
55
+ if len(tree_text) == 0:
56
+ raise ValueError("Found a blank tree in %s" % tree_file)
57
+ ids.append(tree_id)
58
+ tree_text = "\n".join(tree_text)
59
+ trees = read_trees(tree_text)
60
+ # TODO: perhaps the processing can be put into read_trees instead
61
+ trees = [t.prune_none().simplify_labels() for t in trees]
62
+ if len(trees) != 1:
63
+ raise ValueError("Found a tree with %d trees in %s" % (len(trees), tree_file))
64
+ tree = trees[0]
65
+ text = tree.leaf_labels()
66
+ text = [word if all(x == '_' for x in word) else word.replace("_", " ") for word in text]
67
+ docs.append(text)
68
+ tree_text = []
69
+ tree_id = None
70
+ else:
71
+ tree_text.append(line)
72
+
73
+ return docs, ids
74
+
75
+
76
+ def parse_tokenized_sentences(args, model, retag_pipeline, sentences):
77
+ """
78
+ Parse the given sentences, return a list of ParseResult objects
79
+ """
80
+ tags = retag_tags(sentences, retag_pipeline, model.uses_xpos())
81
+ words = [[(word, tag) for word, tag in zip(s_words, s_tags)] for s_words, s_tags in zip(sentences, tags)]
82
+ logger.info("Retagging finished. Parsing tagged text")
83
+
84
+ assert len(words) == len(sentences)
85
+ treebank = model.parse_sentences_no_grad(iter(tqdm(words)), model.build_batch_from_tagged_words, args['eval_batch_size'], model.predict, keep_scores=False)
86
+ return treebank
87
+
88
+ def parse_text(args, model, retag_pipeline, tokenized_file=None, predict_file=None):
89
+ """
90
+ Use the given model to parse text and write it
91
+
92
+ refactored so it can be used elsewhere, such as Ensemble
93
+ """
94
+ model.eval()
95
+
96
+ if predict_file is None:
97
+ if args['predict_file']:
98
+ predict_file = args['predict_file']
99
+ if args['predict_dir']:
100
+ predict_file = os.path.join(args['predict_dir'], predict_file)
101
+
102
+ if tokenized_file is None:
103
+ tokenized_file = args['tokenized_file']
104
+
105
+ docs, ids = None, None
106
+ if tokenized_file is not None:
107
+ docs, ids = read_tokenized_file(tokenized_file)
108
+ elif args['xml_tree_file']:
109
+ logger.info("Reading trees from %s" % args['xml_tree_file'])
110
+ docs, ids = read_xml_tree_file(args['xml_tree_file'])
111
+
112
+ if not docs:
113
+ logger.error("No sentences to process!")
114
+ return
115
+
116
+ logger.info("Processing %d sentences", len(docs))
117
+
118
+ with utils.output_stream(predict_file) as fout:
119
+ chunk_size = 10000
120
+ for chunk_start in range(0, len(docs), chunk_size):
121
+ chunk = docs[chunk_start:chunk_start+chunk_size]
122
+ ids_chunk = ids[chunk_start:chunk_start+chunk_size]
123
+ logger.info("Processing trees %d to %d", chunk_start, chunk_start+len(chunk))
124
+ treebank = parse_tokenized_sentences(args, model, retag_pipeline, chunk)
125
+
126
+ for result, tree_id in zip(treebank, ids_chunk):
127
+ tree = result.predictions[0].tree
128
+ if tree_id is not None:
129
+ tree.tree_id = tree_id
130
+ fout.write(args['predict_format'].format(tree))
131
+ fout.write("\n")
132
+
133
+ def parse_dir(args, model, retag_pipeline, tokenized_dir, predict_dir):
134
+ os.makedirs(predict_dir, exist_ok=True)
135
+ for filename in os.listdir(tokenized_dir):
136
+ input_path = os.path.join(tokenized_dir, filename)
137
+ output_path = os.path.join(predict_dir, os.path.splitext(filename)[0] + ".mrg")
138
+ logger.info("Processing %s to %s", input_path, output_path)
139
+ parse_text(args, model, retag_pipeline, tokenized_file=input_path, predict_file=output_path)
140
+
141
+
142
+ def load_model_parse_text(args, model_file, retag_pipeline):
143
+ """
144
+ Load a model, then parse text and write it to stdout or args['predict_file']
145
+
146
+ retag_pipeline: a list of Pipeline meant to use for retagging
147
+ """
148
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
149
+ load_args = {
150
+ 'wordvec_pretrain_file': args['wordvec_pretrain_file'],
151
+ 'charlm_forward_file': args['charlm_forward_file'],
152
+ 'charlm_backward_file': args['charlm_backward_file'],
153
+ 'device': args['device'],
154
+ }
155
+ trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
156
+ model = trainer.model
157
+ model.eval()
158
+ logger.info("Loaded model from %s", model_file)
159
+
160
+ if args['tokenized_dir']:
161
+ if not args['predict_dir']:
162
+ raise ValueError("Must specific --predict_dir to go with --tokenized_dir")
163
+ parse_dir(args, model, retag_pipeline, args['tokenized_dir'], args['predict_dir'])
164
+ else:
165
+ parse_text(args, model, retag_pipeline)
166
+
stanza/stanza/models/constituency/tree_reader.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reads ParseTree objects from a file, string, or similar input
3
+
4
+ Works by first splitting the input into (, ), and all other tokens,
5
+ then recursively processing those tokens into trees.
6
+ """
7
+
8
+ from collections import deque
9
+ import logging
10
+ import os
11
+ import re
12
+
13
+ from stanza.models.constituency.parse_tree import Tree
14
+ from stanza.utils.get_tqdm import get_tqdm
15
+
16
+ tqdm = get_tqdm()
17
+
18
+ OPEN_PAREN = "("
19
+ CLOSE_PAREN = ")"
20
+
21
+ logger = logging.getLogger('stanza.constituency')
22
+
23
+ # A few specific exception types to clarify parsing errors
24
+ # They store the line number where the error occurred
25
+
26
+ class UnclosedTreeError(ValueError):
27
+ """
28
+ A tree looked like (Foo
29
+ """
30
+ def __init__(self, line_num):
31
+ super().__init__("Found an unfinished tree (missing close brackets). Tree started on line %d" % line_num)
32
+ self.line_num = line_num
33
+
34
+ class ExtraCloseTreeError(ValueError):
35
+ """
36
+ A tree looked like (Foo))
37
+ """
38
+ def __init__(self, line_num):
39
+ super().__init__("Found a broken tree (extra close brackets). Tree started on line %d" % line_num)
40
+ self.line_num = line_num
41
+
42
+ class UnlabeledTreeError(ValueError):
43
+ """
44
+ A tree had no label, such as ((Foo) (Bar))
45
+
46
+ This does not actually happen at the root, btw, as ROOT is silently added
47
+ """
48
+ def __init__(self, line_num):
49
+ super().__init__("Found a tree with no label on a node! Line number %d" % line_num)
50
+ self.line_num = line_num
51
+
52
+ class MixedTreeError(ValueError):
53
+ """
54
+ Leaf and constituent children are mixed in the same node
55
+ """
56
+ def __init__(self, line_num, child_label, children):
57
+ super().__init__("Found a tree with both text children and bracketed children! Line number {} Child label {} Children {}".format(line_num, child_label, children))
58
+ self.line_num = line_num
59
+ self.child_label = child_label
60
+ self.children = children
61
+
62
+ def normalize(text):
63
+ return text.replace("-LRB-", "(").replace("-RRB-", ")")
64
+
65
+ def read_single_tree(token_iterator, broken_ok):
66
+ """
67
+ Build a tree from the tokens in the token_iterator
68
+ """
69
+ # we were called here at a open paren, so start the stack of
70
+ # children with one empty list already on it
71
+ children_stack = deque()
72
+ children_stack.append([])
73
+ text_stack = deque()
74
+ text_stack.append([])
75
+
76
+ token = next(token_iterator, None)
77
+ token_iterator.set_mark()
78
+ while token is not None:
79
+ if token == OPEN_PAREN:
80
+ children_stack.append([])
81
+ text_stack.append([])
82
+ elif token == CLOSE_PAREN:
83
+ text = text_stack.pop()
84
+ children = children_stack.pop()
85
+ if text:
86
+ pieces = " ".join(text).split()
87
+ if len(pieces) == 1:
88
+ child = Tree(pieces[0], children)
89
+ else:
90
+ # the assumption here is that a language such as VI may
91
+ # have spaces in the words, but it still represents
92
+ # just one child
93
+ label = pieces[0]
94
+ child_label = " ".join(pieces[1:])
95
+ if children:
96
+ if broken_ok:
97
+ child = Tree(label, children + [Tree(normalize(child_label))])
98
+ else:
99
+ raise MixedTreeError(token_iterator.line_num, child_label, children)
100
+ else:
101
+ child = Tree(label, Tree(normalize(child_label)))
102
+ if not children_stack:
103
+ return child
104
+ else:
105
+ if not children_stack:
106
+ return Tree("ROOT", children)
107
+ elif broken_ok:
108
+ child = Tree(None, children)
109
+ else:
110
+ raise UnlabeledTreeError(token_iterator.line_num)
111
+ children_stack[-1].append(child)
112
+ else:
113
+ text_stack[-1].append(token)
114
+ token = next(token_iterator, None)
115
+ raise UnclosedTreeError(token_iterator.get_mark())
116
+
117
+ LINE_SPLIT_RE = re.compile(r"([()])")
118
+
119
+
120
+ class TokenIterator:
121
+ """
122
+ A specific iterator for reading trees from a tree file
123
+
124
+ The idea is that this will keep track of which line
125
+ we are processing, so that an error can be logged
126
+ from the correct line
127
+ """
128
+ def __init__(self):
129
+ self.token_iterator = iter([])
130
+ self.line_num = -1
131
+ self.mark = None
132
+
133
+ def set_mark(self):
134
+ """
135
+ The mark is used for determining where the start of a tree occurs for an error
136
+ """
137
+ self.mark = self.line_num
138
+
139
+ def get_mark(self):
140
+ if self.mark is None:
141
+ raise ValueError("No mark set!")
142
+ return self.mark
143
+
144
+ def __iter__(self):
145
+ return self
146
+
147
+ def __next__(self):
148
+ n = next(self.token_iterator, None)
149
+ while n is None:
150
+ self.line_num = self.line_num + 1
151
+ line = next(self.line_iterator)
152
+ if line is None:
153
+ raise StopIteration
154
+ line = line.strip()
155
+ if not line:
156
+ continue
157
+
158
+ pieces = LINE_SPLIT_RE.split(line)
159
+ pieces = [x.strip() for x in pieces]
160
+ pieces = [x for x in pieces if x]
161
+ self.token_iterator = iter(pieces)
162
+ n = next(self.token_iterator, None)
163
+
164
+ return n
165
+
166
+
167
+ class TextTokenIterator(TokenIterator):
168
+ def __init__(self, text, use_tqdm=True):
169
+ super().__init__()
170
+
171
+ self.lines = text.split("\n")
172
+ self.num_lines = len(self.lines)
173
+ if self.num_lines > 1000 and use_tqdm:
174
+ self.line_iterator = iter(tqdm(self.lines))
175
+ else:
176
+ self.line_iterator = iter(self.lines)
177
+
178
+
179
+ class FileTokenIterator(TokenIterator):
180
+ def __init__(self, filename):
181
+ super().__init__()
182
+ self.filename = filename
183
+
184
+ def __enter__(self):
185
+ # TODO: use the file_size instead of counting the lines
186
+ # file_size = Path(self.filename).stat().st_size
187
+ with open(self.filename) as fin:
188
+ num_lines = sum(1 for _ in fin)
189
+
190
+ self.file_obj = open(self.filename)
191
+ if num_lines > 1000:
192
+ self.line_iterator = iter(tqdm(self.file_obj, total=num_lines))
193
+ else:
194
+ self.line_iterator = iter(self.file_obj)
195
+ return self
196
+
197
+ def __exit__(self, exc_type, exc_value, exc_tb):
198
+ if self.file_obj:
199
+ self.file_obj.close()
200
+
201
+ def read_token_iterator(token_iterator, broken_ok, tree_callback):
202
+ trees = []
203
+ token = next(token_iterator, None)
204
+ while token:
205
+ if token == OPEN_PAREN:
206
+ next_tree = read_single_tree(token_iterator, broken_ok=broken_ok)
207
+ if next_tree is None:
208
+ raise ValueError("Tree reader somehow created a None tree! Line number %d" % token_iterator.line_num)
209
+ if tree_callback is not None:
210
+ transformed = tree_callback(next_tree)
211
+ if transformed is not None:
212
+ trees.append(transformed)
213
+ else:
214
+ trees.append(next_tree)
215
+ token = next(token_iterator, None)
216
+ elif token == CLOSE_PAREN:
217
+ raise ExtraCloseTreeError(token_iterator.line_num)
218
+ else:
219
+ raise ValueError("Tree document had text between trees! Line number %d" % token_iterator.line_num)
220
+
221
+ return trees
222
+
223
+
224
+ def read_trees(text, broken_ok=False, tree_callback=None, use_tqdm=True):
225
+ """
226
+ Reads multiple trees from the text
227
+
228
+ TODO: some of the error cases we hit can be recovered from
229
+ """
230
+ token_iterator = TextTokenIterator(text, use_tqdm)
231
+ return read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
232
+
233
+ def read_tree_file(filename, broken_ok=False, tree_callback=None):
234
+ """
235
+ Read all of the trees in the given file
236
+ """
237
+ with FileTokenIterator(filename) as token_iterator:
238
+ trees = read_token_iterator(token_iterator, broken_ok=broken_ok, tree_callback=tree_callback)
239
+ return trees
240
+
241
+ def read_directory(dirname, broken_ok=False, tree_callback=None):
242
+ """
243
+ Read all of the trees in all of the files in a directory
244
+ """
245
+ trees = []
246
+ for filename in sorted(os.listdir(dirname)):
247
+ full_name = os.path.join(dirname, filename)
248
+ trees.extend(read_tree_file(full_name, broken_ok, tree_callback))
249
+ return trees
250
+
251
+ def read_treebank(filename, tree_callback=None):
252
+ """
253
+ Read a treebank and alter the trees to be a simpler format for learning to parse
254
+ """
255
+ logger.info("Reading trees from %s", filename)
256
+ trees = read_tree_file(filename, tree_callback=tree_callback)
257
+ trees = [t.prune_none().simplify_labels() for t in trees]
258
+
259
+ illegal_trees = [t for t in trees if len(t.children) > 1]
260
+ if len(illegal_trees) > 0:
261
+ raise ValueError("Found {} tree(s) which had non-unary transitions at the ROOT. First illegal tree: {:P}".format(len(illegal_trees), illegal_trees[0]))
262
+
263
+ return trees
264
+
265
+ def main():
266
+ """
267
+ Reads a sample tree
268
+ """
269
+ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
270
+ trees = read_trees(text)
271
+ print(trees)
272
+
273
+ if __name__ == '__main__':
274
+ main()
stanza/stanza/models/constituency/tree_stack.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A utilitiy class for keeping track of intermediate parse states
3
+ """
4
+
5
+ from collections import namedtuple
6
+
7
+ class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])):
8
+ """
9
+ A stack which can branch in several directions, as long as you
10
+ keep track of the branching heads
11
+
12
+ An example usage is when K constituents are removed at once
13
+ to create a new constituent, and then the LSTM which tracks the
14
+ values of the constituents is updated starting from the Kth
15
+ output of the LSTM with the new value.
16
+
17
+ We don't simply keep track of a single stack object using a deque
18
+ because versions of the parser which use a beam will want to be
19
+ able to branch in different directions from the same base stack
20
+
21
+ Another possible usage is if an oracle is used for training
22
+ in a manner where some fraction of steps are non-gold steps,
23
+ but we also want to take a gold step from the same state.
24
+ Eg, parser gets to state X, wants to make incorrect transition T
25
+ instead of gold transition G, and so we continue training both
26
+ X+G and X+T. If we only represent the state X with standard
27
+ python stacks, it would not be possible to track both of these
28
+ states at the same time without copying the entire thing.
29
+
30
+ Value can be as transition, a word, or a partially built constituent
31
+
32
+ Implemented as a namedtuple to make it a bit more efficient
33
+ """
34
+ def pop(self):
35
+ return self.parent
36
+
37
+ def push(self, value):
38
+ # returns a new stack node which points to this
39
+ return TreeStack(value, self, self.length+1)
40
+
41
+ def __iter__(self):
42
+ stack = self
43
+ while stack.parent is not None:
44
+ yield stack.value
45
+ stack = stack.parent
46
+ yield stack.value
47
+
48
+ def __reversed__(self):
49
+ items = list(iter(self))
50
+ for item in reversed(items):
51
+ yield item
52
+
53
+ def __str__(self):
54
+ return "TreeStack(%s)" % ", ".join([str(x) for x in self])
55
+
56
+ def __len__(self):
57
+ return self.length
stanza/stanza/models/constituency/utils.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Collects a few of the conparser utility methods which don't belong elsewhere
3
+ """
4
+
5
+ from collections import Counter
6
+ import logging
7
+ import warnings
8
+
9
+ import torch.nn as nn
10
+ from torch import optim
11
+
12
+ from stanza.models.common.doc import TEXT, Document
13
+ from stanza.models.common.utils import get_optimizer
14
+ from stanza.models.constituency.base_model import SimpleModel
15
+ from stanza.models.constituency.parse_transitions import TransitionScheme
16
+ from stanza.models.constituency.parse_tree import Tree
17
+ from stanza.utils.get_tqdm import get_tqdm
18
+
19
+ tqdm = get_tqdm()
20
+
21
+ DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 , "mirror_madgrad": 0.00005 }
22
+ DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 }
23
+ DEFAULT_LEARNING_RHO = 0.9
24
+ DEFAULT_MOMENTUM = { "madgrad": 0.9, "mirror_madgrad": 0.9, "sgd": 0.9 }
25
+
26
+ tlogger = logging.getLogger('stanza.constituency.trainer')
27
+
28
+ # madgrad experiment for weight decay
29
+ # with learning_rate set to 0.0000007 and momentum 0.9
30
+ # on en_wsj, with a baseline model trained on adadela for 200,
31
+ # then madgrad used to further improve that model
32
+ # 0.00000002.out: 0.9590347746438835
33
+ # 0.00000005.out: 0.9591378819960182
34
+ # 0.0000001.out: 0.9595450596319405
35
+ # 0.0000002.out: 0.9594603134479271
36
+ # 0.0000005.out: 0.9591317672706594
37
+ # 0.000001.out: 0.9592548741021389
38
+ # 0.000002.out: 0.9598395477013945
39
+ # 0.000003.out: 0.9594974271553495
40
+ # 0.000004.out: 0.9596665982603754
41
+ # 0.000005.out: 0.9591620720706487
42
+ DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.02, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6, "mirror_madgrad": 2e-6 }
43
+
44
+ def retag_tags(doc, pipelines, xpos):
45
+ """
46
+ Returns a list of list of tags for the items in doc
47
+
48
+ doc can be anything which feeds into the pipeline(s)
49
+ pipelines are a list of 1 or more retag pipelines
50
+ if multiple pipelines are given, majority vote wins
51
+ """
52
+ tag_lists = []
53
+ for pipeline in pipelines:
54
+ doc = pipeline(doc)
55
+ tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences])
56
+ # tag_lists: for N pipeline, S sentences
57
+ # we now have N lists of S sentences each
58
+ # for sentence in zip(*tag_lists): N lists of |s| tags for this given sentence s
59
+ # for tag in zip(*sentence): N predicted tags.
60
+ # most common one in the Counter will be chosen
61
+ tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)]
62
+ for sentence in zip(*tag_lists)]
63
+ return tag_lists
64
+
65
+ def retag_trees(trees, pipelines, xpos=True):
66
+ """
67
+ Retag all of the trees using the given processor
68
+
69
+ Returns a list of new trees
70
+ """
71
+ if len(trees) == 0:
72
+ return trees
73
+
74
+ new_trees = []
75
+ chunk_size = 1000
76
+ with tqdm(total=len(trees)) as pbar:
77
+ for chunk_start in range(0, len(trees), chunk_size):
78
+ chunk_end = min(chunk_start + chunk_size, len(trees))
79
+ chunk = trees[chunk_start:chunk_end]
80
+ sentences = []
81
+ try:
82
+ for idx, tree in enumerate(chunk):
83
+ tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()]
84
+ sentences.append(tokens)
85
+ except ValueError as e:
86
+ raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e
87
+
88
+ doc = Document(sentences)
89
+ tag_lists = retag_tags(doc, pipelines, xpos)
90
+
91
+ for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)):
92
+ try:
93
+ if any(tag is None for tag in tags):
94
+ raise RuntimeError("Tagged tree #{} with a None tag!\n{}\n{}".format(tree_idx, tree, tags))
95
+ new_tree = tree.replace_tags(tags)
96
+ new_trees.append(new_tree)
97
+ pbar.update(1)
98
+ except ValueError as e:
99
+ raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e
100
+ if len(new_trees) != len(trees):
101
+ raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees)))
102
+ return new_trees
103
+
104
+
105
+ # experimental results on nonlinearities
106
+ # this is on a VI dataset, VLSP_22, using 1/10th of the data as a dev set
107
+ # (no released test set at the time of the experiment)
108
+ # original non-Bert tagger, with 1 iteration each instead of averaged over 5
109
+ # considering the number of experiments and the length of time they would take
110
+ #
111
+ # Gelu had the highest score, which tracks with other experiments run.
112
+ # Note that publicly released models have typically used Relu
113
+ # on account of the runtime speed improvement
114
+ #
115
+ # Anyway, a larger experiment of 5x models on gelu or relu, using the
116
+ # Roberta POS tagger and a corpus of silver trees, resulted in 0.8270
117
+ # for relu and 0.8248 for gelu. So it is not even clear that
118
+ # switching to gelu would be an accuracy improvement.
119
+ #
120
+ # Gelu: 82.32
121
+ # Relu: 82.14
122
+ # Mish: 81.95
123
+ # Relu6: 81.91
124
+ # Silu: 81.90
125
+ # ELU: 81.73
126
+ # Hardswish: 81.67
127
+ # Softsign: 81.63
128
+ # Hardtanh: 81.44
129
+ # Celu: 81.43
130
+ # Selu: 81.17
131
+ # TODO: need to redo the prelu experiment with
132
+ # possibly different numbers of parameters
133
+ # and proper weight decay
134
+ # Prelu: 80.95 (terminated early)
135
+ # Softplus: 80.94
136
+ # Logsigmoid: 80.91
137
+ # Hardsigmoid: 79.03
138
+ # RReLU: 77.00
139
+ # Hardshrink: failed
140
+ # Softshrink: failed
141
+ NONLINEARITY = {
142
+ 'celu': nn.CELU,
143
+ 'elu': nn.ELU,
144
+ 'gelu': nn.GELU,
145
+ 'hardshrink': nn.Hardshrink,
146
+ 'hardtanh': nn.Hardtanh,
147
+ 'leaky_relu': nn.LeakyReLU,
148
+ 'logsigmoid': nn.LogSigmoid,
149
+ 'prelu': nn.PReLU,
150
+ 'relu': nn.ReLU,
151
+ 'relu6': nn.ReLU6,
152
+ 'rrelu': nn.RReLU,
153
+ 'selu': nn.SELU,
154
+ 'softplus': nn.Softplus,
155
+ 'softshrink': nn.Softshrink,
156
+ 'softsign': nn.Softsign,
157
+ 'tanhshrink': nn.Tanhshrink,
158
+ 'tanh': nn.Tanh,
159
+ }
160
+
161
+ # separating these out allows for backwards compatibility with earlier versions of pytorch
162
+ # NOTE torch compatibility: if we ever *release* models with these
163
+ # activation functions, we will need to break that compatibility
164
+
165
+ nonlinearity_list = [
166
+ 'GLU',
167
+ 'Hardsigmoid',
168
+ 'Hardswish',
169
+ 'Mish',
170
+ 'SiLU',
171
+ ]
172
+
173
+ for nonlinearity in nonlinearity_list:
174
+ if hasattr(nn, nonlinearity):
175
+ NONLINEARITY[nonlinearity.lower()] = getattr(nn, nonlinearity)
176
+
177
+ def build_nonlinearity(nonlinearity):
178
+ """
179
+ Look up "nonlinearity" in a map from function name to function, build the appropriate layer.
180
+ """
181
+ if nonlinearity in NONLINEARITY:
182
+ return NONLINEARITY[nonlinearity]()
183
+ raise ValueError('Chosen value of nonlinearity, "%s", not handled' % nonlinearity)
184
+
185
+ def build_optimizer(args, model, build_simple_adadelta=False):
186
+ """
187
+ Build an optimizer based on the arguments given
188
+
189
+ If we are "multistage" training and epochs_trained < epochs // 2,
190
+ we build an AdaDelta optimizer instead of whatever was requested
191
+ The build_simple_adadelta parameter controls this
192
+ """
193
+ bert_learning_rate = 0.0
194
+ bert_weight_decay = args['bert_weight_decay']
195
+ if build_simple_adadelta:
196
+ optim_type = 'adadelta'
197
+ bert_finetune = args.get('stage1_bert_finetune', False)
198
+ if bert_finetune:
199
+ bert_learning_rate = args['stage1_bert_learning_rate']
200
+ learning_beta2 = 0.999 # doesn't matter for AdaDelta
201
+ learning_eps = DEFAULT_LEARNING_EPS['adadelta']
202
+ learning_rate = args['stage1_learning_rate']
203
+ learning_rho = DEFAULT_LEARNING_RHO
204
+ momentum = None # also doesn't matter for AdaDelta
205
+ weight_decay = DEFAULT_WEIGHT_DECAY['adadelta']
206
+ else:
207
+ optim_type = args['optim'].lower()
208
+ bert_finetune = args.get('bert_finetune', False)
209
+ if bert_finetune:
210
+ bert_learning_rate = args['bert_learning_rate']
211
+ learning_beta2 = args['learning_beta2']
212
+ learning_eps = args['learning_eps']
213
+ learning_rate = args['learning_rate']
214
+ learning_rho = args['learning_rho']
215
+ momentum = args['learning_momentum']
216
+ weight_decay = args['learning_weight_decay']
217
+
218
+ # TODO: allow rho as an arg for AdaDelta
219
+ return get_optimizer(name=optim_type,
220
+ model=model,
221
+ lr=learning_rate,
222
+ betas=(0.9, learning_beta2),
223
+ eps=learning_eps,
224
+ momentum=momentum,
225
+ weight_decay=weight_decay,
226
+ bert_learning_rate=bert_learning_rate,
227
+ bert_weight_decay=weight_decay*bert_weight_decay,
228
+ is_peft=args.get('use_peft', False),
229
+ bert_finetune_layers=args['bert_finetune_layers'],
230
+ opt_logger=tlogger)
231
+
232
+ def build_scheduler(args, optimizer, first_optimizer=False):
233
+ """
234
+ Build the scheduler for the conparser based on its args
235
+
236
+ Used to use a warmup for learning rate, but that wasn't working very well
237
+ Now, we just use a ReduceLROnPlateau, which does quite well
238
+ """
239
+ #if args.get('learning_rate_warmup', 0) <= 0:
240
+ # # TODO: is there an easier way to make an empty scheduler?
241
+ # lr_lambda = lambda x: 1.0
242
+ #else:
243
+ # warmup_end = args['learning_rate_warmup']
244
+ # def lr_lambda(x):
245
+ # if x >= warmup_end:
246
+ # return 1.0
247
+ # return x / warmup_end
248
+
249
+ #scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
250
+
251
+ if first_optimizer:
252
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['stage1_learning_rate_min_lr'])
253
+ else:
254
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['learning_rate_min_lr'])
255
+ return scheduler
256
+
257
+ def initialize_linear(linear, nonlinearity, bias):
258
+ """
259
+ Initializes the bias to a positive value, hopefully preventing dead neurons
260
+ """
261
+ if nonlinearity in ('relu', 'leaky_relu'):
262
+ nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity)
263
+ nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5)
264
+
265
+ def add_predict_output_args(parser):
266
+ """
267
+ Args specifically for the output location of data
268
+ """
269
+ parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging')
270
+ parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions')
271
+ parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions')
272
+
273
+ parser.add_argument('--predict_output_gold_tags', default=False, action='store_true', help='Output gold tags as part of the evaluation - useful for putting the trees through EvalB')
274
+
275
+ def postprocess_predict_output_args(args):
276
+ if len(args['predict_format']) <= 2 or (len(args['predict_format']) <= 3 and args['predict_format'].endswith("Vi")):
277
+ args['predict_format'] = "{:" + args['predict_format'] + "}"
278
+
279
+
280
+ def get_open_nodes(trees, transition_scheme):
281
+ """
282
+ Return a list of all open nodes in the given dataset.
283
+ Depending on the parameters, may be single or compound open transitions.
284
+ """
285
+ if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND:
286
+ return Tree.get_compound_constituents(trees)
287
+ elif transition_scheme is TransitionScheme.IN_ORDER_COMPOUND:
288
+ return Tree.get_compound_constituents(trees, separate_root=True)
289
+ else:
290
+ return [(x,) for x in Tree.get_unique_constituent_labels(trees)]
291
+
292
+
293
+ def verify_transitions(trees, sequences, transition_scheme, unary_limit, reverse, name, root_labels):
294
+ """
295
+ Given a list of trees and their transition sequences, verify that the sequences rebuild the trees
296
+ """
297
+ model = SimpleModel(transition_scheme, unary_limit, reverse, root_labels)
298
+ tlogger.info("Verifying the transition sequences for %d trees", len(trees))
299
+
300
+ data = zip(trees, sequences)
301
+ if tlogger.getEffectiveLevel() <= logging.INFO:
302
+ data = tqdm(zip(trees, sequences), total=len(trees))
303
+
304
+ for tree_idx, (tree, sequence) in enumerate(data):
305
+ # TODO: make the SimpleModel have a parse operation?
306
+ state = model.initial_state_from_gold_trees([tree])[0]
307
+ for idx, trans in enumerate(sequence):
308
+ if not trans.is_legal(state, model):
309
+ raise RuntimeError("Tree {} of {} failed: transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(tree_idx, name, idx, trans, tree, sequence))
310
+ state = trans.apply(state, model)
311
+ result = model.get_top_constituent(state.constituents)
312
+ if reverse:
313
+ result = result.reverse()
314
+ if tree != result:
315
+ raise RuntimeError("Tree {} of {} failed: transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree_idx, name, tree, sequence, result))
316
+
317
+ def check_constituents(train_constituents, trees, treebank_name, fail=True):
318
+ """
319
+ Check that all the constituents in the other dataset are known in the train set
320
+ """
321
+ constituents = Tree.get_unique_constituent_labels(trees)
322
+ for con in constituents:
323
+ if con not in train_constituents:
324
+ first_error = None
325
+ num_errors = 0
326
+ for tree_idx, tree in enumerate(trees):
327
+ constituents = Tree.get_unique_constituent_labels(tree)
328
+ if con in constituents:
329
+ num_errors += 1
330
+ if first_error is None:
331
+ first_error = tree_idx
332
+ error = "Found constituent label {} in the {} set which don't exist in the train set. This constituent label occured in {} trees, with the first tree index at {} counting from 1\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\n{:P}".format(con, treebank_name, num_errors, (first_error+1), trees[first_error])
333
+ if fail:
334
+ raise RuntimeError(error)
335
+ else:
336
+ warnings.warn(error)
337
+
338
+ def check_root_labels(root_labels, other_trees, treebank_name):
339
+ """
340
+ Check that all the root states in the other dataset are known in the train set
341
+ """
342
+ for root_state in Tree.get_root_labels(other_trees):
343
+ if root_state not in root_labels:
344
+ raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name))
345
+
346
+ def remove_duplicate_trees(trees, treebank_name):
347
+ """
348
+ Filter duplicates from the given dataset
349
+ """
350
+ new_trees = []
351
+ known_trees = set()
352
+ for tree in trees:
353
+ tree_str = "{}".format(tree)
354
+ if tree_str in known_trees:
355
+ continue
356
+ known_trees.add(tree_str)
357
+ new_trees.append(tree)
358
+ if len(new_trees) < len(trees):
359
+ tlogger.info("Filtered %d duplicates from %s dataset", (len(trees) - len(new_trees)), treebank_name)
360
+ return new_trees
361
+
362
+ def remove_singleton_trees(trees):
363
+ """
364
+ remove trees which are just a root and a single word
365
+
366
+ TODO: remove these trees in the conversion instead of here
367
+ """
368
+ new_trees = [x for x in trees if
369
+ len(x.children) > 1 or
370
+ (len(x.children) == 1 and len(x.children[0].children) > 1) or
371
+ (len(x.children) == 1 and len(x.children[0].children) == 1 and len(x.children[0].children[0].children) >= 1)]
372
+ if len(trees) - len(new_trees) > 0:
373
+ tlogger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees)))
374
+ return new_trees
375
+
stanza/stanza/models/coref/predict.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import json
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from stanza.models.coref.model import CorefModel
8
+
9
+
10
+ if __name__ == "__main__":
11
+ argparser = argparse.ArgumentParser()
12
+ argparser.add_argument("experiment")
13
+ argparser.add_argument("input_file")
14
+ argparser.add_argument("output_file")
15
+ argparser.add_argument("--config-file", default="config.toml")
16
+ argparser.add_argument("--batch-size", type=int,
17
+ help="Adjust to override the config value if you're"
18
+ " experiencing out-of-memory issues")
19
+ argparser.add_argument("--weights",
20
+ help="Path to file with weights to load."
21
+ " If not supplied, in the latest"
22
+ " weights of the experiment will be loaded;"
23
+ " if there aren't any, an error is raised.")
24
+ args = argparser.parse_args()
25
+
26
+ model = CorefModel.load_model(path=args.weights,
27
+ map_location="cpu",
28
+ ignore={"bert_optimizer", "general_optimizer",
29
+ "bert_scheduler", "general_scheduler"})
30
+ if args.batch_size:
31
+ model.config.a_scoring_batch_size = args.batch_size
32
+ model.training = False
33
+
34
+ try:
35
+ with open(args.input_file, encoding="utf-8") as fin:
36
+ input_data = json.load(fin)
37
+ except json.decoder.JSONDecodeError:
38
+ # read the old jsonlines format if necessary
39
+ with open(args.input_file, encoding="utf-8") as fin:
40
+ text = "[" + ",\n".join(fin) + "]"
41
+ input_data = json.loads(text)
42
+ docs = [model.build_doc(doc) for doc in input_data]
43
+
44
+ with torch.no_grad():
45
+ for doc in tqdm(docs, unit="docs"):
46
+ result = model.run(doc)
47
+ doc["span_clusters"] = result.span_clusters
48
+ doc["word_clusters"] = result.word_clusters
49
+
50
+ for key in ("word2subword", "subwords", "word_id", "head2span"):
51
+ del doc[key]
52
+
53
+ with open(args.output_file, mode="w") as fout:
54
+ for doc in docs:
55
+ json.dump(doc, fout)
stanza/stanza/models/coref/span_predictor.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes SpanPredictor which aims to predict spans by taking as input
2
+ head word and context embeddings.
3
+ """
4
+
5
+ from typing import List, Optional, Tuple
6
+
7
+ from stanza.models.coref.const import Doc, Span
8
+ import torch
9
+
10
+
11
+ class SpanPredictor(torch.nn.Module):
12
+ def __init__(self, input_size: int, distance_emb_size: int):
13
+ super().__init__()
14
+ self.ffnn = torch.nn.Sequential(
15
+ torch.nn.Linear(input_size * 2 + 64, input_size),
16
+ torch.nn.ReLU(),
17
+ torch.nn.Dropout(0.3),
18
+ torch.nn.Linear(input_size, 256),
19
+ torch.nn.ReLU(),
20
+ torch.nn.Dropout(0.3),
21
+ torch.nn.Linear(256, 64),
22
+ )
23
+ self.conv = torch.nn.Sequential(
24
+ torch.nn.Conv1d(64, 4, 3, 1, 1),
25
+ torch.nn.Conv1d(4, 2, 3, 1, 1)
26
+ )
27
+ self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
28
+
29
+ @property
30
+ def device(self) -> torch.device:
31
+ """ A workaround to get current device (which is assumed to be the
32
+ device of the first parameter of one of the submodules) """
33
+ return next(self.ffnn.parameters()).device
34
+
35
+ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
36
+ doc: Doc,
37
+ words: torch.Tensor,
38
+ heads_ids: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Calculates span start/end scores of words for each span head in
41
+ heads_ids
42
+
43
+ Args:
44
+ doc (Doc): the document data
45
+ words (torch.Tensor): contextual embeddings for each word in the
46
+ document, [n_words, emb_size]
47
+ heads_ids (torch.Tensor): word indices of span heads
48
+
49
+ Returns:
50
+ torch.Tensor: span start/end scores, [n_heads, n_words, 2]
51
+ """
52
+ # Obtain distance embedding indices, [n_heads, n_words]
53
+ relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
54
+ emb_ids = relative_positions + 63 # make all valid distances positive
55
+ emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # "too_far"
56
+
57
+ # Obtain "same sentence" boolean mask, [n_heads, n_words]
58
+ sent_id = torch.tensor(doc["sent_id"], device=words.device)
59
+ same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
60
+
61
+ # To save memory, only pass candidates from one sentence for each head
62
+ # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
63
+ # for each candidate among the words in the same sentence as span_head
64
+ # [n_heads, input_size * 2 + distance_emb_size]
65
+ rows, cols = same_sent.nonzero(as_tuple=True)
66
+ pair_matrix = torch.cat((
67
+ words[heads_ids[rows]],
68
+ words[cols],
69
+ self.emb(emb_ids[rows, cols]),
70
+ ), dim=1)
71
+
72
+ lengths = same_sent.sum(dim=1)
73
+ padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
74
+ padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
75
+
76
+ # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
77
+ # This is necessary to allow the convolution layer to look at several
78
+ # word scores
79
+ padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
80
+ padded_pairs[padding_mask] = pair_matrix
81
+
82
+ res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
83
+ res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
84
+
85
+ scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
86
+ scores[rows, cols] = res[padding_mask]
87
+
88
+ # Make sure that start <= head <= end during inference
89
+ if not self.training:
90
+ valid_starts = torch.log((relative_positions >= 0).to(torch.float))
91
+ valid_ends = torch.log((relative_positions <= 0).to(torch.float))
92
+ valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
93
+ return scores + valid_positions
94
+ return scores
95
+
96
+ def get_training_data(self,
97
+ doc: Doc,
98
+ words: torch.Tensor
99
+ ) -> Tuple[Optional[torch.Tensor],
100
+ Optional[Tuple[torch.Tensor, torch.Tensor]]]:
101
+ """ Returns span starts/ends for gold mentions in the document. """
102
+ head2span = sorted(doc["head2span"])
103
+ if not head2span:
104
+ return None, None
105
+ heads, starts, ends = zip(*head2span)
106
+ heads = torch.tensor(heads, device=self.device)
107
+ starts = torch.tensor(starts, device=self.device)
108
+ ends = torch.tensor(ends, device=self.device) - 1
109
+ return self(doc, words, heads), (starts, ends)
110
+
111
+ def predict(self,
112
+ doc: Doc,
113
+ words: torch.Tensor,
114
+ clusters: List[List[int]]) -> List[List[Span]]:
115
+ """
116
+ Predicts span clusters based on the word clusters.
117
+
118
+ Args:
119
+ doc (Doc): the document data
120
+ words (torch.Tensor): [n_words, emb_size] matrix containing
121
+ embeddings for each of the words in the text
122
+ clusters (List[List[int]]): a list of clusters where each cluster
123
+ is a list of word indices
124
+
125
+ Returns:
126
+ List[List[Span]]: span clusters
127
+ """
128
+ if not clusters:
129
+ return []
130
+
131
+ heads_ids = torch.tensor(
132
+ sorted(i for cluster in clusters for i in cluster),
133
+ device=self.device
134
+ )
135
+
136
+ scores = self(doc, words, heads_ids)
137
+ starts = scores[:, :, 0].argmax(dim=1).tolist()
138
+ ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist()
139
+
140
+ head2span = {
141
+ head: (start, end)
142
+ for head, start, end in zip(heads_ids.tolist(), starts, ends)
143
+ }
144
+
145
+ return [[head2span[head] for head in cluster]
146
+ for cluster in clusters]
stanza/stanza/models/coref/tokenizer_customization.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ This file defines functions used to modify the default behaviour
2
+ of transformers.AutoTokenizer. These changes are necessary, because some
3
+ tokenizers are meant to be used with raw text, while the OntoNotes documents
4
+ have already been split into words.
5
+ All the functions are used in coref_model.CorefModel._get_docs. """
6
+
7
+
8
+ # Filters out unwanted tokens produced by the tokenizer
9
+ TOKENIZER_FILTERS = {
10
+ "albert-xxlarge-v2": (lambda token: token != "▁"), # U+2581, not just "_"
11
+ "albert-large-v2": (lambda token: token != "▁"),
12
+ }
13
+
14
+ # Maps some words to tokens directly, without a tokenizer
15
+ TOKENIZER_MAPS = {
16
+ "roberta-large": {".": ["."], ",": [","], "!": ["!"], "?": ["?"],
17
+ ":":[":"], ";":[";"], "'s": ["'s"]}
18
+ }
stanza/stanza/models/coref/word_encoder.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes WordEncoder. Extracts mention vectors from bert-encoded text.
2
+ """
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+ from stanza.models.coref.config import Config
9
+ from stanza.models.coref.const import Doc
10
+
11
+
12
+ class WordEncoder(torch.nn.Module): # pylint: disable=too-many-instance-attributes
13
+ """ Receives bert contextual embeddings of a text, extracts all the
14
+ possible mentions in that text. """
15
+
16
+ def __init__(self, features: int, config: Config):
17
+ """
18
+ Args:
19
+ features (int): the number of featues in the input embeddings
20
+ config (Config): the configuration of the current session
21
+ """
22
+ super().__init__()
23
+ self.attn = torch.nn.Linear(in_features=features, out_features=1)
24
+ self.dropout = torch.nn.Dropout(config.dropout_rate)
25
+
26
+ @property
27
+ def device(self) -> torch.device:
28
+ """ A workaround to get current device (which is assumed to be the
29
+ device of the first parameter of one of the submodules) """
30
+ return next(self.attn.parameters()).device
31
+
32
+ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
33
+ doc: Doc,
34
+ x: torch.Tensor,
35
+ ) -> Tuple[torch.Tensor, ...]:
36
+ """
37
+ Extracts word representations from text.
38
+
39
+ Args:
40
+ doc: the document data
41
+ x: a tensor containing bert output, shape (n_subtokens, bert_dim)
42
+
43
+ Returns:
44
+ words: a Tensor of shape [n_words, mention_emb];
45
+ mention representations
46
+ cluster_ids: tensor of shape [n_words], containing cluster indices
47
+ for each word. Non-coreferent words have cluster id of zero.
48
+ """
49
+ word_boundaries = torch.tensor(doc["word2subword"], device=self.device)
50
+ starts = word_boundaries[:, 0]
51
+ ends = word_boundaries[:, 1]
52
+
53
+ # [n_mentions, features]
54
+ words = self._attn_scores(x, starts, ends).mm(x)
55
+
56
+ words = self.dropout(words)
57
+
58
+ return (words, self._cluster_ids(doc))
59
+
60
+ def _attn_scores(self,
61
+ bert_out: torch.Tensor,
62
+ word_starts: torch.Tensor,
63
+ word_ends: torch.Tensor) -> torch.Tensor:
64
+ """ Calculates attention scores for each of the mentions.
65
+
66
+ Args:
67
+ bert_out (torch.Tensor): [n_subwords, bert_emb], bert embeddings
68
+ for each of the subwords in the document
69
+ word_starts (torch.Tensor): [n_words], start indices of words
70
+ word_ends (torch.Tensor): [n_words], end indices of words
71
+
72
+ Returns:
73
+ torch.Tensor: [description]
74
+ """
75
+ n_subtokens = len(bert_out)
76
+ n_words = len(word_starts)
77
+
78
+ # [n_mentions, n_subtokens]
79
+ # with 0 at positions belonging to the words and -inf elsewhere
80
+ attn_mask = torch.arange(0, n_subtokens, device=self.device).expand((n_words, n_subtokens))
81
+ attn_mask = ((attn_mask >= word_starts.unsqueeze(1))
82
+ * (attn_mask < word_ends.unsqueeze(1)))
83
+ attn_mask = torch.log(attn_mask.to(torch.float))
84
+
85
+ attn_scores = self.attn(bert_out).T # [1, n_subtokens]
86
+ attn_scores = attn_scores.expand((n_words, n_subtokens))
87
+ attn_scores = attn_mask + attn_scores
88
+ del attn_mask
89
+ return torch.softmax(attn_scores, dim=1) # [n_words, n_subtokens]
90
+
91
+ def _cluster_ids(self, doc: Doc) -> torch.Tensor:
92
+ """
93
+ Args:
94
+ doc: document information
95
+
96
+ Returns:
97
+ torch.Tensor of shape [n_word], containing cluster indices for
98
+ each word. Non-coreferent words have cluster id of zero.
99
+ """
100
+ word2cluster = {word_i: i
101
+ for i, cluster in enumerate(doc["word_clusters"], start=1)
102
+ for word_i in cluster}
103
+
104
+ return torch.tensor(
105
+ [word2cluster.get(word_i, 0)
106
+ for word_i in range(len(doc["cased_words"]))],
107
+ device=self.device
108
+ )
stanza/stanza/models/depparse/data.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import logging
3
+ import torch
4
+
5
+ from stanza.models.common.bert_embedding import filter_data, needs_length_filter
6
+ from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
7
+ from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, ROOT_ID, CompositeVocab, CharVocab
8
+ from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
9
+ from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
10
+ from stanza.models.common.doc import *
11
+
12
+ logger = logging.getLogger('stanza')
13
+
14
+ def data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately):
15
+ """
16
+ Given a list of lists, where the first element of each sublist
17
+ represents the sentence, group the sentences into batches.
18
+
19
+ During training mode (not eval_mode) the sentences are sorted by
20
+ length with a bit of random shuffling. During eval mode, the
21
+ sentences are sorted by length if sort_during_eval is true.
22
+
23
+ Refactored from the data structure in case other models could use
24
+ it and for ease of testing.
25
+
26
+ Returns (batches, original_order), where original_order is None
27
+ when in train mode or when unsorted and represents the original
28
+ location of each sentence in the sort
29
+ """
30
+ res = []
31
+
32
+ if not eval_mode:
33
+ # sort sentences (roughly) by length for better memory utilization
34
+ data = sorted(data, key = lambda x: len(x[0]), reverse=random.random() > .5)
35
+ data_orig_idx = None
36
+ elif sort_during_eval:
37
+ (data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data])
38
+ else:
39
+ data_orig_idx = None
40
+
41
+ current = []
42
+ currentlen = 0
43
+ for x in data:
44
+ if min_length_to_batch_separately is not None and len(x[0]) > min_length_to_batch_separately:
45
+ if currentlen > 0:
46
+ res.append(current)
47
+ current = []
48
+ currentlen = 0
49
+ res.append([x])
50
+ else:
51
+ if len(x[0]) + currentlen > batch_size and currentlen > 0:
52
+ res.append(current)
53
+ current = []
54
+ currentlen = 0
55
+ current.append(x)
56
+ currentlen += len(x[0])
57
+
58
+ if currentlen > 0:
59
+ res.append(current)
60
+
61
+ return res, data_orig_idx
62
+
63
+
64
+ class DataLoader:
65
+
66
+ def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, min_length_to_batch_separately=None, bert_tokenizer=None):
67
+ self.batch_size = batch_size
68
+ self.min_length_to_batch_separately=min_length_to_batch_separately
69
+ self.args = args
70
+ self.eval = evaluation
71
+ self.shuffled = not self.eval
72
+ self.sort_during_eval = sort_during_eval
73
+ self.doc = doc
74
+ data = self.load_doc(doc)
75
+
76
+ # handle vocab
77
+ if vocab is None:
78
+ self.vocab = self.init_vocab(data)
79
+ else:
80
+ self.vocab = vocab
81
+
82
+ # filter out the long sentences if bert is used
83
+ if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
84
+ data = filter_data(self.args['bert_model'], data, bert_tokenizer)
85
+
86
+ # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
87
+ self.pretrain_vocab = None
88
+ if pretrain is not None and args['pretrain']:
89
+ self.pretrain_vocab = pretrain.vocab
90
+
91
+ # filter and sample data
92
+ if args.get('sample_train', 1.0) < 1.0 and not self.eval:
93
+ keep = int(args['sample_train'] * len(data))
94
+ data = random.sample(data, keep)
95
+ logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
96
+
97
+ data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
98
+ # shuffle for training
99
+ if self.shuffled:
100
+ random.shuffle(data)
101
+ self.num_examples = len(data)
102
+
103
+ # chunk into batches
104
+ self.data = self.chunk_batches(data)
105
+ logger.debug("{} batches created.".format(len(self.data)))
106
+
107
+ def init_vocab(self, data):
108
+ assert self.eval == False # for eval vocab must exist
109
+ charvocab = CharVocab(data, self.args['shorthand'])
110
+ wordvocab = WordVocab(data, self.args['shorthand'], cutoff=7, lower=True)
111
+ uposvocab = WordVocab(data, self.args['shorthand'], idx=1)
112
+ xposvocab = xpos_vocab_factory(data, self.args['shorthand'])
113
+ featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3)
114
+ lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=7, idx=4, lower=True)
115
+ deprelvocab = WordVocab(data, self.args['shorthand'], idx=6)
116
+ vocab = MultiVocab({'char': charvocab,
117
+ 'word': wordvocab,
118
+ 'upos': uposvocab,
119
+ 'xpos': xposvocab,
120
+ 'feats': featsvocab,
121
+ 'lemma': lemmavocab,
122
+ 'deprel': deprelvocab})
123
+ return vocab
124
+
125
+ def preprocess(self, data, vocab, pretrain_vocab, args):
126
+ processed = []
127
+ xpos_replacement = [[ROOT_ID] * len(vocab['xpos'])] if isinstance(vocab['xpos'], CompositeVocab) else [ROOT_ID]
128
+ feats_replacement = [[ROOT_ID] * len(vocab['feats'])]
129
+ for sent in data:
130
+ processed_sent = [[ROOT_ID] + vocab['word'].map([w[0] for w in sent])]
131
+ processed_sent += [[[ROOT_ID]] + [vocab['char'].map([x for x in w[0]]) for w in sent]]
132
+ processed_sent += [[ROOT_ID] + vocab['upos'].map([w[1] for w in sent])]
133
+ processed_sent += [xpos_replacement + vocab['xpos'].map([w[2] for w in sent])]
134
+ processed_sent += [feats_replacement + vocab['feats'].map([w[3] for w in sent])]
135
+ if pretrain_vocab is not None:
136
+ # always use lowercase lookup in pretrained vocab
137
+ processed_sent += [[ROOT_ID] + pretrain_vocab.map([w[0].lower() for w in sent])]
138
+ else:
139
+ processed_sent += [[ROOT_ID] + [PAD_ID] * len(sent)]
140
+ processed_sent += [[ROOT_ID] + vocab['lemma'].map([w[4] for w in sent])]
141
+ processed_sent += [[to_int(w[5], ignore_error=self.eval) for w in sent]]
142
+ processed_sent += [vocab['deprel'].map([w[6] for w in sent])]
143
+ processed_sent.append([w[0] for w in sent])
144
+ processed.append(processed_sent)
145
+ return processed
146
+
147
+ def __len__(self):
148
+ return len(self.data)
149
+
150
+ def __getitem__(self, key):
151
+ """ Get a batch with index. """
152
+ if not isinstance(key, int):
153
+ raise TypeError
154
+ if key < 0 or key >= len(self.data):
155
+ raise IndexError
156
+ batch = self.data[key]
157
+ batch_size = len(batch)
158
+ batch = list(zip(*batch))
159
+ assert len(batch) == 10
160
+
161
+ # sort sentences by lens for easy RNN operations
162
+ lens = [len(x) for x in batch[0]]
163
+ batch, orig_idx = sort_all(batch, lens)
164
+
165
+ # sort words by lens for easy char-RNN operations
166
+ batch_words = [w for sent in batch[1] for w in sent]
167
+ word_lens = [len(x) for x in batch_words]
168
+ batch_words, word_orig_idx = sort_all([batch_words], word_lens)
169
+ batch_words = batch_words[0]
170
+ word_lens = [len(x) for x in batch_words]
171
+
172
+ # convert to tensors
173
+ words = batch[0]
174
+ words = get_long_tensor(words, batch_size)
175
+ words_mask = torch.eq(words, PAD_ID)
176
+ wordchars = get_long_tensor(batch_words, len(word_lens))
177
+ wordchars_mask = torch.eq(wordchars, PAD_ID)
178
+
179
+ upos = get_long_tensor(batch[2], batch_size)
180
+ xpos = get_long_tensor(batch[3], batch_size)
181
+ ufeats = get_long_tensor(batch[4], batch_size)
182
+ pretrained = get_long_tensor(batch[5], batch_size)
183
+ sentlens = [len(x) for x in batch[0]]
184
+ lemma = get_long_tensor(batch[6], batch_size)
185
+ head = get_long_tensor(batch[7], batch_size)
186
+ deprel = get_long_tensor(batch[8], batch_size)
187
+ text = batch[9]
188
+ return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, orig_idx, word_orig_idx, sentlens, word_lens, text
189
+
190
+ def load_doc(self, doc):
191
+ data = doc.get([TEXT, UPOS, XPOS, FEATS, LEMMA, HEAD, DEPREL], as_sentences=True)
192
+ data = self.resolve_none(data)
193
+ return data
194
+
195
+ def resolve_none(self, data):
196
+ # replace None to '_'
197
+ for sent_idx in range(len(data)):
198
+ for tok_idx in range(len(data[sent_idx])):
199
+ for feat_idx in range(len(data[sent_idx][tok_idx])):
200
+ if data[sent_idx][tok_idx][feat_idx] is None:
201
+ data[sent_idx][tok_idx][feat_idx] = '_'
202
+ return data
203
+
204
+ def __iter__(self):
205
+ for i in range(self.__len__()):
206
+ yield self.__getitem__(i)
207
+
208
+ def set_batch_size(self, batch_size):
209
+ self.batch_size = batch_size
210
+
211
+ def reshuffle(self):
212
+ data = [y for x in self.data for y in x]
213
+ self.data = self.chunk_batches(data)
214
+ random.shuffle(self.data)
215
+
216
+ def chunk_batches(self, data):
217
+ batches, data_orig_idx = data_to_batches(data=data, batch_size=self.batch_size,
218
+ eval_mode=self.eval, sort_during_eval=self.sort_during_eval,
219
+ min_length_to_batch_separately=self.min_length_to_batch_separately)
220
+ # data_orig_idx might be None at train time, since we don't anticipate unsorting
221
+ self.data_orig_idx = data_orig_idx
222
+ return batches
223
+
224
+ def to_int(string, ignore_error=False):
225
+ try:
226
+ res = int(string)
227
+ except ValueError as err:
228
+ if ignore_error:
229
+ return 0
230
+ else:
231
+ raise err
232
+ return res
233
+
stanza/stanza/models/lemma/attach_lemma_classifier.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from stanza.models.lemma.trainer import Trainer
4
+ from stanza.models.lemma_classifier.base_model import LemmaClassifier
5
+
6
+ def attach_classifier(input_filename, output_filename, classifiers):
7
+ trainer = Trainer(model_file=input_filename)
8
+
9
+ for classifier in classifiers:
10
+ classifier = LemmaClassifier.load(classifier)
11
+ trainer.contextual_lemmatizers.append(classifier)
12
+
13
+ trainer.save(output_filename)
14
+
15
+ def main(args=None):
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from')
18
+ parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer')
19
+ parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach')
20
+ args = parser.parse_args(args)
21
+
22
+ attach_classifier(args.input, args.output, args.classifier)
23
+
24
+ if __name__ == '__main__':
25
+ main()
stanza/stanza/models/lemma/scorer.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils and wrappers for scoring lemmatizers.
3
+ """
4
+
5
+ from stanza.models.common.utils import ud_scores
6
+
7
+ def score(system_conllu_file, gold_conllu_file):
8
+ """ Wrapper for lemma scorer. """
9
+ evaluation = ud_scores(gold_conllu_file, system_conllu_file)
10
+ el = evaluation["Lemmas"]
11
+ p, r, f = el.precision, el.recall, el.f1
12
+ return p, r, f
13
+
stanza/stanza/models/lemma/vocab.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ from stanza.models.common.vocab import BaseVocab, BaseMultiVocab
4
+ from stanza.models.common.seq2seq_constant import VOCAB_PREFIX
5
+
6
+ class Vocab(BaseVocab):
7
+ def build_vocab(self):
8
+ counter = Counter(self.data)
9
+ self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
10
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
11
+
12
+ class MultiVocab(BaseMultiVocab):
13
+ @classmethod
14
+ def load_state_dict(cls, state_dict):
15
+ new = cls()
16
+ for k,v in state_dict.items():
17
+ new[k] = Vocab.load_state_dict(v)
18
+ return new
stanza/stanza/models/lemma_classifier/base_trainer.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from abc import ABC, abstractmethod
3
+ import logging
4
+ import os
5
+ from typing import List, Tuple, Any, Mapping
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+
11
+ from stanza.models.common.utils import default_device
12
+ from stanza.models.lemma_classifier import utils
13
+ from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
14
+ from stanza.models.lemma_classifier.evaluate_models import evaluate_model
15
+ from stanza.utils.get_tqdm import get_tqdm
16
+
17
+ tqdm = get_tqdm()
18
+ logger = logging.getLogger('stanza.lemmaclassifier')
19
+
20
+ class BaseLemmaClassifierTrainer(ABC):
21
+ def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
22
+ """
23
+ If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss.
24
+ The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the
25
+ frequency of the class in the set. E.g. classes with lower frequency will have higher weight.
26
+ """
27
+ weights = [0 for _ in label_decoder.keys()] # each key in the label decoder is one class, we have one weight per class
28
+ total_samples = sum(counts.values())
29
+ for class_idx in counts:
30
+ weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes)
31
+ weights = torch.tensor(weights)
32
+ logger.info(f"Using weights {weights} for weighted loss.")
33
+ self.criterion = nn.BCEWithLogitsLoss(weight=weights)
34
+
35
+ @abstractmethod
36
+ def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
37
+ """
38
+ Build a model using pieces of the dataset to determine some of the model shape
39
+ """
40
+
41
+ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None:
42
+ """
43
+ Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.
44
+
45
+ Args:
46
+ num_epochs (int): Number of training epochs
47
+ save_name (str): Path to file where trained model should be saved.
48
+ eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
49
+ train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
50
+ """
51
+ # Put model on GPU (if possible)
52
+ device = default_device()
53
+
54
+ if not train_file:
55
+ raise ValueError("Cannot train model - no train_file supplied!")
56
+
57
+ dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
58
+ label_decoder = dataset.label_decoder
59
+ upos_to_id = dataset.upos_to_id
60
+ self.output_dim = len(label_decoder)
61
+ logger.info(f"Loaded dataset successfully from {train_file}")
62
+ logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")
63
+ logger.info(f"Target words: {dataset.target_words}")
64
+
65
+ self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words, set(dataset.target_upos))
66
+ self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
67
+
68
+ self.model.to(device)
69
+ logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}")
70
+
71
+ if os.path.exists(save_name) and not args.get('force', False):
72
+ raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...")
73
+
74
+ if self.weighted_loss:
75
+ self.configure_weighted_loss(label_decoder, dataset.counts)
76
+
77
+ # Put the criterion on GPU too
78
+ logger.debug(f"Criterion on {next(self.model.parameters()).device}")
79
+ self.criterion = self.criterion.to(next(self.model.parameters()).device)
80
+
81
+ best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
82
+ for epoch in range(num_epochs):
83
+ # go over entire dataset with each epoch
84
+ for sentences, positions, upos_tags, labels in tqdm(dataset):
85
+ assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"
86
+
87
+ self.optimizer.zero_grad()
88
+ outputs = self.model(positions, sentences, upos_tags)
89
+
90
+ # Compute loss, which is different if using CE or BCEWithLogitsLoss
91
+ if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
92
+ # TODO: three classes?
93
+ targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)
94
+ # should be shape size (batch_size, 2)
95
+ else: # CELoss accepts target as just raw label
96
+ targets = labels.to(device)
97
+
98
+ loss = self.criterion(outputs, targets)
99
+
100
+ loss.backward()
101
+ self.optimizer.step()
102
+
103
+ logger.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
104
+ if eval_file:
105
+ # Evaluate model on dev set to see if it should be saved.
106
+ _, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)
107
+ logger.info(f"Weighted f1 for model: {f1}")
108
+ if f1 > best_f1:
109
+ best_f1 = f1
110
+ self.model.save(save_name)
111
+ logger.info(f"New best model: weighted f1 score of {f1}.")
112
+ else:
113
+ self.model.save(save_name)
114
+
stanza/stanza/models/lemma_classifier/constants.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ UNKNOWN_TOKEN = "unk" # token name for unknown tokens
4
+ UNKNOWN_TOKEN_IDX = -1 # custom index we apply to unknown tokens
5
+
6
+ # TODO: ModelType could just be LSTM and TRANSFORMER
7
+ # and then the transformer baseline would have the transformer as another argument
8
+ class ModelType(Enum):
9
+ LSTM = 1
10
+ TRANSFORMER = 2
11
+ BERT = 3
12
+ ROBERTA = 4
13
+
14
+ DEFAULT_BATCH_SIZE = 16
stanza/stanza/models/lemma_classifier/evaluate_many.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils to evaluate many models of the same type at once
3
+ """
4
+ import argparse
5
+ import os
6
+ import logging
7
+
8
+ from stanza.models.lemma_classifier.evaluate_models import main as evaluate_main
9
+
10
+
11
+ logger = logging.getLogger('stanza.lemmaclassifier')
12
+
13
+ def evaluate_n_models(path_to_models_dir, args):
14
+
15
+ total_results = {
16
+ "be": 0.0,
17
+ "have": 0.0,
18
+ "accuracy": 0.0,
19
+ "weighted_f1": 0.0
20
+ }
21
+ paths = os.listdir(path_to_models_dir)
22
+ num_models = len(paths)
23
+ for model_path in paths:
24
+ full_path = os.path.join(path_to_models_dir, model_path)
25
+ args.save_name = full_path
26
+ mcc_results, confusion, acc, weighted_f1 = evaluate_main(predefined_args=args)
27
+
28
+ for lemma in mcc_results:
29
+
30
+ lemma_f1 = mcc_results.get(lemma, None).get("f1") * 100
31
+ total_results[lemma] += lemma_f1
32
+
33
+ total_results["accuracy"] += acc
34
+ total_results["weighted_f1"] += weighted_f1
35
+
36
+ total_results["be"] /= num_models
37
+ total_results["have"] /= num_models
38
+ total_results["accuracy"] /= num_models
39
+ total_results["weighted_f1"] /= num_models
40
+
41
+ logger.info(f"Models in {path_to_models_dir} had average weighted f1 of {100 * total_results['weighted_f1']}.\nLemma 'be' had f1: {total_results['be']}\nLemma 'have' had f1: {total_results['have']}.\nAccuracy: {100 * total_results['accuracy']}.\n ({num_models} models evaluated).")
42
+ return total_results
43
+
44
+
45
+ def main():
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
48
+ parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
49
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
50
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
51
+ parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
52
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
53
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
54
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
55
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
56
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
57
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
58
+ parser.add_argument("--eval_file", type=str, help="path to evaluation file")
59
+
60
+ # Args specific to several model eval
61
+ parser.add_argument("--base_path", type=str, default=None, help="path to dir for eval")
62
+
63
+ args = parser.parse_args()
64
+ evaluate_n_models(args.base_path, args)
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
stanza/stanza/models/lemma_classifier/evaluate_models.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ parentdir = os.path.dirname(__file__)
5
+ parentdir = os.path.dirname(parentdir)
6
+ parentdir = os.path.dirname(parentdir)
7
+ sys.path.append(parentdir)
8
+
9
+ import logging
10
+ import argparse
11
+ import os
12
+
13
+ from typing import Any, List, Tuple, Mapping
14
+ from collections import defaultdict
15
+ from numpy import random
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ import stanza
21
+
22
+ from stanza.models.common.utils import default_device
23
+ from stanza.models.lemma_classifier import utils
24
+ from stanza.models.lemma_classifier.base_model import LemmaClassifier
25
+ from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
26
+ from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
27
+ from stanza.utils.confusion import format_confusion
28
+ from stanza.utils.get_tqdm import get_tqdm
29
+
30
+ tqdm = get_tqdm()
31
+
32
+ logger = logging.getLogger('stanza.lemmaclassifier')
33
+
34
+
35
+ def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float:
36
+ """
37
+ Computes the weighted F1 score across an evaluation set.
38
+
39
+ The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more
40
+ examples in the evaluation more impactful to the weighted f1.
41
+ """
42
+ num_total_examples = 0
43
+ weighted_f1 = 0
44
+
45
+ for class_id in mcc_results:
46
+ class_f1 = mcc_results.get(class_id).get("f1")
47
+ num_class_examples = sum(confusion.get(class_id).values())
48
+ weighted_f1 += class_f1 * num_class_examples
49
+ num_total_examples += num_class_examples
50
+
51
+ return weighted_f1 / num_total_examples
52
+
53
+
54
+ def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True):
55
+ """
56
+ Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes.
57
+
58
+ Precision = true positives / true positives + false positives
59
+ Recall = true positives / true positives + false negatives
60
+ F1 = 2 * (Precision * Recall) / (Precision + Recall)
61
+
62
+ Returns:
63
+ 1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores.
64
+ e.g. multiclass_results[0]["precision"] would give class 0's precision.
65
+ 2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count.
66
+ e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times.
67
+ """
68
+ assert len(gold_tag_sequences) == len(pred_tag_sequences), \
69
+ f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}"
70
+
71
+ confusion = defaultdict(lambda: defaultdict(int))
72
+
73
+ reverse_label_decoder = {y: x for x, y in label_decoder.items()}
74
+ for gold, pred in zip(gold_tag_sequences, pred_tag_sequences):
75
+ confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1
76
+
77
+ multi_class_result = defaultdict(lambda: defaultdict(float))
78
+ # compute precision, recall and f1 for each class and store inside of `multi_class_result`
79
+ for gold_tag in confusion.keys():
80
+
81
+ try:
82
+ prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()])
83
+ except ZeroDivisionError:
84
+ prec = 0.0
85
+
86
+ try:
87
+ recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values())
88
+ except ZeroDivisionError:
89
+ recall = 0.0
90
+
91
+ try:
92
+ f1 = 2 * (prec * recall) / (prec + recall)
93
+ except ZeroDivisionError:
94
+ f1 = 0.0
95
+
96
+ multi_class_result[gold_tag] = {
97
+ "precision": prec,
98
+ "recall": recall,
99
+ "f1": f1
100
+ }
101
+
102
+ if verbose:
103
+ for lemma in multi_class_result:
104
+ logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}")
105
+
106
+ weighted_f1 = get_weighted_f1(multi_class_result, confusion)
107
+
108
+ return multi_class_result, confusion, weighted_f1
109
+
110
+
111
+ def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor:
112
+ """
113
+ A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token.
114
+
115
+ Args:
116
+ model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token.
117
+ position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch.
118
+ sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences.
119
+
120
+ Returns:
121
+ (int): The index of the predicted class in `model`'s output.
122
+ """
123
+ with torch.no_grad():
124
+ logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size)
125
+ predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)
126
+
127
+ return predicted_class
128
+
129
+
130
+ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]:
131
+ """
132
+ Helper function for model evaluation
133
+
134
+ Args:
135
+ model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`.
136
+ model_path (str): Path to the saved model weights that will be loaded into `model`.
137
+ eval_path (str): Path to the saved evaluation dataset.
138
+ verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True.
139
+ is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode.
140
+
141
+ Returns:
142
+ 1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is
143
+ another map with key of "f1", "precision", or "recall" with corresponding values.
144
+ 2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the
145
+ map with the key as the predicted tag and corresponding count of that (gold, pred) pair.
146
+ 3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set.
147
+ """
148
+ # load model
149
+ device = default_device()
150
+ model.to(device)
151
+
152
+ if not is_training:
153
+ model.eval() # set to eval mode
154
+
155
+ # load in eval data
156
+ dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False)
157
+
158
+ logger.info(f"Evaluating on evaluation file {eval_path}")
159
+
160
+ correct, total = 0, 0
161
+ gold_tags, pred_tags = dataset.labels, []
162
+
163
+ # run eval on each example from dataset
164
+ for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"):
165
+ pred = model_predict(model, pos_indices, sentences, upos_tags) # Pred should be size (batch_size, )
166
+ correct_preds = pred == labels.to(device)
167
+ correct += torch.sum(correct_preds)
168
+ total += len(correct_preds)
169
+ pred_tags += pred.tolist()
170
+
171
+ logger.info("Finished evaluating on dataset. Computing scores...")
172
+ accuracy = correct / total
173
+
174
+ mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose)
175
+ # add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper
176
+ if verbose:
177
+ logger.info(f"Accuracy: {accuracy} ({correct}/{total})")
178
+ logger.info(f"Label decoder: {dataset.label_decoder}")
179
+
180
+ return mc_results, confusion, accuracy, weighted_f1
181
+
182
+
183
+ def main(args=None, predefined_args=None):
184
+
185
+ # TODO: can unify this script with train_lstm_model.py?
186
+ # TODO: can save the model type in the model .pt, then
187
+ # automatically figure out what type of model we are using by
188
+ # looking in the file
189
+ parser = argparse.ArgumentParser()
190
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab")
191
+ parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)")
192
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
193
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
194
+ parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings")
195
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
196
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
197
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
198
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file")
199
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')")
200
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
201
+ parser.add_argument("--eval_file", type=str, help="path to evaluation file")
202
+
203
+ args = parser.parse_args(args) if not predefined_args else predefined_args
204
+
205
+ logger.info("Running training script with the following args:")
206
+ args = vars(args)
207
+ for arg in args:
208
+ logger.info(f"{arg}: {args[arg]}")
209
+ logger.info("------------------------------------------------------------")
210
+
211
+ logger.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}")
212
+ model = LemmaClassifier.load(args['save_name'], args)
213
+
214
+ mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file'])
215
+
216
+ logger.info(f"MCC Results: {dict(mcc_results)}")
217
+ logger.info("______________________________________________")
218
+ logger.info(f"Confusion:\n%s", format_confusion(confusion))
219
+ logger.info("______________________________________________")
220
+ logger.info(f"Accuracy: {acc}")
221
+ logger.info("______________________________________________")
222
+ logger.info(f"Weighted f1: {weighted_f1}")
223
+
224
+ return mcc_results, confusion, acc, weighted_f1
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()
stanza/stanza/models/lemma_classifier/prepare_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ import stanza
7
+ from stanza.models.lemma_classifier import utils
8
+
9
+ from typing import List, Tuple, Any
10
+
11
+ """
12
+ The code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token.
13
+ Furthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma.
14
+ """
15
+
16
+
17
+ def load_doc_from_conll_file(path: str):
18
+ """"
19
+ loads in a Stanza document object from a path to a CoNLL file containing annotated sentences.
20
+ """
21
+ return stanza.utils.conll.CoNLL.conll2doc(path)
22
+
23
+
24
+ class DataProcessor():
25
+
26
+ def __init__(self, target_word: str, target_upos: List[str], allowed_lemmas: str):
27
+ self.target_word = target_word
28
+ self.target_word_regex = re.compile(target_word)
29
+ self.target_upos = target_upos
30
+ self.allowed_lemmas = re.compile(allowed_lemmas)
31
+
32
+ def keep_sentence(self, sentence):
33
+ for word in sentence.words:
34
+ if self.target_word_regex.fullmatch(word.text) and word.upos in self.target_upos:
35
+ return True
36
+ return False
37
+
38
+ def find_all_occurrences(self, sentence) -> List[int]:
39
+ """
40
+ Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences.
41
+ """
42
+ occurrences = []
43
+ for idx, token in enumerate(sentence.words):
44
+ if self.target_word_regex.fullmatch(token.text) and token.upos in self.target_upos:
45
+ occurrences.append(idx)
46
+ return occurrences
47
+
48
+ @staticmethod
49
+ def write_output_file(save_name, target_upos, sentences):
50
+ with open(save_name, "w+", encoding="utf-8") as output_f:
51
+ output_f.write("{\n")
52
+ output_f.write(' "upos": %s,\n' % json.dumps(target_upos))
53
+ output_f.write(' "sentences": [')
54
+ wrote_sentence = False
55
+ for sentence in sentences:
56
+ if not wrote_sentence:
57
+ output_f.write("\n ")
58
+ wrote_sentence = True
59
+ else:
60
+ output_f.write(",\n ")
61
+ output_f.write(json.dumps(sentence))
62
+ output_f.write("\n ]\n}\n")
63
+
64
+ def process_document(self, doc, save_name: str) -> None:
65
+ """
66
+ Takes any sentence from `doc` that meets the condition of `keep_sentence` and writes its tokens, index of target word, and lemma to `save_name`
67
+
68
+ Sentences that meet `keep_sentence` and contain `self.target_word` multiple times have each instance in a different example in the output file.
69
+
70
+ Args:
71
+ doc (Stanza.doc): Document object that represents the file to be analyzed
72
+ save_name (str): Path to the file for storing output
73
+ """
74
+ sentences = []
75
+ for sentence in doc.sentences:
76
+ # for each sentence, we need to determine if it should be added to the output file.
77
+ # if the sentence fulfills keep_sentence, then we will save it along with the target word's index and its corresponding lemma
78
+ if self.keep_sentence(sentence):
79
+ tokens = [token.text for token in sentence.words]
80
+ indexes = self.find_all_occurrences(sentence)
81
+ for idx in indexes:
82
+ if self.allowed_lemmas.fullmatch(sentence.words[idx].lemma):
83
+ # for each example found, we write the tokens,
84
+ # their respective upos tags, the target token index,
85
+ # and the target lemma
86
+ upos_tags = [sentence.words[i].upos for i in range(len(sentence.words))]
87
+ num_tokens = len(upos_tags)
88
+ sentences.append({
89
+ "words": tokens,
90
+ "upos_tags": upos_tags,
91
+ "index": idx,
92
+ "lemma": sentence.words[idx].lemma
93
+ })
94
+
95
+ if save_name:
96
+ self.write_output_file(save_name, self.target_upos, sentences)
97
+ return sentences
98
+
99
+ def main(args=None):
100
+ parser = argparse.ArgumentParser()
101
+
102
+ parser.add_argument("--conll_path", type=str, default=os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu"), help="path to the conll file to translate")
103
+ parser.add_argument("--target_word", type=str, default="'s", help="Token to classify on, e.g. 's.")
104
+ parser.add_argument("--target_upos", type=str, default="AUX", help="upos on target token")
105
+ parser.add_argument("--output_path", type=str, default="test_output.txt", help="Path for output file")
106
+ parser.add_argument("--allowed_lemmas", type=str, default=".*", help="A regex for allowed lemmas. If not set, all lemmas are allowed")
107
+
108
+ args = parser.parse_args(args)
109
+
110
+ conll_path = args.conll_path
111
+ target_upos = args.target_upos
112
+ output_path = args.output_path
113
+ allowed_lemmas = args.allowed_lemmas
114
+
115
+ args = vars(args)
116
+ for arg in args:
117
+ print(f"{arg}: {args[arg]}")
118
+
119
+ doc = load_doc_from_conll_file(conll_path)
120
+ processor = DataProcessor(target_word=args['target_word'], target_upos=[target_upos], allowed_lemmas=allowed_lemmas)
121
+
122
+ return processor.process_document(doc, output_path)
123
+
124
+ if __name__ == "__main__":
125
+ main()
stanza/stanza/models/lemma_classifier/train_lstm_model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The code in this file works to train a lemma classifier for 's
3
+ """
4
+
5
+ import argparse
6
+ import logging
7
+ import os
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from stanza.models.common.foundation_cache import load_pretrain
13
+ from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
14
+ from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
15
+ from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
16
+
17
+ logger = logging.getLogger('stanza.lemmaclassifier')
18
+
19
+ class LemmaClassifierTrainer(BaseLemmaClassifierTrainer):
20
+ """
21
+ Class to assist with training a LemmaClassifierLSTM
22
+ """
23
+
24
+ def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = False, charlm_forward_file: str = None, charlm_backward_file: str = None, lr: float = 0.001, loss_func: str = None):
25
+ """
26
+ Initializes the LemmaClassifierTrainer class.
27
+
28
+ Args:
29
+ model_args (dict): Various model shape parameters
30
+ embedding_file (str): What word embeddings file to use. Use a Stanza pretrain .pt
31
+ use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False.
32
+ charlm_forward_file (str): Path to the forward pass embeddings for the charlm
33
+ charlm_backward_file (str): Path to the backward pass embeddings for the charlm
34
+ upos_emb_dim (int): The dimension size of UPOS tag embeddings
35
+ num_heads (int): The number of attention heads to use.
36
+ lr (float): Learning rate, defaults to 0.001.
37
+ loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce')
38
+
39
+ Raises:
40
+ FileNotFoundError: If the forward charlm file is not present
41
+ FileNotFoundError: If the backward charlm file is not present
42
+ """
43
+ super().__init__()
44
+
45
+ self.model_args = model_args
46
+
47
+ # Load word embeddings
48
+ pt = load_pretrain(embedding_file)
49
+ self.pt_embedding = pt
50
+
51
+ # Load CharLM embeddings
52
+ if use_charlm and charlm_forward_file is not None and not os.path.exists(charlm_forward_file):
53
+ raise FileNotFoundError(f"Could not find forward charlm file: {charlm_forward_file}")
54
+ if use_charlm and charlm_backward_file is not None and not os.path.exists(charlm_backward_file):
55
+ raise FileNotFoundError(f"Could not find backward charlm file: {charlm_backward_file}")
56
+
57
+ # TODO: just pass around the args instead
58
+ self.use_charlm = use_charlm
59
+ self.charlm_forward_file = charlm_forward_file
60
+ self.charlm_backward_file = charlm_backward_file
61
+ self.lr = lr
62
+
63
+ # Find loss function
64
+ if loss_func == "ce":
65
+ self.criterion = nn.CrossEntropyLoss()
66
+ self.weighted_loss = False
67
+ logger.debug("Using CE loss")
68
+ elif loss_func == "weighted_bce":
69
+ self.criterion = nn.BCEWithLogitsLoss()
70
+ self.weighted_loss = True # used to add weights during train time.
71
+ logger.debug("Using Weighted BCE loss")
72
+ else:
73
+ raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
74
+
75
+ def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
76
+ return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,
77
+ use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file)
78
+
79
+ def build_argparse():
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
82
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
83
+ parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
84
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
85
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
86
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
87
+ parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
88
+ parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
89
+ parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
90
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
91
+ parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
92
+ parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
93
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
94
+ parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
95
+ parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
96
+ parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
97
+ parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
98
+ return parser
99
+
100
+ def main(args=None, predefined_args=None):
101
+ parser = build_argparse()
102
+ args = parser.parse_args(args) if predefined_args is None else predefined_args
103
+
104
+ wordvec_pretrain_file = args.wordvec_pretrain_file
105
+ use_charlm = args.use_charlm
106
+ charlm_forward_file = args.charlm_forward_file
107
+ charlm_backward_file = args.charlm_backward_file
108
+ upos_emb_dim = args.upos_emb_dim
109
+ use_attention = args.attn
110
+ num_heads = args.num_heads
111
+ save_name = args.save_name
112
+ lr = args.lr
113
+ num_epochs = args.num_epochs
114
+ train_file = args.train_file
115
+ weighted_loss = args.weighted_loss
116
+ eval_file = args.eval_file
117
+
118
+ args = vars(args)
119
+
120
+ if os.path.exists(save_name) and not args.get('force', False):
121
+ raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
122
+ if not os.path.exists(train_file):
123
+ raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
124
+
125
+ logger.info("Running training script with the following args:")
126
+ for arg in args:
127
+ logger.info(f"{arg}: {args[arg]}")
128
+ logger.info("------------------------------------------------------------")
129
+
130
+ trainer = LemmaClassifierTrainer(model_args=args,
131
+ embedding_file=wordvec_pretrain_file,
132
+ use_charlm=use_charlm,
133
+ charlm_forward_file=charlm_forward_file,
134
+ charlm_backward_file=charlm_backward_file,
135
+ lr=lr,
136
+ loss_func="weighted_bce" if weighted_loss else "ce",
137
+ )
138
+
139
+ trainer.train(
140
+ num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_file=train_file
141
+ )
142
+
143
+ return trainer
144
+
145
+ if __name__ == "__main__":
146
+ main()
147
+
stanza/stanza/models/lemma_classifier/train_many.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils for training and evaluating multiple models simultaneously
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+
8
+ from stanza.models.lemma_classifier.train_lstm_model import main as train_lstm_main
9
+ from stanza.models.lemma_classifier.train_transformer_model import main as train_tfmr_main
10
+ from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
11
+
12
+
13
+ change_params_map = {
14
+ "lstm_layer": [16, 32, 64, 128, 256, 512],
15
+ "upos_emb_dim": [5, 10, 20, 30],
16
+ "training_size": [150, 300, 450, 600, 'full'],
17
+ } # TODO: Add attention
18
+
19
+ def train_n_models(num_models: int, base_path: str, args):
20
+
21
+ if args.change_param == "lstm_layer":
22
+ for num_layers in change_params_map.get("lstm_layer", None):
23
+ for i in range(num_models):
24
+ new_save_name = os.path.join(base_path, f"{num_layers}_{i}.pt")
25
+ args.save_name = new_save_name
26
+ args.hidden_dim = num_layers
27
+ train_lstm_main(predefined_args=args)
28
+
29
+ if args.change_param == "upos_emb_dim":
30
+ for upos_dim in change_params_map("upos_emb_dim", None):
31
+ for i in range(num_models):
32
+ new_save_name = os.path.join(base_path, f"dim_{upos_dim}_{i}.pt")
33
+ args.save_name = new_save_name
34
+ args.upos_emb_dim = upos_dim
35
+ train_lstm_main(predefined_args=args)
36
+
37
+ if args.change_param == "training_size":
38
+ for size in change_params_map.get("training_size", None):
39
+ for i in range(num_models):
40
+ new_save_name = os.path.join(base_path, f"{size}_examples_{i}.pt")
41
+ new_train_file = os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt")
42
+ args.save_name = new_save_name
43
+ args.train_file = new_train_file
44
+ train_lstm_main(predefined_args=args)
45
+
46
+ if args.change_param == "base":
47
+ for i in range(num_models):
48
+ new_save_name = os.path.join(base_path, f"lstm_model_{i}.pt")
49
+ args.save_name = new_save_name
50
+ args.weighted_loss = False
51
+ train_lstm_main(predefined_args=args)
52
+
53
+ if not args.weighted_loss:
54
+ args.weighted_loss = True
55
+ new_save_name = os.path.join(base_path, f"lstm_model_wloss_{i}.pt")
56
+ args.save_name = new_save_name
57
+ train_lstm_main(predefined_args=args)
58
+
59
+ if args.change_param == "base_charlm":
60
+ for i in range(num_models):
61
+ new_save_name = os.path.join(base_path, f"lstm_charlm_{i}.pt")
62
+ args.save_name = new_save_name
63
+ train_lstm_main(predefined_args=args)
64
+
65
+ if args.change_param == "base_charlm_upos":
66
+ for i in range(num_models):
67
+ new_save_name = os.path.join(base_path, f"lstm_charlm_upos_{i}.pt")
68
+ args.save_name = new_save_name
69
+ train_lstm_main(predefined_args=args)
70
+
71
+ if args.change_param == "base_upos":
72
+ for i in range(num_models):
73
+ new_save_name = os.path.join(base_path, f"lstm_upos_{i}.pt")
74
+ args.save_name = new_save_name
75
+ train_lstm_main(predefined_args=args)
76
+
77
+ if args.change_param == "attn_model":
78
+ for i in range(num_models):
79
+ new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt")
80
+ args.save_name = new_save_name
81
+ train_lstm_main(predefined_args=args)
82
+
83
+ def train_n_tfmrs(num_models: int, base_path: str, args):
84
+
85
+ if args.multi_train_type == "tfmr":
86
+
87
+ for i in range(num_models):
88
+
89
+ if args.change_param == "bert":
90
+ new_save_name = os.path.join(base_path, f"bert_{i}.pt")
91
+ args.save_name = new_save_name
92
+ args.loss_fn = "ce"
93
+ train_tfmr_main(predefined_args=args)
94
+
95
+ new_save_name = os.path.join(base_path, f"bert_wloss_{i}.pt")
96
+ args.save_name = new_save_name
97
+ args.loss_fn = "weighted_bce"
98
+ train_tfmr_main(predefined_args=args)
99
+
100
+ elif args.change_param == "roberta":
101
+ new_save_name = os.path.join(base_path, f"roberta_{i}.pt")
102
+ args.save_name = new_save_name
103
+ args.loss_fn = "ce"
104
+ train_tfmr_main(predefined_args=args)
105
+
106
+ new_save_name = os.path.join(base_path, f"roberta_wloss_{i}.pt")
107
+ args.save_name = new_save_name
108
+ args.loss_fn = "weighted_bce"
109
+ train_tfmr_main(predefined_args=args)
110
+
111
+
112
+ def main():
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer")
115
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read')
116
+ parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings")
117
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
118
+ parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file")
119
+ parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file")
120
+ parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.")
121
+ parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.')
122
+ parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.")
123
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file")
124
+ parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
125
+ parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
126
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
127
+ parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file")
128
+ parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.")
129
+ parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
130
+ # Tfmr-specific args
131
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
132
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
133
+ parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
134
+ # Multi-model train args
135
+ parser.add_argument("--multi_train_type", type=str, default="lstm", help="Whether you are attempting to multi-train an LSTM or transformer")
136
+ parser.add_argument("--multi_train_count", type=int, default=5, help="Number of each model to build")
137
+ parser.add_argument("--base_path", type=str, default=None, help="Path to start generating model type for.")
138
+ parser.add_argument("--change_param", type=str, default=None, help="Which hyperparameter to change when training")
139
+
140
+
141
+ args = parser.parse_args()
142
+
143
+ if args.multi_train_type == "lstm":
144
+ train_n_models(num_models=args.multi_train_count,
145
+ base_path=args.base_path,
146
+ args=args)
147
+ elif args.multi_train_type == "tfmr":
148
+ train_n_tfmrs(num_models=args.multi_train_count,
149
+ base_path=args.base_path,
150
+ args=args)
151
+ else:
152
+ raise ValueError(f"Improper input {args.multi_train_type}")
153
+
154
+ if __name__ == "__main__":
155
+ main()
stanza/stanza/models/lemma_classifier/train_transformer_model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains code used to train a baseline transformer model to classify on a lemma of a particular token.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import sys
8
+ import logging
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+
14
+ from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer
15
+ from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
16
+ from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
17
+ from stanza.models.common.utils import default_device
18
+
19
+ logger = logging.getLogger('stanza.lemmaclassifier')
20
+
21
+ class TransformerBaselineTrainer(BaseLemmaClassifierTrainer):
22
+ """
23
+ Class to assist with training a baseline transformer model to classify on token lemmas.
24
+ To find the model spec, refer to `model.py` in this directory.
25
+ """
26
+
27
+ def __init__(self, model_args: dict, transformer_name: str = "roberta", loss_func: str = "ce", lr: int = 0.001):
28
+ """
29
+ Creates the Trainer object
30
+
31
+ Args:
32
+ transformer_name (str, optional): What kind of transformer to use for embeddings. Defaults to "roberta".
33
+ loss_func (str, optional): Which loss function to use (either 'ce' or 'weighted_bce'). Defaults to "ce".
34
+ lr (int, optional): learning rate for the optimizer. Defaults to 0.001.
35
+ """
36
+ super().__init__()
37
+
38
+ self.model_args = model_args
39
+
40
+ # Find loss function
41
+ if loss_func == "ce":
42
+ self.criterion = nn.CrossEntropyLoss()
43
+ self.weighted_loss = False
44
+ elif loss_func == "weighted_bce":
45
+ self.criterion = nn.BCEWithLogitsLoss()
46
+ self.weighted_loss = True # used to add weights during train time.
47
+ else:
48
+ raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")
49
+
50
+ self.transformer_name = transformer_name
51
+ self.lr = lr
52
+
53
+ def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim:
54
+ """
55
+ Sets learning rates for each layer of the model.
56
+ Currently, the model has the transformer layer and the MLP layer, so these are tweakable.
57
+
58
+ Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer.
59
+
60
+ Currently unused - could be refactored into the parent class's train method,
61
+ or the parent class could call a build_optimizer and this subclass would use the optimizer
62
+ """
63
+ transformer_params, mlp_params = [], []
64
+ for name, param in self.model.named_parameters():
65
+ if 'transformer' in name:
66
+ transformer_params.append(param)
67
+ elif 'mlp' in name:
68
+ mlp_params.append(param)
69
+ optimizer = optim.Adam([
70
+ {"params": transformer_params, "lr": transformer_lr},
71
+ {"params": mlp_params, "lr": mlp_lr}
72
+ ])
73
+ return optimizer
74
+
75
+ def build_model(self, label_decoder, upos_to_id, known_words, target_words, target_upos):
76
+ return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words, target_upos=target_upos)
77
+
78
+
79
+ def main(args=None, predefined_args=None):
80
+ parser = argparse.ArgumentParser()
81
+
82
+ parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "saved_models", "big_model_roberta_weighted_loss.pt"), help="Path to model save file")
83
+ parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs")
84
+ parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_train.txt"), help="Full path to training file")
85
+ parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')")
86
+ parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta")
87
+ parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')")
88
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch")
89
+ parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves")
90
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.")
91
+ parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file')
92
+
93
+ args = parser.parse_args(args) if predefined_args is None else predefined_args
94
+
95
+ save_name = args.save_name
96
+ num_epochs = args.num_epochs
97
+ train_file = args.train_file
98
+ loss_fn = args.loss_fn
99
+ eval_file = args.eval_file
100
+ lr = args.lr
101
+
102
+ args = vars(args)
103
+
104
+ if args['model_type'] == 'bert':
105
+ args['bert_model'] = 'bert-base-uncased'
106
+ elif args['model_type'] == 'roberta':
107
+ args['bert_model'] = 'roberta-base'
108
+ elif args['model_type'] == 'transformer':
109
+ if args['bert_model'] is None:
110
+ raise ValueError("Need to specify a bert_model for model_type transformer!")
111
+ else:
112
+ raise ValueError("Unknown model type " + args['model_type'])
113
+
114
+ if os.path.exists(save_name) and not args.get('force', False):
115
+ raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...")
116
+ if not os.path.exists(train_file):
117
+ raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.")
118
+
119
+ logger.info("Running training script with the following args:")
120
+ for arg in args:
121
+ logger.info(f"{arg}: {args[arg]}")
122
+ logger.info("------------------------------------------------------------")
123
+
124
+ trainer = TransformerBaselineTrainer(model_args=args, transformer_name=args['bert_model'], loss_func=loss_fn, lr=lr)
125
+
126
+ trainer.train(num_epochs=num_epochs, save_name=save_name, train_file=train_file, args=args, eval_file=eval_file)
127
+ return trainer
128
+
129
+ if __name__ == "__main__":
130
+ main()
stanza/stanza/models/lemma_classifier/transformer_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import sys
5
+ import logging
6
+
7
+ from transformers import AutoTokenizer, AutoModel
8
+ from typing import Mapping, List, Tuple, Any
9
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence
10
+ from stanza.models.common.bert_embedding import extract_bert_embeddings
11
+ from stanza.models.lemma_classifier.base_model import LemmaClassifier
12
+ from stanza.models.lemma_classifier.constants import ModelType
13
+
14
+ logger = logging.getLogger('stanza.lemmaclassifier')
15
+
16
+ class LemmaClassifierWithTransformer(LemmaClassifier):
17
+ def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set, target_upos: set):
18
+ """
19
+ Model architecture:
20
+
21
+ Use a transformer (BERT or RoBERTa) to extract contextual embedding over a sentence.
22
+ Get the embedding for the word that is to be classified on, and feed the embedding
23
+ as input to an MLP classifier that has 2 linear layers, and a prediction head.
24
+
25
+ Args:
26
+ model_args (dict): args for the model
27
+ output_dim (int): Dimension of the output from the MLP
28
+ transformer_name (str): name of the HF transformer to use
29
+ label_decoder (dict): a map of the labels available to the model
30
+ target_words (set(str)): a set of the words which might need lemmatization
31
+ """
32
+ super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words, target_upos)
33
+ self.model_args = model_args
34
+
35
+ # Choose transformer
36
+ self.transformer_name = transformer_name
37
+ self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True)
38
+ self.add_unsaved_module("transformer", AutoModel.from_pretrained(transformer_name))
39
+ config = self.transformer.config
40
+
41
+ embedding_size = config.hidden_size
42
+
43
+ # define an MLP layer
44
+ self.mlp = nn.Sequential(
45
+ nn.Linear(embedding_size, 64),
46
+ nn.ReLU(),
47
+ nn.Linear(64, output_dim)
48
+ )
49
+
50
+ def get_save_dict(self):
51
+ save_dict = {
52
+ "params": self.state_dict(),
53
+ "label_decoder": self.label_decoder,
54
+ "target_words": list(self.target_words),
55
+ "target_upos": list(self.target_upos),
56
+ "model_type": self.model_type().name,
57
+ "args": self.model_args,
58
+ }
59
+ skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)]
60
+ for k in skipped:
61
+ del save_dict["params"][k]
62
+ return save_dict
63
+
64
+ def convert_tags(self, upos_tags: List[List[str]]):
65
+ return None
66
+
67
+ def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):
68
+ """
69
+ Computes the forward pass of the transformer baselines
70
+
71
+ Args:
72
+ idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.
73
+ sentences (List[List[str]]): A list of the token-split sentences of the input data.
74
+ upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility
75
+
76
+ Returns:
77
+ torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences.
78
+ """
79
+ device = next(self.transformer.parameters()).device
80
+ bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device,
81
+ keep_endpoints=False, num_layers=1, detach=True)
82
+ embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)]
83
+ embeddings = torch.stack(embeddings, dim=0)[:, :, 0]
84
+ # pass to the MLP
85
+ output = self.mlp(embeddings)
86
+ return output
87
+
88
+ def model_type(self):
89
+ return ModelType.TRANSFORMER
stanza/stanza/models/lemma_classifier/utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ from typing import List, Tuple, Any, Mapping
7
+
8
+ import stanza
9
+ import torch
10
+
11
+ from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
12
+
13
+ logger = logging.getLogger('stanza.lemmaclassifier')
14
+
15
+ class Dataset:
16
+ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True):
17
+ """
18
+ Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence.
19
+
20
+ Args:
21
+ data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
22
+ batch_size (int): Size of each batch of examples
23
+ get_counts (optional, bool): Whether there should be a map of the label index to counts
24
+
25
+ Returns:
26
+ 1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence
27
+ 2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence
28
+ 3. List[torch.tensor[int]]: A batch of labels for the target token's lemma
29
+ 4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.)
30
+ 5 (Optional): A mapping of label ID to counts in the dataset.
31
+ 6. Mapping[str, int]: A map between the labels and their indexes
32
+ 7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches
33
+ """
34
+
35
+ if data_path is None or not os.path.exists(data_path):
36
+ raise FileNotFoundError(f"Data file {data_path} could not be found.")
37
+
38
+ if label_decoder is None:
39
+ label_decoder = {}
40
+ else:
41
+ # if labels in the test set aren't in the original model,
42
+ # the model will never predict those labels,
43
+ # but we can still use those labels in a confusion matrix
44
+ label_decoder = dict(label_decoder)
45
+
46
+ logger.debug("Final label decoder: %s Should be strings to ints", label_decoder)
47
+
48
+ # words which we are analyzing
49
+ target_words = set()
50
+
51
+ # all known words in the dataset, not just target words
52
+ known_words = set()
53
+
54
+ with open(data_path, "r+", encoding="utf-8") as fin:
55
+ sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), {}
56
+
57
+ input_json = json.load(fin)
58
+ sentences_data = input_json['sentences']
59
+ self.target_upos = input_json['upos']
60
+
61
+ for idx, sentence in enumerate(sentences_data):
62
+ # TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons
63
+ words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma")
64
+ if None in [words, target_idx, upos_tags, label]:
65
+ raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}")
66
+
67
+ label_id = label_decoder.get(label, None)
68
+ if label_id is None:
69
+ label_decoder[label] = len(label_decoder) # create a new ID for the unknown label
70
+
71
+ converted_upos_tags = [] # convert upos tags to upos IDs
72
+ for upos_tag in upos_tags:
73
+ if upos_tag not in upos_to_id:
74
+ upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag
75
+ converted_upos_tags.append(upos_to_id[upos_tag])
76
+
77
+ sentences.append(words)
78
+ indices.append(target_idx)
79
+ upos_ids.append(converted_upos_tags)
80
+ labels.append(label_decoder[label])
81
+
82
+ if get_counts:
83
+ counts[label_decoder[label]] += 1
84
+
85
+ target_words.add(words[target_idx])
86
+ known_words.update(words)
87
+
88
+ self.sentences = sentences
89
+ self.indices = indices
90
+ self.upos_ids = upos_ids
91
+ self.labels = labels
92
+
93
+ self.counts = counts
94
+ self.label_decoder = label_decoder
95
+ self.upos_to_id = upos_to_id
96
+
97
+ self.batch_size = batch_size
98
+ self.shuffle = shuffle
99
+
100
+ self.known_words = [x.lower() for x in sorted(known_words)]
101
+ self.target_words = set(x.lower() for x in target_words)
102
+
103
+ def __len__(self):
104
+ """
105
+ Number of batches, rounded up to nearest batch
106
+ """
107
+ return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0)
108
+
109
+ def __iter__(self):
110
+ num_sentences = len(self.sentences)
111
+ indices = list(range(num_sentences))
112
+ if self.shuffle:
113
+ random.shuffle(indices)
114
+ for i in range(self.__len__()):
115
+ batch_start = self.batch_size * i
116
+ batch_end = min(batch_start + self.batch_size, num_sentences)
117
+
118
+ batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]]
119
+ batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]])
120
+ batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]]
121
+ batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]])
122
+ yield batch_sentences, batch_indices, batch_upos_ids, batch_labels
123
+
124
+ def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]:
125
+ """
126
+ Extracts the indices within `tokenized_indices` which match `unknown_token_idx`
127
+
128
+ Args:
129
+ tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices.
130
+ unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors.
131
+
132
+ Returns:
133
+ List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index`
134
+ """
135
+ return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx]
136
+
137
+
138
+ def get_device():
139
+ """
140
+ Get the device to run computations on
141
+ """
142
+ if torch.cuda.is_available:
143
+ device = torch.device("cuda")
144
+ if torch.backends.mps.is_available():
145
+ device = torch.device("mps")
146
+ else:
147
+ device = torch.device("cpu")
148
+
149
+ return device
150
+
151
+
152
+ def round_up_to_multiple(number, multiple):
153
+ if multiple == 0:
154
+ return "Error: The second number (multiple) cannot be zero."
155
+
156
+ # Calculate the remainder when dividing the number by the multiple
157
+ remainder = number % multiple
158
+
159
+ # If remainder is non-zero, round up to the next multiple
160
+ if remainder != 0:
161
+ rounded_number = number + (multiple - remainder)
162
+ else:
163
+ rounded_number = number # No rounding needed
164
+
165
+ return rounded_number
166
+
167
+
168
+ def main():
169
+ default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") # get the GUM stuff
170
+ sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True)
171
+
172
+ if __name__ == "__main__":
173
+ main()
stanza/stanza/models/mwt/character_classifier.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Classify characters based on an LSTM with learned character representations
3
+ """
4
+
5
+ import logging
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ import stanza.models.common.seq2seq_constant as constant
11
+
12
+ logger = logging.getLogger('stanza')
13
+
14
+ class CharacterClassifier(nn.Module):
15
+ def __init__(self, args):
16
+ super().__init__()
17
+
18
+ self.vocab_size = args['vocab_size']
19
+ self.emb_dim = args['emb_dim']
20
+ self.hidden_dim = args['hidden_dim']
21
+ self.nlayers = args['num_layers'] # lstm encoder layers
22
+ self.pad_token = constant.PAD_ID
23
+ self.enc_hidden_dim = self.hidden_dim // 2 # since it is bidirectional
24
+
25
+ self.num_outputs = 2
26
+
27
+ self.args = args
28
+
29
+ self.emb_dropout = args.get('emb_dropout', 0.0)
30
+ self.emb_drop = nn.Dropout(self.emb_dropout)
31
+ self.dropout = args['dropout']
32
+
33
+ self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
34
+ self.input_dim = self.emb_dim
35
+ self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
36
+ bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
37
+
38
+ self.output_layer = nn.Sequential(
39
+ nn.Linear(self.hidden_dim, self.hidden_dim),
40
+ nn.ReLU(),
41
+ nn.Linear(self.hidden_dim, self.num_outputs))
42
+
43
+ def encode(self, enc_inputs, lens):
44
+ """ Encode source sequence. """
45
+ packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
46
+ packed_h_in, (hn, cn) = self.encoder(packed_inputs)
47
+ return packed_h_in
48
+
49
+ def embed(self, src, src_mask):
50
+ # the input data could have characters outside the known range
51
+ # of characters in cases where the vocabulary was temporarily
52
+ # expanded (note that this model does nothing with those chars)
53
+ embed_src = src.clone()
54
+ embed_src[embed_src >= self.vocab_size] = constant.UNK_ID
55
+ enc_inputs = self.emb_drop(self.embedding(embed_src))
56
+ batch_size = enc_inputs.size(0)
57
+ src_lens = list(src_mask.data.eq(self.pad_token).long().sum(1))
58
+ return enc_inputs, batch_size, src_lens, src_mask
59
+
60
+ def forward(self, src, src_mask):
61
+ enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask)
62
+ encoded = self.encode(enc_inputs, src_lens)
63
+ encoded, _ = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True)
64
+ logits = self.output_layer(encoded)
65
+ return logits
stanza/stanza/models/mwt/trainer.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A trainer class to handle training and testing of models.
3
+ """
4
+
5
+ import sys
6
+ import numpy as np
7
+ from collections import Counter
8
+ import logging
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.init as init
12
+
13
+ import stanza.models.common.seq2seq_constant as constant
14
+ from stanza.models.common.trainer import Trainer as BaseTrainer
15
+ from stanza.models.common.seq2seq_model import Seq2SeqModel
16
+ from stanza.models.common import utils, loss
17
+ from stanza.models.mwt.character_classifier import CharacterClassifier
18
+ from stanza.models.mwt.vocab import Vocab
19
+
20
+ logger = logging.getLogger('stanza')
21
+
22
+ def unpack_batch(batch, device):
23
+ """ Unpack a batch from the data loader. """
24
+ inputs = [b.to(device) if b is not None else None for b in batch[:4]]
25
+ orig_text = batch[4]
26
+ orig_idx = batch[5]
27
+ return inputs, orig_text, orig_idx
28
+
29
+ class Trainer(BaseTrainer):
30
+ """ A trainer for training models. """
31
+ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None):
32
+ if model_file is not None:
33
+ # load from file
34
+ self.load(model_file)
35
+ else:
36
+ self.args = args
37
+ if args['dict_only']:
38
+ self.model = None
39
+ elif args.get('force_exact_pieces', False):
40
+ self.model = CharacterClassifier(args)
41
+ else:
42
+ self.model = Seq2SeqModel(args, emb_matrix=emb_matrix)
43
+ self.vocab = vocab
44
+ self.expansion_dict = dict()
45
+ if not self.args['dict_only']:
46
+ self.model = self.model.to(device)
47
+ if self.args.get('force_exact_pieces', False):
48
+ self.crit = nn.CrossEntropyLoss()
49
+ else:
50
+ self.crit = loss.SequenceLoss(self.vocab.size).to(device)
51
+ self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'])
52
+
53
+ def update(self, batch, eval=False):
54
+ device = next(self.model.parameters()).device
55
+ # ignore the original text when training
56
+ # can try to learn the correct values, even if we eventually
57
+ # copy directly from the original text
58
+ inputs, _, orig_idx = unpack_batch(batch, device)
59
+ src, src_mask, tgt_in, tgt_out = inputs
60
+
61
+ if eval:
62
+ self.model.eval()
63
+ else:
64
+ self.model.train()
65
+ self.optimizer.zero_grad()
66
+ if self.args.get('force_exact_pieces', False):
67
+ log_probs = self.model(src, src_mask)
68
+ src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
69
+ packed_output = nn.utils.rnn.pack_padded_sequence(log_probs, src_lens, batch_first=True)
70
+ packed_tgt = nn.utils.rnn.pack_padded_sequence(tgt_in, src_lens, batch_first=True)
71
+ loss = self.crit(packed_output.data, packed_tgt.data)
72
+ else:
73
+ log_probs, _ = self.model(src, src_mask, tgt_in)
74
+ loss = self.crit(log_probs.view(-1, self.vocab.size), tgt_out.view(-1))
75
+ loss_val = loss.data.item()
76
+ if eval:
77
+ return loss_val
78
+
79
+ loss.backward()
80
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
81
+ self.optimizer.step()
82
+ return loss_val
83
+
84
+ def predict(self, batch, unsort=True, never_decode_unk=False, vocab=None):
85
+ if vocab is None:
86
+ vocab = self.vocab
87
+
88
+ device = next(self.model.parameters()).device
89
+ inputs, orig_text, orig_idx = unpack_batch(batch, device)
90
+ src, src_mask, tgt, tgt_mask = inputs
91
+
92
+ self.model.eval()
93
+ batch_size = src.size(0)
94
+ if self.args.get('force_exact_pieces', False):
95
+ log_probs = self.model(src, src_mask)
96
+ cuts = log_probs[:, :, 1] > log_probs[:, :, 0]
97
+ src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
98
+ pred_tokens = []
99
+ for src_ids, cut, src_len in zip(src, cuts, src_lens):
100
+ src_chars = vocab.unmap(src_ids)
101
+ pred_seq = []
102
+ for char_idx in range(1, src_len-1):
103
+ if cut[char_idx]:
104
+ pred_seq.append(' ')
105
+ pred_seq.append(src_chars[char_idx])
106
+ pred_seq = "".join(pred_seq).strip()
107
+ pred_tokens.append(pred_seq)
108
+ else:
109
+ preds, _ = self.model.predict(src, src_mask, self.args['beam_size'], never_decode_unk=never_decode_unk)
110
+ pred_seqs = [vocab.unmap(ids) for ids in preds] # unmap to tokens
111
+ pred_seqs = utils.prune_decoded_seqs(pred_seqs)
112
+
113
+ pred_tokens = ["".join(seq) for seq in pred_seqs] # join chars to be tokens
114
+ # if any tokens are predicted to expand to blank,
115
+ # that is likely an error. use the original text
116
+ # this originally came up with the Spanish model turning 's' into a blank
117
+ # furthermore, if there are no spaces predicted by the seq2seq,
118
+ # might as well use the original in case the seq2seq went crazy
119
+ # this particular error came up training a Hebrew MWT
120
+ pred_tokens = [x if x and ' ' in x else y for x, y in zip(pred_tokens, orig_text)]
121
+ if unsort:
122
+ pred_tokens = utils.unsort(pred_tokens, orig_idx)
123
+ return pred_tokens
124
+
125
+ def train_dict(self, pairs):
126
+ """ Train a MWT expander given training word-expansion pairs. """
127
+ # accumulate counter
128
+ ctr = Counter()
129
+ ctr.update([(p[0], p[1]) for p in pairs])
130
+ seen = set()
131
+ # find the most frequent mappings
132
+ for p, _ in ctr.most_common():
133
+ w, l = p
134
+ if w not in seen and w != l:
135
+ self.expansion_dict[w] = l
136
+ seen.add(w)
137
+ return
138
+
139
+ def dict_expansion(self, word):
140
+ """
141
+ Check the expansion dictionary for the word along with a couple common lowercasings of the word
142
+
143
+ (Leadingcase and UPPERCASE)
144
+ """
145
+ expansion = self.expansion_dict.get(word)
146
+ if expansion is not None:
147
+ return expansion
148
+
149
+ if word.isupper():
150
+ expansion = self.expansion_dict.get(word.lower())
151
+ if expansion is not None:
152
+ return expansion.upper()
153
+
154
+ if word[0].isupper() and word[1:].islower():
155
+ expansion = self.expansion_dict.get(word.lower())
156
+ if expansion is not None:
157
+ return expansion[0].upper() + expansion[1:]
158
+
159
+ # could build a truecasing model of some kind to handle cRaZyCaSe...
160
+ # but that's probably too much effort
161
+ return None
162
+
163
+ def predict_dict(self, words):
164
+ """ Predict a list of expansions given words. """
165
+ expansions = []
166
+ for w in words:
167
+ expansion = self.dict_expansion(w)
168
+ if expansion is not None:
169
+ expansions.append(expansion)
170
+ else:
171
+ expansions.append(w)
172
+ return expansions
173
+
174
+ def ensemble(self, cands, other_preds):
175
+ """ Ensemble the dict with statistical model predictions. """
176
+ expansions = []
177
+ assert len(cands) == len(other_preds)
178
+ for c, pred in zip(cands, other_preds):
179
+ expansion = self.dict_expansion(c)
180
+ if expansion is not None:
181
+ expansions.append(expansion)
182
+ else:
183
+ expansions.append(pred)
184
+ return expansions
185
+
186
+ def save(self, filename):
187
+ params = {
188
+ 'model': self.model.state_dict() if self.model is not None else None,
189
+ 'dict': self.expansion_dict,
190
+ 'vocab': self.vocab.state_dict(),
191
+ 'config': self.args
192
+ }
193
+ try:
194
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
195
+ logger.info("Model saved to {}".format(filename))
196
+ except BaseException:
197
+ logger.warning("Saving failed... continuing anyway.")
198
+
199
+ def load(self, filename):
200
+ try:
201
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
202
+ except BaseException:
203
+ logger.error("Cannot load model from {}".format(filename))
204
+ raise
205
+ self.args = checkpoint['config']
206
+ self.expansion_dict = checkpoint['dict']
207
+ if not self.args['dict_only']:
208
+ if self.args.get('force_exact_pieces', False):
209
+ self.model = CharacterClassifier(self.args)
210
+ else:
211
+ self.model = Seq2SeqModel(self.args)
212
+ # could remove strict=False after rebuilding all models,
213
+ # or could switch to 1.6.0 torch with the buffer in seq2seq persistent=False
214
+ self.model.load_state_dict(checkpoint['model'], strict=False)
215
+ else:
216
+ self.model = None
217
+ self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
218
+
stanza/stanza/models/mwt/vocab.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ from stanza.models.common.vocab import BaseVocab
4
+ import stanza.models.common.seq2seq_constant as constant
5
+
6
+ class Vocab(BaseVocab):
7
+ def build_vocab(self):
8
+ pairs = self.data
9
+ allchars = "".join([src + tgt for src, tgt in pairs])
10
+ counter = Counter(allchars)
11
+
12
+ self._id2unit = constant.VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
13
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
14
+
15
+ def add_unit(self, unit):
16
+ if unit in self._unit2id:
17
+ return
18
+ self._unit2id[unit] = len(self._id2unit)
19
+ self._id2unit.append(unit)
stanza/stanza/models/ner/vocab.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, OrderedDict
2
+
3
+ from stanza.models.common.vocab import BaseVocab, BaseMultiVocab, CharVocab, CompositeVocab
4
+ from stanza.models.common.vocab import VOCAB_PREFIX
5
+ from stanza.models.common.pretrain import PretrainedWordVocab
6
+ from stanza.models.pos.vocab import WordVocab
7
+
8
+ class TagVocab(BaseVocab):
9
+ """ A vocab for the output tag sequence. """
10
+ def build_vocab(self):
11
+ counter = Counter([w[self.idx] for sent in self.data for w in sent])
12
+
13
+ self._id2unit = VOCAB_PREFIX + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
14
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
15
+
16
+ def convert_tag_vocab(state_dict):
17
+ if state_dict['lower']:
18
+ raise AssertionError("Did not expect an NER vocab with 'lower' set to True")
19
+ items = state_dict['_id2unit'][len(VOCAB_PREFIX):]
20
+ # this looks silly, but the vocab builder treats this as words with multiple fields
21
+ # (we set it to look for field 0 with idx=0)
22
+ # and then the label field is expected to be a list or tuple of items
23
+ items = [[[[x]]] for x in items]
24
+ vocab = CompositeVocab(data=items, lang=state_dict['lang'], idx=0, sep=None)
25
+ if len(vocab._id2unit[0]) != len(state_dict['_id2unit']):
26
+ raise AssertionError("Failed to construct a new vocab of the same length as the original")
27
+ if vocab._id2unit[0] != state_dict['_id2unit']:
28
+ raise AssertionError("Failed to construct a new vocab in the same order as the original")
29
+ return vocab
30
+
31
+ class MultiVocab(BaseMultiVocab):
32
+ def state_dict(self):
33
+ """ Also save a vocab name to class name mapping in state dict. """
34
+ state = OrderedDict()
35
+ key2class = OrderedDict()
36
+ for k, v in self._vocabs.items():
37
+ state[k] = v.state_dict()
38
+ key2class[k] = type(v).__name__
39
+ state['_key2class'] = key2class
40
+ return state
41
+
42
+ @classmethod
43
+ def load_state_dict(cls, state_dict):
44
+ class_dict = {'CharVocab': CharVocab.load_state_dict,
45
+ 'PretrainedWordVocab': PretrainedWordVocab.load_state_dict,
46
+ 'TagVocab': convert_tag_vocab,
47
+ 'CompositeVocab': CompositeVocab.load_state_dict,
48
+ 'WordVocab': WordVocab.load_state_dict}
49
+ new = cls()
50
+ assert '_key2class' in state_dict, "Cannot find class name mapping in state dict!"
51
+ key2class = state_dict.pop('_key2class')
52
+ for k,v in state_dict.items():
53
+ classname = key2class[k]
54
+ new[k] = class_dict[classname](v)
55
+ return new
56
+
stanza/stanza/models/pos/__init__.py ADDED
File without changes
stanza/stanza/models/pos/build_xpos_vocab_factory.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections import defaultdict
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ from zipfile import ZipFile
8
+
9
+ from stanza.models.common.constant import treebank_to_short_name
10
+ from stanza.models.pos.xpos_vocab_utils import DEFAULT_KEY, choose_simplest_factory, XPOSType
11
+ from stanza.models.common.doc import *
12
+ from stanza.utils.conll import CoNLL
13
+ from stanza.utils import default_paths
14
+
15
+ SHORTNAME_RE = re.compile("[a-z-]+_[a-z0-9]+")
16
+ DATA_DIR = default_paths.get_default_paths()['POS_DATA_DIR']
17
+
18
+ logger = logging.getLogger('stanza')
19
+
20
+ def get_xpos_factory(shorthand, fn):
21
+ logger.info('Resolving vocab option for {}...'.format(shorthand))
22
+ doc = None
23
+ train_file = os.path.join(DATA_DIR, '{}.train.in.conllu'.format(shorthand))
24
+ if os.path.exists(train_file):
25
+ doc = CoNLL.conll2doc(input_file=train_file)
26
+ else:
27
+ zip_file = os.path.join(DATA_DIR, '{}.train.in.zip'.format(shorthand))
28
+ if os.path.exists(zip_file):
29
+ with ZipFile(zip_file) as zin:
30
+ for train_file in zin.namelist():
31
+ doc = CoNLL.conll2doc(input_file=train_file, zip_file=zip_file)
32
+ if any(word.xpos for sentence in doc.sentences for word in sentence.words):
33
+ break
34
+ else:
35
+ raise ValueError('Found training data in {}, but none of the files contained had xpos'.format(zip_file))
36
+
37
+ if doc is None:
38
+ raise FileNotFoundError('Training data for {} not found. To generate the XPOS vocabulary '
39
+ 'for this treebank properly, please run the following command first:\n'
40
+ ' python3 stanza/utils/datasets/prepare_pos_treebank.py {}'.format(fn, fn))
41
+ # without the training file, there's not much we can do
42
+ key = DEFAULT_KEY
43
+ return key
44
+
45
+ data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
46
+ return choose_simplest_factory(data, shorthand)
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument('--treebanks', type=str, default=DATA_DIR, help="Treebanks to process - directory with processed datasets or a file with a list")
51
+ parser.add_argument('--output_file', type=str, default="stanza/models/pos/xpos_vocab_factory.py", help="Where to write the results")
52
+ args = parser.parse_args()
53
+
54
+ output_file = args.output_file
55
+ if os.path.isdir(args.treebanks):
56
+ # if the path is a directory of datasets (which is the default if --treebanks is not set)
57
+ # we use those datasets to prepare the xpos factories
58
+ treebanks = os.listdir(args.treebanks)
59
+ treebanks = [x.split(".", maxsplit=1)[0] for x in treebanks]
60
+ treebanks = sorted(set(treebanks))
61
+ elif os.path.exists(args.treebanks):
62
+ # maybe it's a file with a list of names
63
+ with open(args.treebanks) as fin:
64
+ treebanks = sorted(set([x.strip() for x in fin.readlines() if x.strip()]))
65
+ else:
66
+ raise ValueError("Cannot figure out which treebanks to use. Please set the --treebanks parameter")
67
+
68
+ logger.info("Processing the following treebanks: %s" % " ".join(treebanks))
69
+
70
+ shorthands = []
71
+ fullnames = []
72
+ for treebank in treebanks:
73
+ fullnames.append(treebank)
74
+ if SHORTNAME_RE.match(treebank):
75
+ shorthands.append(treebank)
76
+ else:
77
+ shorthands.append(treebank_to_short_name(treebank))
78
+
79
+ # For each treebank, we would like to find the XPOS Vocab configuration that minimizes
80
+ # the number of total classes needed to predict by all tagger classifiers. This is
81
+ # achieved by enumerating different options of separators that different treebanks might
82
+ # use, and comparing that to treating the XPOS tags as separate categories (using a
83
+ # WordVocab).
84
+ mapping = defaultdict(list)
85
+ for sh, fn in zip(shorthands, fullnames):
86
+ factory = get_xpos_factory(sh, fn)
87
+ mapping[factory].append(sh)
88
+ if sh == 'zh-hans_gsdsimp':
89
+ mapping[factory].append('zh_gsdsimp')
90
+ elif sh == 'no_bokmaal':
91
+ mapping[factory].append('nb_bokmaal')
92
+
93
+ mapping[DEFAULT_KEY].append('en_test')
94
+
95
+ # Generate code. This takes the XPOS vocabulary classes selected above, and generates the
96
+ # actual factory class as seen in models.pos.xpos_vocab_factory.
97
+ first = True
98
+ with open(output_file, 'w') as f:
99
+ max_len = max(max(len(x) for x in mapping[key]) for key in mapping)
100
+ print('''# This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
101
+ # Please don't edit it!
102
+
103
+ import logging
104
+
105
+ from stanza.models.pos.vocab import WordVocab, XPOSVocab
106
+ from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
107
+
108
+ # using a sublogger makes it easier to test in the unittests
109
+ logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
110
+
111
+ XPOS_DESCRIPTIONS = {''', file=f)
112
+
113
+ for key_idx, key in enumerate(mapping):
114
+ if key_idx > 0:
115
+ print(file=f)
116
+ for shorthand in sorted(mapping[key]):
117
+ # +2 to max_len for the ''
118
+ # this format string is left justified (either would be okay, probably)
119
+ if key.sep is None:
120
+ sep = 'None'
121
+ else:
122
+ sep = "'%s'" % key.sep
123
+ print((" {:%ds}: XPOSDescription({}, {})," % (max_len+2)).format("'%s'" % shorthand, key.xpos_type, sep), file=f)
124
+
125
+ print('''}
126
+
127
+ def xpos_vocab_factory(data, shorthand):
128
+ if shorthand not in XPOS_DESCRIPTIONS:
129
+ logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
130
+ desc = choose_simplest_factory(data, shorthand)
131
+ if shorthand in XPOS_DESCRIPTIONS:
132
+ if XPOS_DESCRIPTIONS[shorthand] != desc:
133
+ # log instead of throw
134
+ # otherwise, updating datasets would be unpleasant
135
+ logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
136
+ else:
137
+ logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
138
+ return build_xpos_vocab(desc, data, shorthand)
139
+ ''', file=f)
140
+
141
+ logger.info('Done!')
142
+
143
+ if __name__ == "__main__":
144
+ main()
stanza/stanza/models/pos/data.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import logging
3
+ import copy
4
+ import torch
5
+ from collections import namedtuple
6
+
7
+ from torch.utils.data import DataLoader as DL
8
+ from torch.utils.data.sampler import Sampler
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ from stanza.models.common.bert_embedding import filter_data, needs_length_filter
12
+ from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
13
+ from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, CharVocab
14
+ from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
15
+ from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
16
+ from stanza.models.common.doc import *
17
+
18
+ logger = logging.getLogger('stanza')
19
+
20
+ DataSample = namedtuple("DataSample", "word char upos xpos feats pretrain text")
21
+ DataBatch = namedtuple("DataBatch", "words words_mask wordchars wordchars_mask upos xpos ufeats pretrained orig_idx word_orig_idx lens word_lens text idx")
22
+
23
+ class Dataset:
24
+ def __init__(self, doc, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, bert_tokenizer=None, **kwargs):
25
+ self.args = args
26
+ self.eval = evaluation
27
+ self.shuffled = not self.eval
28
+ self.sort_during_eval = sort_during_eval
29
+ self.doc = doc
30
+
31
+ if vocab is None:
32
+ self.vocab = Dataset.init_vocab([doc], args)
33
+ else:
34
+ self.vocab = vocab
35
+
36
+ self.has_upos = not all(x is None or x == '_' for x in doc.get(UPOS, as_sentences=False))
37
+ self.has_xpos = not all(x is None or x == '_' for x in doc.get(XPOS, as_sentences=False))
38
+ self.has_feats = not all(x is None or x == '_' for x in doc.get(FEATS, as_sentences=False))
39
+
40
+ data = self.load_doc(self.doc)
41
+ # filter out the long sentences if bert is used
42
+ if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
43
+ data = filter_data(self.args['bert_model'], data, bert_tokenizer)
44
+
45
+ # handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
46
+ self.pretrain_vocab = None
47
+ if pretrain is not None and args['pretrain']:
48
+ self.pretrain_vocab = pretrain.vocab
49
+
50
+ # filter and sample data
51
+ if args.get('sample_train', 1.0) < 1.0 and not self.eval:
52
+ keep = int(args['sample_train'] * len(data))
53
+ data = random.sample(data, keep)
54
+ logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
55
+
56
+ data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
57
+
58
+ self.data = data
59
+
60
+ self.num_examples = len(data)
61
+ self.__punct_tags = self.vocab["upos"].map(["PUNCT"])
62
+ self.augment_nopunct = self.args.get("augment_nopunct", 0.0)
63
+
64
+ @staticmethod
65
+ def init_vocab(docs, args):
66
+ data = [x for doc in docs for x in Dataset.load_doc(doc)]
67
+ charvocab = CharVocab(data, args['shorthand'])
68
+ wordvocab = WordVocab(data, args['shorthand'], cutoff=args['word_cutoff'], lower=True)
69
+ uposvocab = WordVocab(data, args['shorthand'], idx=1)
70
+ xposvocab = xpos_vocab_factory(data, args['shorthand'])
71
+ try:
72
+ featsvocab = FeatureVocab(data, args['shorthand'], idx=3)
73
+ except ValueError as e:
74
+ raise ValueError("Unable to build features vocab. Please check the Features column of your data for an error which may match the following description.") from e
75
+ vocab = MultiVocab({'char': charvocab,
76
+ 'word': wordvocab,
77
+ 'upos': uposvocab,
78
+ 'xpos': xposvocab,
79
+ 'feats': featsvocab})
80
+ return vocab
81
+
82
+ def preprocess(self, data, vocab, pretrain_vocab, args):
83
+ processed = []
84
+ for sent in data:
85
+ processed_sent = DataSample(
86
+ word = [vocab['word'].map([w[0] for w in sent])],
87
+ char = [[vocab['char'].map([x for x in w[0]]) for w in sent]],
88
+ upos = [vocab['upos'].map([w[1] for w in sent])],
89
+ xpos = [vocab['xpos'].map([w[2] for w in sent])],
90
+ feats = [vocab['feats'].map([w[3] for w in sent])],
91
+ pretrain = ([pretrain_vocab.map([w[0].lower() for w in sent])]
92
+ if pretrain_vocab is not None
93
+ else [[PAD_ID] * len(sent)]),
94
+ text = [w[0] for w in sent]
95
+ )
96
+ processed.append(processed_sent)
97
+
98
+ return processed
99
+
100
+ def __len__(self):
101
+ return len(self.data)
102
+
103
+ def __mask(self, upos):
104
+ """Returns a torch boolean about which elements should be masked out"""
105
+
106
+ # creates all false mask
107
+ mask = torch.zeros_like(upos, dtype=torch.bool)
108
+
109
+ ### augmentation 1: punctuation augmentation ###
110
+ # tags that needs to be checked, currently only PUNCT
111
+ if random.uniform(0,1) < self.augment_nopunct:
112
+ for i in self.__punct_tags:
113
+ # generate a mask for the last element
114
+ last_element = torch.zeros_like(upos, dtype=torch.bool)
115
+ last_element[..., -1] = True
116
+ # we or the bitmask against the existing mask
117
+ # if it satisfies, we remove the word by masking it
118
+ # to true
119
+ #
120
+ # if your input is just a lone punctuation, we perform
121
+ # no masking
122
+ if not torch.all(upos.eq(torch.tensor([[i]]))):
123
+ mask |= ((upos == i) & (last_element))
124
+
125
+ return mask
126
+
127
+ def __getitem__(self, key):
128
+ """Retrieves a sample from the dataset.
129
+
130
+ Retrieves a sample from the dataset. This function, for the
131
+ most part, is spent performing ad-hoc data augmentation and
132
+ restoration. It recieves a DataSample object from the storage,
133
+ and returns an almost-identical DataSample object that may
134
+ have been augmented with /possibly/ (depending on augment_punct
135
+ settings) PUNCT chopped.
136
+
137
+ **Important Note**
138
+ ------------------
139
+ If you would like to load the data into a model, please convert
140
+ this Dataset object into a DataLoader via self.to_loader(). Then,
141
+ you can use the resulting object like any other PyTorch data
142
+ loader. As masks are calculated ad-hoc given the batch, the samples
143
+ returned from this object doesn't have the appropriate masking.
144
+
145
+ Motivation
146
+ ----------
147
+ Why is this here? Every time you call next(iter(dataloader)), it calls
148
+ this function. Therefore, if we augmented each sample on each iteration,
149
+ the model will see dynamically generated augmentation.
150
+ Furthermore, PyTorch dataloader handles shuffling natively.
151
+
152
+ Parameters
153
+ ----------
154
+ key : int
155
+ the integer ID to from which to retrieve the key.
156
+
157
+ Returns
158
+ -------
159
+ DataSample
160
+ The sample of data you requested, with augmentation.
161
+ """
162
+ # get a sample of the input data
163
+ sample = self.data[key]
164
+
165
+ # some data augmentation requires constructing a mask based on upos.
166
+ # For instance, sometimes we'd like to mask out ending sentence punctuation.
167
+ # We copy the other items here so that any edits made because
168
+ # of the mask don't clobber the version owned by the Dataset
169
+ # convert to tensors
170
+ # TODO: only store single lists per data entry?
171
+ words = torch.tensor(sample.word[0])
172
+ # convert the rest to tensors
173
+ upos = torch.tensor(sample.upos[0]) if self.has_upos else None
174
+ xpos = torch.tensor(sample.xpos[0]) if self.has_xpos else None
175
+ ufeats = torch.tensor(sample.feats[0]) if self.has_feats else None
176
+ pretrained = torch.tensor(sample.pretrain[0])
177
+
178
+ # and deal with char & raw_text
179
+ char = sample.char[0]
180
+ raw_text = sample.text
181
+
182
+ # some data augmentation requires constructing a mask based on
183
+ # which upos. For instance, sometimes we'd like to mask out ending
184
+ # sentence punctuation. The mask is True if we want to remove the element
185
+ if self.has_upos and upos is not None and not self.eval:
186
+ # perform actual masking
187
+ mask = self.__mask(upos)
188
+ else:
189
+ # dummy mask that's all false
190
+ mask = None
191
+ if mask is not None:
192
+ mask_index = mask.nonzero()
193
+
194
+ # mask out the elements that we need to mask out
195
+ for mask in mask_index:
196
+ mask = mask.item()
197
+ words[mask] = PAD_ID
198
+ if upos is not None:
199
+ upos[mask] = PAD_ID
200
+ if xpos is not None:
201
+ # TODO: test the multi-dimension xpos
202
+ xpos[mask, ...] = PAD_ID
203
+ if ufeats is not None:
204
+ ufeats[mask, ...] = PAD_ID
205
+ pretrained[mask] = PAD_ID
206
+ char = char[:mask] + char[mask+1:]
207
+ raw_text = raw_text[:mask] + raw_text[mask+1:]
208
+
209
+ # get each character from the input sentnece
210
+ # chars = [w for sent in char for w in sent]
211
+
212
+ return DataSample(words, char, upos, xpos, ufeats, pretrained, raw_text), key
213
+
214
+ def __iter__(self):
215
+ for i in range(self.__len__()):
216
+ yield self.__getitem__(i)
217
+
218
+ def to_loader(self, **kwargs):
219
+ """Converts self to a DataLoader """
220
+
221
+ return DL(self,
222
+ collate_fn=Dataset.__collate_fn,
223
+ **kwargs)
224
+
225
+ def to_length_limited_loader(self, batch_size, maximum_tokens):
226
+ sampler = LengthLimitedBatchSampler(self, batch_size, maximum_tokens)
227
+ return DL(self,
228
+ collate_fn=Dataset.__collate_fn,
229
+ batch_sampler = sampler)
230
+
231
+ @staticmethod
232
+ def __collate_fn(data):
233
+ """Function used by DataLoader to pack data"""
234
+ (data, idx) = zip(*data)
235
+ (words, wordchars, upos, xpos, ufeats, pretrained, text) = zip(*data)
236
+
237
+ # collate_fn is given a list of length batch size
238
+ batch_size = len(data)
239
+
240
+ # sort sentences by lens for easy RNN operations
241
+ lens = [torch.sum(x != PAD_ID) for x in words]
242
+ (words, wordchars, upos, xpos,
243
+ ufeats, pretrained, text), orig_idx = sort_all((words, wordchars, upos, xpos,
244
+ ufeats, pretrained, text), lens)
245
+ lens = [torch.sum(x != PAD_ID) for x in words] # we need to reinterpret lengths for the RNN
246
+
247
+ # combine all words into one large list, and sort for easy charRNN ops
248
+ wordchars = [w for sent in wordchars for w in sent]
249
+ word_lens = [len(x) for x in wordchars]
250
+ (wordchars,), word_orig_idx = sort_all([wordchars], word_lens)
251
+ word_lens = [len(x) for x in wordchars] # we need to reinterpret lengths for the RNN
252
+
253
+ # We now pad everything
254
+ words = pad_sequence(words, True, PAD_ID)
255
+ if None not in upos:
256
+ upos = pad_sequence(upos, True, PAD_ID)
257
+ else:
258
+ upos = None
259
+ if None not in xpos:
260
+ xpos = pad_sequence(xpos, True, PAD_ID)
261
+ else:
262
+ xpos = None
263
+ if None not in ufeats:
264
+ ufeats = pad_sequence(ufeats, True, PAD_ID)
265
+ else:
266
+ ufeats = None
267
+ pretrained = pad_sequence(pretrained, True, PAD_ID)
268
+ wordchars = get_long_tensor(wordchars, len(word_lens))
269
+
270
+ # and finally create masks for the padding indices
271
+ words_mask = torch.eq(words, PAD_ID)
272
+ wordchars_mask = torch.eq(wordchars, PAD_ID)
273
+
274
+ return DataBatch(words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats,
275
+ pretrained, orig_idx, word_orig_idx, lens, word_lens, text, idx)
276
+
277
+ @staticmethod
278
+ def load_doc(doc):
279
+ data = doc.get([TEXT, UPOS, XPOS, FEATS], as_sentences=True)
280
+ data = Dataset.resolve_none(data)
281
+ return data
282
+
283
+ @staticmethod
284
+ def resolve_none(data):
285
+ # replace None to '_'
286
+ for sent_idx in range(len(data)):
287
+ for tok_idx in range(len(data[sent_idx])):
288
+ for feat_idx in range(len(data[sent_idx][tok_idx])):
289
+ if data[sent_idx][tok_idx][feat_idx] is None:
290
+ data[sent_idx][tok_idx][feat_idx] = '_'
291
+ return data
292
+
293
+ class LengthLimitedBatchSampler(Sampler):
294
+ """
295
+ Batches up the text in batches of batch_size, but cuts off each time a batch reaches maximum_tokens
296
+
297
+ Intent is to avoid GPU OOM in situations where one sentence is significantly longer than expected,
298
+ leaving a batch too large to fit in the GPU
299
+
300
+ Sentences which are longer than maximum_tokens by themselves are put in their own batches
301
+ """
302
+ def __init__(self, data, batch_size, maximum_tokens):
303
+ """
304
+ Precalculate the batches, making it so len and iter just read off the precalculated batches
305
+ """
306
+ self.data = data
307
+ self.batch_size = batch_size
308
+ self.maximum_tokens = maximum_tokens
309
+
310
+ self.batches = []
311
+ current_batch = []
312
+ current_length = 0
313
+
314
+ for item, item_idx in data:
315
+ item_len = len(item.word)
316
+ if maximum_tokens and item_len > maximum_tokens:
317
+ if len(current_batch) > 0:
318
+ self.batches.append(current_batch)
319
+ current_batch = []
320
+ current_length = 0
321
+ self.batches.append([item_idx])
322
+ continue
323
+ if len(current_batch) + 1 > batch_size or (maximum_tokens and item_len + current_length > maximum_tokens):
324
+ self.batches.append(current_batch)
325
+ current_batch = []
326
+ current_length = 0
327
+ current_batch.append(item_idx)
328
+ current_length += item_len
329
+
330
+ if len(current_batch) > 0:
331
+ self.batches.append(current_batch)
332
+
333
+ def __len__(self):
334
+ return len(self.batches)
335
+
336
+ def __iter__(self):
337
+ for batch in self.batches:
338
+ current_batch = []
339
+ for idx in batch:
340
+ current_batch.append(idx)
341
+ yield current_batch
342
+
343
+
344
+ class ShuffledDataset:
345
+ """A wrapper around one or more datasets which shuffles the data in batch_size chunks
346
+
347
+ This means that if multiple datasets are passed in, the batches
348
+ from each dataset are shuffled together, with one batch being
349
+ entirely members of the same dataset.
350
+
351
+ The main use case of this is that in the tagger, there are cases
352
+ where batches from different datasets will have different
353
+ properties, such as having or not having UPOS tags. We found that
354
+ it is actually somewhat tricky to make the model's loss function
355
+ (in model.py) properly represent batches with mixed w/ and w/o
356
+ property, whereas keeping one entire batch together makes it a lot
357
+ easier to process.
358
+
359
+ The mechanism for the shuffling is that the iterator first makes a
360
+ list long enough to represent each batch from each dataset,
361
+ tracking the index of the dataset it is coming from, then shuffles
362
+ that list. Another alternative would be to use a weighted
363
+ randomization approach, but this is very simple and the memory
364
+ requirements are not too onerous.
365
+
366
+ Note that the batch indices are wasteful in the case of only one
367
+ underlying dataset, which is actually the most common use case,
368
+ but the overhead is small enough that it probably isn't worth
369
+ special casing the one dataset version.
370
+ """
371
+ def __init__(self, datasets, batch_size):
372
+ self.batch_size = batch_size
373
+ self.datasets = datasets
374
+ self.loaders = [x.to_loader(batch_size=self.batch_size, shuffle=True) for x in self.datasets]
375
+
376
+ def __iter__(self):
377
+ iterators = [iter(x) for x in self.loaders]
378
+ lengths = [len(x) for x in self.loaders]
379
+ indices = [[x] * y for x, y in enumerate(lengths)]
380
+ indices = [idx for inner in indices for idx in inner]
381
+ random.shuffle(indices)
382
+
383
+ for idx in indices:
384
+ yield(next(iterators[idx]))
385
+
386
+ def __len__(self):
387
+ return sum(len(x) for x in self.datasets)
stanza/stanza/models/pos/model.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pack_sequence, pad_sequence, PackedSequence
9
+
10
+ from stanza.models.common.bert_embedding import extract_bert_embeddings
11
+ from stanza.models.common.biaffine import BiaffineScorer
12
+ from stanza.models.common.foundation_cache import load_bert, load_charlm
13
+ from stanza.models.common.hlstm import HighwayLSTM
14
+ from stanza.models.common.dropout import WordDropout
15
+ from stanza.models.common.utils import attach_bert_model
16
+ from stanza.models.common.vocab import CompositeVocab
17
+ from stanza.models.common.char_model import CharacterModel
18
+ from stanza.models.common import utils
19
+
20
+ logger = logging.getLogger('stanza')
21
+
22
+ class Tagger(nn.Module):
23
+ def __init__(self, args, vocab, emb_matrix=None, share_hid=False, foundation_cache=None, bert_model=None, bert_tokenizer=None, force_bert_saved=False, peft_name=None):
24
+ super().__init__()
25
+
26
+ self.vocab = vocab
27
+ self.args = args
28
+ self.share_hid = share_hid
29
+ self.unsaved_modules = []
30
+
31
+ # input layers
32
+ input_size = 0
33
+ if self.args['word_emb_dim'] > 0:
34
+ # frequent word embeddings
35
+ self.word_emb = nn.Embedding(len(vocab['word']), self.args['word_emb_dim'], padding_idx=0)
36
+ input_size += self.args['word_emb_dim']
37
+
38
+ if not share_hid:
39
+ # upos embeddings
40
+ self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
41
+
42
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
43
+ if self.args.get('charlm', None):
44
+ if args['charlm_forward_file'] is None or not os.path.exists(args['charlm_forward_file']):
45
+ raise FileNotFoundError('Could not find forward character model: {} Please specify with --charlm_forward_file'.format(args['charlm_forward_file']))
46
+ if args['charlm_backward_file'] is None or not os.path.exists(args['charlm_backward_file']):
47
+ raise FileNotFoundError('Could not find backward character model: {} Please specify with --charlm_backward_file'.format(args['charlm_backward_file']))
48
+ logger.debug("POS model loading charmodels: %s and %s", args['charlm_forward_file'], args['charlm_backward_file'])
49
+ self.add_unsaved_module('charmodel_forward', load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache))
50
+ self.add_unsaved_module('charmodel_backward', load_charlm(args['charlm_backward_file'], foundation_cache=foundation_cache))
51
+ # optionally add a input transformation layer
52
+ if self.args.get('charlm_transform_dim', 0):
53
+ self.charmodel_forward_transform = nn.Linear(self.charmodel_forward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
54
+ self.charmodel_backward_transform = nn.Linear(self.charmodel_backward.hidden_dim(), self.args['charlm_transform_dim'], bias=False)
55
+ input_size += self.args['charlm_transform_dim'] * 2
56
+ else:
57
+ self.charmodel_forward_transform = None
58
+ self.charmodel_backward_transform = None
59
+ input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
60
+ else:
61
+ bidirectional = args.get('char_bidirectional', False)
62
+ self.charmodel = CharacterModel(args, vocab, bidirectional=bidirectional)
63
+ if bidirectional:
64
+ self.trans_char = nn.Linear(self.args['char_hidden_dim'] * 2, self.args['transformed_dim'], bias=False)
65
+ else:
66
+ self.trans_char = nn.Linear(self.args['char_hidden_dim'], self.args['transformed_dim'], bias=False)
67
+ input_size += self.args['transformed_dim']
68
+
69
+ self.peft_name = peft_name
70
+ attach_bert_model(self, bert_model, bert_tokenizer, self.args.get('use_peft', False), force_bert_saved)
71
+ if self.args.get('bert_model', None):
72
+ # TODO: refactor bert_hidden_layers between the different models
73
+ if args.get('bert_hidden_layers', False):
74
+ # The average will be offset by 1/N so that the default zeros
75
+ # represents an average of the N layers
76
+ self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
77
+ nn.init.zeros_(self.bert_layer_mix.weight)
78
+ else:
79
+ # an average of layers 2, 3, 4 will be used
80
+ # (for historic reasons)
81
+ self.bert_layer_mix = None
82
+ input_size += self.bert_model.config.hidden_size
83
+
84
+ if self.args['pretrain']:
85
+ # pretrained embeddings, by default this won't be saved into model file
86
+ self.add_unsaved_module('pretrained_emb', nn.Embedding.from_pretrained(emb_matrix, freeze=True))
87
+ self.trans_pretrained = nn.Linear(emb_matrix.shape[1], self.args['transformed_dim'], bias=False)
88
+ input_size += self.args['transformed_dim']
89
+
90
+ # recurrent layers
91
+ self.taggerlstm = HighwayLSTM(input_size, self.args['hidden_dim'], self.args['num_layers'], batch_first=True, bidirectional=True, dropout=self.args['dropout'], rec_dropout=self.args['rec_dropout'], highway_func=torch.tanh)
92
+ self.drop_replacement = nn.Parameter(torch.randn(input_size) / np.sqrt(input_size))
93
+ self.taggerlstm_h_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
94
+ self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']))
95
+
96
+ # classifiers
97
+ self.upos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'])
98
+ self.upos_clf = nn.Linear(self.args['deep_biaff_hidden_dim'], len(vocab['upos']))
99
+ self.upos_clf.weight.data.zero_()
100
+ self.upos_clf.bias.data.zero_()
101
+
102
+ if share_hid:
103
+ clf_constructor = lambda insize, outsize: nn.Linear(insize, outsize)
104
+ else:
105
+ self.xpos_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['deep_biaff_hidden_dim'] if not isinstance(vocab['xpos'], CompositeVocab) else self.args['composite_deep_biaff_hidden_dim'])
106
+ self.ufeats_hid = nn.Linear(self.args['hidden_dim'] * 2, self.args['composite_deep_biaff_hidden_dim'])
107
+ clf_constructor = lambda insize, outsize: BiaffineScorer(insize, self.args['tag_emb_dim'], outsize)
108
+
109
+ if isinstance(vocab['xpos'], CompositeVocab):
110
+ self.xpos_clf = nn.ModuleList()
111
+ for l in vocab['xpos'].lens():
112
+ self.xpos_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
113
+ else:
114
+ self.xpos_clf = clf_constructor(self.args['deep_biaff_hidden_dim'], len(vocab['xpos']))
115
+ if share_hid:
116
+ self.xpos_clf.weight.data.zero_()
117
+ self.xpos_clf.bias.data.zero_()
118
+
119
+ self.ufeats_clf = nn.ModuleList()
120
+ for l in vocab['feats'].lens():
121
+ if share_hid:
122
+ self.ufeats_clf.append(clf_constructor(self.args['deep_biaff_hidden_dim'], l))
123
+ self.ufeats_clf[-1].weight.data.zero_()
124
+ self.ufeats_clf[-1].bias.data.zero_()
125
+ else:
126
+ self.ufeats_clf.append(clf_constructor(self.args['composite_deep_biaff_hidden_dim'], l))
127
+
128
+ # criterion
129
+ self.crit = nn.CrossEntropyLoss(ignore_index=0) # ignore padding
130
+
131
+ self.drop = nn.Dropout(args['dropout'])
132
+ self.worddrop = WordDropout(args['word_dropout'])
133
+
134
+ def add_unsaved_module(self, name, module):
135
+ self.unsaved_modules += [name]
136
+ setattr(self, name, module)
137
+
138
+ def log_norms(self):
139
+ utils.log_norms(self)
140
+
141
+ def forward(self, word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text):
142
+
143
+ def pack(x):
144
+ return pack_padded_sequence(x, sentlens, batch_first=True)
145
+
146
+ inputs = []
147
+ if self.args['word_emb_dim'] > 0:
148
+ word_emb = self.word_emb(word)
149
+ word_emb = pack(word_emb)
150
+ inputs += [word_emb]
151
+
152
+ if self.args['pretrain']:
153
+ pretrained_emb = self.pretrained_emb(pretrained)
154
+ pretrained_emb = self.trans_pretrained(pretrained_emb)
155
+ pretrained_emb = pack(pretrained_emb)
156
+ inputs += [pretrained_emb]
157
+
158
+ def pad(x):
159
+ return pad_packed_sequence(PackedSequence(x, inputs[0].batch_sizes), batch_first=True)[0]
160
+
161
+ if self.args['char'] and self.args['char_emb_dim'] > 0:
162
+ if self.args.get('charlm', None):
163
+ all_forward_chars = self.charmodel_forward.build_char_representation(text)
164
+ assert isinstance(all_forward_chars, list)
165
+ if self.charmodel_forward_transform is not None:
166
+ all_forward_chars = [self.charmodel_forward_transform(x) for x in all_forward_chars]
167
+ all_forward_chars = pack(pad_sequence(all_forward_chars, batch_first=True))
168
+
169
+ all_backward_chars = self.charmodel_backward.build_char_representation(text)
170
+ if self.charmodel_backward_transform is not None:
171
+ all_backward_chars = [self.charmodel_backward_transform(x) for x in all_backward_chars]
172
+ all_backward_chars = pack(pad_sequence(all_backward_chars, batch_first=True))
173
+
174
+ inputs += [all_forward_chars, all_backward_chars]
175
+ else:
176
+ char_reps = self.charmodel(wordchars, wordchars_mask, word_orig_idx, sentlens, wordlens)
177
+ char_reps = PackedSequence(self.trans_char(self.drop(char_reps.data)), char_reps.batch_sizes)
178
+ inputs += [char_reps]
179
+
180
+ if self.bert_model is not None:
181
+ device = next(self.parameters()).device
182
+ processed_bert = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, text, device, keep_endpoints=False,
183
+ num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None,
184
+ detach=not self.args.get('bert_finetune', False) or not self.training,
185
+ peft_name=self.peft_name)
186
+
187
+ if self.bert_layer_mix is not None:
188
+ # add the average so that the default behavior is to
189
+ # take an average of the N layers, and anything else
190
+ # other than that needs to be learned
191
+ # TODO: refactor this
192
+ processed_bert = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in processed_bert]
193
+
194
+ processed_bert = pad_sequence(processed_bert, batch_first=True)
195
+ inputs += [pack(processed_bert)]
196
+
197
+ lstm_inputs = torch.cat([x.data for x in inputs], 1)
198
+ lstm_inputs = self.worddrop(lstm_inputs, self.drop_replacement)
199
+ lstm_inputs = self.drop(lstm_inputs)
200
+ lstm_inputs = PackedSequence(lstm_inputs, inputs[0].batch_sizes)
201
+
202
+ lstm_outputs, _ = self.taggerlstm(lstm_inputs, sentlens, hx=(self.taggerlstm_h_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous(), self.taggerlstm_c_init.expand(2 * self.args['num_layers'], word.size(0), self.args['hidden_dim']).contiguous()))
203
+ lstm_outputs = lstm_outputs.data
204
+
205
+ upos_hid = F.relu(self.upos_hid(self.drop(lstm_outputs)))
206
+ upos_pred = self.upos_clf(self.drop(upos_hid))
207
+
208
+ preds = [pad(upos_pred).max(2)[1]]
209
+
210
+ if upos is not None:
211
+ upos = pack(upos).data
212
+ loss = self.crit(upos_pred.view(-1, upos_pred.size(-1)), upos.view(-1))
213
+ else:
214
+ loss = 0.0
215
+
216
+ if self.share_hid:
217
+ xpos_hid = upos_hid
218
+ ufeats_hid = upos_hid
219
+
220
+ clffunc = lambda clf, hid: clf(self.drop(hid))
221
+ else:
222
+ xpos_hid = F.relu(self.xpos_hid(self.drop(lstm_outputs)))
223
+ ufeats_hid = F.relu(self.ufeats_hid(self.drop(lstm_outputs)))
224
+
225
+ if self.training and upos is not None:
226
+ upos_emb = self.upos_emb(upos)
227
+ else:
228
+ upos_emb = self.upos_emb(upos_pred.max(1)[1])
229
+
230
+ clffunc = lambda clf, hid: clf(self.drop(hid), self.drop(upos_emb))
231
+
232
+ if xpos is not None: xpos = pack(xpos).data
233
+ if isinstance(self.vocab['xpos'], CompositeVocab):
234
+ xpos_preds = []
235
+ for i in range(len(self.vocab['xpos'])):
236
+ xpos_pred = clffunc(self.xpos_clf[i], xpos_hid)
237
+ if xpos is not None:
238
+ loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos[:, i].view(-1))
239
+ xpos_preds.append(pad(xpos_pred).max(2, keepdim=True)[1])
240
+ preds.append(torch.cat(xpos_preds, 2))
241
+ else:
242
+ xpos_pred = clffunc(self.xpos_clf, xpos_hid)
243
+ if xpos is not None:
244
+ loss += self.crit(xpos_pred.view(-1, xpos_pred.size(-1)), xpos.view(-1))
245
+ preds.append(pad(xpos_pred).max(2)[1])
246
+
247
+ ufeats_preds = []
248
+ if ufeats is not None: ufeats = pack(ufeats).data
249
+ for i in range(len(self.vocab['feats'])):
250
+ ufeats_pred = clffunc(self.ufeats_clf[i], ufeats_hid)
251
+ if ufeats is not None:
252
+ loss += self.crit(ufeats_pred.view(-1, ufeats_pred.size(-1)), ufeats[:, i].view(-1))
253
+ ufeats_preds.append(pad(ufeats_pred).max(2, keepdim=True)[1])
254
+ preds.append(torch.cat(ufeats_preds, 2))
255
+
256
+ return loss, preds
stanza/stanza/models/pos/trainer.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A trainer class to handle training and testing of models.
3
+ """
4
+
5
+ import sys
6
+ import logging
7
+ import torch
8
+ from torch import nn
9
+
10
+ from stanza.models.common.trainer import Trainer as BaseTrainer
11
+ from stanza.models.common import utils, loss
12
+ from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
13
+ from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
14
+ from stanza.models.pos.model import Tagger
15
+ from stanza.models.pos.vocab import MultiVocab
16
+
17
+ logger = logging.getLogger('stanza')
18
+
19
+ def unpack_batch(batch, device):
20
+ """ Unpack a batch from the data loader. """
21
+ inputs = [b.to(device) if b is not None else None for b in batch[:8]]
22
+ orig_idx = batch[8]
23
+ word_orig_idx = batch[9]
24
+ sentlens = batch[10]
25
+ wordlens = batch[11]
26
+ text = batch[12]
27
+ return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text
28
+
29
+ class Trainer(BaseTrainer):
30
+ """ A trainer for training models. """
31
+ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None):
32
+ if model_file is not None:
33
+ # load everything from file
34
+ self.load(model_file, pretrain, args=args, foundation_cache=foundation_cache)
35
+ else:
36
+ # build model from scratch
37
+ self.args = args
38
+ self.vocab = vocab
39
+
40
+ bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
41
+ peft_name = None
42
+ if self.args['use_peft']:
43
+ # fine tune the bert if we're using peft
44
+ self.args['bert_finetune'] = True
45
+ peft_name = "pos"
46
+ bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
47
+
48
+ self.model = Tagger(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, share_hid=args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
49
+
50
+ self.model = self.model.to(device)
51
+ self.optimizers = utils.get_split_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, weight_decay=self.args.get('initial_weight_decay', None), bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("peft", False))
52
+
53
+ self.schedulers = {}
54
+
55
+ if self.args.get('bert_finetune', None):
56
+ import transformers
57
+ warmup_scheduler = transformers.get_linear_schedule_with_warmup(
58
+ self.optimizers["bert_optimizer"],
59
+ # todo late starting?
60
+ 0, self.args["max_steps"])
61
+ self.schedulers["bert_scheduler"] = warmup_scheduler
62
+
63
+ def update(self, batch, eval=False):
64
+ device = next(self.model.parameters()).device
65
+ inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
66
+ word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
67
+
68
+ if eval:
69
+ self.model.eval()
70
+ else:
71
+ self.model.train()
72
+ for optimizer in self.optimizers.values():
73
+ optimizer.zero_grad()
74
+ loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
75
+ if loss == 0.0:
76
+ return loss
77
+
78
+ loss_val = loss.data.item()
79
+ if eval:
80
+ return loss_val
81
+
82
+ loss.backward()
83
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
84
+
85
+ for optimizer in self.optimizers.values():
86
+ optimizer.step()
87
+ for scheduler in self.schedulers.values():
88
+ scheduler.step()
89
+ return loss_val
90
+
91
+ def predict(self, batch, unsort=True):
92
+ device = next(self.model.parameters()).device
93
+ inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
94
+ word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained = inputs
95
+
96
+ self.model.eval()
97
+ batch_size = word.size(0)
98
+ _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, word_orig_idx, sentlens, wordlens, text)
99
+ upos_seqs = [self.vocab['upos'].unmap(sent) for sent in preds[0].tolist()]
100
+ xpos_seqs = [self.vocab['xpos'].unmap(sent) for sent in preds[1].tolist()]
101
+ feats_seqs = [self.vocab['feats'].unmap(sent) for sent in preds[2].tolist()]
102
+
103
+ pred_tokens = [[[upos_seqs[i][j], xpos_seqs[i][j], feats_seqs[i][j]] for j in range(sentlens[i])] for i in range(batch_size)]
104
+ if unsort:
105
+ pred_tokens = utils.unsort(pred_tokens, orig_idx)
106
+ return pred_tokens
107
+
108
+ def save(self, filename, skip_modules=True):
109
+ model_state = self.model.state_dict()
110
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
111
+ if skip_modules:
112
+ skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
113
+ for k in skipped:
114
+ del model_state[k]
115
+ params = {
116
+ 'model': model_state,
117
+ 'vocab': self.vocab.state_dict(),
118
+ 'config': self.args
119
+ }
120
+ if self.args.get('use_peft', False):
121
+ # Hide import so that peft dependency is optional
122
+ from peft import get_peft_model_state_dict
123
+ params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
124
+
125
+ try:
126
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
127
+ logger.info("Model saved to {}".format(filename))
128
+ except (KeyboardInterrupt, SystemExit):
129
+ raise
130
+ except Exception as e:
131
+ logger.warning(f"Saving failed... {e} continuing anyway.")
132
+
133
+ def load(self, filename, pretrain, args=None, foundation_cache=None):
134
+ """
135
+ Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,
136
+ and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
137
+ """
138
+ try:
139
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
140
+ except BaseException:
141
+ logger.error("Cannot load model from {}".format(filename))
142
+ raise
143
+ self.args = checkpoint['config']
144
+ if args is not None: self.args.update(args)
145
+
146
+ # preserve old models which were created before transformers were added
147
+ if 'bert_model' not in self.args:
148
+ self.args['bert_model'] = None
149
+
150
+ lora_weights = checkpoint.get('bert_lora')
151
+ if lora_weights:
152
+ logger.debug("Found peft weights for POS; loading a peft adapter")
153
+ self.args["use_peft"] = True
154
+
155
+ # TODO: refactor this common block of code with NER
156
+ force_bert_saved = False
157
+ peft_name = None
158
+ if self.args.get('use_peft', False):
159
+ force_bert_saved = True
160
+ bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "pos", foundation_cache)
161
+ bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
162
+ logger.debug("Loaded peft with name %s", peft_name)
163
+ else:
164
+ if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
165
+ logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
166
+ foundation_cache = NoTransformerFoundationCache(foundation_cache)
167
+ force_bert_saved = True
168
+ bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
169
+
170
+ self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
171
+ # load model
172
+ emb_matrix = None
173
+ if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None
174
+ emb_matrix = pretrain.emb
175
+ if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
176
+ logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
177
+ foundation_cache = NoTransformerFoundationCache(foundation_cache)
178
+ self.model = Tagger(self.args, self.vocab, emb_matrix=emb_matrix, share_hid=self.args['share_hid'], foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
179
+ self.model.load_state_dict(checkpoint['model'], strict=False)
stanza/stanza/models/pos/xpos_vocab_factory.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is the XPOS factory method generated automatically from stanza.models.pos.build_xpos_vocab_factory.
2
+ # Please don't edit it!
3
+
4
+ import logging
5
+
6
+ from stanza.models.pos.vocab import WordVocab, XPOSVocab
7
+ from stanza.models.pos.xpos_vocab_utils import XPOSDescription, XPOSType, build_xpos_vocab, choose_simplest_factory
8
+
9
+ # using a sublogger makes it easier to test in the unittests
10
+ logger = logging.getLogger('stanza.models.pos.xpos_vocab_factory')
11
+
12
+ XPOS_DESCRIPTIONS = {
13
+ 'af_afribooms' : XPOSDescription(XPOSType.XPOS, ''),
14
+ 'ar_padt' : XPOSDescription(XPOSType.XPOS, ''),
15
+ 'bg_btb' : XPOSDescription(XPOSType.XPOS, ''),
16
+ 'ca_ancora' : XPOSDescription(XPOSType.XPOS, ''),
17
+ 'cs_cac' : XPOSDescription(XPOSType.XPOS, ''),
18
+ 'cs_cltt' : XPOSDescription(XPOSType.XPOS, ''),
19
+ 'cs_fictree' : XPOSDescription(XPOSType.XPOS, ''),
20
+ 'cs_pdt' : XPOSDescription(XPOSType.XPOS, ''),
21
+ 'en_partut' : XPOSDescription(XPOSType.XPOS, ''),
22
+ 'es_ancora' : XPOSDescription(XPOSType.XPOS, ''),
23
+ 'es_combined' : XPOSDescription(XPOSType.XPOS, ''),
24
+ 'fr_partut' : XPOSDescription(XPOSType.XPOS, ''),
25
+ 'gd_arcosg' : XPOSDescription(XPOSType.XPOS, ''),
26
+ 'gl_ctg' : XPOSDescription(XPOSType.XPOS, ''),
27
+ 'gl_treegal' : XPOSDescription(XPOSType.XPOS, ''),
28
+ 'grc_perseus' : XPOSDescription(XPOSType.XPOS, ''),
29
+ 'hr_set' : XPOSDescription(XPOSType.XPOS, ''),
30
+ 'is_gc' : XPOSDescription(XPOSType.XPOS, ''),
31
+ 'is_icepahc' : XPOSDescription(XPOSType.XPOS, ''),
32
+ 'is_modern' : XPOSDescription(XPOSType.XPOS, ''),
33
+ 'it_combined' : XPOSDescription(XPOSType.XPOS, ''),
34
+ 'it_isdt' : XPOSDescription(XPOSType.XPOS, ''),
35
+ 'it_markit' : XPOSDescription(XPOSType.XPOS, ''),
36
+ 'it_parlamint' : XPOSDescription(XPOSType.XPOS, ''),
37
+ 'it_partut' : XPOSDescription(XPOSType.XPOS, ''),
38
+ 'it_postwita' : XPOSDescription(XPOSType.XPOS, ''),
39
+ 'it_twittiro' : XPOSDescription(XPOSType.XPOS, ''),
40
+ 'it_vit' : XPOSDescription(XPOSType.XPOS, ''),
41
+ 'la_perseus' : XPOSDescription(XPOSType.XPOS, ''),
42
+ 'la_udante' : XPOSDescription(XPOSType.XPOS, ''),
43
+ 'lt_alksnis' : XPOSDescription(XPOSType.XPOS, ''),
44
+ 'lv_lvtb' : XPOSDescription(XPOSType.XPOS, ''),
45
+ 'ro_nonstandard' : XPOSDescription(XPOSType.XPOS, ''),
46
+ 'ro_rrt' : XPOSDescription(XPOSType.XPOS, ''),
47
+ 'ro_simonero' : XPOSDescription(XPOSType.XPOS, ''),
48
+ 'sk_snk' : XPOSDescription(XPOSType.XPOS, ''),
49
+ 'sl_ssj' : XPOSDescription(XPOSType.XPOS, ''),
50
+ 'sl_sst' : XPOSDescription(XPOSType.XPOS, ''),
51
+ 'sr_set' : XPOSDescription(XPOSType.XPOS, ''),
52
+ 'ta_ttb' : XPOSDescription(XPOSType.XPOS, ''),
53
+ 'uk_iu' : XPOSDescription(XPOSType.XPOS, ''),
54
+
55
+ 'be_hse' : XPOSDescription(XPOSType.WORD, None),
56
+ 'bxr_bdt' : XPOSDescription(XPOSType.WORD, None),
57
+ 'cop_scriptorium': XPOSDescription(XPOSType.WORD, None),
58
+ 'cu_proiel' : XPOSDescription(XPOSType.WORD, None),
59
+ 'cy_ccg' : XPOSDescription(XPOSType.WORD, None),
60
+ 'da_ddt' : XPOSDescription(XPOSType.WORD, None),
61
+ 'de_gsd' : XPOSDescription(XPOSType.WORD, None),
62
+ 'de_hdt' : XPOSDescription(XPOSType.WORD, None),
63
+ 'el_gdt' : XPOSDescription(XPOSType.WORD, None),
64
+ 'el_gud' : XPOSDescription(XPOSType.WORD, None),
65
+ 'en_atis' : XPOSDescription(XPOSType.WORD, None),
66
+ 'en_combined' : XPOSDescription(XPOSType.WORD, None),
67
+ 'en_craft' : XPOSDescription(XPOSType.WORD, None),
68
+ 'en_eslspok' : XPOSDescription(XPOSType.WORD, None),
69
+ 'en_ewt' : XPOSDescription(XPOSType.WORD, None),
70
+ 'en_genia' : XPOSDescription(XPOSType.WORD, None),
71
+ 'en_gum' : XPOSDescription(XPOSType.WORD, None),
72
+ 'en_gumreddit' : XPOSDescription(XPOSType.WORD, None),
73
+ 'en_mimic' : XPOSDescription(XPOSType.WORD, None),
74
+ 'en_test' : XPOSDescription(XPOSType.WORD, None),
75
+ 'es_gsd' : XPOSDescription(XPOSType.WORD, None),
76
+ 'et_edt' : XPOSDescription(XPOSType.WORD, None),
77
+ 'et_ewt' : XPOSDescription(XPOSType.WORD, None),
78
+ 'eu_bdt' : XPOSDescription(XPOSType.WORD, None),
79
+ 'fa_perdt' : XPOSDescription(XPOSType.WORD, None),
80
+ 'fa_seraji' : XPOSDescription(XPOSType.WORD, None),
81
+ 'fi_tdt' : XPOSDescription(XPOSType.WORD, None),
82
+ 'fr_combined' : XPOSDescription(XPOSType.WORD, None),
83
+ 'fr_gsd' : XPOSDescription(XPOSType.WORD, None),
84
+ 'fr_parisstories': XPOSDescription(XPOSType.WORD, None),
85
+ 'fr_rhapsodie' : XPOSDescription(XPOSType.WORD, None),
86
+ 'fr_sequoia' : XPOSDescription(XPOSType.WORD, None),
87
+ 'fro_profiterole': XPOSDescription(XPOSType.WORD, None),
88
+ 'ga_idt' : XPOSDescription(XPOSType.WORD, None),
89
+ 'ga_twittirish' : XPOSDescription(XPOSType.WORD, None),
90
+ 'got_proiel' : XPOSDescription(XPOSType.WORD, None),
91
+ 'grc_proiel' : XPOSDescription(XPOSType.WORD, None),
92
+ 'grc_ptnk' : XPOSDescription(XPOSType.WORD, None),
93
+ 'gv_cadhan' : XPOSDescription(XPOSType.WORD, None),
94
+ 'hbo_ptnk' : XPOSDescription(XPOSType.WORD, None),
95
+ 'he_combined' : XPOSDescription(XPOSType.WORD, None),
96
+ 'he_htb' : XPOSDescription(XPOSType.WORD, None),
97
+ 'he_iahltknesset': XPOSDescription(XPOSType.WORD, None),
98
+ 'he_iahltwiki' : XPOSDescription(XPOSType.WORD, None),
99
+ 'hi_hdtb' : XPOSDescription(XPOSType.WORD, None),
100
+ 'hsb_ufal' : XPOSDescription(XPOSType.WORD, None),
101
+ 'hu_szeged' : XPOSDescription(XPOSType.WORD, None),
102
+ 'hy_armtdp' : XPOSDescription(XPOSType.WORD, None),
103
+ 'hy_bsut' : XPOSDescription(XPOSType.WORD, None),
104
+ 'hyw_armtdp' : XPOSDescription(XPOSType.WORD, None),
105
+ 'id_csui' : XPOSDescription(XPOSType.WORD, None),
106
+ 'it_old' : XPOSDescription(XPOSType.WORD, None),
107
+ 'ka_glc' : XPOSDescription(XPOSType.WORD, None),
108
+ 'kk_ktb' : XPOSDescription(XPOSType.WORD, None),
109
+ 'kmr_mg' : XPOSDescription(XPOSType.WORD, None),
110
+ 'kpv_lattice' : XPOSDescription(XPOSType.WORD, None),
111
+ 'ky_ktmu' : XPOSDescription(XPOSType.WORD, None),
112
+ 'la_proiel' : XPOSDescription(XPOSType.WORD, None),
113
+ 'lij_glt' : XPOSDescription(XPOSType.WORD, None),
114
+ 'lt_hse' : XPOSDescription(XPOSType.WORD, None),
115
+ 'lzh_kyoto' : XPOSDescription(XPOSType.WORD, None),
116
+ 'mr_ufal' : XPOSDescription(XPOSType.WORD, None),
117
+ 'mt_mudt' : XPOSDescription(XPOSType.WORD, None),
118
+ 'myv_jr' : XPOSDescription(XPOSType.WORD, None),
119
+ 'nb_bokmaal' : XPOSDescription(XPOSType.WORD, None),
120
+ 'nds_lsdc' : XPOSDescription(XPOSType.WORD, None),
121
+ 'nn_nynorsk' : XPOSDescription(XPOSType.WORD, None),
122
+ 'nn_nynorsklia' : XPOSDescription(XPOSType.WORD, None),
123
+ 'no_bokmaal' : XPOSDescription(XPOSType.WORD, None),
124
+ 'orv_birchbark' : XPOSDescription(XPOSType.WORD, None),
125
+ 'orv_rnc' : XPOSDescription(XPOSType.WORD, None),
126
+ 'orv_torot' : XPOSDescription(XPOSType.WORD, None),
127
+ 'ota_boun' : XPOSDescription(XPOSType.WORD, None),
128
+ 'pcm_nsc' : XPOSDescription(XPOSType.WORD, None),
129
+ 'pt_bosque' : XPOSDescription(XPOSType.WORD, None),
130
+ 'pt_cintil' : XPOSDescription(XPOSType.WORD, None),
131
+ 'pt_dantestocks' : XPOSDescription(XPOSType.WORD, None),
132
+ 'pt_gsd' : XPOSDescription(XPOSType.WORD, None),
133
+ 'pt_petrogold' : XPOSDescription(XPOSType.WORD, None),
134
+ 'pt_porttinari' : XPOSDescription(XPOSType.WORD, None),
135
+ 'qpm_philotis' : XPOSDescription(XPOSType.WORD, None),
136
+ 'qtd_sagt' : XPOSDescription(XPOSType.WORD, None),
137
+ 'ru_gsd' : XPOSDescription(XPOSType.WORD, None),
138
+ 'ru_poetry' : XPOSDescription(XPOSType.WORD, None),
139
+ 'ru_syntagrus' : XPOSDescription(XPOSType.WORD, None),
140
+ 'ru_taiga' : XPOSDescription(XPOSType.WORD, None),
141
+ 'sa_vedic' : XPOSDescription(XPOSType.WORD, None),
142
+ 'sme_giella' : XPOSDescription(XPOSType.WORD, None),
143
+ 'swl_sslc' : XPOSDescription(XPOSType.WORD, None),
144
+ 'sq_staf' : XPOSDescription(XPOSType.WORD, None),
145
+ 'te_mtg' : XPOSDescription(XPOSType.WORD, None),
146
+ 'tr_atis' : XPOSDescription(XPOSType.WORD, None),
147
+ 'tr_boun' : XPOSDescription(XPOSType.WORD, None),
148
+ 'tr_framenet' : XPOSDescription(XPOSType.WORD, None),
149
+ 'tr_imst' : XPOSDescription(XPOSType.WORD, None),
150
+ 'tr_kenet' : XPOSDescription(XPOSType.WORD, None),
151
+ 'tr_penn' : XPOSDescription(XPOSType.WORD, None),
152
+ 'tr_tourism' : XPOSDescription(XPOSType.WORD, None),
153
+ 'ug_udt' : XPOSDescription(XPOSType.WORD, None),
154
+ 'uk_parlamint' : XPOSDescription(XPOSType.WORD, None),
155
+ 'vi_vtb' : XPOSDescription(XPOSType.WORD, None),
156
+ 'wo_wtb' : XPOSDescription(XPOSType.WORD, None),
157
+ 'xcl_caval' : XPOSDescription(XPOSType.WORD, None),
158
+ 'zh-hans_gsdsimp': XPOSDescription(XPOSType.WORD, None),
159
+ 'zh-hant_gsd' : XPOSDescription(XPOSType.WORD, None),
160
+ 'zh_gsdsimp' : XPOSDescription(XPOSType.WORD, None),
161
+
162
+ 'en_lines' : XPOSDescription(XPOSType.XPOS, '-'),
163
+ 'fo_farpahc' : XPOSDescription(XPOSType.XPOS, '-'),
164
+ 'ja_gsd' : XPOSDescription(XPOSType.XPOS, '-'),
165
+ 'ja_gsdluw' : XPOSDescription(XPOSType.XPOS, '-'),
166
+ 'sv_lines' : XPOSDescription(XPOSType.XPOS, '-'),
167
+ 'ur_udtb' : XPOSDescription(XPOSType.XPOS, '-'),
168
+
169
+ 'fi_ftb' : XPOSDescription(XPOSType.XPOS, ','),
170
+ 'orv_ruthenian' : XPOSDescription(XPOSType.XPOS, ','),
171
+
172
+ 'id_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
173
+ 'ko_gsd' : XPOSDescription(XPOSType.XPOS, '+'),
174
+ 'ko_kaist' : XPOSDescription(XPOSType.XPOS, '+'),
175
+ 'ko_ksl' : XPOSDescription(XPOSType.XPOS, '+'),
176
+ 'qaf_arabizi' : XPOSDescription(XPOSType.XPOS, '+'),
177
+
178
+ 'la_ittb' : XPOSDescription(XPOSType.XPOS, '|'),
179
+ 'la_llct' : XPOSDescription(XPOSType.XPOS, '|'),
180
+ 'nl_alpino' : XPOSDescription(XPOSType.XPOS, '|'),
181
+ 'nl_lassysmall' : XPOSDescription(XPOSType.XPOS, '|'),
182
+ 'sv_talbanken' : XPOSDescription(XPOSType.XPOS, '|'),
183
+
184
+ 'pl_lfg' : XPOSDescription(XPOSType.XPOS, ':'),
185
+ 'pl_pdb' : XPOSDescription(XPOSType.XPOS, ':'),
186
+ }
187
+
188
+ def xpos_vocab_factory(data, shorthand):
189
+ if shorthand not in XPOS_DESCRIPTIONS:
190
+ logger.warning("%s is not a known dataset. Examining the data to choose which xpos vocab to use", shorthand)
191
+ desc = choose_simplest_factory(data, shorthand)
192
+ if shorthand in XPOS_DESCRIPTIONS:
193
+ if XPOS_DESCRIPTIONS[shorthand] != desc:
194
+ # log instead of throw
195
+ # otherwise, updating datasets would be unpleasant
196
+ logger.error("XPOS tagset in %s has apparently changed! Was %s, is now %s", shorthand, XPOS_DESCRIPTIONS[shorthand], desc)
197
+ else:
198
+ logger.warning("Chose %s for the xpos factory for %s", desc, shorthand)
199
+ return build_xpos_vocab(desc, data, shorthand)
200
+
stanza/stanza/models/pos/xpos_vocab_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from enum import Enum
3
+ import logging
4
+ import os
5
+
6
+ from stanza.models.common.vocab import VOCAB_PREFIX
7
+ from stanza.models.pos.vocab import XPOSVocab, WordVocab
8
+
9
+ class XPOSType(Enum):
10
+ XPOS = 1
11
+ WORD = 2
12
+
13
+ XPOSDescription = namedtuple('XPOSDescription', ['xpos_type', 'sep'])
14
+ DEFAULT_KEY = XPOSDescription(XPOSType.WORD, None)
15
+
16
+ logger = logging.getLogger('stanza')
17
+
18
+ def filter_data(data, idx):
19
+ data_filtered = []
20
+ for sentence in data:
21
+ flag = True
22
+ for token in sentence:
23
+ if token[idx] is None:
24
+ flag = False
25
+ if flag: data_filtered.append(sentence)
26
+ return data_filtered
27
+
28
+ def choose_simplest_factory(data, shorthand):
29
+ logger.info(f'Original length = {len(data)}')
30
+ data = filter_data(data, idx=2)
31
+ logger.info(f'Filtered length = {len(data)}')
32
+ vocab = WordVocab(data, shorthand, idx=2, ignore=["_"])
33
+ key = DEFAULT_KEY
34
+ best_size = len(vocab) - len(VOCAB_PREFIX)
35
+ if best_size > 20:
36
+ for sep in ['', '-', '+', '|', ',', ':']: # separators
37
+ vocab = XPOSVocab(data, shorthand, idx=2, sep=sep)
38
+ length = sum(len(x) - len(VOCAB_PREFIX) for x in vocab._id2unit.values())
39
+ if length < best_size:
40
+ key = XPOSDescription(XPOSType.XPOS, sep)
41
+ best_size = length
42
+ return key
43
+
44
+ def build_xpos_vocab(description, data, shorthand):
45
+ if description.xpos_type is XPOSType.WORD:
46
+ return WordVocab(data, shorthand, idx=2, ignore=["_"])
47
+
48
+ return XPOSVocab(data, shorthand, idx=2, sep=description.sep)
stanza/stanza/models/tokenization/__init__.py ADDED
File without changes
stanza/stanza/models/tokenization/data.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect_right
2
+ from copy import copy
3
+ import numpy as np
4
+ import random
5
+ import logging
6
+ import re
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from .vocab import Vocab
10
+
11
+ from stanza.models.common.utils import sort_with_indices, unsort
12
+
13
+ logger = logging.getLogger('stanza')
14
+
15
+ def filter_consecutive_whitespaces(para):
16
+ filtered = []
17
+ for i, (char, label) in enumerate(para):
18
+ if i > 0:
19
+ if char == ' ' and para[i-1][0] == ' ':
20
+ continue
21
+
22
+ filtered.append((char, label))
23
+
24
+ return filtered
25
+
26
+ NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
27
+ # this was (r'^([\d]+[,\.]*)+$')
28
+ # but the runtime on that can explode exponentially
29
+ # for example, on 111111111111111111111111a
30
+ NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$')
31
+ WHITESPACE_RE = re.compile(r'\s')
32
+
33
+ class TokenizationDataset:
34
+ def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs):
35
+ super().__init__(*args, **kwargs) # forwards all unused arguments
36
+ self.args = tokenizer_args
37
+ self.eval = evaluation
38
+ self.dictionary = dictionary
39
+ self.vocab = vocab
40
+
41
+ # get input files
42
+ txt_file = input_files['txt']
43
+ label_file = input_files['label']
44
+
45
+ # Load data and process it
46
+ # set up text from file or input string
47
+ assert txt_file is not None or input_text is not None
48
+ if input_text is None:
49
+ with open(txt_file) as f:
50
+ text = ''.join(f.readlines()).rstrip()
51
+ else:
52
+ text = input_text
53
+
54
+ text_chunks = NEWLINE_WHITESPACE_RE.split(text)
55
+ text_chunks = [pt.rstrip() for pt in text_chunks]
56
+ text_chunks = [pt for pt in text_chunks if pt]
57
+ if label_file is not None:
58
+ with open(label_file) as f:
59
+ labels = ''.join(f.readlines()).rstrip()
60
+ labels = NEWLINE_WHITESPACE_RE.split(labels)
61
+ labels = [pt.rstrip() for pt in labels]
62
+ labels = [map(int, pt) for pt in labels if pt]
63
+ else:
64
+ labels = [[0 for _ in pt] for pt in text_chunks]
65
+
66
+ skip_newline = self.args.get('skip_newline', False)
67
+ self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces
68
+ for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten
69
+ for pt, pc in zip(text_chunks, labels)]
70
+
71
+ # remove consecutive whitespaces
72
+ self.data = [filter_consecutive_whitespaces(x) for x in self.data]
73
+
74
+ def labels(self):
75
+ """
76
+ Returns a list of the labels for all of the sentences in this DataLoader
77
+
78
+ Used at eval time to compare to the results, for example
79
+ """
80
+ return [np.array(list(x[1] for x in sent)) for sent in self.data]
81
+
82
+ def extract_dict_feat(self, para, idx):
83
+ """
84
+ This function is to extract dictionary features for each character
85
+ """
86
+ length = len(para)
87
+
88
+ dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
89
+ dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
90
+ forward_word = para[idx][0]
91
+ backward_word = para[idx][0]
92
+ prefix = True
93
+ suffix = True
94
+ for window in range(1,self.args['num_dict_feat']+1):
95
+ # concatenate each character and check if words found in dict not, stop if prefix not found
96
+ #check if idx+t is out of bound and if the prefix is already not found
97
+ if (idx + window) <= length-1 and prefix:
98
+ forward_word += para[idx+window][0].lower()
99
+ #check in json file if the word is present as prefix or word or None.
100
+ feat = 1 if forward_word in self.dictionary["words"] else 0
101
+ #if the return value is not 2 or 3 then the checking word is not a valid word in dict.
102
+ dict_forward_feats[window-1] = feat
103
+ #if the dict return 0 means no prefixes found, thus, stop looking for forward.
104
+ if forward_word not in self.dictionary["prefixes"]:
105
+ prefix = False
106
+ #backward check: similar to forward
107
+ if (idx - window) >= 0 and suffix:
108
+ backward_word = para[idx-window][0].lower() + backward_word
109
+ feat = 1 if backward_word in self.dictionary["words"] else 0
110
+ dict_backward_feats[window-1] = feat
111
+ if backward_word not in self.dictionary["suffixes"]:
112
+ suffix = False
113
+ #if cannot find both prefix and suffix, then exit the loop
114
+ if not prefix and not suffix:
115
+ break
116
+
117
+ return dict_forward_feats + dict_backward_feats
118
+
119
+ def para_to_sentences(self, para):
120
+ """ Convert a paragraph to a list of processed sentences. """
121
+ res = []
122
+ funcs = []
123
+ for feat_func in self.args['feat_funcs']:
124
+ if feat_func == 'end_of_para' or feat_func == 'start_of_para':
125
+ # skip for position-dependent features
126
+ continue
127
+ if feat_func == 'space_before':
128
+ func = lambda x: 1 if x.startswith(' ') else 0
129
+ elif feat_func == 'capitalized':
130
+ func = lambda x: 1 if x[0].isupper() else 0
131
+ elif feat_func == 'numeric':
132
+ func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
133
+ else:
134
+ raise ValueError('Feature function "{}" is undefined.'.format(feat_func))
135
+
136
+ funcs.append(func)
137
+
138
+ # stacking all featurize functions
139
+ composite_func = lambda x: [f(x) for f in funcs]
140
+
141
+ def process_sentence(sent_units, sent_labels, sent_feats):
142
+ return (np.array([self.vocab.unit2id(y) for y in sent_units]),
143
+ np.array(sent_labels),
144
+ np.array(sent_feats),
145
+ list(sent_units))
146
+
147
+ use_end_of_para = 'end_of_para' in self.args['feat_funcs']
148
+ use_start_of_para = 'start_of_para' in self.args['feat_funcs']
149
+ use_dictionary = self.args['use_dictionary']
150
+ current_units = []
151
+ current_labels = []
152
+ current_feats = []
153
+ for i, (unit, label) in enumerate(para):
154
+ feats = composite_func(unit)
155
+ # position-dependent features
156
+ if use_end_of_para:
157
+ f = 1 if i == len(para)-1 else 0
158
+ feats.append(f)
159
+ if use_start_of_para:
160
+ f = 1 if i == 0 else 0
161
+ feats.append(f)
162
+
163
+ #if dictionary feature is selected
164
+ if use_dictionary:
165
+ dict_feats = self.extract_dict_feat(para, i)
166
+ feats = feats + dict_feats
167
+
168
+ current_units.append(unit)
169
+ current_labels.append(label)
170
+ current_feats.append(feats)
171
+ if not self.eval and (label == 2 or label == 4): # end of sentence
172
+ if len(current_units) <= self.args['max_seqlen']:
173
+ # get rid of sentences that are too long during training of the tokenizer
174
+ res.append(process_sentence(current_units, current_labels, current_feats))
175
+ current_units.clear()
176
+ current_labels.clear()
177
+ current_feats.clear()
178
+
179
+ if len(current_units) > 0:
180
+ if self.eval or len(current_units) <= self.args['max_seqlen']:
181
+ res.append(process_sentence(current_units, current_labels, current_feats))
182
+
183
+ return res
184
+
185
+ def advance_old_batch(self, eval_offsets, old_batch):
186
+ """
187
+ Advance to a new position in a batch where we have partially processed the batch
188
+
189
+ If we have previously built a batch of data and made predictions on them, then when we are trying to make
190
+ prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
191
+ and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
192
+ In this case, eval_offsets index within the old_batch to advance the strings to process.
193
+ """
194
+ unkid = self.vocab.unit2id('<UNK>')
195
+ padid = self.vocab.unit2id('<PAD>')
196
+
197
+ ounits, olabels, ofeatures, oraw = old_batch
198
+ feat_size = ofeatures.shape[-1]
199
+ lens = (ounits != padid).sum(1).tolist()
200
+ pad_len = max(l-i for i, l in zip(eval_offsets, lens))
201
+
202
+ units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64)
203
+ labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32)
204
+ features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32)
205
+ raw_units = []
206
+
207
+ for i in range(len(ounits)):
208
+ eval_offsets[i] = min(eval_offsets[i], lens[i])
209
+ units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]
210
+ labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]
211
+ features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]
212
+ raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + ['<PAD>'] * (pad_len - lens[i] + eval_offsets[i]))
213
+
214
+ return units, labels, features, raw_units
215
+
216
+ class DataLoader(TokenizationDataset):
217
+ """
218
+ This is the training version of the dataset.
219
+ """
220
+ def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None):
221
+ super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)
222
+
223
+ self.vocab = vocab if vocab is not None else self.init_vocab()
224
+
225
+ # data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels.
226
+ # At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where
227
+ # sentence breaks occur. We make prediction from left to right for each paragraph and move forward to
228
+ # the last predicted sentence break to start afresh.
229
+ self.sentences = [self.para_to_sentences(para) for para in self.data]
230
+
231
+ self.init_sent_ids()
232
+ logger.debug(f"{len(self.sentence_ids)} sentences loaded.")
233
+
234
+ def __len__(self):
235
+ return len(self.sentence_ids)
236
+
237
+ def init_vocab(self):
238
+ vocab = Vocab(self.data, self.args['lang'])
239
+ return vocab
240
+
241
+ def init_sent_ids(self):
242
+ self.sentence_ids = []
243
+ self.cumlen = [0]
244
+ for i, para in enumerate(self.sentences):
245
+ for j in range(len(para)):
246
+ self.sentence_ids += [(i, j)]
247
+ self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])]
248
+
249
+ def has_mwt(self):
250
+ # presumably this only needs to be called either 0 or 1 times,
251
+ # 1 when training and 0 any other time, so no effort is put
252
+ # into caching the result
253
+ for sentence in self.data:
254
+ for word in sentence:
255
+ if word[1] > 2:
256
+ return True
257
+ return False
258
+
259
+ def shuffle(self):
260
+ for para in self.sentences:
261
+ random.shuffle(para)
262
+ self.init_sent_ids()
263
+
264
+ def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0):
265
+ ''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''
266
+ feat_size = len(self.sentences[0][0][2][0])
267
+ unkid = self.vocab.unit2id('<UNK>')
268
+ padid = self.vocab.unit2id('<PAD>')
269
+
270
+ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
271
+ # At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed
272
+ # by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences
273
+ # from the entire dataset until we reach max_seqlen.
274
+ pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)
275
+ sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]
276
+
277
+ drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0))
278
+ drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))
279
+ total_len = len(sentences[0][0])
280
+
281
+ assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])]))
282
+ if self.eval:
283
+ for sid1 in range(sid+1, len(self.sentences[pid])):
284
+ total_len += len(self.sentences[pid][sid1][0])
285
+ sentences.append(self.sentences[pid][sid1])
286
+
287
+ if total_len >= self.args['max_seqlen']:
288
+ break
289
+ else:
290
+ while True:
291
+ pid1, sid1 = random.choice(self.sentence_ids)
292
+ total_len += len(self.sentences[pid1][sid1][0])
293
+ sentences.append(self.sentences[pid1][sid1])
294
+
295
+ if total_len >= self.args['max_seqlen']:
296
+ break
297
+
298
+ if drop_sents and len(sentences) > 1:
299
+ if total_len > self.args['max_seqlen']:
300
+ sentences = sentences[:-1]
301
+ if len(sentences) > 1:
302
+ p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability
303
+ cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0]
304
+ sentences = sentences[:cutoff+1]
305
+
306
+ units = np.concatenate([s[0] for s in sentences])
307
+ labels = np.concatenate([s[1] for s in sentences])
308
+ feats = np.concatenate([s[2] for s in sentences])
309
+ raw_units = [x for s in sentences for x in s[3]]
310
+
311
+ if not self.eval:
312
+ cutoff = self.args['max_seqlen']
313
+ units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff]
314
+
315
+ if drop_last_char: # can only happen in non-eval mode
316
+ if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3):
317
+ # training text ended with a sentence end position
318
+ # and that word was a single character
319
+ # and the previous character ended the word
320
+ units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1]
321
+ # word end -> sentence end, mwt end -> sentence mwt end
322
+ labels[-1] = labels[-1] + 1
323
+
324
+ return units, labels, feats, raw_units
325
+
326
+ if eval_offsets is not None:
327
+ # find max padding length
328
+ pad_len = 0
329
+ for eval_offset in eval_offsets:
330
+ if eval_offset < self.cumlen[-1]:
331
+ pair_id = bisect_right(self.cumlen, eval_offset) - 1
332
+ pair = self.sentence_ids[pair_id]
333
+ pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0]))
334
+
335
+ pad_len += 1
336
+ id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets]
337
+ pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs]
338
+ offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)]
339
+
340
+ offsets_pairs = list(zip(offsets, pairs))
341
+ else:
342
+ id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size']))
343
+ offsets_pairs = [(0, x) for x in id_pairs]
344
+ pad_len = self.args['max_seqlen']
345
+
346
+ # put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors
347
+ units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64)
348
+ labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64)
349
+ features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32)
350
+ raw_units = []
351
+ for i, (offset, pair) in enumerate(offsets_pairs):
352
+ u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len)
353
+ units[i, :len(u_)] = u_
354
+ labels[i, :len(l_)] = l_
355
+ features[i, :len(f_), :] = f_
356
+ raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))
357
+
358
+ if unit_dropout > 0 and not self.eval:
359
+ # dropout characters/units at training time and replace them with UNKs
360
+ mask = np.random.random_sample(units.shape) < unit_dropout
361
+ mask[units == padid] = 0
362
+ units[mask] = unkid
363
+ for i in range(len(raw_units)):
364
+ for j in range(len(raw_units[i])):
365
+ if mask[i, j]:
366
+ raw_units[i][j] = '<UNK>'
367
+
368
+ # dropout unit feature vector in addition to only torch.dropout in the model.
369
+ # experiments showed that only torch.dropout hurts the model
370
+ # we believe it is because the dict feature vector is mostly scarse so it makes
371
+ # more sense to drop out the whole vector instead of only single element.
372
+ if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
373
+ mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
374
+ mask_feat[units == padid] = 0
375
+ for i in range(len(raw_units)):
376
+ for j in range(len(raw_units[i])):
377
+ if mask_feat[i,j]:
378
+ features[i,j,:] = 0
379
+
380
+ units = torch.from_numpy(units)
381
+ labels = torch.from_numpy(labels)
382
+ features = torch.from_numpy(features)
383
+
384
+ return units, labels, features, raw_units
385
+
386
+ class SortedDataset(Dataset):
387
+ """
388
+ Holds a TokenizationDataset for use in a torch DataLoader
389
+
390
+ The torch DataLoader is different from the DataLoader defined here
391
+ and allows for cpu & gpu parallelism. Updating output_predictions
392
+ to use this class as a wrapper to a TokenizationDataset means the
393
+ calculation of features can happen in parallel, saving quite a
394
+ bit of time.
395
+ """
396
+ def __init__(self, dataset):
397
+ super().__init__()
398
+
399
+ self.dataset = dataset
400
+ self.data, self.indices = sort_with_indices(self.dataset.data, key=len)
401
+
402
+ def __len__(self):
403
+ return len(self.data)
404
+
405
+ def __getitem__(self, index):
406
+ return self.dataset.para_to_sentences(self.data[index])
407
+
408
+ def unsort(self, arr):
409
+ return unsort(arr, self.indices)
410
+
411
+ def collate(self, samples):
412
+ if any(len(x) > 1 for x in samples):
413
+ raise ValueError("Expected all paragraphs to have no preset sentence splits!")
414
+ feat_size = samples[0][0][2].shape[-1]
415
+ padid = self.dataset.vocab.unit2id('<PAD>')
416
+
417
+ # +1 so that all samples end with at least one pad
418
+ pad_len = max(len(x[0][3]) for x in samples) + 1
419
+
420
+ units = torch.full((len(samples), pad_len), padid, dtype=torch.int64)
421
+ labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32)
422
+ features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32)
423
+ raw_units = []
424
+ for i, sample in enumerate(samples):
425
+ u_, l_, f_, r_ = sample[0]
426
+ units[i, :len(u_)] = torch.from_numpy(u_)
427
+ labels[i, :len(l_)] = torch.from_numpy(l_)
428
+ features[i, :len(f_), :] = torch.from_numpy(f_)
429
+ raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))
430
+
431
+ return units, labels, features, raw_units
432
+
stanza/stanza/models/tokenization/model.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+
5
+ class Tokenizer(nn.Module):
6
+ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
7
+ super().__init__()
8
+
9
+ self.args = args
10
+ feat_dim = args['feat_dim']
11
+
12
+ self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)
13
+
14
+ self.rnn = nn.LSTM(emb_dim + feat_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)
15
+
16
+ if self.args['conv_res'] is not None:
17
+ self.conv_res = nn.ModuleList()
18
+ self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')]
19
+
20
+ for si, size in enumerate(self.conv_sizes):
21
+ l = nn.Conv1d(emb_dim + feat_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))
22
+ self.conv_res.append(l)
23
+
24
+ if self.args.get('hier_conv_res', False):
25
+ self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1)
26
+ self.tok_clf = nn.Linear(hidden_dim * 2, 1)
27
+ self.sent_clf = nn.Linear(hidden_dim * 2, 1)
28
+ if self.args['use_mwt']:
29
+ self.mwt_clf = nn.Linear(hidden_dim * 2, 1)
30
+
31
+ if args['hierarchical']:
32
+ in_dim = hidden_dim * 2
33
+ self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
34
+ self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
35
+ self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
36
+ if self.args['use_mwt']:
37
+ self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
38
+
39
+ self.dropout = nn.Dropout(dropout)
40
+ self.dropout_feat = nn.Dropout(feat_dropout)
41
+
42
+ self.toknoise = nn.Dropout(self.args['tok_noise'])
43
+
44
+ def forward(self, x, feats):
45
+ emb = self.embeddings(x)
46
+ emb = self.dropout(emb)
47
+ feats = self.dropout_feat(feats)
48
+
49
+
50
+ emb = torch.cat([emb, feats], 2)
51
+
52
+ inp, _ = self.rnn(emb)
53
+
54
+ if self.args['conv_res'] is not None:
55
+ conv_input = emb.transpose(1, 2).contiguous()
56
+ if not self.args.get('hier_conv_res', False):
57
+ for l in self.conv_res:
58
+ inp = inp + l(conv_input).transpose(1, 2).contiguous()
59
+ else:
60
+ hid = []
61
+ for l in self.conv_res:
62
+ hid += [l(conv_input)]
63
+ hid = torch.cat(hid, 1)
64
+ hid = F.relu(hid)
65
+ hid = self.dropout(hid)
66
+ inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous()
67
+
68
+ inp = self.dropout(inp)
69
+
70
+ tok0 = self.tok_clf(inp)
71
+ sent0 = self.sent_clf(inp)
72
+ if self.args['use_mwt']:
73
+ mwt0 = self.mwt_clf(inp)
74
+
75
+ if self.args['hierarchical']:
76
+ if self.args['hier_invtemp'] > 0:
77
+ inp2, _ = self.rnn2(inp * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp']))))
78
+ else:
79
+ inp2, _ = self.rnn2(inp)
80
+
81
+ inp2 = self.dropout(inp2)
82
+
83
+ tok0 = tok0 + self.tok_clf2(inp2)
84
+ sent0 = sent0 + self.sent_clf2(inp2)
85
+ if self.args['use_mwt']:
86
+ mwt0 = mwt0 + self.mwt_clf2(inp2)
87
+
88
+ nontok = F.logsigmoid(-tok0)
89
+ tok = F.logsigmoid(tok0)
90
+ nonsent = F.logsigmoid(-sent0)
91
+ sent = F.logsigmoid(sent0)
92
+ if self.args['use_mwt']:
93
+ nonmwt = F.logsigmoid(-mwt0)
94
+ mwt = F.logsigmoid(mwt0)
95
+
96
+ if self.args['use_mwt']:
97
+ pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2)
98
+ else:
99
+ pred = torch.cat([nontok, tok+nonsent, tok+sent], 2)
100
+
101
+ return pred
stanza/stanza/models/tokenization/tokenize_files.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Use a Stanza tokenizer to turn a text file into one tokenized paragraph per line
2
+
3
+ For example, the output of this script is suitable for Glove
4
+
5
+ Currently this *only* supports tokenization, no MWT splitting.
6
+ It also would be beneficial to have an option to convert spaces into
7
+ NBSP, underscore, or some other marker to make it easier to process
8
+ languages such as VI which have spaces in them
9
+ """
10
+
11
+
12
+ import argparse
13
+ import io
14
+ import os
15
+ import time
16
+ import re
17
+ import zipfile
18
+
19
+ import torch
20
+
21
+ import stanza
22
+ from stanza.models.common.utils import open_read_text, default_device
23
+ from stanza.models.tokenization.data import TokenizationDataset
24
+ from stanza.models.tokenization.utils import output_predictions
25
+ from stanza.pipeline.tokenize_processor import TokenizeProcessor
26
+ from stanza.utils.get_tqdm import get_tqdm
27
+
28
+ tqdm = get_tqdm()
29
+
30
+ NEWLINE_SPLIT_RE = re.compile(r"\n\s*\n")
31
+
32
+ def tokenize_to_file(tokenizer, fin, fout, chunk_size=500):
33
+ raw_text = fin.read()
34
+ documents = NEWLINE_SPLIT_RE.split(raw_text)
35
+ for chunk_start in tqdm(range(0, len(documents), chunk_size), leave=False):
36
+ chunk_end = min(chunk_start + chunk_size, len(documents))
37
+ chunk = documents[chunk_start:chunk_end]
38
+ in_docs = [stanza.Document([], text=d) for d in chunk]
39
+ out_docs = tokenizer.bulk_process(in_docs)
40
+ for document in out_docs:
41
+ for sent_idx, sentence in enumerate(document.sentences):
42
+ if sent_idx > 0:
43
+ fout.write(" ")
44
+ fout.write(" ".join(x.text for x in sentence.tokens))
45
+ fout.write("\n")
46
+
47
+ def main(args=None):
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--lang", type=str, default="sd", help="Which language to use for tokenization")
50
+ parser.add_argument("--tokenize_model_path", type=str, default=None, help="Specific tokenizer model to use")
51
+ parser.add_argument("input_files", type=str, nargs="+", help="Which input files to tokenize")
52
+ parser.add_argument("--output_file", type=str, default="glove.txt", help="Where to write the tokenized output")
53
+ parser.add_argument("--model_dir", type=str, default=None, help="Where to get models for a Pipeline (None => default models dir)")
54
+ parser.add_argument("--chunk_size", type=int, default=500, help="How many 'documents' to use in a chunk when tokenizing. This is separate from the tokenizer batching - this limits how much memory gets used at once, since we don't need to store an entire file in memory at once")
55
+ args = parser.parse_args(args=args)
56
+
57
+ if os.path.exists(args.output_file):
58
+ print("Cowardly refusing to overwrite existing output file %s" % args.output_file)
59
+ return
60
+
61
+ if args.tokenize_model_path:
62
+ config = { "model_path": args.tokenize_model_path,
63
+ "check_requirements": False }
64
+ tokenizer = TokenizeProcessor(config, pipeline=None, device=default_device())
65
+ else:
66
+ pipe = stanza.Pipeline(lang=args.lang, processors="tokenize", model_dir=args.model_dir)
67
+ tokenizer = pipe.processors["tokenize"]
68
+
69
+ with open(args.output_file, "w", encoding="utf-8") as fout:
70
+ for filename in tqdm(args.input_files):
71
+ if filename.endswith(".zip"):
72
+ with zipfile.ZipFile(filename) as zin:
73
+ input_names = zin.namelist()
74
+ for input_name in tqdm(input_names, leave=False):
75
+ with zin.open(input_names[0]) as fin:
76
+ fin = io.TextIOWrapper(fin, encoding='utf-8')
77
+ tokenize_to_file(tokenizer, fin, fout)
78
+ else:
79
+ with open_read_text(filename, encoding="utf-8") as fin:
80
+ tokenize_to_file(tokenizer, fin, fout)
81
+
82
+ if __name__ == '__main__':
83
+ main()
stanza/stanza/models/tokenization/trainer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+
7
+ from stanza.models.common import utils
8
+ from stanza.models.common.trainer import Trainer as BaseTrainer
9
+ from stanza.models.tokenization.utils import create_dictionary
10
+
11
+ from .model import Tokenizer
12
+ from .vocab import Vocab
13
+
14
+ logger = logging.getLogger('stanza')
15
+
16
+ class Trainer(BaseTrainer):
17
+ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None):
18
+ if model_file is not None:
19
+ # load everything from file
20
+ self.load(model_file)
21
+ else:
22
+ # build model from scratch
23
+ self.args = args
24
+ self.vocab = vocab
25
+ self.lexicon = list(lexicon) if lexicon is not None else None
26
+ self.dictionary = dictionary
27
+ self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
28
+ self.model = self.model.to(device)
29
+ self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
30
+ self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay'])
31
+ self.feat_funcs = self.args.get('feat_funcs', None)
32
+ self.lang = self.args['lang'] # language determines how token normalization is done
33
+
34
+ def update(self, inputs):
35
+ self.model.train()
36
+ units, labels, features, _ = inputs
37
+
38
+ device = next(self.model.parameters()).device
39
+ units = units.to(device)
40
+ labels = labels.to(device)
41
+ features = features.to(device)
42
+
43
+ pred = self.model(units, features)
44
+
45
+ self.optimizer.zero_grad()
46
+ classes = pred.size(2)
47
+ loss = self.criterion(pred.view(-1, classes), labels.view(-1))
48
+
49
+ loss.backward()
50
+ nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
51
+ self.optimizer.step()
52
+
53
+ return loss.item()
54
+
55
+ def predict(self, inputs):
56
+ self.model.eval()
57
+ units, _, features, _ = inputs
58
+
59
+ device = next(self.model.parameters()).device
60
+ units = units.to(device)
61
+ features = features.to(device)
62
+
63
+ pred = self.model(units, features)
64
+
65
+ return pred.data.cpu().numpy()
66
+
67
+ def save(self, filename):
68
+ params = {
69
+ 'model': self.model.state_dict() if self.model is not None else None,
70
+ 'vocab': self.vocab.state_dict(),
71
+ # save and load lexicon as list instead of set so
72
+ # we can use weights_only=True
73
+ 'lexicon': list(self.lexicon) if self.lexicon is not None else None,
74
+ 'config': self.args
75
+ }
76
+ try:
77
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
78
+ logger.info("Model saved to {}".format(filename))
79
+ except BaseException:
80
+ logger.warning("Saving failed... continuing anyway.")
81
+
82
+ def load(self, filename):
83
+ try:
84
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
85
+ except BaseException:
86
+ logger.error("Cannot load model from {}".format(filename))
87
+ raise
88
+ self.args = checkpoint['config']
89
+ if self.args.get('use_mwt', None) is None:
90
+ # Default to True as many currently saved models
91
+ # were built with mwt layers
92
+ self.args['use_mwt'] = True
93
+ self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
94
+ self.model.load_state_dict(checkpoint['model'])
95
+ self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
96
+ self.lexicon = checkpoint['lexicon']
97
+
98
+ if self.lexicon is not None:
99
+ self.lexicon = set(self.lexicon)
100
+ self.dictionary = create_dictionary(self.lexicon)
101
+ else:
102
+ self.dictionary = None
stanza/stanza/utils/datasets/constituency/convert_ctb.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import glob
3
+ import os
4
+ import re
5
+
6
+ import xml.etree.ElementTree as ET
7
+
8
+ from stanza.models.constituency import tree_reader
9
+ from stanza.utils.datasets.constituency.utils import write_dataset
10
+ from stanza.utils.get_tqdm import get_tqdm
11
+
12
+ tqdm = get_tqdm()
13
+
14
+ class Version(Enum):
15
+ V51 = 1
16
+ V51b = 2
17
+ V90 = 3
18
+
19
+ def filenum_to_shard_51(filenum):
20
+ if filenum >= 1 and filenum <= 815:
21
+ return 0
22
+ if filenum >= 1001 and filenum <= 1136:
23
+ return 0
24
+
25
+ if filenum >= 886 and filenum <= 931:
26
+ return 1
27
+ if filenum >= 1148 and filenum <= 1151:
28
+ return 1
29
+
30
+ if filenum >= 816 and filenum <= 885:
31
+ return 2
32
+ if filenum >= 1137 and filenum <= 1147:
33
+ return 2
34
+
35
+ raise ValueError("Unhandled filenum %d" % filenum)
36
+
37
+ def filenum_to_shard_51_basic(filenum):
38
+ if filenum >= 1 and filenum <= 270:
39
+ return 0
40
+ if filenum >= 440 and filenum <= 1151:
41
+ return 0
42
+
43
+ if filenum >= 301 and filenum <= 325:
44
+ return 1
45
+
46
+ if filenum >= 271 and filenum <= 300:
47
+ return 2
48
+
49
+ if filenum >= 400 and filenum <= 439:
50
+ return None
51
+
52
+ raise ValueError("Unhandled filenum %d" % filenum)
53
+
54
+ def filenum_to_shard_90(filenum):
55
+ if filenum >= 1 and filenum <= 40:
56
+ return 2
57
+ if filenum >= 900 and filenum <= 931:
58
+ return 2
59
+ if filenum in (1018, 1020, 1036, 1044, 1060, 1061, 1072, 1118, 1119, 1132, 1141, 1142, 1148):
60
+ return 2
61
+ if filenum >= 2165 and filenum <= 2180:
62
+ return 2
63
+ if filenum >= 2295 and filenum <= 2310:
64
+ return 2
65
+ if filenum >= 2570 and filenum <= 2602:
66
+ return 2
67
+ if filenum >= 2800 and filenum <= 2819:
68
+ return 2
69
+ if filenum >= 3110 and filenum <= 3145:
70
+ return 2
71
+
72
+
73
+ if filenum >= 41 and filenum <= 80:
74
+ return 1
75
+ if filenum >= 1120 and filenum <= 1129:
76
+ return 1
77
+ if filenum >= 2140 and filenum <= 2159:
78
+ return 1
79
+ if filenum >= 2280 and filenum <= 2294:
80
+ return 1
81
+ if filenum >= 2550 and filenum <= 2569:
82
+ return 1
83
+ if filenum >= 2775 and filenum <= 2799:
84
+ return 1
85
+ if filenum >= 3080 and filenum <= 3109:
86
+ return 1
87
+
88
+ if filenum >= 81 and filenum <= 900:
89
+ return 0
90
+ if filenum >= 1001 and filenum <= 1017:
91
+ return 0
92
+ if filenum in (1019, 1130, 1131):
93
+ return 0
94
+ if filenum >= 1021 and filenum <= 1035:
95
+ return 0
96
+ if filenum >= 1037 and filenum <= 1043:
97
+ return 0
98
+ if filenum >= 1045 and filenum <= 1059:
99
+ return 0
100
+ if filenum >= 1062 and filenum <= 1071:
101
+ return 0
102
+ if filenum >= 1073 and filenum <= 1117:
103
+ return 0
104
+ if filenum >= 1133 and filenum <= 1140:
105
+ return 0
106
+ if filenum >= 1143 and filenum <= 1147:
107
+ return 0
108
+ if filenum >= 1149 and filenum <= 2139:
109
+ return 0
110
+ if filenum >= 2160 and filenum <= 2164:
111
+ return 0
112
+ if filenum >= 2181 and filenum <= 2279:
113
+ return 0
114
+ if filenum >= 2311 and filenum <= 2549:
115
+ return 0
116
+ if filenum >= 2603 and filenum <= 2774:
117
+ return 0
118
+ if filenum >= 2820 and filenum <= 3079:
119
+ return 0
120
+ if filenum >= 4000 and filenum <= 7017:
121
+ return 0
122
+
123
+
124
+ def collect_trees_s(root):
125
+ if root.tag == 'S':
126
+ yield root.text, root.attrib['ID']
127
+
128
+ for child in root:
129
+ for tree in collect_trees_s(child):
130
+ yield tree
131
+
132
+ def collect_trees_text(root):
133
+ if root.tag == 'TEXT' and len(root.text.strip()) > 0:
134
+ yield root.text, None
135
+
136
+ if root.tag == 'TURN' and len(root.text.strip()) > 0:
137
+ yield root.text, None
138
+
139
+ for child in root:
140
+ for tree in collect_trees_text(child):
141
+ yield tree
142
+
143
+
144
+ id_re = re.compile("<S ID=([0-9a-z]+)>")
145
+ su_re = re.compile("<(su|msg) id=([0-9a-zA-Z_=]+)>")
146
+
147
+ def convert_ctb(input_dir, output_dir, dataset_name, version):
148
+ input_files = glob.glob(os.path.join(input_dir, "*"))
149
+
150
+ # train, dev, test
151
+ datasets = [[], [], []]
152
+
153
+ sorted_filenames = []
154
+ for input_filename in input_files:
155
+ base_filename = os.path.split(input_filename)[1]
156
+ filenum = int(os.path.splitext(base_filename)[0].split("_")[1])
157
+ sorted_filenames.append((filenum, input_filename))
158
+ sorted_filenames.sort()
159
+
160
+ for filenum, filename in tqdm(sorted_filenames):
161
+ if version in (Version.V51, Version.V51b):
162
+ with open(filename, errors='ignore', encoding="gb2312") as fin:
163
+ text = fin.read()
164
+ elif version is Version.V90:
165
+ with open(filename, encoding="utf-8") as fin:
166
+ text = fin.read()
167
+ if text.find("<TURN>") >= 0 and text.find("</TURN>") < 0:
168
+ text = text.replace("<TURN>", "")
169
+ if filenum in (4205, 4208, 4289):
170
+ text = text.replace("<)", "&lt;)").replace(">)", "&gt;)")
171
+ if filenum >= 4000 and filenum <= 4411:
172
+ if text.find("<segment") >= 0:
173
+ text = text.replace("<segment id=", "<S ID=").replace("</segment>", "</S>")
174
+ elif text.find("<seg") < 0:
175
+ text = "<TEXT>\n%s</TEXT>\n" % text
176
+ else:
177
+ text = text.replace("<seg id=", "<S ID=").replace("</seg>", "</S>")
178
+ text = "<foo>\n%s</foo>\n" % text
179
+ if filenum >= 5000 and filenum <= 5558 or filenum >= 6000 and filenum <= 6700 or filenum >= 7000 and filenum <= 7017:
180
+ text = su_re.sub("", text)
181
+ if filenum in (6066, 6453):
182
+ text = text.replace("<", "&lt;").replace(">", "&gt;")
183
+ text = "<foo><TEXT>\n%s</TEXT></foo>\n" % text
184
+ else:
185
+ raise ValueError("Unknown CTB version %s" % version)
186
+ text = id_re.sub(r'<S ID="\1">', text)
187
+ text = text.replace("&", "&amp;")
188
+
189
+ try:
190
+ xml_root = ET.fromstring(text)
191
+ except Exception as e:
192
+ print(text[:1000])
193
+ raise RuntimeError("Cannot xml process %s" % filename) from e
194
+ trees = [x for x in collect_trees_s(xml_root)]
195
+ if version is Version.V90 and len(trees) == 0:
196
+ trees = [x for x in collect_trees_text(xml_root)]
197
+
198
+ if version in (Version.V51, Version.V51b):
199
+ trees = [x[0] for x in trees if filenum != 414 or x[1] != "4366"]
200
+ else:
201
+ trees = [x[0] for x in trees]
202
+
203
+ trees = "\n".join(trees)
204
+ try:
205
+ trees = tree_reader.read_trees(trees, use_tqdm=False)
206
+ except ValueError as e:
207
+ print(text[:300])
208
+ raise RuntimeError("Could not process the tree text in %s" % filename)
209
+ trees = [t.prune_none().simplify_labels() for t in trees]
210
+
211
+ assert len(trees) > 0, "No trees in %s" % filename
212
+
213
+ if version is Version.V51:
214
+ shard = filenum_to_shard_51(filenum)
215
+ elif version is Version.V51b:
216
+ shard = filenum_to_shard_51_basic(filenum)
217
+ else:
218
+ shard = filenum_to_shard_90(filenum)
219
+ if shard is None:
220
+ continue
221
+ datasets[shard].extend(trees)
222
+
223
+
224
+ write_dataset(datasets, output_dir, dataset_name)
stanza/stanza/utils/datasets/constituency/extract_silver_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ After running build_silver_dataset.py, this extracts the trees of a certain match level
3
+
4
+ For example
5
+
6
+ python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score 0 --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_0.mrg
7
+
8
+ for i in `echo 0 1 2 3 4 5 6 7 8 9 10`; do python3 stanza/utils/datasets/constituency/extract_silver_dataset.py --parsed_trees /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/a*.trees --keep_score $i --output_file /u/nlp/data/constituency-parser/italian/2024_it_vit_electra/it_silver_$i.mrg; done
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(description="After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy")
16
+ parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.')
17
+ parser.add_argument('--keep_score', type=int, default=None, help='Which agreement level to keep. None keeps all')
18
+ parser.add_argument('--output_file', type=str, default=None, help='Where to put the output file')
19
+ args = parser.parse_args()
20
+
21
+ return args
22
+
23
+
24
+ def main():
25
+ args = parse_args()
26
+
27
+ trees = []
28
+ for filename in args.parsed_trees:
29
+ with open(filename, encoding='utf-8') as fin:
30
+ for line in fin.readlines():
31
+ tree = json.loads(line)
32
+ if args.keep_score is None or tree['count'] == args.keep_score:
33
+ tree = tree['tree']
34
+ trees.append(tree)
35
+
36
+ if args.output_file is None:
37
+ for tree in trees:
38
+ print(tree)
39
+ else:
40
+ with open(args.output_file, 'w', encoding='utf-8') as fout:
41
+ for tree in trees:
42
+ fout.write(tree)
43
+ fout.write('\n')
44
+
45
+ if __name__ == '__main__':
46
+ main()
47
+
stanza/stanza/utils/datasets/coref/balance_languages.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ balance_concat.py
3
+ create a test set from a dev set which is language balanced
4
+ """
5
+
6
+ import json
7
+ from collections import defaultdict
8
+
9
+ from random import Random
10
+
11
+ # fix random seed for reproducability
12
+ R = Random(42)
13
+
14
+ with open("./corefud_concat_v1_0_langid.train.json", 'r') as df:
15
+ raw = json.load(df)
16
+
17
+ # calculate type of each class; then, we will select the one
18
+ # which has the LOWEST counts as the sample rate
19
+ lang_counts = defaultdict(int)
20
+ for i in raw:
21
+ lang_counts[i["lang"]] += 1
22
+
23
+ min_lang_count = min(lang_counts.values())
24
+
25
+ # sample 20% of the smallest amount for test set
26
+ # this will look like an absurdly small number, but
27
+ # remember this is DOCUMENTS not TOKENS or UTTERANCES
28
+ # so its actually decent
29
+ # also its per language
30
+ test_set_size = int(0.1*min_lang_count)
31
+
32
+ # sampling input by language
33
+ raw_by_language = defaultdict(list)
34
+ for i in raw:
35
+ raw_by_language[i["lang"]].append(i)
36
+ languages = list(set(raw_by_language.keys()))
37
+
38
+ train_set = []
39
+ test_set = []
40
+ for i in languages:
41
+ length = list(range(len(raw_by_language[i])))
42
+ choices = R.sample(length, test_set_size)
43
+
44
+ for indx,i in enumerate(raw_by_language[i]):
45
+ if indx in choices:
46
+ test_set.append(i)
47
+ else:
48
+ train_set.append(i)
49
+
50
+ with open("./corefud_concat_v1_0_langid-bal.train.json", 'w') as df:
51
+ json.dump(train_set, df, indent=2)
52
+
53
+ with open("./corefud_concat_v1_0_langid-bal.test.json", 'w') as df:
54
+ json.dump(test_set, df, indent=2)
55
+
56
+
57
+
58
+ # raw_by_language["en"]
59
+
60
+