bowphs commited on
Commit
6cd9428
·
verified ·
1 Parent(s): 9cbeb98

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stanza/stanza/models/constituency_parser.py +881 -0
  2. stanza/stanza/models/lemmatizer.py +313 -0
  3. stanza/stanza/pipeline/_constants.py +13 -0
  4. stanza/stanza/pipeline/external/spacy.py +74 -0
  5. stanza/stanza/pipeline/ner_processor.py +143 -0
  6. stanza/stanza/resources/print_charlm_depparse.py +22 -0
  7. stanza/stanza/server/dependency_converter.py +101 -0
  8. stanza/stanza/tests/classifiers/test_constituency_classifier.py +128 -0
  9. stanza/stanza/tests/common/__init__.py +0 -0
  10. stanza/stanza/tests/common/test_chuliu_edmonds.py +36 -0
  11. stanza/stanza/tests/common/test_confusion.py +81 -0
  12. stanza/stanza/tests/common/test_constant.py +67 -0
  13. stanza/stanza/tests/common/test_data_conversion.py +520 -0
  14. stanza/stanza/tests/common/test_foundation_cache.py +36 -0
  15. stanza/stanza/tests/common/test_pretrain.py +139 -0
  16. stanza/stanza/tests/common/test_utils.py +194 -0
  17. stanza/stanza/tests/constituency/__init__.py +0 -0
  18. stanza/stanza/tests/constituency/test_convert_arboretum.py +235 -0
  19. stanza/stanza/tests/constituency/test_ensemble.py +110 -0
  20. stanza/stanza/tests/constituency/test_in_order_compound_oracle.py +93 -0
  21. stanza/stanza/tests/constituency/test_parse_transitions.py +486 -0
  22. stanza/stanza/tests/constituency/test_parse_tree.py +369 -0
  23. stanza/stanza/tests/constituency/test_positional_encoding.py +45 -0
  24. stanza/stanza/tests/constituency/test_selftrain_vi_quad.py +23 -0
  25. stanza/stanza/tests/constituency/test_utils.py +68 -0
  26. stanza/stanza/tests/data/example_french.json +22 -0
  27. stanza/stanza/tests/data/test.dat +0 -0
  28. stanza/stanza/tests/data/tiny_emb.csv +4 -0
  29. stanza/stanza/tests/datasets/__init__.py +0 -0
  30. stanza/stanza/tests/datasets/ner/__init__.py +0 -0
  31. stanza/stanza/tests/datasets/ner/test_prepare_ner_file.py +77 -0
  32. stanza/stanza/tests/datasets/ner/test_utils.py +34 -0
  33. stanza/stanza/tests/lemma/test_data.py +106 -0
  34. stanza/stanza/tests/lemma/test_lemma_trainer.py +154 -0
  35. stanza/stanza/tests/lemma_classifier/test_data_preparation.py +256 -0
  36. stanza/stanza/tests/mwt/test_character_classifier.py +92 -0
  37. stanza/stanza/tests/mwt/test_english_corner_cases.py +88 -0
  38. stanza/stanza/tests/ner/test_bsf_2_iob.py +93 -0
  39. stanza/stanza/tests/ner/test_convert_amt.py +104 -0
  40. stanza/stanza/tests/ner/test_convert_starlang_ner.py +23 -0
  41. stanza/stanza/tests/ner/test_from_conllu.py +30 -0
  42. stanza/stanza/tests/ner/test_ner_utils.py +129 -0
  43. stanza/stanza/tests/pipeline/__init__.py +0 -0
  44. stanza/stanza/tests/pipeline/test_arabic_pipeline.py +27 -0
  45. stanza/stanza/tests/pipeline/test_core.py +248 -0
  46. stanza/stanza/tests/pipeline/test_depparse.py +87 -0
  47. stanza/stanza/tests/pipeline/test_english_pipeline.py +279 -0
  48. stanza/stanza/tests/pipeline/test_french_pipeline.py +353 -0
  49. stanza/stanza/tests/pipeline/test_lemmatizer.py +135 -0
  50. stanza/stanza/tests/pipeline/test_pipeline_constituency_processor.py +61 -0
stanza/stanza/models/constituency_parser.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A command line interface to a shift reduce constituency parser.
2
+
3
+ This follows the work of
4
+ Recurrent neural network grammars by Dyer et al
5
+ In-Order Transition-based Constituent Parsing by Liu & Zhang
6
+
7
+ The general outline is:
8
+
9
+ Train a model by taking a list of trees, converting them to
10
+ transition sequences, and learning a model which can predict the
11
+ next transition given a current state
12
+ Then, at inference time, repeatedly predict the next transition until parsing is complete
13
+
14
+ The "transitions" are variations on shift/reduce as per an
15
+ intro-to-compilers class. The idea is that you can treat all of the
16
+ words in a sentence as a buffer of tokens, then either "shift" them to
17
+ represent a new constituent, or "reduce" one or more constituents to
18
+ form a new constituent.
19
+
20
+ In order to make the runtime a more competitive speed, effort is taken
21
+ to batch the transitions and apply multiple transitions at once. At
22
+ train time, batches are groups together by length, and at inference
23
+ time, new trees are added to the batch as previous trees on the batch
24
+ finish their inference.
25
+
26
+ There are a few minor differences in the model:
27
+ - The word input is a bi-lstm, not a uni-lstm.
28
+ This gave a small increase in accuracy.
29
+ - The combination of several constituents into one constituent is done
30
+ via a single bi-lstm rather than two separate lstms. This increases
31
+ speed without a noticeable effect on accuracy.
32
+ - In fact, an even better (in terms of final model accuracy) method
33
+ is to combine the constituents with torch.max, believe it or not
34
+ See lstm_model.py for more details
35
+ - Initializing the embeddings with smaller values than pytorch default
36
+ For example, on a ja_alt dataset, scores went from 0.8980 to 0.8985
37
+ at 200 iterations averaged over 5 trials
38
+ - Training with AdaDelta first, then AdamW or madgrad later improves
39
+ results quite a bit. See --multistage
40
+
41
+ A couple experiments which have been tried with little noticeable impact:
42
+ - Combining constituents using the method in the paper (only a trained
43
+ vector at the start instead of both ends) did not affect results
44
+ and is a little slower
45
+ - Using multiple layers of LSTM hidden state for the input to the final
46
+ classification layers didn't help
47
+ - Initializing Linear layers with He initialization and a positive bias
48
+ (to avoid dead connections) had no noticeable effect on accuracy
49
+ 0.8396 on it_turin with the original initialization
50
+ 0.8401 and 0.8427 on two runs with updated initialization
51
+ (so maybe a small improvement...)
52
+ - Initializing LSTM layers with different gates was slightly worse:
53
+ forget gates of 1.0
54
+ forget gates of 1.0, input gates of -1.0
55
+ - Replacing the LSTMs that make up the Transition and Constituent
56
+ LSTMs with Dynamic Skip LSTMs made no difference, but was slower
57
+ - Highway LSTMs also made no difference
58
+ - Putting labels on the shift transitions (the word or the tag shifted)
59
+ or putting labels on the close transitions didn't help
60
+ - Building larger constituents from the output of the constituent LSTM
61
+ instead of the children constituents hurts scores
62
+ For example, an experiment on ja_alt went from 0.8985 to 0.8964
63
+ when built that way
64
+ - The initial transition scheme implemented was TOP_DOWN. We tried
65
+ a compound unary option, since this worked so well in the CoreNLP
66
+ constituency parser. Unfortunately, this is far less effective
67
+ than IN_ORDER. Both specialized unary matrices and reusing the
68
+ n-ary constituency combination fell short. On the ja_alt dataset:
69
+ IN_ORDER, max combination method: 0.8985
70
+ TOP_DOWN_UNARY, specialized matrices: 0.8501
71
+ TOP_DOWN_UNARY, max combination method: 0.8508
72
+ - Adding multiple layers of MLP to combine inputs for words made
73
+ no difference in the scores
74
+ Tried both before the LSTM and after
75
+ A simple single layer tensor multiply after the LSTM works well.
76
+ Replacing that with a two layer MLP on the English PTB
77
+ with roberta-base causes a notable drop in scores
78
+ First experiment didn't use the fancy Linear weight init,
79
+ but adding that barely made a difference
80
+ 260 training iterations on en_wsj dev, roberta-base
81
+ model as of bb983fd5e912f6706ad484bf819486971742c3d1
82
+ two layer MLP: 0.9409
83
+ two layer MLP, init weights: 0.9413
84
+ single layer: 0.9467
85
+ - There is code to rebuild models with a new structure in lstm_model.py
86
+ As part of this, we tried to randomly reinitialize the transitions
87
+ if the transition embedding had gone to 0, which often happens
88
+ This didn't help at all
89
+ - We tried something akin to attention with just the query vector
90
+ over the bert embeddings as a way to mix them, but that did not
91
+ improve scores.
92
+ Example, with a self.bert_layer_mix of size bert_dim x 1:
93
+ mixed_bert_embeddings = []
94
+ for feature in bert_embeddings:
95
+ weighted_feature = self.bert_layer_mix(feature.transpose(1, 2))
96
+ weighted_feature = torch.softmax(weighted_feature, dim=1)
97
+ weighted_feature = torch.matmul(feature, weighted_feature).squeeze(2)
98
+ mixed_bert_embeddings.append(weighted_feature)
99
+ bert_embeddings = mixed_bert_embeddings
100
+ It seems just finetuning the transformer is already enough
101
+ (in general, no need to mix layers at all when finetuning bert embeddings)
102
+
103
+
104
+ The code breakdown is as follows:
105
+
106
+ this file: main interface for training or evaluating models
107
+ constituency/trainer.py: contains the training & evaluation code
108
+ constituency/ensemble.py: evaluation code specifically for letting multiple models
109
+ vote on the correct next transition. a modest improvement.
110
+ constituency/evaluate_treebanks.py: specifically to evaluate multiple parsed treebanks
111
+ against a gold. in particular, reports whether the theoretical best from those
112
+ parsed treebanks is an improvement (eg, the k-best score as reported by CoreNLP)
113
+
114
+ constituency/parse_tree.py: a data structure for representing a parse tree and utility methods
115
+ constituency/tree_reader.py: a module which can read trees from a string or input file
116
+
117
+ constituency/tree_stack.py: a linked list which can branch in
118
+ different directions, which will be useful when implementing beam
119
+ search or a dynamic oracle
120
+ constituency/lstm_tree_stack.py: an LSTM over the elements of a TreeStack
121
+ constituency/transformer_tree_stack.py: attempts to run attention over the nodes
122
+ of a tree_stack. not as effective as the lstm_tree_stack in the initial experiments.
123
+ perhaps it could be refined to work better, though
124
+
125
+ constituency/parse_transitions.py: transitions and a State data structure to store them
126
+ constituency/transition_sequence.py: turns ParseTree objects into
127
+ the transition sequences needed to make them
128
+
129
+ constituency/base_model.py: operates on the transitions to turn them in to constituents,
130
+ eventually forming one final parse tree composed of all of the constituents
131
+ constituency/lstm_model.py: adds LSTM features to the constituents to predict what the
132
+ correct transition to make is, allowing for predictions on previously unseen text
133
+
134
+ constituency/retagging.py: a couple utility methods specifically for retagging
135
+ constituency/utils.py: a couple utility methods
136
+
137
+ constituency/dyanmic_oracle.py: a dynamic oracle which currently
138
+ only operates for the inorder transition sequence.
139
+ uses deterministic rules to redo the correct action sequence when
140
+ the parser makes an error.
141
+
142
+ constituency/partitioned_transformer.py: implementation of a transformer for self-attention.
143
+ presumably this should help, but we have yet to find a model structure where
144
+ this makes the scores go up.
145
+ constituency/label_attention.py: an even fancier form of transformer based on labeled attention:
146
+ https://arxiv.org/abs/1911.03875
147
+ constituency/positional_encoding.py: so far, just the sinusoidal is here.
148
+ a trained encoding is in partitioned_transformer.py.
149
+ this should probably be refactored to common, especially if used elsewhere.
150
+
151
+ stanza/pipeline/constituency_processor.py: interface between this model and the Pipeline
152
+
153
+ stanza/utils/datasets/constituency: various scripts and tools for processing constituency datasets
154
+
155
+ Some alternate optimizer methods:
156
+ adabelief: https://github.com/juntang-zhuang/Adabelief-Optimizer
157
+ madgrad: https://github.com/facebookresearch/madgrad
158
+
159
+ """
160
+
161
+ import argparse
162
+ import logging
163
+ import os
164
+ import re
165
+
166
+ import torch
167
+
168
+ import stanza
169
+ from stanza.models.common import constant
170
+ from stanza.models.common import utils
171
+ from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
172
+ from stanza.models.constituency import parser_training
173
+ from stanza.models.constituency import retagging
174
+ from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory
175
+ from stanza.models.constituency.parse_transitions import TransitionScheme
176
+ from stanza.models.constituency.text_processing import load_model_parse_text
177
+ from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, NONLINEARITY, add_predict_output_args, postprocess_predict_output_args
178
+ from stanza.resources.common import DEFAULT_MODEL_DIR
179
+
180
+ logger = logging.getLogger('stanza')
181
+ tlogger = logging.getLogger('stanza.constituency.trainer')
182
+
183
+ def build_argparse():
184
+ """
185
+ Adds the arguments for building the con parser
186
+
187
+ For the most part, defaults are set to cross-validated values, at least for WSJ
188
+ """
189
+ parser = argparse.ArgumentParser()
190
+
191
+ parser.add_argument('--data_dir', type=str, default='data/constituency', help='Directory of constituency data.')
192
+
193
+ parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors')
194
+ parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
195
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
196
+ parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
197
+
198
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
199
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
200
+
201
+ # BERT helps a lot and actually doesn't slow things down too much
202
+ # for VI, for example, use vinai/phobert-base
203
+ parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
204
+ parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
205
+ parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer")
206
+ parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')
207
+
208
+ # BERT finetuning (or any transformer finetuning)
209
+ # also helps quite a lot.
210
+ # Experimentally, finetuning all of the layers is the most effective
211
+ # On the id_icon dataset with the indolem transformer
212
+ # In this experiment, we trained for 150 iterations with AdaDelta,
213
+ # with the learning rate 0.01,
214
+ # then trained for another 150 with madgrad and no finetuning
215
+ # 1 layer 0.880753 (152)
216
+ # 2 layers 0.880453 (174)
217
+ # 3 layers 0.881774 (163)
218
+ # 4 layers 0.886915 (194)
219
+ # 5 layers 0.892064 (299)
220
+ # 6 layers 0.891825 (224)
221
+ # 7 layers 0.894373 (173)
222
+ # 8 layers 0.894505 (233)
223
+ # 9 layers 0.896676 (269)
224
+ # 10 layers 0.897525 (269)
225
+ # 11 layers 0.897348 (211)
226
+ # 12 layers 0.898729 (270)
227
+ # everything 0.898855 (252)
228
+ # so the trend is clear that more finetuning is better
229
+ #
230
+ # We found that finetuning works very well on the AdaDelta portion
231
+ # of a multistage training, but less well on a madgrad second
232
+ # stage. The issue was that we literally could not set the
233
+ # learning rate low enough because madgrad used epsilon in the LR:
234
+ # https://github.com/facebookresearch/madgrad/issues/16
235
+ #
236
+ # Possible values of the AdaDelta learning rate on the id_icon dataset
237
+ # In this experiment, we finetuned the entire transformer 150
238
+ # iterations on AdaDelta, then trained with madgrad for another
239
+ # 150 with no finetuning
240
+ # 0.0005: 0.89122 (155)
241
+ # 0.001: 0.889807 (241)
242
+ # 0.002: 0.894874 (202)
243
+ # 0.005: 0.896327 (270)
244
+ # 0.006: 0.898989 (246)
245
+ # 0.007: 0.896712 (167)
246
+ # 0.008: 0.900136 (237)
247
+ # 0.009: 0.898597 (169)
248
+ # 0.01: 0.898665 (251)
249
+ # 0.012: 0.89661 (274)
250
+ # 0.014: 0.899149 (283)
251
+ # 0.016: 0.896314 (230)
252
+ # 0.018: 0.897753 (257)
253
+ # 0.02: 0.893665 (256)
254
+ # 0.05: 0.849274 (159)
255
+ # 0.1: 0.850633 (183)
256
+ # 0.2: 0.847332 (176)
257
+ #
258
+ # The peak is somewhere around 0.008 to 0.014, with the further
259
+ # observation that at the 150 iteration mark, 0.09 was winning:
260
+ # 0.007: 0.894589 (33)
261
+ # 0.008: 0.894777 (53)
262
+ # 0.009: 0.896466 (56)
263
+ # 0.01: 0.895557 (71)
264
+ # 0.012: 0.893479 (45)
265
+ # 0.014: 0.89468 (116)
266
+ # 0.016: 0.893053 (128)
267
+ # 0.018: 0.893086 (48)
268
+ #
269
+ # Another option is to train for a few iterations with no
270
+ # finetuning, then begin finetuning. However, that was not
271
+ # beneficial at all.
272
+ # Start iteration on id_icon, same setup as above:
273
+ # 1: 0.898855 (252)
274
+ # 5: 0.897885 (217)
275
+ # 10: 0.895367 (215)
276
+ # 25: 0.896781 (193)
277
+ # 50: 0.895216 (193)
278
+ # Using adamw instead of madgrad:
279
+ # 1: 0.900594 (226)
280
+ # 5: 0.898153 (267)
281
+ # 10: 0.898756 (271)
282
+ # 25: 0.896867 (256)
283
+ # 50: 0.895025 (220)
284
+ #
285
+ #
286
+ # With the observation that very low learning rate is currently
287
+ # not working for madgrad, we tried to parameter sweep LR for
288
+ # AdamW, and got the following, using a first stage LR of 0.009:
289
+ # 0.0: 0.899706 (290)
290
+ # 0.00005: 0.899631 (176)
291
+ # 0.0001: 0.899851 (233)
292
+ # 0.0002: 0.898601 (207)
293
+ # 0.0003: 0.899258 (252)
294
+ # 0.0004: 0.90033 (187)
295
+ # 0.0005: 0.899091 (183)
296
+ # 0.001: 0.899791 (268)
297
+ # 0.002: 0.899453 (196)
298
+ # 0.003: 0.897029 (173)
299
+ # 0.004: 0.899566 (290)
300
+ # 0.005: 0.899285 (289)
301
+ # 0.01: 0.898938 (233)
302
+ # 0.02: 0.898983 (248)
303
+ # 0.03: 0.898571 (247)
304
+ # 0.04: 0.898466 (180)
305
+ # 0.05: 0.897448 (214)
306
+ # It should be noted that in the 0.0001 range, the epoch to epoch
307
+ # change of the Bert weights was almost negligible. Weights would
308
+ # change in the 5th or 6th decimal place, if at all.
309
+ #
310
+ # The conclusion of all these experiments is that, if we are using
311
+ # bert_finetuning, the best approach is probably a stage1 learning
312
+ # rate of 0.009 or so and a second stage optimizer of adamw with
313
+ # no LR or a very low LR. This behavior is what happens with the
314
+ # --stage1_bert_finetune flag
315
+ parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
316
+ parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
317
+ parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer')
318
+ parser.add_argument('--bert_finetune_begin_epoch', default=None, type=int, help='Which epoch to start finetuning the transformer')
319
+ parser.add_argument('--bert_finetune_end_epoch', default=None, type=int, help='Which epoch to stop finetuning the transformer')
320
+ parser.add_argument('--bert_learning_rate', default=0.009, type=float, help='Scale the learning rate for transformer finetuning by this much')
321
+ parser.add_argument('--stage1_bert_learning_rate', default=None, type=float, help="Scale the learning rate for transformer finetuning by this much only during an AdaDelta warmup")
322
+ parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much')
323
+ parser.add_argument('--stage1_bert_finetune', default=None, action='store_true', help="Finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune")
324
+ parser.add_argument('--no_stage1_bert_finetune', dest='stage1_bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune")
325
+
326
+ add_peft_args(parser)
327
+
328
+ parser.add_argument('--tag_embedding_dim', type=int, default=20, help="Embedding size for a tag. 0 turns off the feature")
329
+ # Smaller values also seem to work
330
+ # For example, after 700 iterations:
331
+ # 32: 0.9174
332
+ # 50: 0.9183
333
+ # 72: 0.9176
334
+ # 100: 0.9185
335
+ # not a huge difference regardless
336
+ # (these numbers were without retagging)
337
+ parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding")
338
+
339
+ parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
340
+ parser.add_argument('--no_train_remove_duplicates', default=True, action='store_false', dest="train_remove_duplicates", help="Do/don't remove duplicates from the training file. Could be useful for intentionally reweighting some trees")
341
+ parser.add_argument('--silver_file', type=str, default=None, help='Secondary training file.')
342
+ parser.add_argument('--silver_remove_duplicates', default=False, action='store_true', help="Do/don't remove duplicates from the silver training file. Could be useful for intentionally reweighting some trees")
343
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
344
+ # TODO: possibly refactor --tokenized_file / --tokenized_dir from here & ensemble
345
+ parser.add_argument('--xml_tree_file', type=str, default=None, help='Input file of VLSP formatted trees for parsing with parse_text.')
346
+ parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
347
+ parser.add_argument('--tokenized_dir', type=str, default=None, help='Input directory of tokenized text for parsing with parse_text.')
348
+ parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer'])
349
+ parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one')
350
+ add_predict_output_args(parser)
351
+
352
+ parser.add_argument('--lang', type=str, help='Language')
353
+ parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
354
+
355
+ parser.add_argument('--transition_embedding_dim', type=int, default=20, help="Embedding size for a transition")
356
+ parser.add_argument('--transition_hidden_size', type=int, default=20, help="Embedding size for transition stack")
357
+ parser.add_argument('--transition_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],
358
+ help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory)))
359
+ parser.add_argument('--transition_heads', default=4, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention")
360
+
361
+ parser.add_argument('--constituent_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],
362
+ help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory)))
363
+ parser.add_argument('--constituent_heads', default=8, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention")
364
+
365
+ # larger was more effective, up to a point
366
+ # substantially smaller, such as 128,
367
+ # is fine if bert & charlm are not available
368
+ parser.add_argument('--hidden_size', type=int, default=512, help="Size of the output layers for constituency stack and word queue")
369
+
370
+ parser.add_argument('--epochs', type=int, default=400)
371
+ parser.add_argument('--epoch_size', type=int, default=5000, help="Runs this many trees in an 'epoch' instead of going through the training dataset exactly once. Set to 0 to do the whole training set")
372
+ parser.add_argument('--silver_epoch_size', type=int, default=None, help="Runs this many trees in a silver 'epoch'. If not set, will match --epoch_size")
373
+
374
+ # AdaDelta warmup for the conparser. Motivation: AdaDelta results in
375
+ # higher scores overall, but learns 0s for the weights of the pattn and
376
+ # lattn layers. AdamW learns weights for pattn, and the models are more
377
+ # accurate than models trained without pattn using AdamW, but the models
378
+ # are lower scores overall than the AdaDelta models.
379
+ #
380
+ # This improves that by first running AdaDelta, then switching.
381
+ #
382
+ # Now, if --multistage is set, run AdaDelta for half the epochs with no
383
+ # pattn or lattn. Then start the specified optimizer for the rest of
384
+ # the time with the full model. If pattn and lattn are both present,
385
+ # the model is 1/2 no attn, 1/4 pattn, 1/4 pattn and lattn
386
+ #
387
+ # Improvement on the WSJ dev set can be seen from 94.8 to 95.3
388
+ # when 4 layers of pattn are trained this way.
389
+ # More experiments to follow.
390
+ parser.add_argument('--multistage', default=True, action='store_true', help='1/2 epochs with adadelta no pattn or lattn, 1/4 with chosen optim and no lattn, 1/4 full model')
391
+ parser.add_argument('--no_multistage', dest='multistage', action='store_false', help="don't do the multistage learning")
392
+
393
+ # 1 seems to be the most effective, but we should cross-validate
394
+ parser.add_argument('--oracle_initial_epoch', type=int, default=1, help="Epoch where we start using the dynamic oracle to let the parser keep going with wrong decisions")
395
+ parser.add_argument('--oracle_frequency', type=float, default=0.8, help="How often to use the oracle vs how often to force the correct transition")
396
+ parser.add_argument('--oracle_forced_errors', type=float, default=0.001, help="Occasionally have the model randomly walk through the state space to try to learn how to recover")
397
+ parser.add_argument('--oracle_level', type=int, default=None, help='Restrict oracle transitions to this level or lower. 0 means off. None means use all oracle transitions.')
398
+ parser.add_argument('--additional_oracle_levels', type=str, default=None, help='Add some additional experimental oracle transitions. Basically for A/B testing transitions we expect to be bad.')
399
+ parser.add_argument('--deactivated_oracle_levels', type=str, default=None, help='Temporarily turn off a default oracle level. Basically for A/B testing transitions we expect to be bad.')
400
+
401
+ # 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ
402
+ # earlier version of the model (less accurate overall) had the following results with adadelta:
403
+ # 30: 0.9085
404
+ # 50: 0.9070
405
+ # 75: 0.9010
406
+ # 150: 0.8985
407
+ # as another data point, running a newer version with better constituency lstm behavior had:
408
+ # 30: 0.9111
409
+ # 50: 0.9094
410
+ # checking smaller batch sizes to see how this works, at 135 epochs, the values are
411
+ # 10: 0.8919
412
+ # 20: 0.9072
413
+ # 30: 0.9121
414
+ # obviously these experiments aren't the complete story, but it
415
+ # looks like 30 trees per batch is the best value for WSJ
416
+ # note that these numbers are for adadelta and might not apply
417
+ # to other optimizers
418
+ # eval batch should generally be faster the bigger the batch,
419
+ # up to a point, as it allows for more batching of the LSTM
420
+ # operations and the prediction step
421
+ parser.add_argument('--train_batch_size', type=int, default=30, help='How many trees to train before taking an optimizer step')
422
+ parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
423
+
424
+ parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.')
425
+ parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_constituency.pt", help="File name to save the model")
426
+ parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing")
427
+ parser.add_argument('--save_each_start', type=int, default=None, help="When to start saving each model")
428
+ parser.add_argument('--save_each_frequency', type=int, default=1, help="How frequently to save each model")
429
+ parser.add_argument('--no_save_each_optimizer', dest='save_each_optimizer', default=True, action='store_false', help="Don't save the optimizer when saving 'each' model")
430
+
431
+ parser.add_argument('--seed', type=int, default=1234)
432
+
433
+ parser.add_argument('--no_check_valid_states', default=True, action='store_false', dest='check_valid_states', help="Don't check the constituents or transitions in the dev set when starting a new parser. Warning: the parser will never guess unknown constituents")
434
+ parser.add_argument('--no_strict_check_constituents', default=True, action='store_false', dest='strict_check_constituents', help="Don't check the constituents between the train & dev set. May result in untrainable transitions")
435
+ utils.add_device_args(parser)
436
+
437
+ # Numbers are on a VLSP dataset, before adding attn or other improvements
438
+ # baseline is an 80.6 model that occurs when trained using adadelta, lr 1.0
439
+ #
440
+ # adabelief 0.1: fails horribly
441
+ # 0.02: converges very low scores
442
+ # 0.01: very slow learning
443
+ # 0.002: almost decent
444
+ # 0.001: close, but about 1 f1 low on IT
445
+ # 0.0005: 79.71
446
+ # 0.0002: 80.11
447
+ # 0.0001: 79.85
448
+ # 0.00005: 80.40
449
+ # 0.00002: 80.02
450
+ # 0.00001: 78.95
451
+ #
452
+ # madgrad 0.005: fails horribly
453
+ # 0.001: low scores
454
+ # 0.0005: still somewhat low
455
+ # 0.0002: close, but about 1 f1 low on IT
456
+ # 0.0001: 80.04
457
+ # 0.00005: 79.91
458
+ # 0.00002: 80.15
459
+ # 0.00001: 80.44
460
+ # 0.000005: 80.34
461
+ # 0.000002: 80.39
462
+ #
463
+ # adamw experiment on a TR dataset (not necessarily the best test case)
464
+ # note that at that time, the expected best for adadelta was 0.816
465
+ #
466
+ # 0.00005 - 0.7925
467
+ # 0.0001 - 0.7889
468
+ # 0.0002 - 0.8110
469
+ # 0.00025 - 0.8108
470
+ # 0.0003 - 0.8050
471
+ # 0.0005 - 0.8076
472
+ # 0.001 - 0.8069
473
+
474
+ # Numbers on the VLSP Dataset, with --multistage and default learning rates and adabelief optimizer
475
+ # Gelu: 82.32
476
+ # Mish: 81.95
477
+ # ELU: 81.73
478
+ # Hardshrink: 0.3
479
+ # Hardsigmoid: 79.03
480
+ # Hardtanh: 81.44
481
+ # Hardswish: 81.67
482
+ # Logsigmoid: 80.91
483
+ # Prelu: 80.95 (terminated early)
484
+ # Relu6: 81.91
485
+ # RReLU: 77.00
486
+ # Selu: 81.17
487
+ # Celu: 81.43
488
+ # Silu: 81.90
489
+ # Softplus: 80.94
490
+ # Softshrink: 0.3
491
+ # Softsign: 81.63
492
+ # Softshrink: 13.74
493
+ #
494
+ # Tests with no_charlm, --multitstage
495
+ # Gelu
496
+ # 0.00002 0.819746
497
+ # 0.00005 0.818
498
+ # 0.0001 0.818566
499
+ # 0.0002 0.819111
500
+ # 0.001 0.815609
501
+ #
502
+ # Mish
503
+ # 0.00002 0.816898
504
+ # 0.00005 0.821085
505
+ # 0.0001 0.817821
506
+ # 0.0002 0.818806
507
+ # 0.001 0.816494
508
+ #
509
+ # Relu
510
+ # 0.00002 0.818402
511
+ # 0.00005 0.819019
512
+ # 0.0001 0.821625
513
+ # 0.0002 0.820633
514
+ # 0.001 0.814315
515
+ #
516
+ # Relu6
517
+ # 0.00002 0.819719
518
+ # 0.00005 0.819871
519
+ # 0.0001 0.819018
520
+ # 0.0002 0.819506
521
+ # 0.001 0.819018
522
+
523
+ parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate for the optimizer. Reasonable values are 1.0 for adadelta or 0.001 for SGD. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_RATES))
524
+ parser.add_argument('--learning_eps', default=None, type=float, help='eps value to use in the optimizer. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_EPS))
525
+ parser.add_argument('--learning_momentum', default=None, type=float, help='Momentum. None uses a default for the given optimizer: {}'.format(DEFAULT_MOMENTUM))
526
+ # weight decay values other than adadelta have not been thoroughly tested.
527
+ # When using adadelta, weight_decay of 0.01 to 0.001 had the best results.
528
+ # 0.1 was very clearly too high. 0.0001 might have been okay.
529
+ # Running a series of 5x experiments on a VI dataset:
530
+ # 0.030: 0.8167018
531
+ # 0.025: 0.81659
532
+ # 0.020: 0.81722
533
+ # 0.015: 0.81721
534
+ # 0.010: 0.81474348
535
+ # 0.005: 0.81503
536
+ parser.add_argument('--learning_weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer')
537
+ parser.add_argument('--learning_rho', default=DEFAULT_LEARNING_RHO, type=float, help='Rho parameter in Adadelta')
538
+ # A few experiments on beta2 didn't show much benefit from changing it
539
+ # On an experiment with training WSJ with default parameters
540
+ # AdaDelta for 200 iterations, then training AdamW for 200 more,
541
+ # 0.999, 0.997, 0.995 all wound up with 0.9588
542
+ # values lower than 0.995 all had a slight dropoff
543
+ parser.add_argument('--learning_beta2', default=0.999, type=float, help='Beta2 argument for AdamW')
544
+ parser.add_argument('--optim', default=None, help='Optimizer type: SGD, AdamW, Adadelta, AdaBelief, Madgrad')
545
+
546
+ parser.add_argument('--stage1_learning_rate', default=None, type=float, help='Learning rate to use in the first stage of --multistage. None means use default: {}'.format(DEFAULT_LEARNING_RATES['adadelta']))
547
+
548
+ parser.add_argument('--learning_rate_warmup', default=0, type=int, help="Number of epochs to ramp up learning rate from 0 to full. Set to 0 to always use the chosen learning rate. Currently not functional, as it didn't do anything")
549
+
550
+ parser.add_argument('--learning_rate_factor', default=0.6, type=float, help='Plateau learning rate decreate when plateaued')
551
+ parser.add_argument('--learning_rate_patience', default=5, type=int, help='Plateau learning rate patience')
552
+ parser.add_argument('--learning_rate_cooldown', default=10, type=int, help='Plateau learning rate cooldown')
553
+ parser.add_argument('--learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum')
554
+ parser.add_argument('--stage1_learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum (stage 1)')
555
+
556
+ parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
557
+ parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')
558
+
559
+ # Large Margin is from Large Margin In Softmax Cross-Entropy Loss
560
+ # it did not help on an Italian VIT test
561
+ # scores went from 0.8252 to 0.8248
562
+ parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal. Focal requires `pip install focal_loss_torch`')
563
+ parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')
564
+
565
+ # turn off dropout for word_dropout, predict_dropout, and lstm_input_dropout
566
+ # this mechanism doesn't actually turn off lstm_layer_dropout (yet)
567
+ # but that is set to a default of 0 anyway
568
+ # this is reusing the idea presented in
569
+ # https://arxiv.org/pdf/2303.01500v2
570
+ # "Dropout Reduces Underfitting"
571
+ # Zhuang Liu, Zhiqiu Xu, Joseph Jin, Zhiqiang Shen, Trevor Darrell
572
+ # Unfortunately, this does not consistently help results
573
+ # Averaged of 5 models w/ transformer, dev / test
574
+ # id_icon - improves a little
575
+ # baseline 0.8823 0.8904
576
+ # early_dropout 40 0.8835 0.8919
577
+ # ja_alt - worsens a little
578
+ # baseline 0.9308 0.9355
579
+ # early_dropout 40 0.9287 0.9345
580
+ # vi_vlsp23 - worsens a little
581
+ # baseline 0.8262 0.8290
582
+ # early_dropout 40 0.8255 0.8286
583
+ # We keep this as an available option for further experiments, if needed
584
+ parser.add_argument('--early_dropout', default=-1, type=int, help='When to turn off dropout')
585
+ # When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations:
586
+ # 0.0: 0.9085
587
+ # 0.2: 0.9165
588
+ # 0.4: 0.9162
589
+ # 0.5: 0.9123
590
+ # Letting 0.2 and 0.4 run for longer, along with 0.3 as another
591
+ # trial, continued to give extremely similar results over time.
592
+ # No attempt has been made to test the different dropouts separately...
593
+ parser.add_argument('--word_dropout', default=0.2, type=float, help='Dropout on the word embedding')
594
+ parser.add_argument('--predict_dropout', default=0.2, type=float, help='Dropout on the final prediction layer')
595
+ # lstm_dropout has not been fully tested yet
596
+ # one experiment after 200 iterations (after retagging, so scores are lower than some other experiments):
597
+ # 0.0: 0.9093
598
+ # 0.1: 0.9094
599
+ # 0.2: 0.9094
600
+ # 0.3: 0.9076
601
+ # 0.4: 0.9077
602
+ parser.add_argument('--lstm_layer_dropout', default=0.0, type=float, help='Dropout in the LSTM layers')
603
+ # one not very conclusive experiment (not long enough) came up with these numbers after ~200 iterations
604
+ # 0.0 0.9091
605
+ # 0.1 0.9095
606
+ # 0.2 0.9118
607
+ # 0.3 0.9123
608
+ # 0.4 0.9080
609
+ parser.add_argument('--lstm_input_dropout', default=0.2, type=float, help='Dropout on the input to an LSTM')
610
+
611
+ parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],
612
+ help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme)))
613
+
614
+ parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed')
615
+
616
+ # combining dummy and open node embeddings might be a slight improvement
617
+ # for example, after 550 iterations, one experiment had
618
+ # True: 0.9154
619
+ # False: 0.9150
620
+ # another (with a different structure) had 850 iterations
621
+ # True: 0.9155
622
+ # False: 0.9149
623
+ parser.add_argument('--combined_dummy_embedding', default=True, action='store_true', help="Use the same embedding for dummy nodes and the vectors used when combining constituents")
624
+ parser.add_argument('--no_combined_dummy_embedding', dest='combined_dummy_embedding', action='store_false', help="Don't use the same embedding for dummy nodes and the vectors used when combining constituents")
625
+
626
+ # relu gave at least 1 F1 improvement over tanh in various experiments
627
+ # relu & gelu seem roughly the same, but relu is clearly faster.
628
+ # relu, 496 iterations: 0.9176
629
+ # gelu, 467 iterations: 0.9181
630
+ # after the same clock time on the same hardware. the two had been
631
+ # trading places in terms of accuracy over those ~500 iterations.
632
+ # leaky_relu was not an improvement - a full run on WSJ led to 0.9181 f1 instead of 0.919
633
+ # See constituency/utils.py for more extensive comments on nonlinearity options
634
+ parser.add_argument('--nonlinearity', default='relu', choices=NONLINEARITY.keys(), help='Nonlinearity to use in the model. relu is a noticeable improvement over tanh')
635
+ # In one experiment on an Italian dataset, VIT, we got the following:
636
+ # 0.8254 with relu as the nonlinearity (10 trials)
637
+ # 0.8265 with maxout, k = 2 (15)
638
+ # 0.8253 with maxout, k = 3 (5)
639
+ # The speed in terms of trees/second might be slightly slower with maxout.
640
+ # 51.4 it/s on a Titan Xp with maxout 2 and 51.9 it/s with relu
641
+ # It might also be worth running some experiments with bigger
642
+ # output layers to see if that makes up for the difference in score.
643
+ parser.add_argument('--maxout_k', default=None, type=int, help="Use maxout layers instead of a nonlinearity for the output layers")
644
+
645
+ parser.add_argument('--use_silver_words', default=True, dest='use_silver_words', action='store_true', help="Train/don't train word vectors for words only in the silver dataset")
646
+ parser.add_argument('--no_use_silver_words', default=True, dest='use_silver_words', action='store_false', help="Train/don't train word vectors for words only in the silver dataset")
647
+ parser.add_argument('--rare_word_unknown_frequency', default=0.02, type=float, help='How often to replace a rare word with UNK when training')
648
+ parser.add_argument('--rare_word_threshold', default=0.02, type=float, help='How many words to consider as rare words as a fraction of the dataset')
649
+ parser.add_argument('--tag_unknown_frequency', default=0.001, type=float, help='How often to replace a tag with UNK when training')
650
+
651
+ parser.add_argument('--num_lstm_layers', default=2, type=int, help='How many layers to use in the LSTMs')
652
+ parser.add_argument('--num_tree_lstm_layers', default=None, type=int, help='How many layers to use in the TREE_LSTMs, if used. This also increases the width of the word outputs to match the tree lstm inputs. Default 2 if TREE_LSTM or TREE_LSTM_CX, 1 otherwise')
653
+ parser.add_argument('--num_output_layers', default=3, type=int, help='How many layers to use at the prediction level')
654
+
655
+ parser.add_argument('--sentence_boundary_vectors', default=SentenceBoundary.EVERYTHING, type=lambda x: SentenceBoundary[x.upper()],
656
+ help='Vectors to learn at the start & end of sentences. {}'.format(", ".join(x.name for x in SentenceBoundary)))
657
+ parser.add_argument('--constituency_composition', default=ConstituencyComposition.MAX, type=lambda x: ConstituencyComposition[x.upper()],
658
+ help='How to build a new composition from its children. {}'.format(", ".join(x.name for x in ConstituencyComposition)))
659
+ parser.add_argument('--reduce_heads', default=8, type=int, help='Number of attn heads to use when reducing children into a parent tree (constituency_composition == attn)')
660
+ parser.add_argument('--reduce_position', default=None, type=int, help="Dimension of position vector to use when reducing children. None means 1/4 hidden_size, 0 means don't use (constituency_composition == key | untied_key)")
661
+
662
+ parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw')
663
+ parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path')
664
+ parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
665
+ parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
666
+ parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file')
667
+ parser.add_argument('--load_package', type=str, default=None, help='Download an existing stanza package & use this for tests, finetuning, etc')
668
+
669
+ retagging.add_retag_args(parser)
670
+
671
+ # Partitioned Attention
672
+ parser.add_argument('--pattn_d_model', default=1024, type=int, help='Partitioned attention model dimensionality')
673
+ parser.add_argument('--pattn_morpho_emb_dropout', default=0.2, type=float, help='Dropout rate for morphological features obtained from pretrained model')
674
+ parser.add_argument('--pattn_encoder_max_len', default=512, type=int, help='Max length that can be put into the transformer attention layer')
675
+ parser.add_argument('--pattn_num_heads', default=8, type=int, help='Partitioned attention model number of attention heads')
676
+ parser.add_argument('--pattn_d_kv', default=64, type=int, help='Size of the query and key vector')
677
+ parser.add_argument('--pattn_d_ff', default=2048, type=int, help='Size of the intermediate vectors in the feed-forward sublayer')
678
+ parser.add_argument('--pattn_relu_dropout', default=0.1, type=float, help='ReLU dropout probability in feed-forward sublayer')
679
+ parser.add_argument('--pattn_residual_dropout', default=0.2, type=float, help='Residual dropout probability for all residual connections')
680
+ parser.add_argument('--pattn_attention_dropout', default=0.2, type=float, help='Attention dropout probability')
681
+ parser.add_argument('--pattn_num_layers', default=0, type=int, help='Number of layers for the Partitioned Attention. Currently turned off')
682
+ parser.add_argument('--pattn_bias', default=False, action='store_true', help='Whether or not to learn an additive bias')
683
+ # Results seem relatively similar with learned position embeddings or sin/cos position embeddings
684
+ parser.add_argument('--pattn_timing', default='sin', choices=['learned', 'sin'], help='Use a learned embedding or a sin embedding')
685
+
686
+ # Label Attention
687
+ parser.add_argument('--lattn_d_input_proj', default=None, type=int, help='If set, project the non-positional inputs down to this size before proceeding.')
688
+ parser.add_argument('--lattn_d_kv', default=64, type=int, help='Dimension of the key/query vector')
689
+ parser.add_argument('--lattn_d_proj', default=64, type=int, help='Dimension of the output vector from each label attention head')
690
+ parser.add_argument('--lattn_resdrop', default=True, action='store_true', help='Whether or not to use Residual Dropout')
691
+ parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer')
692
+ parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does')
693
+ parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned')
694
+ parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned')
695
+ parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does')
696
+ # currently unused - always assume 1/2 of pattn
697
+ #parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding')
698
+ parser.add_argument('--lattn_d_l', default=32, type=int, help='Number of labels')
699
+ parser.add_argument('--lattn_attention_dropout', default=0.2, type=float, help='Dropout for attention layer')
700
+ parser.add_argument('--lattn_d_ff', default=2048, type=int, help='Dimension of the Feed-forward layer')
701
+ parser.add_argument('--lattn_relu_dropout', default=0.2, type=float, help='Relu dropout for the label attention')
702
+ parser.add_argument('--lattn_residual_dropout', default=0.2, type=float, help='Residual dropout for the label attention')
703
+ parser.add_argument('--lattn_combined_input', default=True, action='store_true', help='Combine all inputs for the lattn, not just the pattn')
704
+ parser.add_argument('--use_lattn', default=False, action='store_true', help='Use the lattn layers - currently turned off')
705
+ parser.add_argument('--no_use_lattn', dest='use_lattn', action='store_false', help='Use the lattn layers - currently turned off')
706
+ parser.add_argument('--no_lattn_combined_input', dest='lattn_combined_input', action='store_false', help="Don't combine all inputs for the lattn, not just the pattn")
707
+
708
+ parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option')
709
+ parser.add_argument('--log_shapes', default=False, action='store_true', help='Log the parameters shapes at the beginning')
710
+ parser.add_argument('--watch_regex', default=None, help='regex to describe which weights and biases to output, if any')
711
+
712
+ parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
713
+ parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
714
+ parser.add_argument('--wandb_norm_regex', default=None, help='Log on wandb any tensor whose norm matches this matrix. Might get cluttered?')
715
+
716
+ return parser
717
+
718
+ def build_model_filename(args):
719
+ embedding = utils.embedding_name(args)
720
+ maybe_finetune = "finetuned" if args['bert_finetune'] or args['stage1_bert_finetune'] else ""
721
+ transformer_finetune_begin = "%d" % args['bert_finetune_begin_epoch'] if args['bert_finetune_begin_epoch'] is not None else ""
722
+ model_save_file = args['save_name'].format(shorthand=args['shorthand'],
723
+ oracle_level=args['oracle_level'],
724
+ embedding=embedding,
725
+ finetune=maybe_finetune,
726
+ transformer_finetune_begin=transformer_finetune_begin,
727
+ transition_scheme=args['transition_scheme'].name.lower().replace("_", ""),
728
+ tscheme=args['transition_scheme'].short_name,
729
+ trans_layers=args['bert_hidden_layers'],
730
+ seed=args['seed'])
731
+ model_save_file = re.sub("_+", "_", model_save_file)
732
+ logger.info("Expanded save_name: %s", model_save_file)
733
+
734
+ model_dir = os.path.split(model_save_file)[0]
735
+ if model_dir != args['save_dir']:
736
+ model_save_file = os.path.join(args['save_dir'], model_save_file)
737
+ return model_save_file
738
+
739
+ def parse_args(args=None):
740
+ parser = build_argparse()
741
+
742
+ args = parser.parse_args(args=args)
743
+ resolve_peft_args(args, logger, check_bert_finetune=False)
744
+ if not args.lang and args.shorthand and len(args.shorthand.split("_", maxsplit=1)) == 2:
745
+ args.lang = args.shorthand.split("_")[0]
746
+
747
+ if args.stage1_bert_learning_rate is None:
748
+ args.stage1_bert_learning_rate = args.bert_learning_rate
749
+
750
+ if args.optim is None and args.mode == 'train':
751
+ if not args.multistage:
752
+ # this seemed to work the best when not doing multistage
753
+ args.optim = "adadelta"
754
+ if args.use_peft and not args.bert_finetune:
755
+ logger.info("--use_peft set. setting --bert_finetune as well")
756
+ args.bert_finetune = True
757
+ elif args.bert_finetune or args.stage1_bert_finetune:
758
+ logger.info("Multistage training is set, optimizer is not chosen, and bert finetuning is active. Will use AdamW as the second stage optimizer.")
759
+ args.optim = "adamw"
760
+ else:
761
+ # if MADGRAD exists, use it
762
+ # otherwise, adamw
763
+ try:
764
+ import madgrad
765
+ args.optim = "madgrad"
766
+ logger.info("Multistage training is set, optimizer is not chosen, and MADGRAD is available. Will use MADGRAD as the second stage optimizer.")
767
+ except ModuleNotFoundError as e:
768
+ logger.warning("Multistage training is set. Best models are with MADGRAD, but it is not installed. Will use AdamW for the second stage optimizer. Consider installing MADGRAD")
769
+ args.optim = "adamw"
770
+
771
+ if args.mode == 'train':
772
+ if args.learning_rate is None:
773
+ args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim.lower(), None)
774
+ if args.learning_eps is None:
775
+ args.learning_eps = DEFAULT_LEARNING_EPS.get(args.optim.lower(), None)
776
+ if args.learning_momentum is None:
777
+ args.learning_momentum = DEFAULT_MOMENTUM.get(args.optim.lower(), None)
778
+ if args.learning_weight_decay is None:
779
+ args.learning_weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim.lower(), None)
780
+
781
+ if args.stage1_learning_rate is None:
782
+ args.stage1_learning_rate = DEFAULT_LEARNING_RATES["adadelta"]
783
+ if args.stage1_bert_finetune is None:
784
+ args.stage1_bert_finetune = args.bert_finetune
785
+
786
+ if args.learning_rate_min_lr is None:
787
+ args.learning_rate_min_lr = args.learning_rate * 0.02
788
+ if args.stage1_learning_rate_min_lr is None:
789
+ args.stage1_learning_rate_min_lr = args.stage1_learning_rate * 0.02
790
+
791
+ if args.reduce_position is None:
792
+ args.reduce_position = args.hidden_size // 4
793
+
794
+ if args.num_tree_lstm_layers is None:
795
+ if args.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
796
+ args.num_tree_lstm_layers = 2
797
+ else:
798
+ args.num_tree_lstm_layers = 1
799
+
800
+ if args.wandb_name or args.wandb_norm_regex:
801
+ args.wandb = True
802
+
803
+ args = vars(args)
804
+
805
+ retagging.postprocess_args(args)
806
+ postprocess_predict_output_args(args)
807
+
808
+ model_save_file = build_model_filename(args)
809
+ args['save_name'] = model_save_file
810
+
811
+ if args['save_each_name']:
812
+ model_save_each_file = os.path.join(args['save_dir'], args['save_each_name'])
813
+ model_save_each_file = utils.build_save_each_filename(model_save_each_file)
814
+ args['save_each_name'] = model_save_each_file
815
+ else:
816
+ # in the event that there is a start epoch setting,
817
+ # this will make a reasonable default for the path
818
+ pieces = os.path.splitext(args['save_name'])
819
+ model_save_each_file = pieces[0] + "_%04d" + pieces[1]
820
+ args['save_each_name'] = model_save_each_file
821
+
822
+ if args['checkpoint']:
823
+ args['checkpoint_save_name'] = utils.checkpoint_name(args['save_dir'], model_save_file, args['checkpoint_save_name'])
824
+
825
+ return args
826
+
827
+ def main(args=None):
828
+ """
829
+ Main function for building con parser
830
+
831
+ Processes args, calls the appropriate function for the chosen --mode
832
+ """
833
+ args = parse_args(args=args)
834
+
835
+ utils.set_random_seed(args['seed'])
836
+
837
+ logger.info("Running constituency parser in %s mode", args['mode'])
838
+ logger.debug("Using device: %s", args['device'])
839
+
840
+ model_load_file = args['save_name']
841
+ if args['load_name']:
842
+ if os.path.exists(args['load_name']):
843
+ model_load_file = args['load_name']
844
+ else:
845
+ model_load_file = os.path.join(args['save_dir'], args['load_name'])
846
+ elif args['load_package']:
847
+ if args['lang'] is None:
848
+ lang_pieces = args['load_package'].split("_", maxsplit=1)
849
+ try:
850
+ lang = constant.lang_to_langcode(lang_pieces[0])
851
+ except ValueError as e:
852
+ raise ValueError("--lang not specified, and the start of the --load_package name, %s, is not a known language. Please check the values of those parameters" % args['load_package']) from e
853
+ args['lang'] = lang
854
+ args['load_package'] = lang_pieces[1]
855
+ stanza.download(args['lang'], processors="constituency", package={"constituency": args['load_package']})
856
+ model_load_file = os.path.join(DEFAULT_MODEL_DIR, args['lang'], 'constituency', args['load_package'] + ".pt")
857
+ if not os.path.exists(model_load_file):
858
+ raise FileNotFoundError("Expected the downloaded model file for language %s package %s to be in %s, but there is nothing there. Perhaps the package name doesn't exist?" % (args['lang'], args['load_package'], model_load_file))
859
+ else:
860
+ logger.info("Model for language %s package %s is in %s", args['lang'], args['load_package'], model_load_file)
861
+
862
+ # TODO: when loading a saved model, we should default to whatever
863
+ # is in the model file for --retag_method, not the default for the language
864
+ if args['mode'] == 'train':
865
+ if tlogger.level == logging.NOTSET:
866
+ tlogger.setLevel(logging.DEBUG)
867
+ tlogger.debug("Set trainer logging level to DEBUG")
868
+
869
+ retag_pipeline = retagging.build_retag_pipeline(args)
870
+
871
+ if args['mode'] == 'train':
872
+ parser_training.train(args, model_load_file, retag_pipeline)
873
+ elif args['mode'] == 'predict':
874
+ parser_training.evaluate(args, model_load_file, retag_pipeline)
875
+ elif args['mode'] == 'parse_text':
876
+ load_model_parse_text(args, model_load_file, retag_pipeline)
877
+ elif args['mode'] == 'remove_optimizer':
878
+ parser_training.remove_optimizer(args, args['save_name'], model_load_file)
879
+
880
+ if __name__ == '__main__':
881
+ main()
stanza/stanza/models/lemmatizer.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for training and evaluating a lemmatizer.
3
+
4
+ This lemmatizer combines a neural sequence-to-sequence architecture with an `edit` classifier
5
+ and two dictionaries to produce robust lemmas from word forms.
6
+ For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
7
+ """
8
+
9
+ import logging
10
+ import sys
11
+ import os
12
+ import shutil
13
+ import time
14
+ from datetime import datetime
15
+ import argparse
16
+ import numpy as np
17
+ import random
18
+ import torch
19
+ from torch import nn, optim
20
+
21
+ from stanza.models.lemma.data import DataLoader
22
+ from stanza.models.lemma.vocab import Vocab
23
+ from stanza.models.lemma.trainer import Trainer
24
+ from stanza.models.lemma import scorer, edit
25
+ from stanza.models.common import utils
26
+ import stanza.models.common.seq2seq_constant as constant
27
+ from stanza.models.common.doc import *
28
+ from stanza.utils.conll import CoNLL
29
+ from stanza.models import _training_logging
30
+
31
+ logger = logging.getLogger('stanza')
32
+
33
+ def build_argparse():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')
36
+ parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
37
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
38
+ parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
39
+ parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
40
+
41
+ parser.add_argument('--mode', default='train', choices=['train', 'predict'])
42
+ parser.add_argument('--shorthand', type=str, help='Shorthand for the dataset to use. lang_dataset')
43
+
44
+ parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default use ensemble.')
45
+ parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based lemmatizer.')
46
+
47
+ parser.add_argument('--hidden_dim', type=int, default=200)
48
+ parser.add_argument('--emb_dim', type=int, default=50)
49
+ parser.add_argument('--num_layers', type=int, default=1)
50
+ parser.add_argument('--emb_dropout', type=float, default=0.5)
51
+ parser.add_argument('--dropout', type=float, default=0.5)
52
+ parser.add_argument('--max_dec_len', type=int, default=50)
53
+ parser.add_argument('--beam_size', type=int, default=1)
54
+
55
+ parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
56
+ parser.add_argument('--pos_dim', type=int, default=50)
57
+ parser.add_argument('--pos_dropout', type=float, default=0.5)
58
+ parser.add_argument('--no_edit', dest='edit', action='store_false', help='Do not use edit classifier in lemmatization. By default use an edit classifier.')
59
+ parser.add_argument('--num_edit', type=int, default=len(edit.EDIT_TO_ID))
60
+ parser.add_argument('--alpha', type=float, default=1.0)
61
+ parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.')
62
+ parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.')
63
+
64
+ parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
65
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
66
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
67
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
68
+
69
+ parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
70
+ parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
71
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
72
+ parser.add_argument('--lr_decay', type=float, default=0.9)
73
+ parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.")
74
+ parser.add_argument('--num_epoch', type=int, default=60)
75
+ parser.add_argument('--batch_size', type=int, default=50)
76
+ parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
77
+ parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
78
+ parser.add_argument('--save_dir', type=str, default='saved_models/lemma', help='Root dir for saving models.')
79
+ parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_lemmatizer.pt", help="File name to save the model")
80
+
81
+ parser.add_argument('--caseless', default=False, action='store_true', help='Lowercase everything first before processing. This will happen automatically if 100%% of the data is caseless')
82
+
83
+ parser.add_argument('--seed', type=int, default=1234)
84
+ utils.add_device_args(parser)
85
+
86
+ parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
87
+ parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
88
+ return parser
89
+
90
+ def parse_args(args=None):
91
+ parser = build_argparse()
92
+ args = parser.parse_args(args=args)
93
+
94
+ if args.wandb_name:
95
+ args.wandb = True
96
+
97
+ args = vars(args)
98
+ # when building the vocab, we keep track of the original language name
99
+ lang = args['shorthand'].split("_")[0] if args['shorthand'] else ""
100
+ args['lang'] = lang
101
+ return args
102
+
103
+ def main(args=None):
104
+ args = parse_args(args=args)
105
+
106
+ utils.set_random_seed(args['seed'])
107
+
108
+ logger.info("Running lemmatizer in {} mode".format(args['mode']))
109
+
110
+ if args['mode'] == 'train':
111
+ train(args)
112
+ else:
113
+ evaluate(args)
114
+
115
+ def all_lowercase(doc):
116
+ for sentence in doc.sentences:
117
+ for word in sentence.words:
118
+ if word.text.lower() != word.text:
119
+ return False
120
+ return True
121
+
122
+ def build_model_filename(args):
123
+ embedding = "nocharlm"
124
+ if args['charlm'] and args['charlm_forward_file']:
125
+ embedding = "charlm"
126
+ model_file = args['save_name'].format(shorthand=args['shorthand'],
127
+ embedding=embedding)
128
+ model_dir = os.path.split(model_file)[0]
129
+ if not model_dir.startswith(args['save_dir']):
130
+ model_file = os.path.join(args['save_dir'], model_file)
131
+ return model_file
132
+
133
+ def train(args):
134
+ # load data
135
+ logger.info("[Loading data with batch size {}...]".format(args['batch_size']))
136
+ train_doc = CoNLL.conll2doc(input_file=args['train_file'])
137
+ train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)
138
+ vocab = train_batch.vocab
139
+ args['vocab_size'] = vocab['char'].size
140
+ args['pos_vocab_size'] = vocab['pos'].size
141
+ dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
142
+ dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
143
+
144
+ utils.ensure_dir(args['save_dir'])
145
+ model_file = build_model_filename(args)
146
+ logger.info("Using full savename: %s", model_file)
147
+
148
+ # pred and gold path
149
+ system_pred_file = args['output_file']
150
+ gold_file = args['gold_file']
151
+
152
+ utils.print_config(args)
153
+
154
+ # skip training if the language does not have training or dev data
155
+ if len(train_batch) == 0 or len(dev_batch) == 0:
156
+ logger.warning("[Skip training because no training data available...]")
157
+ return
158
+
159
+ if not args['caseless'] and all_lowercase(train_doc):
160
+ logger.info("Building a caseless model, as all of the training data is caseless")
161
+ args['caseless'] = True
162
+
163
+ # start training
164
+ # train a dictionary-based lemmatizer
165
+ logger.info("Building lemmatizer in %s", model_file)
166
+ trainer = Trainer(args=args, vocab=vocab, device=args['device'])
167
+ logger.info("[Training dictionary-based lemmatizer...]")
168
+ trainer.train_dict(train_batch.raw_data())
169
+ logger.info("Evaluating on dev set...")
170
+ dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS]))
171
+ dev_batch.doc.set([LEMMA], dev_preds)
172
+ CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)
173
+ _, _, dev_f = scorer.score(system_pred_file, gold_file)
174
+ logger.info("Dev F1 = {:.2f}".format(dev_f * 100))
175
+
176
+ if args.get('dict_only', False):
177
+ # save dictionaries
178
+ trainer.save(model_file)
179
+ else:
180
+ if args['wandb']:
181
+ import wandb
182
+ wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_lemmatizer" % args['shorthand']
183
+ wandb.init(name=wandb_name, config=args)
184
+ wandb.run.define_metric('train_loss', summary='min')
185
+ wandb.run.define_metric('dev_score', summary='max')
186
+
187
+ # train a seq2seq model
188
+ logger.info("[Training seq2seq-based lemmatizer...]")
189
+ global_step = 0
190
+ max_steps = len(train_batch) * args['num_epoch']
191
+ dev_score_history = []
192
+ best_dev_preds = []
193
+ current_lr = args['lr']
194
+ global_start_time = time.time()
195
+ format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
196
+
197
+ # start training
198
+ for epoch in range(1, args['num_epoch']+1):
199
+ train_loss = 0
200
+ for i, batch in enumerate(train_batch):
201
+ start_time = time.time()
202
+ global_step += 1
203
+ loss = trainer.update(batch, eval=False) # update step
204
+ train_loss += loss
205
+ if global_step % args['log_step'] == 0:
206
+ duration = time.time() - start_time
207
+ logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,
208
+ max_steps, epoch, args['num_epoch'], loss, duration, current_lr))
209
+
210
+ # eval on dev
211
+ logger.info("Evaluating on dev set...")
212
+ dev_preds = []
213
+ dev_edits = []
214
+ for i, batch in enumerate(dev_batch):
215
+ preds, edits = trainer.predict(batch, args['beam_size'])
216
+ dev_preds += preds
217
+ if edits is not None:
218
+ dev_edits += edits
219
+ dev_preds = trainer.postprocess(dev_batch.doc.get([TEXT]), dev_preds, edits=dev_edits)
220
+
221
+ # try ensembling with dict if necessary
222
+ if args.get('ensemble_dict', False):
223
+ logger.info("[Ensembling dict with seq2seq model...]")
224
+ dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds)
225
+ dev_batch.doc.set([LEMMA], dev_preds)
226
+ CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)
227
+ _, _, dev_score = scorer.score(system_pred_file, gold_file)
228
+
229
+ train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
230
+ logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score))
231
+
232
+ if args['wandb']:
233
+ wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
234
+
235
+ # save best model
236
+ if epoch == 1 or dev_score > max(dev_score_history):
237
+ trainer.save(model_file)
238
+ logger.info("new best model saved.")
239
+ best_dev_preds = dev_preds
240
+
241
+ # lr schedule
242
+ if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1] and \
243
+ args['optim'] in ['sgd', 'adagrad']:
244
+ current_lr *= args['lr_decay']
245
+ trainer.update_lr(current_lr)
246
+
247
+ dev_score_history += [dev_score]
248
+ logger.info("")
249
+
250
+ logger.info("Training ended with {} epochs.".format(epoch))
251
+
252
+ if args['wandb']:
253
+ wandb.finish()
254
+
255
+ best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1
256
+ logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))
257
+
258
+ def evaluate(args):
259
+ # file paths
260
+ system_pred_file = args['output_file']
261
+ gold_file = args['gold_file']
262
+ model_file = build_model_filename(args)
263
+
264
+ # load model
265
+ trainer = Trainer(model_file=model_file, device=args['device'], args=args)
266
+ loaded_args, vocab = trainer.args, trainer.vocab
267
+
268
+ for k in args:
269
+ if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:
270
+ loaded_args[k] = args[k]
271
+
272
+ # load data
273
+ logger.info("Loading data with batch size {}...".format(args['batch_size']))
274
+ doc = CoNLL.conll2doc(input_file=args['eval_file'])
275
+ batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
276
+
277
+ # skip eval if dev data does not exist
278
+ if len(batch) == 0:
279
+ logger.warning("Skip evaluation because no dev data is available...\nLemma score:\n{} ".format(args['shorthand']))
280
+ return
281
+
282
+ dict_preds = trainer.predict_dict(batch.doc.get([TEXT, UPOS]))
283
+
284
+ if loaded_args.get('dict_only', False):
285
+ preds = dict_preds
286
+ else:
287
+ logger.info("Running the seq2seq model...")
288
+ preds = []
289
+ edits = []
290
+ for i, b in enumerate(batch):
291
+ ps, es = trainer.predict(b, args['beam_size'])
292
+ preds += ps
293
+ if es is not None:
294
+ edits += es
295
+ preds = trainer.postprocess(batch.doc.get([TEXT]), preds, edits=edits)
296
+
297
+ if loaded_args.get('ensemble_dict', False):
298
+ logger.info("[Ensembling dict with seq2seq lemmatizer...]")
299
+ preds = trainer.ensemble(batch.doc.get([TEXT, UPOS]), preds)
300
+
301
+ if trainer.has_contextual_lemmatizers():
302
+ preds = trainer.update_contextual_preds(batch.doc, preds)
303
+
304
+ # write to file and score
305
+ batch.doc.set([LEMMA], preds)
306
+ CoNLL.write_doc2conll(batch.doc, system_pred_file)
307
+ if gold_file is not None:
308
+ _, _, score = scorer.score(system_pred_file, gold_file)
309
+
310
+ logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100))
311
+
312
+ if __name__ == '__main__':
313
+ main()
stanza/stanza/pipeline/_constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Module defining constants """
2
+
3
+ # string constants for processor names
4
+ LANGID = 'langid'
5
+ TOKENIZE = 'tokenize'
6
+ MWT = 'mwt'
7
+ POS = 'pos'
8
+ LEMMA = 'lemma'
9
+ DEPPARSE = 'depparse'
10
+ NER = 'ner'
11
+ SENTIMENT = 'sentiment'
12
+ CONSTITUENCY = 'constituency'
13
+ COREF = 'coref'
stanza/stanza/pipeline/external/spacy.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processors related to spaCy in the pipeline.
3
+ """
4
+
5
+ from stanza.models.common import doc
6
+ from stanza.pipeline._constants import TOKENIZE
7
+ from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
8
+
9
+ def check_spacy():
10
+ """
11
+ Import necessary components from spaCy to perform tokenization.
12
+ """
13
+ try:
14
+ import spacy
15
+ except ImportError:
16
+ raise ImportError(
17
+ "spaCy is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions."
18
+ )
19
+ return True
20
+
21
+ @register_processor_variant(TOKENIZE, 'spacy')
22
+ class SpacyTokenizer(ProcessorVariant):
23
+ def __init__(self, config):
24
+ """ Construct a spaCy-based tokenizer by loading the spaCy pipeline.
25
+ """
26
+ if config['lang'] != 'en':
27
+ raise Exception("spaCy tokenizer is currently only allowed in English pipeline.")
28
+
29
+ try:
30
+ import spacy
31
+ from spacy.lang.en import English
32
+ except ImportError:
33
+ raise ImportError(
34
+ "spaCy 2.0+ is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions."
35
+ )
36
+
37
+ # Create a Tokenizer with the default settings for English
38
+ # including punctuation rules and exceptions
39
+ self.nlp = English()
40
+ # by default spacy uses dependency parser to do ssplit
41
+ # we need to add a sentencizer for fast rule-based ssplit
42
+ if spacy.__version__.startswith("2."):
43
+ self.nlp.add_pipe(self.nlp.create_pipe("sentencizer"))
44
+ else:
45
+ self.nlp.add_pipe("sentencizer")
46
+ self.no_ssplit = config.get('no_ssplit', False)
47
+
48
+ def process(self, document):
49
+ """ Tokenize a document with the spaCy tokenizer and wrap the results into a Doc object.
50
+ """
51
+ if isinstance(document, doc.Document):
52
+ text = document.text
53
+ else:
54
+ text = document
55
+ if not isinstance(text, str):
56
+ raise Exception("Must supply a string or Stanza Document object to the spaCy tokenizer.")
57
+ spacy_doc = self.nlp(text)
58
+
59
+ sentences = []
60
+ for sent in spacy_doc.sents:
61
+ tokens = []
62
+ for tok in sent:
63
+ token_entry = {
64
+ doc.TEXT: tok.text,
65
+ doc.MISC: f"{doc.START_CHAR}={tok.idx}|{doc.END_CHAR}={tok.idx+len(tok.text)}"
66
+ }
67
+ tokens.append(token_entry)
68
+ sentences.append(tokens)
69
+
70
+ # if no_ssplit is set, flatten all the sentences into one sentence
71
+ if self.no_ssplit:
72
+ sentences = [[t for s in sentences for t in s]]
73
+
74
+ return doc.Document(sentences, text)
stanza/stanza/pipeline/ner_processor.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor for performing named entity tagging.
3
+ """
4
+
5
+ import torch
6
+
7
+ import logging
8
+
9
+ from stanza.models.common import doc
10
+ from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
11
+ from stanza.models.common.utils import unsort
12
+ from stanza.models.ner.data import DataLoader
13
+ from stanza.models.ner.trainer import Trainer
14
+ from stanza.models.ner.utils import merge_tags
15
+ from stanza.pipeline._constants import *
16
+ from stanza.pipeline.processor import UDProcessor, register_processor
17
+
18
+ logger = logging.getLogger('stanza')
19
+
20
+ @register_processor(name=NER)
21
+ class NERProcessor(UDProcessor):
22
+
23
+ # set of processor requirements this processor fulfills
24
+ PROVIDES_DEFAULT = set([NER])
25
+ # set of processor requirements for this processor
26
+ REQUIRES_DEFAULT = set([TOKENIZE])
27
+
28
+ def _get_dependencies(self, config, dep_name):
29
+ dependencies = config.get(dep_name, None)
30
+ if dependencies is not None:
31
+ dependencies = dependencies.split(";")
32
+ dependencies = [x if x else None for x in dependencies]
33
+ else:
34
+ dependencies = [x.get(dep_name) for x in config.get('dependencies', [])]
35
+ return dependencies
36
+
37
+ def _set_up_model(self, config, pipeline, device):
38
+ # set up trainer
39
+ model_paths = config.get('model_path')
40
+ if isinstance(model_paths, str):
41
+ model_paths = model_paths.split(";")
42
+
43
+ charlm_forward_files = self._get_dependencies(config, 'forward_charlm_path')
44
+ charlm_backward_files = self._get_dependencies(config, 'backward_charlm_path')
45
+ pretrain_files = self._get_dependencies(config, 'pretrain_path')
46
+
47
+ # allow predict_tagset to be specified as an int
48
+ # (which only applies to the first model)
49
+ # or as a string ";" separated list of ints
50
+ self._predict_tagset = {}
51
+ predict_tagset = config.get('predict_tagset', None)
52
+ if predict_tagset:
53
+ if isinstance(predict_tagset, int):
54
+ self._predict_tagset[0] = predict_tagset
55
+ else:
56
+ predict_tagset = predict_tagset.split(";")
57
+ for piece_idx, piece in enumerate(predict_tagset):
58
+ if piece:
59
+ self._predict_tagset[piece_idx] = int(piece)
60
+
61
+ self.trainers = []
62
+ for (model_path, pretrain_path, charlm_forward, charlm_backward) in zip(model_paths, pretrain_files, charlm_forward_files, charlm_backward_files):
63
+ logger.debug("Loading %s with pretrain %s, forward charlm %s, backward charlm %s", model_path, pretrain_path, charlm_forward, charlm_backward)
64
+ pretrain = pipeline.foundation_cache.load_pretrain(pretrain_path) if pretrain_path else None
65
+ args = {'charlm_forward_file': charlm_forward,
66
+ 'charlm_backward_file': charlm_backward}
67
+
68
+ predict_tagset = self._predict_tagset.get(len(self.trainers), None)
69
+ if predict_tagset is not None:
70
+ args['predict_tagset'] = predict_tagset
71
+
72
+ try:
73
+ trainer = Trainer(args=args, model_file=model_path, pretrain=pretrain, device=device, foundation_cache=pipeline.foundation_cache)
74
+ except ForwardCharlmNotFoundError as e:
75
+ raise ForwardCharlmNotFoundError("Could not find the forward charlm %s. Please specify the correct path with ner_forward_charlm_path" % e.filename, e.filename) from None
76
+ except BackwardCharlmNotFoundError as e:
77
+ raise BackwardCharlmNotFoundError("Could not find the backward charlm %s. Please specify the correct path with ner_backward_charlm_path" % e.filename, e.filename) from None
78
+ self.trainers.append(trainer)
79
+
80
+ self._trainer = self.trainers[0]
81
+ self.model_paths = model_paths
82
+
83
+ def _set_up_final_config(self, config):
84
+ """ Finalize the configurations for this processor, based off of values from a UD model. """
85
+ # set configurations from loaded model
86
+ if len(self.trainers) == 0:
87
+ raise RuntimeError("Somehow there are no models loaded!")
88
+ self._vocab = self.trainers[0].vocab
89
+ self.configs = []
90
+ for trainer in self.trainers:
91
+ loaded_args = trainer.args
92
+ # filter out unneeded args from model
93
+ loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
94
+ loaded_args.update(config)
95
+ self.configs.append(loaded_args)
96
+ self._config = self.configs[0]
97
+
98
+ def __str__(self):
99
+ return "NERProcessor(%s)" % ";".join(self.model_paths)
100
+
101
+ def mark_inactive(self):
102
+ """ Drop memory intensive resources if keeping this processor around for reasons other than running it. """
103
+ super().mark_inactive()
104
+ self.trainers = None
105
+
106
+ def process(self, document):
107
+ with torch.no_grad():
108
+ all_preds = []
109
+ for trainer, config in zip(self.trainers, self.configs):
110
+ # set up a eval-only data loader and skip tag preprocessing
111
+ batch = DataLoader(document, config['batch_size'], config, vocab=trainer.vocab, evaluation=True, preprocess_tags=False, bert_tokenizer=trainer.model.bert_tokenizer)
112
+ preds = []
113
+ for i, b in enumerate(batch):
114
+ preds += trainer.predict(b)
115
+ all_preds.append(preds)
116
+ # for each sentence, gather a list of predictions
117
+ # merge those predictions into a single list
118
+ # earlier models will have precedence
119
+ preds = [merge_tags(*x) for x in zip(*all_preds)]
120
+ batch.doc.set([doc.NER], [y for x in preds for y in x], to_token=True)
121
+ batch.doc.set([doc.MULTI_NER], [tuple(y) for x in zip(*all_preds) for y in zip(*x)], to_token=True)
122
+ # collect entities into document attribute
123
+ total = len(batch.doc.build_ents())
124
+ logger.debug(f'{total} entities found in document.')
125
+ return batch.doc
126
+
127
+ def bulk_process(self, docs):
128
+ """
129
+ NER processor has a collation step after running inference
130
+ """
131
+ docs = super().bulk_process(docs)
132
+ for doc in docs:
133
+ doc.build_ents()
134
+ return docs
135
+
136
+ def get_known_tags(self, model_idx=0):
137
+ """
138
+ Return the tags known by this model
139
+
140
+ Removes the S-, B-, etc, and does not include O
141
+ Specify model_idx if the processor has more than one model
142
+ """
143
+ return self.trainers[model_idx].get_known_tags()
stanza/stanza/resources/print_charlm_depparse.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A small utility script to output which depparse models use charlm
3
+
4
+ (It should skip en_genia, en_craft, but currently doesn't)
5
+
6
+ Not frequently useful, but seems like the kind of thing that might get used a couple times
7
+ """
8
+
9
+ from stanza.resources.common import load_resources_json
10
+ from stanza.resources.default_packages import default_charlms, depparse_charlms
11
+
12
+ def list_depparse():
13
+ charlm_langs = list(default_charlms.keys())
14
+ resources = load_resources_json()
15
+
16
+ models = ["%s_%s" % (lang, model) for lang in charlm_langs for model in resources[lang].get("depparse", {})
17
+ if lang not in depparse_charlms or model not in depparse_charlms[lang] or depparse_charlms[lang][model] is not None]
18
+ return models
19
+
20
+ if __name__ == "__main__":
21
+ models = list_depparse()
22
+ print(" ".join(models))
stanza/stanza/server/dependency_converter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A converter from constituency trees to dependency trees using CoreNLP's UniversalEnglish converter.
3
+
4
+ ONLY works on English.
5
+ """
6
+
7
+ import stanza
8
+ from stanza.protobuf import DependencyConverterRequest, DependencyConverterResponse
9
+ from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext
10
+
11
+ CONVERTER_JAVA = "edu.stanford.nlp.trees.ProcessDependencyConverterRequest"
12
+
13
+ def send_converter_request(request, classpath=None):
14
+ return send_request(request, DependencyConverterResponse, CONVERTER_JAVA, classpath=classpath)
15
+
16
+ def build_request(doc):
17
+ """
18
+ Request format is simple: one tree per sentence in the document
19
+ """
20
+ request = DependencyConverterRequest()
21
+ for sentence in doc.sentences:
22
+ request.trees.append(build_tree(sentence.constituency, None))
23
+ return request
24
+
25
+ def process_doc(doc, classpath=None):
26
+ """
27
+ Convert the constituency trees in the document,
28
+ then attach the resulting dependencies to the sentences
29
+ """
30
+ request = build_request(doc)
31
+ response = send_converter_request(request, classpath=classpath)
32
+ attach_dependencies(doc, response)
33
+
34
+ def attach_dependencies(doc, response):
35
+ if len(doc.sentences) != len(response.conversions):
36
+ raise ValueError("Sent %d sentences but got back %d conversions" % (len(doc.sentences), len(response.conversions)))
37
+ for sent_idx, (sentence, conversion) in enumerate(zip(doc.sentences, response.conversions)):
38
+ graph = conversion.graph
39
+
40
+ # The deterministic conversion should have an equal number of words and one fewer edge
41
+ # ... the root is represented by a word with no parent
42
+ if len(sentence.words) != len(graph.node):
43
+ raise ValueError("Sentence %d of the conversion should have %d words but got back %d nodes in the graph" % (sent_idx, len(sentence.words), len(graph.node)))
44
+ if len(sentence.words) != len(graph.edge) + 1:
45
+ raise ValueError("Sentence %d of the conversion should have %d edges (one per word, plus the root) but got back %d edges in the graph" % (sent_idx, len(sentence.words) - 1, len(graph.edge)))
46
+
47
+ expected_nodes = set(range(1, len(sentence.words) + 1))
48
+ targets = set()
49
+ for edge in graph.edge:
50
+ if edge.target in targets:
51
+ raise ValueError("Found two parents of %d in sentence %d" % (edge.target, sent_idx))
52
+ targets.add(edge.target)
53
+ # -1 since the words are 0 indexed in the sentence,
54
+ # but we count dependencies from 1
55
+ sentence.words[edge.target-1].head = edge.source
56
+ sentence.words[edge.target-1].deprel = edge.dep
57
+ roots = expected_nodes - targets
58
+ assert len(roots) == 1
59
+ for root in roots:
60
+ sentence.words[root-1].head = 0
61
+ sentence.words[root-1].deprel = "root"
62
+ sentence.build_dependencies()
63
+
64
+
65
+ class DependencyConverter(JavaProtobufContext):
66
+ """
67
+ Context window for the dependency converter
68
+
69
+ This is a context window which keeps a process open. Should allow
70
+ for multiple requests without launching new java processes each time.
71
+ """
72
+ def __init__(self, classpath=None):
73
+ super(DependencyConverter, self).__init__(classpath, DependencyConverterResponse, CONVERTER_JAVA)
74
+
75
+ def process(self, doc):
76
+ """
77
+ Converts a constituency tree to dependency trees for each of the sentences in the document
78
+ """
79
+ request = build_request(doc)
80
+ response = self.process_request(request)
81
+ attach_dependencies(doc, response)
82
+ return doc
83
+
84
+ def main():
85
+ nlp = stanza.Pipeline('en',
86
+ processors='tokenize,pos,constituency')
87
+
88
+ doc = nlp('I like blue antennae.')
89
+ print("{:C}".format(doc))
90
+ process_doc(doc, classpath="$CLASSPATH")
91
+ print("{:C}".format(doc))
92
+
93
+ doc = nlp('And I cannot lie.')
94
+ print("{:C}".format(doc))
95
+ with DependencyConverter(classpath="$CLASSPATH") as converter:
96
+ converter.process(doc)
97
+ print("{:C}".format(doc))
98
+
99
+
100
+ if __name__ == '__main__':
101
+ main()
stanza/stanza/tests/classifiers/test_constituency_classifier.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pytest
4
+
5
+ import stanza
6
+ import stanza.models.classifier as classifier
7
+ import stanza.models.classifiers.data as data
8
+ from stanza.models.classifiers.trainer import Trainer
9
+ from stanza.tests import TEST_MODELS_DIR
10
+ from stanza.tests.classifiers.test_classifier import fake_embeddings
11
+ from stanza.tests.classifiers.test_data import train_file_with_trees, dev_file_with_trees
12
+ from stanza.models.common import utils
13
+ from stanza.tests.constituency.test_trainer import build_trainer, TREEBANK
14
+
15
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
16
+
17
+ class TestConstituencyClassifier:
18
+ @pytest.fixture(scope="class")
19
+ def constituency_model(self, fake_embeddings, tmp_path_factory):
20
+ args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
21
+ trainer = build_trainer(str(fake_embeddings), *args, treebank=TREEBANK)
22
+
23
+ trainer_pt = str(tmp_path_factory.mktemp("constituency") / "constituency.pt")
24
+ trainer.save(trainer_pt, save_optimizer=False)
25
+ return trainer_pt
26
+
27
+ def build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
28
+ """
29
+ Build a Constituency Classifier model to be used by one of the later tests
30
+ """
31
+ save_dir = str(tmp_path / "classifier")
32
+ save_name = "model.pt"
33
+ args = ["--save_dir", save_dir,
34
+ "--save_name", save_name,
35
+ "--model_type", "constituency",
36
+ "--constituency_model", constituency_model,
37
+ "--wordvec_pretrain_file", str(fake_embeddings),
38
+ "--fc_shapes", "20,10",
39
+ "--train_file", str(train_file_with_trees),
40
+ "--dev_file", str(dev_file_with_trees),
41
+ "--max_epochs", "2",
42
+ "--batch_size", "60"]
43
+ if extra_args is not None:
44
+ args = args + extra_args
45
+ args = classifier.parse_args(args)
46
+ train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
47
+ trainer = Trainer.build_new_model(args, train_set)
48
+ return trainer, train_set, args
49
+
50
+ def run_training(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
51
+ """
52
+ Iterate a couple times over a model
53
+ """
54
+ trainer, train_set, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args)
55
+ dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
56
+ labels = data.dataset_labels(train_set)
57
+
58
+ save_filename = os.path.join(args.save_dir, args.save_name)
59
+ checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)
60
+ classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)
61
+ return trainer, train_set, args
62
+
63
+ def test_build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
64
+ """
65
+ Test that building a basic constituency-based model works
66
+ """
67
+ self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
68
+
69
+ def test_save_load(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
70
+ """
71
+ Test that a constituency model can save & load
72
+ """
73
+ trainer, _, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
74
+
75
+ save_filename = os.path.join(args.save_dir, args.save_name)
76
+ trainer.save(save_filename)
77
+
78
+ args.load_name = args.save_name
79
+ trainer = Trainer.load(args.load_name, args)
80
+
81
+ def test_train_basic(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
82
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
83
+
84
+ def test_train_pipeline(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
85
+ """
86
+ Test that writing out a temp model, then loading it in the pipeline is a thing that works
87
+ """
88
+ trainer, _, args = self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
89
+ save_filename = os.path.join(args.save_dir, args.save_name)
90
+ assert os.path.exists(save_filename)
91
+ assert os.path.exists(args.constituency_model)
92
+
93
+ pipeline_args = {"lang": "en",
94
+ "download_method": None,
95
+ "model_dir": TEST_MODELS_DIR,
96
+ "processors": "tokenize,pos,constituency,sentiment",
97
+ "tokenize_pretokenized": True,
98
+ "constituency_model_path": args.constituency_model,
99
+ "constituency_pretrain_path": args.wordvec_pretrain_file,
100
+ "constituency_backward_charlm_path": None,
101
+ "constituency_forward_charlm_path": None,
102
+ "sentiment_model_path": save_filename,
103
+ "sentiment_pretrain_path": args.wordvec_pretrain_file,
104
+ "sentiment_backward_charlm_path": None,
105
+ "sentiment_forward_charlm_path": None}
106
+ pipeline = stanza.Pipeline(**pipeline_args)
107
+ doc = pipeline("This is a test")
108
+ # since the model is random, we have no expectations for what the result actually is
109
+ assert doc.sentences[0].sentiment is not None
110
+
111
+
112
+ def test_train_all_words(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
113
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_all_words'])
114
+
115
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_all_words'])
116
+
117
+ def test_train_top_layer(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
118
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_top_layer'])
119
+
120
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_top_layer'])
121
+
122
+ def test_train_attn(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
123
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--no_constituency_all_words'])
124
+
125
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--constituency_all_words'])
126
+
127
+ self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_node_attn'])
128
+
stanza/stanza/tests/common/__init__.py ADDED
File without changes
stanza/stanza/tests/common/test_chuliu_edmonds.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test some use cases of the chuliu_edmonds algorithm
3
+
4
+ (currently just the tarjan implementation)
5
+ """
6
+
7
+ import numpy as np
8
+ import pytest
9
+
10
+ from stanza.models.common.chuliu_edmonds import tarjan
11
+
12
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
13
+
14
+ def test_tarjan_basic():
15
+ simple = np.array([0, 4, 4, 4, 0])
16
+ result = tarjan(simple)
17
+ assert result == []
18
+
19
+ simple = np.array([0, 2, 0, 4, 2, 2])
20
+ result = tarjan(simple)
21
+ assert result == []
22
+
23
+ def test_tarjan_cycle():
24
+ cycle_graph = np.array([0, 3, 1, 2])
25
+ result = tarjan(cycle_graph)
26
+ expected = np.array([False, True, True, True])
27
+ assert len(result) == 1
28
+ np.testing.assert_array_equal(result[0], expected)
29
+
30
+ cycle_graph = np.array([0, 3, 1, 2, 5, 6, 4])
31
+ result = tarjan(cycle_graph)
32
+ assert len(result) == 2
33
+ expected = [np.array([False, True, True, True, False, False, False]),
34
+ np.array([False, False, False, False, True, True, True])]
35
+ for r, e in zip(result, expected):
36
+ np.testing.assert_array_equal(r, e)
stanza/stanza/tests/common/test_confusion.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple simple confusion matrices and output formats
3
+ """
4
+
5
+ from collections import defaultdict
6
+ import pytest
7
+
8
+ from stanza.utils.confusion import format_confusion, confusion_to_f1, confusion_to_macro_f1, confusion_to_weighted_f1
9
+
10
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
11
+
12
+ @pytest.fixture
13
+ def simple_confusion():
14
+ confusion = defaultdict(lambda: defaultdict(int))
15
+ confusion["B-ORG"]["B-ORG"] = 1
16
+ confusion["B-ORG"]["B-PER"] = 1
17
+ confusion["E-ORG"]["E-ORG"] = 1
18
+ confusion["E-ORG"]["E-PER"] = 1
19
+ confusion["O"]["O"] = 4
20
+ return confusion
21
+
22
+ @pytest.fixture
23
+ def short_confusion():
24
+ """
25
+ Same thing, but with a short name. This should not be sorted by entity type
26
+ """
27
+ confusion = defaultdict(lambda: defaultdict(int))
28
+ confusion["A"]["B-ORG"] = 1
29
+ confusion["B-ORG"]["B-PER"] = 1
30
+ confusion["E-ORG"]["E-ORG"] = 1
31
+ confusion["E-ORG"]["E-PER"] = 1
32
+ confusion["O"]["O"] = 4
33
+ return confusion
34
+
35
+ EXPECTED_SIMPLE_OUTPUT = """
36
+ t\\p O B-ORG E-ORG B-PER E-PER
37
+ O 4 0 0 0 0
38
+ B-ORG 0 1 0 1 0
39
+ E-ORG 0 0 1 0 1
40
+ B-PER 0 0 0 0 0
41
+ E-PER 0 0 0 0 0
42
+ """[1:-1] # don't want to strip
43
+
44
+ EXPECTED_SHORT_OUTPUT = """
45
+ t\\p O A B-ORG B-PER E-ORG E-PER
46
+ O 4 0 0 0 0 0
47
+ A 0 0 1 0 0 0
48
+ B-ORG 0 0 0 1 0 0
49
+ B-PER 0 0 0 0 0 0
50
+ E-ORG 0 0 0 0 1 1
51
+ E-PER 0 0 0 0 0 0
52
+ """[1:-1]
53
+
54
+ EXPECTED_HIDE_BLANK_SHORT_OUTPUT = """
55
+ t\\p O B-ORG E-ORG B-PER E-PER
56
+ O 4 0 0 0 0
57
+ A 0 1 0 0 0
58
+ B-ORG 0 0 0 1 0
59
+ E-ORG 0 0 1 0 1
60
+ """[1:-1]
61
+
62
+ def test_simple_output(simple_confusion):
63
+ assert EXPECTED_SIMPLE_OUTPUT == format_confusion(simple_confusion)
64
+
65
+ def test_short_output(short_confusion):
66
+ assert EXPECTED_SHORT_OUTPUT == format_confusion(short_confusion)
67
+
68
+ def test_hide_blank_short_output(short_confusion):
69
+ assert EXPECTED_HIDE_BLANK_SHORT_OUTPUT == format_confusion(short_confusion, hide_blank=True)
70
+
71
+ def test_macro_f1(simple_confusion, short_confusion):
72
+ assert confusion_to_macro_f1(simple_confusion) == pytest.approx(0.466666666666)
73
+ assert confusion_to_macro_f1(short_confusion) == pytest.approx(0.277777777777)
74
+
75
+ def test_weighted_f1(simple_confusion, short_confusion):
76
+ assert confusion_to_weighted_f1(simple_confusion) == pytest.approx(0.83333333)
77
+ assert confusion_to_weighted_f1(short_confusion) == pytest.approx(0.66666666)
78
+
79
+ assert confusion_to_weighted_f1(simple_confusion, exclude=["O"]) == pytest.approx(0.66666666)
80
+ assert confusion_to_weighted_f1(short_confusion, exclude=["O"]) == pytest.approx(0.33333333)
81
+
stanza/stanza/tests/common/test_constant.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test the conversion to lcodes and splitting of dataset names
3
+ """
4
+
5
+ import tempfile
6
+
7
+ import pytest
8
+
9
+ import stanza
10
+ from stanza.models.common.constant import treebank_to_short_name, lang_to_langcode, is_right_to_left, two_to_three_letters, langlower2lcode
11
+ from stanza.tests import *
12
+
13
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
14
+
15
+ def test_treebank():
16
+ """
17
+ Test the entire treebank name conversion
18
+ """
19
+ # conversion of a UD_ name
20
+ assert "hi_hdtb" == treebank_to_short_name("UD_Hindi-HDTB")
21
+ # conversion of names without UD
22
+ assert "hi_fire2013" == treebank_to_short_name("Hindi-fire2013")
23
+ assert "hi_fire2013" == treebank_to_short_name("Hindi-Fire2013")
24
+ assert "hi_fire2013" == treebank_to_short_name("Hindi-FIRE2013")
25
+ # already short names are generally preserved
26
+ assert "hi_fire2013" == treebank_to_short_name("hi-fire2013")
27
+ assert "hi_fire2013" == treebank_to_short_name("hi_fire2013")
28
+ # a special case
29
+ assert "zh-hant_pud" == treebank_to_short_name("UD_Chinese-PUD")
30
+ # a special case already converted once
31
+ assert "zh-hant_pud" == treebank_to_short_name("zh-hant_pud")
32
+ assert "zh-hant_pud" == treebank_to_short_name("zh-hant-pud")
33
+ assert "zh-hans_gsdsimp" == treebank_to_short_name("zh-hans_gsdsimp")
34
+
35
+ assert "wo_masakhane" == treebank_to_short_name("wo_masakhane")
36
+ assert "wo_masakhane" == treebank_to_short_name("wol_masakhane")
37
+ assert "wo_masakhane" == treebank_to_short_name("Wol_masakhane")
38
+ assert "wo_masakhane" == treebank_to_short_name("wolof_masakhane")
39
+ assert "wo_masakhane" == treebank_to_short_name("Wolof_masakhane")
40
+
41
+ def test_lang_to_langcode():
42
+ assert "hi" == lang_to_langcode("Hindi")
43
+ assert "hi" == lang_to_langcode("HINDI")
44
+ assert "hi" == lang_to_langcode("hindi")
45
+ assert "hi" == lang_to_langcode("HI")
46
+ assert "hi" == lang_to_langcode("hi")
47
+
48
+ def test_right_to_left():
49
+ assert is_right_to_left("ar")
50
+ assert is_right_to_left("Arabic")
51
+
52
+ assert not is_right_to_left("en")
53
+ assert not is_right_to_left("English")
54
+
55
+ def test_two_to_three():
56
+ assert lang_to_langcode("Wolof") == "wo"
57
+ assert lang_to_langcode("wol") == "wo"
58
+
59
+ assert "wo" in two_to_three_letters
60
+ assert two_to_three_letters["wo"] == "wol"
61
+
62
+ def test_langlower():
63
+ assert lang_to_langcode("WOLOF") == "wo"
64
+ assert lang_to_langcode("nOrWeGiAn") == "nb"
65
+
66
+ assert "soj" == langlower2lcode["soi"]
67
+ assert "soj" == langlower2lcode["sohi"]
stanza/stanza/tests/common/test_data_conversion.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic tests of the data conversion
3
+ """
4
+
5
+ import io
6
+ import pytest
7
+ import tempfile
8
+ from zipfile import ZipFile
9
+
10
+ import stanza
11
+ from stanza.utils.conll import CoNLL
12
+ from stanza.models.common.doc import Document
13
+ from stanza.tests import *
14
+
15
+ pytestmark = pytest.mark.pipeline
16
+
17
+ # data for testing
18
+ CONLL = [[['1', 'Nous', 'il', 'PRON', '_', 'Number=Plur|Person=1|PronType=Prs', '3', 'nsubj', '_', 'start_char=0|end_char=4'],
19
+ ['2', 'avons', 'avoir', 'AUX', '_', 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', '3', 'aux:tense', '_', 'start_char=5|end_char=10'],
20
+ ['3', 'atteint', 'atteindre', 'VERB', '_', 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', '0', 'root', '_', 'start_char=11|end_char=18'],
21
+ ['4', 'la', 'le', 'DET', '_', 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', '5', 'det', '_', 'start_char=19|end_char=21'],
22
+ ['5', 'fin', 'fin', 'NOUN', '_', 'Gender=Fem|Number=Sing', '3', 'obj', '_', 'start_char=22|end_char=25'],
23
+ ['6-7', 'du', '_', '_', '_', '_', '_', '_', '_', 'start_char=26|end_char=28'],
24
+ ['6', 'de', 'de', 'ADP', '_', '_', '8', 'case', '_', '_'],
25
+ ['7', 'le', 'le', 'DET', '_', 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', '8', 'det', '_', '_'],
26
+ ['8', 'sentier', 'sentier', 'NOUN', '_', 'Gender=Masc|Number=Sing', '5', 'nmod', '_', 'start_char=29|end_char=36'],
27
+ ['9', '.', '.', 'PUNCT', '_', '_', '3', 'punct', '_', 'start_char=36|end_char=37']]]
28
+
29
+
30
+ DICT = [[{'id': (1,), 'text': 'Nous', 'lemma': 'il', 'upos': 'PRON', 'feats': 'Number=Plur|Person=1|PronType=Prs', 'head': 3, 'deprel': 'nsubj', 'misc': 'start_char=0|end_char=4'},
31
+ {'id': (2,), 'text': 'avons', 'lemma': 'avoir', 'upos': 'AUX', 'feats': 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', 'head': 3, 'deprel': 'aux:tense', 'misc': 'start_char=5|end_char=10'},
32
+ {'id': (3,), 'text': 'atteint', 'lemma': 'atteindre', 'upos': 'VERB', 'feats': 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', 'head': 0, 'deprel': 'root', 'misc': 'start_char=11|end_char=18'},
33
+ {'id': (4,), 'text': 'la', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', 'head': 5, 'deprel': 'det', 'misc': 'start_char=19|end_char=21'},
34
+ {'id': (5,), 'text': 'fin', 'lemma': 'fin', 'upos': 'NOUN', 'feats': 'Gender=Fem|Number=Sing', 'head': 3, 'deprel': 'obj', 'misc': 'start_char=22|end_char=25'},
35
+ {'id': (6, 7), 'text': 'du', 'misc': 'start_char=26|end_char=28'},
36
+ {'id': (6,), 'text': 'de', 'lemma': 'de', 'upos': 'ADP', 'head': 8, 'deprel': 'case'},
37
+ {'id': (7,), 'text': 'le', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', 'head': 8, 'deprel': 'det'},
38
+ {'id': (8,), 'text': 'sentier', 'lemma': 'sentier', 'upos': 'NOUN', 'feats': 'Gender=Masc|Number=Sing', 'head': 5, 'deprel': 'nmod', 'misc': 'start_char=29|end_char=36'},
39
+ {'id': (9,), 'text': '.', 'lemma': '.', 'upos': 'PUNCT', 'head': 3, 'deprel': 'punct', 'misc': 'start_char=36|end_char=37'}]]
40
+
41
+ def test_conll_to_dict():
42
+ dicts, empty = CoNLL.convert_conll(CONLL)
43
+ assert dicts == DICT
44
+ assert len(dicts) == len(empty)
45
+ assert all(len(x) == 0 for x in empty)
46
+
47
+ def test_dict_to_conll():
48
+ document = Document(DICT)
49
+ # :c = no comments
50
+ conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")]
51
+ assert conll == CONLL
52
+
53
+ def test_dict_to_doc_and_doc_to_dict():
54
+ """
55
+ Test the conversion from raw dict to Document and back
56
+
57
+ This code path will first turn start_char|end_char into start_char & end_char fields in the Document
58
+ That version to a dict will have separate fields for each of those
59
+ Finally, the conversion from that dict to a list of conll entries should convert that back to misc
60
+ """
61
+ document = Document(DICT)
62
+ dicts = document.to_dict()
63
+ document = Document(dicts)
64
+ conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")]
65
+ assert conll == CONLL
66
+
67
+ # sample is two sentences long so that the tests check multiple sentences
68
+ RUSSIAN_SAMPLE="""
69
+ # sent_id = yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253
70
+ # genre = review
71
+ # text = Как- то слишком мало цветов получают актёры после спектакля.
72
+ 1 Как как-то ADV _ Degree=Pos|PronType=Ind 7 advmod _ SpaceAfter=No
73
+ 2 - - PUNCT _ _ 3 punct _ _
74
+ 3 то то PART _ _ 1 list _ deprel=list:goeswith
75
+ 4 слишком слишком ADV _ Degree=Pos 5 advmod _ _
76
+ 5 мало мало ADV _ Degree=Pos 6 advmod _ _
77
+ 6 цветов цветок NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Plur 7 obj _ _
78
+ 7 получают получать VERB _ Aspect=Imp|Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ _
79
+ 8 актёры актер NOUN _ Animacy=Anim|Case=Nom|Gender=Masc|Number=Plur 7 nsubj _ _
80
+ 9 после после ADP _ _ 10 case _ _
81
+ 10 спектакля спектакль NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Sing 7 obl _ SpaceAfter=No
82
+ 11 . . PUNCT _ _ 7 punct _ _
83
+
84
+ # sent_id = 4
85
+ # genre = social
86
+ # text = В женщине важна верность, а не красота.
87
+ 1 В в ADP _ _ 2 case _ _
88
+ 2 женщине женщина NOUN _ Animacy=Anim|Case=Loc|Gender=Fem|Number=Sing 3 obl _ _
89
+ 3 важна важный ADJ _ Degree=Pos|Gender=Fem|Number=Sing|Variant=Short 0 root _ _
90
+ 4 верность верность NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 3 nsubj _ SpaceAfter=No
91
+ 5 , , PUNCT _ _ 8 punct _ _
92
+ 6 а а CCONJ _ _ 8 cc _ _
93
+ 7 не не PART _ Polarity=Neg 8 advmod _ _
94
+ 8 красота красота NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 4 conj _ SpaceAfter=No
95
+ 9 . . PUNCT _ _ 3 punct _ _
96
+ """.strip()
97
+
98
+ RUSSIAN_TEXT = ["Как- то слишком мало цветов получают актёры после спектакля.", "В женщине важна верность, а не красота."]
99
+ RUSSIAN_IDS = ["yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253", "4"]
100
+
101
+ def check_russian_doc(doc):
102
+ """
103
+ Refactored the test for the Russian doc so we can use it to test various file methods
104
+ """
105
+ lines = RUSSIAN_SAMPLE.split("\n")
106
+ assert len(doc.sentences) == 2
107
+ assert lines[0] == doc.sentences[0].comments[0]
108
+ assert lines[1] == doc.sentences[0].comments[1]
109
+ assert lines[2] == doc.sentences[0].comments[2]
110
+ for sent_idx, (expected_text, expected_id, sentence) in enumerate(zip(RUSSIAN_TEXT, RUSSIAN_IDS, doc.sentences)):
111
+ assert expected_text == sentence.text
112
+ assert expected_id == sentence.sent_id
113
+ assert sent_idx == sentence.index
114
+ assert len(sentence.comments) == 3
115
+ assert not sentence.has_enhanced_dependencies()
116
+
117
+ sentences = "{:C}".format(doc)
118
+ sentences = sentences.split("\n\n")
119
+ assert len(sentences) == 2
120
+
121
+ sentence = sentences[0].split("\n")
122
+ assert len(sentence) == 14
123
+ assert lines[0] == sentence[0]
124
+ assert lines[1] == sentence[1]
125
+ assert lines[2] == sentence[2]
126
+
127
+ # assert that the weird deprel=list:goeswith was properly handled
128
+ assert doc.sentences[0].words[2].head == 1
129
+ assert doc.sentences[0].words[2].deprel == "list:goeswith"
130
+
131
+ def test_write_russian_doc(tmp_path):
132
+ """
133
+ Specifically test the write_doc2conll method
134
+ """
135
+ filename = tmp_path / "russian.conll"
136
+ doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
137
+ check_russian_doc(doc)
138
+ CoNLL.write_doc2conll(doc, filename)
139
+
140
+ with open(filename, encoding="utf-8") as fin:
141
+ text = fin.read()
142
+
143
+ # the conll docs have to end with \n\n
144
+ assert text.endswith("\n\n")
145
+
146
+ # but to compare against the original, strip off the whitespace
147
+ text = text.strip()
148
+
149
+ # we skip the first sentence because the "deprel=list:goeswith" is weird
150
+ # note that the deprel itself is checked in check_russian_doc
151
+ text = text[text.find("# sent_id = 4"):]
152
+ sample = RUSSIAN_SAMPLE[RUSSIAN_SAMPLE.find("# sent_id = 4"):]
153
+ assert text == sample
154
+
155
+ doc2 = CoNLL.conll2doc(filename)
156
+ check_russian_doc(doc2)
157
+
158
+ # random sentence from EN_Pronouns
159
+ ENGLISH_SAMPLE = """
160
+ # newdoc
161
+ # sent_id = 1
162
+ # text = It is hers.
163
+ # previous = Which person owns this?
164
+ # comment = copular subject
165
+ 1 It it PRON PRP Number=Sing|Person=3|PronType=Prs 3 nsubj _ _
166
+ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _
167
+ 3 hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
168
+ 4 . . PUNCT . _ 3 punct _ _
169
+ """.strip()
170
+
171
+ def test_write_to_io():
172
+ doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)
173
+ output = io.StringIO()
174
+ CoNLL.write_doc2conll(doc, output)
175
+ output_value = output.getvalue()
176
+ assert output_value.endswith("\n\n")
177
+ assert output_value.strip() == ENGLISH_SAMPLE
178
+
179
+ def test_write_doc2conll_append(tmp_path):
180
+ doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)
181
+ filename = tmp_path / "english.conll"
182
+ CoNLL.write_doc2conll(doc, filename)
183
+ CoNLL.write_doc2conll(doc, filename, mode="a")
184
+
185
+ with open(filename) as fin:
186
+ text = fin.read()
187
+ expected = ENGLISH_SAMPLE + "\n\n" + ENGLISH_SAMPLE + "\n\n"
188
+ assert text == expected
189
+
190
+ def test_doc_with_comments():
191
+ """
192
+ Test that a doc with comments gets converted back with comments
193
+ """
194
+ doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
195
+ check_russian_doc(doc)
196
+
197
+ def test_unusual_misc():
198
+ """
199
+ The above RUSSIAN_SAMPLE resulted in a blank misc field in one particular implementation of the conll code
200
+ (the below test would fail)
201
+ """
202
+ doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
203
+ sentences = "{:C}".format(doc).split("\n\n")
204
+ assert len(sentences) == 2
205
+ sentence = sentences[0].split("\n")
206
+ assert len(sentence) == 14
207
+
208
+ for word in sentence:
209
+ pieces = word.split("\t")
210
+ assert len(pieces) == 1 or len(pieces) == 10
211
+ if len(pieces) == 10:
212
+ assert all(piece for piece in pieces)
213
+
214
+ def test_file():
215
+ """
216
+ Test loading a doc from a file
217
+ """
218
+ with tempfile.TemporaryDirectory() as tempdir:
219
+ filename = os.path.join(tempdir, "russian.conll")
220
+ with open(filename, "w", encoding="utf-8") as fout:
221
+ fout.write(RUSSIAN_SAMPLE)
222
+ doc = CoNLL.conll2doc(input_file=filename)
223
+ check_russian_doc(doc)
224
+
225
+ def test_zip_file():
226
+ """
227
+ Test loading a doc from a zip file
228
+ """
229
+ with tempfile.TemporaryDirectory() as tempdir:
230
+ zip_file = os.path.join(tempdir, "russian.zip")
231
+ filename = "russian.conll"
232
+ with ZipFile(zip_file, "w") as zout:
233
+ with zout.open(filename, "w") as fout:
234
+ fout.write(RUSSIAN_SAMPLE.encode())
235
+
236
+ doc = CoNLL.conll2doc(input_file=filename, zip_file=zip_file)
237
+ check_russian_doc(doc)
238
+
239
+ SIMPLE_NER = """
240
+ # text = Teferi's best friend is Karn
241
+ # sent_id = 0
242
+ 1 Teferi _ _ _ _ 0 _ _ start_char=0|end_char=6|ner=S-PERSON
243
+ 2 's _ _ _ _ 1 _ _ start_char=6|end_char=8|ner=O
244
+ 3 best _ _ _ _ 2 _ _ start_char=9|end_char=13|ner=O
245
+ 4 friend _ _ _ _ 3 _ _ start_char=14|end_char=20|ner=O
246
+ 5 is _ _ _ _ 4 _ _ start_char=21|end_char=23|ner=O
247
+ 6 Karn _ _ _ _ 5 _ _ start_char=24|end_char=28|ner=S-PERSON
248
+ """.strip()
249
+
250
+ def test_simple_ner_conversion():
251
+ """
252
+ Test that tokens get properly created with NER tags
253
+ """
254
+ doc = CoNLL.conll2doc(input_str=SIMPLE_NER)
255
+ assert len(doc.sentences) == 1
256
+ sentence = doc.sentences[0]
257
+ assert len(sentence.tokens) == 6
258
+ EXPECTED_NER = ["S-PERSON", "O", "O", "O", "O", "S-PERSON"]
259
+ for token, ner in zip(sentence.tokens, EXPECTED_NER):
260
+ assert token.ner == ner
261
+ # check that the ner, start_char, end_char fields were not put on the token's misc
262
+ # those should all be set as specific fields on the token
263
+ assert not token.misc
264
+ assert len(token.words) == 1
265
+ # they should also not reach the word's misc field
266
+ assert not token.words[0].misc
267
+
268
+ conll = "{:C}".format(doc)
269
+ assert conll == SIMPLE_NER
270
+
271
+ MWT_NER = """
272
+ # text = This makes John's headache worse
273
+ # sent_id = 0
274
+ 1 This _ _ _ _ 0 _ _ start_char=0|end_char=4|ner=O
275
+ 2 makes _ _ _ _ 1 _ _ start_char=5|end_char=10|ner=O
276
+ 3-4 John's _ _ _ _ _ _ _ start_char=11|end_char=17|ner=S-PERSON
277
+ 3 John _ _ _ _ 2 _ _ _
278
+ 4 's _ _ _ _ 3 _ _ _
279
+ 5 headache _ _ _ _ 4 _ _ start_char=18|end_char=26|ner=O
280
+ 6 worse _ _ _ _ 5 _ _ start_char=27|end_char=32|ner=O
281
+ """.strip()
282
+
283
+ def test_mwt_ner_conversion():
284
+ """
285
+ Test that tokens including MWT get properly created with NER tags
286
+
287
+ Note that this kind of thing happens with the EWT tokenizer for English, for example
288
+ """
289
+ doc = CoNLL.conll2doc(input_str=MWT_NER)
290
+ assert len(doc.sentences) == 1
291
+ sentence = doc.sentences[0]
292
+ assert len(sentence.tokens) == 5
293
+ assert not sentence.has_enhanced_dependencies()
294
+ EXPECTED_NER = ["O", "O", "S-PERSON", "O", "O"]
295
+ EXPECTED_WORDS = [1, 1, 2, 1, 1]
296
+ for token, ner, expected_words in zip(sentence.tokens, EXPECTED_NER, EXPECTED_WORDS):
297
+ assert token.ner == ner
298
+ # check that the ner, start_char, end_char fields were not put on the token's misc
299
+ # those should all be set as specific fields on the token
300
+ assert not token.misc
301
+ assert len(token.words) == expected_words
302
+ # they should also not reach the word's misc field
303
+ assert not token.words[0].misc
304
+
305
+ conll = "{:C}".format(doc)
306
+ assert conll == MWT_NER
307
+
308
+
309
+ # A random sentence from et_ewt-ud-train.conllu
310
+ # which we use to test the deps conversion for multiple deps
311
+ ESTONIAN_DEPS = """
312
+ # newpar
313
+ # sent_id = aia_foorum_37
314
+ # text = Sestpeale ei mõistagi neid, kes koduaias sortidega tegelevad.
315
+ 1 Sestpeale sest_peale ADV D _ 3 advmod 3:advmod _
316
+ 2 ei ei AUX V Polarity=Neg 3 aux 3:aux _
317
+ 3 mõistagi mõistma VERB V Connegative=Yes|Mood=Ind|Tense=Pres|VerbForm=Fin|Voice=Act 0 root 0:root _
318
+ 4 neid tema PRON P Case=Par|Number=Plur|Person=3|PronType=Prs 3 obj 3:obj|9:nsubj SpaceAfter=No
319
+ 5 , , PUNCT Z _ 9 punct 9:punct _
320
+ 6 kes kes PRON P Case=Nom|Number=Plur|PronType=Int,Rel 9 nsubj 4:ref _
321
+ 7 koduaias kodu_aed NOUN S Case=Ine|Number=Sing 9 obl 9:obl _
322
+ 8 sortidega sort NOUN S Case=Com|Number=Plur 9 obl 9:obl _
323
+ 9 tegelevad tegelema VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 4 acl:relcl 4:acl SpaceAfter=No
324
+ 10 . . PUNCT Z _ 3 punct 3:punct _
325
+ """.strip()
326
+
327
+ def test_deps_conversion():
328
+ doc = CoNLL.conll2doc(input_str=ESTONIAN_DEPS)
329
+ assert len(doc.sentences) == 1
330
+ sentence = doc.sentences[0]
331
+ assert len(sentence.tokens) == 10
332
+ assert sentence.has_enhanced_dependencies()
333
+
334
+ word = doc.sentences[0].words[3]
335
+ assert word.deps == "3:obj|9:nsubj"
336
+
337
+ conll = "{:C}".format(doc)
338
+ assert conll == ESTONIAN_DEPS
339
+
340
+ ESTONIAN_EMPTY_DEPS = """
341
+ # sent_id = ewtb2_000035_15
342
+ # text = Ja paari aasta pärast rôômalt maasikatele ...
343
+ 1 Ja ja CCONJ J _ 3 cc 5.1:cc _
344
+ 2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
345
+ 3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
346
+ 4 pärast pärast ADP K AdpType=Post 3 case 3:case _
347
+ 5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
348
+ 5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
349
+ 6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes
350
+ 7 ... ... PUNCT Z _ 3 punct 5.1:punct _
351
+ """.strip()
352
+
353
+ ESTONIAN_EMPTY_END_DEPS = """
354
+ # sent_id = ewtb2_000035_15
355
+ # text = Ja paari aasta pärast rôômalt maasikatele ...
356
+ 1 Ja ja CCONJ J _ 3 cc 5.1:cc _
357
+ 2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
358
+ 3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
359
+ 4 pärast pärast ADP K AdpType=Post 3 case 3:case _
360
+ 5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
361
+ 5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
362
+ """.strip()
363
+
364
+ def test_empty_deps_conversion():
365
+ """
366
+ Check that we can read and then output a sentence with empty dependencies
367
+ """
368
+ check_empty_deps_conversion(ESTONIAN_EMPTY_DEPS, 7)
369
+
370
+ def test_empty_deps_at_end_conversion():
371
+ """
372
+ The empty deps conversion should also work if the empty dep is at the end
373
+ """
374
+ check_empty_deps_conversion(ESTONIAN_EMPTY_END_DEPS, 5)
375
+
376
+ def check_empty_deps_conversion(input_str, expected_words):
377
+ doc = CoNLL.conll2doc(input_str=input_str, ignore_gapping=False)
378
+ assert len(doc.sentences) == 1
379
+ assert len(doc.sentences[0].tokens) == expected_words
380
+ assert len(doc.sentences[0].words) == expected_words
381
+ assert len(doc.sentences[0].empty_words) == 1
382
+
383
+ sentence = doc.sentences[0]
384
+ conll = "{:C}".format(doc)
385
+ assert conll == input_str
386
+
387
+ sentence_dict = doc.sentences[0].to_dict()
388
+ assert len(sentence_dict) == expected_words + 1
389
+ # currently this is true for both of the examples we run
390
+ assert sentence_dict[5]['id'] == (5, 1)
391
+
392
+ # redo the above checks to make sure
393
+ # there are no weird bugs in the accessors
394
+ assert len(doc.sentences) == 1
395
+ assert len(doc.sentences[0].tokens) == expected_words
396
+ assert len(doc.sentences[0].words) == expected_words
397
+ assert len(doc.sentences[0].empty_words) == 1
398
+
399
+
400
+ ESTONIAN_DOC_ID = """
401
+ # doc_id = this_is_a_doc
402
+ # sent_id = ewtb2_000035_15
403
+ # text = Ja paari aasta pärast rôômalt maasikatele ...
404
+ 1 Ja ja CCONJ J _ 3 cc 5.1:cc _
405
+ 2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
406
+ 3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
407
+ 4 pärast pärast ADP K AdpType=Post 3 case 3:case _
408
+ 5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
409
+ 5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
410
+ 6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes
411
+ 7 ... ... PUNCT Z _ 3 punct 5.1:punct _
412
+ """.strip()
413
+
414
+ def test_read_doc_id():
415
+ doc = CoNLL.conll2doc(input_str=ESTONIAN_DOC_ID, ignore_gapping=False)
416
+ assert "{:C}".format(doc) == ESTONIAN_DOC_ID
417
+ assert doc.sentences[0].doc_id == 'this_is_a_doc'
418
+
419
+ SIMPLE_DEPENDENCY_INDEX_ERROR = """
420
+ # text = Teferi's best friend is Karn
421
+ # sent_id = 0
422
+ # notes = this sentence has a dependency index outside the sentence. it should throw an IndexError
423
+ 1 Teferi _ _ _ _ 0 root _ start_char=0|end_char=6|ner=S-PERSON
424
+ 2 's _ _ _ _ 1 dep _ start_char=6|end_char=8|ner=O
425
+ 3 best _ _ _ _ 2 dep _ start_char=9|end_char=13|ner=O
426
+ 4 friend _ _ _ _ 3 dep _ start_char=14|end_char=20|ner=O
427
+ 5 is _ _ _ _ 4 dep _ start_char=21|end_char=23|ner=O
428
+ 6 Karn _ _ _ _ 8 dep _ start_char=24|end_char=28|ner=S-PERSON
429
+ """.strip()
430
+
431
+ def test_read_dependency_errors():
432
+ with pytest.raises(IndexError):
433
+ doc = CoNLL.conll2doc(input_str=SIMPLE_DEPENDENCY_INDEX_ERROR)
434
+
435
+ MULTIPLE_DOC_IDS = """
436
+ # doc_id = doc_1
437
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0020
438
+ # text = His mother was also killed in the attack.
439
+ 1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 2 nmod:poss 2:nmod:poss _
440
+ 2 mother mother NOUN NN Number=Sing 5 nsubj:pass 5:nsubj:pass _
441
+ 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 5 aux:pass 5:aux:pass _
442
+ 4 also also ADV RB _ 5 advmod 5:advmod _
443
+ 5 killed kill VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
444
+ 6 in in ADP IN _ 8 case 8:case _
445
+ 7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _
446
+ 8 attack attack NOUN NN Number=Sing 5 obl 5:obl:in SpaceAfter=No
447
+ 9 . . PUNCT . _ 5 punct 5:punct _
448
+
449
+ # doc_id = doc_1
450
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0028
451
+ # text = This item is a small one and easily missed.
452
+ 1 This this DET DT Number=Sing|PronType=Dem 2 det 2:det _
453
+ 2 item item NOUN NN Number=Sing 6 nsubj 6:nsubj|9:nsubj:pass _
454
+ 3 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
455
+ 4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
456
+ 5 small small ADJ JJ Degree=Pos 6 amod 6:amod _
457
+ 6 one one NOUN NN Number=Sing 0 root 0:root _
458
+ 7 and and CCONJ CC _ 9 cc 9:cc _
459
+ 8 easily easily ADV RB _ 9 advmod 9:advmod _
460
+ 9 missed miss VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 6 conj 6:conj:and SpaceAfter=No
461
+ 10 . . PUNCT . _ 6 punct 6:punct _
462
+
463
+ # doc_id = doc_2
464
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0029
465
+ # text = But in my view it is highly significant.
466
+ 1 But but CCONJ CC _ 8 cc 8:cc _
467
+ 2 in in ADP IN _ 4 case 4:case _
468
+ 3 my my PRON PRP$ Case=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs 4 nmod:poss 4:nmod:poss _
469
+ 4 view view NOUN NN Number=Sing 8 obl 8:obl:in _
470
+ 5 it it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 8 nsubj 8:nsubj _
471
+ 6 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 8 cop 8:cop _
472
+ 7 highly highly ADV RB _ 8 advmod 8:advmod _
473
+ 8 significant significant ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No
474
+ 9 . . PUNCT . _ 8 punct 8:punct _
475
+
476
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0040
477
+ # text = The trial begins again Nov.28.
478
+ 1 The the DET DT Definite=Def|PronType=Art 2 det 2:det _
479
+ 2 trial trial NOUN NN Number=Sing 3 nsubj 3:nsubj _
480
+ 3 begins begin VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
481
+ 4 again again ADV RB _ 3 advmod 3:advmod _
482
+ 5 Nov. November PROPN NNP Abbr=Yes|Number=Sing 3 obl:tmod 3:obl:tmod SpaceAfter=No
483
+ 6 28 28 NUM CD NumForm=Digit|NumType=Card 5 nummod 5:nummod SpaceAfter=No
484
+ 7 . . PUNCT . _ 3 punct 3:punct _
485
+
486
+ """.lstrip()
487
+
488
+ def test_read_multiple_doc_ids():
489
+ docs = CoNLL.conll2multi_docs(input_str=MULTIPLE_DOC_IDS)
490
+ assert len(docs) == 2
491
+ assert len(docs[0].sentences) == 2
492
+ assert len(docs[1].sentences) == 2
493
+
494
+ # remove the first doc_id comment
495
+ text = "\n".join(MULTIPLE_DOC_IDS.split("\n")[1:])
496
+ docs = CoNLL.conll2multi_docs(input_str=text)
497
+ assert len(docs) == 3
498
+ assert len(docs[0].sentences) == 1
499
+ assert len(docs[1].sentences) == 1
500
+ assert len(docs[2].sentences) == 2
501
+
502
+ ENGLISH_TEST_SENTENCE = """
503
+ # text = This is a test
504
+ # sent_id = 0
505
+ 1 This this PRON DT Number=Sing|PronType=Dem 4 nsubj _ start_char=0|end_char=4
506
+ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ start_char=5|end_char=7
507
+ 3 a a DET DT Definite=Ind|PronType=Art 4 det _ start_char=8|end_char=9
508
+ 4 test test NOUN NN Number=Sing 0 root _ start_char=10|end_char=14|SpaceAfter=No
509
+ """.lstrip()
510
+
511
+ def test_convert_dict():
512
+ doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE)
513
+ converted = CoNLL.convert_dict(doc.to_dict())
514
+
515
+ expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'],
516
+ ['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'],
517
+ ['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'],
518
+ ['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]]
519
+
520
+ assert converted == expected
stanza/stanza/tests/common/test_foundation_cache.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+
6
+ import pytest
7
+
8
+ import stanza
9
+ from stanza.models.common.foundation_cache import FoundationCache, load_charlm
10
+ from stanza.tests import TEST_MODELS_DIR
11
+
12
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
13
+
14
+ def test_charlm_cache():
15
+ models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*")
16
+ models = glob.glob(models_path)
17
+ # we expect at least one English model downloaded for the tests
18
+ assert len(models) >= 1
19
+ model_file = models[0]
20
+
21
+ cache = FoundationCache()
22
+ with tempfile.TemporaryDirectory(dir=".") as test_dir:
23
+ temp_file = os.path.join(test_dir, "charlm.pt")
24
+ shutil.copy2(model_file, temp_file)
25
+ # this will work
26
+ model = load_charlm(temp_file)
27
+
28
+ # this will save the model
29
+ model = cache.load_charlm(temp_file)
30
+
31
+ # this should no longer work
32
+ with pytest.raises(FileNotFoundError):
33
+ model = load_charlm(temp_file)
34
+
35
+ # it should remember the cached version
36
+ model = cache.load_charlm(temp_file)
stanza/stanza/tests/common/test_pretrain.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import pytest
5
+ import numpy as np
6
+ import torch
7
+
8
+ from stanza.models.common import pretrain
9
+ from stanza.models.common.vocab import UNK_ID
10
+ from stanza.tests import *
11
+
12
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
13
+
14
+ def check_vocab(vocab):
15
+ # 4 base vectors, plus the 3 vectors actually present in the file
16
+ assert len(vocab) == 7
17
+ assert 'unban' in vocab
18
+ assert 'mox' in vocab
19
+ assert 'opal' in vocab
20
+
21
+ def check_embedding(emb, unk=False):
22
+ expected = np.array([[ 0., 0., 0., 0.,],
23
+ [ 0., 0., 0., 0.,],
24
+ [ 0., 0., 0., 0.,],
25
+ [ 0., 0., 0., 0.,],
26
+ [ 1., 2., 3., 4.,],
27
+ [ 5., 6., 7., 8.,],
28
+ [ 9., 10., 11., 12.,]])
29
+ if unk:
30
+ expected[UNK_ID] = -1
31
+ np.testing.assert_allclose(emb, expected)
32
+
33
+ def check_pretrain(pt):
34
+ check_vocab(pt.vocab)
35
+ check_embedding(pt.emb)
36
+
37
+ def test_text_pretrain():
38
+ pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.txt', save_to_file=False)
39
+ check_pretrain(pt)
40
+
41
+ def test_xz_pretrain():
42
+ pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False)
43
+ check_pretrain(pt)
44
+
45
+ def test_gz_pretrain():
46
+ pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.gz', save_to_file=False)
47
+ check_pretrain(pt)
48
+
49
+ def test_zip_pretrain():
50
+ pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.zip', save_to_file=False)
51
+ check_pretrain(pt)
52
+
53
+ def test_csv_pretrain():
54
+ pt = pretrain.Pretrain(csv_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.csv', save_to_file=False)
55
+ check_pretrain(pt)
56
+
57
+ def test_resave_pretrain():
58
+ """
59
+ Test saving a pretrain and then loading from the existing file
60
+ """
61
+ test_pt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".pt", delete=False)
62
+ try:
63
+ test_pt_file.close()
64
+ # note that this tests the ability to save a pretrain and the
65
+ # ability to fall back when the existing pretrain isn't working
66
+ pt = pretrain.Pretrain(filename=test_pt_file.name,
67
+ vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz')
68
+ check_pretrain(pt)
69
+
70
+ pt2 = pretrain.Pretrain(filename=test_pt_file.name,
71
+ vec_filename=f'unban_mox_opal')
72
+ check_pretrain(pt2)
73
+
74
+ pt3 = torch.load(test_pt_file.name, weights_only=True)
75
+ check_embedding(pt3['emb'])
76
+ finally:
77
+ os.unlink(test_pt_file.name)
78
+
79
+ SPACE_PRETRAIN="""
80
+ 3 4
81
+ unban mox 1 2 3 4
82
+ opal 5 6 7 8
83
+ foo 9 10 11 12
84
+ """.strip()
85
+
86
+ def test_whitespace():
87
+ """
88
+ Test reading a pretrain with an ascii space in it
89
+
90
+ The vocab word with a space in it should have the correct number
91
+ of dimensions read, with the space converted to nbsp
92
+ """
93
+ test_txt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".txt", delete=False)
94
+ try:
95
+ test_txt_file.write(SPACE_PRETRAIN.encode())
96
+ test_txt_file.close()
97
+
98
+ pt = pretrain.Pretrain(vec_filename=test_txt_file.name, save_to_file=False)
99
+ check_embedding(pt.emb)
100
+ assert "unban\xa0mox" in pt.vocab
101
+ # this one also works because of the normalize_unit in vocab.py
102
+ assert "unban mox" in pt.vocab
103
+ finally:
104
+ os.unlink(test_txt_file.name)
105
+
106
+ NO_HEADER_PRETRAIN="""
107
+ unban 1 2 3 4
108
+ mox 5 6 7 8
109
+ opal 9 10 11 12
110
+ """.strip()
111
+
112
+ def test_no_header():
113
+ """
114
+ Check loading a pretrain with no rows,cols header
115
+ """
116
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:
117
+ filename = os.path.join(tmpdir, "tiny.txt")
118
+ with open(filename, "w", encoding="utf-8") as fout:
119
+ fout.write(NO_HEADER_PRETRAIN)
120
+ pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)
121
+ check_embedding(pt.emb)
122
+
123
+ UNK_PRETRAIN="""
124
+ unban 1 2 3 4
125
+ mox 5 6 7 8
126
+ opal 9 10 11 12
127
+ <unk> -1 -1 -1 -1
128
+ """.strip()
129
+
130
+ def test_no_header():
131
+ """
132
+ Check loading a pretrain with <unk> at the end, like GloVe does
133
+ """
134
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:
135
+ filename = os.path.join(tmpdir, "tiny.txt")
136
+ with open(filename, "w", encoding="utf-8") as fout:
137
+ fout.write(UNK_PRETRAIN)
138
+ pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)
139
+ check_embedding(pt.emb, unk=True)
stanza/stanza/tests/common/test_utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lzma
2
+ import os
3
+ import tempfile
4
+
5
+ import pytest
6
+
7
+ import stanza
8
+ import stanza.models.common.utils as utils
9
+ from stanza.tests import *
10
+
11
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
12
+
13
+ def test_wordvec_not_found():
14
+ """
15
+ get_wordvec_file should fail if neither word2vec nor fasttext exists
16
+ """
17
+ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
18
+ with pytest.raises(FileNotFoundError):
19
+ utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
20
+
21
+
22
+ def test_word2vec_xz():
23
+ """
24
+ Test searching for word2vec and xz files
25
+ """
26
+ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
27
+ # make a fake directory for English word vectors
28
+ word2vec_dir = os.path.join(temp_dir, 'word2vec', 'English')
29
+ os.makedirs(word2vec_dir)
30
+
31
+ # make a fake English word vector file
32
+ fake_file = os.path.join(word2vec_dir, 'en.vectors.xz')
33
+ fout = open(fake_file, 'w')
34
+ fout.close()
35
+
36
+ # get_wordvec_file should now find this fake file
37
+ filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
38
+ assert filename == fake_file
39
+
40
+ def test_fasttext_txt():
41
+ """
42
+ Test searching for fasttext and txt files
43
+ """
44
+ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
45
+ # make a fake directory for English word vectors
46
+ fasttext_dir = os.path.join(temp_dir, 'fasttext', 'English')
47
+ os.makedirs(fasttext_dir)
48
+
49
+ # make a fake English word vector file
50
+ fake_file = os.path.join(fasttext_dir, 'en.vectors.txt')
51
+ fout = open(fake_file, 'w')
52
+ fout.close()
53
+
54
+ # get_wordvec_file should now find this fake file
55
+ filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
56
+ assert filename == fake_file
57
+
58
+ def test_wordvec_type():
59
+ """
60
+ If we supply our own wordvec type, get_wordvec_file should find that
61
+ """
62
+ with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
63
+ # make a fake directory for English word vectors
64
+ google_dir = os.path.join(temp_dir, 'google', 'English')
65
+ os.makedirs(google_dir)
66
+
67
+ # make a fake English word vector file
68
+ fake_file = os.path.join(google_dir, 'en.vectors.txt')
69
+ fout = open(fake_file, 'w')
70
+ fout.close()
71
+
72
+ # get_wordvec_file should now find this fake file
73
+ filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo', wordvec_type='google')
74
+ assert filename == fake_file
75
+
76
+ # this file won't be found using the normal defaults
77
+ with pytest.raises(FileNotFoundError):
78
+ utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
79
+
80
+ def test_sort_with_indices():
81
+ data = [[1, 2, 3], [4, 5], [6]]
82
+ ordered, orig_idx = utils.sort_with_indices(data, key=len)
83
+ assert ordered == ([6], [4, 5], [1, 2, 3])
84
+ assert orig_idx == (2, 1, 0)
85
+
86
+ unsorted = utils.unsort(ordered, orig_idx)
87
+ assert data == unsorted
88
+
89
+ def test_empty_sort_with_indices():
90
+ ordered, orig_idx = utils.sort_with_indices([])
91
+ assert len(ordered) == 0
92
+ assert len(orig_idx) == 0
93
+
94
+ unsorted = utils.unsort(ordered, orig_idx)
95
+ assert [] == unsorted
96
+
97
+
98
+ def test_split_into_batches():
99
+ data = []
100
+ for i in range(5):
101
+ data.append(["Unban", "mox", "opal", str(i)])
102
+
103
+ data.append(["Do", "n't", "ban", "Urza", "'s", "Saga", "that", "card", "is", "great"])
104
+ data.append(["Ban", "Ragavan"])
105
+
106
+ # small batches will put one element in each interval
107
+ batches = utils.split_into_batches(data, 5)
108
+ assert batches == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
109
+
110
+ # this one has a batch interrupted in the middle by a large element
111
+ batches = utils.split_into_batches(data, 8)
112
+ assert batches == [(0, 2), (2, 4), (4, 5), (5, 6), (6, 7)]
113
+
114
+ # this one has the large element at the start of its own batch
115
+ batches = utils.split_into_batches(data[1:], 8)
116
+ assert batches == [(0, 2), (2, 4), (4, 5), (5, 6)]
117
+
118
+ # overloading the test! assert that the key & reverse is working
119
+ ordered, orig_idx = utils.sort_with_indices(data, key=len, reverse=True)
120
+ assert [len(x) for x in ordered] == [10, 4, 4, 4, 4, 4, 2]
121
+
122
+ # this has the large element at the start
123
+ batches = utils.split_into_batches(ordered, 8)
124
+ assert batches == [(0, 1), (1, 3), (3, 5), (5, 7)]
125
+
126
+ # double check that unsort is working as expected
127
+ assert data == utils.unsort(ordered, orig_idx)
128
+
129
+
130
+ def test_find_missing_tags():
131
+ assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == []
132
+ assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG']
133
+ assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG']
134
+
135
+
136
+ def test_open_read_text():
137
+ """
138
+ test that we can read either .xz or regular txt
139
+ """
140
+ TEXT = "this is a test"
141
+ with tempfile.TemporaryDirectory() as tempdir:
142
+ # test text file
143
+ filename = os.path.join(tempdir, "foo.txt")
144
+ with open(filename, "w") as fout:
145
+ fout.write(TEXT)
146
+ with utils.open_read_text(filename) as fin:
147
+ in_text = fin.read()
148
+ assert TEXT == in_text
149
+
150
+ assert fin.closed
151
+
152
+ # the context should close the file when we throw an exception!
153
+ try:
154
+ with utils.open_read_text(filename) as finex:
155
+ assert not finex.closed
156
+ raise ValueError("unban mox opal!")
157
+ except ValueError:
158
+ pass
159
+ assert finex.closed
160
+
161
+ # test xz file
162
+ filename = os.path.join(tempdir, "foo.txt.xz")
163
+ with lzma.open(filename, "wt") as fout:
164
+ fout.write(TEXT)
165
+ with utils.open_read_text(filename) as finxz:
166
+ in_text = finxz.read()
167
+ assert TEXT == in_text
168
+
169
+ assert finxz.closed
170
+
171
+ # the context should close the file when we throw an exception!
172
+ try:
173
+ with utils.open_read_text(filename) as finexxz:
174
+ assert not finexxz.closed
175
+ raise ValueError("unban mox opal!")
176
+ except ValueError:
177
+ pass
178
+ assert finexxz.closed
179
+
180
+
181
+ def test_checkpoint_name():
182
+ """
183
+ Test some expected results for the checkpoint names
184
+ """
185
+ # use os.path.split so that the test is agnostic of file separator on Linux or Windows
186
+ checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm.pt", None)
187
+ assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint.pt")
188
+
189
+ checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", None)
190
+ assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint")
191
+
192
+ checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", "othername.pt")
193
+ assert os.path.split(checkpoint) == ("saved_models", "othername.pt")
194
+
stanza/stanza/tests/constituency/__init__.py ADDED
File without changes
stanza/stanza/tests/constituency/test_convert_arboretum.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple different classes of trees to check the output of the Arboretum conversion
3
+
4
+ Note that the text has been removed
5
+ """
6
+
7
+ import os
8
+ import tempfile
9
+
10
+ import pytest
11
+
12
+ from stanza.server import tsurgeon
13
+ from stanza.tests import TEST_WORKING_DIR
14
+ from stanza.utils.datasets.constituency import convert_arboretum
15
+
16
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
17
+
18
+
19
+ PROJ_EXAMPLE="""
20
+ <s id="s2" ref="AACBPIGY" source="id=AACBPIGY" forest="1/1" text="A B C D E F G H.">
21
+ <graph root="s2_500">
22
+ <terminals>
23
+ <t id="s2_1" word="A" lemma="A" pos="prop" morph="NOM" extra="PROP:A compound brand"/>
24
+ <t id="s2_2" word="B" lemma="B" pos="v-fin" morph="PR AKT" extra="mv"/>
25
+ <t id="s2_3" word="C" lemma="C" pos="pron-pers" morph="2S ACC" extra="--"/>
26
+ <t id="s2_4" word="D" lemma="D" pos="adj" morph="UTR S IDF NOM" extra="F:u+afhængig"/>
27
+ <t id="s2_5" word="E" lemma="E" pos="prp" morph="--" extra="--"/>
28
+ <t id="s2_6" word="F" lemma="F" pos="art" morph="NEU S DEF" extra="--"/>
29
+ <t id="s2_7" word="G" lemma="G" pos="adj" morph="nG S DEF NOM" extra="--"/>
30
+ <t id="s2_8" word="H" lemma="H" pos="n" morph="NEU S IDF NOM" extra="N:lys+net"/>
31
+ <t id="s2_9" word="." lemma="--" pos="pu" morph="--" extra="--"/>
32
+ </terminals>
33
+
34
+ <nonterminals>
35
+ <nt id="s2_500" cat="s">
36
+ <edge label="STA" idref="s2_501"/>
37
+ </nt>
38
+ <nt id="s2_501" cat="fcl">
39
+ <edge label="S" idref="s2_1"/>
40
+ <edge label="P" idref="s2_2"/>
41
+ <edge label="Od" idref="s2_3"/>
42
+ <edge label="Co" idref="s2_502"/>
43
+ <edge label="PU" idref="s2_9"/>
44
+ </nt>
45
+ <nt id="s2_502" cat="adjp">
46
+ <edge label="H" idref="s2_4"/>
47
+ <edge label="DA" idref="s2_503"/>
48
+ </nt>
49
+ <nt id="s2_503" cat="pp">
50
+ <edge label="H" idref="s2_5"/>
51
+ <edge label="DP" idref="s2_504"/>
52
+ </nt>
53
+ <nt id="s2_504" cat="np">
54
+ <edge label="DN" idref="s2_6"/>
55
+ <edge label="DN" idref="s2_7"/>
56
+ <edge label="H" idref="s2_8"/>
57
+ </nt>
58
+ </nonterminals>
59
+ </graph>
60
+ </s>
61
+ """
62
+
63
+ NOT_FIX_NONPROJ_EXAMPLE="""
64
+ <s id="s322" ref="EDGBITSZ" source="id=EDGBITSZ" forest="1/2" text="A B C D E, F G H I J.">
65
+ <graph root="s322_500">
66
+ <terminals>
67
+ <t id="s322_1" word="A" lemma="A" pos="prop" morph="NOM" extra="hum fem"/>
68
+ <t id="s322_2" word="B" lemma="B" pos="v-fin" morph="PR AKT" extra="mv"/>
69
+ <t id="s322_3" word="C" lemma="C" pos="pron-dem" morph="UTR S" extra="dem"/>
70
+ <t id="s322_4" word="D" lemma="D" pos="n" morph="UTR S IDF NOM" extra="--"/>
71
+ <t id="s322_5" word="E" lemma="E" pos="adv" morph="--" extra="--"/>
72
+ <t id="s322_6" word="," lemma="--" pos="pu" morph="--" extra="--"/>
73
+ <t id="s322_7" word="F" lemma="F" pos="pron-rel" morph="--" extra="rel"/>
74
+ <t id="s322_8" word="G" lemma="G" pos="prop" morph="NOM" extra="hum"/>
75
+ <t id="s322_9" word="H" lemma="H" pos="v-fin" morph="IMPF AKT" extra="mv"/>
76
+ <t id="s322_10" word="I" lemma="I" pos="prp" morph="--" extra="--"/>
77
+ <t id="s322_11" word="J" lemma="J" pos="n" morph="UTR S DEF NOM" extra="F:ur+premiere"/>
78
+ <t id="s322_12" word="." lemma="--" pos="pu" morph="--" extra="--"/>
79
+ </terminals>
80
+
81
+ <nonterminals>
82
+ <nt id="s322_500" cat="s">
83
+ <edge label="STA" idref="s322_501"/>
84
+ </nt>
85
+ <nt id="s322_501" cat="fcl">
86
+ <edge label="S" idref="s322_1"/>
87
+ <edge label="P" idref="s322_2"/>
88
+ <edge label="Od" idref="s322_502"/>
89
+ <edge label="Vpart" idref="s322_5"/>
90
+ <edge label="PU" idref="s322_6"/>
91
+ <edge label="PU" idref="s322_12"/>
92
+ </nt>
93
+ <nt id="s322_502" cat="np">
94
+ <edge label="DN" idref="s322_3"/>
95
+ <edge label="H" idref="s322_4"/>
96
+ <edge label="DN" idref="s322_503"/>
97
+ </nt>
98
+ <nt id="s322_503" cat="fcl">
99
+ <edge label="Od" idref="s322_7"/>
100
+ <edge label="S" idref="s322_8"/>
101
+ <edge label="P" idref="s322_9"/>
102
+ <edge label="Ao" idref="s322_504"/>
103
+ </nt>
104
+ <nt id="s322_504" cat="pp">
105
+ <edge label="H" idref="s322_10"/>
106
+ <edge label="DP" idref="s322_11"/>
107
+ </nt>
108
+ </nonterminals>
109
+ </graph>
110
+ </s>
111
+ """
112
+
113
+
114
+ NONPROJ_EXAMPLE="""
115
+ <s id="s9" ref="AATCNKQZ" source="id=AATCNKQZ" forest="1/1" text="A B C D E F G H I.">
116
+ <graph root="s9_500">
117
+ <terminals>
118
+ <t id="s9_1" word="A" lemma="A" pos="adv" morph="--" extra="--"/>
119
+ <t id="s9_2" word="B" lemma="B" pos="adv" morph="--" extra="--"/>
120
+ <t id="s9_3" word="C" lemma="C" pos="v-fin" morph="IMPF AKT" extra="aux"/>
121
+ <t id="s9_4" word="D" lemma="D" pos="prop" morph="NOM" extra="hum"/>
122
+ <t id="s9_5" word="E" lemma="E" pos="adv" morph="--" extra="--"/>
123
+ <t id="s9_6" word="F" lemma="F" pos="v-pcp2" morph="PAS" extra="mv"/>
124
+ <t id="s9_7" word="G" lemma="G" pos="prp" morph="--" extra="--"/>
125
+ <t id="s9_8" word="H" lemma="H" pos="num" morph="--" extra="card"/>
126
+ <t id="s9_9" word="I" lemma="I" pos="n" morph="UTR P IDF NOM" extra="N:patrulje+vogn"/>
127
+ <t id="s9_10" word="." lemma="--" pos="pu" morph="--" extra="--"/>
128
+ </terminals>
129
+
130
+ <nonterminals>
131
+ <nt id="s9_500" cat="s">
132
+ <edge label="STA" idref="s9_501"/>
133
+ </nt>
134
+ <nt id="s9_501" cat="fcl">
135
+ <edge label="fA" idref="s9_502"/>
136
+ <edge label="P" idref="s9_503"/>
137
+ <edge label="S" idref="s9_4"/>
138
+ <edge label="fA" idref="s9_5"/>
139
+ <edge label="fA" idref="s9_504"/>
140
+ <edge label="PU" idref="s9_10"/>
141
+ </nt>
142
+ <nt id="s9_502" cat="advp">
143
+ <edge label="DA" idref="s9_1"/>
144
+ <edge label="H" idref="s9_2"/>
145
+ </nt>
146
+ <nt id="s9_503" cat="vp">
147
+ <edge label="Vaux" idref="s9_3"/>
148
+ <edge label="Vm" idref="s9_6"/>
149
+ </nt>
150
+ <nt id="s9_504" cat="pp">
151
+ <edge label="H" idref="s9_7"/>
152
+ <edge label="DP" idref="s9_505"/>
153
+ </nt>
154
+ <nt id="s9_505" cat="np">
155
+ <edge label="DN" idref="s9_8"/>
156
+ <edge label="H" idref="s9_9"/>
157
+ </nt>
158
+ </nonterminals>
159
+ </graph>
160
+ </s>
161
+ """
162
+
163
+ def test_projective_example():
164
+ """
165
+ Test reading a basic tree, along with some further manipulations from the conversion program
166
+ """
167
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
168
+ test_name = os.path.join(tempdir, "proj.xml")
169
+ with open(test_name, "w", encoding="utf-8") as fout:
170
+ fout.write(PROJ_EXAMPLE)
171
+ sentences = convert_arboretum.read_xml_file(test_name)
172
+ assert len(sentences) == 1
173
+
174
+ tree, words = convert_arboretum.process_tree(sentences[0])
175
+ expected_tree = "(s (fcl (prop s2_1) (v-fin s2_2) (pron-pers s2_3) (adjp (adj s2_4) (pp (prp s2_5) (np (art s2_6) (adj s2_7) (n s2_8)))) (pu s2_9)))"
176
+ assert str(tree) == expected_tree
177
+ assert [w.word for w in words.values()] == ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', '.']
178
+ assert not convert_arboretum.word_sequence_missing_words(tree)
179
+ with tsurgeon.Tsurgeon() as tsurgeon_processor:
180
+ assert tree == convert_arboretum.check_words(tree, tsurgeon_processor)
181
+
182
+ # check that the words can be replaced as expected
183
+ replaced_tree = convert_arboretum.replace_words(tree, words)
184
+ expected_tree = "(s (fcl (prop A) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
185
+ assert str(replaced_tree) == expected_tree
186
+ assert convert_arboretum.split_underscores(replaced_tree) == replaced_tree
187
+
188
+ # fake a word which should be split
189
+ words['s2_1'] = words['s2_1']._replace(word='foo_bar')
190
+ replaced_tree = convert_arboretum.replace_words(tree, words)
191
+ expected_tree = "(s (fcl (prop foo_bar) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
192
+ assert str(replaced_tree) == expected_tree
193
+ expected_tree = "(s (fcl (np (prop foo) (prop bar)) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
194
+ assert str(convert_arboretum.split_underscores(replaced_tree)) == expected_tree
195
+
196
+
197
+ def test_not_fix_example():
198
+ """
199
+ Test that a non-projective tree which we don't have a heuristic for quietly fails
200
+ """
201
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
202
+ test_name = os.path.join(tempdir, "nofix.xml")
203
+ with open(test_name, "w", encoding="utf-8") as fout:
204
+ fout.write(NOT_FIX_NONPROJ_EXAMPLE)
205
+ sentences = convert_arboretum.read_xml_file(test_name)
206
+ assert len(sentences) == 1
207
+
208
+ tree, words = convert_arboretum.process_tree(sentences[0])
209
+ assert not convert_arboretum.word_sequence_missing_words(tree)
210
+ with tsurgeon.Tsurgeon() as tsurgeon_processor:
211
+ assert convert_arboretum.check_words(tree, tsurgeon_processor) is None
212
+
213
+
214
+ def test_fix_proj_example():
215
+ """
216
+ Test that a non-projective tree can be rearranged as expected
217
+
218
+ Note that there are several other classes of non-proj tree we could test as well...
219
+ """
220
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
221
+ test_name = os.path.join(tempdir, "fix.xml")
222
+ with open(test_name, "w", encoding="utf-8") as fout:
223
+ fout.write(NONPROJ_EXAMPLE)
224
+ sentences = convert_arboretum.read_xml_file(test_name)
225
+ assert len(sentences) == 1
226
+
227
+ tree, words = convert_arboretum.process_tree(sentences[0])
228
+ assert not convert_arboretum.word_sequence_missing_words(tree)
229
+ # the 4 and 5 are moved inside the 3-6 node
230
+ expected_orig = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (v-pcp2 s9_6)) (prop s9_4) (adv s9_5) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))"
231
+ expected_proj = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (prop s9_4) (adv s9_5) (v-pcp2 s9_6)) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))"
232
+ assert str(tree) == expected_orig
233
+ with tsurgeon.Tsurgeon() as tsurgeon_processor:
234
+ assert str(convert_arboretum.check_words(tree, tsurgeon_processor)) == expected_proj
235
+
stanza/stanza/tests/constituency/test_ensemble.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Add a simple test of the Ensemble's inference path
3
+
4
+ This just reuses one model several times - that should still check the main loop, at least
5
+ """
6
+
7
+ import pytest
8
+
9
+ from stanza import Pipeline
10
+ from stanza.models.constituency import text_processing
11
+ from stanza.models.constituency import tree_reader
12
+ from stanza.models.constituency.ensemble import Ensemble, EnsembleTrainer
13
+ from stanza.models.constituency.text_processing import parse_tokenized_sentences
14
+
15
+ from stanza.tests import TEST_MODELS_DIR
16
+
17
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
18
+
19
+
20
+ @pytest.fixture(scope="module")
21
+ def pipeline():
22
+ return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True)
23
+
24
+ @pytest.fixture(scope="module")
25
+ def saved_ensemble(tmp_path_factory, pipeline):
26
+ tmp_path = tmp_path_factory.mktemp("ensemble")
27
+
28
+ # test the ensemble by reusing the same parser multiple times
29
+ con_processor = pipeline.processors["constituency"]
30
+ model = con_processor._model
31
+ args = dict(model.args)
32
+ foundation_cache = pipeline.foundation_cache
33
+
34
+ model_path = con_processor._config['model_path']
35
+ # reuse the same model 3 times just to make sure the code paths are working
36
+ filenames = [model_path, model_path, model_path]
37
+
38
+ ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)
39
+ save_path = tmp_path / "ensemble.pt"
40
+
41
+ ensemble.save(save_path)
42
+ return ensemble, save_path, args, foundation_cache
43
+
44
+ def check_basic_predictions(trees):
45
+ predictions = [x.predictions for x in trees]
46
+ assert len(predictions) == 2
47
+ assert all(len(x) == 1 for x in predictions)
48
+ trees = [x[0].tree for x in predictions]
49
+ result = ["{}".format(tree) for tree in trees]
50
+ expected = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
51
+ "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
52
+ assert result == expected
53
+
54
+ def test_ensemble_inference(pipeline):
55
+ # test the ensemble by reusing the same parser multiple times
56
+ con_processor = pipeline.processors["constituency"]
57
+ model = con_processor._model
58
+ args = dict(model.args)
59
+ foundation_cache = pipeline.foundation_cache
60
+
61
+ model_path = con_processor._config['model_path']
62
+ # reuse the same model 3 times just to make sure the code paths are working
63
+ filenames = [model_path, model_path, model_path]
64
+
65
+ ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)
66
+ ensemble = ensemble.model
67
+ sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]]
68
+ trees = parse_tokenized_sentences(args, ensemble, [pipeline], sentences)
69
+ check_basic_predictions(trees)
70
+
71
+ def test_ensemble_save(saved_ensemble):
72
+ """
73
+ Depending on the saved_ensemble fixture should be enough to ensure
74
+ that the ensemble was correctly saved
75
+
76
+ (loading is tested separately)
77
+ """
78
+
79
+ def test_ensemble_save_load(pipeline, saved_ensemble):
80
+ _, save_path, args, foundation_cache = saved_ensemble
81
+ ensemble = EnsembleTrainer.load(save_path, args, foundation_cache=foundation_cache)
82
+ sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]]
83
+ trees = parse_tokenized_sentences(args, ensemble.model, [pipeline], sentences)
84
+ check_basic_predictions(trees)
85
+
86
+ def test_parse_text(tmp_path, pipeline, saved_ensemble):
87
+ _, model_path, args, foundation_cache = saved_ensemble
88
+
89
+ raw_file = str(tmp_path / "test_input.txt")
90
+ with open(raw_file, "w") as fout:
91
+ fout.write("This is a test\nThis is another test\n")
92
+ output_file = str(tmp_path / "test_output.txt")
93
+
94
+ args = dict(args)
95
+ args['tokenized_file'] = raw_file
96
+ args['predict_file'] = output_file
97
+
98
+ text_processing.load_model_parse_text(args, model_path, [pipeline])
99
+ trees = tree_reader.read_treebank(output_file)
100
+ trees = ["{}".format(x) for x in trees]
101
+ expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
102
+ "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
103
+ assert trees == expected_trees
104
+
105
+ def test_pipeline(saved_ensemble):
106
+ _, model_path, _, foundation_cache = saved_ensemble
107
+ nlp = Pipeline("en", processors="tokenize,pos,constituency", constituency_model_path=str(model_path), foundation_cache=foundation_cache, download_method=None)
108
+ doc = nlp("This is a test")
109
+ tree = "{}".format(doc.sentences[0].constituency)
110
+ assert tree == "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))"
stanza/stanza/tests/constituency/test_in_order_compound_oracle.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.models.constituency import in_order_compound_oracle
4
+ from stanza.models.constituency import tree_reader
5
+ from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme
6
+ from stanza.models.constituency.transition_sequence import build_treebank
7
+
8
+ from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
9
+
10
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
11
+
12
+ # A sample tree from PTB with a triple unary transition (at a location other than root)
13
+ # Here we test the incorrect closing of various brackets
14
+ TRIPLE_UNARY_START_TREE = """
15
+ ( (S
16
+ (PRN
17
+ (S
18
+ (NP-SBJ (-NONE- *) )
19
+ (VP (VB See) )))
20
+ (, ,)
21
+ (NP-SBJ
22
+ (NP (DT the) (JJ other) (NN rule) )
23
+ (PP (IN of)
24
+ (NP (NN thumb) ))
25
+ (PP (IN about)
26
+ (NP (NN ballooning) )))))
27
+ """
28
+
29
+ TREES = [TRIPLE_UNARY_START_TREE]
30
+ TREEBANK = "\n".join(TREES)
31
+
32
+ ROOT_LABELS = ["ROOT"]
33
+
34
+ @pytest.fixture(scope="module")
35
+ def trees():
36
+ trees = tree_reader.read_trees(TREEBANK)
37
+ trees = [t.prune_none().simplify_labels() for t in trees]
38
+ assert len(trees) == len(TREES)
39
+
40
+ return trees
41
+
42
+ @pytest.fixture(scope="module")
43
+ def gold_sequences(trees):
44
+ gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
45
+ return gold_sequences
46
+
47
+ def get_repairs(gold_sequence, wrong_transition, repair_fn):
48
+ """
49
+ Use the repair function and the wrong transition to iterate over the gold sequence
50
+
51
+ Returns a list of possible repairs, one for each position in the sequence
52
+ Repairs are tuples, (idx, seq)
53
+ """
54
+ repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))
55
+ for idx, gold_transition in enumerate(gold_sequence)]
56
+ repairs = [x for x in repairs if x[1] is not None]
57
+ return repairs
58
+
59
+ def test_fix_shift_close():
60
+ trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)
61
+ trees = [t.prune_none().simplify_labels() for t in trees]
62
+ assert len(trees) == 1
63
+ tree = trees[0]
64
+
65
+ gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
66
+
67
+ # there are three places in this tree where a long bracket (more than 2 subtrees)
68
+ # could theoretically be closed and then reopened
69
+ repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_shift_close_error)
70
+ assert len(repairs) == 3
71
+
72
+ expected_trees = ["(ROOT (S (S (PRN (S (VP (VB See)))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
73
+ "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other)) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
74
+ "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)))) (PP (IN about) (NP (NN ballooning))))))"]
75
+
76
+ for repair, expected in zip(repairs, expected_trees):
77
+ repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
78
+ assert str(repaired_tree) == expected
79
+
80
+ def test_fix_open_close():
81
+ trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)
82
+ trees = [t.prune_none().simplify_labels() for t in trees]
83
+ assert len(trees) == 1
84
+ tree = trees[0]
85
+
86
+ gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
87
+
88
+ repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_open_close_error)
89
+ print("------------------")
90
+ for repair in repairs:
91
+ print(repair)
92
+ repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
93
+ print("{:P}".format(repaired_tree))
stanza/stanza/tests/constituency/test_parse_transitions.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.models.constituency import parse_transitions
4
+ from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT
5
+ from stanza.models.constituency.parse_transitions import TransitionScheme, Shift, CloseConstituent, OpenConstituent
6
+ from stanza.tests import *
7
+
8
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
9
+
10
+
11
+ def build_initial_state(model, num_states=1):
12
+ words = ["Unban", "Mox", "Opal"]
13
+ tags = ["VB", "NNP", "NNP"]
14
+ sentences = [list(zip(words, tags)) for _ in range(num_states)]
15
+
16
+ states = model.initial_state_from_words(sentences)
17
+ assert len(states) == num_states
18
+ assert all(state.num_transitions == 0 for state in states)
19
+ return states
20
+
21
+ def test_initial_state(model=None):
22
+ if model is None:
23
+ model = SimpleModel()
24
+ states = build_initial_state(model)
25
+ assert len(states) == 1
26
+ state = states[0]
27
+
28
+ assert state.sentence_length == 3
29
+ assert state.num_opens == 0
30
+ # each stack has a sentinel value at the end
31
+ assert len(state.word_queue) == 5
32
+ assert len(state.constituents) == 1
33
+ assert len(state.transitions) == 1
34
+ assert state.word_position == 0
35
+
36
+ def test_shift(model=None):
37
+ if model is None:
38
+ model = SimpleModel()
39
+ state = build_initial_state(model)[0]
40
+
41
+ open_transition = parse_transitions.OpenConstituent("ROOT")
42
+ state = open_transition.apply(state, model)
43
+ open_transition = parse_transitions.OpenConstituent("S")
44
+ state = open_transition.apply(state, model)
45
+ shift = parse_transitions.Shift()
46
+ assert shift.is_legal(state, model)
47
+ assert len(state.word_queue) == 5
48
+ assert state.word_position == 0
49
+
50
+ state = shift.apply(state, model)
51
+ assert len(state.word_queue) == 5
52
+ # 4 because of the dummy created by the opens
53
+ assert len(state.constituents) == 4
54
+ assert len(state.transitions) == 4
55
+ assert shift.is_legal(state, model)
56
+ assert state.word_position == 1
57
+ assert not state.empty_word_queue()
58
+
59
+ state = shift.apply(state, model)
60
+ assert len(state.word_queue) == 5
61
+ assert len(state.constituents) == 5
62
+ assert len(state.transitions) == 5
63
+ assert shift.is_legal(state, model)
64
+ assert state.word_position == 2
65
+ assert not state.empty_word_queue()
66
+
67
+ state = shift.apply(state, model)
68
+ assert len(state.word_queue) == 5
69
+ assert len(state.constituents) == 6
70
+ assert len(state.transitions) == 6
71
+ assert not shift.is_legal(state, model)
72
+ assert state.word_position == 3
73
+ assert state.empty_word_queue()
74
+
75
+ constituents = state.constituents
76
+ assert model.get_top_constituent(constituents).children[0].label == 'Opal'
77
+ constituents = constituents.pop()
78
+ assert model.get_top_constituent(constituents).children[0].label == 'Mox'
79
+ constituents = constituents.pop()
80
+ assert model.get_top_constituent(constituents).children[0].label == 'Unban'
81
+
82
+ def test_initial_unary(model=None):
83
+ # it doesn't make sense to start with a CompoundUnary
84
+ if model is None:
85
+ model = SimpleModel()
86
+
87
+ state = build_initial_state(model)[0]
88
+ unary = parse_transitions.CompoundUnary('ROOT', 'VP')
89
+ assert unary.label == ('ROOT', 'VP',)
90
+ assert not unary.is_legal(state, model)
91
+ unary = parse_transitions.CompoundUnary('VP')
92
+ assert unary.label == ('VP',)
93
+ assert not unary.is_legal(state, model)
94
+
95
+
96
+ def test_unary(model=None):
97
+ if model is None:
98
+ model = SimpleModel()
99
+ state = build_initial_state(model)[0]
100
+
101
+ shift = parse_transitions.Shift()
102
+ state = shift.apply(state, model)
103
+
104
+ # this is technically the wrong parse but we're being lazy
105
+ unary = parse_transitions.CompoundUnary('S', 'VP')
106
+ assert unary.is_legal(state, model)
107
+ state = unary.apply(state, model)
108
+ assert not unary.is_legal(state, model)
109
+
110
+ tree = model.get_top_constituent(state.constituents)
111
+ assert tree.label == 'S'
112
+ assert len(tree.children) == 1
113
+ tree = tree.children[0]
114
+ assert tree.label == 'VP'
115
+ assert len(tree.children) == 1
116
+ tree = tree.children[0]
117
+ assert tree.label == 'VB'
118
+ assert tree.is_preterminal()
119
+
120
+ def test_unary_requires_root(model=None):
121
+ if model is None:
122
+ model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)
123
+ state = build_initial_state(model)[0]
124
+
125
+ open_transition = parse_transitions.OpenConstituent("S")
126
+ assert open_transition.is_legal(state, model)
127
+ state = open_transition.apply(state, model)
128
+
129
+ shift = parse_transitions.Shift()
130
+ assert shift.is_legal(state, model)
131
+ state = shift.apply(state, model)
132
+ assert shift.is_legal(state, model)
133
+ state = shift.apply(state, model)
134
+ assert shift.is_legal(state, model)
135
+ state = shift.apply(state, model)
136
+ assert not shift.is_legal(state, model)
137
+
138
+ close_transition = parse_transitions.CloseConstituent()
139
+ assert close_transition.is_legal(state, model)
140
+ state = close_transition.apply(state, model)
141
+ assert not open_transition.is_legal(state, model)
142
+ assert not close_transition.is_legal(state, model)
143
+
144
+ np_unary = parse_transitions.CompoundUnary("NP")
145
+ assert not np_unary.is_legal(state, model)
146
+ root_unary = parse_transitions.CompoundUnary("ROOT")
147
+ assert root_unary.is_legal(state, model)
148
+ assert not state.finished(model)
149
+ state = root_unary.apply(state, model)
150
+ assert not root_unary.is_legal(state, model)
151
+
152
+ assert state.finished(model)
153
+
154
+ def test_open(model=None):
155
+ if model is None:
156
+ model = SimpleModel()
157
+ state = build_initial_state(model)[0]
158
+
159
+ shift = parse_transitions.Shift()
160
+ state = shift.apply(state, model)
161
+ state = shift.apply(state, model)
162
+ assert state.num_opens == 0
163
+
164
+ open_transition = parse_transitions.OpenConstituent("VP")
165
+ assert open_transition.is_legal(state, model)
166
+ state = open_transition.apply(state, model)
167
+ assert open_transition.is_legal(state, model)
168
+ assert state.num_opens == 1
169
+
170
+ # check that it is illegal if there are too many opens already
171
+ for i in range(20):
172
+ state = open_transition.apply(state, model)
173
+ assert not open_transition.is_legal(state, model)
174
+ assert state.num_opens == 21
175
+
176
+ # check that it is illegal if the state is out of words
177
+ state = build_initial_state(model)[0]
178
+ state = shift.apply(state, model)
179
+ state = shift.apply(state, model)
180
+ state = shift.apply(state, model)
181
+ assert not open_transition.is_legal(state, model)
182
+
183
+ def test_compound_open(model=None):
184
+ if model is None:
185
+ model = SimpleModel()
186
+ state = build_initial_state(model)[0]
187
+
188
+ open_transition = parse_transitions.OpenConstituent("ROOT", "S")
189
+ assert open_transition.is_legal(state, model)
190
+ shift = parse_transitions.Shift()
191
+ close_transition = parse_transitions.CloseConstituent()
192
+
193
+ state = open_transition.apply(state, model)
194
+ state = shift.apply(state, model)
195
+ state = shift.apply(state, model)
196
+ state = shift.apply(state, model)
197
+ state = close_transition.apply(state, model)
198
+
199
+ tree = model.get_top_constituent(state.constituents)
200
+ assert tree.label == 'ROOT'
201
+ assert len(tree.children) == 1
202
+ tree = tree.children[0]
203
+ assert tree.label == 'S'
204
+ assert len(tree.children) == 3
205
+ assert tree.children[0].children[0].label == 'Unban'
206
+ assert tree.children[1].children[0].label == 'Mox'
207
+ assert tree.children[2].children[0].label == 'Opal'
208
+
209
+ def test_in_order_open(model=None):
210
+ if model is None:
211
+ model = SimpleModel(TransitionScheme.IN_ORDER)
212
+ state = build_initial_state(model)[0]
213
+
214
+ shift = parse_transitions.Shift()
215
+ assert shift.is_legal(state, model)
216
+ state = shift.apply(state, model)
217
+ assert not shift.is_legal(state, model)
218
+
219
+ open_vp = parse_transitions.OpenConstituent("VP")
220
+ assert open_vp.is_legal(state, model)
221
+ state = open_vp.apply(state, model)
222
+ assert not open_vp.is_legal(state, model)
223
+
224
+ close_trans = parse_transitions.CloseConstituent()
225
+ assert close_trans.is_legal(state, model)
226
+ state = close_trans.apply(state, model)
227
+
228
+ open_s = parse_transitions.OpenConstituent("S")
229
+ assert open_s.is_legal(state, model)
230
+ state = open_s.apply(state, model)
231
+ assert not open_vp.is_legal(state, model)
232
+
233
+ # check that root transitions won't happen in the middle of a parse
234
+ open_root = parse_transitions.OpenConstituent("ROOT")
235
+ assert not open_root.is_legal(state, model)
236
+
237
+ # build (NP (NNP Mox) (NNP Opal))
238
+ open_np = parse_transitions.OpenConstituent("NP")
239
+ assert shift.is_legal(state, model)
240
+ state = shift.apply(state, model)
241
+ assert open_np.is_legal(state, model)
242
+ # make sure root can't happen in places where an arbitrary open is legal
243
+ assert not open_root.is_legal(state, model)
244
+ state = open_np.apply(state, model)
245
+ assert shift.is_legal(state, model)
246
+ state = shift.apply(state, model)
247
+ assert close_trans.is_legal(state, model)
248
+ state = close_trans.apply(state, model)
249
+
250
+ assert close_trans.is_legal(state, model)
251
+ state = close_trans.apply(state, model)
252
+
253
+ assert open_root.is_legal(state, model)
254
+ state = open_root.apply(state, model)
255
+
256
+ def test_too_many_unaries_close():
257
+ """
258
+ This tests rejecting Close at the start of a sequence after too many unary transitions
259
+
260
+ The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
261
+ """
262
+ model = SimpleModel(TransitionScheme.IN_ORDER)
263
+ state = build_initial_state(model)[0]
264
+
265
+ shift = parse_transitions.Shift()
266
+ assert shift.is_legal(state, model)
267
+ state = shift.apply(state, model)
268
+
269
+ open_np = parse_transitions.OpenConstituent("NP")
270
+ close_trans = parse_transitions.CloseConstituent()
271
+ for _ in range(UNARY_LIMIT):
272
+ assert open_np.is_legal(state, model)
273
+ state = open_np.apply(state, model)
274
+
275
+ assert close_trans.is_legal(state, model)
276
+ state = close_trans.apply(state, model)
277
+
278
+ assert open_np.is_legal(state, model)
279
+ state = open_np.apply(state, model)
280
+ assert not close_trans.is_legal(state, model)
281
+
282
+ def test_too_many_unaries_open():
283
+ """
284
+ This tests rejecting Open in the middle of a sequence after too many unary transitions
285
+
286
+ The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
287
+ """
288
+ model = SimpleModel(TransitionScheme.IN_ORDER)
289
+ state = build_initial_state(model)[0]
290
+
291
+ shift = parse_transitions.Shift()
292
+ assert shift.is_legal(state, model)
293
+ state = shift.apply(state, model)
294
+
295
+ open_np = parse_transitions.OpenConstituent("NP")
296
+ close_trans = parse_transitions.CloseConstituent()
297
+
298
+ assert open_np.is_legal(state, model)
299
+ state = open_np.apply(state, model)
300
+ assert not open_np.is_legal(state, model)
301
+ assert shift.is_legal(state, model)
302
+ state = shift.apply(state, model)
303
+
304
+ for _ in range(UNARY_LIMIT):
305
+ assert open_np.is_legal(state, model)
306
+ state = open_np.apply(state, model)
307
+
308
+ assert close_trans.is_legal(state, model)
309
+ state = close_trans.apply(state, model)
310
+
311
+ assert not open_np.is_legal(state, model)
312
+
313
+ def test_close(model=None):
314
+ if model is None:
315
+ model = SimpleModel()
316
+
317
+ # this one actually tests an entire subtree building
318
+ state = build_initial_state(model)[0]
319
+
320
+ open_transition_vp = parse_transitions.OpenConstituent("VP")
321
+ assert open_transition_vp.is_legal(state, model)
322
+ state = open_transition_vp.apply(state, model)
323
+ assert state.num_opens == 1
324
+
325
+ shift = parse_transitions.Shift()
326
+ assert shift.is_legal(state, model)
327
+ state = shift.apply(state, model)
328
+
329
+ open_transition_np = parse_transitions.OpenConstituent("NP")
330
+ assert open_transition_np.is_legal(state, model)
331
+ state = open_transition_np.apply(state, model)
332
+ assert state.num_opens == 2
333
+
334
+ assert shift.is_legal(state, model)
335
+ state = shift.apply(state, model)
336
+ assert shift.is_legal(state, model)
337
+ state = shift.apply(state, model)
338
+ assert not shift.is_legal(state, model)
339
+ assert state.num_opens == 2
340
+ # now should have "mox", "opal" on the constituents
341
+
342
+ close_transition = parse_transitions.CloseConstituent()
343
+ assert close_transition.is_legal(state, model)
344
+ state = close_transition.apply(state, model)
345
+ assert state.num_opens == 1
346
+ assert close_transition.is_legal(state, model)
347
+ state = close_transition.apply(state, model)
348
+ assert state.num_opens == 0
349
+ assert not close_transition.is_legal(state, model)
350
+
351
+ tree = model.get_top_constituent(state.constituents)
352
+ assert tree.label == 'VP'
353
+ assert len(tree.children) == 2
354
+ tree = tree.children[1]
355
+ assert tree.label == 'NP'
356
+ assert len(tree.children) == 2
357
+ assert tree.children[0].is_preterminal()
358
+ assert tree.children[1].is_preterminal()
359
+ assert tree.children[0].children[0].label == 'Mox'
360
+ assert tree.children[1].children[0].label == 'Opal'
361
+
362
+ # extra one for None at the start of the TreeStack
363
+ assert len(state.constituents) == 2
364
+
365
+ assert state.all_transitions(model) == [open_transition_vp, shift, open_transition_np, shift, shift, close_transition, close_transition]
366
+
367
+ def test_in_order_compound_finalize(model=None):
368
+ """
369
+ Test the Finalize transition is only legal at the end of a sequence
370
+ """
371
+ if model is None:
372
+ model = SimpleModel(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
373
+
374
+ state = build_initial_state(model)[0]
375
+
376
+ finalize = parse_transitions.Finalize("ROOT")
377
+
378
+ shift = parse_transitions.Shift()
379
+ assert shift.is_legal(state, model)
380
+ assert not finalize.is_legal(state, model)
381
+ state = shift.apply(state, model)
382
+
383
+ open_transition = parse_transitions.OpenConstituent("NP")
384
+ assert open_transition.is_legal(state, model)
385
+ assert not finalize.is_legal(state, model)
386
+ state = open_transition.apply(state, model)
387
+ assert state.num_opens == 1
388
+
389
+ assert shift.is_legal(state, model)
390
+ assert not finalize.is_legal(state, model)
391
+ state = shift.apply(state, model)
392
+ assert shift.is_legal(state, model)
393
+ assert not finalize.is_legal(state, model)
394
+ state = shift.apply(state, model)
395
+
396
+ close_transition = parse_transitions.CloseConstituent()
397
+ assert close_transition.is_legal(state, model)
398
+ state = close_transition.apply(state, model)
399
+ assert state.num_opens == 0
400
+ assert not close_transition.is_legal(state, model)
401
+ assert finalize.is_legal(state, model)
402
+
403
+ state = finalize.apply(state, model)
404
+ assert not finalize.is_legal(state, model)
405
+ tree = model.get_top_constituent(state.constituents)
406
+ assert tree.label == 'ROOT'
407
+
408
+ def test_hashes():
409
+ transitions = set()
410
+
411
+ shift = parse_transitions.Shift()
412
+ assert shift not in transitions
413
+ transitions.add(shift)
414
+ assert shift in transitions
415
+ shift = parse_transitions.Shift()
416
+ assert shift in transitions
417
+
418
+ for i in range(5):
419
+ transitions.add(shift)
420
+ assert len(transitions) == 1
421
+
422
+ unary = parse_transitions.CompoundUnary("asdf")
423
+ assert unary not in transitions
424
+ transitions.add(unary)
425
+ assert unary in transitions
426
+
427
+ unary = parse_transitions.CompoundUnary("asdf", "zzzz")
428
+ assert unary not in transitions
429
+ transitions.add(unary)
430
+ transitions.add(unary)
431
+ transitions.add(unary)
432
+ unary = parse_transitions.CompoundUnary("asdf", "zzzz")
433
+ assert unary in transitions
434
+
435
+ oc = parse_transitions.OpenConstituent("asdf")
436
+ assert oc not in transitions
437
+ transitions.add(oc)
438
+ assert oc in transitions
439
+ transitions.add(oc)
440
+ transitions.add(oc)
441
+ assert len(transitions) == 4
442
+ assert parse_transitions.OpenConstituent("asdf") in transitions
443
+
444
+ cc = parse_transitions.CloseConstituent()
445
+ assert cc not in transitions
446
+ transitions.add(cc)
447
+ transitions.add(cc)
448
+ transitions.add(cc)
449
+ assert cc in transitions
450
+ cc = parse_transitions.CloseConstituent()
451
+ assert cc in transitions
452
+ assert len(transitions) == 5
453
+
454
+
455
+ def test_sort():
456
+ expected = []
457
+
458
+ expected.append(parse_transitions.Shift())
459
+ expected.append(parse_transitions.CloseConstituent())
460
+ expected.append(parse_transitions.CompoundUnary("NP"))
461
+ expected.append(parse_transitions.CompoundUnary("NP", "VP"))
462
+ expected.append(parse_transitions.OpenConstituent("mox"))
463
+ expected.append(parse_transitions.OpenConstituent("opal"))
464
+ expected.append(parse_transitions.OpenConstituent("unban"))
465
+
466
+ transitions = set(expected)
467
+ transitions = sorted(transitions)
468
+ assert transitions == expected
469
+
470
+ def test_check_transitions():
471
+ """
472
+ Test that check_transitions passes or fails a couple simple, small test cases
473
+ """
474
+ transitions = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
475
+
476
+ other = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
477
+ parse_transitions.check_transitions(transitions, other, "test")
478
+
479
+ # This will get a pass because it is a unary made out of existing unaries
480
+ other = {Shift(), CloseConstituent(), OpenConstituent("NP", "VP")}
481
+ parse_transitions.check_transitions(transitions, other, "test")
482
+
483
+ # This should fail
484
+ with pytest.raises(RuntimeError):
485
+ other = {Shift(), CloseConstituent(), OpenConstituent("NP", "ZP")}
486
+ parse_transitions.check_transitions(transitions, other, "test")
stanza/stanza/tests/constituency/test_parse_tree.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.models.constituency.parse_tree import Tree
4
+ from stanza.models.constituency import tree_reader
5
+
6
+ from stanza.tests import *
7
+
8
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
9
+
10
+ def test_leaf_preterminal():
11
+ foo = Tree(label="foo")
12
+ assert foo.is_leaf()
13
+ assert not foo.is_preterminal()
14
+ assert len(foo.children) == 0
15
+ assert str(foo) == 'foo'
16
+
17
+ bar = Tree(label="bar", children=foo)
18
+ assert not bar.is_leaf()
19
+ assert bar.is_preterminal()
20
+ assert len(bar.children) == 1
21
+ assert str(bar) == "(bar foo)"
22
+
23
+ baz = Tree(label="baz", children=[bar])
24
+ assert not baz.is_leaf()
25
+ assert not baz.is_preterminal()
26
+ assert len(baz.children) == 1
27
+ assert str(baz) == "(baz (bar foo))"
28
+
29
+
30
+ def test_yield_preterminals():
31
+ text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
32
+ trees = tree_reader.read_trees(text)
33
+
34
+ preterminals = list(trees[0].yield_preterminals())
35
+ assert len(preterminals) == 3
36
+ assert str(preterminals) == "[(VB Unban), (NNP Mox), (NNP Opal)]"
37
+
38
+ def test_depth():
39
+ text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
40
+ trees = tree_reader.read_trees(text)
41
+ assert trees[0].depth() == 0
42
+ assert trees[1].depth() == 4
43
+
44
+ def test_unique_labels():
45
+ """
46
+ Test getting the unique labels from a tree
47
+
48
+ Assumes tree_reader works, which should be fine since it is tested elsewhere
49
+ """
50
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
51
+
52
+ trees = tree_reader.read_trees(text)
53
+
54
+ labels = Tree.get_unique_constituent_labels(trees)
55
+ expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP']
56
+ assert labels == expected
57
+
58
+ def test_unique_tags():
59
+ """
60
+ Test getting the unique tags from a tree
61
+ """
62
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
63
+
64
+ trees = tree_reader.read_trees(text)
65
+
66
+ tags = Tree.get_unique_tags(trees)
67
+ expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP']
68
+ assert tags == expected
69
+
70
+
71
+ def test_unique_words():
72
+ """
73
+ Test getting the unique words from a tree
74
+ """
75
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
76
+
77
+ trees = tree_reader.read_trees(text)
78
+
79
+ words = Tree.get_unique_words(trees)
80
+ expected = ['?', 'Who', 'in', 'seat', 'sits', 'this']
81
+ assert words == expected
82
+
83
+ def test_rare_words():
84
+ """
85
+ Test getting the unique words from a tree
86
+ """
87
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
88
+
89
+ trees = tree_reader.read_trees(text)
90
+
91
+ words = Tree.get_rare_words(trees, 0.5)
92
+ expected = ['Who', 'in', 'sits']
93
+ assert words == expected
94
+
95
+ def test_common_words():
96
+ """
97
+ Test getting the unique words from a tree
98
+ """
99
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
100
+
101
+ trees = tree_reader.read_trees(text)
102
+
103
+ words = Tree.get_common_words(trees, 3)
104
+ expected = ['?', 'seat', 'this']
105
+ assert words == expected
106
+
107
+ def test_root_labels():
108
+ text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
109
+ trees = tree_reader.read_trees(text)
110
+ assert ["ROOT"] == Tree.get_root_labels(trees)
111
+
112
+ text=("( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
113
+ "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
114
+ "( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))")
115
+ trees = tree_reader.read_trees(text)
116
+ assert ["ROOT"] == Tree.get_root_labels(trees)
117
+
118
+ text="(FOO) (BAR)"
119
+ trees = tree_reader.read_trees(text)
120
+ assert ["BAR", "FOO"] == Tree.get_root_labels(trees)
121
+
122
+ def test_prune_none():
123
+ text=["((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (-NONE- in) (NP (DT this) (NN seat))))) (. ?)))", # test one dead node
124
+ "((SBARQ (WHNP (-NONE- Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", # test recursive dead nodes
125
+ "((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (-NONE- this) (-NONE- seat))))) (. ?)))"] # test all children dead
126
+ expected=["(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (NP (DT this) (NN seat))))) (. ?)))",
127
+ "(ROOT (SBARQ (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))",
128
+ "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"]
129
+
130
+ for t, e in zip(text, expected):
131
+ trees = tree_reader.read_trees(t)
132
+ assert len(trees) == 1
133
+ tree = trees[0].prune_none()
134
+ assert e == str(tree)
135
+
136
+ def test_simplify_labels():
137
+ text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
138
+ expected = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
139
+ trees = tree_reader.read_trees(text)
140
+ trees = [t.simplify_labels() for t in trees]
141
+ assert len(trees) == 1
142
+ assert expected == str(trees[0])
143
+
144
+ def test_remap_constituent_labels():
145
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
146
+ expected="(ROOT (FOO (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
147
+
148
+ label_map = { "SBARQ": "FOO" }
149
+ trees = tree_reader.read_trees(text)
150
+ trees = [t.remap_constituent_labels(label_map) for t in trees]
151
+ assert len(trees) == 1
152
+ assert expected == str(trees[0])
153
+
154
+ def test_remap_constituent_words():
155
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
156
+ expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
157
+
158
+ word_map = { "Who": "unban", "sits": "mox", "in": "opal" }
159
+ trees = tree_reader.read_trees(text)
160
+ trees = [t.remap_words(word_map) for t in trees]
161
+ assert len(trees) == 1
162
+ assert expected == str(trees[0])
163
+
164
+ def test_replace_words():
165
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
166
+ expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
167
+ new_words = ["unban", "mox", "opal", "?"]
168
+
169
+ trees = tree_reader.read_trees(text)
170
+ assert len(trees) == 1
171
+ tree = trees[0]
172
+ new_tree = tree.replace_words(new_words)
173
+ assert expected == str(new_tree)
174
+
175
+
176
+ def test_compound_constituents():
177
+ # TODO: add skinny trees like this to the various transition tests
178
+ text="((VP (VB Unban)))"
179
+ trees = tree_reader.read_trees(text)
180
+ assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')]
181
+
182
+ text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
183
+ trees = tree_reader.read_trees(text)
184
+ assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP',)]
185
+
186
+ text="((VP (VB Unban))) (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
187
+ trees = tree_reader.read_trees(text)
188
+ assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP',)]
189
+
190
+ def test_equals():
191
+ """
192
+ Check one tree from the actual dataset for ==
193
+
194
+ when built with compound Open, this didn't work because of a silly bug
195
+ """
196
+ text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
197
+
198
+ trees = tree_reader.read_trees(text)
199
+ assert len(trees) == 1
200
+ tree = trees[0]
201
+
202
+ assert tree == tree
203
+
204
+ trees2 = tree_reader.read_trees(text)
205
+ tree2 = trees2[0]
206
+
207
+ assert tree is not tree2
208
+ assert tree == tree2
209
+
210
+
211
+ # This tree was causing the model to barf on CTB7,
212
+ # although it turns out the problem was just the
213
+ # depth of the unary, not the list
214
+ CHINESE_LONG_LIST_TREE = """
215
+ (ROOT
216
+ (IP
217
+ (NP (NNP 证券法))
218
+ (VP
219
+ (PP
220
+ (IN 对)
221
+ (NP
222
+ (DNP
223
+ (NP
224
+ (NP (NNP 中国))
225
+ (NP
226
+ (NN 证券)
227
+ (NN 市场)))
228
+ (DEC 的))
229
+ (NP (NN 运作))))
230
+ (, ,)
231
+ (PP
232
+ (PP
233
+ (IN 从)
234
+ (NP
235
+ (NP (NN 股票))
236
+ (NP (VV 发行) (EC 、) (VV 交易))))
237
+ (, ,)
238
+ (PP
239
+ (VV 到)
240
+ (NP
241
+ (NP (NN 上市) (NN 公司) (NN 收购))
242
+ (EC 、)
243
+ (NP (NN 证券) (NN 交易所))
244
+ (EC 、)
245
+ (NP (NN 证券) (NN 公司))
246
+ (EC 、)
247
+ (NP (NN 登记) (NN 结算) (NN 机构))
248
+ (EC 、)
249
+ (NP (NN 交易) (NN 服务) (NN 机构))
250
+ (EC 、)
251
+ (NP (NN 证券业) (NN 协会))
252
+ (EC 、)
253
+ (NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))
254
+ (CC 和)
255
+ (NP
256
+ (DNP
257
+ (NP (CP (CP (IP (VP (VV 违法))))))
258
+ (DEC 的))
259
+ (NP (NN 法律) (NN 责任))))))
260
+ (ADVP (RB 都))
261
+ (VP
262
+ (VV 作)
263
+ (AS 了)
264
+ (NP
265
+ (ADJP (JJ 详细))
266
+ (NP (NN 规定)))))
267
+ (. 。)))
268
+ """
269
+
270
+ WEIRD_UNARY = """
271
+ (DNP
272
+ (NP (CP (CP (IP (VP (ASDF
273
+ (NP (NN 上市) (NN 公司) (NN 收购))
274
+ (EC 、)
275
+ (NP (NN 证券) (NN 交易所))
276
+ (EC 、)
277
+ (NP (NN 证券) (NN 公司))
278
+ (EC 、)
279
+ (NP (NN 登记) (NN 结算) (NN 机构))
280
+ (EC 、)
281
+ (NP (NN 交易) (NN 服务) (NN 机构))
282
+ (EC 、)
283
+ (NP (NN 证券业) (NN 协会))
284
+ (EC 、)
285
+ (NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))))))))
286
+ (DEC 的))
287
+ """
288
+
289
+
290
+ def test_count_unaries():
291
+ trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
292
+ assert len(trees) == 1
293
+ assert trees[0].count_unary_depth() == 5
294
+
295
+ trees = tree_reader.read_trees(WEIRD_UNARY)
296
+ assert len(trees) == 1
297
+ assert trees[0].count_unary_depth() == 5
298
+
299
+ def test_str_bracket_labels():
300
+ text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
301
+ expected = "(_ROOT (_S (_VP (_VB Unban )_VB )_VP (_NP (_NNP Mox )_NNP (_NNP Opal )_NNP )_NP )_S )_ROOT"
302
+
303
+ trees = tree_reader.read_trees(text)
304
+ assert len(trees) == 1
305
+ assert "{:L}".format(trees[0]) == expected
306
+
307
+ def test_all_leaves_are_preterminals():
308
+ text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
309
+ trees = tree_reader.read_trees(text)
310
+ assert len(trees) == 1
311
+ assert trees[0].all_leaves_are_preterminals()
312
+
313
+ text = "((S (VP (VB Unban)) (NP (Mox) (NNP Opal))))"
314
+ trees = tree_reader.read_trees(text)
315
+ assert len(trees) == 1
316
+ assert not trees[0].all_leaves_are_preterminals()
317
+
318
+ def test_latex():
319
+ """
320
+ Test the latex format for trees
321
+ """
322
+ expected = "\\Tree [.S [.NP Jennifer ] [.VP has [.NP nice antennae ] ] ]"
323
+ tree = "(ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ nice) (NNS antennae)))))"
324
+ tree = tree_reader.read_trees(tree)[0]
325
+ text = "{:T}".format(tree)
326
+ assert text == expected
327
+
328
+ def test_pretty_print():
329
+ """
330
+ Pretty print a couple trees - newlines & indentation
331
+ """
332
+ text = "(ROOT (S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal)))) (ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric)))))))"
333
+ trees = tree_reader.read_trees(text)
334
+ assert len(trees) == 2
335
+
336
+ expected = """(ROOT
337
+ (S
338
+ (VP (VB Unban))
339
+ (NP (NNP Mox) (NNP Opal))))
340
+ """
341
+
342
+ assert "{:P}".format(trees[0]) == expected
343
+
344
+ expected = """(ROOT
345
+ (S
346
+ (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission))
347
+ (VP
348
+ (VBD authorized)
349
+ (NP
350
+ (NP
351
+ (DT an)
352
+ (ADJP (CD 11.5))
353
+ (NN %)
354
+ (NN rate)
355
+ (NN increase))
356
+ (PP
357
+ (IN at)
358
+ (NP (NNP Tucson) (NNP Electric)))))))
359
+ """
360
+ assert "{:P}".format(trees[1]) == expected
361
+
362
+ assert text == "{:O} {:O}".format(*trees)
363
+
364
+ def test_reverse():
365
+ text = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB lick) (NP (NP (NNP Jennifer) (POS 's)) (NNS antennae))))))))"
366
+ trees = tree_reader.read_trees(text)
367
+ assert len(trees) == 1
368
+ reversed_tree = trees[0].reverse()
369
+ assert str(reversed_tree) == "(ROOT (S (VP (S (VP (VP (NP (NNS antennae) (NP (POS 's) (NNP Jennifer))) (VB lick)) (TO to))) (VBP want)) (NP (PRP I))))"
stanza/stanza/tests/constituency/test_positional_encoding.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ from stanza import Pipeline
6
+ from stanza.models.constituency.positional_encoding import SinusoidalEncoding, AddSinusoidalEncoding
7
+
8
+ from stanza.tests import *
9
+
10
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
11
+
12
+
13
+ def test_positional_encoding():
14
+ encoding = SinusoidalEncoding(model_dim=10, max_len=6)
15
+ foo = encoding(torch.tensor([5]))
16
+ assert foo.shape == (1, 10)
17
+ # TODO: check the values
18
+
19
+ def test_resize():
20
+ encoding = SinusoidalEncoding(model_dim=10, max_len=3)
21
+ foo = encoding(torch.tensor([5]))
22
+ assert foo.shape == (1, 10)
23
+
24
+
25
+ def test_arange():
26
+ encoding = SinusoidalEncoding(model_dim=10, max_len=2)
27
+ foo = encoding(torch.arange(4))
28
+ assert foo.shape == (4, 10)
29
+ assert encoding.max_len() == 4
30
+
31
+ def test_add():
32
+ encoding = AddSinusoidalEncoding(d_model=10, max_len=4)
33
+ x = torch.zeros(1, 4, 10)
34
+ y = encoding(x)
35
+
36
+ r = torch.randn(1, 4, 10)
37
+ r2 = encoding(r)
38
+
39
+ assert torch.allclose(r2 - r, y, atol=1e-07)
40
+
41
+ r = torch.randn(2, 4, 10)
42
+ r2 = encoding(r)
43
+
44
+ assert torch.allclose(r2[0] - r[0], y, atol=1e-07)
45
+ assert torch.allclose(r2[1] - r[1], y, atol=1e-07)
stanza/stanza/tests/constituency/test_selftrain_vi_quad.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test some of the methods in the vi_quad dataset
3
+
4
+ Uses a small section of the dataset as a test
5
+ """
6
+
7
+ import pytest
8
+
9
+ from stanza.utils.datasets.constituency import selftrain_vi_quad
10
+
11
+ from stanza.tests import *
12
+
13
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
14
+
15
+ SAMPLE_TEXT = """
16
+ {"version": "1.1", "data": [{"title": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng", "paragraphs": [{"qas": [{"question": "T\u00ean g\u1ecdi n\u00e0o \u0111\u01b0\u1ee3c Ph\u1ea1m V\u0103n \u0110\u1ed3ng s\u1eed d\u1ee5ng khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m?", "answers": [{"answer_start": 507, "text": "L\u00e2m B\u00e1 Ki\u1ec7t"}], "id": "uit_01__05272_0_1"}, {"question": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec trong b\u1ed9 m\u00e1y Nh\u00e0 n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam?", "answers": [{"answer_start": 60, "text": "Th\u1ee7 t\u01b0\u1edbng"}], "id": "uit_01__05272_0_2"}, {"question": "Giai \u0111o\u1ea1n n\u0103m 1955-1976, Ph\u1ea1m V\u0103n \u0110\u1ed3ng n\u1eafm gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec?", "answers": [{"answer_start": 245, "text": "Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a"}], "id": "uit_01__05272_0_3"}], "context": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng (1 th\u00e1ng 3 n\u0103m 1906 \u2013 29 th\u00e1ng 4 n\u0103m 2000) l\u00e0 Th\u1ee7 t\u01b0\u1edbng \u0111\u1ea7u ti\u00ean c\u1ee7a n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam t\u1eeb n\u0103m 1976 (t\u1eeb n\u0103m 1981 g\u1ecdi l\u00e0 Ch\u1ee7 t\u1ecbch H\u1ed9i \u0111\u1ed3ng B\u1ed9 tr\u01b0\u1edfng) cho \u0111\u1ebfn khi ngh\u1ec9 h\u01b0u n\u0103m 1987. Tr\u01b0\u1edbc \u0111\u00f3 \u00f4ng t\u1eebng gi\u1eef ch\u1ee9c v\u1ee5 Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a t\u1eeb n\u0103m 1955 \u0111\u1ebfn n\u0103m 1976. \u00d4ng l\u00e0 v\u1ecb Th\u1ee7 t\u01b0\u1edbng Vi\u1ec7t Nam t\u1ea1i v\u1ecb l\u00e2u nh\u1ea5t (1955\u20131987). \u00d4ng l\u00e0 h\u1ecdc tr\u00f2, c\u1ed9ng s\u1ef1 c\u1ee7a Ch\u1ee7 t\u1ecbch H\u1ed3 Ch\u00ed Minh. \u00d4ng c\u00f3 t\u00ean g\u1ecdi th\u00e2n m\u1eadt l\u00e0 T\u00f4, \u0111\u00e2y t\u1eebng l\u00e0 b\u00ed danh c\u1ee7a \u00f4ng. \u00d4ng c\u00f2n c\u00f3 t\u00ean g\u1ecdi l\u00e0 L\u00e2m B\u00e1 Ki\u1ec7t khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m (Ch\u1ee7 nhi\u1ec7m l\u00e0 H\u1ed3 H\u1ecdc L\u00e3m)."}, {"qas": [{"question": "S\u1ef1 ki\u1ec7n quan tr\u1ecdng n\u00e0o \u0111\u00e3 di\u1ec5n ra v\u00e0o ng\u00e0y 20/7/1954?", "answers": [{"answer_start": 364, "text": "b\u1ea3n Hi\u1ec7p \u0111\u1ecbnh \u0111\u00ecnh ch\u1ec9 chi\u1ebfn s\u1ef1 \u1edf Vi\u1ec7t Nam, Campuchia v\u00e0 L\u00e0o \u0111\u00e3 \u0111\u01b0\u1ee3c k\u00fd k\u1ebft th\u1eeba nh\u1eadn t\u00f4n tr\u1ecdng \u0111\u1ed9c l\u1eadp, ch\u1ee7 quy\u1ec1n, c\u1ee7a n\u01b0\u1edbc Vi\u1ec7t Nam, L\u00e0o v\u00e0 Campuchia"}], "id": "uit_01__05272_1_1"}, {"question": "Ch\u1ee9c v\u1ee5 m\u00e0 Ph\u1ea1m V\u0103n \u0110\u1ed3ng \u0111\u1ea3m nhi\u1ec7m t\u1ea1i H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng?", "answers": [{"answer_start": 33, "text": "Tr\u01b0\u1edfng ph\u00e1i \u0111o\u00e0n Ch\u00ednh ph\u1ee7"}], "id": "uit_01__05272_1_2"}, {"question": "H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng c\u00f3 t\u00ednh ch\u1ea5t nh\u01b0 th\u1ebf n\u00e0o?", "answers": [{"answer_start": 262, "text": "r\u1ea5t c\u0103ng th\u1eb3ng v\u00e0 ph\u1ee9c t\u1ea1p"}], "id": "uit_01__05272_1_3"}]}]}]}
17
+ """
18
+
19
+ EXPECTED = ['Tên gọi nào được Phạm Văn Đồng sử dụng khi làm Phó chủ nhiệm cơ quan Biện sự xứ tại Quế Lâm?', 'Phạm Văn Đồng giữ chức vụ gì trong bộ máy Nhà nước Cộng hòa Xã hội chủ nghĩa Việt Nam?', 'Giai đoạn năm 1955-1976, Phạm Văn Đồng nắm giữ chức vụ gì?', 'Sự kiện quan trọng nào đã diễn ra vào ngày 20/7/1954?', 'Chức vụ mà Phạm Văn Đồng đảm nhiệm tại Hội nghị Genève về Đông Dương?', 'Hội nghị Genève về Đông Dương có tính chất như thế nào?']
20
+
21
+ def test_read_file():
22
+ results = selftrain_vi_quad.parse_quad(SAMPLE_TEXT)
23
+ assert results == EXPECTED
stanza/stanza/tests/constituency/test_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza import Pipeline
4
+ from stanza.models.constituency import tree_reader
5
+ from stanza.models.constituency import utils
6
+
7
+ from stanza.tests import *
8
+
9
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
10
+
11
+
12
+ @pytest.fixture(scope="module")
13
+ def pipeline():
14
+ return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
15
+
16
+
17
+
18
+ def test_xpos_retag(pipeline):
19
+ """
20
+ Test using the English tagger that trees will be correctly retagged by read_trees using xpos
21
+ """
22
+ text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
23
+ expected = "((S (VP (VB Find)) (NP (NNP Mox) (NNP Opal)))) ((S (NP (NNP Ragavan)) (VP (VBZ steals) (NP (JJ important) (NNS cards)))))"
24
+
25
+ trees = tree_reader.read_trees(text)
26
+
27
+ new_trees = utils.retag_trees(trees, [pipeline], xpos=True)
28
+ assert new_trees == tree_reader.read_trees(expected)
29
+
30
+
31
+
32
+ def test_upos_retag(pipeline):
33
+ """
34
+ Test using the English tagger that trees will be correctly retagged by read_trees using upos
35
+ """
36
+ text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
37
+ expected = "((S (VP (VERB Find)) (NP (PROPN Mox) (PROPN Opal)))) ((S (NP (PROPN Ragavan)) (VP (VERB steals) (NP (ADJ important) (NOUN cards)))))"
38
+
39
+ trees = tree_reader.read_trees(text)
40
+
41
+ new_trees = utils.retag_trees(trees, [pipeline], xpos=False)
42
+ assert new_trees == tree_reader.read_trees(expected)
43
+
44
+
45
+ def test_replace_tags():
46
+ """
47
+ Test the underlying replace_tags method
48
+
49
+ Also tests that the method throws exceptions when it is supposed to
50
+ """
51
+ text = "((S (VP (X Find)) (NP (X Mox) (X Opal))))"
52
+ expected = "((S (VP (A Find)) (NP (B Mox) (C Opal))))"
53
+
54
+ trees = tree_reader.read_trees(text)
55
+
56
+ new_tags = ["A", "B", "C"]
57
+ new_tree = trees[0].replace_tags(new_tags)
58
+
59
+ assert new_tree == tree_reader.read_trees(expected)[0]
60
+
61
+ with pytest.raises(ValueError):
62
+ new_tags = ["A", "B"]
63
+ new_tree = trees[0].replace_tags(new_tags)
64
+
65
+ with pytest.raises(ValueError):
66
+ new_tags = ["A", "B", "C", "D"]
67
+ new_tree = trees[0].replace_tags(new_tags)
68
+
stanza/stanza/tests/data/example_french.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"sentences":
2
+ [{"index": 0,
3
+ "tokens": [
4
+ {"index": 1, "word": "Cette", "originalText": "Cette", "characterOffsetBegin": 0, "characterOffsetEnd": 5, "pos": "DET", "before": "", "after": " "},
5
+ {"index": 2, "word": "enquête", "originalText": "enquête", "characterOffsetBegin": 6, "characterOffsetEnd": 13, "pos": "NOUN", "before": " ", "after": " "},
6
+ {"index": 3, "word": "préliminaire", "originalText": "préliminaire", "characterOffsetBegin": 14, "characterOffsetEnd": 26, "pos": "ADJ", "before": " ", "after": " "},
7
+ {"index": 4, "word": "fait", "originalText": "fait", "characterOffsetBegin": 27, "characterOffsetEnd": 31, "pos": "VERB", "before": " ", "after": " "},
8
+ {"index": 5, "word": "suite", "originalText": "suite", "characterOffsetBegin": 32, "characterOffsetEnd": 37, "pos": "NOUN", "before": " ", "after": " "},
9
+ {"index": 6, "word": "à", "originalText": "à", "characterOffsetBegin": 38, "characterOffsetEnd": 41, "pos": "ADP", "before": " ", "after": " "},
10
+ {"index": 7, "word": "les", "originalText": "les", "characterOffsetBegin": 38, "characterOffsetEnd": 41, "pos": "DET", "before": " ", "after": " "},
11
+ {"index": 8, "word": "révélations", "originalText": "révélations", "characterOffsetBegin": 42, "characterOffsetEnd": 53, "pos": "NOUN", "before": " ", "after": " "},
12
+ {"index": 9, "word": "de", "originalText": "de", "characterOffsetBegin": 54, "characterOffsetEnd": 56, "pos": "ADP", "before": " ", "after": " "},
13
+ {"index": 10, "word": "l’", "originalText": "l’", "characterOffsetBegin": 57, "characterOffsetEnd": 59, "pos": "NOUN", "before": " ", "after": ""},
14
+ {"index": 11, "word": "hebdomadaire", "originalText": "hebdomadaire", "characterOffsetBegin": 59, "characterOffsetEnd": 71, "pos": "ADJ", "before": "", "after": " "},
15
+ {"index": 12, "word": "quelques", "originalText": "quelques", "characterOffsetBegin": 72, "characterOffsetEnd": 80, "pos": "DET", "before": " ", "after": " "},
16
+ {"index": 13, "word": "jours", "originalText": "jours", "characterOffsetBegin": 81, "characterOffsetEnd": 86, "pos": "NOUN", "before": " ", "after": " "},
17
+ {"index": 14, "word": "plus", "originalText": "plus", "characterOffsetBegin": 87, "characterOffsetEnd": 91, "pos": "ADV", "before": " ", "after": " "},
18
+ {"index": 15, "word": "tôt", "originalText": "tôt", "characterOffsetBegin": 92, "characterOffsetEnd": 95, "pos": "ADV", "before": " ", "after": ""},
19
+ {"index": 16, "word": ".", "originalText": ".", "characterOffsetBegin": 95, "characterOffsetEnd": 96, "pos": "PUNCT", "before": "", "after": ""}
20
+ ]}
21
+ ]
22
+ }
stanza/stanza/tests/data/test.dat ADDED
Binary file (4.24 kB). View file
 
stanza/stanza/tests/data/tiny_emb.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 3 4
2
+ unban,1,2,3,4
3
+ mox,5,6,7,8
4
+ opal,9,10,11,12
stanza/stanza/tests/datasets/__init__.py ADDED
File without changes
stanza/stanza/tests/datasets/ner/__init__.py ADDED
File without changes
stanza/stanza/tests/datasets/ner/test_prepare_ner_file.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test some simple conversions of NER bio files
3
+ """
4
+
5
+ import pytest
6
+
7
+ import json
8
+
9
+ from stanza.models.common.doc import Document
10
+ from stanza.utils.datasets.ner.prepare_ner_file import process_dataset
11
+
12
+ BIO_1 = """
13
+ Jennifer B-PERSON
14
+ Sh'reyan I-PERSON
15
+ has O
16
+ lovely O
17
+ antennae O
18
+ """.strip()
19
+
20
+ BIO_2 = """
21
+ but O
22
+ I O
23
+ don't O
24
+ like O
25
+ the O
26
+ way O
27
+ Jennifer B-PERSON
28
+ treated O
29
+ Beckett B-PERSON
30
+ on O
31
+ the O
32
+ Cerritos B-LOCATION
33
+ """.strip()
34
+
35
+ def check_json_file(doc, raw_text, expected_sentences, expected_tokens):
36
+ raw_sentences = raw_text.strip().split("\n\n")
37
+ assert len(raw_sentences) == expected_sentences
38
+ if isinstance(expected_tokens, int):
39
+ expected_tokens = [expected_tokens]
40
+ for raw_sentence, expected_len in zip(raw_sentences, expected_tokens):
41
+ assert len(raw_sentence.strip().split("\n")) == expected_len
42
+
43
+ assert len(doc.sentences) == expected_sentences
44
+ for sentence, expected_len in zip(doc.sentences, expected_tokens):
45
+ assert len(sentence.tokens) == expected_len
46
+ for sentence, raw_sentence in zip(doc.sentences, raw_sentences):
47
+ for token, line in zip(sentence.tokens, raw_sentence.strip().split("\n")):
48
+ word, tag = line.strip().split()
49
+ assert token.text == word
50
+ assert token.ner == tag
51
+
52
+ def write_and_convert(tmp_path, raw_text):
53
+ bio_file = tmp_path / "test.bio"
54
+ with open(bio_file, "w", encoding="utf-8") as fout:
55
+ fout.write(raw_text)
56
+
57
+ json_file = tmp_path / "json.bio"
58
+ process_dataset(bio_file, json_file)
59
+
60
+ with open(json_file) as fin:
61
+ doc = Document(json.load(fin))
62
+
63
+ return doc
64
+
65
+ def run_test(tmp_path, raw_text, expected_sentences, expected_tokens):
66
+ doc = write_and_convert(tmp_path, raw_text)
67
+ check_json_file(doc, raw_text, expected_sentences, expected_tokens)
68
+
69
+ def test_simple(tmp_path):
70
+ run_test(tmp_path, BIO_1, 1, 5)
71
+
72
+ def test_ner_at_end(tmp_path):
73
+ run_test(tmp_path, BIO_2, 1, 12)
74
+
75
+ def test_two_sentences(tmp_path):
76
+ raw_text = BIO_1 + "\n\n" + BIO_2
77
+ run_test(tmp_path, raw_text, 2, [5, 12])
stanza/stanza/tests/datasets/ner/test_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test the utils file of the NER dataset processing
3
+ """
4
+
5
+ import pytest
6
+
7
+ from stanza.utils.datasets.ner.utils import list_doc_entities
8
+ from stanza.tests.datasets.ner.test_prepare_ner_file import BIO_1, BIO_2, write_and_convert
9
+
10
+ def test_list_doc_entities(tmp_path):
11
+ """
12
+ Test the function which lists all of the entities in a doc
13
+ """
14
+ doc = write_and_convert(tmp_path, BIO_1)
15
+ entities = list_doc_entities(doc)
16
+ expected = [(('Jennifer', "Sh'reyan"), 'PERSON')]
17
+ assert expected == entities
18
+
19
+ doc = write_and_convert(tmp_path, BIO_2)
20
+ entities = list_doc_entities(doc)
21
+ expected = [(('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
22
+ assert expected == entities
23
+
24
+ doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_2]))
25
+ entities = list_doc_entities(doc)
26
+ expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
27
+ assert expected == entities
28
+
29
+ doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_1, BIO_2]))
30
+ entities = list_doc_entities(doc)
31
+ expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
32
+ assert expected == entities
33
+
34
+
stanza/stanza/tests/lemma/test_data.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple basic data functions, such as processing a doc for its lemmas
3
+ """
4
+
5
+ import pytest
6
+
7
+ from stanza.models.common.doc import Document
8
+ from stanza.models.lemma.data import DataLoader
9
+ from stanza.utils.conll import CoNLL
10
+
11
+ TRAIN_DATA = """
12
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
13
+ # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
14
+ 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
15
+ 2 : : PUNCT : _ 1 punct 1:punct _
16
+ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
17
+ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
18
+ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
19
+ 6 that that SCONJ IN _ 9 mark 9:mark _
20
+ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
21
+ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
22
+ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
23
+ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _
24
+ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
25
+ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
26
+ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
27
+ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
28
+ 15 in in ADP IN _ 16 case 16:case _
29
+ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
30
+ 17 . . PUNCT . _ 1 punct 1:punct _
31
+
32
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
33
+ # text = Two of them were being run by 2 officials of the Ministry of the Interior!
34
+ 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
35
+ 2 of of ADP IN _ 3 case 3:case _
36
+ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
37
+ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
38
+ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
39
+ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
40
+ 7 by by ADP IN _ 9 case 9:case _
41
+ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
42
+ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
43
+ 10 of of ADP IN _ 12 case 12:case _
44
+ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
45
+ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
46
+ 13 of of ADP IN _ 15 case 15:case _
47
+ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
48
+ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
49
+ 16 ! ! PUNCT . _ 6 punct 6:punct _
50
+
51
+ """.lstrip()
52
+
53
+ GOESWITH_DATA = """
54
+ # sent_id = email-enronsent27_01-0041
55
+ # newpar id = email-enronsent27_01-p0005
56
+ # text = Ken Rice@ENRON COMMUNICATIONS
57
+ 1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _
58
+ 2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _
59
+ 3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _
60
+
61
+ """.lstrip()
62
+
63
+ CORRECT_FORM_DATA = """
64
+ # sent_id = weblog-blogspot.com_healingiraq_20040409053012_ENG_20040409_053012-0019
65
+ # text = They are targetting ambulances
66
+ 1 They they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 3 nsubj 3:nsubj _
67
+ 2 are be AUX VBP Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
68
+ 3 targetting target VERB VBG Tense=Pres|Typo=Yes|VerbForm=Part 0 root 0:root CorrectForm=targeting
69
+ 4 ambulances ambulance NOUN NNS Number=Plur 3 obj 3:obj SpaceAfter=No
70
+ """
71
+
72
+
73
+ def test_load_document():
74
+ train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)
75
+ data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
76
+ assert len(data) == 33 # meticulously counted by hand
77
+ assert all(len(x) == 3 for x in data)
78
+
79
+ data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
80
+ assert len(data) == 33
81
+ assert all(len(x) == 3 for x in data)
82
+
83
+ def test_load_goeswith():
84
+ raw_data = TRAIN_DATA + GOESWITH_DATA
85
+ train_doc = CoNLL.conll2doc(input_str=raw_data)
86
+ data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
87
+ assert len(data) == 36 # will be the same as in test_load_document with three additional words
88
+ assert all(len(x) == 3 for x in data)
89
+
90
+ data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
91
+ assert len(data) == 33 # will be the same as in test_load_document, but with the trailing 3 GOESWITH removed
92
+ assert all(len(x) == 3 for x in data)
93
+
94
+ def test_correct_form():
95
+ raw_data = TRAIN_DATA + CORRECT_FORM_DATA
96
+ train_doc = CoNLL.conll2doc(input_str=raw_data)
97
+ data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
98
+ assert len(data) == 37
99
+ # the 'targeting' correction should not be applied if evaluation=True
100
+ # when evaluation=False, then the CorrectForms will be applied
101
+ assert not any(x[0] == 'targeting' for x in data)
102
+
103
+ data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
104
+ assert len(data) == 38 # the same, but with an extra row so the model learns both 'targetting' and 'targeting'
105
+ assert any(x[0] == 'targeting' for x in data)
106
+ assert any(x[0] == 'targetting' for x in data)
stanza/stanza/tests/lemma/test_lemma_trainer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple basic functions - load & save an existing model
3
+ """
4
+
5
+ import pytest
6
+
7
+ import glob
8
+ import os
9
+ import tempfile
10
+
11
+ import torch
12
+
13
+ from stanza.models import lemmatizer
14
+ from stanza.models.lemma import trainer
15
+ from stanza.tests import *
16
+ from stanza.utils.training.common import choose_lemma_charlm, build_charlm_args
17
+
18
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
19
+
20
+ @pytest.fixture(scope="module")
21
+ def english_model():
22
+ models_path = os.path.join(TEST_MODELS_DIR, "en", "lemma", "*")
23
+ models = glob.glob(models_path)
24
+ # we expect at least one English model downloaded for the tests
25
+ assert len(models) >= 1
26
+ model_file = models[0]
27
+ return trainer.Trainer(model_file=model_file)
28
+
29
+ def test_load_model(english_model):
30
+ """
31
+ Does nothing, just tests that loading works
32
+ """
33
+
34
+ def test_save_load_model(english_model):
35
+ """
36
+ Load, save, and load again
37
+ """
38
+ with tempfile.TemporaryDirectory() as tempdir:
39
+ save_file = os.path.join(tempdir, "resaved", "lemma.pt")
40
+ english_model.save(save_file)
41
+ reloaded = trainer.Trainer(model_file=save_file)
42
+
43
+ TRAIN_DATA = """
44
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
45
+ # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
46
+ 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
47
+ 2 : : PUNCT : _ 1 punct 1:punct _
48
+ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
49
+ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
50
+ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
51
+ 6 that that SCONJ IN _ 9 mark 9:mark _
52
+ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
53
+ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
54
+ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
55
+ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _
56
+ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
57
+ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
58
+ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
59
+ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
60
+ 15 in in ADP IN _ 16 case 16:case _
61
+ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
62
+ 17 . . PUNCT . _ 1 punct 1:punct _
63
+
64
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
65
+ # text = Two of them were being run by 2 officials of the Ministry of the Interior!
66
+ 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
67
+ 2 of of ADP IN _ 3 case 3:case _
68
+ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
69
+ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
70
+ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
71
+ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
72
+ 7 by by ADP IN _ 9 case 9:case _
73
+ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
74
+ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
75
+ 10 of of ADP IN _ 12 case 12:case _
76
+ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
77
+ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
78
+ 13 of of ADP IN _ 15 case 15:case _
79
+ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
80
+ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
81
+ 16 ! ! PUNCT . _ 6 punct 6:punct _
82
+
83
+ """.lstrip()
84
+
85
+ DEV_DATA = """
86
+ 1 From from ADP IN _ 3 case 3:case _
87
+ 2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
88
+ 3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
89
+ 4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
90
+ 5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
91
+ 6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
92
+ 7 : : PUNCT : _ 4 punct 4:punct _
93
+
94
+ """.lstrip()
95
+
96
+ class TestLemmatizer:
97
+ @pytest.fixture(scope="class")
98
+ def charlm_args(self):
99
+ charlm = choose_lemma_charlm("en", "test", "default")
100
+ charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR)
101
+ return charlm_args
102
+
103
+
104
+ def run_training(self, tmp_path, train_text, dev_text, extra_args=None):
105
+ """
106
+ Run the training for a few iterations, load & return the model
107
+ """
108
+ pred_file = str(tmp_path / "pred.conllu")
109
+
110
+ save_name = "test_tagger.pt"
111
+ save_file = str(tmp_path / save_name)
112
+
113
+ train_file = str(tmp_path / "train.conllu")
114
+ with open(train_file, "w", encoding="utf-8") as fout:
115
+ fout.write(train_text)
116
+
117
+ dev_file = str(tmp_path / "dev.conllu")
118
+ with open(dev_file, "w", encoding="utf-8") as fout:
119
+ fout.write(dev_text)
120
+
121
+ args = ["--train_file", train_file,
122
+ "--eval_file", dev_file,
123
+ "--gold_file", dev_file,
124
+ "--output_file", pred_file,
125
+ "--num_epoch", "2",
126
+ "--log_step", "10",
127
+ "--save_dir", str(tmp_path),
128
+ "--save_name", save_name,
129
+ "--shorthand", "en_test"]
130
+ if extra_args is not None:
131
+ args = args + extra_args
132
+ lemmatizer.main(args)
133
+
134
+ assert os.path.exists(save_file)
135
+ saved_model = trainer.Trainer(model_file=save_file)
136
+ return saved_model
137
+
138
+ def test_basic_train(self, tmp_path):
139
+ """
140
+ Simple test of a few 'epochs' of lemmatizer training
141
+ """
142
+ self.run_training(tmp_path, TRAIN_DATA, DEV_DATA)
143
+
144
+ def test_charlm_train(self, tmp_path, charlm_args):
145
+ """
146
+ Simple test of a few 'epochs' of lemmatizer training
147
+ """
148
+ saved_model = self.run_training(tmp_path, TRAIN_DATA, DEV_DATA, extra_args=charlm_args)
149
+
150
+ # check that the charlm wasn't saved in here
151
+ args = saved_model.args
152
+ save_name = os.path.join(args['save_dir'], args['save_name'])
153
+ checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True)
154
+ assert not any(x.startswith("contextual_embedding") for x in checkpoint['model'].keys())
stanza/stanza/tests/lemma_classifier/test_data_preparation.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pytest
4
+
5
+ import stanza.models.lemma_classifier.utils as utils
6
+ import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier
7
+
8
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
9
+
10
+ EWT_ONE_SENTENCE = """
11
+ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002
12
+ # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002
13
+ # text = Here's a Miami Herald interview
14
+ 1-2 Here's _ _ _ _ _ _ _ _
15
+ 1 Here here ADV RB PronType=Dem 0 root 0:root _
16
+ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _
17
+ 3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
18
+ 4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _
19
+ 5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _
20
+ 6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _
21
+ """.lstrip()
22
+
23
+
24
+ EWT_TRAIN_SENTENCES = """
25
+ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002
26
+ # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002
27
+ # text = Here's a Miami Herald interview
28
+ 1-2 Here's _ _ _ _ _ _ _ _
29
+ 1 Here here ADV RB PronType=Dem 0 root 0:root _
30
+ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _
31
+ 3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
32
+ 4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _
33
+ 5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _
34
+ 6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _
35
+
36
+ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0027
37
+ # text = But Posada's nearly 80 years old
38
+ 1 But but CCONJ CC _ 7 cc 7:cc _
39
+ 2-3 Posada's _ _ _ _ _ _ _ _
40
+ 2 Posada Posada PROPN NNP Number=Sing 7 nsubj 7:nsubj _
41
+ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 cop 7:cop _
42
+ 4 nearly nearly ADV RB _ 5 advmod 5:advmod _
43
+ 5 80 80 NUM CD NumForm=Digit|NumType=Card 6 nummod 6:nummod _
44
+ 6 years year NOUN NNS Number=Plur 7 obl:npmod 7:obl:npmod _
45
+ 7 old old ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No
46
+
47
+ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0067
48
+ # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0011
49
+ # text = Now that's a post I can relate to.
50
+ 1 Now now ADV RB _ 5 advmod 5:advmod _
51
+ 2-3 that's _ _ _ _ _ _ _ _
52
+ 2 that that PRON DT Number=Sing|PronType=Dem 5 nsubj 5:nsubj _
53
+ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
54
+ 4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
55
+ 5 post post NOUN NN Number=Sing 0 root 0:root _
56
+ 6 I I PRON PRP Case=Nom|Number=Sing|Person=1|PronType=Prs 8 nsubj 8:nsubj _
57
+ 7 can can AUX MD VerbForm=Fin 8 aux 8:aux _
58
+ 8 relate relate VERB VB VerbForm=Inf 5 acl:relcl 5:acl:relcl _
59
+ 9 to to ADP IN _ 8 obl 8:obl SpaceAfter=No
60
+ 10 . . PUNCT . _ 5 punct 5:punct _
61
+
62
+ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0073
63
+ # newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0012
64
+ # text = hey that's a great blog
65
+ 1 hey hey INTJ UH _ 6 discourse 6:discourse _
66
+ 2-3 that's _ _ _ _ _ _ _ _
67
+ 2 that that PRON DT Number=Sing|PronType=Dem 6 nsubj 6:nsubj _
68
+ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
69
+ 4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
70
+ 5 great great ADJ JJ Degree=Pos 6 amod 6:amod _
71
+ 6 blog blog NOUN NN Number=Sing 0 root 0:root SpaceAfter=No
72
+
73
+ # sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0089
74
+ # text = And It's Not Hard To Do
75
+ 1 And and CCONJ CC _ 5 cc 5:cc _
76
+ 2-3 It's _ _ _ _ _ _ _ _
77
+ 2 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 expl 5:expl _
78
+ 3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
79
+ 4 Not not PART RB _ 5 advmod 5:advmod _
80
+ 5 Hard hard ADJ JJ Degree=Pos 0 root 0:root _
81
+ 6 To to PART TO _ 7 mark 7:mark _
82
+ 7 Do do VERB VB VerbForm=Inf 5 csubj 5:csubj SpaceAfter=No
83
+
84
+ # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0029
85
+ # text = Meanwhile, a decision's been reached
86
+ 1 Meanwhile meanwhile ADV RB _ 7 advmod 7:advmod SpaceAfter=No
87
+ 2 , , PUNCT , _ 1 punct 1:punct _
88
+ 3 a a DET DT Definite=Ind|PronType=Art 4 det 4:det _
89
+ 4-5 decision's _ _ _ _ _ _ _ _
90
+ 4 decision decision NOUN NN Number=Sing 7 nsubj:pass 7:nsubj:pass _
91
+ 5 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux 7:aux _
92
+ 6 been be AUX VBN Tense=Past|VerbForm=Part 7 aux:pass 7:aux:pass _
93
+ 7 reached reach VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
94
+
95
+ # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0138
96
+ # text = It's become a guardian of morality
97
+ 1-2 It's _ _ _ _ _ _ _ _
98
+ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|5:nsubj:xsubj _
99
+ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
100
+ 3 become become VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
101
+ 4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
102
+ 5 guardian guardian NOUN NN Number=Sing 3 xcomp 3:xcomp _
103
+ 6 of of ADP IN _ 7 case 7:case _
104
+ 7 morality morality NOUN NN Number=Sing 5 nmod 5:nmod:of _
105
+
106
+ # sent_id = email-enronsent15_01-0018
107
+ # text = It's got its own bathroom and tv
108
+ 1-2 It's _ _ _ _ _ _ _ _
109
+ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|13:nsubj _
110
+ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
111
+ 3 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
112
+ 4 its its PRON PRP$ Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 6 nmod:poss 6:nmod:poss _
113
+ 5 own own ADJ JJ Degree=Pos 6 amod 6:amod _
114
+ 6 bathroom bathroom NOUN NN Number=Sing 3 obj 3:obj _
115
+ 7 and and CCONJ CC _ 8 cc 8:cc _
116
+ 8 tv TV NOUN NN Number=Sing 6 conj 3:obj|6:conj:and SpaceAfter=No
117
+
118
+ # sent_id = newsgroup-groups.google.com_alt.animals.cat_01ff709c4bf2c60c_ENG_20040418_040100-0022
119
+ # text = It's also got the website
120
+ 1-2 It's _ _ _ _ _ _ _ _
121
+ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _
122
+ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _
123
+ 3 also also ADV RB _ 4 advmod 4:advmod _
124
+ 4 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
125
+ 5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _
126
+ 6 website website NOUN NN Number=Sing 4 obj 4:obj|12:obl _
127
+ """.lstrip()
128
+
129
+
130
+ # from the train set, actually
131
+ EWT_DEV_SENTENCES = """
132
+ # sent_id = answers-20111108104724AAuBUR7_ans-0044
133
+ # text = He's only exhibited weight loss and some muscle atrophy
134
+ 1-2 He's _ _ _ _ _ _ _ _
135
+ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _
136
+ 2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _
137
+ 3 only only ADV RB _ 4 advmod 4:advmod _
138
+ 4 exhibited exhibit VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
139
+ 5 weight weight NOUN NN Number=Sing 6 compound 6:compound _
140
+ 6 loss loss NOUN NN Number=Sing 4 obj 4:obj _
141
+ 7 and and CCONJ CC _ 10 cc 10:cc _
142
+ 8 some some DET DT PronType=Ind 10 det 10:det _
143
+ 9 muscle muscle NOUN NN Number=Sing 10 compound 10:compound _
144
+ 10 atrophy atrophy NOUN NN Number=Sing 6 conj 4:obj|6:conj:and SpaceAfter=No
145
+
146
+ # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0097
147
+ # text = It's a good thing too.
148
+ 1-2 It's _ _ _ _ _ _ _ _
149
+ 1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _
150
+ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
151
+ 3 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
152
+ 4 good good ADJ JJ Degree=Pos 5 amod 5:amod _
153
+ 5 thing thing NOUN NN Number=Sing 0 root 0:root _
154
+ 6 too too ADV RB _ 5 advmod 5:advmod SpaceAfter=No
155
+ 7 . . PUNCT . _ 5 punct 5:punct _
156
+ """.lstrip()
157
+
158
+ # from the train set, actually
159
+ EWT_TEST_SENTENCES = """
160
+ # sent_id = reviews-162422-0015
161
+ # text = He said he's had a long and bad day.
162
+ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 2 nsubj 2:nsubj _
163
+ 2 said say VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _
164
+ 3-4 he's _ _ _ _ _ _ _ _
165
+ 3 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _
166
+ 4 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux 5:aux _
167
+ 5 had have VERB VBN Tense=Past|VerbForm=Part 2 ccomp 2:ccomp _
168
+ 6 a a DET DT Definite=Ind|PronType=Art 10 det 10:det _
169
+ 7 long long ADJ JJ Degree=Pos 10 amod 10:amod _
170
+ 8 and and CCONJ CC _ 9 cc 9:cc _
171
+ 9 bad bad ADJ JJ Degree=Pos 7 conj 7:conj:and|10:amod _
172
+ 10 day day NOUN NN Number=Sing 5 obj 5:obj SpaceAfter=No
173
+ 11 . . PUNCT . _ 2 punct 2:punct _
174
+
175
+ # sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0100
176
+ # text = What's a few dead soldiers
177
+ 1-2 What's _ _ _ _ _ _ _ _
178
+ 1 What what PRON WP PronType=Int 6 nsubj 6:nsubj _
179
+ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
180
+ 3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
181
+ 4 few few ADJ JJ Degree=Pos 6 amod 6:amod _
182
+ 5 dead dead ADJ JJ Degree=Pos 6 amod 6:amod _
183
+ 6 soldiers soldier NOUN NNS Number=Plur 0 root 0:root _
184
+ """
185
+
186
+ def write_test_dataset(tmp_path, texts, datasets):
187
+ ud_path = tmp_path / "ud"
188
+ input_path = ud_path / "UD_English-EWT"
189
+ output_path = tmp_path / "data" / "lemma_classifier"
190
+
191
+ os.makedirs(input_path, exist_ok=True)
192
+
193
+ for text, dataset in zip(texts, datasets):
194
+ sample_file = input_path / ("en_ewt-ud-%s.conllu" % dataset)
195
+ with open(sample_file, "w", encoding="utf-8") as fout:
196
+ fout.write(text)
197
+
198
+ paths = {"UDBASE": ud_path,
199
+ "LEMMA_CLASSIFIER_DATA_DIR": output_path}
200
+
201
+ return paths
202
+
203
+ def write_english_test_dataset(tmp_path):
204
+ texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES)
205
+ datasets = prepare_lemma_classifier.SECTIONS
206
+ return write_test_dataset(tmp_path, texts, datasets)
207
+
208
+ def convert_english_dataset(tmp_path):
209
+ paths = write_english_test_dataset(tmp_path)
210
+ converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have")
211
+ assert len(converted_files) == 3
212
+
213
+ return converted_files
214
+
215
+ def test_convert_one_sentence(tmp_path):
216
+ texts = [EWT_ONE_SENTENCE]
217
+ datasets = ["train"]
218
+ paths = write_test_dataset(tmp_path, texts, datasets)
219
+
220
+ converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have", ["train"])
221
+ assert len(converted_files) == 1
222
+
223
+ dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)
224
+
225
+ assert len(dataset) == 1
226
+ assert dataset.label_decoder == {'be': 0}
227
+ id_to_upos = {y: x for x, y in dataset.upos_to_id.items()}
228
+
229
+ for text_batches, _, upos_batches, _ in dataset:
230
+ assert text_batches == [['Here', "'s", 'a', 'Miami', 'Herald', 'interview']]
231
+ upos = [id_to_upos[x] for x in upos_batches[0]]
232
+ assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN']
233
+
234
+ def test_convert_dataset(tmp_path):
235
+ converted_files = convert_english_dataset(tmp_path)
236
+
237
+ dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)
238
+
239
+ assert len(dataset) == 1
240
+ label_decoder = dataset.label_decoder
241
+ assert len(label_decoder) == 2
242
+ assert "be" in label_decoder
243
+ assert "have" in label_decoder
244
+ for text_batches, _, _, _ in dataset:
245
+ assert len(text_batches) == 9
246
+
247
+ dataset = utils.Dataset(converted_files[1], get_counts=True, batch_size=10, shuffle=False)
248
+ assert len(dataset) == 1
249
+ for text_batches, _, _, _ in dataset:
250
+ assert len(text_batches) == 2
251
+
252
+ dataset = utils.Dataset(converted_files[2], get_counts=True, batch_size=10, shuffle=False)
253
+ assert len(dataset) == 1
254
+ for text_batches, _, _, _ in dataset:
255
+ assert len(text_batches) == 2
256
+
stanza/stanza/tests/mwt/test_character_classifier.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ from stanza.models import mwt_expander
5
+ from stanza.models.mwt.character_classifier import CharacterClassifier
6
+ from stanza.models.mwt.data import DataLoader
7
+ from stanza.models.mwt.trainer import Trainer
8
+ from stanza.utils.conll import CoNLL
9
+
10
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
11
+
12
+ ENG_TRAIN = """
13
+ # text = Elena's motorcycle tour
14
+ 1-2 Elena's _ _ _ _ _ _ _ _
15
+ 1 Elena Elena PROPN NNP Number=Sing 4 nmod:poss 4:nmod:poss _
16
+ 2 's 's PART POS _ 1 case 1:case _
17
+ 3 motorcycle motorcycle NOUN NN Number=Sing 4 compound 4:compound _
18
+ 4 tour tour NOUN NN Number=Sing 0 root 0:root _
19
+
20
+
21
+ # text = women's reproductive health
22
+ 1-2 women's _ _ _ _ _ _ _ _
23
+ 1 women woman NOUN NNS Number=Plur 4 nmod:poss 4:nmod:poss _
24
+ 2 's 's PART POS _ 1 case 1:case _
25
+ 3 reproductive reproductive ADJ JJ Degree=Pos 4 amod 4:amod _
26
+ 4 health health NOUN NN Number=Sing 0 root 0:root SpaceAfter=No
27
+
28
+
29
+ # text = The Chernobyl Children's Project
30
+ 1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
31
+ 2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
32
+ 3-4 Children's _ _ _ _ _ _ _ _
33
+ 3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
34
+ 4 's 's PART POS _ 3 case 3:case _
35
+ 5 Project Project PROPN NNP Number=Sing 0 root 0:root _
36
+
37
+ """.lstrip()
38
+
39
+ ENG_DEV = """
40
+ # text = The Chernobyl Children's Project
41
+ 1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
42
+ 2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
43
+ 3-4 Children's _ _ _ _ _ _ _ _
44
+ 3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
45
+ 4 's 's PART POS _ 3 case 3:case _
46
+ 5 Project Project PROPN NNP Number=Sing 0 root 0:root _
47
+
48
+ """.lstrip()
49
+
50
+ def test_train(tmp_path):
51
+ test_train = str(os.path.join(tmp_path, "en_test.train.conllu"))
52
+ with open(test_train, "w") as fout:
53
+ fout.write(ENG_TRAIN)
54
+
55
+ test_dev = str(os.path.join(tmp_path, "en_test.dev.conllu"))
56
+ with open(test_dev, "w") as fout:
57
+ fout.write(ENG_DEV)
58
+
59
+ test_output = str(os.path.join(tmp_path, "en_test.dev.pred.conllu"))
60
+ model_name = "en_test_mwt.pt"
61
+
62
+ args = [
63
+ "--data_dir", str(tmp_path),
64
+ "--train_file", test_train,
65
+ "--eval_file", test_dev,
66
+ "--gold_file", test_dev,
67
+ "--lang", "en",
68
+ "--shorthand", "en_test",
69
+ "--output_file", test_output,
70
+ "--save_dir", str(tmp_path),
71
+ "--save_name", model_name,
72
+ "--num_epoch", "10",
73
+ ]
74
+
75
+ mwt_expander.main(args=args)
76
+
77
+ model = Trainer(model_file=os.path.join(tmp_path, model_name))
78
+ assert model.model is not None
79
+ assert isinstance(model.model, CharacterClassifier)
80
+
81
+ doc = CoNLL.conll2doc(input_str=ENG_DEV)
82
+ dataloader = DataLoader(doc, 10, model.args, vocab=model.vocab, evaluation=True, expand_unk_vocab=True)
83
+ preds = []
84
+ for i, batch in enumerate(dataloader.to_loader()):
85
+ assert i == 0 # there should only be one batch
86
+ preds += model.predict(batch, never_decode_unk=True, vocab=dataloader.vocab)
87
+ assert len(preds) == 1
88
+ # it is possible to make a version of the test where this happens almost every time
89
+ # for example, running for 100 epochs makes the model succeed 30 times in a row
90
+ # (never saw a failure)
91
+ # but the one time that failure happened, it would be really annoying
92
+ #assert preds[0] == "Children 's"
stanza/stanza/tests/mwt/test_english_corner_cases.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple English MWT corner cases which might be more widely applicable to other MWT languages
3
+
4
+ - unknown English character doesn't result in bizarre splits
5
+ - Casing or CASING doesn't get lost in the dictionary lookup
6
+
7
+ In the English UD datasets, the MWT are composed exactly of the
8
+ subwords, so the MWT model should be chopping up the input text rather
9
+ than generating new text.
10
+
11
+ Furthermore, SHE'S and She's should be split "SHE 'S" and "She 's" respectively
12
+ """
13
+
14
+ import pytest
15
+ import stanza
16
+
17
+ from stanza.tests import TEST_MODELS_DIR
18
+
19
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
20
+
21
+ def test_mwt_unknown_char():
22
+ pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
23
+
24
+ mwt_trainer = pipeline.processors['mwt']._trainer
25
+
26
+ assert mwt_trainer.args['force_exact_pieces']
27
+
28
+ # find a letter 'i' which isn't in the training data
29
+ # the MWT model should still recognize a possessive containing this letter
30
+ assert "i" in mwt_trainer.vocab
31
+ for letter in "ĩîíìī":
32
+ if letter not in mwt_trainer.vocab:
33
+ break
34
+ else:
35
+ raise AssertionError("Need to update the MWT test - all of the non-standard letters 'i' are now in the MWT vocab")
36
+
37
+ word = "Jenn" + letter + "fer"
38
+ possessive = word + "'s"
39
+ text = "I wanna lick " + possessive + " antennae"
40
+ doc = pipeline(text)
41
+ assert doc.sentences[0].tokens[1].text == 'wanna'
42
+ assert len(doc.sentences[0].tokens[1].words) == 2
43
+ assert "".join(x.text for x in doc.sentences[0].tokens[1].words) == 'wanna'
44
+
45
+ assert doc.sentences[0].tokens[3].text == possessive
46
+ assert len(doc.sentences[0].tokens[3].words) == 2
47
+ assert "".join(x.text for x in doc.sentences[0].tokens[3].words) == possessive
48
+
49
+
50
+ def test_english_mwt_casing():
51
+ """
52
+ Test that for a word where the lowercase split is known, the correct casing is still used
53
+
54
+ Once upon a time, the logic used in the MWT expander would split
55
+ SHE'S -> she 's
56
+
57
+ which is a very surprising tokenization to people expecting
58
+ the original text in the output document
59
+ """
60
+ pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
61
+
62
+ mwt_trainer = pipeline.processors['mwt']._trainer
63
+ for i in range(1, 20):
64
+ # many test cases follow this pattern for some reason,
65
+ # so we should proactively look for a test case which hasn't
66
+ # made its way into the MWT dictionary
67
+ unknown_name = "jennife" + "r" * i + "'s"
68
+ if unknown_name not in mwt_trainer.expansion_dict and unknown_name.upper() not in mwt_trainer.expansion_dict:
69
+ unknown_name = unknown_name.upper()
70
+ break
71
+ else:
72
+ raise AssertionError("Need a new heuristic for the unknown word in the English MWT!")
73
+
74
+ # this SHOULD show up in the expansion dict
75
+ assert "she's" in mwt_trainer.expansion_dict, "Expected |she's| to be in the English MWT expansion dict... perhaps find a different test case"
76
+
77
+ text = [x.text for x in pipeline("JENNIFER HAS NICE ANTENNAE").sentences[0].words]
78
+ assert text == ['JENNIFER', 'HAS', 'NICE', 'ANTENNAE']
79
+
80
+ text = [x.text for x in pipeline(unknown_name + " GOT NICE ANTENNAE").sentences[0].words]
81
+ assert text == [unknown_name[:-2], "'S", 'GOT', 'NICE', 'ANTENNAE']
82
+
83
+ text = [x.text for x in pipeline("SHE'S GOT NICE ANTENNAE").sentences[0].words]
84
+ assert text == ['SHE', "'S", 'GOT', 'NICE', 'ANTENNAE']
85
+
86
+ text = [x.text for x in pipeline("She's GOT NICE ANTENNAE").sentences[0].words]
87
+ assert text == ['She', "'s", 'GOT', 'NICE', 'ANTENNAE']
88
+
stanza/stanza/tests/ner/test_bsf_2_iob.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests the conversion code for the lang_uk NER dataset
3
+ """
4
+
5
+ import unittest
6
+ from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo
7
+
8
+ import pytest
9
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
10
+
11
+ class TestBsf2Iob(unittest.TestCase):
12
+
13
+ def test_1line_follow_markup_iob(self):
14
+ data = 'тележурналіст Василь .'
15
+ bsf_markup = 'T1 PERS 14 20 Василь'
16
+ expected = '''тележурналіст O
17
+ Василь B-PERS
18
+ . O'''
19
+ self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
20
+
21
+ def test_1line_2tok_markup_iob(self):
22
+ data = 'тележурналіст Василь Нагірний .'
23
+ bsf_markup = 'T1 PERS 14 29 Василь Нагірний'
24
+ expected = '''тележурналіст O
25
+ Василь B-PERS
26
+ Нагірний I-PERS
27
+ . O'''
28
+ self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
29
+
30
+ def test_1line_Long_tok_markup_iob(self):
31
+ data = 'А в музеї Гуцульщини і Покуття можна '
32
+ bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття'
33
+ expected = '''А O
34
+ в O
35
+ музеї B-ORG
36
+ Гуцульщини I-ORG
37
+ і I-ORG
38
+ Покуття I-ORG
39
+ можна O'''
40
+ self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
41
+
42
+ def test_2line_2tok_markup_iob(self):
43
+ data = '''тележурналіст Василь Нагірний .
44
+ В івано-франківському видавництві «Лілея НВ» вийшла друком'''
45
+ bsf_markup = '''T1 PERS 14 29 Василь Нагірний
46
+ T2 ORG 67 75 Лілея НВ'''
47
+ expected = '''тележурналіст O
48
+ Василь B-PERS
49
+ Нагірний I-PERS
50
+ . O
51
+
52
+
53
+ В O
54
+ івано-франківському O
55
+ видавництві O
56
+ « O
57
+ Лілея B-ORG
58
+ НВ I-ORG
59
+ » O
60
+ вийшла O
61
+ друком O'''
62
+ self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
63
+
64
+ def test_all_multiline_iob(self):
65
+ data = '''його книжечка «А .
66
+ Kubler .
67
+ Світло і тіні маестро» .
68
+ Причому'''
69
+ bsf_markup = '''T4 MISC 15 49 А .
70
+ Kubler .
71
+ Світло і тіні маестро
72
+ '''
73
+ expected = '''його O
74
+ книжечка O
75
+ « O
76
+ А B-MISC
77
+ . I-MISC
78
+ Kubler I-MISC
79
+ . I-MISC
80
+ Світло I-MISC
81
+ і I-MISC
82
+ тіні I-MISC
83
+ маестро I-MISC
84
+ » O
85
+ . O
86
+
87
+
88
+ Причому O'''
89
+ self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
90
+
91
+
92
+ if __name__ == '__main__':
93
+ unittest.main()
stanza/stanza/tests/ner/test_convert_amt.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test some of the functions used for converting an AMT json to a Stanza json
3
+ """
4
+
5
+
6
+ import os
7
+
8
+ import pytest
9
+
10
+ import stanza
11
+ from stanza.utils.datasets.ner import convert_amt
12
+
13
+ from stanza.tests import TEST_MODELS_DIR
14
+
15
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
16
+
17
+ TEXT = "Jennifer Sh'reyan has lovely antennae."
18
+
19
+ def fake_label(label, start_char, end_char):
20
+ return {'label': label,
21
+ 'startOffset': start_char,
22
+ 'endOffset': end_char}
23
+
24
+ LABELS = [
25
+ fake_label('Person', 0, 8),
26
+ fake_label('Person', 9, 17),
27
+ fake_label('Person', 0, 17),
28
+ fake_label('Andorian', 0, 8),
29
+ fake_label('Appendage', 29, 37),
30
+ fake_label('Person', 1, 8),
31
+ fake_label('Person', 0, 7),
32
+ fake_label('Person', 0, 9),
33
+ fake_label('Appendage', 29, 38),
34
+ ]
35
+
36
+ def fake_labels(*indices):
37
+ return [LABELS[x] for x in indices]
38
+
39
+ def fake_docs(*indices):
40
+ return [(TEXT, fake_labels(*indices))]
41
+
42
+ def test_remove_nesting():
43
+ """
44
+ Test a few orders on nested items to make sure the desired results are coming back
45
+ """
46
+ # this should be unchanged
47
+ result = convert_amt.remove_nesting(fake_docs(0, 1))
48
+ assert result == fake_docs(0, 1)
49
+
50
+ # this should be returned sorted
51
+ result = convert_amt.remove_nesting(fake_docs(0, 4, 1))
52
+ assert result == fake_docs(0, 1, 4)
53
+
54
+ # this should just have one copy
55
+ result = convert_amt.remove_nesting(fake_docs(0, 0))
56
+ assert result == fake_docs(0)
57
+
58
+ # outer one preferred
59
+ result = convert_amt.remove_nesting(fake_docs(0, 2))
60
+ assert result == fake_docs(2)
61
+ result = convert_amt.remove_nesting(fake_docs(1, 2))
62
+ assert result == fake_docs(2)
63
+ result = convert_amt.remove_nesting(fake_docs(5, 2))
64
+ assert result == fake_docs(2)
65
+ # order doesn't matter
66
+ result = convert_amt.remove_nesting(fake_docs(0, 4, 2))
67
+ assert result == fake_docs(2, 4)
68
+ result = convert_amt.remove_nesting(fake_docs(2, 4, 0))
69
+ assert result == fake_docs(2, 4)
70
+
71
+ # first one preferred
72
+ result = convert_amt.remove_nesting(fake_docs(0, 3))
73
+ assert result == fake_docs(0)
74
+ result = convert_amt.remove_nesting(fake_docs(3, 0))
75
+ assert result == fake_docs(3)
76
+
77
+ def test_process_doc():
78
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
79
+
80
+ def check_results(doc, *expected):
81
+ ner = [x[1] for x in doc[0]]
82
+ assert ner == list(expected)
83
+
84
+ # test a standard case of all the values lining up
85
+ doc = convert_amt.process_doc(TEXT, fake_labels(2, 4), nlp)
86
+ check_results(doc, "B-Person", "I-Person", "O", "O", "B-Appendage", "O")
87
+
88
+ # test a slightly wrong start index
89
+ doc = convert_amt.process_doc(TEXT, fake_labels(5, 1, 4), nlp)
90
+ check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O")
91
+
92
+ # test a slightly wrong end index
93
+ doc = convert_amt.process_doc(TEXT, fake_labels(6, 1, 4), nlp)
94
+ check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O")
95
+
96
+ # test a slightly wronger end index
97
+ doc = convert_amt.process_doc(TEXT, fake_labels(7, 4), nlp)
98
+ check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O")
99
+
100
+ # test a period at the end of a text - should not be captured
101
+ doc = convert_amt.process_doc(TEXT, fake_labels(7, 8), nlp)
102
+ check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O")
103
+
104
+
stanza/stanza/tests/ner/test_convert_starlang_ner.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple different classes of trees to check the output of the Starlang conversion for NER
3
+ """
4
+
5
+ import os
6
+ import tempfile
7
+
8
+ import pytest
9
+
10
+ from stanza.utils.datasets.ner import convert_starlang_ner
11
+
12
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
13
+
14
+ TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )"
15
+
16
+ def test_read_tree():
17
+ """
18
+ Test a basic tree read
19
+ """
20
+ sentence = convert_starlang_ner.read_tree(TREE)
21
+ expected = [('Bayan', 'PERSON'), ('Haag', 'PERSON'), ('Elianti', 'O'), ('çalar', 'O'), ('.', 'O')]
22
+ assert sentence == expected
23
+
stanza/stanza/tests/ner/test_from_conllu.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza import Pipeline
4
+ from stanza.utils.conll import CoNLL
5
+ from stanza.tests import TEST_MODELS_DIR
6
+
7
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
8
+
9
+ def test_from_conllu():
10
+ """
11
+ If the doc does not have the entire text available, make sure it still safely processes the text
12
+
13
+ Test case supplied from user - see issue #1428
14
+ """
15
+ pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", download_method=None)
16
+ doc = pipe("In February, I traveled to Seattle. Dr. Pritchett gave me a new hip")
17
+ ents = [x.text for x in doc.ents]
18
+ # the default NER model ought to find these three
19
+ assert ents == ['February', 'Seattle', 'Pritchett']
20
+
21
+ doc_conllu = "{:C}\n\n".format(doc)
22
+ doc = CoNLL.conll2doc(input_str=doc_conllu)
23
+ pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", tokenize_pretokenized=True, download_method=None)
24
+ pipe(doc)
25
+ ents = [x.text for x in doc.ents]
26
+ # this should still work when processed from a CoNLLu document
27
+ # the bug previously caused a crash because the text to construct
28
+ # the entities was not available, since the Document wouldn't have
29
+ # the entire document text available
30
+ assert ents == ['February', 'Seattle', 'Pritchett']
stanza/stanza/tests/ner/test_ner_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.tests import *
4
+
5
+ from stanza.models.common.vocab import EMPTY
6
+ from stanza.models.ner import utils
7
+
8
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
9
+
10
+ WORDS = [["Unban", "Mox", "Opal"], ["Ragavan", "is", "red"], ["Urza", "Lord", "High", "Artificer", "goes", "infinite", "with", "Thopter", "Sword"]]
11
+ BIO_TAGS = [["O", "B-ART", "I-ART"], ["B-MONKEY", "O", "B-COLOR"], ["B-PER", "I-PER", "I-PER", "I-PER", "O", "O", "O", "B-WEAPON", "B-WEAPON"]]
12
+ BIO_U_TAGS = [["O", "B_ART", "I_ART"], ["B_MONKEY", "O", "B_COLOR"], ["B_PER", "I_PER", "I_PER", "I_PER", "O", "O", "O", "B_WEAPON", "B_WEAPON"]]
13
+ BIOES_TAGS = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "S-WEAPON", "S-WEAPON"]]
14
+ # note the problem with not using BIO tags - the consecutive tags for thopter/sword get treated as one item
15
+ BASIC_TAGS = [["O", "ART", "ART"], ["MONKEY", "O", "COLOR"], [ "PER", "PER", "PER", "PER", "O", "O", "O", "WEAPON", "WEAPON"]]
16
+ BASIC_BIOES = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "B-WEAPON", "E-WEAPON"]]
17
+ ALT_BIO = [["O", "B-MANA", "I-MANA"], ["B-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
18
+ ALT_BIOES = [["O", "B-MANA", "E-MANA"], ["S-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]]
19
+ NONE_BIO = [["O", "B-MANA", "I-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
20
+ NONE_BIOES = [["O", "B-MANA", "E-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]]
21
+ EMPTY_BIO = [["O", "B-MANA", "I-MANA"], [EMPTY, EMPTY, EMPTY], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
22
+
23
+ def test_normalize_empty_tags():
24
+ sentences = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, NONE_BIO)]
25
+ new_sentences = utils.normalize_empty_tags(sentences)
26
+ expected = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, EMPTY_BIO)]
27
+ assert new_sentences == expected
28
+
29
+ def check_reprocessed_tags(words, input_tags, expected_tags):
30
+ sentences = [list(zip(x, y)) for x, y in zip(words, input_tags)]
31
+ retagged = utils.process_tags(sentences=sentences, scheme="bioes")
32
+ # process_tags selectively returns tuples or strings based on the input
33
+ # so we don't need to fiddle with the expected output format here
34
+ expected_retagged = [list(zip(x, y)) for x, y in zip(words, expected_tags)]
35
+ assert retagged == expected_retagged
36
+
37
+ def test_process_tags_bio():
38
+ check_reprocessed_tags(WORDS, BIO_TAGS, BIOES_TAGS)
39
+ # check that the alternate version is correct as well
40
+ # that way we can independently check the two layer version
41
+ check_reprocessed_tags(WORDS, ALT_BIO, ALT_BIOES)
42
+
43
+ def test_process_tags_with_none():
44
+ # if there is a block of tags with None in them, the Nones should be skipped over
45
+ check_reprocessed_tags(WORDS, NONE_BIO, NONE_BIOES)
46
+
47
+ def merge_tags(*tags):
48
+ merged_tags = [[tuple(x) for x in zip(*sentences)] # combine tags such as ("O", "O"), ("B-ART", "B-MANA"), ...
49
+ for sentences in zip(*tags)] # ... for each set of sentences
50
+ return merged_tags
51
+
52
+ def test_combined_tags_bio():
53
+ bio_tags = merge_tags(BIO_TAGS, ALT_BIO)
54
+ expected = merge_tags(BIOES_TAGS, ALT_BIOES)
55
+ check_reprocessed_tags(WORDS, bio_tags, expected)
56
+
57
+ def test_combined_tags_mixed():
58
+ bio_tags = merge_tags(BIO_TAGS, ALT_BIOES)
59
+ expected = merge_tags(BIOES_TAGS, ALT_BIOES)
60
+ check_reprocessed_tags(WORDS, bio_tags, expected)
61
+
62
+ def test_process_tags_basic():
63
+ check_reprocessed_tags(WORDS, BASIC_TAGS, BASIC_BIOES)
64
+
65
+ def test_process_tags_bioes():
66
+ """
67
+ This one should not change, naturally
68
+ """
69
+ check_reprocessed_tags(WORDS, BIOES_TAGS, BIOES_TAGS)
70
+ check_reprocessed_tags(WORDS, BASIC_BIOES, BASIC_BIOES)
71
+
72
+ def run_flattened(fn, tags):
73
+ return fn([x for x in y for y in tags])
74
+
75
+ def test_check_bio():
76
+ assert utils.is_bio_scheme([x for y in BIO_TAGS for x in y])
77
+ assert not utils.is_bio_scheme([x for y in BIOES_TAGS for x in y])
78
+ assert not utils.is_bio_scheme([x for y in BASIC_TAGS for x in y])
79
+ assert not utils.is_bio_scheme([x for y in BASIC_BIOES for x in y])
80
+
81
+ def test_check_basic():
82
+ assert not utils.is_basic_scheme([x for y in BIO_TAGS for x in y])
83
+ assert not utils.is_basic_scheme([x for y in BIOES_TAGS for x in y])
84
+ assert utils.is_basic_scheme([x for y in BASIC_TAGS for x in y])
85
+ assert not utils.is_basic_scheme([x for y in BASIC_BIOES for x in y])
86
+
87
+ def test_underscores():
88
+ """
89
+ Check that the methods work if the inputs are underscores instead of dashes
90
+ """
91
+ assert not utils.is_basic_scheme([x for y in BIO_U_TAGS for x in y])
92
+ check_reprocessed_tags(WORDS, BIO_U_TAGS, BIOES_TAGS)
93
+
94
+ def test_merge_tags():
95
+ """
96
+ Check a few versions of the tag sequence merging
97
+ """
98
+ seq1 = [ "O", "O", "O", "B-FOO", "E-FOO", "O"]
99
+ seq2 = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"]
100
+ seq3 = [ "B-FOO", "E-FOO", "B-FOO", "E-FOO", "O", "O"]
101
+ seq_err = [ "O", "B-FOO", "O", "B-FOO", "E-FOO", "O"]
102
+ seq_err2 = [ "O", "B-FOO", "O", "B-FOO", "B-FOO", "O"]
103
+ seq_err3 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "O"]
104
+ seq_err4 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "I-FOO"]
105
+
106
+ result = utils.merge_tags(seq1, seq2)
107
+ expected = [ "S-FOO", "O", "O", "B-FOO", "E-FOO", "O"]
108
+ assert result == expected
109
+
110
+ result = utils.merge_tags(seq2, seq1)
111
+ expected = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"]
112
+ assert result == expected
113
+
114
+ result = utils.merge_tags(seq1, seq3)
115
+ expected = [ "B-FOO", "E-FOO", "O", "B-FOO", "E-FOO", "O"]
116
+ assert result == expected
117
+
118
+ with pytest.raises(ValueError):
119
+ result = utils.merge_tags(seq1, seq_err)
120
+
121
+ with pytest.raises(ValueError):
122
+ result = utils.merge_tags(seq1, seq_err2)
123
+
124
+ with pytest.raises(ValueError):
125
+ result = utils.merge_tags(seq1, seq_err3)
126
+
127
+ with pytest.raises(ValueError):
128
+ result = utils.merge_tags(seq1, seq_err4)
129
+
stanza/stanza/tests/pipeline/__init__.py ADDED
File without changes
stanza/stanza/tests/pipeline/test_arabic_pipeline.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Small test of loading the Arabic pipeline
3
+
4
+ The main goal is to check that nothing goes wrong with RtL languages,
5
+ but incidentally this would have caught a bug where the xpos tags
6
+ were split into individual pieces instead of reassembled as expected
7
+ """
8
+
9
+ import pytest
10
+ import stanza
11
+
12
+ from stanza.tests import TEST_MODELS_DIR
13
+
14
+ pytestmark = pytest.mark.pipeline
15
+
16
+ def test_arabic_pos_pipeline():
17
+ pipe = stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'ar'})
18
+ text = "ولم يتم اعتقال احد بحسب المتحدث باسم الشرطة."
19
+
20
+ doc = pipe(text)
21
+ # the first token translates to "and not", seems common enough
22
+ # that we should be able to rely on it having a stable MWT and tag
23
+
24
+ assert len(doc.sentences) == 1
25
+ assert doc.sentences[0].tokens[0].text == "ولم"
26
+ assert doc.sentences[0].words[0].xpos == "C---------"
27
+ assert doc.sentences[0].words[1].xpos == "F---------"
stanza/stanza/tests/pipeline/test_core.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import shutil
3
+ import tempfile
4
+
5
+ import stanza
6
+
7
+ from stanza.tests import *
8
+
9
+ from stanza.pipeline import core
10
+ from stanza.resources.common import get_md5, load_resources_json
11
+
12
+ pytestmark = pytest.mark.pipeline
13
+
14
+ def test_pretagged():
15
+ """
16
+ Test that the pipeline does or doesn't build if pos is left out and pretagged is specified
17
+ """
18
+ nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,pos,lemma,depparse")
19
+ with pytest.raises(core.PipelineRequirementsException):
20
+ nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse")
21
+ nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True)
22
+ nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", pretagged=True)
23
+ # test that the module specific flag overrides the general flag
24
+ nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True, pretagged=False)
25
+
26
+ def test_download_missing_ner_model():
27
+ """
28
+ Test that the pipeline will automatically download missing models
29
+ """
30
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
31
+ stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined", verbose=False)
32
+ pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"ner": ("ontonotes_charlm")})
33
+
34
+ assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
35
+ en_dir = os.path.join(test_dir, 'en')
36
+ en_dir_listing = sorted(os.listdir(en_dir))
37
+ assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']
38
+ assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']
39
+
40
+
41
+ def test_download_missing_resources():
42
+ """
43
+ Test that the pipeline will automatically download missing models
44
+ """
45
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
46
+ pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"tokenize": "combined", "ner": "ontonotes_charlm"})
47
+
48
+ assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
49
+ en_dir = os.path.join(test_dir, 'en')
50
+ en_dir_listing = sorted(os.listdir(en_dir))
51
+ assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']
52
+ assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']
53
+
54
+
55
+ def test_download_resources_overwrites():
56
+ """
57
+ Test that the DOWNLOAD_RESOURCES method overwrites an existing resources.json
58
+ """
59
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
60
+ pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
61
+
62
+ assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
63
+ resources_path = os.path.join(test_dir, 'resources.json')
64
+ mod_time = os.path.getmtime(resources_path)
65
+
66
+ pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
67
+ new_mod_time = os.path.getmtime(resources_path)
68
+ assert mod_time != new_mod_time
69
+
70
+ def test_reuse_resources_overwrites():
71
+ """
72
+ Test that the REUSE_RESOURCES method does *not* overwrite an existing resources.json
73
+ """
74
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
75
+ pipe = stanza.Pipeline("en",
76
+ download_method=core.DownloadMethod.REUSE_RESOURCES,
77
+ model_dir=test_dir,
78
+ processors="tokenize",
79
+ package={"tokenize": "combined"})
80
+
81
+ assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
82
+ resources_path = os.path.join(test_dir, 'resources.json')
83
+ mod_time = os.path.getmtime(resources_path)
84
+
85
+ pipe = stanza.Pipeline("en",
86
+ download_method=core.DownloadMethod.REUSE_RESOURCES,
87
+ model_dir=test_dir,
88
+ processors="tokenize",
89
+ package={"tokenize": "combined"})
90
+ new_mod_time = os.path.getmtime(resources_path)
91
+ assert mod_time == new_mod_time
92
+
93
+
94
+ def test_download_not_repeated():
95
+ """
96
+ Test that a model is only downloaded once if it already matches the expected model from the resources file
97
+ """
98
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
99
+ stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined")
100
+
101
+ assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
102
+ en_dir = os.path.join(test_dir, 'en')
103
+ en_dir_listing = sorted(os.listdir(en_dir))
104
+ assert en_dir_listing == ['mwt', 'tokenize']
105
+ tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt")
106
+ mod_time = os.path.getmtime(tokenize_path)
107
+
108
+ pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
109
+ assert os.path.getmtime(tokenize_path) == mod_time
110
+
111
+ def test_download_none():
112
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
113
+ stanza.download("it", model_dir=test_dir, processors="tokenize", package="combined")
114
+ stanza.download("it", model_dir=test_dir, processors="tokenize", package="vit")
115
+
116
+ it_dir = os.path.join(test_dir, 'it')
117
+ it_dir_listing = sorted(os.listdir(it_dir))
118
+ assert sorted(it_dir_listing) == ['mwt', 'tokenize']
119
+ combined_path = os.path.join(it_dir, "tokenize", "combined.pt")
120
+ vit_path = os.path.join(it_dir, "tokenize", "vit.pt")
121
+
122
+ assert os.path.exists(combined_path)
123
+ assert os.path.exists(vit_path)
124
+
125
+ combined_md5 = get_md5(combined_path)
126
+ vit_md5 = get_md5(vit_path)
127
+ # check that the models are different
128
+ # otherwise the test is not testing anything
129
+ assert combined_md5 != vit_md5
130
+
131
+ shutil.copyfile(vit_path, combined_path)
132
+ assert get_md5(combined_path) == vit_md5
133
+
134
+ pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=None)
135
+ assert get_md5(combined_path) == vit_md5
136
+
137
+ pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
138
+ assert get_md5(combined_path) != vit_md5
139
+
140
+
141
+ def check_download_method_updates(download_method):
142
+ """
143
+ Run a single test of creating a pipeline with a given download_method, checking that the model is updated
144
+ """
145
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
146
+ stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined")
147
+
148
+ assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
149
+ en_dir = os.path.join(test_dir, 'en')
150
+ en_dir_listing = sorted(os.listdir(en_dir))
151
+ assert en_dir_listing == ['mwt', 'tokenize']
152
+ tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt")
153
+
154
+ with open(tokenize_path, "w") as fout:
155
+ fout.write("Unban mox opal!")
156
+ mod_time = os.path.getmtime(tokenize_path)
157
+
158
+ pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=download_method)
159
+ assert os.path.getmtime(tokenize_path) != mod_time
160
+
161
+ def test_download_fixed():
162
+ """
163
+ Test that a model is fixed if the existing model doesn't match the md5sum
164
+ """
165
+ for download_method in (core.DownloadMethod.REUSE_RESOURCES, core.DownloadMethod.DOWNLOAD_RESOURCES):
166
+ check_download_method_updates(download_method)
167
+
168
+ def test_download_strings():
169
+ """
170
+ Same as the test of the download_method, but tests that the pipeline works for string download_method
171
+ """
172
+ for download_method in ("reuse_resources", "download_resources"):
173
+ check_download_method_updates(download_method)
174
+
175
+ def test_limited_pipeline():
176
+ """
177
+ Test loading a pipeline, but then only using a couple processors
178
+ """
179
+ pipe = stanza.Pipeline(processors="tokenize,pos,lemma,depparse,ner", dir=TEST_MODELS_DIR)
180
+ doc = pipe("John Bauer works at Stanford")
181
+ assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)
182
+ assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
183
+
184
+ doc = pipe("John Bauer works at Stanford", processors=["tokenize","pos"])
185
+ assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)
186
+ assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
187
+
188
+ doc = pipe("John Bauer works at Stanford", processors="tokenize")
189
+ assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)
190
+ assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
191
+
192
+ doc = pipe("John Bauer works at Stanford", processors="tokenize,ner")
193
+ assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)
194
+ assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
195
+
196
+ with pytest.raises(ValueError):
197
+ # this should fail
198
+ doc = pipe("John Bauer works at Stanford", processors="tokenize,depparse")
199
+
200
+ @pytest.fixture(scope="module")
201
+ def unknown_language_name():
202
+ resources = load_resources_json(model_dir=TEST_MODELS_DIR)
203
+ name = "en"
204
+ while name in resources:
205
+ name = name + "z"
206
+ assert name != "en"
207
+ return name
208
+
209
+ def test_empty_unknown_language(unknown_language_name):
210
+ """
211
+ Check that there is an error for trying to load an unknown language
212
+ """
213
+ with pytest.raises(ValueError):
214
+ pipe = stanza.Pipeline(unknown_language_name, download_method=None)
215
+
216
+ def test_unknown_language_tokenizer(unknown_language_name):
217
+ """
218
+ Test that loading tokenize works for an unknown language
219
+ """
220
+ base_pipe = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
221
+ # even if we one day add MWT to English, the tokenizer by itself should still work
222
+ tokenize_processor = base_pipe.processors["tokenize"]
223
+
224
+ pipe=stanza.Pipeline(unknown_language_name,
225
+ processors="tokenize",
226
+ allow_unknown_language=True,
227
+ tokenize_model_path=tokenize_processor.config['model_path'],
228
+ download_method=None)
229
+ doc = pipe("This is a test")
230
+ words = [x.text for x in doc.sentences[0].words]
231
+ assert words == ['This', 'is', 'a', 'test']
232
+
233
+
234
+ def test_unknown_language_mwt(unknown_language_name):
235
+ """
236
+ Test that loading tokenize & mwt works for an unknown language
237
+ """
238
+ base_pipe = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, processors="tokenize,mwt", download_method=None)
239
+ assert len(base_pipe.processors) == 2
240
+ tokenize_processor = base_pipe.processors["tokenize"]
241
+ mwt_processor = base_pipe.processors["mwt"]
242
+
243
+ pipe=stanza.Pipeline(unknown_language_name,
244
+ processors="tokenize,mwt",
245
+ allow_unknown_language=True,
246
+ tokenize_model_path=tokenize_processor.config['model_path'],
247
+ mwt_model_path=mwt_processor.config['model_path'],
248
+ download_method=None)
stanza/stanza/tests/pipeline/test_depparse.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic tests of the depparse processor boolean flags
3
+ """
4
+ import pytest
5
+
6
+ import stanza
7
+ from stanza.pipeline.core import PipelineRequirementsException
8
+ from stanza.utils.conll import CoNLL
9
+ from stanza.tests import *
10
+
11
+ pytestmark = pytest.mark.pipeline
12
+
13
+ # data for testing
14
+ EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard."
15
+
16
+ EN_DOC_CONLLU_PRETAGGED = """
17
+ 1 Barack Barack PROPN NNP Number=Sing 0 _ _ _
18
+ 2 Obama Obama PROPN NNP Number=Sing 1 _ _ _
19
+ 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 2 _ _ _
20
+ 4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 3 _ _ _
21
+ 5 in in ADP IN _ 4 _ _ _
22
+ 6 Hawaii Hawaii PROPN NNP Number=Sing 5 _ _ _
23
+ 7 . . PUNCT . _ 6 _ _ _
24
+
25
+ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 0 _ _ _
26
+ 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 _ _ _
27
+ 3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 2 _ _ _
28
+ 4 president president PROPN NNP Number=Sing 3 _ _ _
29
+ 5 in in ADP IN _ 4 _ _ _
30
+ 6 2008 2008 NUM CD NumType=Card 5 _ _ _
31
+ 7 . . PUNCT . _ 6 _ _ _
32
+
33
+ 1 Obama Obama PROPN NNP Number=Sing 0 _ _ _
34
+ 2 attended attend VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 1 _ _ _
35
+ 3 Harvard Harvard PROPN NNP Number=Sing 2 _ _ _
36
+ 4 . . PUNCT . _ 3 _ _ _
37
+
38
+
39
+ """.lstrip()
40
+
41
+ EN_DOC_DEPENDENCY_PARSES_GOLD = """
42
+ ('Barack', 4, 'nsubj:pass')
43
+ ('Obama', 1, 'flat')
44
+ ('was', 4, 'aux:pass')
45
+ ('born', 0, 'root')
46
+ ('in', 6, 'case')
47
+ ('Hawaii', 4, 'obl')
48
+ ('.', 4, 'punct')
49
+
50
+ ('He', 3, 'nsubj:pass')
51
+ ('was', 3, 'aux:pass')
52
+ ('elected', 0, 'root')
53
+ ('president', 3, 'xcomp')
54
+ ('in', 6, 'case')
55
+ ('2008', 3, 'obl')
56
+ ('.', 3, 'punct')
57
+
58
+ ('Obama', 2, 'nsubj')
59
+ ('attended', 0, 'root')
60
+ ('Harvard', 2, 'obj')
61
+ ('.', 2, 'punct')
62
+ """.strip()
63
+
64
+ @pytest.fixture(scope="module")
65
+ def en_depparse_pipeline():
66
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,pos,lemma,depparse')
67
+ return nlp
68
+
69
+ def test_depparse(en_depparse_pipeline):
70
+ doc = en_depparse_pipeline(EN_DOC)
71
+ assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join([sent.dependencies_string() for sent in doc.sentences])
72
+
73
+
74
+ def test_depparse_with_pretagged_doc():
75
+ nlp = stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en',
76
+ 'depparse_pretagged': True})
77
+
78
+ doc = CoNLL.conll2doc(input_str=EN_DOC_CONLLU_PRETAGGED)
79
+ processed_doc = nlp(doc)
80
+
81
+ assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join(
82
+ [sent.dependencies_string() for sent in processed_doc.sentences])
83
+
84
+
85
+ def test_raises_requirements_exception_if_pretagged_not_passed():
86
+ with pytest.raises(PipelineRequirementsException):
87
+ stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'})
stanza/stanza/tests/pipeline/test_english_pipeline.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic testing of the English pipeline
3
+ """
4
+
5
+ import pytest
6
+ import stanza
7
+ from stanza.utils.conll import CoNLL
8
+ from stanza.models.common.doc import Document
9
+
10
+ from stanza.tests import *
11
+ from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu
12
+
13
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
14
+
15
+ # data for testing
16
+ EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard."
17
+
18
+ EN_DOCS = ["Barack Obama was born in Hawaii.", "He was elected president in 2008.", "Obama attended Harvard."]
19
+
20
+ EN_DOC_TOKENS_GOLD = """
21
+ <Token id=1;words=[<Word id=1;text=Barack;lemma=Barack;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=nsubj:pass>]>
22
+ <Token id=2;words=[<Word id=2;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=1;deprel=flat>]>
23
+ <Token id=3;words=[<Word id=3;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=4;deprel=aux:pass>]>
24
+ <Token id=4;words=[<Word id=4;text=born;lemma=bear;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>]>
25
+ <Token id=5;words=[<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>]>
26
+ <Token id=6;words=[<Word id=6;text=Hawaii;lemma=Hawaii;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=obl>]>
27
+ <Token id=7;words=[<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=4;deprel=punct>]>
28
+
29
+ <Token id=1;words=[<Word id=1;text=He;lemma=he;upos=PRON;xpos=PRP;feats=Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs;head=3;deprel=nsubj:pass>]>
30
+ <Token id=2;words=[<Word id=2;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=3;deprel=aux:pass>]>
31
+ <Token id=3;words=[<Word id=3;text=elected;lemma=elect;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>]>
32
+ <Token id=4;words=[<Word id=4;text=president;lemma=president;upos=NOUN;xpos=NN;feats=Number=Sing;head=3;deprel=xcomp>]>
33
+ <Token id=5;words=[<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>]>
34
+ <Token id=6;words=[<Word id=6;text=2008;lemma=2008;upos=NUM;xpos=CD;feats=NumForm=Digit|NumType=Card;head=3;deprel=obl>]>
35
+ <Token id=7;words=[<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=3;deprel=punct>]>
36
+
37
+ <Token id=1;words=[<Word id=1;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=nsubj>]>
38
+ <Token id=2;words=[<Word id=2;text=attended;lemma=attend;upos=VERB;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=0;deprel=root>]>
39
+ <Token id=3;words=[<Word id=3;text=Harvard;lemma=Harvard;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=obj>]>
40
+ <Token id=4;words=[<Word id=4;text=.;lemma=.;upos=PUNCT;xpos=.;head=2;deprel=punct>]>
41
+ """.strip()
42
+
43
+ EN_DOC_WORDS_GOLD = """
44
+ <Word id=1;text=Barack;lemma=Barack;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=nsubj:pass>
45
+ <Word id=2;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=1;deprel=flat>
46
+ <Word id=3;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=4;deprel=aux:pass>
47
+ <Word id=4;text=born;lemma=bear;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>
48
+ <Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>
49
+ <Word id=6;text=Hawaii;lemma=Hawaii;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=obl>
50
+ <Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=4;deprel=punct>
51
+
52
+ <Word id=1;text=He;lemma=he;upos=PRON;xpos=PRP;feats=Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs;head=3;deprel=nsubj:pass>
53
+ <Word id=2;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=3;deprel=aux:pass>
54
+ <Word id=3;text=elected;lemma=elect;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>
55
+ <Word id=4;text=president;lemma=president;upos=NOUN;xpos=NN;feats=Number=Sing;head=3;deprel=xcomp>
56
+ <Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>
57
+ <Word id=6;text=2008;lemma=2008;upos=NUM;xpos=CD;feats=NumForm=Digit|NumType=Card;head=3;deprel=obl>
58
+ <Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=3;deprel=punct>
59
+
60
+ <Word id=1;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=nsubj>
61
+ <Word id=2;text=attended;lemma=attend;upos=VERB;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=0;deprel=root>
62
+ <Word id=3;text=Harvard;lemma=Harvard;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=obj>
63
+ <Word id=4;text=.;lemma=.;upos=PUNCT;xpos=.;head=2;deprel=punct>
64
+ """.strip()
65
+
66
+ EN_DOC_DEPENDENCY_PARSES_GOLD = """
67
+ ('Barack', 4, 'nsubj:pass')
68
+ ('Obama', 1, 'flat')
69
+ ('was', 4, 'aux:pass')
70
+ ('born', 0, 'root')
71
+ ('in', 6, 'case')
72
+ ('Hawaii', 4, 'obl')
73
+ ('.', 4, 'punct')
74
+
75
+ ('He', 3, 'nsubj:pass')
76
+ ('was', 3, 'aux:pass')
77
+ ('elected', 0, 'root')
78
+ ('president', 3, 'xcomp')
79
+ ('in', 6, 'case')
80
+ ('2008', 3, 'obl')
81
+ ('.', 3, 'punct')
82
+
83
+ ('Obama', 2, 'nsubj')
84
+ ('attended', 0, 'root')
85
+ ('Harvard', 2, 'obj')
86
+ ('.', 2, 'punct')
87
+ """.strip()
88
+
89
+ EN_DOC_CONLLU_GOLD = """
90
+ # text = Barack Obama was born in Hawaii.
91
+ # sent_id = 0
92
+ # constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))
93
+ # sentiment = 1
94
+ 1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON
95
+ 2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON
96
+ 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O
97
+ 4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O
98
+ 5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O
99
+ 6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ start_char=25|end_char=31|ner=S-GPE|SpaceAfter=No
100
+ 7 . . PUNCT . _ 4 punct _ start_char=31|end_char=32|ner=O|SpacesAfter=\\s\\s
101
+
102
+ # text = He was elected president in 2008.
103
+ # sent_id = 1
104
+ # constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))
105
+ # sentiment = 1
106
+ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=34|end_char=36|ner=O
107
+ 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=37|end_char=40|ner=O
108
+ 3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=41|end_char=48|ner=O
109
+ 4 president president NOUN NN Number=Sing 3 xcomp _ start_char=49|end_char=58|ner=O
110
+ 5 in in ADP IN _ 6 case _ start_char=59|end_char=61|ner=O
111
+ 6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ start_char=62|end_char=66|ner=S-DATE|SpaceAfter=No
112
+ 7 . . PUNCT . _ 3 punct _ start_char=66|end_char=67|ner=O|SpacesAfter=\\s\\s
113
+
114
+ # text = Obama attended Harvard.
115
+ # sent_id = 2
116
+ # constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))
117
+ # sentiment = 1
118
+ 1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=69|end_char=74|ner=S-PERSON
119
+ 2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=75|end_char=83|ner=O
120
+ 3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ start_char=84|end_char=91|ner=S-ORG|SpaceAfter=No
121
+ 4 . . PUNCT . _ 2 punct _ start_char=91|end_char=92|ner=O|SpaceAfter=No
122
+ """.strip()
123
+
124
+ EN_DOC_CONLLU_GOLD_MULTIDOC = """
125
+ # text = Barack Obama was born in Hawaii.
126
+ # sent_id = 0
127
+ # constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))
128
+ # sentiment = 1
129
+ 1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON
130
+ 2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON
131
+ 3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O
132
+ 4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O
133
+ 5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O
134
+ 6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ start_char=25|end_char=31|ner=S-GPE|SpaceAfter=No
135
+ 7 . . PUNCT . _ 4 punct _ start_char=31|end_char=32|ner=O|SpaceAfter=No
136
+
137
+ # text = He was elected president in 2008.
138
+ # sent_id = 1
139
+ # constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))
140
+ # sentiment = 1
141
+ 1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=0|end_char=2|ner=O
142
+ 2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=3|end_char=6|ner=O
143
+ 3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=7|end_char=14|ner=O
144
+ 4 president president NOUN NN Number=Sing 3 xcomp _ start_char=15|end_char=24|ner=O
145
+ 5 in in ADP IN _ 6 case _ start_char=25|end_char=27|ner=O
146
+ 6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ start_char=28|end_char=32|ner=S-DATE|SpaceAfter=No
147
+ 7 . . PUNCT . _ 3 punct _ start_char=32|end_char=33|ner=O|SpaceAfter=No
148
+
149
+ # text = Obama attended Harvard.
150
+ # sent_id = 2
151
+ # constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))
152
+ # sentiment = 1
153
+ 1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=5|ner=S-PERSON
154
+ 2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=6|end_char=14|ner=O
155
+ 3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ start_char=15|end_char=22|ner=S-ORG|SpaceAfter=No
156
+ 4 . . PUNCT . _ 2 punct _ start_char=22|end_char=23|ner=O|SpaceAfter=No
157
+ """.strip()
158
+
159
+ class TestEnglishPipeline:
160
+ @pytest.fixture(scope="class")
161
+ def pipeline(self):
162
+ return stanza.Pipeline(dir=TEST_MODELS_DIR)
163
+
164
+ @pytest.fixture(scope="class")
165
+ def processed_doc(self, pipeline):
166
+ """ Document created by running full English pipeline on a few sentences """
167
+ return pipeline(EN_DOC)
168
+
169
+ def test_text(self, processed_doc):
170
+ assert processed_doc.text == EN_DOC
171
+
172
+
173
+ def test_conllu(self, processed_doc):
174
+ assert "{:C}".format(processed_doc) == EN_DOC_CONLLU_GOLD
175
+
176
+
177
+ def test_tokens(self, processed_doc):
178
+ assert "\n\n".join([sent.tokens_string() for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD
179
+
180
+
181
+ def test_words(self, processed_doc):
182
+ assert "\n\n".join([sent.words_string() for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD
183
+
184
+
185
+ def test_dependency_parse(self, processed_doc):
186
+ assert "\n\n".join([sent.dependencies_string() for sent in processed_doc.sentences]) == \
187
+ EN_DOC_DEPENDENCY_PARSES_GOLD
188
+
189
+ def test_empty(self, pipeline):
190
+ # make sure that various models handle the degenerate empty case
191
+ pipeline("")
192
+ pipeline("--")
193
+
194
+ def test_bulk_process(self, pipeline):
195
+ """ Double check that the bulk_process method in Pipeline converts documents as expected """
196
+ # it should process strings
197
+ processed = pipeline.bulk_process(EN_DOCS)
198
+ assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
199
+
200
+ # it should pass Documents through successfully
201
+ docs = [Document([], text=t) for t in EN_DOCS]
202
+ processed = pipeline.bulk_process(docs)
203
+ assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
204
+
205
+ def test_empty_bulk_process(self, pipeline):
206
+ """ Previously we had a bug where an empty document list would cause a crash """
207
+ processed = pipeline.bulk_process([])
208
+ assert processed == []
209
+
210
+ def test_stream(self, pipeline):
211
+ """ Test the streaming interface to the Pipeline """
212
+ # Test all of the documents in one batch
213
+ # (the default batch size is significantly more than |EN_DOCS|)
214
+ processed = [doc for doc in pipeline.stream(EN_DOCS)]
215
+ assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
216
+
217
+ # It should also work on an iterator rather than an iterable
218
+ processed = [doc for doc in pipeline.stream(iter(EN_DOCS))]
219
+ assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
220
+
221
+ # Stream one at a time
222
+ processed = [doc for doc in pipeline.stream(EN_DOCS, batch_size=1)]
223
+ processed = ["{:C}".format(doc) for doc in processed]
224
+ assert "\n\n".join(processed) == EN_DOC_CONLLU_GOLD_MULTIDOC
225
+
226
+ @pytest.fixture(scope="class")
227
+ def processed_multidoc(self, pipeline):
228
+ """ Document created by running full English pipeline on a few sentences """
229
+ docs = [Document([], text=t) for t in EN_DOCS]
230
+ return pipeline(docs)
231
+
232
+ def test_conllu_multidoc(self, processed_multidoc):
233
+ assert "\n\n".join(["{:C}".format(doc) for doc in processed_multidoc]) == EN_DOC_CONLLU_GOLD_MULTIDOC
234
+
235
+ def test_tokens_multidoc(self, processed_multidoc):
236
+ assert "\n\n".join([sent.tokens_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD
237
+
238
+
239
+ def test_words_multidoc(self, processed_multidoc):
240
+ assert "\n\n".join([sent.words_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD
241
+
242
+ def test_sentence_indices_multidoc(self, processed_multidoc):
243
+ sentences = [sent for doc in processed_multidoc for sent in doc.sentences]
244
+ for sent_idx, sentence in enumerate(sentences):
245
+ assert sent_idx == sentence.index
246
+
247
+ def test_dependency_parse_multidoc(self, processed_multidoc):
248
+ assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == \
249
+ EN_DOC_DEPENDENCY_PARSES_GOLD
250
+
251
+
252
+ @pytest.fixture(scope="class")
253
+ def processed_multidoc_variant(self):
254
+ """ Document created by running full English pipeline on a few sentences """
255
+ docs = [Document([], text=t) for t in EN_DOCS]
256
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors={'tokenize': 'spacy'})
257
+ return nlp(docs)
258
+
259
+ def test_dependency_parse_multidoc_variant(self, processed_multidoc_variant):
260
+ assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc_variant for sent in processed_doc.sentences]) == \
261
+ EN_DOC_DEPENDENCY_PARSES_GOLD
262
+
263
+ def test_constituency_parser(self):
264
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency")
265
+ doc = nlp("This is a test")
266
+ assert str(doc.sentences[0].constituency) == '(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))'
267
+
268
+ def test_on_gpu(self, pipeline):
269
+ """
270
+ The default pipeline should have all the models on the GPU
271
+ """
272
+ check_on_gpu(pipeline)
273
+
274
+ def test_on_cpu(self):
275
+ """
276
+ Create a pipeline on the CPU, check that all the models on CPU
277
+ """
278
+ pipeline = stanza.Pipeline("en", dir=TEST_MODELS_DIR, use_gpu=False)
279
+ check_on_cpu(pipeline)
stanza/stanza/tests/pipeline/test_french_pipeline.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic testing of French pipeline
3
+
4
+ The benefit of this test is to verify that the bulk processing works
5
+ for languages with MWT in them
6
+ """
7
+
8
+ import pytest
9
+ import stanza
10
+ from stanza.models.common.doc import Document
11
+
12
+ from stanza.tests import *
13
+ from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu
14
+
15
+ pytestmark = pytest.mark.pipeline
16
+
17
+
18
+ FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \
19
+ "l'Industrie et du Numérique."
20
+
21
+ EXPECTED_RESULT = """
22
+ [
23
+ [
24
+ {
25
+ "id": 1,
26
+ "text": "Alors",
27
+ "lemma": "alors",
28
+ "upos": "ADV",
29
+ "head": 3,
30
+ "deprel": "advmod",
31
+ "start_char": 0,
32
+ "end_char": 5
33
+ },
34
+ {
35
+ "id": 2,
36
+ "text": "encore",
37
+ "lemma": "encore",
38
+ "upos": "ADV",
39
+ "head": 3,
40
+ "deprel": "advmod",
41
+ "start_char": 6,
42
+ "end_char": 12
43
+ },
44
+ {
45
+ "id": 3,
46
+ "text": "inconnu",
47
+ "lemma": "inconnu",
48
+ "upos": "ADJ",
49
+ "feats": "Gender=Masc|Number=Sing",
50
+ "head": 11,
51
+ "deprel": "advcl",
52
+ "start_char": 13,
53
+ "end_char": 20
54
+ },
55
+ {
56
+ "id": [
57
+ 4,
58
+ 5
59
+ ],
60
+ "text": "du",
61
+ "start_char": 21,
62
+ "end_char": 23
63
+ },
64
+ {
65
+ "id": 4,
66
+ "text": "de",
67
+ "lemma": "de",
68
+ "upos": "ADP",
69
+ "head": 7,
70
+ "deprel": "case"
71
+ },
72
+ {
73
+ "id": 5,
74
+ "text": "le",
75
+ "lemma": "le",
76
+ "upos": "DET",
77
+ "feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art",
78
+ "head": 7,
79
+ "deprel": "det"
80
+ },
81
+ {
82
+ "id": 6,
83
+ "text": "grand",
84
+ "lemma": "grand",
85
+ "upos": "ADJ",
86
+ "feats": "Gender=Masc|Number=Sing",
87
+ "head": 7,
88
+ "deprel": "amod",
89
+ "start_char": 24,
90
+ "end_char": 29
91
+ },
92
+ {
93
+ "id": 7,
94
+ "text": "public",
95
+ "lemma": "public",
96
+ "upos": "NOUN",
97
+ "feats": "Gender=Masc|Number=Sing",
98
+ "head": 3,
99
+ "deprel": "obl:arg",
100
+ "start_char": 30,
101
+ "end_char": 36,
102
+ "misc": "SpaceAfter=No"
103
+ },
104
+ {
105
+ "id": 8,
106
+ "text": ",",
107
+ "lemma": ",",
108
+ "upos": "PUNCT",
109
+ "head": 3,
110
+ "deprel": "punct",
111
+ "start_char": 36,
112
+ "end_char": 37
113
+ },
114
+ {
115
+ "id": 9,
116
+ "text": "Emmanuel",
117
+ "lemma": "Emmanuel",
118
+ "upos": "PROPN",
119
+ "head": 11,
120
+ "deprel": "nsubj",
121
+ "start_char": 38,
122
+ "end_char": 46
123
+ },
124
+ {
125
+ "id": 10,
126
+ "text": "Macron",
127
+ "lemma": "Macron",
128
+ "upos": "PROPN",
129
+ "head": 9,
130
+ "deprel": "flat:name",
131
+ "start_char": 47,
132
+ "end_char": 53
133
+ },
134
+ {
135
+ "id": 11,
136
+ "text": "devient",
137
+ "lemma": "devenir",
138
+ "upos": "VERB",
139
+ "feats": "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
140
+ "head": 0,
141
+ "deprel": "root",
142
+ "start_char": 54,
143
+ "end_char": 61
144
+ },
145
+ {
146
+ "id": 12,
147
+ "text": "en",
148
+ "lemma": "en",
149
+ "upos": "ADP",
150
+ "head": 13,
151
+ "deprel": "case",
152
+ "start_char": 62,
153
+ "end_char": 64
154
+ },
155
+ {
156
+ "id": 13,
157
+ "text": "2014",
158
+ "lemma": "2014",
159
+ "upos": "NUM",
160
+ "feats": "Number=Plur",
161
+ "head": 11,
162
+ "deprel": "obl:mod",
163
+ "start_char": 65,
164
+ "end_char": 69
165
+ },
166
+ {
167
+ "id": 14,
168
+ "text": "ministre",
169
+ "lemma": "ministre",
170
+ "upos": "NOUN",
171
+ "feats": "Gender=Masc|Number=Sing",
172
+ "head": 11,
173
+ "deprel": "xcomp",
174
+ "start_char": 70,
175
+ "end_char": 78
176
+ },
177
+ {
178
+ "id": 15,
179
+ "text": "de",
180
+ "lemma": "de",
181
+ "upos": "ADP",
182
+ "head": 17,
183
+ "deprel": "case",
184
+ "start_char": 79,
185
+ "end_char": 81
186
+ },
187
+ {
188
+ "id": 16,
189
+ "text": "l'",
190
+ "lemma": "le",
191
+ "upos": "DET",
192
+ "feats": "Definite=Def|Number=Sing|PronType=Art",
193
+ "head": 17,
194
+ "deprel": "det",
195
+ "start_char": 82,
196
+ "end_char": 84,
197
+ "misc": "SpaceAfter=No"
198
+ },
199
+ {
200
+ "id": 17,
201
+ "text": "Économie",
202
+ "lemma": "économie",
203
+ "upos": "NOUN",
204
+ "feats": "Gender=Fem|Number=Sing",
205
+ "head": 14,
206
+ "deprel": "nmod",
207
+ "start_char": 84,
208
+ "end_char": 92,
209
+ "misc": "SpaceAfter=No"
210
+ },
211
+ {
212
+ "id": 18,
213
+ "text": ",",
214
+ "lemma": ",",
215
+ "upos": "PUNCT",
216
+ "head": 21,
217
+ "deprel": "punct",
218
+ "start_char": 92,
219
+ "end_char": 93
220
+ },
221
+ {
222
+ "id": 19,
223
+ "text": "de",
224
+ "lemma": "de",
225
+ "upos": "ADP",
226
+ "head": 21,
227
+ "deprel": "case",
228
+ "start_char": 94,
229
+ "end_char": 96
230
+ },
231
+ {
232
+ "id": 20,
233
+ "text": "l'",
234
+ "lemma": "le",
235
+ "upos": "DET",
236
+ "feats": "Definite=Def|Number=Sing|PronType=Art",
237
+ "head": 21,
238
+ "deprel": "det",
239
+ "start_char": 97,
240
+ "end_char": 99,
241
+ "misc": "SpaceAfter=No"
242
+ },
243
+ {
244
+ "id": 21,
245
+ "text": "Industrie",
246
+ "lemma": "industrie",
247
+ "upos": "NOUN",
248
+ "feats": "Gender=Fem|Number=Sing",
249
+ "head": 17,
250
+ "deprel": "conj",
251
+ "start_char": 99,
252
+ "end_char": 108
253
+ },
254
+ {
255
+ "id": 22,
256
+ "text": "et",
257
+ "lemma": "et",
258
+ "upos": "CCONJ",
259
+ "head": 25,
260
+ "deprel": "cc",
261
+ "start_char": 109,
262
+ "end_char": 111
263
+ },
264
+ {
265
+ "id": [
266
+ 23,
267
+ 24
268
+ ],
269
+ "text": "du",
270
+ "start_char": 112,
271
+ "end_char": 114
272
+ },
273
+ {
274
+ "id": 23,
275
+ "text": "de",
276
+ "lemma": "de",
277
+ "upos": "ADP",
278
+ "head": 25,
279
+ "deprel": "case"
280
+ },
281
+ {
282
+ "id": 24,
283
+ "text": "le",
284
+ "lemma": "le",
285
+ "upos": "DET",
286
+ "feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art",
287
+ "head": 25,
288
+ "deprel": "det"
289
+ },
290
+ {
291
+ "id": 25,
292
+ "text": "Numérique",
293
+ "lemma": "numérique",
294
+ "upos": "NOUN",
295
+ "feats": "Gender=Masc|Number=Sing",
296
+ "head": 17,
297
+ "deprel": "conj",
298
+ "start_char": 115,
299
+ "end_char": 124,
300
+ "misc": "SpaceAfter=No"
301
+ },
302
+ {
303
+ "id": 26,
304
+ "text": ".",
305
+ "lemma": ".",
306
+ "upos": "PUNCT",
307
+ "head": 11,
308
+ "deprel": "punct",
309
+ "start_char": 124,
310
+ "end_char": 125,
311
+ "misc": "SpaceAfter=No"
312
+ }
313
+ ]
314
+ ]
315
+ """
316
+
317
+ class TestFrenchPipeline:
318
+ @pytest.fixture(scope="class")
319
+ def pipeline(self):
320
+ """ Create a pipeline with French models """
321
+ pipeline = stanza.Pipeline(processors='tokenize,mwt,pos,lemma,depparse', dir=TEST_MODELS_DIR, lang='fr')
322
+ return pipeline
323
+
324
+ def test_single(self, pipeline):
325
+ doc = pipeline(FR_MWT_SENTENCE)
326
+ compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)
327
+
328
+ def test_bulk(self, pipeline):
329
+ NUM_DOCS = 10
330
+ raw_text = [FR_MWT_SENTENCE] * NUM_DOCS
331
+ raw_doc = [Document([], text=doccontent) for doccontent in raw_text]
332
+
333
+ result = pipeline(raw_doc)
334
+
335
+ assert len(result) == NUM_DOCS
336
+ for doc in result:
337
+ compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)
338
+ assert len(doc.sentences) == 1
339
+ assert doc.num_words == 26
340
+ assert doc.num_tokens == 24
341
+
342
+ def test_on_gpu(self, pipeline):
343
+ """
344
+ The default pipeline should have all the models on the GPU
345
+ """
346
+ check_on_gpu(pipeline)
347
+
348
+ def test_on_cpu(self):
349
+ """
350
+ Create a pipeline on the CPU, check that all the models on CPU
351
+ """
352
+ pipeline = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, use_gpu=False)
353
+ check_on_cpu(pipeline)
stanza/stanza/tests/pipeline/test_lemmatizer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic testing of lemmatization
3
+ """
4
+
5
+ import pytest
6
+ import stanza
7
+
8
+ from stanza.tests import *
9
+ from stanza.models.common.doc import TEXT, UPOS, LEMMA
10
+
11
+ pytestmark = pytest.mark.pipeline
12
+
13
+ EN_DOC = "Joe Smith was born in California."
14
+
15
+ EN_DOC_IDENTITY_GOLD = """
16
+ Joe Joe
17
+ Smith Smith
18
+ was was
19
+ born born
20
+ in in
21
+ California California
22
+ . .
23
+ """.strip()
24
+
25
+ EN_DOC_LEMMATIZER_MODEL_GOLD = """
26
+ Joe Joe
27
+ Smith Smith
28
+ was be
29
+ born bear
30
+ in in
31
+ California California
32
+ . .
33
+ """.strip()
34
+
35
+
36
+ def test_identity_lemmatizer():
37
+ nlp = stanza.Pipeline(**{'processors': 'tokenize,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_use_identity': True}, download_method=None)
38
+ doc = nlp(EN_DOC)
39
+ word_lemma_pairs = []
40
+ for w in doc.iter_words():
41
+ word_lemma_pairs += [f"{w.text} {w.lemma}"]
42
+ assert EN_DOC_IDENTITY_GOLD == "\n".join(word_lemma_pairs)
43
+
44
+ def test_full_lemmatizer():
45
+ nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, download_method=None)
46
+ doc = nlp(EN_DOC)
47
+ word_lemma_pairs = []
48
+ for w in doc.iter_words():
49
+ word_lemma_pairs += [f"{w.text} {w.lemma}"]
50
+ assert EN_DOC_LEMMATIZER_MODEL_GOLD == "\n".join(word_lemma_pairs)
51
+
52
+ def find_unknown_word(lemmatizer, base):
53
+ for i in range(10):
54
+ base = base + "z"
55
+ if base not in lemmatizer.word_dict and all(x[0] != base for x in lemmatizer.composite_dict.keys()):
56
+ return base
57
+ raise RuntimeError("wtf?")
58
+
59
+ def test_store_results():
60
+ nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, lemma_store_results=True, download_method=None)
61
+ lemmatizer = nlp.processors["lemma"]._trainer
62
+
63
+ az = find_unknown_word(lemmatizer, "a")
64
+ bz = find_unknown_word(lemmatizer, "b")
65
+ cz = find_unknown_word(lemmatizer, "c")
66
+
67
+ # try sentences with the order long, short
68
+ doc = nlp("I found an " + az + " in my " + bz + ". It was a " + cz)
69
+ stuff = doc.get([TEXT, UPOS, LEMMA])
70
+ assert len(stuff) == 12
71
+ assert stuff[3][0] == az
72
+ assert stuff[6][0] == bz
73
+ assert stuff[11][0] == cz
74
+
75
+ assert lemmatizer.composite_dict[(az, stuff[3][1])] == stuff[3][2]
76
+ assert lemmatizer.composite_dict[(bz, stuff[6][1])] == stuff[6][2]
77
+ assert lemmatizer.composite_dict[(cz, stuff[11][1])] == stuff[11][2]
78
+
79
+ doc2 = nlp("I found an " + az + " in my " + bz + ". It was a " + cz)
80
+ stuff2 = doc2.get([TEXT, UPOS, LEMMA])
81
+
82
+ assert stuff == stuff2
83
+
84
+ dz = find_unknown_word(lemmatizer, "d")
85
+ ez = find_unknown_word(lemmatizer, "e")
86
+ fz = find_unknown_word(lemmatizer, "f")
87
+
88
+ # try sentences with the order long, short
89
+ doc = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz)
90
+ stuff = doc.get([TEXT, UPOS, LEMMA])
91
+ assert len(stuff) == 12
92
+ assert stuff[3][0] == dz
93
+ assert stuff[8][0] == ez
94
+ assert stuff[11][0] == fz
95
+
96
+ assert lemmatizer.composite_dict[(dz, stuff[3][1])] == stuff[3][2]
97
+ assert lemmatizer.composite_dict[(ez, stuff[8][1])] == stuff[8][2]
98
+ assert lemmatizer.composite_dict[(fz, stuff[11][1])] == stuff[11][2]
99
+
100
+ doc2 = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz)
101
+ stuff2 = doc2.get([TEXT, UPOS, LEMMA])
102
+
103
+ assert stuff == stuff2
104
+
105
+ assert az not in lemmatizer.word_dict
106
+
107
+ def test_caseless_lemmatizer():
108
+ """
109
+ Test that setting the lemmatizer as caseless at Pipeline time lowercases the text
110
+ """
111
+ nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)
112
+ # the capital letter here should throw off the lemmatizer & it won't remove the plural
113
+ # although weirdly the current English model *does* lowercase the A
114
+ doc = nlp("Here is an Excerpt")
115
+ assert doc.sentences[0].words[-1].lemma == 'excerpt'
116
+
117
+ nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None, lemma_caseless=True)
118
+ # with the model set to lowercasing, the word will be treated as if it were 'antennae'
119
+ doc = nlp("Here is an Excerpt")
120
+ assert doc.sentences[0].words[-1].lemma == 'Excerpt'
121
+
122
+ def test_latin_caseless_lemmatizer():
123
+ """
124
+ Test the Latin caseless lemmatizer
125
+ """
126
+ nlp = stanza.Pipeline('la', package='ittb', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)
127
+ lemmatizer = nlp.processors['lemma']
128
+ assert lemmatizer.config['caseless']
129
+
130
+ doc = nlp("Quod Erat Demonstrandum")
131
+ expected_lemmas = "qui sum demonstro".split()
132
+ assert len(doc.sentences) == 1
133
+ assert len(doc.sentences[0].words) == 3
134
+ for word, expected in zip(doc.sentences[0].words, expected_lemmas):
135
+ assert word.lemma == expected
stanza/stanza/tests/pipeline/test_pipeline_constituency_processor.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import pytest
3
+ import stanza
4
+ from stanza.models.common.foundation_cache import FoundationCache
5
+
6
+ from stanza.tests import *
7
+
8
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
9
+
10
+ # data for testing
11
+ TEST_TEXT = "This is a test. Another sentence. Are these sorted?"
12
+
13
+ TEST_TOKENS = [["This", "is", "a", "test", "."], ["Another", "sentence", "."], ["Are", "these", "sorted", "?"]]
14
+
15
+ @pytest.fixture(scope="module")
16
+ def foundation_cache():
17
+ gc.collect()
18
+ return FoundationCache()
19
+
20
+ def check_results(doc):
21
+ assert len(doc.sentences) == len(TEST_TOKENS)
22
+ for sentence, expected in zip(doc.sentences, TEST_TOKENS):
23
+ assert sentence.constituency.leaf_labels() == expected
24
+
25
+ def test_sorted_big_batch(foundation_cache):
26
+ pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
27
+ doc = pipe(TEST_TEXT)
28
+ check_results(doc)
29
+
30
+ def test_comments(foundation_cache):
31
+ """
32
+ Test that the pipeline is creating constituency comments
33
+ """
34
+ pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
35
+ doc = pipe(TEST_TEXT)
36
+ check_results(doc)
37
+ for sentence in doc.sentences:
38
+ assert any(x.startswith("# constituency = ") for x in sentence.comments)
39
+ doc.sentences[0].constituency = "asdf"
40
+ assert "# constituency = asdf" in doc.sentences[0].comments
41
+ for sentence in doc.sentences:
42
+ assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 1
43
+
44
+ def test_illegal_batch_size(foundation_cache):
45
+ stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None)
46
+ with pytest.raises(ValueError):
47
+ stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None)
48
+
49
+ def test_sorted_one_batch(foundation_cache):
50
+ pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=1, foundation_cache=foundation_cache, download_method=None)
51
+ doc = pipe(TEST_TEXT)
52
+ check_results(doc)
53
+
54
+ def test_sorted_two_batch(foundation_cache):
55
+ pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=2, foundation_cache=foundation_cache, download_method=None)
56
+ doc = pipe(TEST_TEXT)
57
+ check_results(doc)
58
+
59
+ def test_get_constituents(foundation_cache):
60
+ pipe = stanza.Pipeline("en", processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
61
+ assert "SBAR" in pipe.processors["constituency"].get_constituents()