Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- stanza/stanza/models/constituency_parser.py +881 -0
- stanza/stanza/models/lemmatizer.py +313 -0
- stanza/stanza/pipeline/_constants.py +13 -0
- stanza/stanza/pipeline/external/spacy.py +74 -0
- stanza/stanza/pipeline/ner_processor.py +143 -0
- stanza/stanza/resources/print_charlm_depparse.py +22 -0
- stanza/stanza/server/dependency_converter.py +101 -0
- stanza/stanza/tests/classifiers/test_constituency_classifier.py +128 -0
- stanza/stanza/tests/common/__init__.py +0 -0
- stanza/stanza/tests/common/test_chuliu_edmonds.py +36 -0
- stanza/stanza/tests/common/test_confusion.py +81 -0
- stanza/stanza/tests/common/test_constant.py +67 -0
- stanza/stanza/tests/common/test_data_conversion.py +520 -0
- stanza/stanza/tests/common/test_foundation_cache.py +36 -0
- stanza/stanza/tests/common/test_pretrain.py +139 -0
- stanza/stanza/tests/common/test_utils.py +194 -0
- stanza/stanza/tests/constituency/__init__.py +0 -0
- stanza/stanza/tests/constituency/test_convert_arboretum.py +235 -0
- stanza/stanza/tests/constituency/test_ensemble.py +110 -0
- stanza/stanza/tests/constituency/test_in_order_compound_oracle.py +93 -0
- stanza/stanza/tests/constituency/test_parse_transitions.py +486 -0
- stanza/stanza/tests/constituency/test_parse_tree.py +369 -0
- stanza/stanza/tests/constituency/test_positional_encoding.py +45 -0
- stanza/stanza/tests/constituency/test_selftrain_vi_quad.py +23 -0
- stanza/stanza/tests/constituency/test_utils.py +68 -0
- stanza/stanza/tests/data/example_french.json +22 -0
- stanza/stanza/tests/data/test.dat +0 -0
- stanza/stanza/tests/data/tiny_emb.csv +4 -0
- stanza/stanza/tests/datasets/__init__.py +0 -0
- stanza/stanza/tests/datasets/ner/__init__.py +0 -0
- stanza/stanza/tests/datasets/ner/test_prepare_ner_file.py +77 -0
- stanza/stanza/tests/datasets/ner/test_utils.py +34 -0
- stanza/stanza/tests/lemma/test_data.py +106 -0
- stanza/stanza/tests/lemma/test_lemma_trainer.py +154 -0
- stanza/stanza/tests/lemma_classifier/test_data_preparation.py +256 -0
- stanza/stanza/tests/mwt/test_character_classifier.py +92 -0
- stanza/stanza/tests/mwt/test_english_corner_cases.py +88 -0
- stanza/stanza/tests/ner/test_bsf_2_iob.py +93 -0
- stanza/stanza/tests/ner/test_convert_amt.py +104 -0
- stanza/stanza/tests/ner/test_convert_starlang_ner.py +23 -0
- stanza/stanza/tests/ner/test_from_conllu.py +30 -0
- stanza/stanza/tests/ner/test_ner_utils.py +129 -0
- stanza/stanza/tests/pipeline/__init__.py +0 -0
- stanza/stanza/tests/pipeline/test_arabic_pipeline.py +27 -0
- stanza/stanza/tests/pipeline/test_core.py +248 -0
- stanza/stanza/tests/pipeline/test_depparse.py +87 -0
- stanza/stanza/tests/pipeline/test_english_pipeline.py +279 -0
- stanza/stanza/tests/pipeline/test_french_pipeline.py +353 -0
- stanza/stanza/tests/pipeline/test_lemmatizer.py +135 -0
- stanza/stanza/tests/pipeline/test_pipeline_constituency_processor.py +61 -0
stanza/stanza/models/constituency_parser.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A command line interface to a shift reduce constituency parser.
|
| 2 |
+
|
| 3 |
+
This follows the work of
|
| 4 |
+
Recurrent neural network grammars by Dyer et al
|
| 5 |
+
In-Order Transition-based Constituent Parsing by Liu & Zhang
|
| 6 |
+
|
| 7 |
+
The general outline is:
|
| 8 |
+
|
| 9 |
+
Train a model by taking a list of trees, converting them to
|
| 10 |
+
transition sequences, and learning a model which can predict the
|
| 11 |
+
next transition given a current state
|
| 12 |
+
Then, at inference time, repeatedly predict the next transition until parsing is complete
|
| 13 |
+
|
| 14 |
+
The "transitions" are variations on shift/reduce as per an
|
| 15 |
+
intro-to-compilers class. The idea is that you can treat all of the
|
| 16 |
+
words in a sentence as a buffer of tokens, then either "shift" them to
|
| 17 |
+
represent a new constituent, or "reduce" one or more constituents to
|
| 18 |
+
form a new constituent.
|
| 19 |
+
|
| 20 |
+
In order to make the runtime a more competitive speed, effort is taken
|
| 21 |
+
to batch the transitions and apply multiple transitions at once. At
|
| 22 |
+
train time, batches are groups together by length, and at inference
|
| 23 |
+
time, new trees are added to the batch as previous trees on the batch
|
| 24 |
+
finish their inference.
|
| 25 |
+
|
| 26 |
+
There are a few minor differences in the model:
|
| 27 |
+
- The word input is a bi-lstm, not a uni-lstm.
|
| 28 |
+
This gave a small increase in accuracy.
|
| 29 |
+
- The combination of several constituents into one constituent is done
|
| 30 |
+
via a single bi-lstm rather than two separate lstms. This increases
|
| 31 |
+
speed without a noticeable effect on accuracy.
|
| 32 |
+
- In fact, an even better (in terms of final model accuracy) method
|
| 33 |
+
is to combine the constituents with torch.max, believe it or not
|
| 34 |
+
See lstm_model.py for more details
|
| 35 |
+
- Initializing the embeddings with smaller values than pytorch default
|
| 36 |
+
For example, on a ja_alt dataset, scores went from 0.8980 to 0.8985
|
| 37 |
+
at 200 iterations averaged over 5 trials
|
| 38 |
+
- Training with AdaDelta first, then AdamW or madgrad later improves
|
| 39 |
+
results quite a bit. See --multistage
|
| 40 |
+
|
| 41 |
+
A couple experiments which have been tried with little noticeable impact:
|
| 42 |
+
- Combining constituents using the method in the paper (only a trained
|
| 43 |
+
vector at the start instead of both ends) did not affect results
|
| 44 |
+
and is a little slower
|
| 45 |
+
- Using multiple layers of LSTM hidden state for the input to the final
|
| 46 |
+
classification layers didn't help
|
| 47 |
+
- Initializing Linear layers with He initialization and a positive bias
|
| 48 |
+
(to avoid dead connections) had no noticeable effect on accuracy
|
| 49 |
+
0.8396 on it_turin with the original initialization
|
| 50 |
+
0.8401 and 0.8427 on two runs with updated initialization
|
| 51 |
+
(so maybe a small improvement...)
|
| 52 |
+
- Initializing LSTM layers with different gates was slightly worse:
|
| 53 |
+
forget gates of 1.0
|
| 54 |
+
forget gates of 1.0, input gates of -1.0
|
| 55 |
+
- Replacing the LSTMs that make up the Transition and Constituent
|
| 56 |
+
LSTMs with Dynamic Skip LSTMs made no difference, but was slower
|
| 57 |
+
- Highway LSTMs also made no difference
|
| 58 |
+
- Putting labels on the shift transitions (the word or the tag shifted)
|
| 59 |
+
or putting labels on the close transitions didn't help
|
| 60 |
+
- Building larger constituents from the output of the constituent LSTM
|
| 61 |
+
instead of the children constituents hurts scores
|
| 62 |
+
For example, an experiment on ja_alt went from 0.8985 to 0.8964
|
| 63 |
+
when built that way
|
| 64 |
+
- The initial transition scheme implemented was TOP_DOWN. We tried
|
| 65 |
+
a compound unary option, since this worked so well in the CoreNLP
|
| 66 |
+
constituency parser. Unfortunately, this is far less effective
|
| 67 |
+
than IN_ORDER. Both specialized unary matrices and reusing the
|
| 68 |
+
n-ary constituency combination fell short. On the ja_alt dataset:
|
| 69 |
+
IN_ORDER, max combination method: 0.8985
|
| 70 |
+
TOP_DOWN_UNARY, specialized matrices: 0.8501
|
| 71 |
+
TOP_DOWN_UNARY, max combination method: 0.8508
|
| 72 |
+
- Adding multiple layers of MLP to combine inputs for words made
|
| 73 |
+
no difference in the scores
|
| 74 |
+
Tried both before the LSTM and after
|
| 75 |
+
A simple single layer tensor multiply after the LSTM works well.
|
| 76 |
+
Replacing that with a two layer MLP on the English PTB
|
| 77 |
+
with roberta-base causes a notable drop in scores
|
| 78 |
+
First experiment didn't use the fancy Linear weight init,
|
| 79 |
+
but adding that barely made a difference
|
| 80 |
+
260 training iterations on en_wsj dev, roberta-base
|
| 81 |
+
model as of bb983fd5e912f6706ad484bf819486971742c3d1
|
| 82 |
+
two layer MLP: 0.9409
|
| 83 |
+
two layer MLP, init weights: 0.9413
|
| 84 |
+
single layer: 0.9467
|
| 85 |
+
- There is code to rebuild models with a new structure in lstm_model.py
|
| 86 |
+
As part of this, we tried to randomly reinitialize the transitions
|
| 87 |
+
if the transition embedding had gone to 0, which often happens
|
| 88 |
+
This didn't help at all
|
| 89 |
+
- We tried something akin to attention with just the query vector
|
| 90 |
+
over the bert embeddings as a way to mix them, but that did not
|
| 91 |
+
improve scores.
|
| 92 |
+
Example, with a self.bert_layer_mix of size bert_dim x 1:
|
| 93 |
+
mixed_bert_embeddings = []
|
| 94 |
+
for feature in bert_embeddings:
|
| 95 |
+
weighted_feature = self.bert_layer_mix(feature.transpose(1, 2))
|
| 96 |
+
weighted_feature = torch.softmax(weighted_feature, dim=1)
|
| 97 |
+
weighted_feature = torch.matmul(feature, weighted_feature).squeeze(2)
|
| 98 |
+
mixed_bert_embeddings.append(weighted_feature)
|
| 99 |
+
bert_embeddings = mixed_bert_embeddings
|
| 100 |
+
It seems just finetuning the transformer is already enough
|
| 101 |
+
(in general, no need to mix layers at all when finetuning bert embeddings)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
The code breakdown is as follows:
|
| 105 |
+
|
| 106 |
+
this file: main interface for training or evaluating models
|
| 107 |
+
constituency/trainer.py: contains the training & evaluation code
|
| 108 |
+
constituency/ensemble.py: evaluation code specifically for letting multiple models
|
| 109 |
+
vote on the correct next transition. a modest improvement.
|
| 110 |
+
constituency/evaluate_treebanks.py: specifically to evaluate multiple parsed treebanks
|
| 111 |
+
against a gold. in particular, reports whether the theoretical best from those
|
| 112 |
+
parsed treebanks is an improvement (eg, the k-best score as reported by CoreNLP)
|
| 113 |
+
|
| 114 |
+
constituency/parse_tree.py: a data structure for representing a parse tree and utility methods
|
| 115 |
+
constituency/tree_reader.py: a module which can read trees from a string or input file
|
| 116 |
+
|
| 117 |
+
constituency/tree_stack.py: a linked list which can branch in
|
| 118 |
+
different directions, which will be useful when implementing beam
|
| 119 |
+
search or a dynamic oracle
|
| 120 |
+
constituency/lstm_tree_stack.py: an LSTM over the elements of a TreeStack
|
| 121 |
+
constituency/transformer_tree_stack.py: attempts to run attention over the nodes
|
| 122 |
+
of a tree_stack. not as effective as the lstm_tree_stack in the initial experiments.
|
| 123 |
+
perhaps it could be refined to work better, though
|
| 124 |
+
|
| 125 |
+
constituency/parse_transitions.py: transitions and a State data structure to store them
|
| 126 |
+
constituency/transition_sequence.py: turns ParseTree objects into
|
| 127 |
+
the transition sequences needed to make them
|
| 128 |
+
|
| 129 |
+
constituency/base_model.py: operates on the transitions to turn them in to constituents,
|
| 130 |
+
eventually forming one final parse tree composed of all of the constituents
|
| 131 |
+
constituency/lstm_model.py: adds LSTM features to the constituents to predict what the
|
| 132 |
+
correct transition to make is, allowing for predictions on previously unseen text
|
| 133 |
+
|
| 134 |
+
constituency/retagging.py: a couple utility methods specifically for retagging
|
| 135 |
+
constituency/utils.py: a couple utility methods
|
| 136 |
+
|
| 137 |
+
constituency/dyanmic_oracle.py: a dynamic oracle which currently
|
| 138 |
+
only operates for the inorder transition sequence.
|
| 139 |
+
uses deterministic rules to redo the correct action sequence when
|
| 140 |
+
the parser makes an error.
|
| 141 |
+
|
| 142 |
+
constituency/partitioned_transformer.py: implementation of a transformer for self-attention.
|
| 143 |
+
presumably this should help, but we have yet to find a model structure where
|
| 144 |
+
this makes the scores go up.
|
| 145 |
+
constituency/label_attention.py: an even fancier form of transformer based on labeled attention:
|
| 146 |
+
https://arxiv.org/abs/1911.03875
|
| 147 |
+
constituency/positional_encoding.py: so far, just the sinusoidal is here.
|
| 148 |
+
a trained encoding is in partitioned_transformer.py.
|
| 149 |
+
this should probably be refactored to common, especially if used elsewhere.
|
| 150 |
+
|
| 151 |
+
stanza/pipeline/constituency_processor.py: interface between this model and the Pipeline
|
| 152 |
+
|
| 153 |
+
stanza/utils/datasets/constituency: various scripts and tools for processing constituency datasets
|
| 154 |
+
|
| 155 |
+
Some alternate optimizer methods:
|
| 156 |
+
adabelief: https://github.com/juntang-zhuang/Adabelief-Optimizer
|
| 157 |
+
madgrad: https://github.com/facebookresearch/madgrad
|
| 158 |
+
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
import argparse
|
| 162 |
+
import logging
|
| 163 |
+
import os
|
| 164 |
+
import re
|
| 165 |
+
|
| 166 |
+
import torch
|
| 167 |
+
|
| 168 |
+
import stanza
|
| 169 |
+
from stanza.models.common import constant
|
| 170 |
+
from stanza.models.common import utils
|
| 171 |
+
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
|
| 172 |
+
from stanza.models.constituency import parser_training
|
| 173 |
+
from stanza.models.constituency import retagging
|
| 174 |
+
from stanza.models.constituency.lstm_model import ConstituencyComposition, SentenceBoundary, StackHistory
|
| 175 |
+
from stanza.models.constituency.parse_transitions import TransitionScheme
|
| 176 |
+
from stanza.models.constituency.text_processing import load_model_parse_text
|
| 177 |
+
from stanza.models.constituency.utils import DEFAULT_LEARNING_EPS, DEFAULT_LEARNING_RATES, DEFAULT_MOMENTUM, DEFAULT_LEARNING_RHO, DEFAULT_WEIGHT_DECAY, NONLINEARITY, add_predict_output_args, postprocess_predict_output_args
|
| 178 |
+
from stanza.resources.common import DEFAULT_MODEL_DIR
|
| 179 |
+
|
| 180 |
+
logger = logging.getLogger('stanza')
|
| 181 |
+
tlogger = logging.getLogger('stanza.constituency.trainer')
|
| 182 |
+
|
| 183 |
+
def build_argparse():
|
| 184 |
+
"""
|
| 185 |
+
Adds the arguments for building the con parser
|
| 186 |
+
|
| 187 |
+
For the most part, defaults are set to cross-validated values, at least for WSJ
|
| 188 |
+
"""
|
| 189 |
+
parser = argparse.ArgumentParser()
|
| 190 |
+
|
| 191 |
+
parser.add_argument('--data_dir', type=str, default='data/constituency', help='Directory of constituency data.')
|
| 192 |
+
|
| 193 |
+
parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors')
|
| 194 |
+
parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
|
| 195 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 196 |
+
parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
|
| 197 |
+
|
| 198 |
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
|
| 199 |
+
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
|
| 200 |
+
|
| 201 |
+
# BERT helps a lot and actually doesn't slow things down too much
|
| 202 |
+
# for VI, for example, use vinai/phobert-base
|
| 203 |
+
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
|
| 204 |
+
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
|
| 205 |
+
parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer")
|
| 206 |
+
parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')
|
| 207 |
+
|
| 208 |
+
# BERT finetuning (or any transformer finetuning)
|
| 209 |
+
# also helps quite a lot.
|
| 210 |
+
# Experimentally, finetuning all of the layers is the most effective
|
| 211 |
+
# On the id_icon dataset with the indolem transformer
|
| 212 |
+
# In this experiment, we trained for 150 iterations with AdaDelta,
|
| 213 |
+
# with the learning rate 0.01,
|
| 214 |
+
# then trained for another 150 with madgrad and no finetuning
|
| 215 |
+
# 1 layer 0.880753 (152)
|
| 216 |
+
# 2 layers 0.880453 (174)
|
| 217 |
+
# 3 layers 0.881774 (163)
|
| 218 |
+
# 4 layers 0.886915 (194)
|
| 219 |
+
# 5 layers 0.892064 (299)
|
| 220 |
+
# 6 layers 0.891825 (224)
|
| 221 |
+
# 7 layers 0.894373 (173)
|
| 222 |
+
# 8 layers 0.894505 (233)
|
| 223 |
+
# 9 layers 0.896676 (269)
|
| 224 |
+
# 10 layers 0.897525 (269)
|
| 225 |
+
# 11 layers 0.897348 (211)
|
| 226 |
+
# 12 layers 0.898729 (270)
|
| 227 |
+
# everything 0.898855 (252)
|
| 228 |
+
# so the trend is clear that more finetuning is better
|
| 229 |
+
#
|
| 230 |
+
# We found that finetuning works very well on the AdaDelta portion
|
| 231 |
+
# of a multistage training, but less well on a madgrad second
|
| 232 |
+
# stage. The issue was that we literally could not set the
|
| 233 |
+
# learning rate low enough because madgrad used epsilon in the LR:
|
| 234 |
+
# https://github.com/facebookresearch/madgrad/issues/16
|
| 235 |
+
#
|
| 236 |
+
# Possible values of the AdaDelta learning rate on the id_icon dataset
|
| 237 |
+
# In this experiment, we finetuned the entire transformer 150
|
| 238 |
+
# iterations on AdaDelta, then trained with madgrad for another
|
| 239 |
+
# 150 with no finetuning
|
| 240 |
+
# 0.0005: 0.89122 (155)
|
| 241 |
+
# 0.001: 0.889807 (241)
|
| 242 |
+
# 0.002: 0.894874 (202)
|
| 243 |
+
# 0.005: 0.896327 (270)
|
| 244 |
+
# 0.006: 0.898989 (246)
|
| 245 |
+
# 0.007: 0.896712 (167)
|
| 246 |
+
# 0.008: 0.900136 (237)
|
| 247 |
+
# 0.009: 0.898597 (169)
|
| 248 |
+
# 0.01: 0.898665 (251)
|
| 249 |
+
# 0.012: 0.89661 (274)
|
| 250 |
+
# 0.014: 0.899149 (283)
|
| 251 |
+
# 0.016: 0.896314 (230)
|
| 252 |
+
# 0.018: 0.897753 (257)
|
| 253 |
+
# 0.02: 0.893665 (256)
|
| 254 |
+
# 0.05: 0.849274 (159)
|
| 255 |
+
# 0.1: 0.850633 (183)
|
| 256 |
+
# 0.2: 0.847332 (176)
|
| 257 |
+
#
|
| 258 |
+
# The peak is somewhere around 0.008 to 0.014, with the further
|
| 259 |
+
# observation that at the 150 iteration mark, 0.09 was winning:
|
| 260 |
+
# 0.007: 0.894589 (33)
|
| 261 |
+
# 0.008: 0.894777 (53)
|
| 262 |
+
# 0.009: 0.896466 (56)
|
| 263 |
+
# 0.01: 0.895557 (71)
|
| 264 |
+
# 0.012: 0.893479 (45)
|
| 265 |
+
# 0.014: 0.89468 (116)
|
| 266 |
+
# 0.016: 0.893053 (128)
|
| 267 |
+
# 0.018: 0.893086 (48)
|
| 268 |
+
#
|
| 269 |
+
# Another option is to train for a few iterations with no
|
| 270 |
+
# finetuning, then begin finetuning. However, that was not
|
| 271 |
+
# beneficial at all.
|
| 272 |
+
# Start iteration on id_icon, same setup as above:
|
| 273 |
+
# 1: 0.898855 (252)
|
| 274 |
+
# 5: 0.897885 (217)
|
| 275 |
+
# 10: 0.895367 (215)
|
| 276 |
+
# 25: 0.896781 (193)
|
| 277 |
+
# 50: 0.895216 (193)
|
| 278 |
+
# Using adamw instead of madgrad:
|
| 279 |
+
# 1: 0.900594 (226)
|
| 280 |
+
# 5: 0.898153 (267)
|
| 281 |
+
# 10: 0.898756 (271)
|
| 282 |
+
# 25: 0.896867 (256)
|
| 283 |
+
# 50: 0.895025 (220)
|
| 284 |
+
#
|
| 285 |
+
#
|
| 286 |
+
# With the observation that very low learning rate is currently
|
| 287 |
+
# not working for madgrad, we tried to parameter sweep LR for
|
| 288 |
+
# AdamW, and got the following, using a first stage LR of 0.009:
|
| 289 |
+
# 0.0: 0.899706 (290)
|
| 290 |
+
# 0.00005: 0.899631 (176)
|
| 291 |
+
# 0.0001: 0.899851 (233)
|
| 292 |
+
# 0.0002: 0.898601 (207)
|
| 293 |
+
# 0.0003: 0.899258 (252)
|
| 294 |
+
# 0.0004: 0.90033 (187)
|
| 295 |
+
# 0.0005: 0.899091 (183)
|
| 296 |
+
# 0.001: 0.899791 (268)
|
| 297 |
+
# 0.002: 0.899453 (196)
|
| 298 |
+
# 0.003: 0.897029 (173)
|
| 299 |
+
# 0.004: 0.899566 (290)
|
| 300 |
+
# 0.005: 0.899285 (289)
|
| 301 |
+
# 0.01: 0.898938 (233)
|
| 302 |
+
# 0.02: 0.898983 (248)
|
| 303 |
+
# 0.03: 0.898571 (247)
|
| 304 |
+
# 0.04: 0.898466 (180)
|
| 305 |
+
# 0.05: 0.897448 (214)
|
| 306 |
+
# It should be noted that in the 0.0001 range, the epoch to epoch
|
| 307 |
+
# change of the Bert weights was almost negligible. Weights would
|
| 308 |
+
# change in the 5th or 6th decimal place, if at all.
|
| 309 |
+
#
|
| 310 |
+
# The conclusion of all these experiments is that, if we are using
|
| 311 |
+
# bert_finetuning, the best approach is probably a stage1 learning
|
| 312 |
+
# rate of 0.009 or so and a second stage optimizer of adamw with
|
| 313 |
+
# no LR or a very low LR. This behavior is what happens with the
|
| 314 |
+
# --stage1_bert_finetune flag
|
| 315 |
+
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
|
| 316 |
+
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
|
| 317 |
+
parser.add_argument('--bert_finetune_layers', default=None, type=int, help='Only finetune this many layers from the transformer')
|
| 318 |
+
parser.add_argument('--bert_finetune_begin_epoch', default=None, type=int, help='Which epoch to start finetuning the transformer')
|
| 319 |
+
parser.add_argument('--bert_finetune_end_epoch', default=None, type=int, help='Which epoch to stop finetuning the transformer')
|
| 320 |
+
parser.add_argument('--bert_learning_rate', default=0.009, type=float, help='Scale the learning rate for transformer finetuning by this much')
|
| 321 |
+
parser.add_argument('--stage1_bert_learning_rate', default=None, type=float, help="Scale the learning rate for transformer finetuning by this much only during an AdaDelta warmup")
|
| 322 |
+
parser.add_argument('--bert_weight_decay', default=0.0001, type=float, help='Scale the weight decay for transformer finetuning by this much')
|
| 323 |
+
parser.add_argument('--stage1_bert_finetune', default=None, action='store_true', help="Finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune")
|
| 324 |
+
parser.add_argument('--no_stage1_bert_finetune', dest='stage1_bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer) during an AdaDelta warmup, even if the second half doesn't use bert_finetune")
|
| 325 |
+
|
| 326 |
+
add_peft_args(parser)
|
| 327 |
+
|
| 328 |
+
parser.add_argument('--tag_embedding_dim', type=int, default=20, help="Embedding size for a tag. 0 turns off the feature")
|
| 329 |
+
# Smaller values also seem to work
|
| 330 |
+
# For example, after 700 iterations:
|
| 331 |
+
# 32: 0.9174
|
| 332 |
+
# 50: 0.9183
|
| 333 |
+
# 72: 0.9176
|
| 334 |
+
# 100: 0.9185
|
| 335 |
+
# not a huge difference regardless
|
| 336 |
+
# (these numbers were without retagging)
|
| 337 |
+
parser.add_argument('--delta_embedding_dim', type=int, default=100, help="Embedding size for a delta embedding")
|
| 338 |
+
|
| 339 |
+
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
|
| 340 |
+
parser.add_argument('--no_train_remove_duplicates', default=True, action='store_false', dest="train_remove_duplicates", help="Do/don't remove duplicates from the training file. Could be useful for intentionally reweighting some trees")
|
| 341 |
+
parser.add_argument('--silver_file', type=str, default=None, help='Secondary training file.')
|
| 342 |
+
parser.add_argument('--silver_remove_duplicates', default=False, action='store_true', help="Do/don't remove duplicates from the silver training file. Could be useful for intentionally reweighting some trees")
|
| 343 |
+
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
|
| 344 |
+
# TODO: possibly refactor --tokenized_file / --tokenized_dir from here & ensemble
|
| 345 |
+
parser.add_argument('--xml_tree_file', type=str, default=None, help='Input file of VLSP formatted trees for parsing with parse_text.')
|
| 346 |
+
parser.add_argument('--tokenized_file', type=str, default=None, help='Input file of tokenized text for parsing with parse_text.')
|
| 347 |
+
parser.add_argument('--tokenized_dir', type=str, default=None, help='Input directory of tokenized text for parsing with parse_text.')
|
| 348 |
+
parser.add_argument('--mode', default='train', choices=['train', 'parse_text', 'predict', 'remove_optimizer'])
|
| 349 |
+
parser.add_argument('--num_generate', type=int, default=0, help='When running a dev set, how many sentences to generate beyond the greedy one')
|
| 350 |
+
add_predict_output_args(parser)
|
| 351 |
+
|
| 352 |
+
parser.add_argument('--lang', type=str, help='Language')
|
| 353 |
+
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
|
| 354 |
+
|
| 355 |
+
parser.add_argument('--transition_embedding_dim', type=int, default=20, help="Embedding size for a transition")
|
| 356 |
+
parser.add_argument('--transition_hidden_size', type=int, default=20, help="Embedding size for transition stack")
|
| 357 |
+
parser.add_argument('--transition_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],
|
| 358 |
+
help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory)))
|
| 359 |
+
parser.add_argument('--transition_heads', default=4, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention")
|
| 360 |
+
|
| 361 |
+
parser.add_argument('--constituent_stack', default=StackHistory.LSTM, type=lambda x: StackHistory[x.upper()],
|
| 362 |
+
help='How to track transitions over a parse. {}'.format(", ".join(x.name for x in StackHistory)))
|
| 363 |
+
parser.add_argument('--constituent_heads', default=8, type=int, help="How many heads to use in MHA *if* the transition_stack is Attention")
|
| 364 |
+
|
| 365 |
+
# larger was more effective, up to a point
|
| 366 |
+
# substantially smaller, such as 128,
|
| 367 |
+
# is fine if bert & charlm are not available
|
| 368 |
+
parser.add_argument('--hidden_size', type=int, default=512, help="Size of the output layers for constituency stack and word queue")
|
| 369 |
+
|
| 370 |
+
parser.add_argument('--epochs', type=int, default=400)
|
| 371 |
+
parser.add_argument('--epoch_size', type=int, default=5000, help="Runs this many trees in an 'epoch' instead of going through the training dataset exactly once. Set to 0 to do the whole training set")
|
| 372 |
+
parser.add_argument('--silver_epoch_size', type=int, default=None, help="Runs this many trees in a silver 'epoch'. If not set, will match --epoch_size")
|
| 373 |
+
|
| 374 |
+
# AdaDelta warmup for the conparser. Motivation: AdaDelta results in
|
| 375 |
+
# higher scores overall, but learns 0s for the weights of the pattn and
|
| 376 |
+
# lattn layers. AdamW learns weights for pattn, and the models are more
|
| 377 |
+
# accurate than models trained without pattn using AdamW, but the models
|
| 378 |
+
# are lower scores overall than the AdaDelta models.
|
| 379 |
+
#
|
| 380 |
+
# This improves that by first running AdaDelta, then switching.
|
| 381 |
+
#
|
| 382 |
+
# Now, if --multistage is set, run AdaDelta for half the epochs with no
|
| 383 |
+
# pattn or lattn. Then start the specified optimizer for the rest of
|
| 384 |
+
# the time with the full model. If pattn and lattn are both present,
|
| 385 |
+
# the model is 1/2 no attn, 1/4 pattn, 1/4 pattn and lattn
|
| 386 |
+
#
|
| 387 |
+
# Improvement on the WSJ dev set can be seen from 94.8 to 95.3
|
| 388 |
+
# when 4 layers of pattn are trained this way.
|
| 389 |
+
# More experiments to follow.
|
| 390 |
+
parser.add_argument('--multistage', default=True, action='store_true', help='1/2 epochs with adadelta no pattn or lattn, 1/4 with chosen optim and no lattn, 1/4 full model')
|
| 391 |
+
parser.add_argument('--no_multistage', dest='multistage', action='store_false', help="don't do the multistage learning")
|
| 392 |
+
|
| 393 |
+
# 1 seems to be the most effective, but we should cross-validate
|
| 394 |
+
parser.add_argument('--oracle_initial_epoch', type=int, default=1, help="Epoch where we start using the dynamic oracle to let the parser keep going with wrong decisions")
|
| 395 |
+
parser.add_argument('--oracle_frequency', type=float, default=0.8, help="How often to use the oracle vs how often to force the correct transition")
|
| 396 |
+
parser.add_argument('--oracle_forced_errors', type=float, default=0.001, help="Occasionally have the model randomly walk through the state space to try to learn how to recover")
|
| 397 |
+
parser.add_argument('--oracle_level', type=int, default=None, help='Restrict oracle transitions to this level or lower. 0 means off. None means use all oracle transitions.')
|
| 398 |
+
parser.add_argument('--additional_oracle_levels', type=str, default=None, help='Add some additional experimental oracle transitions. Basically for A/B testing transitions we expect to be bad.')
|
| 399 |
+
parser.add_argument('--deactivated_oracle_levels', type=str, default=None, help='Temporarily turn off a default oracle level. Basically for A/B testing transitions we expect to be bad.')
|
| 400 |
+
|
| 401 |
+
# 30 is slightly slower than 50, for example, but seems to train a bit better on WSJ
|
| 402 |
+
# earlier version of the model (less accurate overall) had the following results with adadelta:
|
| 403 |
+
# 30: 0.9085
|
| 404 |
+
# 50: 0.9070
|
| 405 |
+
# 75: 0.9010
|
| 406 |
+
# 150: 0.8985
|
| 407 |
+
# as another data point, running a newer version with better constituency lstm behavior had:
|
| 408 |
+
# 30: 0.9111
|
| 409 |
+
# 50: 0.9094
|
| 410 |
+
# checking smaller batch sizes to see how this works, at 135 epochs, the values are
|
| 411 |
+
# 10: 0.8919
|
| 412 |
+
# 20: 0.9072
|
| 413 |
+
# 30: 0.9121
|
| 414 |
+
# obviously these experiments aren't the complete story, but it
|
| 415 |
+
# looks like 30 trees per batch is the best value for WSJ
|
| 416 |
+
# note that these numbers are for adadelta and might not apply
|
| 417 |
+
# to other optimizers
|
| 418 |
+
# eval batch should generally be faster the bigger the batch,
|
| 419 |
+
# up to a point, as it allows for more batching of the LSTM
|
| 420 |
+
# operations and the prediction step
|
| 421 |
+
parser.add_argument('--train_batch_size', type=int, default=30, help='How many trees to train before taking an optimizer step')
|
| 422 |
+
parser.add_argument('--eval_batch_size', type=int, default=50, help='How many trees to batch when running eval')
|
| 423 |
+
|
| 424 |
+
parser.add_argument('--save_dir', type=str, default='saved_models/constituency', help='Root dir for saving models.')
|
| 425 |
+
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_constituency.pt", help="File name to save the model")
|
| 426 |
+
parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing")
|
| 427 |
+
parser.add_argument('--save_each_start', type=int, default=None, help="When to start saving each model")
|
| 428 |
+
parser.add_argument('--save_each_frequency', type=int, default=1, help="How frequently to save each model")
|
| 429 |
+
parser.add_argument('--no_save_each_optimizer', dest='save_each_optimizer', default=True, action='store_false', help="Don't save the optimizer when saving 'each' model")
|
| 430 |
+
|
| 431 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 432 |
+
|
| 433 |
+
parser.add_argument('--no_check_valid_states', default=True, action='store_false', dest='check_valid_states', help="Don't check the constituents or transitions in the dev set when starting a new parser. Warning: the parser will never guess unknown constituents")
|
| 434 |
+
parser.add_argument('--no_strict_check_constituents', default=True, action='store_false', dest='strict_check_constituents', help="Don't check the constituents between the train & dev set. May result in untrainable transitions")
|
| 435 |
+
utils.add_device_args(parser)
|
| 436 |
+
|
| 437 |
+
# Numbers are on a VLSP dataset, before adding attn or other improvements
|
| 438 |
+
# baseline is an 80.6 model that occurs when trained using adadelta, lr 1.0
|
| 439 |
+
#
|
| 440 |
+
# adabelief 0.1: fails horribly
|
| 441 |
+
# 0.02: converges very low scores
|
| 442 |
+
# 0.01: very slow learning
|
| 443 |
+
# 0.002: almost decent
|
| 444 |
+
# 0.001: close, but about 1 f1 low on IT
|
| 445 |
+
# 0.0005: 79.71
|
| 446 |
+
# 0.0002: 80.11
|
| 447 |
+
# 0.0001: 79.85
|
| 448 |
+
# 0.00005: 80.40
|
| 449 |
+
# 0.00002: 80.02
|
| 450 |
+
# 0.00001: 78.95
|
| 451 |
+
#
|
| 452 |
+
# madgrad 0.005: fails horribly
|
| 453 |
+
# 0.001: low scores
|
| 454 |
+
# 0.0005: still somewhat low
|
| 455 |
+
# 0.0002: close, but about 1 f1 low on IT
|
| 456 |
+
# 0.0001: 80.04
|
| 457 |
+
# 0.00005: 79.91
|
| 458 |
+
# 0.00002: 80.15
|
| 459 |
+
# 0.00001: 80.44
|
| 460 |
+
# 0.000005: 80.34
|
| 461 |
+
# 0.000002: 80.39
|
| 462 |
+
#
|
| 463 |
+
# adamw experiment on a TR dataset (not necessarily the best test case)
|
| 464 |
+
# note that at that time, the expected best for adadelta was 0.816
|
| 465 |
+
#
|
| 466 |
+
# 0.00005 - 0.7925
|
| 467 |
+
# 0.0001 - 0.7889
|
| 468 |
+
# 0.0002 - 0.8110
|
| 469 |
+
# 0.00025 - 0.8108
|
| 470 |
+
# 0.0003 - 0.8050
|
| 471 |
+
# 0.0005 - 0.8076
|
| 472 |
+
# 0.001 - 0.8069
|
| 473 |
+
|
| 474 |
+
# Numbers on the VLSP Dataset, with --multistage and default learning rates and adabelief optimizer
|
| 475 |
+
# Gelu: 82.32
|
| 476 |
+
# Mish: 81.95
|
| 477 |
+
# ELU: 81.73
|
| 478 |
+
# Hardshrink: 0.3
|
| 479 |
+
# Hardsigmoid: 79.03
|
| 480 |
+
# Hardtanh: 81.44
|
| 481 |
+
# Hardswish: 81.67
|
| 482 |
+
# Logsigmoid: 80.91
|
| 483 |
+
# Prelu: 80.95 (terminated early)
|
| 484 |
+
# Relu6: 81.91
|
| 485 |
+
# RReLU: 77.00
|
| 486 |
+
# Selu: 81.17
|
| 487 |
+
# Celu: 81.43
|
| 488 |
+
# Silu: 81.90
|
| 489 |
+
# Softplus: 80.94
|
| 490 |
+
# Softshrink: 0.3
|
| 491 |
+
# Softsign: 81.63
|
| 492 |
+
# Softshrink: 13.74
|
| 493 |
+
#
|
| 494 |
+
# Tests with no_charlm, --multitstage
|
| 495 |
+
# Gelu
|
| 496 |
+
# 0.00002 0.819746
|
| 497 |
+
# 0.00005 0.818
|
| 498 |
+
# 0.0001 0.818566
|
| 499 |
+
# 0.0002 0.819111
|
| 500 |
+
# 0.001 0.815609
|
| 501 |
+
#
|
| 502 |
+
# Mish
|
| 503 |
+
# 0.00002 0.816898
|
| 504 |
+
# 0.00005 0.821085
|
| 505 |
+
# 0.0001 0.817821
|
| 506 |
+
# 0.0002 0.818806
|
| 507 |
+
# 0.001 0.816494
|
| 508 |
+
#
|
| 509 |
+
# Relu
|
| 510 |
+
# 0.00002 0.818402
|
| 511 |
+
# 0.00005 0.819019
|
| 512 |
+
# 0.0001 0.821625
|
| 513 |
+
# 0.0002 0.820633
|
| 514 |
+
# 0.001 0.814315
|
| 515 |
+
#
|
| 516 |
+
# Relu6
|
| 517 |
+
# 0.00002 0.819719
|
| 518 |
+
# 0.00005 0.819871
|
| 519 |
+
# 0.0001 0.819018
|
| 520 |
+
# 0.0002 0.819506
|
| 521 |
+
# 0.001 0.819018
|
| 522 |
+
|
| 523 |
+
parser.add_argument('--learning_rate', default=None, type=float, help='Learning rate for the optimizer. Reasonable values are 1.0 for adadelta or 0.001 for SGD. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_RATES))
|
| 524 |
+
parser.add_argument('--learning_eps', default=None, type=float, help='eps value to use in the optimizer. None uses a default for the given optimizer: {}'.format(DEFAULT_LEARNING_EPS))
|
| 525 |
+
parser.add_argument('--learning_momentum', default=None, type=float, help='Momentum. None uses a default for the given optimizer: {}'.format(DEFAULT_MOMENTUM))
|
| 526 |
+
# weight decay values other than adadelta have not been thoroughly tested.
|
| 527 |
+
# When using adadelta, weight_decay of 0.01 to 0.001 had the best results.
|
| 528 |
+
# 0.1 was very clearly too high. 0.0001 might have been okay.
|
| 529 |
+
# Running a series of 5x experiments on a VI dataset:
|
| 530 |
+
# 0.030: 0.8167018
|
| 531 |
+
# 0.025: 0.81659
|
| 532 |
+
# 0.020: 0.81722
|
| 533 |
+
# 0.015: 0.81721
|
| 534 |
+
# 0.010: 0.81474348
|
| 535 |
+
# 0.005: 0.81503
|
| 536 |
+
parser.add_argument('--learning_weight_decay', default=None, type=float, help='Weight decay (eg, l2 reg) to use in the optimizer')
|
| 537 |
+
parser.add_argument('--learning_rho', default=DEFAULT_LEARNING_RHO, type=float, help='Rho parameter in Adadelta')
|
| 538 |
+
# A few experiments on beta2 didn't show much benefit from changing it
|
| 539 |
+
# On an experiment with training WSJ with default parameters
|
| 540 |
+
# AdaDelta for 200 iterations, then training AdamW for 200 more,
|
| 541 |
+
# 0.999, 0.997, 0.995 all wound up with 0.9588
|
| 542 |
+
# values lower than 0.995 all had a slight dropoff
|
| 543 |
+
parser.add_argument('--learning_beta2', default=0.999, type=float, help='Beta2 argument for AdamW')
|
| 544 |
+
parser.add_argument('--optim', default=None, help='Optimizer type: SGD, AdamW, Adadelta, AdaBelief, Madgrad')
|
| 545 |
+
|
| 546 |
+
parser.add_argument('--stage1_learning_rate', default=None, type=float, help='Learning rate to use in the first stage of --multistage. None means use default: {}'.format(DEFAULT_LEARNING_RATES['adadelta']))
|
| 547 |
+
|
| 548 |
+
parser.add_argument('--learning_rate_warmup', default=0, type=int, help="Number of epochs to ramp up learning rate from 0 to full. Set to 0 to always use the chosen learning rate. Currently not functional, as it didn't do anything")
|
| 549 |
+
|
| 550 |
+
parser.add_argument('--learning_rate_factor', default=0.6, type=float, help='Plateau learning rate decreate when plateaued')
|
| 551 |
+
parser.add_argument('--learning_rate_patience', default=5, type=int, help='Plateau learning rate patience')
|
| 552 |
+
parser.add_argument('--learning_rate_cooldown', default=10, type=int, help='Plateau learning rate cooldown')
|
| 553 |
+
parser.add_argument('--learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum')
|
| 554 |
+
parser.add_argument('--stage1_learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum (stage 1)')
|
| 555 |
+
|
| 556 |
+
parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
|
| 557 |
+
parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')
|
| 558 |
+
|
| 559 |
+
# Large Margin is from Large Margin In Softmax Cross-Entropy Loss
|
| 560 |
+
# it did not help on an Italian VIT test
|
| 561 |
+
# scores went from 0.8252 to 0.8248
|
| 562 |
+
parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal. Focal requires `pip install focal_loss_torch`')
|
| 563 |
+
parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')
|
| 564 |
+
|
| 565 |
+
# turn off dropout for word_dropout, predict_dropout, and lstm_input_dropout
|
| 566 |
+
# this mechanism doesn't actually turn off lstm_layer_dropout (yet)
|
| 567 |
+
# but that is set to a default of 0 anyway
|
| 568 |
+
# this is reusing the idea presented in
|
| 569 |
+
# https://arxiv.org/pdf/2303.01500v2
|
| 570 |
+
# "Dropout Reduces Underfitting"
|
| 571 |
+
# Zhuang Liu, Zhiqiu Xu, Joseph Jin, Zhiqiang Shen, Trevor Darrell
|
| 572 |
+
# Unfortunately, this does not consistently help results
|
| 573 |
+
# Averaged of 5 models w/ transformer, dev / test
|
| 574 |
+
# id_icon - improves a little
|
| 575 |
+
# baseline 0.8823 0.8904
|
| 576 |
+
# early_dropout 40 0.8835 0.8919
|
| 577 |
+
# ja_alt - worsens a little
|
| 578 |
+
# baseline 0.9308 0.9355
|
| 579 |
+
# early_dropout 40 0.9287 0.9345
|
| 580 |
+
# vi_vlsp23 - worsens a little
|
| 581 |
+
# baseline 0.8262 0.8290
|
| 582 |
+
# early_dropout 40 0.8255 0.8286
|
| 583 |
+
# We keep this as an available option for further experiments, if needed
|
| 584 |
+
parser.add_argument('--early_dropout', default=-1, type=int, help='When to turn off dropout')
|
| 585 |
+
# When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations:
|
| 586 |
+
# 0.0: 0.9085
|
| 587 |
+
# 0.2: 0.9165
|
| 588 |
+
# 0.4: 0.9162
|
| 589 |
+
# 0.5: 0.9123
|
| 590 |
+
# Letting 0.2 and 0.4 run for longer, along with 0.3 as another
|
| 591 |
+
# trial, continued to give extremely similar results over time.
|
| 592 |
+
# No attempt has been made to test the different dropouts separately...
|
| 593 |
+
parser.add_argument('--word_dropout', default=0.2, type=float, help='Dropout on the word embedding')
|
| 594 |
+
parser.add_argument('--predict_dropout', default=0.2, type=float, help='Dropout on the final prediction layer')
|
| 595 |
+
# lstm_dropout has not been fully tested yet
|
| 596 |
+
# one experiment after 200 iterations (after retagging, so scores are lower than some other experiments):
|
| 597 |
+
# 0.0: 0.9093
|
| 598 |
+
# 0.1: 0.9094
|
| 599 |
+
# 0.2: 0.9094
|
| 600 |
+
# 0.3: 0.9076
|
| 601 |
+
# 0.4: 0.9077
|
| 602 |
+
parser.add_argument('--lstm_layer_dropout', default=0.0, type=float, help='Dropout in the LSTM layers')
|
| 603 |
+
# one not very conclusive experiment (not long enough) came up with these numbers after ~200 iterations
|
| 604 |
+
# 0.0 0.9091
|
| 605 |
+
# 0.1 0.9095
|
| 606 |
+
# 0.2 0.9118
|
| 607 |
+
# 0.3 0.9123
|
| 608 |
+
# 0.4 0.9080
|
| 609 |
+
parser.add_argument('--lstm_input_dropout', default=0.2, type=float, help='Dropout on the input to an LSTM')
|
| 610 |
+
|
| 611 |
+
parser.add_argument('--transition_scheme', default=TransitionScheme.IN_ORDER, type=lambda x: TransitionScheme[x.upper()],
|
| 612 |
+
help='Transition scheme to use. {}'.format(", ".join(x.name for x in TransitionScheme)))
|
| 613 |
+
|
| 614 |
+
parser.add_argument('--reversed', default=False, action='store_true', help='Do the transition sequence reversed')
|
| 615 |
+
|
| 616 |
+
# combining dummy and open node embeddings might be a slight improvement
|
| 617 |
+
# for example, after 550 iterations, one experiment had
|
| 618 |
+
# True: 0.9154
|
| 619 |
+
# False: 0.9150
|
| 620 |
+
# another (with a different structure) had 850 iterations
|
| 621 |
+
# True: 0.9155
|
| 622 |
+
# False: 0.9149
|
| 623 |
+
parser.add_argument('--combined_dummy_embedding', default=True, action='store_true', help="Use the same embedding for dummy nodes and the vectors used when combining constituents")
|
| 624 |
+
parser.add_argument('--no_combined_dummy_embedding', dest='combined_dummy_embedding', action='store_false', help="Don't use the same embedding for dummy nodes and the vectors used when combining constituents")
|
| 625 |
+
|
| 626 |
+
# relu gave at least 1 F1 improvement over tanh in various experiments
|
| 627 |
+
# relu & gelu seem roughly the same, but relu is clearly faster.
|
| 628 |
+
# relu, 496 iterations: 0.9176
|
| 629 |
+
# gelu, 467 iterations: 0.9181
|
| 630 |
+
# after the same clock time on the same hardware. the two had been
|
| 631 |
+
# trading places in terms of accuracy over those ~500 iterations.
|
| 632 |
+
# leaky_relu was not an improvement - a full run on WSJ led to 0.9181 f1 instead of 0.919
|
| 633 |
+
# See constituency/utils.py for more extensive comments on nonlinearity options
|
| 634 |
+
parser.add_argument('--nonlinearity', default='relu', choices=NONLINEARITY.keys(), help='Nonlinearity to use in the model. relu is a noticeable improvement over tanh')
|
| 635 |
+
# In one experiment on an Italian dataset, VIT, we got the following:
|
| 636 |
+
# 0.8254 with relu as the nonlinearity (10 trials)
|
| 637 |
+
# 0.8265 with maxout, k = 2 (15)
|
| 638 |
+
# 0.8253 with maxout, k = 3 (5)
|
| 639 |
+
# The speed in terms of trees/second might be slightly slower with maxout.
|
| 640 |
+
# 51.4 it/s on a Titan Xp with maxout 2 and 51.9 it/s with relu
|
| 641 |
+
# It might also be worth running some experiments with bigger
|
| 642 |
+
# output layers to see if that makes up for the difference in score.
|
| 643 |
+
parser.add_argument('--maxout_k', default=None, type=int, help="Use maxout layers instead of a nonlinearity for the output layers")
|
| 644 |
+
|
| 645 |
+
parser.add_argument('--use_silver_words', default=True, dest='use_silver_words', action='store_true', help="Train/don't train word vectors for words only in the silver dataset")
|
| 646 |
+
parser.add_argument('--no_use_silver_words', default=True, dest='use_silver_words', action='store_false', help="Train/don't train word vectors for words only in the silver dataset")
|
| 647 |
+
parser.add_argument('--rare_word_unknown_frequency', default=0.02, type=float, help='How often to replace a rare word with UNK when training')
|
| 648 |
+
parser.add_argument('--rare_word_threshold', default=0.02, type=float, help='How many words to consider as rare words as a fraction of the dataset')
|
| 649 |
+
parser.add_argument('--tag_unknown_frequency', default=0.001, type=float, help='How often to replace a tag with UNK when training')
|
| 650 |
+
|
| 651 |
+
parser.add_argument('--num_lstm_layers', default=2, type=int, help='How many layers to use in the LSTMs')
|
| 652 |
+
parser.add_argument('--num_tree_lstm_layers', default=None, type=int, help='How many layers to use in the TREE_LSTMs, if used. This also increases the width of the word outputs to match the tree lstm inputs. Default 2 if TREE_LSTM or TREE_LSTM_CX, 1 otherwise')
|
| 653 |
+
parser.add_argument('--num_output_layers', default=3, type=int, help='How many layers to use at the prediction level')
|
| 654 |
+
|
| 655 |
+
parser.add_argument('--sentence_boundary_vectors', default=SentenceBoundary.EVERYTHING, type=lambda x: SentenceBoundary[x.upper()],
|
| 656 |
+
help='Vectors to learn at the start & end of sentences. {}'.format(", ".join(x.name for x in SentenceBoundary)))
|
| 657 |
+
parser.add_argument('--constituency_composition', default=ConstituencyComposition.MAX, type=lambda x: ConstituencyComposition[x.upper()],
|
| 658 |
+
help='How to build a new composition from its children. {}'.format(", ".join(x.name for x in ConstituencyComposition)))
|
| 659 |
+
parser.add_argument('--reduce_heads', default=8, type=int, help='Number of attn heads to use when reducing children into a parent tree (constituency_composition == attn)')
|
| 660 |
+
parser.add_argument('--reduce_position', default=None, type=int, help="Dimension of position vector to use when reducing children. None means 1/4 hidden_size, 0 means don't use (constituency_composition == key | untied_key)")
|
| 661 |
+
|
| 662 |
+
parser.add_argument('--relearn_structure', action='store_true', help='Starting from an existing checkpoint, add or remove pattn / lattn. One thing that works well is to train an initial model using adadelta with no pattn, then add pattn with adamw')
|
| 663 |
+
parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `load_name` path')
|
| 664 |
+
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
|
| 665 |
+
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
|
| 666 |
+
parser.add_argument('--load_name', type=str, default=None, help='Model to load when finetuning, evaluating, or manipulating an existing file')
|
| 667 |
+
parser.add_argument('--load_package', type=str, default=None, help='Download an existing stanza package & use this for tests, finetuning, etc')
|
| 668 |
+
|
| 669 |
+
retagging.add_retag_args(parser)
|
| 670 |
+
|
| 671 |
+
# Partitioned Attention
|
| 672 |
+
parser.add_argument('--pattn_d_model', default=1024, type=int, help='Partitioned attention model dimensionality')
|
| 673 |
+
parser.add_argument('--pattn_morpho_emb_dropout', default=0.2, type=float, help='Dropout rate for morphological features obtained from pretrained model')
|
| 674 |
+
parser.add_argument('--pattn_encoder_max_len', default=512, type=int, help='Max length that can be put into the transformer attention layer')
|
| 675 |
+
parser.add_argument('--pattn_num_heads', default=8, type=int, help='Partitioned attention model number of attention heads')
|
| 676 |
+
parser.add_argument('--pattn_d_kv', default=64, type=int, help='Size of the query and key vector')
|
| 677 |
+
parser.add_argument('--pattn_d_ff', default=2048, type=int, help='Size of the intermediate vectors in the feed-forward sublayer')
|
| 678 |
+
parser.add_argument('--pattn_relu_dropout', default=0.1, type=float, help='ReLU dropout probability in feed-forward sublayer')
|
| 679 |
+
parser.add_argument('--pattn_residual_dropout', default=0.2, type=float, help='Residual dropout probability for all residual connections')
|
| 680 |
+
parser.add_argument('--pattn_attention_dropout', default=0.2, type=float, help='Attention dropout probability')
|
| 681 |
+
parser.add_argument('--pattn_num_layers', default=0, type=int, help='Number of layers for the Partitioned Attention. Currently turned off')
|
| 682 |
+
parser.add_argument('--pattn_bias', default=False, action='store_true', help='Whether or not to learn an additive bias')
|
| 683 |
+
# Results seem relatively similar with learned position embeddings or sin/cos position embeddings
|
| 684 |
+
parser.add_argument('--pattn_timing', default='sin', choices=['learned', 'sin'], help='Use a learned embedding or a sin embedding')
|
| 685 |
+
|
| 686 |
+
# Label Attention
|
| 687 |
+
parser.add_argument('--lattn_d_input_proj', default=None, type=int, help='If set, project the non-positional inputs down to this size before proceeding.')
|
| 688 |
+
parser.add_argument('--lattn_d_kv', default=64, type=int, help='Dimension of the key/query vector')
|
| 689 |
+
parser.add_argument('--lattn_d_proj', default=64, type=int, help='Dimension of the output vector from each label attention head')
|
| 690 |
+
parser.add_argument('--lattn_resdrop', default=True, action='store_true', help='Whether or not to use Residual Dropout')
|
| 691 |
+
parser.add_argument('--lattn_pwff', default=True, action='store_true', help='Whether or not to use a Position-wise Feed-forward Layer')
|
| 692 |
+
parser.add_argument('--lattn_q_as_matrix', default=False, action='store_true', help='Whether or not Label Attention uses learned query vectors. False means it does')
|
| 693 |
+
parser.add_argument('--lattn_partitioned', default=True, action='store_true', help='Whether or not it is partitioned')
|
| 694 |
+
parser.add_argument('--no_lattn_partitioned', default=True, action='store_false', dest='lattn_partitioned', help='Whether or not it is partitioned')
|
| 695 |
+
parser.add_argument('--lattn_combine_as_self', default=False, action='store_true', help='Whether or not the layer uses concatenation. False means it does')
|
| 696 |
+
# currently unused - always assume 1/2 of pattn
|
| 697 |
+
#parser.add_argument('--lattn_d_positional', default=512, type=int, help='Dimension for the positional embedding')
|
| 698 |
+
parser.add_argument('--lattn_d_l', default=32, type=int, help='Number of labels')
|
| 699 |
+
parser.add_argument('--lattn_attention_dropout', default=0.2, type=float, help='Dropout for attention layer')
|
| 700 |
+
parser.add_argument('--lattn_d_ff', default=2048, type=int, help='Dimension of the Feed-forward layer')
|
| 701 |
+
parser.add_argument('--lattn_relu_dropout', default=0.2, type=float, help='Relu dropout for the label attention')
|
| 702 |
+
parser.add_argument('--lattn_residual_dropout', default=0.2, type=float, help='Residual dropout for the label attention')
|
| 703 |
+
parser.add_argument('--lattn_combined_input', default=True, action='store_true', help='Combine all inputs for the lattn, not just the pattn')
|
| 704 |
+
parser.add_argument('--use_lattn', default=False, action='store_true', help='Use the lattn layers - currently turned off')
|
| 705 |
+
parser.add_argument('--no_use_lattn', dest='use_lattn', action='store_false', help='Use the lattn layers - currently turned off')
|
| 706 |
+
parser.add_argument('--no_lattn_combined_input', dest='lattn_combined_input', action='store_false', help="Don't combine all inputs for the lattn, not just the pattn")
|
| 707 |
+
|
| 708 |
+
parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option')
|
| 709 |
+
parser.add_argument('--log_shapes', default=False, action='store_true', help='Log the parameters shapes at the beginning')
|
| 710 |
+
parser.add_argument('--watch_regex', default=None, help='regex to describe which weights and biases to output, if any')
|
| 711 |
+
|
| 712 |
+
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
|
| 713 |
+
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 714 |
+
parser.add_argument('--wandb_norm_regex', default=None, help='Log on wandb any tensor whose norm matches this matrix. Might get cluttered?')
|
| 715 |
+
|
| 716 |
+
return parser
|
| 717 |
+
|
| 718 |
+
def build_model_filename(args):
|
| 719 |
+
embedding = utils.embedding_name(args)
|
| 720 |
+
maybe_finetune = "finetuned" if args['bert_finetune'] or args['stage1_bert_finetune'] else ""
|
| 721 |
+
transformer_finetune_begin = "%d" % args['bert_finetune_begin_epoch'] if args['bert_finetune_begin_epoch'] is not None else ""
|
| 722 |
+
model_save_file = args['save_name'].format(shorthand=args['shorthand'],
|
| 723 |
+
oracle_level=args['oracle_level'],
|
| 724 |
+
embedding=embedding,
|
| 725 |
+
finetune=maybe_finetune,
|
| 726 |
+
transformer_finetune_begin=transformer_finetune_begin,
|
| 727 |
+
transition_scheme=args['transition_scheme'].name.lower().replace("_", ""),
|
| 728 |
+
tscheme=args['transition_scheme'].short_name,
|
| 729 |
+
trans_layers=args['bert_hidden_layers'],
|
| 730 |
+
seed=args['seed'])
|
| 731 |
+
model_save_file = re.sub("_+", "_", model_save_file)
|
| 732 |
+
logger.info("Expanded save_name: %s", model_save_file)
|
| 733 |
+
|
| 734 |
+
model_dir = os.path.split(model_save_file)[0]
|
| 735 |
+
if model_dir != args['save_dir']:
|
| 736 |
+
model_save_file = os.path.join(args['save_dir'], model_save_file)
|
| 737 |
+
return model_save_file
|
| 738 |
+
|
| 739 |
+
def parse_args(args=None):
|
| 740 |
+
parser = build_argparse()
|
| 741 |
+
|
| 742 |
+
args = parser.parse_args(args=args)
|
| 743 |
+
resolve_peft_args(args, logger, check_bert_finetune=False)
|
| 744 |
+
if not args.lang and args.shorthand and len(args.shorthand.split("_", maxsplit=1)) == 2:
|
| 745 |
+
args.lang = args.shorthand.split("_")[0]
|
| 746 |
+
|
| 747 |
+
if args.stage1_bert_learning_rate is None:
|
| 748 |
+
args.stage1_bert_learning_rate = args.bert_learning_rate
|
| 749 |
+
|
| 750 |
+
if args.optim is None and args.mode == 'train':
|
| 751 |
+
if not args.multistage:
|
| 752 |
+
# this seemed to work the best when not doing multistage
|
| 753 |
+
args.optim = "adadelta"
|
| 754 |
+
if args.use_peft and not args.bert_finetune:
|
| 755 |
+
logger.info("--use_peft set. setting --bert_finetune as well")
|
| 756 |
+
args.bert_finetune = True
|
| 757 |
+
elif args.bert_finetune or args.stage1_bert_finetune:
|
| 758 |
+
logger.info("Multistage training is set, optimizer is not chosen, and bert finetuning is active. Will use AdamW as the second stage optimizer.")
|
| 759 |
+
args.optim = "adamw"
|
| 760 |
+
else:
|
| 761 |
+
# if MADGRAD exists, use it
|
| 762 |
+
# otherwise, adamw
|
| 763 |
+
try:
|
| 764 |
+
import madgrad
|
| 765 |
+
args.optim = "madgrad"
|
| 766 |
+
logger.info("Multistage training is set, optimizer is not chosen, and MADGRAD is available. Will use MADGRAD as the second stage optimizer.")
|
| 767 |
+
except ModuleNotFoundError as e:
|
| 768 |
+
logger.warning("Multistage training is set. Best models are with MADGRAD, but it is not installed. Will use AdamW for the second stage optimizer. Consider installing MADGRAD")
|
| 769 |
+
args.optim = "adamw"
|
| 770 |
+
|
| 771 |
+
if args.mode == 'train':
|
| 772 |
+
if args.learning_rate is None:
|
| 773 |
+
args.learning_rate = DEFAULT_LEARNING_RATES.get(args.optim.lower(), None)
|
| 774 |
+
if args.learning_eps is None:
|
| 775 |
+
args.learning_eps = DEFAULT_LEARNING_EPS.get(args.optim.lower(), None)
|
| 776 |
+
if args.learning_momentum is None:
|
| 777 |
+
args.learning_momentum = DEFAULT_MOMENTUM.get(args.optim.lower(), None)
|
| 778 |
+
if args.learning_weight_decay is None:
|
| 779 |
+
args.learning_weight_decay = DEFAULT_WEIGHT_DECAY.get(args.optim.lower(), None)
|
| 780 |
+
|
| 781 |
+
if args.stage1_learning_rate is None:
|
| 782 |
+
args.stage1_learning_rate = DEFAULT_LEARNING_RATES["adadelta"]
|
| 783 |
+
if args.stage1_bert_finetune is None:
|
| 784 |
+
args.stage1_bert_finetune = args.bert_finetune
|
| 785 |
+
|
| 786 |
+
if args.learning_rate_min_lr is None:
|
| 787 |
+
args.learning_rate_min_lr = args.learning_rate * 0.02
|
| 788 |
+
if args.stage1_learning_rate_min_lr is None:
|
| 789 |
+
args.stage1_learning_rate_min_lr = args.stage1_learning_rate * 0.02
|
| 790 |
+
|
| 791 |
+
if args.reduce_position is None:
|
| 792 |
+
args.reduce_position = args.hidden_size // 4
|
| 793 |
+
|
| 794 |
+
if args.num_tree_lstm_layers is None:
|
| 795 |
+
if args.constituency_composition in (ConstituencyComposition.TREE_LSTM, ConstituencyComposition.TREE_LSTM_CX):
|
| 796 |
+
args.num_tree_lstm_layers = 2
|
| 797 |
+
else:
|
| 798 |
+
args.num_tree_lstm_layers = 1
|
| 799 |
+
|
| 800 |
+
if args.wandb_name or args.wandb_norm_regex:
|
| 801 |
+
args.wandb = True
|
| 802 |
+
|
| 803 |
+
args = vars(args)
|
| 804 |
+
|
| 805 |
+
retagging.postprocess_args(args)
|
| 806 |
+
postprocess_predict_output_args(args)
|
| 807 |
+
|
| 808 |
+
model_save_file = build_model_filename(args)
|
| 809 |
+
args['save_name'] = model_save_file
|
| 810 |
+
|
| 811 |
+
if args['save_each_name']:
|
| 812 |
+
model_save_each_file = os.path.join(args['save_dir'], args['save_each_name'])
|
| 813 |
+
model_save_each_file = utils.build_save_each_filename(model_save_each_file)
|
| 814 |
+
args['save_each_name'] = model_save_each_file
|
| 815 |
+
else:
|
| 816 |
+
# in the event that there is a start epoch setting,
|
| 817 |
+
# this will make a reasonable default for the path
|
| 818 |
+
pieces = os.path.splitext(args['save_name'])
|
| 819 |
+
model_save_each_file = pieces[0] + "_%04d" + pieces[1]
|
| 820 |
+
args['save_each_name'] = model_save_each_file
|
| 821 |
+
|
| 822 |
+
if args['checkpoint']:
|
| 823 |
+
args['checkpoint_save_name'] = utils.checkpoint_name(args['save_dir'], model_save_file, args['checkpoint_save_name'])
|
| 824 |
+
|
| 825 |
+
return args
|
| 826 |
+
|
| 827 |
+
def main(args=None):
|
| 828 |
+
"""
|
| 829 |
+
Main function for building con parser
|
| 830 |
+
|
| 831 |
+
Processes args, calls the appropriate function for the chosen --mode
|
| 832 |
+
"""
|
| 833 |
+
args = parse_args(args=args)
|
| 834 |
+
|
| 835 |
+
utils.set_random_seed(args['seed'])
|
| 836 |
+
|
| 837 |
+
logger.info("Running constituency parser in %s mode", args['mode'])
|
| 838 |
+
logger.debug("Using device: %s", args['device'])
|
| 839 |
+
|
| 840 |
+
model_load_file = args['save_name']
|
| 841 |
+
if args['load_name']:
|
| 842 |
+
if os.path.exists(args['load_name']):
|
| 843 |
+
model_load_file = args['load_name']
|
| 844 |
+
else:
|
| 845 |
+
model_load_file = os.path.join(args['save_dir'], args['load_name'])
|
| 846 |
+
elif args['load_package']:
|
| 847 |
+
if args['lang'] is None:
|
| 848 |
+
lang_pieces = args['load_package'].split("_", maxsplit=1)
|
| 849 |
+
try:
|
| 850 |
+
lang = constant.lang_to_langcode(lang_pieces[0])
|
| 851 |
+
except ValueError as e:
|
| 852 |
+
raise ValueError("--lang not specified, and the start of the --load_package name, %s, is not a known language. Please check the values of those parameters" % args['load_package']) from e
|
| 853 |
+
args['lang'] = lang
|
| 854 |
+
args['load_package'] = lang_pieces[1]
|
| 855 |
+
stanza.download(args['lang'], processors="constituency", package={"constituency": args['load_package']})
|
| 856 |
+
model_load_file = os.path.join(DEFAULT_MODEL_DIR, args['lang'], 'constituency', args['load_package'] + ".pt")
|
| 857 |
+
if not os.path.exists(model_load_file):
|
| 858 |
+
raise FileNotFoundError("Expected the downloaded model file for language %s package %s to be in %s, but there is nothing there. Perhaps the package name doesn't exist?" % (args['lang'], args['load_package'], model_load_file))
|
| 859 |
+
else:
|
| 860 |
+
logger.info("Model for language %s package %s is in %s", args['lang'], args['load_package'], model_load_file)
|
| 861 |
+
|
| 862 |
+
# TODO: when loading a saved model, we should default to whatever
|
| 863 |
+
# is in the model file for --retag_method, not the default for the language
|
| 864 |
+
if args['mode'] == 'train':
|
| 865 |
+
if tlogger.level == logging.NOTSET:
|
| 866 |
+
tlogger.setLevel(logging.DEBUG)
|
| 867 |
+
tlogger.debug("Set trainer logging level to DEBUG")
|
| 868 |
+
|
| 869 |
+
retag_pipeline = retagging.build_retag_pipeline(args)
|
| 870 |
+
|
| 871 |
+
if args['mode'] == 'train':
|
| 872 |
+
parser_training.train(args, model_load_file, retag_pipeline)
|
| 873 |
+
elif args['mode'] == 'predict':
|
| 874 |
+
parser_training.evaluate(args, model_load_file, retag_pipeline)
|
| 875 |
+
elif args['mode'] == 'parse_text':
|
| 876 |
+
load_model_parse_text(args, model_load_file, retag_pipeline)
|
| 877 |
+
elif args['mode'] == 'remove_optimizer':
|
| 878 |
+
parser_training.remove_optimizer(args, args['save_name'], model_load_file)
|
| 879 |
+
|
| 880 |
+
if __name__ == '__main__':
|
| 881 |
+
main()
|
stanza/stanza/models/lemmatizer.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for training and evaluating a lemmatizer.
|
| 3 |
+
|
| 4 |
+
This lemmatizer combines a neural sequence-to-sequence architecture with an `edit` classifier
|
| 5 |
+
and two dictionaries to produce robust lemmas from word forms.
|
| 6 |
+
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
import shutil
|
| 13 |
+
import time
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
import argparse
|
| 16 |
+
import numpy as np
|
| 17 |
+
import random
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn, optim
|
| 20 |
+
|
| 21 |
+
from stanza.models.lemma.data import DataLoader
|
| 22 |
+
from stanza.models.lemma.vocab import Vocab
|
| 23 |
+
from stanza.models.lemma.trainer import Trainer
|
| 24 |
+
from stanza.models.lemma import scorer, edit
|
| 25 |
+
from stanza.models.common import utils
|
| 26 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 27 |
+
from stanza.models.common.doc import *
|
| 28 |
+
from stanza.utils.conll import CoNLL
|
| 29 |
+
from stanza.models import _training_logging
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger('stanza')
|
| 32 |
+
|
| 33 |
+
def build_argparse():
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')
|
| 36 |
+
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
|
| 37 |
+
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
|
| 38 |
+
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
|
| 39 |
+
parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
|
| 40 |
+
|
| 41 |
+
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
|
| 42 |
+
parser.add_argument('--shorthand', type=str, help='Shorthand for the dataset to use. lang_dataset')
|
| 43 |
+
|
| 44 |
+
parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default use ensemble.')
|
| 45 |
+
parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based lemmatizer.')
|
| 46 |
+
|
| 47 |
+
parser.add_argument('--hidden_dim', type=int, default=200)
|
| 48 |
+
parser.add_argument('--emb_dim', type=int, default=50)
|
| 49 |
+
parser.add_argument('--num_layers', type=int, default=1)
|
| 50 |
+
parser.add_argument('--emb_dropout', type=float, default=0.5)
|
| 51 |
+
parser.add_argument('--dropout', type=float, default=0.5)
|
| 52 |
+
parser.add_argument('--max_dec_len', type=int, default=50)
|
| 53 |
+
parser.add_argument('--beam_size', type=int, default=1)
|
| 54 |
+
|
| 55 |
+
parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
|
| 56 |
+
parser.add_argument('--pos_dim', type=int, default=50)
|
| 57 |
+
parser.add_argument('--pos_dropout', type=float, default=0.5)
|
| 58 |
+
parser.add_argument('--no_edit', dest='edit', action='store_false', help='Do not use edit classifier in lemmatization. By default use an edit classifier.')
|
| 59 |
+
parser.add_argument('--num_edit', type=int, default=len(edit.EDIT_TO_ID))
|
| 60 |
+
parser.add_argument('--alpha', type=float, default=1.0)
|
| 61 |
+
parser.add_argument('--no_pos', dest='pos', action='store_false', help='Do not use UPOS in lemmatization. By default UPOS is used.')
|
| 62 |
+
parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in lemmatization. By default copy mechanism is used to improve generalization.')
|
| 63 |
+
|
| 64 |
+
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
|
| 65 |
+
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
|
| 66 |
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
|
| 67 |
+
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
|
| 68 |
+
|
| 69 |
+
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
|
| 70 |
+
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
|
| 71 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
|
| 72 |
+
parser.add_argument('--lr_decay', type=float, default=0.9)
|
| 73 |
+
parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.")
|
| 74 |
+
parser.add_argument('--num_epoch', type=int, default=60)
|
| 75 |
+
parser.add_argument('--batch_size', type=int, default=50)
|
| 76 |
+
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
|
| 77 |
+
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
|
| 78 |
+
parser.add_argument('--save_dir', type=str, default='saved_models/lemma', help='Root dir for saving models.')
|
| 79 |
+
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_lemmatizer.pt", help="File name to save the model")
|
| 80 |
+
|
| 81 |
+
parser.add_argument('--caseless', default=False, action='store_true', help='Lowercase everything first before processing. This will happen automatically if 100%% of the data is caseless')
|
| 82 |
+
|
| 83 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 84 |
+
utils.add_device_args(parser)
|
| 85 |
+
|
| 86 |
+
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
|
| 87 |
+
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 88 |
+
return parser
|
| 89 |
+
|
| 90 |
+
def parse_args(args=None):
|
| 91 |
+
parser = build_argparse()
|
| 92 |
+
args = parser.parse_args(args=args)
|
| 93 |
+
|
| 94 |
+
if args.wandb_name:
|
| 95 |
+
args.wandb = True
|
| 96 |
+
|
| 97 |
+
args = vars(args)
|
| 98 |
+
# when building the vocab, we keep track of the original language name
|
| 99 |
+
lang = args['shorthand'].split("_")[0] if args['shorthand'] else ""
|
| 100 |
+
args['lang'] = lang
|
| 101 |
+
return args
|
| 102 |
+
|
| 103 |
+
def main(args=None):
|
| 104 |
+
args = parse_args(args=args)
|
| 105 |
+
|
| 106 |
+
utils.set_random_seed(args['seed'])
|
| 107 |
+
|
| 108 |
+
logger.info("Running lemmatizer in {} mode".format(args['mode']))
|
| 109 |
+
|
| 110 |
+
if args['mode'] == 'train':
|
| 111 |
+
train(args)
|
| 112 |
+
else:
|
| 113 |
+
evaluate(args)
|
| 114 |
+
|
| 115 |
+
def all_lowercase(doc):
|
| 116 |
+
for sentence in doc.sentences:
|
| 117 |
+
for word in sentence.words:
|
| 118 |
+
if word.text.lower() != word.text:
|
| 119 |
+
return False
|
| 120 |
+
return True
|
| 121 |
+
|
| 122 |
+
def build_model_filename(args):
|
| 123 |
+
embedding = "nocharlm"
|
| 124 |
+
if args['charlm'] and args['charlm_forward_file']:
|
| 125 |
+
embedding = "charlm"
|
| 126 |
+
model_file = args['save_name'].format(shorthand=args['shorthand'],
|
| 127 |
+
embedding=embedding)
|
| 128 |
+
model_dir = os.path.split(model_file)[0]
|
| 129 |
+
if not model_dir.startswith(args['save_dir']):
|
| 130 |
+
model_file = os.path.join(args['save_dir'], model_file)
|
| 131 |
+
return model_file
|
| 132 |
+
|
| 133 |
+
def train(args):
|
| 134 |
+
# load data
|
| 135 |
+
logger.info("[Loading data with batch size {}...]".format(args['batch_size']))
|
| 136 |
+
train_doc = CoNLL.conll2doc(input_file=args['train_file'])
|
| 137 |
+
train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)
|
| 138 |
+
vocab = train_batch.vocab
|
| 139 |
+
args['vocab_size'] = vocab['char'].size
|
| 140 |
+
args['pos_vocab_size'] = vocab['pos'].size
|
| 141 |
+
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
|
| 142 |
+
dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
|
| 143 |
+
|
| 144 |
+
utils.ensure_dir(args['save_dir'])
|
| 145 |
+
model_file = build_model_filename(args)
|
| 146 |
+
logger.info("Using full savename: %s", model_file)
|
| 147 |
+
|
| 148 |
+
# pred and gold path
|
| 149 |
+
system_pred_file = args['output_file']
|
| 150 |
+
gold_file = args['gold_file']
|
| 151 |
+
|
| 152 |
+
utils.print_config(args)
|
| 153 |
+
|
| 154 |
+
# skip training if the language does not have training or dev data
|
| 155 |
+
if len(train_batch) == 0 or len(dev_batch) == 0:
|
| 156 |
+
logger.warning("[Skip training because no training data available...]")
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
if not args['caseless'] and all_lowercase(train_doc):
|
| 160 |
+
logger.info("Building a caseless model, as all of the training data is caseless")
|
| 161 |
+
args['caseless'] = True
|
| 162 |
+
|
| 163 |
+
# start training
|
| 164 |
+
# train a dictionary-based lemmatizer
|
| 165 |
+
logger.info("Building lemmatizer in %s", model_file)
|
| 166 |
+
trainer = Trainer(args=args, vocab=vocab, device=args['device'])
|
| 167 |
+
logger.info("[Training dictionary-based lemmatizer...]")
|
| 168 |
+
trainer.train_dict(train_batch.raw_data())
|
| 169 |
+
logger.info("Evaluating on dev set...")
|
| 170 |
+
dev_preds = trainer.predict_dict(dev_batch.doc.get([TEXT, UPOS]))
|
| 171 |
+
dev_batch.doc.set([LEMMA], dev_preds)
|
| 172 |
+
CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)
|
| 173 |
+
_, _, dev_f = scorer.score(system_pred_file, gold_file)
|
| 174 |
+
logger.info("Dev F1 = {:.2f}".format(dev_f * 100))
|
| 175 |
+
|
| 176 |
+
if args.get('dict_only', False):
|
| 177 |
+
# save dictionaries
|
| 178 |
+
trainer.save(model_file)
|
| 179 |
+
else:
|
| 180 |
+
if args['wandb']:
|
| 181 |
+
import wandb
|
| 182 |
+
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_lemmatizer" % args['shorthand']
|
| 183 |
+
wandb.init(name=wandb_name, config=args)
|
| 184 |
+
wandb.run.define_metric('train_loss', summary='min')
|
| 185 |
+
wandb.run.define_metric('dev_score', summary='max')
|
| 186 |
+
|
| 187 |
+
# train a seq2seq model
|
| 188 |
+
logger.info("[Training seq2seq-based lemmatizer...]")
|
| 189 |
+
global_step = 0
|
| 190 |
+
max_steps = len(train_batch) * args['num_epoch']
|
| 191 |
+
dev_score_history = []
|
| 192 |
+
best_dev_preds = []
|
| 193 |
+
current_lr = args['lr']
|
| 194 |
+
global_start_time = time.time()
|
| 195 |
+
format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
|
| 196 |
+
|
| 197 |
+
# start training
|
| 198 |
+
for epoch in range(1, args['num_epoch']+1):
|
| 199 |
+
train_loss = 0
|
| 200 |
+
for i, batch in enumerate(train_batch):
|
| 201 |
+
start_time = time.time()
|
| 202 |
+
global_step += 1
|
| 203 |
+
loss = trainer.update(batch, eval=False) # update step
|
| 204 |
+
train_loss += loss
|
| 205 |
+
if global_step % args['log_step'] == 0:
|
| 206 |
+
duration = time.time() - start_time
|
| 207 |
+
logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,
|
| 208 |
+
max_steps, epoch, args['num_epoch'], loss, duration, current_lr))
|
| 209 |
+
|
| 210 |
+
# eval on dev
|
| 211 |
+
logger.info("Evaluating on dev set...")
|
| 212 |
+
dev_preds = []
|
| 213 |
+
dev_edits = []
|
| 214 |
+
for i, batch in enumerate(dev_batch):
|
| 215 |
+
preds, edits = trainer.predict(batch, args['beam_size'])
|
| 216 |
+
dev_preds += preds
|
| 217 |
+
if edits is not None:
|
| 218 |
+
dev_edits += edits
|
| 219 |
+
dev_preds = trainer.postprocess(dev_batch.doc.get([TEXT]), dev_preds, edits=dev_edits)
|
| 220 |
+
|
| 221 |
+
# try ensembling with dict if necessary
|
| 222 |
+
if args.get('ensemble_dict', False):
|
| 223 |
+
logger.info("[Ensembling dict with seq2seq model...]")
|
| 224 |
+
dev_preds = trainer.ensemble(dev_batch.doc.get([TEXT, UPOS]), dev_preds)
|
| 225 |
+
dev_batch.doc.set([LEMMA], dev_preds)
|
| 226 |
+
CoNLL.write_doc2conll(dev_batch.doc, system_pred_file)
|
| 227 |
+
_, _, dev_score = scorer.score(system_pred_file, gold_file)
|
| 228 |
+
|
| 229 |
+
train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
|
| 230 |
+
logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score))
|
| 231 |
+
|
| 232 |
+
if args['wandb']:
|
| 233 |
+
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
|
| 234 |
+
|
| 235 |
+
# save best model
|
| 236 |
+
if epoch == 1 or dev_score > max(dev_score_history):
|
| 237 |
+
trainer.save(model_file)
|
| 238 |
+
logger.info("new best model saved.")
|
| 239 |
+
best_dev_preds = dev_preds
|
| 240 |
+
|
| 241 |
+
# lr schedule
|
| 242 |
+
if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1] and \
|
| 243 |
+
args['optim'] in ['sgd', 'adagrad']:
|
| 244 |
+
current_lr *= args['lr_decay']
|
| 245 |
+
trainer.update_lr(current_lr)
|
| 246 |
+
|
| 247 |
+
dev_score_history += [dev_score]
|
| 248 |
+
logger.info("")
|
| 249 |
+
|
| 250 |
+
logger.info("Training ended with {} epochs.".format(epoch))
|
| 251 |
+
|
| 252 |
+
if args['wandb']:
|
| 253 |
+
wandb.finish()
|
| 254 |
+
|
| 255 |
+
best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1
|
| 256 |
+
logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))
|
| 257 |
+
|
| 258 |
+
def evaluate(args):
|
| 259 |
+
# file paths
|
| 260 |
+
system_pred_file = args['output_file']
|
| 261 |
+
gold_file = args['gold_file']
|
| 262 |
+
model_file = build_model_filename(args)
|
| 263 |
+
|
| 264 |
+
# load model
|
| 265 |
+
trainer = Trainer(model_file=model_file, device=args['device'], args=args)
|
| 266 |
+
loaded_args, vocab = trainer.args, trainer.vocab
|
| 267 |
+
|
| 268 |
+
for k in args:
|
| 269 |
+
if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:
|
| 270 |
+
loaded_args[k] = args[k]
|
| 271 |
+
|
| 272 |
+
# load data
|
| 273 |
+
logger.info("Loading data with batch size {}...".format(args['batch_size']))
|
| 274 |
+
doc = CoNLL.conll2doc(input_file=args['eval_file'])
|
| 275 |
+
batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
|
| 276 |
+
|
| 277 |
+
# skip eval if dev data does not exist
|
| 278 |
+
if len(batch) == 0:
|
| 279 |
+
logger.warning("Skip evaluation because no dev data is available...\nLemma score:\n{} ".format(args['shorthand']))
|
| 280 |
+
return
|
| 281 |
+
|
| 282 |
+
dict_preds = trainer.predict_dict(batch.doc.get([TEXT, UPOS]))
|
| 283 |
+
|
| 284 |
+
if loaded_args.get('dict_only', False):
|
| 285 |
+
preds = dict_preds
|
| 286 |
+
else:
|
| 287 |
+
logger.info("Running the seq2seq model...")
|
| 288 |
+
preds = []
|
| 289 |
+
edits = []
|
| 290 |
+
for i, b in enumerate(batch):
|
| 291 |
+
ps, es = trainer.predict(b, args['beam_size'])
|
| 292 |
+
preds += ps
|
| 293 |
+
if es is not None:
|
| 294 |
+
edits += es
|
| 295 |
+
preds = trainer.postprocess(batch.doc.get([TEXT]), preds, edits=edits)
|
| 296 |
+
|
| 297 |
+
if loaded_args.get('ensemble_dict', False):
|
| 298 |
+
logger.info("[Ensembling dict with seq2seq lemmatizer...]")
|
| 299 |
+
preds = trainer.ensemble(batch.doc.get([TEXT, UPOS]), preds)
|
| 300 |
+
|
| 301 |
+
if trainer.has_contextual_lemmatizers():
|
| 302 |
+
preds = trainer.update_contextual_preds(batch.doc, preds)
|
| 303 |
+
|
| 304 |
+
# write to file and score
|
| 305 |
+
batch.doc.set([LEMMA], preds)
|
| 306 |
+
CoNLL.write_doc2conll(batch.doc, system_pred_file)
|
| 307 |
+
if gold_file is not None:
|
| 308 |
+
_, _, score = scorer.score(system_pred_file, gold_file)
|
| 309 |
+
|
| 310 |
+
logger.info("Finished evaluation\nLemma score:\n{} {:.2f}".format(args['shorthand'], score*100))
|
| 311 |
+
|
| 312 |
+
if __name__ == '__main__':
|
| 313 |
+
main()
|
stanza/stanza/pipeline/_constants.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Module defining constants """
|
| 2 |
+
|
| 3 |
+
# string constants for processor names
|
| 4 |
+
LANGID = 'langid'
|
| 5 |
+
TOKENIZE = 'tokenize'
|
| 6 |
+
MWT = 'mwt'
|
| 7 |
+
POS = 'pos'
|
| 8 |
+
LEMMA = 'lemma'
|
| 9 |
+
DEPPARSE = 'depparse'
|
| 10 |
+
NER = 'ner'
|
| 11 |
+
SENTIMENT = 'sentiment'
|
| 12 |
+
CONSTITUENCY = 'constituency'
|
| 13 |
+
COREF = 'coref'
|
stanza/stanza/pipeline/external/spacy.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processors related to spaCy in the pipeline.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from stanza.models.common import doc
|
| 6 |
+
from stanza.pipeline._constants import TOKENIZE
|
| 7 |
+
from stanza.pipeline.processor import ProcessorVariant, register_processor_variant
|
| 8 |
+
|
| 9 |
+
def check_spacy():
|
| 10 |
+
"""
|
| 11 |
+
Import necessary components from spaCy to perform tokenization.
|
| 12 |
+
"""
|
| 13 |
+
try:
|
| 14 |
+
import spacy
|
| 15 |
+
except ImportError:
|
| 16 |
+
raise ImportError(
|
| 17 |
+
"spaCy is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions."
|
| 18 |
+
)
|
| 19 |
+
return True
|
| 20 |
+
|
| 21 |
+
@register_processor_variant(TOKENIZE, 'spacy')
|
| 22 |
+
class SpacyTokenizer(ProcessorVariant):
|
| 23 |
+
def __init__(self, config):
|
| 24 |
+
""" Construct a spaCy-based tokenizer by loading the spaCy pipeline.
|
| 25 |
+
"""
|
| 26 |
+
if config['lang'] != 'en':
|
| 27 |
+
raise Exception("spaCy tokenizer is currently only allowed in English pipeline.")
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import spacy
|
| 31 |
+
from spacy.lang.en import English
|
| 32 |
+
except ImportError:
|
| 33 |
+
raise ImportError(
|
| 34 |
+
"spaCy 2.0+ is used but not installed on your machine. Go to https://spacy.io/usage for installation instructions."
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Create a Tokenizer with the default settings for English
|
| 38 |
+
# including punctuation rules and exceptions
|
| 39 |
+
self.nlp = English()
|
| 40 |
+
# by default spacy uses dependency parser to do ssplit
|
| 41 |
+
# we need to add a sentencizer for fast rule-based ssplit
|
| 42 |
+
if spacy.__version__.startswith("2."):
|
| 43 |
+
self.nlp.add_pipe(self.nlp.create_pipe("sentencizer"))
|
| 44 |
+
else:
|
| 45 |
+
self.nlp.add_pipe("sentencizer")
|
| 46 |
+
self.no_ssplit = config.get('no_ssplit', False)
|
| 47 |
+
|
| 48 |
+
def process(self, document):
|
| 49 |
+
""" Tokenize a document with the spaCy tokenizer and wrap the results into a Doc object.
|
| 50 |
+
"""
|
| 51 |
+
if isinstance(document, doc.Document):
|
| 52 |
+
text = document.text
|
| 53 |
+
else:
|
| 54 |
+
text = document
|
| 55 |
+
if not isinstance(text, str):
|
| 56 |
+
raise Exception("Must supply a string or Stanza Document object to the spaCy tokenizer.")
|
| 57 |
+
spacy_doc = self.nlp(text)
|
| 58 |
+
|
| 59 |
+
sentences = []
|
| 60 |
+
for sent in spacy_doc.sents:
|
| 61 |
+
tokens = []
|
| 62 |
+
for tok in sent:
|
| 63 |
+
token_entry = {
|
| 64 |
+
doc.TEXT: tok.text,
|
| 65 |
+
doc.MISC: f"{doc.START_CHAR}={tok.idx}|{doc.END_CHAR}={tok.idx+len(tok.text)}"
|
| 66 |
+
}
|
| 67 |
+
tokens.append(token_entry)
|
| 68 |
+
sentences.append(tokens)
|
| 69 |
+
|
| 70 |
+
# if no_ssplit is set, flatten all the sentences into one sentence
|
| 71 |
+
if self.no_ssplit:
|
| 72 |
+
sentences = [[t for s in sentences for t in s]]
|
| 73 |
+
|
| 74 |
+
return doc.Document(sentences, text)
|
stanza/stanza/pipeline/ner_processor.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor for performing named entity tagging.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from stanza.models.common import doc
|
| 10 |
+
from stanza.models.common.exceptions import ForwardCharlmNotFoundError, BackwardCharlmNotFoundError
|
| 11 |
+
from stanza.models.common.utils import unsort
|
| 12 |
+
from stanza.models.ner.data import DataLoader
|
| 13 |
+
from stanza.models.ner.trainer import Trainer
|
| 14 |
+
from stanza.models.ner.utils import merge_tags
|
| 15 |
+
from stanza.pipeline._constants import *
|
| 16 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger('stanza')
|
| 19 |
+
|
| 20 |
+
@register_processor(name=NER)
|
| 21 |
+
class NERProcessor(UDProcessor):
|
| 22 |
+
|
| 23 |
+
# set of processor requirements this processor fulfills
|
| 24 |
+
PROVIDES_DEFAULT = set([NER])
|
| 25 |
+
# set of processor requirements for this processor
|
| 26 |
+
REQUIRES_DEFAULT = set([TOKENIZE])
|
| 27 |
+
|
| 28 |
+
def _get_dependencies(self, config, dep_name):
|
| 29 |
+
dependencies = config.get(dep_name, None)
|
| 30 |
+
if dependencies is not None:
|
| 31 |
+
dependencies = dependencies.split(";")
|
| 32 |
+
dependencies = [x if x else None for x in dependencies]
|
| 33 |
+
else:
|
| 34 |
+
dependencies = [x.get(dep_name) for x in config.get('dependencies', [])]
|
| 35 |
+
return dependencies
|
| 36 |
+
|
| 37 |
+
def _set_up_model(self, config, pipeline, device):
|
| 38 |
+
# set up trainer
|
| 39 |
+
model_paths = config.get('model_path')
|
| 40 |
+
if isinstance(model_paths, str):
|
| 41 |
+
model_paths = model_paths.split(";")
|
| 42 |
+
|
| 43 |
+
charlm_forward_files = self._get_dependencies(config, 'forward_charlm_path')
|
| 44 |
+
charlm_backward_files = self._get_dependencies(config, 'backward_charlm_path')
|
| 45 |
+
pretrain_files = self._get_dependencies(config, 'pretrain_path')
|
| 46 |
+
|
| 47 |
+
# allow predict_tagset to be specified as an int
|
| 48 |
+
# (which only applies to the first model)
|
| 49 |
+
# or as a string ";" separated list of ints
|
| 50 |
+
self._predict_tagset = {}
|
| 51 |
+
predict_tagset = config.get('predict_tagset', None)
|
| 52 |
+
if predict_tagset:
|
| 53 |
+
if isinstance(predict_tagset, int):
|
| 54 |
+
self._predict_tagset[0] = predict_tagset
|
| 55 |
+
else:
|
| 56 |
+
predict_tagset = predict_tagset.split(";")
|
| 57 |
+
for piece_idx, piece in enumerate(predict_tagset):
|
| 58 |
+
if piece:
|
| 59 |
+
self._predict_tagset[piece_idx] = int(piece)
|
| 60 |
+
|
| 61 |
+
self.trainers = []
|
| 62 |
+
for (model_path, pretrain_path, charlm_forward, charlm_backward) in zip(model_paths, pretrain_files, charlm_forward_files, charlm_backward_files):
|
| 63 |
+
logger.debug("Loading %s with pretrain %s, forward charlm %s, backward charlm %s", model_path, pretrain_path, charlm_forward, charlm_backward)
|
| 64 |
+
pretrain = pipeline.foundation_cache.load_pretrain(pretrain_path) if pretrain_path else None
|
| 65 |
+
args = {'charlm_forward_file': charlm_forward,
|
| 66 |
+
'charlm_backward_file': charlm_backward}
|
| 67 |
+
|
| 68 |
+
predict_tagset = self._predict_tagset.get(len(self.trainers), None)
|
| 69 |
+
if predict_tagset is not None:
|
| 70 |
+
args['predict_tagset'] = predict_tagset
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
trainer = Trainer(args=args, model_file=model_path, pretrain=pretrain, device=device, foundation_cache=pipeline.foundation_cache)
|
| 74 |
+
except ForwardCharlmNotFoundError as e:
|
| 75 |
+
raise ForwardCharlmNotFoundError("Could not find the forward charlm %s. Please specify the correct path with ner_forward_charlm_path" % e.filename, e.filename) from None
|
| 76 |
+
except BackwardCharlmNotFoundError as e:
|
| 77 |
+
raise BackwardCharlmNotFoundError("Could not find the backward charlm %s. Please specify the correct path with ner_backward_charlm_path" % e.filename, e.filename) from None
|
| 78 |
+
self.trainers.append(trainer)
|
| 79 |
+
|
| 80 |
+
self._trainer = self.trainers[0]
|
| 81 |
+
self.model_paths = model_paths
|
| 82 |
+
|
| 83 |
+
def _set_up_final_config(self, config):
|
| 84 |
+
""" Finalize the configurations for this processor, based off of values from a UD model. """
|
| 85 |
+
# set configurations from loaded model
|
| 86 |
+
if len(self.trainers) == 0:
|
| 87 |
+
raise RuntimeError("Somehow there are no models loaded!")
|
| 88 |
+
self._vocab = self.trainers[0].vocab
|
| 89 |
+
self.configs = []
|
| 90 |
+
for trainer in self.trainers:
|
| 91 |
+
loaded_args = trainer.args
|
| 92 |
+
# filter out unneeded args from model
|
| 93 |
+
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
|
| 94 |
+
loaded_args.update(config)
|
| 95 |
+
self.configs.append(loaded_args)
|
| 96 |
+
self._config = self.configs[0]
|
| 97 |
+
|
| 98 |
+
def __str__(self):
|
| 99 |
+
return "NERProcessor(%s)" % ";".join(self.model_paths)
|
| 100 |
+
|
| 101 |
+
def mark_inactive(self):
|
| 102 |
+
""" Drop memory intensive resources if keeping this processor around for reasons other than running it. """
|
| 103 |
+
super().mark_inactive()
|
| 104 |
+
self.trainers = None
|
| 105 |
+
|
| 106 |
+
def process(self, document):
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
all_preds = []
|
| 109 |
+
for trainer, config in zip(self.trainers, self.configs):
|
| 110 |
+
# set up a eval-only data loader and skip tag preprocessing
|
| 111 |
+
batch = DataLoader(document, config['batch_size'], config, vocab=trainer.vocab, evaluation=True, preprocess_tags=False, bert_tokenizer=trainer.model.bert_tokenizer)
|
| 112 |
+
preds = []
|
| 113 |
+
for i, b in enumerate(batch):
|
| 114 |
+
preds += trainer.predict(b)
|
| 115 |
+
all_preds.append(preds)
|
| 116 |
+
# for each sentence, gather a list of predictions
|
| 117 |
+
# merge those predictions into a single list
|
| 118 |
+
# earlier models will have precedence
|
| 119 |
+
preds = [merge_tags(*x) for x in zip(*all_preds)]
|
| 120 |
+
batch.doc.set([doc.NER], [y for x in preds for y in x], to_token=True)
|
| 121 |
+
batch.doc.set([doc.MULTI_NER], [tuple(y) for x in zip(*all_preds) for y in zip(*x)], to_token=True)
|
| 122 |
+
# collect entities into document attribute
|
| 123 |
+
total = len(batch.doc.build_ents())
|
| 124 |
+
logger.debug(f'{total} entities found in document.')
|
| 125 |
+
return batch.doc
|
| 126 |
+
|
| 127 |
+
def bulk_process(self, docs):
|
| 128 |
+
"""
|
| 129 |
+
NER processor has a collation step after running inference
|
| 130 |
+
"""
|
| 131 |
+
docs = super().bulk_process(docs)
|
| 132 |
+
for doc in docs:
|
| 133 |
+
doc.build_ents()
|
| 134 |
+
return docs
|
| 135 |
+
|
| 136 |
+
def get_known_tags(self, model_idx=0):
|
| 137 |
+
"""
|
| 138 |
+
Return the tags known by this model
|
| 139 |
+
|
| 140 |
+
Removes the S-, B-, etc, and does not include O
|
| 141 |
+
Specify model_idx if the processor has more than one model
|
| 142 |
+
"""
|
| 143 |
+
return self.trainers[model_idx].get_known_tags()
|
stanza/stanza/resources/print_charlm_depparse.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A small utility script to output which depparse models use charlm
|
| 3 |
+
|
| 4 |
+
(It should skip en_genia, en_craft, but currently doesn't)
|
| 5 |
+
|
| 6 |
+
Not frequently useful, but seems like the kind of thing that might get used a couple times
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from stanza.resources.common import load_resources_json
|
| 10 |
+
from stanza.resources.default_packages import default_charlms, depparse_charlms
|
| 11 |
+
|
| 12 |
+
def list_depparse():
|
| 13 |
+
charlm_langs = list(default_charlms.keys())
|
| 14 |
+
resources = load_resources_json()
|
| 15 |
+
|
| 16 |
+
models = ["%s_%s" % (lang, model) for lang in charlm_langs for model in resources[lang].get("depparse", {})
|
| 17 |
+
if lang not in depparse_charlms or model not in depparse_charlms[lang] or depparse_charlms[lang][model] is not None]
|
| 18 |
+
return models
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
models = list_depparse()
|
| 22 |
+
print(" ".join(models))
|
stanza/stanza/server/dependency_converter.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A converter from constituency trees to dependency trees using CoreNLP's UniversalEnglish converter.
|
| 3 |
+
|
| 4 |
+
ONLY works on English.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import stanza
|
| 8 |
+
from stanza.protobuf import DependencyConverterRequest, DependencyConverterResponse
|
| 9 |
+
from stanza.server.java_protobuf_requests import send_request, build_tree, JavaProtobufContext
|
| 10 |
+
|
| 11 |
+
CONVERTER_JAVA = "edu.stanford.nlp.trees.ProcessDependencyConverterRequest"
|
| 12 |
+
|
| 13 |
+
def send_converter_request(request, classpath=None):
|
| 14 |
+
return send_request(request, DependencyConverterResponse, CONVERTER_JAVA, classpath=classpath)
|
| 15 |
+
|
| 16 |
+
def build_request(doc):
|
| 17 |
+
"""
|
| 18 |
+
Request format is simple: one tree per sentence in the document
|
| 19 |
+
"""
|
| 20 |
+
request = DependencyConverterRequest()
|
| 21 |
+
for sentence in doc.sentences:
|
| 22 |
+
request.trees.append(build_tree(sentence.constituency, None))
|
| 23 |
+
return request
|
| 24 |
+
|
| 25 |
+
def process_doc(doc, classpath=None):
|
| 26 |
+
"""
|
| 27 |
+
Convert the constituency trees in the document,
|
| 28 |
+
then attach the resulting dependencies to the sentences
|
| 29 |
+
"""
|
| 30 |
+
request = build_request(doc)
|
| 31 |
+
response = send_converter_request(request, classpath=classpath)
|
| 32 |
+
attach_dependencies(doc, response)
|
| 33 |
+
|
| 34 |
+
def attach_dependencies(doc, response):
|
| 35 |
+
if len(doc.sentences) != len(response.conversions):
|
| 36 |
+
raise ValueError("Sent %d sentences but got back %d conversions" % (len(doc.sentences), len(response.conversions)))
|
| 37 |
+
for sent_idx, (sentence, conversion) in enumerate(zip(doc.sentences, response.conversions)):
|
| 38 |
+
graph = conversion.graph
|
| 39 |
+
|
| 40 |
+
# The deterministic conversion should have an equal number of words and one fewer edge
|
| 41 |
+
# ... the root is represented by a word with no parent
|
| 42 |
+
if len(sentence.words) != len(graph.node):
|
| 43 |
+
raise ValueError("Sentence %d of the conversion should have %d words but got back %d nodes in the graph" % (sent_idx, len(sentence.words), len(graph.node)))
|
| 44 |
+
if len(sentence.words) != len(graph.edge) + 1:
|
| 45 |
+
raise ValueError("Sentence %d of the conversion should have %d edges (one per word, plus the root) but got back %d edges in the graph" % (sent_idx, len(sentence.words) - 1, len(graph.edge)))
|
| 46 |
+
|
| 47 |
+
expected_nodes = set(range(1, len(sentence.words) + 1))
|
| 48 |
+
targets = set()
|
| 49 |
+
for edge in graph.edge:
|
| 50 |
+
if edge.target in targets:
|
| 51 |
+
raise ValueError("Found two parents of %d in sentence %d" % (edge.target, sent_idx))
|
| 52 |
+
targets.add(edge.target)
|
| 53 |
+
# -1 since the words are 0 indexed in the sentence,
|
| 54 |
+
# but we count dependencies from 1
|
| 55 |
+
sentence.words[edge.target-1].head = edge.source
|
| 56 |
+
sentence.words[edge.target-1].deprel = edge.dep
|
| 57 |
+
roots = expected_nodes - targets
|
| 58 |
+
assert len(roots) == 1
|
| 59 |
+
for root in roots:
|
| 60 |
+
sentence.words[root-1].head = 0
|
| 61 |
+
sentence.words[root-1].deprel = "root"
|
| 62 |
+
sentence.build_dependencies()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DependencyConverter(JavaProtobufContext):
|
| 66 |
+
"""
|
| 67 |
+
Context window for the dependency converter
|
| 68 |
+
|
| 69 |
+
This is a context window which keeps a process open. Should allow
|
| 70 |
+
for multiple requests without launching new java processes each time.
|
| 71 |
+
"""
|
| 72 |
+
def __init__(self, classpath=None):
|
| 73 |
+
super(DependencyConverter, self).__init__(classpath, DependencyConverterResponse, CONVERTER_JAVA)
|
| 74 |
+
|
| 75 |
+
def process(self, doc):
|
| 76 |
+
"""
|
| 77 |
+
Converts a constituency tree to dependency trees for each of the sentences in the document
|
| 78 |
+
"""
|
| 79 |
+
request = build_request(doc)
|
| 80 |
+
response = self.process_request(request)
|
| 81 |
+
attach_dependencies(doc, response)
|
| 82 |
+
return doc
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
nlp = stanza.Pipeline('en',
|
| 86 |
+
processors='tokenize,pos,constituency')
|
| 87 |
+
|
| 88 |
+
doc = nlp('I like blue antennae.')
|
| 89 |
+
print("{:C}".format(doc))
|
| 90 |
+
process_doc(doc, classpath="$CLASSPATH")
|
| 91 |
+
print("{:C}".format(doc))
|
| 92 |
+
|
| 93 |
+
doc = nlp('And I cannot lie.')
|
| 94 |
+
print("{:C}".format(doc))
|
| 95 |
+
with DependencyConverter(classpath="$CLASSPATH") as converter:
|
| 96 |
+
converter.process(doc)
|
| 97 |
+
print("{:C}".format(doc))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
main()
|
stanza/stanza/tests/classifiers/test_constituency_classifier.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
import stanza
|
| 6 |
+
import stanza.models.classifier as classifier
|
| 7 |
+
import stanza.models.classifiers.data as data
|
| 8 |
+
from stanza.models.classifiers.trainer import Trainer
|
| 9 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 10 |
+
from stanza.tests.classifiers.test_classifier import fake_embeddings
|
| 11 |
+
from stanza.tests.classifiers.test_data import train_file_with_trees, dev_file_with_trees
|
| 12 |
+
from stanza.models.common import utils
|
| 13 |
+
from stanza.tests.constituency.test_trainer import build_trainer, TREEBANK
|
| 14 |
+
|
| 15 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 16 |
+
|
| 17 |
+
class TestConstituencyClassifier:
|
| 18 |
+
@pytest.fixture(scope="class")
|
| 19 |
+
def constituency_model(self, fake_embeddings, tmp_path_factory):
|
| 20 |
+
args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
|
| 21 |
+
trainer = build_trainer(str(fake_embeddings), *args, treebank=TREEBANK)
|
| 22 |
+
|
| 23 |
+
trainer_pt = str(tmp_path_factory.mktemp("constituency") / "constituency.pt")
|
| 24 |
+
trainer.save(trainer_pt, save_optimizer=False)
|
| 25 |
+
return trainer_pt
|
| 26 |
+
|
| 27 |
+
def build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
|
| 28 |
+
"""
|
| 29 |
+
Build a Constituency Classifier model to be used by one of the later tests
|
| 30 |
+
"""
|
| 31 |
+
save_dir = str(tmp_path / "classifier")
|
| 32 |
+
save_name = "model.pt"
|
| 33 |
+
args = ["--save_dir", save_dir,
|
| 34 |
+
"--save_name", save_name,
|
| 35 |
+
"--model_type", "constituency",
|
| 36 |
+
"--constituency_model", constituency_model,
|
| 37 |
+
"--wordvec_pretrain_file", str(fake_embeddings),
|
| 38 |
+
"--fc_shapes", "20,10",
|
| 39 |
+
"--train_file", str(train_file_with_trees),
|
| 40 |
+
"--dev_file", str(dev_file_with_trees),
|
| 41 |
+
"--max_epochs", "2",
|
| 42 |
+
"--batch_size", "60"]
|
| 43 |
+
if extra_args is not None:
|
| 44 |
+
args = args + extra_args
|
| 45 |
+
args = classifier.parse_args(args)
|
| 46 |
+
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
|
| 47 |
+
trainer = Trainer.build_new_model(args, train_set)
|
| 48 |
+
return trainer, train_set, args
|
| 49 |
+
|
| 50 |
+
def run_training(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
|
| 51 |
+
"""
|
| 52 |
+
Iterate a couple times over a model
|
| 53 |
+
"""
|
| 54 |
+
trainer, train_set, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args)
|
| 55 |
+
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
|
| 56 |
+
labels = data.dataset_labels(train_set)
|
| 57 |
+
|
| 58 |
+
save_filename = os.path.join(args.save_dir, args.save_name)
|
| 59 |
+
checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)
|
| 60 |
+
classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)
|
| 61 |
+
return trainer, train_set, args
|
| 62 |
+
|
| 63 |
+
def test_build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
|
| 64 |
+
"""
|
| 65 |
+
Test that building a basic constituency-based model works
|
| 66 |
+
"""
|
| 67 |
+
self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
|
| 68 |
+
|
| 69 |
+
def test_save_load(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
|
| 70 |
+
"""
|
| 71 |
+
Test that a constituency model can save & load
|
| 72 |
+
"""
|
| 73 |
+
trainer, _, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
|
| 74 |
+
|
| 75 |
+
save_filename = os.path.join(args.save_dir, args.save_name)
|
| 76 |
+
trainer.save(save_filename)
|
| 77 |
+
|
| 78 |
+
args.load_name = args.save_name
|
| 79 |
+
trainer = Trainer.load(args.load_name, args)
|
| 80 |
+
|
| 81 |
+
def test_train_basic(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
|
| 82 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
|
| 83 |
+
|
| 84 |
+
def test_train_pipeline(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
|
| 85 |
+
"""
|
| 86 |
+
Test that writing out a temp model, then loading it in the pipeline is a thing that works
|
| 87 |
+
"""
|
| 88 |
+
trainer, _, args = self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
|
| 89 |
+
save_filename = os.path.join(args.save_dir, args.save_name)
|
| 90 |
+
assert os.path.exists(save_filename)
|
| 91 |
+
assert os.path.exists(args.constituency_model)
|
| 92 |
+
|
| 93 |
+
pipeline_args = {"lang": "en",
|
| 94 |
+
"download_method": None,
|
| 95 |
+
"model_dir": TEST_MODELS_DIR,
|
| 96 |
+
"processors": "tokenize,pos,constituency,sentiment",
|
| 97 |
+
"tokenize_pretokenized": True,
|
| 98 |
+
"constituency_model_path": args.constituency_model,
|
| 99 |
+
"constituency_pretrain_path": args.wordvec_pretrain_file,
|
| 100 |
+
"constituency_backward_charlm_path": None,
|
| 101 |
+
"constituency_forward_charlm_path": None,
|
| 102 |
+
"sentiment_model_path": save_filename,
|
| 103 |
+
"sentiment_pretrain_path": args.wordvec_pretrain_file,
|
| 104 |
+
"sentiment_backward_charlm_path": None,
|
| 105 |
+
"sentiment_forward_charlm_path": None}
|
| 106 |
+
pipeline = stanza.Pipeline(**pipeline_args)
|
| 107 |
+
doc = pipeline("This is a test")
|
| 108 |
+
# since the model is random, we have no expectations for what the result actually is
|
| 109 |
+
assert doc.sentences[0].sentiment is not None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test_train_all_words(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
|
| 113 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_all_words'])
|
| 114 |
+
|
| 115 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_all_words'])
|
| 116 |
+
|
| 117 |
+
def test_train_top_layer(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
|
| 118 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_top_layer'])
|
| 119 |
+
|
| 120 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_top_layer'])
|
| 121 |
+
|
| 122 |
+
def test_train_attn(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
|
| 123 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--no_constituency_all_words'])
|
| 124 |
+
|
| 125 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--constituency_all_words'])
|
| 126 |
+
|
| 127 |
+
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_node_attn'])
|
| 128 |
+
|
stanza/stanza/tests/common/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/common/test_chuliu_edmonds.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test some use cases of the chuliu_edmonds algorithm
|
| 3 |
+
|
| 4 |
+
(currently just the tarjan implementation)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stanza.models.common.chuliu_edmonds import tarjan
|
| 11 |
+
|
| 12 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 13 |
+
|
| 14 |
+
def test_tarjan_basic():
|
| 15 |
+
simple = np.array([0, 4, 4, 4, 0])
|
| 16 |
+
result = tarjan(simple)
|
| 17 |
+
assert result == []
|
| 18 |
+
|
| 19 |
+
simple = np.array([0, 2, 0, 4, 2, 2])
|
| 20 |
+
result = tarjan(simple)
|
| 21 |
+
assert result == []
|
| 22 |
+
|
| 23 |
+
def test_tarjan_cycle():
|
| 24 |
+
cycle_graph = np.array([0, 3, 1, 2])
|
| 25 |
+
result = tarjan(cycle_graph)
|
| 26 |
+
expected = np.array([False, True, True, True])
|
| 27 |
+
assert len(result) == 1
|
| 28 |
+
np.testing.assert_array_equal(result[0], expected)
|
| 29 |
+
|
| 30 |
+
cycle_graph = np.array([0, 3, 1, 2, 5, 6, 4])
|
| 31 |
+
result = tarjan(cycle_graph)
|
| 32 |
+
assert len(result) == 2
|
| 33 |
+
expected = [np.array([False, True, True, True, False, False, False]),
|
| 34 |
+
np.array([False, False, False, False, True, True, True])]
|
| 35 |
+
for r, e in zip(result, expected):
|
| 36 |
+
np.testing.assert_array_equal(r, e)
|
stanza/stanza/tests/common/test_confusion.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple simple confusion matrices and output formats
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from stanza.utils.confusion import format_confusion, confusion_to_f1, confusion_to_macro_f1, confusion_to_weighted_f1
|
| 9 |
+
|
| 10 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def simple_confusion():
|
| 14 |
+
confusion = defaultdict(lambda: defaultdict(int))
|
| 15 |
+
confusion["B-ORG"]["B-ORG"] = 1
|
| 16 |
+
confusion["B-ORG"]["B-PER"] = 1
|
| 17 |
+
confusion["E-ORG"]["E-ORG"] = 1
|
| 18 |
+
confusion["E-ORG"]["E-PER"] = 1
|
| 19 |
+
confusion["O"]["O"] = 4
|
| 20 |
+
return confusion
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def short_confusion():
|
| 24 |
+
"""
|
| 25 |
+
Same thing, but with a short name. This should not be sorted by entity type
|
| 26 |
+
"""
|
| 27 |
+
confusion = defaultdict(lambda: defaultdict(int))
|
| 28 |
+
confusion["A"]["B-ORG"] = 1
|
| 29 |
+
confusion["B-ORG"]["B-PER"] = 1
|
| 30 |
+
confusion["E-ORG"]["E-ORG"] = 1
|
| 31 |
+
confusion["E-ORG"]["E-PER"] = 1
|
| 32 |
+
confusion["O"]["O"] = 4
|
| 33 |
+
return confusion
|
| 34 |
+
|
| 35 |
+
EXPECTED_SIMPLE_OUTPUT = """
|
| 36 |
+
t\\p O B-ORG E-ORG B-PER E-PER
|
| 37 |
+
O 4 0 0 0 0
|
| 38 |
+
B-ORG 0 1 0 1 0
|
| 39 |
+
E-ORG 0 0 1 0 1
|
| 40 |
+
B-PER 0 0 0 0 0
|
| 41 |
+
E-PER 0 0 0 0 0
|
| 42 |
+
"""[1:-1] # don't want to strip
|
| 43 |
+
|
| 44 |
+
EXPECTED_SHORT_OUTPUT = """
|
| 45 |
+
t\\p O A B-ORG B-PER E-ORG E-PER
|
| 46 |
+
O 4 0 0 0 0 0
|
| 47 |
+
A 0 0 1 0 0 0
|
| 48 |
+
B-ORG 0 0 0 1 0 0
|
| 49 |
+
B-PER 0 0 0 0 0 0
|
| 50 |
+
E-ORG 0 0 0 0 1 1
|
| 51 |
+
E-PER 0 0 0 0 0 0
|
| 52 |
+
"""[1:-1]
|
| 53 |
+
|
| 54 |
+
EXPECTED_HIDE_BLANK_SHORT_OUTPUT = """
|
| 55 |
+
t\\p O B-ORG E-ORG B-PER E-PER
|
| 56 |
+
O 4 0 0 0 0
|
| 57 |
+
A 0 1 0 0 0
|
| 58 |
+
B-ORG 0 0 0 1 0
|
| 59 |
+
E-ORG 0 0 1 0 1
|
| 60 |
+
"""[1:-1]
|
| 61 |
+
|
| 62 |
+
def test_simple_output(simple_confusion):
|
| 63 |
+
assert EXPECTED_SIMPLE_OUTPUT == format_confusion(simple_confusion)
|
| 64 |
+
|
| 65 |
+
def test_short_output(short_confusion):
|
| 66 |
+
assert EXPECTED_SHORT_OUTPUT == format_confusion(short_confusion)
|
| 67 |
+
|
| 68 |
+
def test_hide_blank_short_output(short_confusion):
|
| 69 |
+
assert EXPECTED_HIDE_BLANK_SHORT_OUTPUT == format_confusion(short_confusion, hide_blank=True)
|
| 70 |
+
|
| 71 |
+
def test_macro_f1(simple_confusion, short_confusion):
|
| 72 |
+
assert confusion_to_macro_f1(simple_confusion) == pytest.approx(0.466666666666)
|
| 73 |
+
assert confusion_to_macro_f1(short_confusion) == pytest.approx(0.277777777777)
|
| 74 |
+
|
| 75 |
+
def test_weighted_f1(simple_confusion, short_confusion):
|
| 76 |
+
assert confusion_to_weighted_f1(simple_confusion) == pytest.approx(0.83333333)
|
| 77 |
+
assert confusion_to_weighted_f1(short_confusion) == pytest.approx(0.66666666)
|
| 78 |
+
|
| 79 |
+
assert confusion_to_weighted_f1(simple_confusion, exclude=["O"]) == pytest.approx(0.66666666)
|
| 80 |
+
assert confusion_to_weighted_f1(short_confusion, exclude=["O"]) == pytest.approx(0.33333333)
|
| 81 |
+
|
stanza/stanza/tests/common/test_constant.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test the conversion to lcodes and splitting of dataset names
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import tempfile
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
import stanza
|
| 10 |
+
from stanza.models.common.constant import treebank_to_short_name, lang_to_langcode, is_right_to_left, two_to_three_letters, langlower2lcode
|
| 11 |
+
from stanza.tests import *
|
| 12 |
+
|
| 13 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 14 |
+
|
| 15 |
+
def test_treebank():
|
| 16 |
+
"""
|
| 17 |
+
Test the entire treebank name conversion
|
| 18 |
+
"""
|
| 19 |
+
# conversion of a UD_ name
|
| 20 |
+
assert "hi_hdtb" == treebank_to_short_name("UD_Hindi-HDTB")
|
| 21 |
+
# conversion of names without UD
|
| 22 |
+
assert "hi_fire2013" == treebank_to_short_name("Hindi-fire2013")
|
| 23 |
+
assert "hi_fire2013" == treebank_to_short_name("Hindi-Fire2013")
|
| 24 |
+
assert "hi_fire2013" == treebank_to_short_name("Hindi-FIRE2013")
|
| 25 |
+
# already short names are generally preserved
|
| 26 |
+
assert "hi_fire2013" == treebank_to_short_name("hi-fire2013")
|
| 27 |
+
assert "hi_fire2013" == treebank_to_short_name("hi_fire2013")
|
| 28 |
+
# a special case
|
| 29 |
+
assert "zh-hant_pud" == treebank_to_short_name("UD_Chinese-PUD")
|
| 30 |
+
# a special case already converted once
|
| 31 |
+
assert "zh-hant_pud" == treebank_to_short_name("zh-hant_pud")
|
| 32 |
+
assert "zh-hant_pud" == treebank_to_short_name("zh-hant-pud")
|
| 33 |
+
assert "zh-hans_gsdsimp" == treebank_to_short_name("zh-hans_gsdsimp")
|
| 34 |
+
|
| 35 |
+
assert "wo_masakhane" == treebank_to_short_name("wo_masakhane")
|
| 36 |
+
assert "wo_masakhane" == treebank_to_short_name("wol_masakhane")
|
| 37 |
+
assert "wo_masakhane" == treebank_to_short_name("Wol_masakhane")
|
| 38 |
+
assert "wo_masakhane" == treebank_to_short_name("wolof_masakhane")
|
| 39 |
+
assert "wo_masakhane" == treebank_to_short_name("Wolof_masakhane")
|
| 40 |
+
|
| 41 |
+
def test_lang_to_langcode():
|
| 42 |
+
assert "hi" == lang_to_langcode("Hindi")
|
| 43 |
+
assert "hi" == lang_to_langcode("HINDI")
|
| 44 |
+
assert "hi" == lang_to_langcode("hindi")
|
| 45 |
+
assert "hi" == lang_to_langcode("HI")
|
| 46 |
+
assert "hi" == lang_to_langcode("hi")
|
| 47 |
+
|
| 48 |
+
def test_right_to_left():
|
| 49 |
+
assert is_right_to_left("ar")
|
| 50 |
+
assert is_right_to_left("Arabic")
|
| 51 |
+
|
| 52 |
+
assert not is_right_to_left("en")
|
| 53 |
+
assert not is_right_to_left("English")
|
| 54 |
+
|
| 55 |
+
def test_two_to_three():
|
| 56 |
+
assert lang_to_langcode("Wolof") == "wo"
|
| 57 |
+
assert lang_to_langcode("wol") == "wo"
|
| 58 |
+
|
| 59 |
+
assert "wo" in two_to_three_letters
|
| 60 |
+
assert two_to_three_letters["wo"] == "wol"
|
| 61 |
+
|
| 62 |
+
def test_langlower():
|
| 63 |
+
assert lang_to_langcode("WOLOF") == "wo"
|
| 64 |
+
assert lang_to_langcode("nOrWeGiAn") == "nb"
|
| 65 |
+
|
| 66 |
+
assert "soj" == langlower2lcode["soi"]
|
| 67 |
+
assert "soj" == langlower2lcode["sohi"]
|
stanza/stanza/tests/common/test_data_conversion.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic tests of the data conversion
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import pytest
|
| 7 |
+
import tempfile
|
| 8 |
+
from zipfile import ZipFile
|
| 9 |
+
|
| 10 |
+
import stanza
|
| 11 |
+
from stanza.utils.conll import CoNLL
|
| 12 |
+
from stanza.models.common.doc import Document
|
| 13 |
+
from stanza.tests import *
|
| 14 |
+
|
| 15 |
+
pytestmark = pytest.mark.pipeline
|
| 16 |
+
|
| 17 |
+
# data for testing
|
| 18 |
+
CONLL = [[['1', 'Nous', 'il', 'PRON', '_', 'Number=Plur|Person=1|PronType=Prs', '3', 'nsubj', '_', 'start_char=0|end_char=4'],
|
| 19 |
+
['2', 'avons', 'avoir', 'AUX', '_', 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', '3', 'aux:tense', '_', 'start_char=5|end_char=10'],
|
| 20 |
+
['3', 'atteint', 'atteindre', 'VERB', '_', 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', '0', 'root', '_', 'start_char=11|end_char=18'],
|
| 21 |
+
['4', 'la', 'le', 'DET', '_', 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', '5', 'det', '_', 'start_char=19|end_char=21'],
|
| 22 |
+
['5', 'fin', 'fin', 'NOUN', '_', 'Gender=Fem|Number=Sing', '3', 'obj', '_', 'start_char=22|end_char=25'],
|
| 23 |
+
['6-7', 'du', '_', '_', '_', '_', '_', '_', '_', 'start_char=26|end_char=28'],
|
| 24 |
+
['6', 'de', 'de', 'ADP', '_', '_', '8', 'case', '_', '_'],
|
| 25 |
+
['7', 'le', 'le', 'DET', '_', 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', '8', 'det', '_', '_'],
|
| 26 |
+
['8', 'sentier', 'sentier', 'NOUN', '_', 'Gender=Masc|Number=Sing', '5', 'nmod', '_', 'start_char=29|end_char=36'],
|
| 27 |
+
['9', '.', '.', 'PUNCT', '_', '_', '3', 'punct', '_', 'start_char=36|end_char=37']]]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DICT = [[{'id': (1,), 'text': 'Nous', 'lemma': 'il', 'upos': 'PRON', 'feats': 'Number=Plur|Person=1|PronType=Prs', 'head': 3, 'deprel': 'nsubj', 'misc': 'start_char=0|end_char=4'},
|
| 31 |
+
{'id': (2,), 'text': 'avons', 'lemma': 'avoir', 'upos': 'AUX', 'feats': 'Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin', 'head': 3, 'deprel': 'aux:tense', 'misc': 'start_char=5|end_char=10'},
|
| 32 |
+
{'id': (3,), 'text': 'atteint', 'lemma': 'atteindre', 'upos': 'VERB', 'feats': 'Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part', 'head': 0, 'deprel': 'root', 'misc': 'start_char=11|end_char=18'},
|
| 33 |
+
{'id': (4,), 'text': 'la', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Fem|Number=Sing|PronType=Art', 'head': 5, 'deprel': 'det', 'misc': 'start_char=19|end_char=21'},
|
| 34 |
+
{'id': (5,), 'text': 'fin', 'lemma': 'fin', 'upos': 'NOUN', 'feats': 'Gender=Fem|Number=Sing', 'head': 3, 'deprel': 'obj', 'misc': 'start_char=22|end_char=25'},
|
| 35 |
+
{'id': (6, 7), 'text': 'du', 'misc': 'start_char=26|end_char=28'},
|
| 36 |
+
{'id': (6,), 'text': 'de', 'lemma': 'de', 'upos': 'ADP', 'head': 8, 'deprel': 'case'},
|
| 37 |
+
{'id': (7,), 'text': 'le', 'lemma': 'le', 'upos': 'DET', 'feats': 'Definite=Def|Gender=Masc|Number=Sing|PronType=Art', 'head': 8, 'deprel': 'det'},
|
| 38 |
+
{'id': (8,), 'text': 'sentier', 'lemma': 'sentier', 'upos': 'NOUN', 'feats': 'Gender=Masc|Number=Sing', 'head': 5, 'deprel': 'nmod', 'misc': 'start_char=29|end_char=36'},
|
| 39 |
+
{'id': (9,), 'text': '.', 'lemma': '.', 'upos': 'PUNCT', 'head': 3, 'deprel': 'punct', 'misc': 'start_char=36|end_char=37'}]]
|
| 40 |
+
|
| 41 |
+
def test_conll_to_dict():
|
| 42 |
+
dicts, empty = CoNLL.convert_conll(CONLL)
|
| 43 |
+
assert dicts == DICT
|
| 44 |
+
assert len(dicts) == len(empty)
|
| 45 |
+
assert all(len(x) == 0 for x in empty)
|
| 46 |
+
|
| 47 |
+
def test_dict_to_conll():
|
| 48 |
+
document = Document(DICT)
|
| 49 |
+
# :c = no comments
|
| 50 |
+
conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")]
|
| 51 |
+
assert conll == CONLL
|
| 52 |
+
|
| 53 |
+
def test_dict_to_doc_and_doc_to_dict():
|
| 54 |
+
"""
|
| 55 |
+
Test the conversion from raw dict to Document and back
|
| 56 |
+
|
| 57 |
+
This code path will first turn start_char|end_char into start_char & end_char fields in the Document
|
| 58 |
+
That version to a dict will have separate fields for each of those
|
| 59 |
+
Finally, the conversion from that dict to a list of conll entries should convert that back to misc
|
| 60 |
+
"""
|
| 61 |
+
document = Document(DICT)
|
| 62 |
+
dicts = document.to_dict()
|
| 63 |
+
document = Document(dicts)
|
| 64 |
+
conll = [[sentence.split("\t") for sentence in doc.split("\n")] for doc in "{:c}".format(document).split("\n\n")]
|
| 65 |
+
assert conll == CONLL
|
| 66 |
+
|
| 67 |
+
# sample is two sentences long so that the tests check multiple sentences
|
| 68 |
+
RUSSIAN_SAMPLE="""
|
| 69 |
+
# sent_id = yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253
|
| 70 |
+
# genre = review
|
| 71 |
+
# text = Как- то слишком мало цветов получают актёры после спектакля.
|
| 72 |
+
1 Как как-то ADV _ Degree=Pos|PronType=Ind 7 advmod _ SpaceAfter=No
|
| 73 |
+
2 - - PUNCT _ _ 3 punct _ _
|
| 74 |
+
3 то то PART _ _ 1 list _ deprel=list:goeswith
|
| 75 |
+
4 слишком слишком ADV _ Degree=Pos 5 advmod _ _
|
| 76 |
+
5 мало мало ADV _ Degree=Pos 6 advmod _ _
|
| 77 |
+
6 цветов цветок NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Plur 7 obj _ _
|
| 78 |
+
7 получают получать VERB _ Aspect=Imp|Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 0 root _ _
|
| 79 |
+
8 актёры актер NOUN _ Animacy=Anim|Case=Nom|Gender=Masc|Number=Plur 7 nsubj _ _
|
| 80 |
+
9 после после ADP _ _ 10 case _ _
|
| 81 |
+
10 спектакля спектакль NOUN _ Animacy=Inan|Case=Gen|Gender=Masc|Number=Sing 7 obl _ SpaceAfter=No
|
| 82 |
+
11 . . PUNCT _ _ 7 punct _ _
|
| 83 |
+
|
| 84 |
+
# sent_id = 4
|
| 85 |
+
# genre = social
|
| 86 |
+
# text = В женщине важна верность, а не красота.
|
| 87 |
+
1 В в ADP _ _ 2 case _ _
|
| 88 |
+
2 женщине женщина NOUN _ Animacy=Anim|Case=Loc|Gender=Fem|Number=Sing 3 obl _ _
|
| 89 |
+
3 важна важный ADJ _ Degree=Pos|Gender=Fem|Number=Sing|Variant=Short 0 root _ _
|
| 90 |
+
4 верность верность NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 3 nsubj _ SpaceAfter=No
|
| 91 |
+
5 , , PUNCT _ _ 8 punct _ _
|
| 92 |
+
6 а а CCONJ _ _ 8 cc _ _
|
| 93 |
+
7 не не PART _ Polarity=Neg 8 advmod _ _
|
| 94 |
+
8 красота красота NOUN _ Animacy=Inan|Case=Nom|Gender=Fem|Number=Sing 4 conj _ SpaceAfter=No
|
| 95 |
+
9 . . PUNCT _ _ 3 punct _ _
|
| 96 |
+
""".strip()
|
| 97 |
+
|
| 98 |
+
RUSSIAN_TEXT = ["Как- то слишком мало цветов получают актёры после спектакля.", "В женщине важна верность, а не красота."]
|
| 99 |
+
RUSSIAN_IDS = ["yandex.reviews-f-8xh5zqnmwak3t6p68y4rhwd4e0-1969-9253", "4"]
|
| 100 |
+
|
| 101 |
+
def check_russian_doc(doc):
|
| 102 |
+
"""
|
| 103 |
+
Refactored the test for the Russian doc so we can use it to test various file methods
|
| 104 |
+
"""
|
| 105 |
+
lines = RUSSIAN_SAMPLE.split("\n")
|
| 106 |
+
assert len(doc.sentences) == 2
|
| 107 |
+
assert lines[0] == doc.sentences[0].comments[0]
|
| 108 |
+
assert lines[1] == doc.sentences[0].comments[1]
|
| 109 |
+
assert lines[2] == doc.sentences[0].comments[2]
|
| 110 |
+
for sent_idx, (expected_text, expected_id, sentence) in enumerate(zip(RUSSIAN_TEXT, RUSSIAN_IDS, doc.sentences)):
|
| 111 |
+
assert expected_text == sentence.text
|
| 112 |
+
assert expected_id == sentence.sent_id
|
| 113 |
+
assert sent_idx == sentence.index
|
| 114 |
+
assert len(sentence.comments) == 3
|
| 115 |
+
assert not sentence.has_enhanced_dependencies()
|
| 116 |
+
|
| 117 |
+
sentences = "{:C}".format(doc)
|
| 118 |
+
sentences = sentences.split("\n\n")
|
| 119 |
+
assert len(sentences) == 2
|
| 120 |
+
|
| 121 |
+
sentence = sentences[0].split("\n")
|
| 122 |
+
assert len(sentence) == 14
|
| 123 |
+
assert lines[0] == sentence[0]
|
| 124 |
+
assert lines[1] == sentence[1]
|
| 125 |
+
assert lines[2] == sentence[2]
|
| 126 |
+
|
| 127 |
+
# assert that the weird deprel=list:goeswith was properly handled
|
| 128 |
+
assert doc.sentences[0].words[2].head == 1
|
| 129 |
+
assert doc.sentences[0].words[2].deprel == "list:goeswith"
|
| 130 |
+
|
| 131 |
+
def test_write_russian_doc(tmp_path):
|
| 132 |
+
"""
|
| 133 |
+
Specifically test the write_doc2conll method
|
| 134 |
+
"""
|
| 135 |
+
filename = tmp_path / "russian.conll"
|
| 136 |
+
doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
|
| 137 |
+
check_russian_doc(doc)
|
| 138 |
+
CoNLL.write_doc2conll(doc, filename)
|
| 139 |
+
|
| 140 |
+
with open(filename, encoding="utf-8") as fin:
|
| 141 |
+
text = fin.read()
|
| 142 |
+
|
| 143 |
+
# the conll docs have to end with \n\n
|
| 144 |
+
assert text.endswith("\n\n")
|
| 145 |
+
|
| 146 |
+
# but to compare against the original, strip off the whitespace
|
| 147 |
+
text = text.strip()
|
| 148 |
+
|
| 149 |
+
# we skip the first sentence because the "deprel=list:goeswith" is weird
|
| 150 |
+
# note that the deprel itself is checked in check_russian_doc
|
| 151 |
+
text = text[text.find("# sent_id = 4"):]
|
| 152 |
+
sample = RUSSIAN_SAMPLE[RUSSIAN_SAMPLE.find("# sent_id = 4"):]
|
| 153 |
+
assert text == sample
|
| 154 |
+
|
| 155 |
+
doc2 = CoNLL.conll2doc(filename)
|
| 156 |
+
check_russian_doc(doc2)
|
| 157 |
+
|
| 158 |
+
# random sentence from EN_Pronouns
|
| 159 |
+
ENGLISH_SAMPLE = """
|
| 160 |
+
# newdoc
|
| 161 |
+
# sent_id = 1
|
| 162 |
+
# text = It is hers.
|
| 163 |
+
# previous = Which person owns this?
|
| 164 |
+
# comment = copular subject
|
| 165 |
+
1 It it PRON PRP Number=Sing|Person=3|PronType=Prs 3 nsubj _ _
|
| 166 |
+
2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _
|
| 167 |
+
3 hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
|
| 168 |
+
4 . . PUNCT . _ 3 punct _ _
|
| 169 |
+
""".strip()
|
| 170 |
+
|
| 171 |
+
def test_write_to_io():
|
| 172 |
+
doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)
|
| 173 |
+
output = io.StringIO()
|
| 174 |
+
CoNLL.write_doc2conll(doc, output)
|
| 175 |
+
output_value = output.getvalue()
|
| 176 |
+
assert output_value.endswith("\n\n")
|
| 177 |
+
assert output_value.strip() == ENGLISH_SAMPLE
|
| 178 |
+
|
| 179 |
+
def test_write_doc2conll_append(tmp_path):
|
| 180 |
+
doc = CoNLL.conll2doc(input_str=ENGLISH_SAMPLE)
|
| 181 |
+
filename = tmp_path / "english.conll"
|
| 182 |
+
CoNLL.write_doc2conll(doc, filename)
|
| 183 |
+
CoNLL.write_doc2conll(doc, filename, mode="a")
|
| 184 |
+
|
| 185 |
+
with open(filename) as fin:
|
| 186 |
+
text = fin.read()
|
| 187 |
+
expected = ENGLISH_SAMPLE + "\n\n" + ENGLISH_SAMPLE + "\n\n"
|
| 188 |
+
assert text == expected
|
| 189 |
+
|
| 190 |
+
def test_doc_with_comments():
|
| 191 |
+
"""
|
| 192 |
+
Test that a doc with comments gets converted back with comments
|
| 193 |
+
"""
|
| 194 |
+
doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
|
| 195 |
+
check_russian_doc(doc)
|
| 196 |
+
|
| 197 |
+
def test_unusual_misc():
|
| 198 |
+
"""
|
| 199 |
+
The above RUSSIAN_SAMPLE resulted in a blank misc field in one particular implementation of the conll code
|
| 200 |
+
(the below test would fail)
|
| 201 |
+
"""
|
| 202 |
+
doc = CoNLL.conll2doc(input_str=RUSSIAN_SAMPLE)
|
| 203 |
+
sentences = "{:C}".format(doc).split("\n\n")
|
| 204 |
+
assert len(sentences) == 2
|
| 205 |
+
sentence = sentences[0].split("\n")
|
| 206 |
+
assert len(sentence) == 14
|
| 207 |
+
|
| 208 |
+
for word in sentence:
|
| 209 |
+
pieces = word.split("\t")
|
| 210 |
+
assert len(pieces) == 1 or len(pieces) == 10
|
| 211 |
+
if len(pieces) == 10:
|
| 212 |
+
assert all(piece for piece in pieces)
|
| 213 |
+
|
| 214 |
+
def test_file():
|
| 215 |
+
"""
|
| 216 |
+
Test loading a doc from a file
|
| 217 |
+
"""
|
| 218 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 219 |
+
filename = os.path.join(tempdir, "russian.conll")
|
| 220 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 221 |
+
fout.write(RUSSIAN_SAMPLE)
|
| 222 |
+
doc = CoNLL.conll2doc(input_file=filename)
|
| 223 |
+
check_russian_doc(doc)
|
| 224 |
+
|
| 225 |
+
def test_zip_file():
|
| 226 |
+
"""
|
| 227 |
+
Test loading a doc from a zip file
|
| 228 |
+
"""
|
| 229 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 230 |
+
zip_file = os.path.join(tempdir, "russian.zip")
|
| 231 |
+
filename = "russian.conll"
|
| 232 |
+
with ZipFile(zip_file, "w") as zout:
|
| 233 |
+
with zout.open(filename, "w") as fout:
|
| 234 |
+
fout.write(RUSSIAN_SAMPLE.encode())
|
| 235 |
+
|
| 236 |
+
doc = CoNLL.conll2doc(input_file=filename, zip_file=zip_file)
|
| 237 |
+
check_russian_doc(doc)
|
| 238 |
+
|
| 239 |
+
SIMPLE_NER = """
|
| 240 |
+
# text = Teferi's best friend is Karn
|
| 241 |
+
# sent_id = 0
|
| 242 |
+
1 Teferi _ _ _ _ 0 _ _ start_char=0|end_char=6|ner=S-PERSON
|
| 243 |
+
2 's _ _ _ _ 1 _ _ start_char=6|end_char=8|ner=O
|
| 244 |
+
3 best _ _ _ _ 2 _ _ start_char=9|end_char=13|ner=O
|
| 245 |
+
4 friend _ _ _ _ 3 _ _ start_char=14|end_char=20|ner=O
|
| 246 |
+
5 is _ _ _ _ 4 _ _ start_char=21|end_char=23|ner=O
|
| 247 |
+
6 Karn _ _ _ _ 5 _ _ start_char=24|end_char=28|ner=S-PERSON
|
| 248 |
+
""".strip()
|
| 249 |
+
|
| 250 |
+
def test_simple_ner_conversion():
|
| 251 |
+
"""
|
| 252 |
+
Test that tokens get properly created with NER tags
|
| 253 |
+
"""
|
| 254 |
+
doc = CoNLL.conll2doc(input_str=SIMPLE_NER)
|
| 255 |
+
assert len(doc.sentences) == 1
|
| 256 |
+
sentence = doc.sentences[0]
|
| 257 |
+
assert len(sentence.tokens) == 6
|
| 258 |
+
EXPECTED_NER = ["S-PERSON", "O", "O", "O", "O", "S-PERSON"]
|
| 259 |
+
for token, ner in zip(sentence.tokens, EXPECTED_NER):
|
| 260 |
+
assert token.ner == ner
|
| 261 |
+
# check that the ner, start_char, end_char fields were not put on the token's misc
|
| 262 |
+
# those should all be set as specific fields on the token
|
| 263 |
+
assert not token.misc
|
| 264 |
+
assert len(token.words) == 1
|
| 265 |
+
# they should also not reach the word's misc field
|
| 266 |
+
assert not token.words[0].misc
|
| 267 |
+
|
| 268 |
+
conll = "{:C}".format(doc)
|
| 269 |
+
assert conll == SIMPLE_NER
|
| 270 |
+
|
| 271 |
+
MWT_NER = """
|
| 272 |
+
# text = This makes John's headache worse
|
| 273 |
+
# sent_id = 0
|
| 274 |
+
1 This _ _ _ _ 0 _ _ start_char=0|end_char=4|ner=O
|
| 275 |
+
2 makes _ _ _ _ 1 _ _ start_char=5|end_char=10|ner=O
|
| 276 |
+
3-4 John's _ _ _ _ _ _ _ start_char=11|end_char=17|ner=S-PERSON
|
| 277 |
+
3 John _ _ _ _ 2 _ _ _
|
| 278 |
+
4 's _ _ _ _ 3 _ _ _
|
| 279 |
+
5 headache _ _ _ _ 4 _ _ start_char=18|end_char=26|ner=O
|
| 280 |
+
6 worse _ _ _ _ 5 _ _ start_char=27|end_char=32|ner=O
|
| 281 |
+
""".strip()
|
| 282 |
+
|
| 283 |
+
def test_mwt_ner_conversion():
|
| 284 |
+
"""
|
| 285 |
+
Test that tokens including MWT get properly created with NER tags
|
| 286 |
+
|
| 287 |
+
Note that this kind of thing happens with the EWT tokenizer for English, for example
|
| 288 |
+
"""
|
| 289 |
+
doc = CoNLL.conll2doc(input_str=MWT_NER)
|
| 290 |
+
assert len(doc.sentences) == 1
|
| 291 |
+
sentence = doc.sentences[0]
|
| 292 |
+
assert len(sentence.tokens) == 5
|
| 293 |
+
assert not sentence.has_enhanced_dependencies()
|
| 294 |
+
EXPECTED_NER = ["O", "O", "S-PERSON", "O", "O"]
|
| 295 |
+
EXPECTED_WORDS = [1, 1, 2, 1, 1]
|
| 296 |
+
for token, ner, expected_words in zip(sentence.tokens, EXPECTED_NER, EXPECTED_WORDS):
|
| 297 |
+
assert token.ner == ner
|
| 298 |
+
# check that the ner, start_char, end_char fields were not put on the token's misc
|
| 299 |
+
# those should all be set as specific fields on the token
|
| 300 |
+
assert not token.misc
|
| 301 |
+
assert len(token.words) == expected_words
|
| 302 |
+
# they should also not reach the word's misc field
|
| 303 |
+
assert not token.words[0].misc
|
| 304 |
+
|
| 305 |
+
conll = "{:C}".format(doc)
|
| 306 |
+
assert conll == MWT_NER
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# A random sentence from et_ewt-ud-train.conllu
|
| 310 |
+
# which we use to test the deps conversion for multiple deps
|
| 311 |
+
ESTONIAN_DEPS = """
|
| 312 |
+
# newpar
|
| 313 |
+
# sent_id = aia_foorum_37
|
| 314 |
+
# text = Sestpeale ei mõistagi neid, kes koduaias sortidega tegelevad.
|
| 315 |
+
1 Sestpeale sest_peale ADV D _ 3 advmod 3:advmod _
|
| 316 |
+
2 ei ei AUX V Polarity=Neg 3 aux 3:aux _
|
| 317 |
+
3 mõistagi mõistma VERB V Connegative=Yes|Mood=Ind|Tense=Pres|VerbForm=Fin|Voice=Act 0 root 0:root _
|
| 318 |
+
4 neid tema PRON P Case=Par|Number=Plur|Person=3|PronType=Prs 3 obj 3:obj|9:nsubj SpaceAfter=No
|
| 319 |
+
5 , , PUNCT Z _ 9 punct 9:punct _
|
| 320 |
+
6 kes kes PRON P Case=Nom|Number=Plur|PronType=Int,Rel 9 nsubj 4:ref _
|
| 321 |
+
7 koduaias kodu_aed NOUN S Case=Ine|Number=Sing 9 obl 9:obl _
|
| 322 |
+
8 sortidega sort NOUN S Case=Com|Number=Plur 9 obl 9:obl _
|
| 323 |
+
9 tegelevad tegelema VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin|Voice=Act 4 acl:relcl 4:acl SpaceAfter=No
|
| 324 |
+
10 . . PUNCT Z _ 3 punct 3:punct _
|
| 325 |
+
""".strip()
|
| 326 |
+
|
| 327 |
+
def test_deps_conversion():
|
| 328 |
+
doc = CoNLL.conll2doc(input_str=ESTONIAN_DEPS)
|
| 329 |
+
assert len(doc.sentences) == 1
|
| 330 |
+
sentence = doc.sentences[0]
|
| 331 |
+
assert len(sentence.tokens) == 10
|
| 332 |
+
assert sentence.has_enhanced_dependencies()
|
| 333 |
+
|
| 334 |
+
word = doc.sentences[0].words[3]
|
| 335 |
+
assert word.deps == "3:obj|9:nsubj"
|
| 336 |
+
|
| 337 |
+
conll = "{:C}".format(doc)
|
| 338 |
+
assert conll == ESTONIAN_DEPS
|
| 339 |
+
|
| 340 |
+
ESTONIAN_EMPTY_DEPS = """
|
| 341 |
+
# sent_id = ewtb2_000035_15
|
| 342 |
+
# text = Ja paari aasta pärast rôômalt maasikatele ...
|
| 343 |
+
1 Ja ja CCONJ J _ 3 cc 5.1:cc _
|
| 344 |
+
2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
|
| 345 |
+
3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
|
| 346 |
+
4 pärast pärast ADP K AdpType=Post 3 case 3:case _
|
| 347 |
+
5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
|
| 348 |
+
5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
|
| 349 |
+
6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes
|
| 350 |
+
7 ... ... PUNCT Z _ 3 punct 5.1:punct _
|
| 351 |
+
""".strip()
|
| 352 |
+
|
| 353 |
+
ESTONIAN_EMPTY_END_DEPS = """
|
| 354 |
+
# sent_id = ewtb2_000035_15
|
| 355 |
+
# text = Ja paari aasta pärast rôômalt maasikatele ...
|
| 356 |
+
1 Ja ja CCONJ J _ 3 cc 5.1:cc _
|
| 357 |
+
2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
|
| 358 |
+
3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
|
| 359 |
+
4 pärast pärast ADP K AdpType=Post 3 case 3:case _
|
| 360 |
+
5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
|
| 361 |
+
5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
|
| 362 |
+
""".strip()
|
| 363 |
+
|
| 364 |
+
def test_empty_deps_conversion():
|
| 365 |
+
"""
|
| 366 |
+
Check that we can read and then output a sentence with empty dependencies
|
| 367 |
+
"""
|
| 368 |
+
check_empty_deps_conversion(ESTONIAN_EMPTY_DEPS, 7)
|
| 369 |
+
|
| 370 |
+
def test_empty_deps_at_end_conversion():
|
| 371 |
+
"""
|
| 372 |
+
The empty deps conversion should also work if the empty dep is at the end
|
| 373 |
+
"""
|
| 374 |
+
check_empty_deps_conversion(ESTONIAN_EMPTY_END_DEPS, 5)
|
| 375 |
+
|
| 376 |
+
def check_empty_deps_conversion(input_str, expected_words):
|
| 377 |
+
doc = CoNLL.conll2doc(input_str=input_str, ignore_gapping=False)
|
| 378 |
+
assert len(doc.sentences) == 1
|
| 379 |
+
assert len(doc.sentences[0].tokens) == expected_words
|
| 380 |
+
assert len(doc.sentences[0].words) == expected_words
|
| 381 |
+
assert len(doc.sentences[0].empty_words) == 1
|
| 382 |
+
|
| 383 |
+
sentence = doc.sentences[0]
|
| 384 |
+
conll = "{:C}".format(doc)
|
| 385 |
+
assert conll == input_str
|
| 386 |
+
|
| 387 |
+
sentence_dict = doc.sentences[0].to_dict()
|
| 388 |
+
assert len(sentence_dict) == expected_words + 1
|
| 389 |
+
# currently this is true for both of the examples we run
|
| 390 |
+
assert sentence_dict[5]['id'] == (5, 1)
|
| 391 |
+
|
| 392 |
+
# redo the above checks to make sure
|
| 393 |
+
# there are no weird bugs in the accessors
|
| 394 |
+
assert len(doc.sentences) == 1
|
| 395 |
+
assert len(doc.sentences[0].tokens) == expected_words
|
| 396 |
+
assert len(doc.sentences[0].words) == expected_words
|
| 397 |
+
assert len(doc.sentences[0].empty_words) == 1
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
ESTONIAN_DOC_ID = """
|
| 401 |
+
# doc_id = this_is_a_doc
|
| 402 |
+
# sent_id = ewtb2_000035_15
|
| 403 |
+
# text = Ja paari aasta pärast rôômalt maasikatele ...
|
| 404 |
+
1 Ja ja CCONJ J _ 3 cc 5.1:cc _
|
| 405 |
+
2 paari paar NUM N Case=Gen|Number=Sing|NumForm=Word|NumType=Card 3 nummod 3:nummod _
|
| 406 |
+
3 aasta aasta NOUN S Case=Gen|Number=Sing 0 root 5.1:obl _
|
| 407 |
+
4 pärast pärast ADP K AdpType=Post 3 case 3:case _
|
| 408 |
+
5 rôômalt rõõmsalt ADV D Typo=Yes 3 advmod 5.1:advmod Orphan=Yes|CorrectForm=rõõmsalt
|
| 409 |
+
5.1 panna panema VERB V VerbForm=Inf _ _ 0:root Empty=5.1
|
| 410 |
+
6 maasikatele maasikas NOUN S Case=All|Number=Plur 3 obl 5.1:obl Orphan=Yes
|
| 411 |
+
7 ... ... PUNCT Z _ 3 punct 5.1:punct _
|
| 412 |
+
""".strip()
|
| 413 |
+
|
| 414 |
+
def test_read_doc_id():
|
| 415 |
+
doc = CoNLL.conll2doc(input_str=ESTONIAN_DOC_ID, ignore_gapping=False)
|
| 416 |
+
assert "{:C}".format(doc) == ESTONIAN_DOC_ID
|
| 417 |
+
assert doc.sentences[0].doc_id == 'this_is_a_doc'
|
| 418 |
+
|
| 419 |
+
SIMPLE_DEPENDENCY_INDEX_ERROR = """
|
| 420 |
+
# text = Teferi's best friend is Karn
|
| 421 |
+
# sent_id = 0
|
| 422 |
+
# notes = this sentence has a dependency index outside the sentence. it should throw an IndexError
|
| 423 |
+
1 Teferi _ _ _ _ 0 root _ start_char=0|end_char=6|ner=S-PERSON
|
| 424 |
+
2 's _ _ _ _ 1 dep _ start_char=6|end_char=8|ner=O
|
| 425 |
+
3 best _ _ _ _ 2 dep _ start_char=9|end_char=13|ner=O
|
| 426 |
+
4 friend _ _ _ _ 3 dep _ start_char=14|end_char=20|ner=O
|
| 427 |
+
5 is _ _ _ _ 4 dep _ start_char=21|end_char=23|ner=O
|
| 428 |
+
6 Karn _ _ _ _ 8 dep _ start_char=24|end_char=28|ner=S-PERSON
|
| 429 |
+
""".strip()
|
| 430 |
+
|
| 431 |
+
def test_read_dependency_errors():
|
| 432 |
+
with pytest.raises(IndexError):
|
| 433 |
+
doc = CoNLL.conll2doc(input_str=SIMPLE_DEPENDENCY_INDEX_ERROR)
|
| 434 |
+
|
| 435 |
+
MULTIPLE_DOC_IDS = """
|
| 436 |
+
# doc_id = doc_1
|
| 437 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0020
|
| 438 |
+
# text = His mother was also killed in the attack.
|
| 439 |
+
1 His his PRON PRP$ Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs 2 nmod:poss 2:nmod:poss _
|
| 440 |
+
2 mother mother NOUN NN Number=Sing 5 nsubj:pass 5:nsubj:pass _
|
| 441 |
+
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 5 aux:pass 5:aux:pass _
|
| 442 |
+
4 also also ADV RB _ 5 advmod 5:advmod _
|
| 443 |
+
5 killed kill VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
|
| 444 |
+
6 in in ADP IN _ 8 case 8:case _
|
| 445 |
+
7 the the DET DT Definite=Def|PronType=Art 8 det 8:det _
|
| 446 |
+
8 attack attack NOUN NN Number=Sing 5 obl 5:obl:in SpaceAfter=No
|
| 447 |
+
9 . . PUNCT . _ 5 punct 5:punct _
|
| 448 |
+
|
| 449 |
+
# doc_id = doc_1
|
| 450 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0028
|
| 451 |
+
# text = This item is a small one and easily missed.
|
| 452 |
+
1 This this DET DT Number=Sing|PronType=Dem 2 det 2:det _
|
| 453 |
+
2 item item NOUN NN Number=Sing 6 nsubj 6:nsubj|9:nsubj:pass _
|
| 454 |
+
3 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
|
| 455 |
+
4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
|
| 456 |
+
5 small small ADJ JJ Degree=Pos 6 amod 6:amod _
|
| 457 |
+
6 one one NOUN NN Number=Sing 0 root 0:root _
|
| 458 |
+
7 and and CCONJ CC _ 9 cc 9:cc _
|
| 459 |
+
8 easily easily ADV RB _ 9 advmod 9:advmod _
|
| 460 |
+
9 missed miss VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 6 conj 6:conj:and SpaceAfter=No
|
| 461 |
+
10 . . PUNCT . _ 6 punct 6:punct _
|
| 462 |
+
|
| 463 |
+
# doc_id = doc_2
|
| 464 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0029
|
| 465 |
+
# text = But in my view it is highly significant.
|
| 466 |
+
1 But but CCONJ CC _ 8 cc 8:cc _
|
| 467 |
+
2 in in ADP IN _ 4 case 4:case _
|
| 468 |
+
3 my my PRON PRP$ Case=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs 4 nmod:poss 4:nmod:poss _
|
| 469 |
+
4 view view NOUN NN Number=Sing 8 obl 8:obl:in _
|
| 470 |
+
5 it it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 8 nsubj 8:nsubj _
|
| 471 |
+
6 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 8 cop 8:cop _
|
| 472 |
+
7 highly highly ADV RB _ 8 advmod 8:advmod _
|
| 473 |
+
8 significant significant ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No
|
| 474 |
+
9 . . PUNCT . _ 8 punct 8:punct _
|
| 475 |
+
|
| 476 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0040
|
| 477 |
+
# text = The trial begins again Nov.28.
|
| 478 |
+
1 The the DET DT Definite=Def|PronType=Art 2 det 2:det _
|
| 479 |
+
2 trial trial NOUN NN Number=Sing 3 nsubj 3:nsubj _
|
| 480 |
+
3 begins begin VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
|
| 481 |
+
4 again again ADV RB _ 3 advmod 3:advmod _
|
| 482 |
+
5 Nov. November PROPN NNP Abbr=Yes|Number=Sing 3 obl:tmod 3:obl:tmod SpaceAfter=No
|
| 483 |
+
6 28 28 NUM CD NumForm=Digit|NumType=Card 5 nummod 5:nummod SpaceAfter=No
|
| 484 |
+
7 . . PUNCT . _ 3 punct 3:punct _
|
| 485 |
+
|
| 486 |
+
""".lstrip()
|
| 487 |
+
|
| 488 |
+
def test_read_multiple_doc_ids():
|
| 489 |
+
docs = CoNLL.conll2multi_docs(input_str=MULTIPLE_DOC_IDS)
|
| 490 |
+
assert len(docs) == 2
|
| 491 |
+
assert len(docs[0].sentences) == 2
|
| 492 |
+
assert len(docs[1].sentences) == 2
|
| 493 |
+
|
| 494 |
+
# remove the first doc_id comment
|
| 495 |
+
text = "\n".join(MULTIPLE_DOC_IDS.split("\n")[1:])
|
| 496 |
+
docs = CoNLL.conll2multi_docs(input_str=text)
|
| 497 |
+
assert len(docs) == 3
|
| 498 |
+
assert len(docs[0].sentences) == 1
|
| 499 |
+
assert len(docs[1].sentences) == 1
|
| 500 |
+
assert len(docs[2].sentences) == 2
|
| 501 |
+
|
| 502 |
+
ENGLISH_TEST_SENTENCE = """
|
| 503 |
+
# text = This is a test
|
| 504 |
+
# sent_id = 0
|
| 505 |
+
1 This this PRON DT Number=Sing|PronType=Dem 4 nsubj _ start_char=0|end_char=4
|
| 506 |
+
2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ start_char=5|end_char=7
|
| 507 |
+
3 a a DET DT Definite=Ind|PronType=Art 4 det _ start_char=8|end_char=9
|
| 508 |
+
4 test test NOUN NN Number=Sing 0 root _ start_char=10|end_char=14|SpaceAfter=No
|
| 509 |
+
""".lstrip()
|
| 510 |
+
|
| 511 |
+
def test_convert_dict():
|
| 512 |
+
doc = CoNLL.conll2doc(input_str=ENGLISH_TEST_SENTENCE)
|
| 513 |
+
converted = CoNLL.convert_dict(doc.to_dict())
|
| 514 |
+
|
| 515 |
+
expected = [[['1', 'This', 'this', 'PRON', 'DT', 'Number=Sing|PronType=Dem', '4', 'nsubj', '_', 'start_char=0|end_char=4'],
|
| 516 |
+
['2', 'is', 'be', 'AUX', 'VBZ', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', '4', 'cop', '_', 'start_char=5|end_char=7'],
|
| 517 |
+
['3', 'a', 'a', 'DET', 'DT', 'Definite=Ind|PronType=Art', '4', 'det', '_', 'start_char=8|end_char=9'],
|
| 518 |
+
['4', 'test', 'test', 'NOUN', 'NN', 'Number=Sing', '0', 'root', '_', 'SpaceAfter=No|start_char=10|end_char=14']]]
|
| 519 |
+
|
| 520 |
+
assert converted == expected
|
stanza/stanza/tests/common/test_foundation_cache.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
import stanza
|
| 9 |
+
from stanza.models.common.foundation_cache import FoundationCache, load_charlm
|
| 10 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 11 |
+
|
| 12 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 13 |
+
|
| 14 |
+
def test_charlm_cache():
|
| 15 |
+
models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*")
|
| 16 |
+
models = glob.glob(models_path)
|
| 17 |
+
# we expect at least one English model downloaded for the tests
|
| 18 |
+
assert len(models) >= 1
|
| 19 |
+
model_file = models[0]
|
| 20 |
+
|
| 21 |
+
cache = FoundationCache()
|
| 22 |
+
with tempfile.TemporaryDirectory(dir=".") as test_dir:
|
| 23 |
+
temp_file = os.path.join(test_dir, "charlm.pt")
|
| 24 |
+
shutil.copy2(model_file, temp_file)
|
| 25 |
+
# this will work
|
| 26 |
+
model = load_charlm(temp_file)
|
| 27 |
+
|
| 28 |
+
# this will save the model
|
| 29 |
+
model = cache.load_charlm(temp_file)
|
| 30 |
+
|
| 31 |
+
# this should no longer work
|
| 32 |
+
with pytest.raises(FileNotFoundError):
|
| 33 |
+
model = load_charlm(temp_file)
|
| 34 |
+
|
| 35 |
+
# it should remember the cached version
|
| 36 |
+
model = cache.load_charlm(temp_file)
|
stanza/stanza/tests/common/test_pretrain.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from stanza.models.common import pretrain
|
| 9 |
+
from stanza.models.common.vocab import UNK_ID
|
| 10 |
+
from stanza.tests import *
|
| 11 |
+
|
| 12 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 13 |
+
|
| 14 |
+
def check_vocab(vocab):
|
| 15 |
+
# 4 base vectors, plus the 3 vectors actually present in the file
|
| 16 |
+
assert len(vocab) == 7
|
| 17 |
+
assert 'unban' in vocab
|
| 18 |
+
assert 'mox' in vocab
|
| 19 |
+
assert 'opal' in vocab
|
| 20 |
+
|
| 21 |
+
def check_embedding(emb, unk=False):
|
| 22 |
+
expected = np.array([[ 0., 0., 0., 0.,],
|
| 23 |
+
[ 0., 0., 0., 0.,],
|
| 24 |
+
[ 0., 0., 0., 0.,],
|
| 25 |
+
[ 0., 0., 0., 0.,],
|
| 26 |
+
[ 1., 2., 3., 4.,],
|
| 27 |
+
[ 5., 6., 7., 8.,],
|
| 28 |
+
[ 9., 10., 11., 12.,]])
|
| 29 |
+
if unk:
|
| 30 |
+
expected[UNK_ID] = -1
|
| 31 |
+
np.testing.assert_allclose(emb, expected)
|
| 32 |
+
|
| 33 |
+
def check_pretrain(pt):
|
| 34 |
+
check_vocab(pt.vocab)
|
| 35 |
+
check_embedding(pt.emb)
|
| 36 |
+
|
| 37 |
+
def test_text_pretrain():
|
| 38 |
+
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.txt', save_to_file=False)
|
| 39 |
+
check_pretrain(pt)
|
| 40 |
+
|
| 41 |
+
def test_xz_pretrain():
|
| 42 |
+
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz', save_to_file=False)
|
| 43 |
+
check_pretrain(pt)
|
| 44 |
+
|
| 45 |
+
def test_gz_pretrain():
|
| 46 |
+
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.gz', save_to_file=False)
|
| 47 |
+
check_pretrain(pt)
|
| 48 |
+
|
| 49 |
+
def test_zip_pretrain():
|
| 50 |
+
pt = pretrain.Pretrain(vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.zip', save_to_file=False)
|
| 51 |
+
check_pretrain(pt)
|
| 52 |
+
|
| 53 |
+
def test_csv_pretrain():
|
| 54 |
+
pt = pretrain.Pretrain(csv_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.csv', save_to_file=False)
|
| 55 |
+
check_pretrain(pt)
|
| 56 |
+
|
| 57 |
+
def test_resave_pretrain():
|
| 58 |
+
"""
|
| 59 |
+
Test saving a pretrain and then loading from the existing file
|
| 60 |
+
"""
|
| 61 |
+
test_pt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".pt", delete=False)
|
| 62 |
+
try:
|
| 63 |
+
test_pt_file.close()
|
| 64 |
+
# note that this tests the ability to save a pretrain and the
|
| 65 |
+
# ability to fall back when the existing pretrain isn't working
|
| 66 |
+
pt = pretrain.Pretrain(filename=test_pt_file.name,
|
| 67 |
+
vec_filename=f'{TEST_WORKING_DIR}/in/tiny_emb.xz')
|
| 68 |
+
check_pretrain(pt)
|
| 69 |
+
|
| 70 |
+
pt2 = pretrain.Pretrain(filename=test_pt_file.name,
|
| 71 |
+
vec_filename=f'unban_mox_opal')
|
| 72 |
+
check_pretrain(pt2)
|
| 73 |
+
|
| 74 |
+
pt3 = torch.load(test_pt_file.name, weights_only=True)
|
| 75 |
+
check_embedding(pt3['emb'])
|
| 76 |
+
finally:
|
| 77 |
+
os.unlink(test_pt_file.name)
|
| 78 |
+
|
| 79 |
+
SPACE_PRETRAIN="""
|
| 80 |
+
3 4
|
| 81 |
+
unban mox 1 2 3 4
|
| 82 |
+
opal 5 6 7 8
|
| 83 |
+
foo 9 10 11 12
|
| 84 |
+
""".strip()
|
| 85 |
+
|
| 86 |
+
def test_whitespace():
|
| 87 |
+
"""
|
| 88 |
+
Test reading a pretrain with an ascii space in it
|
| 89 |
+
|
| 90 |
+
The vocab word with a space in it should have the correct number
|
| 91 |
+
of dimensions read, with the space converted to nbsp
|
| 92 |
+
"""
|
| 93 |
+
test_txt_file = tempfile.NamedTemporaryFile(dir=f'{TEST_WORKING_DIR}/out', suffix=".txt", delete=False)
|
| 94 |
+
try:
|
| 95 |
+
test_txt_file.write(SPACE_PRETRAIN.encode())
|
| 96 |
+
test_txt_file.close()
|
| 97 |
+
|
| 98 |
+
pt = pretrain.Pretrain(vec_filename=test_txt_file.name, save_to_file=False)
|
| 99 |
+
check_embedding(pt.emb)
|
| 100 |
+
assert "unban\xa0mox" in pt.vocab
|
| 101 |
+
# this one also works because of the normalize_unit in vocab.py
|
| 102 |
+
assert "unban mox" in pt.vocab
|
| 103 |
+
finally:
|
| 104 |
+
os.unlink(test_txt_file.name)
|
| 105 |
+
|
| 106 |
+
NO_HEADER_PRETRAIN="""
|
| 107 |
+
unban 1 2 3 4
|
| 108 |
+
mox 5 6 7 8
|
| 109 |
+
opal 9 10 11 12
|
| 110 |
+
""".strip()
|
| 111 |
+
|
| 112 |
+
def test_no_header():
|
| 113 |
+
"""
|
| 114 |
+
Check loading a pretrain with no rows,cols header
|
| 115 |
+
"""
|
| 116 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:
|
| 117 |
+
filename = os.path.join(tmpdir, "tiny.txt")
|
| 118 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 119 |
+
fout.write(NO_HEADER_PRETRAIN)
|
| 120 |
+
pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)
|
| 121 |
+
check_embedding(pt.emb)
|
| 122 |
+
|
| 123 |
+
UNK_PRETRAIN="""
|
| 124 |
+
unban 1 2 3 4
|
| 125 |
+
mox 5 6 7 8
|
| 126 |
+
opal 9 10 11 12
|
| 127 |
+
<unk> -1 -1 -1 -1
|
| 128 |
+
""".strip()
|
| 129 |
+
|
| 130 |
+
def test_no_header():
|
| 131 |
+
"""
|
| 132 |
+
Check loading a pretrain with <unk> at the end, like GloVe does
|
| 133 |
+
"""
|
| 134 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdir:
|
| 135 |
+
filename = os.path.join(tmpdir, "tiny.txt")
|
| 136 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 137 |
+
fout.write(UNK_PRETRAIN)
|
| 138 |
+
pt = pretrain.Pretrain(vec_filename=filename, save_to_file=False)
|
| 139 |
+
check_embedding(pt.emb, unk=True)
|
stanza/stanza/tests/common/test_utils.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import lzma
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
import stanza
|
| 8 |
+
import stanza.models.common.utils as utils
|
| 9 |
+
from stanza.tests import *
|
| 10 |
+
|
| 11 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 12 |
+
|
| 13 |
+
def test_wordvec_not_found():
|
| 14 |
+
"""
|
| 15 |
+
get_wordvec_file should fail if neither word2vec nor fasttext exists
|
| 16 |
+
"""
|
| 17 |
+
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
|
| 18 |
+
with pytest.raises(FileNotFoundError):
|
| 19 |
+
utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_word2vec_xz():
|
| 23 |
+
"""
|
| 24 |
+
Test searching for word2vec and xz files
|
| 25 |
+
"""
|
| 26 |
+
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
|
| 27 |
+
# make a fake directory for English word vectors
|
| 28 |
+
word2vec_dir = os.path.join(temp_dir, 'word2vec', 'English')
|
| 29 |
+
os.makedirs(word2vec_dir)
|
| 30 |
+
|
| 31 |
+
# make a fake English word vector file
|
| 32 |
+
fake_file = os.path.join(word2vec_dir, 'en.vectors.xz')
|
| 33 |
+
fout = open(fake_file, 'w')
|
| 34 |
+
fout.close()
|
| 35 |
+
|
| 36 |
+
# get_wordvec_file should now find this fake file
|
| 37 |
+
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
|
| 38 |
+
assert filename == fake_file
|
| 39 |
+
|
| 40 |
+
def test_fasttext_txt():
|
| 41 |
+
"""
|
| 42 |
+
Test searching for fasttext and txt files
|
| 43 |
+
"""
|
| 44 |
+
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
|
| 45 |
+
# make a fake directory for English word vectors
|
| 46 |
+
fasttext_dir = os.path.join(temp_dir, 'fasttext', 'English')
|
| 47 |
+
os.makedirs(fasttext_dir)
|
| 48 |
+
|
| 49 |
+
# make a fake English word vector file
|
| 50 |
+
fake_file = os.path.join(fasttext_dir, 'en.vectors.txt')
|
| 51 |
+
fout = open(fake_file, 'w')
|
| 52 |
+
fout.close()
|
| 53 |
+
|
| 54 |
+
# get_wordvec_file should now find this fake file
|
| 55 |
+
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
|
| 56 |
+
assert filename == fake_file
|
| 57 |
+
|
| 58 |
+
def test_wordvec_type():
|
| 59 |
+
"""
|
| 60 |
+
If we supply our own wordvec type, get_wordvec_file should find that
|
| 61 |
+
"""
|
| 62 |
+
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir:
|
| 63 |
+
# make a fake directory for English word vectors
|
| 64 |
+
google_dir = os.path.join(temp_dir, 'google', 'English')
|
| 65 |
+
os.makedirs(google_dir)
|
| 66 |
+
|
| 67 |
+
# make a fake English word vector file
|
| 68 |
+
fake_file = os.path.join(google_dir, 'en.vectors.txt')
|
| 69 |
+
fout = open(fake_file, 'w')
|
| 70 |
+
fout.close()
|
| 71 |
+
|
| 72 |
+
# get_wordvec_file should now find this fake file
|
| 73 |
+
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo', wordvec_type='google')
|
| 74 |
+
assert filename == fake_file
|
| 75 |
+
|
| 76 |
+
# this file won't be found using the normal defaults
|
| 77 |
+
with pytest.raises(FileNotFoundError):
|
| 78 |
+
utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo')
|
| 79 |
+
|
| 80 |
+
def test_sort_with_indices():
|
| 81 |
+
data = [[1, 2, 3], [4, 5], [6]]
|
| 82 |
+
ordered, orig_idx = utils.sort_with_indices(data, key=len)
|
| 83 |
+
assert ordered == ([6], [4, 5], [1, 2, 3])
|
| 84 |
+
assert orig_idx == (2, 1, 0)
|
| 85 |
+
|
| 86 |
+
unsorted = utils.unsort(ordered, orig_idx)
|
| 87 |
+
assert data == unsorted
|
| 88 |
+
|
| 89 |
+
def test_empty_sort_with_indices():
|
| 90 |
+
ordered, orig_idx = utils.sort_with_indices([])
|
| 91 |
+
assert len(ordered) == 0
|
| 92 |
+
assert len(orig_idx) == 0
|
| 93 |
+
|
| 94 |
+
unsorted = utils.unsort(ordered, orig_idx)
|
| 95 |
+
assert [] == unsorted
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_split_into_batches():
|
| 99 |
+
data = []
|
| 100 |
+
for i in range(5):
|
| 101 |
+
data.append(["Unban", "mox", "opal", str(i)])
|
| 102 |
+
|
| 103 |
+
data.append(["Do", "n't", "ban", "Urza", "'s", "Saga", "that", "card", "is", "great"])
|
| 104 |
+
data.append(["Ban", "Ragavan"])
|
| 105 |
+
|
| 106 |
+
# small batches will put one element in each interval
|
| 107 |
+
batches = utils.split_into_batches(data, 5)
|
| 108 |
+
assert batches == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
|
| 109 |
+
|
| 110 |
+
# this one has a batch interrupted in the middle by a large element
|
| 111 |
+
batches = utils.split_into_batches(data, 8)
|
| 112 |
+
assert batches == [(0, 2), (2, 4), (4, 5), (5, 6), (6, 7)]
|
| 113 |
+
|
| 114 |
+
# this one has the large element at the start of its own batch
|
| 115 |
+
batches = utils.split_into_batches(data[1:], 8)
|
| 116 |
+
assert batches == [(0, 2), (2, 4), (4, 5), (5, 6)]
|
| 117 |
+
|
| 118 |
+
# overloading the test! assert that the key & reverse is working
|
| 119 |
+
ordered, orig_idx = utils.sort_with_indices(data, key=len, reverse=True)
|
| 120 |
+
assert [len(x) for x in ordered] == [10, 4, 4, 4, 4, 4, 2]
|
| 121 |
+
|
| 122 |
+
# this has the large element at the start
|
| 123 |
+
batches = utils.split_into_batches(ordered, 8)
|
| 124 |
+
assert batches == [(0, 1), (1, 3), (3, 5), (5, 7)]
|
| 125 |
+
|
| 126 |
+
# double check that unsort is working as expected
|
| 127 |
+
assert data == utils.unsort(ordered, orig_idx)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_find_missing_tags():
|
| 131 |
+
assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == []
|
| 132 |
+
assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG']
|
| 133 |
+
assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG']
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def test_open_read_text():
|
| 137 |
+
"""
|
| 138 |
+
test that we can read either .xz or regular txt
|
| 139 |
+
"""
|
| 140 |
+
TEXT = "this is a test"
|
| 141 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 142 |
+
# test text file
|
| 143 |
+
filename = os.path.join(tempdir, "foo.txt")
|
| 144 |
+
with open(filename, "w") as fout:
|
| 145 |
+
fout.write(TEXT)
|
| 146 |
+
with utils.open_read_text(filename) as fin:
|
| 147 |
+
in_text = fin.read()
|
| 148 |
+
assert TEXT == in_text
|
| 149 |
+
|
| 150 |
+
assert fin.closed
|
| 151 |
+
|
| 152 |
+
# the context should close the file when we throw an exception!
|
| 153 |
+
try:
|
| 154 |
+
with utils.open_read_text(filename) as finex:
|
| 155 |
+
assert not finex.closed
|
| 156 |
+
raise ValueError("unban mox opal!")
|
| 157 |
+
except ValueError:
|
| 158 |
+
pass
|
| 159 |
+
assert finex.closed
|
| 160 |
+
|
| 161 |
+
# test xz file
|
| 162 |
+
filename = os.path.join(tempdir, "foo.txt.xz")
|
| 163 |
+
with lzma.open(filename, "wt") as fout:
|
| 164 |
+
fout.write(TEXT)
|
| 165 |
+
with utils.open_read_text(filename) as finxz:
|
| 166 |
+
in_text = finxz.read()
|
| 167 |
+
assert TEXT == in_text
|
| 168 |
+
|
| 169 |
+
assert finxz.closed
|
| 170 |
+
|
| 171 |
+
# the context should close the file when we throw an exception!
|
| 172 |
+
try:
|
| 173 |
+
with utils.open_read_text(filename) as finexxz:
|
| 174 |
+
assert not finexxz.closed
|
| 175 |
+
raise ValueError("unban mox opal!")
|
| 176 |
+
except ValueError:
|
| 177 |
+
pass
|
| 178 |
+
assert finexxz.closed
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def test_checkpoint_name():
|
| 182 |
+
"""
|
| 183 |
+
Test some expected results for the checkpoint names
|
| 184 |
+
"""
|
| 185 |
+
# use os.path.split so that the test is agnostic of file separator on Linux or Windows
|
| 186 |
+
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm.pt", None)
|
| 187 |
+
assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint.pt")
|
| 188 |
+
|
| 189 |
+
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", None)
|
| 190 |
+
assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint")
|
| 191 |
+
|
| 192 |
+
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", "othername.pt")
|
| 193 |
+
assert os.path.split(checkpoint) == ("saved_models", "othername.pt")
|
| 194 |
+
|
stanza/stanza/tests/constituency/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/constituency/test_convert_arboretum.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple different classes of trees to check the output of the Arboretum conversion
|
| 3 |
+
|
| 4 |
+
Note that the text has been removed
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
|
| 12 |
+
from stanza.server import tsurgeon
|
| 13 |
+
from stanza.tests import TEST_WORKING_DIR
|
| 14 |
+
from stanza.utils.datasets.constituency import convert_arboretum
|
| 15 |
+
|
| 16 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
PROJ_EXAMPLE="""
|
| 20 |
+
<s id="s2" ref="AACBPIGY" source="id=AACBPIGY" forest="1/1" text="A B C D E F G H.">
|
| 21 |
+
<graph root="s2_500">
|
| 22 |
+
<terminals>
|
| 23 |
+
<t id="s2_1" word="A" lemma="A" pos="prop" morph="NOM" extra="PROP:A compound brand"/>
|
| 24 |
+
<t id="s2_2" word="B" lemma="B" pos="v-fin" morph="PR AKT" extra="mv"/>
|
| 25 |
+
<t id="s2_3" word="C" lemma="C" pos="pron-pers" morph="2S ACC" extra="--"/>
|
| 26 |
+
<t id="s2_4" word="D" lemma="D" pos="adj" morph="UTR S IDF NOM" extra="F:u+afhængig"/>
|
| 27 |
+
<t id="s2_5" word="E" lemma="E" pos="prp" morph="--" extra="--"/>
|
| 28 |
+
<t id="s2_6" word="F" lemma="F" pos="art" morph="NEU S DEF" extra="--"/>
|
| 29 |
+
<t id="s2_7" word="G" lemma="G" pos="adj" morph="nG S DEF NOM" extra="--"/>
|
| 30 |
+
<t id="s2_8" word="H" lemma="H" pos="n" morph="NEU S IDF NOM" extra="N:lys+net"/>
|
| 31 |
+
<t id="s2_9" word="." lemma="--" pos="pu" morph="--" extra="--"/>
|
| 32 |
+
</terminals>
|
| 33 |
+
|
| 34 |
+
<nonterminals>
|
| 35 |
+
<nt id="s2_500" cat="s">
|
| 36 |
+
<edge label="STA" idref="s2_501"/>
|
| 37 |
+
</nt>
|
| 38 |
+
<nt id="s2_501" cat="fcl">
|
| 39 |
+
<edge label="S" idref="s2_1"/>
|
| 40 |
+
<edge label="P" idref="s2_2"/>
|
| 41 |
+
<edge label="Od" idref="s2_3"/>
|
| 42 |
+
<edge label="Co" idref="s2_502"/>
|
| 43 |
+
<edge label="PU" idref="s2_9"/>
|
| 44 |
+
</nt>
|
| 45 |
+
<nt id="s2_502" cat="adjp">
|
| 46 |
+
<edge label="H" idref="s2_4"/>
|
| 47 |
+
<edge label="DA" idref="s2_503"/>
|
| 48 |
+
</nt>
|
| 49 |
+
<nt id="s2_503" cat="pp">
|
| 50 |
+
<edge label="H" idref="s2_5"/>
|
| 51 |
+
<edge label="DP" idref="s2_504"/>
|
| 52 |
+
</nt>
|
| 53 |
+
<nt id="s2_504" cat="np">
|
| 54 |
+
<edge label="DN" idref="s2_6"/>
|
| 55 |
+
<edge label="DN" idref="s2_7"/>
|
| 56 |
+
<edge label="H" idref="s2_8"/>
|
| 57 |
+
</nt>
|
| 58 |
+
</nonterminals>
|
| 59 |
+
</graph>
|
| 60 |
+
</s>
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
NOT_FIX_NONPROJ_EXAMPLE="""
|
| 64 |
+
<s id="s322" ref="EDGBITSZ" source="id=EDGBITSZ" forest="1/2" text="A B C D E, F G H I J.">
|
| 65 |
+
<graph root="s322_500">
|
| 66 |
+
<terminals>
|
| 67 |
+
<t id="s322_1" word="A" lemma="A" pos="prop" morph="NOM" extra="hum fem"/>
|
| 68 |
+
<t id="s322_2" word="B" lemma="B" pos="v-fin" morph="PR AKT" extra="mv"/>
|
| 69 |
+
<t id="s322_3" word="C" lemma="C" pos="pron-dem" morph="UTR S" extra="dem"/>
|
| 70 |
+
<t id="s322_4" word="D" lemma="D" pos="n" morph="UTR S IDF NOM" extra="--"/>
|
| 71 |
+
<t id="s322_5" word="E" lemma="E" pos="adv" morph="--" extra="--"/>
|
| 72 |
+
<t id="s322_6" word="," lemma="--" pos="pu" morph="--" extra="--"/>
|
| 73 |
+
<t id="s322_7" word="F" lemma="F" pos="pron-rel" morph="--" extra="rel"/>
|
| 74 |
+
<t id="s322_8" word="G" lemma="G" pos="prop" morph="NOM" extra="hum"/>
|
| 75 |
+
<t id="s322_9" word="H" lemma="H" pos="v-fin" morph="IMPF AKT" extra="mv"/>
|
| 76 |
+
<t id="s322_10" word="I" lemma="I" pos="prp" morph="--" extra="--"/>
|
| 77 |
+
<t id="s322_11" word="J" lemma="J" pos="n" morph="UTR S DEF NOM" extra="F:ur+premiere"/>
|
| 78 |
+
<t id="s322_12" word="." lemma="--" pos="pu" morph="--" extra="--"/>
|
| 79 |
+
</terminals>
|
| 80 |
+
|
| 81 |
+
<nonterminals>
|
| 82 |
+
<nt id="s322_500" cat="s">
|
| 83 |
+
<edge label="STA" idref="s322_501"/>
|
| 84 |
+
</nt>
|
| 85 |
+
<nt id="s322_501" cat="fcl">
|
| 86 |
+
<edge label="S" idref="s322_1"/>
|
| 87 |
+
<edge label="P" idref="s322_2"/>
|
| 88 |
+
<edge label="Od" idref="s322_502"/>
|
| 89 |
+
<edge label="Vpart" idref="s322_5"/>
|
| 90 |
+
<edge label="PU" idref="s322_6"/>
|
| 91 |
+
<edge label="PU" idref="s322_12"/>
|
| 92 |
+
</nt>
|
| 93 |
+
<nt id="s322_502" cat="np">
|
| 94 |
+
<edge label="DN" idref="s322_3"/>
|
| 95 |
+
<edge label="H" idref="s322_4"/>
|
| 96 |
+
<edge label="DN" idref="s322_503"/>
|
| 97 |
+
</nt>
|
| 98 |
+
<nt id="s322_503" cat="fcl">
|
| 99 |
+
<edge label="Od" idref="s322_7"/>
|
| 100 |
+
<edge label="S" idref="s322_8"/>
|
| 101 |
+
<edge label="P" idref="s322_9"/>
|
| 102 |
+
<edge label="Ao" idref="s322_504"/>
|
| 103 |
+
</nt>
|
| 104 |
+
<nt id="s322_504" cat="pp">
|
| 105 |
+
<edge label="H" idref="s322_10"/>
|
| 106 |
+
<edge label="DP" idref="s322_11"/>
|
| 107 |
+
</nt>
|
| 108 |
+
</nonterminals>
|
| 109 |
+
</graph>
|
| 110 |
+
</s>
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
NONPROJ_EXAMPLE="""
|
| 115 |
+
<s id="s9" ref="AATCNKQZ" source="id=AATCNKQZ" forest="1/1" text="A B C D E F G H I.">
|
| 116 |
+
<graph root="s9_500">
|
| 117 |
+
<terminals>
|
| 118 |
+
<t id="s9_1" word="A" lemma="A" pos="adv" morph="--" extra="--"/>
|
| 119 |
+
<t id="s9_2" word="B" lemma="B" pos="adv" morph="--" extra="--"/>
|
| 120 |
+
<t id="s9_3" word="C" lemma="C" pos="v-fin" morph="IMPF AKT" extra="aux"/>
|
| 121 |
+
<t id="s9_4" word="D" lemma="D" pos="prop" morph="NOM" extra="hum"/>
|
| 122 |
+
<t id="s9_5" word="E" lemma="E" pos="adv" morph="--" extra="--"/>
|
| 123 |
+
<t id="s9_6" word="F" lemma="F" pos="v-pcp2" morph="PAS" extra="mv"/>
|
| 124 |
+
<t id="s9_7" word="G" lemma="G" pos="prp" morph="--" extra="--"/>
|
| 125 |
+
<t id="s9_8" word="H" lemma="H" pos="num" morph="--" extra="card"/>
|
| 126 |
+
<t id="s9_9" word="I" lemma="I" pos="n" morph="UTR P IDF NOM" extra="N:patrulje+vogn"/>
|
| 127 |
+
<t id="s9_10" word="." lemma="--" pos="pu" morph="--" extra="--"/>
|
| 128 |
+
</terminals>
|
| 129 |
+
|
| 130 |
+
<nonterminals>
|
| 131 |
+
<nt id="s9_500" cat="s">
|
| 132 |
+
<edge label="STA" idref="s9_501"/>
|
| 133 |
+
</nt>
|
| 134 |
+
<nt id="s9_501" cat="fcl">
|
| 135 |
+
<edge label="fA" idref="s9_502"/>
|
| 136 |
+
<edge label="P" idref="s9_503"/>
|
| 137 |
+
<edge label="S" idref="s9_4"/>
|
| 138 |
+
<edge label="fA" idref="s9_5"/>
|
| 139 |
+
<edge label="fA" idref="s9_504"/>
|
| 140 |
+
<edge label="PU" idref="s9_10"/>
|
| 141 |
+
</nt>
|
| 142 |
+
<nt id="s9_502" cat="advp">
|
| 143 |
+
<edge label="DA" idref="s9_1"/>
|
| 144 |
+
<edge label="H" idref="s9_2"/>
|
| 145 |
+
</nt>
|
| 146 |
+
<nt id="s9_503" cat="vp">
|
| 147 |
+
<edge label="Vaux" idref="s9_3"/>
|
| 148 |
+
<edge label="Vm" idref="s9_6"/>
|
| 149 |
+
</nt>
|
| 150 |
+
<nt id="s9_504" cat="pp">
|
| 151 |
+
<edge label="H" idref="s9_7"/>
|
| 152 |
+
<edge label="DP" idref="s9_505"/>
|
| 153 |
+
</nt>
|
| 154 |
+
<nt id="s9_505" cat="np">
|
| 155 |
+
<edge label="DN" idref="s9_8"/>
|
| 156 |
+
<edge label="H" idref="s9_9"/>
|
| 157 |
+
</nt>
|
| 158 |
+
</nonterminals>
|
| 159 |
+
</graph>
|
| 160 |
+
</s>
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def test_projective_example():
|
| 164 |
+
"""
|
| 165 |
+
Test reading a basic tree, along with some further manipulations from the conversion program
|
| 166 |
+
"""
|
| 167 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
|
| 168 |
+
test_name = os.path.join(tempdir, "proj.xml")
|
| 169 |
+
with open(test_name, "w", encoding="utf-8") as fout:
|
| 170 |
+
fout.write(PROJ_EXAMPLE)
|
| 171 |
+
sentences = convert_arboretum.read_xml_file(test_name)
|
| 172 |
+
assert len(sentences) == 1
|
| 173 |
+
|
| 174 |
+
tree, words = convert_arboretum.process_tree(sentences[0])
|
| 175 |
+
expected_tree = "(s (fcl (prop s2_1) (v-fin s2_2) (pron-pers s2_3) (adjp (adj s2_4) (pp (prp s2_5) (np (art s2_6) (adj s2_7) (n s2_8)))) (pu s2_9)))"
|
| 176 |
+
assert str(tree) == expected_tree
|
| 177 |
+
assert [w.word for w in words.values()] == ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', '.']
|
| 178 |
+
assert not convert_arboretum.word_sequence_missing_words(tree)
|
| 179 |
+
with tsurgeon.Tsurgeon() as tsurgeon_processor:
|
| 180 |
+
assert tree == convert_arboretum.check_words(tree, tsurgeon_processor)
|
| 181 |
+
|
| 182 |
+
# check that the words can be replaced as expected
|
| 183 |
+
replaced_tree = convert_arboretum.replace_words(tree, words)
|
| 184 |
+
expected_tree = "(s (fcl (prop A) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
|
| 185 |
+
assert str(replaced_tree) == expected_tree
|
| 186 |
+
assert convert_arboretum.split_underscores(replaced_tree) == replaced_tree
|
| 187 |
+
|
| 188 |
+
# fake a word which should be split
|
| 189 |
+
words['s2_1'] = words['s2_1']._replace(word='foo_bar')
|
| 190 |
+
replaced_tree = convert_arboretum.replace_words(tree, words)
|
| 191 |
+
expected_tree = "(s (fcl (prop foo_bar) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
|
| 192 |
+
assert str(replaced_tree) == expected_tree
|
| 193 |
+
expected_tree = "(s (fcl (np (prop foo) (prop bar)) (v-fin B) (pron-pers C) (adjp (adj D) (pp (prp E) (np (art F) (adj G) (n H)))) (pu .)))"
|
| 194 |
+
assert str(convert_arboretum.split_underscores(replaced_tree)) == expected_tree
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def test_not_fix_example():
|
| 198 |
+
"""
|
| 199 |
+
Test that a non-projective tree which we don't have a heuristic for quietly fails
|
| 200 |
+
"""
|
| 201 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
|
| 202 |
+
test_name = os.path.join(tempdir, "nofix.xml")
|
| 203 |
+
with open(test_name, "w", encoding="utf-8") as fout:
|
| 204 |
+
fout.write(NOT_FIX_NONPROJ_EXAMPLE)
|
| 205 |
+
sentences = convert_arboretum.read_xml_file(test_name)
|
| 206 |
+
assert len(sentences) == 1
|
| 207 |
+
|
| 208 |
+
tree, words = convert_arboretum.process_tree(sentences[0])
|
| 209 |
+
assert not convert_arboretum.word_sequence_missing_words(tree)
|
| 210 |
+
with tsurgeon.Tsurgeon() as tsurgeon_processor:
|
| 211 |
+
assert convert_arboretum.check_words(tree, tsurgeon_processor) is None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def test_fix_proj_example():
|
| 215 |
+
"""
|
| 216 |
+
Test that a non-projective tree can be rearranged as expected
|
| 217 |
+
|
| 218 |
+
Note that there are several other classes of non-proj tree we could test as well...
|
| 219 |
+
"""
|
| 220 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tempdir:
|
| 221 |
+
test_name = os.path.join(tempdir, "fix.xml")
|
| 222 |
+
with open(test_name, "w", encoding="utf-8") as fout:
|
| 223 |
+
fout.write(NONPROJ_EXAMPLE)
|
| 224 |
+
sentences = convert_arboretum.read_xml_file(test_name)
|
| 225 |
+
assert len(sentences) == 1
|
| 226 |
+
|
| 227 |
+
tree, words = convert_arboretum.process_tree(sentences[0])
|
| 228 |
+
assert not convert_arboretum.word_sequence_missing_words(tree)
|
| 229 |
+
# the 4 and 5 are moved inside the 3-6 node
|
| 230 |
+
expected_orig = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (v-pcp2 s9_6)) (prop s9_4) (adv s9_5) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))"
|
| 231 |
+
expected_proj = "(s (fcl (advp (adv s9_1) (adv s9_2)) (vp (v-fin s9_3) (prop s9_4) (adv s9_5) (v-pcp2 s9_6)) (pp (prp s9_7) (np (num s9_8) (n s9_9))) (pu s9_10)))"
|
| 232 |
+
assert str(tree) == expected_orig
|
| 233 |
+
with tsurgeon.Tsurgeon() as tsurgeon_processor:
|
| 234 |
+
assert str(convert_arboretum.check_words(tree, tsurgeon_processor)) == expected_proj
|
| 235 |
+
|
stanza/stanza/tests/constituency/test_ensemble.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Add a simple test of the Ensemble's inference path
|
| 3 |
+
|
| 4 |
+
This just reuses one model several times - that should still check the main loop, at least
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from stanza import Pipeline
|
| 10 |
+
from stanza.models.constituency import text_processing
|
| 11 |
+
from stanza.models.constituency import tree_reader
|
| 12 |
+
from stanza.models.constituency.ensemble import Ensemble, EnsembleTrainer
|
| 13 |
+
from stanza.models.constituency.text_processing import parse_tokenized_sentences
|
| 14 |
+
|
| 15 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 16 |
+
|
| 17 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.fixture(scope="module")
|
| 21 |
+
def pipeline():
|
| 22 |
+
return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True)
|
| 23 |
+
|
| 24 |
+
@pytest.fixture(scope="module")
|
| 25 |
+
def saved_ensemble(tmp_path_factory, pipeline):
|
| 26 |
+
tmp_path = tmp_path_factory.mktemp("ensemble")
|
| 27 |
+
|
| 28 |
+
# test the ensemble by reusing the same parser multiple times
|
| 29 |
+
con_processor = pipeline.processors["constituency"]
|
| 30 |
+
model = con_processor._model
|
| 31 |
+
args = dict(model.args)
|
| 32 |
+
foundation_cache = pipeline.foundation_cache
|
| 33 |
+
|
| 34 |
+
model_path = con_processor._config['model_path']
|
| 35 |
+
# reuse the same model 3 times just to make sure the code paths are working
|
| 36 |
+
filenames = [model_path, model_path, model_path]
|
| 37 |
+
|
| 38 |
+
ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)
|
| 39 |
+
save_path = tmp_path / "ensemble.pt"
|
| 40 |
+
|
| 41 |
+
ensemble.save(save_path)
|
| 42 |
+
return ensemble, save_path, args, foundation_cache
|
| 43 |
+
|
| 44 |
+
def check_basic_predictions(trees):
|
| 45 |
+
predictions = [x.predictions for x in trees]
|
| 46 |
+
assert len(predictions) == 2
|
| 47 |
+
assert all(len(x) == 1 for x in predictions)
|
| 48 |
+
trees = [x[0].tree for x in predictions]
|
| 49 |
+
result = ["{}".format(tree) for tree in trees]
|
| 50 |
+
expected = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
|
| 51 |
+
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
|
| 52 |
+
assert result == expected
|
| 53 |
+
|
| 54 |
+
def test_ensemble_inference(pipeline):
|
| 55 |
+
# test the ensemble by reusing the same parser multiple times
|
| 56 |
+
con_processor = pipeline.processors["constituency"]
|
| 57 |
+
model = con_processor._model
|
| 58 |
+
args = dict(model.args)
|
| 59 |
+
foundation_cache = pipeline.foundation_cache
|
| 60 |
+
|
| 61 |
+
model_path = con_processor._config['model_path']
|
| 62 |
+
# reuse the same model 3 times just to make sure the code paths are working
|
| 63 |
+
filenames = [model_path, model_path, model_path]
|
| 64 |
+
|
| 65 |
+
ensemble = EnsembleTrainer.from_files(args, filenames, foundation_cache=foundation_cache)
|
| 66 |
+
ensemble = ensemble.model
|
| 67 |
+
sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]]
|
| 68 |
+
trees = parse_tokenized_sentences(args, ensemble, [pipeline], sentences)
|
| 69 |
+
check_basic_predictions(trees)
|
| 70 |
+
|
| 71 |
+
def test_ensemble_save(saved_ensemble):
|
| 72 |
+
"""
|
| 73 |
+
Depending on the saved_ensemble fixture should be enough to ensure
|
| 74 |
+
that the ensemble was correctly saved
|
| 75 |
+
|
| 76 |
+
(loading is tested separately)
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def test_ensemble_save_load(pipeline, saved_ensemble):
|
| 80 |
+
_, save_path, args, foundation_cache = saved_ensemble
|
| 81 |
+
ensemble = EnsembleTrainer.load(save_path, args, foundation_cache=foundation_cache)
|
| 82 |
+
sentences = [["This", "is", "a", "test"], ["This", "is", "another", "test"]]
|
| 83 |
+
trees = parse_tokenized_sentences(args, ensemble.model, [pipeline], sentences)
|
| 84 |
+
check_basic_predictions(trees)
|
| 85 |
+
|
| 86 |
+
def test_parse_text(tmp_path, pipeline, saved_ensemble):
|
| 87 |
+
_, model_path, args, foundation_cache = saved_ensemble
|
| 88 |
+
|
| 89 |
+
raw_file = str(tmp_path / "test_input.txt")
|
| 90 |
+
with open(raw_file, "w") as fout:
|
| 91 |
+
fout.write("This is a test\nThis is another test\n")
|
| 92 |
+
output_file = str(tmp_path / "test_output.txt")
|
| 93 |
+
|
| 94 |
+
args = dict(args)
|
| 95 |
+
args['tokenized_file'] = raw_file
|
| 96 |
+
args['predict_file'] = output_file
|
| 97 |
+
|
| 98 |
+
text_processing.load_model_parse_text(args, model_path, [pipeline])
|
| 99 |
+
trees = tree_reader.read_treebank(output_file)
|
| 100 |
+
trees = ["{}".format(x) for x in trees]
|
| 101 |
+
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
|
| 102 |
+
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
|
| 103 |
+
assert trees == expected_trees
|
| 104 |
+
|
| 105 |
+
def test_pipeline(saved_ensemble):
|
| 106 |
+
_, model_path, _, foundation_cache = saved_ensemble
|
| 107 |
+
nlp = Pipeline("en", processors="tokenize,pos,constituency", constituency_model_path=str(model_path), foundation_cache=foundation_cache, download_method=None)
|
| 108 |
+
doc = nlp("This is a test")
|
| 109 |
+
tree = "{}".format(doc.sentences[0].constituency)
|
| 110 |
+
assert tree == "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))"
|
stanza/stanza/tests/constituency/test_in_order_compound_oracle.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency import in_order_compound_oracle
|
| 4 |
+
from stanza.models.constituency import tree_reader
|
| 5 |
+
from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme
|
| 6 |
+
from stanza.models.constituency.transition_sequence import build_treebank
|
| 7 |
+
|
| 8 |
+
from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
|
| 9 |
+
|
| 10 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 11 |
+
|
| 12 |
+
# A sample tree from PTB with a triple unary transition (at a location other than root)
|
| 13 |
+
# Here we test the incorrect closing of various brackets
|
| 14 |
+
TRIPLE_UNARY_START_TREE = """
|
| 15 |
+
( (S
|
| 16 |
+
(PRN
|
| 17 |
+
(S
|
| 18 |
+
(NP-SBJ (-NONE- *) )
|
| 19 |
+
(VP (VB See) )))
|
| 20 |
+
(, ,)
|
| 21 |
+
(NP-SBJ
|
| 22 |
+
(NP (DT the) (JJ other) (NN rule) )
|
| 23 |
+
(PP (IN of)
|
| 24 |
+
(NP (NN thumb) ))
|
| 25 |
+
(PP (IN about)
|
| 26 |
+
(NP (NN ballooning) )))))
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
TREES = [TRIPLE_UNARY_START_TREE]
|
| 30 |
+
TREEBANK = "\n".join(TREES)
|
| 31 |
+
|
| 32 |
+
ROOT_LABELS = ["ROOT"]
|
| 33 |
+
|
| 34 |
+
@pytest.fixture(scope="module")
|
| 35 |
+
def trees():
|
| 36 |
+
trees = tree_reader.read_trees(TREEBANK)
|
| 37 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 38 |
+
assert len(trees) == len(TREES)
|
| 39 |
+
|
| 40 |
+
return trees
|
| 41 |
+
|
| 42 |
+
@pytest.fixture(scope="module")
|
| 43 |
+
def gold_sequences(trees):
|
| 44 |
+
gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
|
| 45 |
+
return gold_sequences
|
| 46 |
+
|
| 47 |
+
def get_repairs(gold_sequence, wrong_transition, repair_fn):
|
| 48 |
+
"""
|
| 49 |
+
Use the repair function and the wrong transition to iterate over the gold sequence
|
| 50 |
+
|
| 51 |
+
Returns a list of possible repairs, one for each position in the sequence
|
| 52 |
+
Repairs are tuples, (idx, seq)
|
| 53 |
+
"""
|
| 54 |
+
repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))
|
| 55 |
+
for idx, gold_transition in enumerate(gold_sequence)]
|
| 56 |
+
repairs = [x for x in repairs if x[1] is not None]
|
| 57 |
+
return repairs
|
| 58 |
+
|
| 59 |
+
def test_fix_shift_close():
|
| 60 |
+
trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)
|
| 61 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 62 |
+
assert len(trees) == 1
|
| 63 |
+
tree = trees[0]
|
| 64 |
+
|
| 65 |
+
gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
|
| 66 |
+
|
| 67 |
+
# there are three places in this tree where a long bracket (more than 2 subtrees)
|
| 68 |
+
# could theoretically be closed and then reopened
|
| 69 |
+
repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_shift_close_error)
|
| 70 |
+
assert len(repairs) == 3
|
| 71 |
+
|
| 72 |
+
expected_trees = ["(ROOT (S (S (PRN (S (VP (VB See)))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
|
| 73 |
+
"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other)) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
|
| 74 |
+
"(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)))) (PP (IN about) (NP (NN ballooning))))))"]
|
| 75 |
+
|
| 76 |
+
for repair, expected in zip(repairs, expected_trees):
|
| 77 |
+
repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
|
| 78 |
+
assert str(repaired_tree) == expected
|
| 79 |
+
|
| 80 |
+
def test_fix_open_close():
|
| 81 |
+
trees = tree_reader.read_trees(TRIPLE_UNARY_START_TREE)
|
| 82 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 83 |
+
assert len(trees) == 1
|
| 84 |
+
tree = trees[0]
|
| 85 |
+
|
| 86 |
+
gold_sequences = build_treebank(trees, TransitionScheme.IN_ORDER_COMPOUND)
|
| 87 |
+
|
| 88 |
+
repairs = get_repairs(gold_sequences[0], CloseConstituent(), in_order_compound_oracle.fix_open_close_error)
|
| 89 |
+
print("------------------")
|
| 90 |
+
for repair in repairs:
|
| 91 |
+
print(repair)
|
| 92 |
+
repaired_tree = reconstruct_tree(tree, repair[1], transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
|
| 93 |
+
print("{:P}".format(repaired_tree))
|
stanza/stanza/tests/constituency/test_parse_transitions.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency import parse_transitions
|
| 4 |
+
from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT
|
| 5 |
+
from stanza.models.constituency.parse_transitions import TransitionScheme, Shift, CloseConstituent, OpenConstituent
|
| 6 |
+
from stanza.tests import *
|
| 7 |
+
|
| 8 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_initial_state(model, num_states=1):
|
| 12 |
+
words = ["Unban", "Mox", "Opal"]
|
| 13 |
+
tags = ["VB", "NNP", "NNP"]
|
| 14 |
+
sentences = [list(zip(words, tags)) for _ in range(num_states)]
|
| 15 |
+
|
| 16 |
+
states = model.initial_state_from_words(sentences)
|
| 17 |
+
assert len(states) == num_states
|
| 18 |
+
assert all(state.num_transitions == 0 for state in states)
|
| 19 |
+
return states
|
| 20 |
+
|
| 21 |
+
def test_initial_state(model=None):
|
| 22 |
+
if model is None:
|
| 23 |
+
model = SimpleModel()
|
| 24 |
+
states = build_initial_state(model)
|
| 25 |
+
assert len(states) == 1
|
| 26 |
+
state = states[0]
|
| 27 |
+
|
| 28 |
+
assert state.sentence_length == 3
|
| 29 |
+
assert state.num_opens == 0
|
| 30 |
+
# each stack has a sentinel value at the end
|
| 31 |
+
assert len(state.word_queue) == 5
|
| 32 |
+
assert len(state.constituents) == 1
|
| 33 |
+
assert len(state.transitions) == 1
|
| 34 |
+
assert state.word_position == 0
|
| 35 |
+
|
| 36 |
+
def test_shift(model=None):
|
| 37 |
+
if model is None:
|
| 38 |
+
model = SimpleModel()
|
| 39 |
+
state = build_initial_state(model)[0]
|
| 40 |
+
|
| 41 |
+
open_transition = parse_transitions.OpenConstituent("ROOT")
|
| 42 |
+
state = open_transition.apply(state, model)
|
| 43 |
+
open_transition = parse_transitions.OpenConstituent("S")
|
| 44 |
+
state = open_transition.apply(state, model)
|
| 45 |
+
shift = parse_transitions.Shift()
|
| 46 |
+
assert shift.is_legal(state, model)
|
| 47 |
+
assert len(state.word_queue) == 5
|
| 48 |
+
assert state.word_position == 0
|
| 49 |
+
|
| 50 |
+
state = shift.apply(state, model)
|
| 51 |
+
assert len(state.word_queue) == 5
|
| 52 |
+
# 4 because of the dummy created by the opens
|
| 53 |
+
assert len(state.constituents) == 4
|
| 54 |
+
assert len(state.transitions) == 4
|
| 55 |
+
assert shift.is_legal(state, model)
|
| 56 |
+
assert state.word_position == 1
|
| 57 |
+
assert not state.empty_word_queue()
|
| 58 |
+
|
| 59 |
+
state = shift.apply(state, model)
|
| 60 |
+
assert len(state.word_queue) == 5
|
| 61 |
+
assert len(state.constituents) == 5
|
| 62 |
+
assert len(state.transitions) == 5
|
| 63 |
+
assert shift.is_legal(state, model)
|
| 64 |
+
assert state.word_position == 2
|
| 65 |
+
assert not state.empty_word_queue()
|
| 66 |
+
|
| 67 |
+
state = shift.apply(state, model)
|
| 68 |
+
assert len(state.word_queue) == 5
|
| 69 |
+
assert len(state.constituents) == 6
|
| 70 |
+
assert len(state.transitions) == 6
|
| 71 |
+
assert not shift.is_legal(state, model)
|
| 72 |
+
assert state.word_position == 3
|
| 73 |
+
assert state.empty_word_queue()
|
| 74 |
+
|
| 75 |
+
constituents = state.constituents
|
| 76 |
+
assert model.get_top_constituent(constituents).children[0].label == 'Opal'
|
| 77 |
+
constituents = constituents.pop()
|
| 78 |
+
assert model.get_top_constituent(constituents).children[0].label == 'Mox'
|
| 79 |
+
constituents = constituents.pop()
|
| 80 |
+
assert model.get_top_constituent(constituents).children[0].label == 'Unban'
|
| 81 |
+
|
| 82 |
+
def test_initial_unary(model=None):
|
| 83 |
+
# it doesn't make sense to start with a CompoundUnary
|
| 84 |
+
if model is None:
|
| 85 |
+
model = SimpleModel()
|
| 86 |
+
|
| 87 |
+
state = build_initial_state(model)[0]
|
| 88 |
+
unary = parse_transitions.CompoundUnary('ROOT', 'VP')
|
| 89 |
+
assert unary.label == ('ROOT', 'VP',)
|
| 90 |
+
assert not unary.is_legal(state, model)
|
| 91 |
+
unary = parse_transitions.CompoundUnary('VP')
|
| 92 |
+
assert unary.label == ('VP',)
|
| 93 |
+
assert not unary.is_legal(state, model)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_unary(model=None):
|
| 97 |
+
if model is None:
|
| 98 |
+
model = SimpleModel()
|
| 99 |
+
state = build_initial_state(model)[0]
|
| 100 |
+
|
| 101 |
+
shift = parse_transitions.Shift()
|
| 102 |
+
state = shift.apply(state, model)
|
| 103 |
+
|
| 104 |
+
# this is technically the wrong parse but we're being lazy
|
| 105 |
+
unary = parse_transitions.CompoundUnary('S', 'VP')
|
| 106 |
+
assert unary.is_legal(state, model)
|
| 107 |
+
state = unary.apply(state, model)
|
| 108 |
+
assert not unary.is_legal(state, model)
|
| 109 |
+
|
| 110 |
+
tree = model.get_top_constituent(state.constituents)
|
| 111 |
+
assert tree.label == 'S'
|
| 112 |
+
assert len(tree.children) == 1
|
| 113 |
+
tree = tree.children[0]
|
| 114 |
+
assert tree.label == 'VP'
|
| 115 |
+
assert len(tree.children) == 1
|
| 116 |
+
tree = tree.children[0]
|
| 117 |
+
assert tree.label == 'VB'
|
| 118 |
+
assert tree.is_preterminal()
|
| 119 |
+
|
| 120 |
+
def test_unary_requires_root(model=None):
|
| 121 |
+
if model is None:
|
| 122 |
+
model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)
|
| 123 |
+
state = build_initial_state(model)[0]
|
| 124 |
+
|
| 125 |
+
open_transition = parse_transitions.OpenConstituent("S")
|
| 126 |
+
assert open_transition.is_legal(state, model)
|
| 127 |
+
state = open_transition.apply(state, model)
|
| 128 |
+
|
| 129 |
+
shift = parse_transitions.Shift()
|
| 130 |
+
assert shift.is_legal(state, model)
|
| 131 |
+
state = shift.apply(state, model)
|
| 132 |
+
assert shift.is_legal(state, model)
|
| 133 |
+
state = shift.apply(state, model)
|
| 134 |
+
assert shift.is_legal(state, model)
|
| 135 |
+
state = shift.apply(state, model)
|
| 136 |
+
assert not shift.is_legal(state, model)
|
| 137 |
+
|
| 138 |
+
close_transition = parse_transitions.CloseConstituent()
|
| 139 |
+
assert close_transition.is_legal(state, model)
|
| 140 |
+
state = close_transition.apply(state, model)
|
| 141 |
+
assert not open_transition.is_legal(state, model)
|
| 142 |
+
assert not close_transition.is_legal(state, model)
|
| 143 |
+
|
| 144 |
+
np_unary = parse_transitions.CompoundUnary("NP")
|
| 145 |
+
assert not np_unary.is_legal(state, model)
|
| 146 |
+
root_unary = parse_transitions.CompoundUnary("ROOT")
|
| 147 |
+
assert root_unary.is_legal(state, model)
|
| 148 |
+
assert not state.finished(model)
|
| 149 |
+
state = root_unary.apply(state, model)
|
| 150 |
+
assert not root_unary.is_legal(state, model)
|
| 151 |
+
|
| 152 |
+
assert state.finished(model)
|
| 153 |
+
|
| 154 |
+
def test_open(model=None):
|
| 155 |
+
if model is None:
|
| 156 |
+
model = SimpleModel()
|
| 157 |
+
state = build_initial_state(model)[0]
|
| 158 |
+
|
| 159 |
+
shift = parse_transitions.Shift()
|
| 160 |
+
state = shift.apply(state, model)
|
| 161 |
+
state = shift.apply(state, model)
|
| 162 |
+
assert state.num_opens == 0
|
| 163 |
+
|
| 164 |
+
open_transition = parse_transitions.OpenConstituent("VP")
|
| 165 |
+
assert open_transition.is_legal(state, model)
|
| 166 |
+
state = open_transition.apply(state, model)
|
| 167 |
+
assert open_transition.is_legal(state, model)
|
| 168 |
+
assert state.num_opens == 1
|
| 169 |
+
|
| 170 |
+
# check that it is illegal if there are too many opens already
|
| 171 |
+
for i in range(20):
|
| 172 |
+
state = open_transition.apply(state, model)
|
| 173 |
+
assert not open_transition.is_legal(state, model)
|
| 174 |
+
assert state.num_opens == 21
|
| 175 |
+
|
| 176 |
+
# check that it is illegal if the state is out of words
|
| 177 |
+
state = build_initial_state(model)[0]
|
| 178 |
+
state = shift.apply(state, model)
|
| 179 |
+
state = shift.apply(state, model)
|
| 180 |
+
state = shift.apply(state, model)
|
| 181 |
+
assert not open_transition.is_legal(state, model)
|
| 182 |
+
|
| 183 |
+
def test_compound_open(model=None):
|
| 184 |
+
if model is None:
|
| 185 |
+
model = SimpleModel()
|
| 186 |
+
state = build_initial_state(model)[0]
|
| 187 |
+
|
| 188 |
+
open_transition = parse_transitions.OpenConstituent("ROOT", "S")
|
| 189 |
+
assert open_transition.is_legal(state, model)
|
| 190 |
+
shift = parse_transitions.Shift()
|
| 191 |
+
close_transition = parse_transitions.CloseConstituent()
|
| 192 |
+
|
| 193 |
+
state = open_transition.apply(state, model)
|
| 194 |
+
state = shift.apply(state, model)
|
| 195 |
+
state = shift.apply(state, model)
|
| 196 |
+
state = shift.apply(state, model)
|
| 197 |
+
state = close_transition.apply(state, model)
|
| 198 |
+
|
| 199 |
+
tree = model.get_top_constituent(state.constituents)
|
| 200 |
+
assert tree.label == 'ROOT'
|
| 201 |
+
assert len(tree.children) == 1
|
| 202 |
+
tree = tree.children[0]
|
| 203 |
+
assert tree.label == 'S'
|
| 204 |
+
assert len(tree.children) == 3
|
| 205 |
+
assert tree.children[0].children[0].label == 'Unban'
|
| 206 |
+
assert tree.children[1].children[0].label == 'Mox'
|
| 207 |
+
assert tree.children[2].children[0].label == 'Opal'
|
| 208 |
+
|
| 209 |
+
def test_in_order_open(model=None):
|
| 210 |
+
if model is None:
|
| 211 |
+
model = SimpleModel(TransitionScheme.IN_ORDER)
|
| 212 |
+
state = build_initial_state(model)[0]
|
| 213 |
+
|
| 214 |
+
shift = parse_transitions.Shift()
|
| 215 |
+
assert shift.is_legal(state, model)
|
| 216 |
+
state = shift.apply(state, model)
|
| 217 |
+
assert not shift.is_legal(state, model)
|
| 218 |
+
|
| 219 |
+
open_vp = parse_transitions.OpenConstituent("VP")
|
| 220 |
+
assert open_vp.is_legal(state, model)
|
| 221 |
+
state = open_vp.apply(state, model)
|
| 222 |
+
assert not open_vp.is_legal(state, model)
|
| 223 |
+
|
| 224 |
+
close_trans = parse_transitions.CloseConstituent()
|
| 225 |
+
assert close_trans.is_legal(state, model)
|
| 226 |
+
state = close_trans.apply(state, model)
|
| 227 |
+
|
| 228 |
+
open_s = parse_transitions.OpenConstituent("S")
|
| 229 |
+
assert open_s.is_legal(state, model)
|
| 230 |
+
state = open_s.apply(state, model)
|
| 231 |
+
assert not open_vp.is_legal(state, model)
|
| 232 |
+
|
| 233 |
+
# check that root transitions won't happen in the middle of a parse
|
| 234 |
+
open_root = parse_transitions.OpenConstituent("ROOT")
|
| 235 |
+
assert not open_root.is_legal(state, model)
|
| 236 |
+
|
| 237 |
+
# build (NP (NNP Mox) (NNP Opal))
|
| 238 |
+
open_np = parse_transitions.OpenConstituent("NP")
|
| 239 |
+
assert shift.is_legal(state, model)
|
| 240 |
+
state = shift.apply(state, model)
|
| 241 |
+
assert open_np.is_legal(state, model)
|
| 242 |
+
# make sure root can't happen in places where an arbitrary open is legal
|
| 243 |
+
assert not open_root.is_legal(state, model)
|
| 244 |
+
state = open_np.apply(state, model)
|
| 245 |
+
assert shift.is_legal(state, model)
|
| 246 |
+
state = shift.apply(state, model)
|
| 247 |
+
assert close_trans.is_legal(state, model)
|
| 248 |
+
state = close_trans.apply(state, model)
|
| 249 |
+
|
| 250 |
+
assert close_trans.is_legal(state, model)
|
| 251 |
+
state = close_trans.apply(state, model)
|
| 252 |
+
|
| 253 |
+
assert open_root.is_legal(state, model)
|
| 254 |
+
state = open_root.apply(state, model)
|
| 255 |
+
|
| 256 |
+
def test_too_many_unaries_close():
|
| 257 |
+
"""
|
| 258 |
+
This tests rejecting Close at the start of a sequence after too many unary transitions
|
| 259 |
+
|
| 260 |
+
The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
|
| 261 |
+
"""
|
| 262 |
+
model = SimpleModel(TransitionScheme.IN_ORDER)
|
| 263 |
+
state = build_initial_state(model)[0]
|
| 264 |
+
|
| 265 |
+
shift = parse_transitions.Shift()
|
| 266 |
+
assert shift.is_legal(state, model)
|
| 267 |
+
state = shift.apply(state, model)
|
| 268 |
+
|
| 269 |
+
open_np = parse_transitions.OpenConstituent("NP")
|
| 270 |
+
close_trans = parse_transitions.CloseConstituent()
|
| 271 |
+
for _ in range(UNARY_LIMIT):
|
| 272 |
+
assert open_np.is_legal(state, model)
|
| 273 |
+
state = open_np.apply(state, model)
|
| 274 |
+
|
| 275 |
+
assert close_trans.is_legal(state, model)
|
| 276 |
+
state = close_trans.apply(state, model)
|
| 277 |
+
|
| 278 |
+
assert open_np.is_legal(state, model)
|
| 279 |
+
state = open_np.apply(state, model)
|
| 280 |
+
assert not close_trans.is_legal(state, model)
|
| 281 |
+
|
| 282 |
+
def test_too_many_unaries_open():
|
| 283 |
+
"""
|
| 284 |
+
This tests rejecting Open in the middle of a sequence after too many unary transitions
|
| 285 |
+
|
| 286 |
+
The model should reject doing multiple "unaries" - eg, Open then Close - in an IN_ORDER sequence
|
| 287 |
+
"""
|
| 288 |
+
model = SimpleModel(TransitionScheme.IN_ORDER)
|
| 289 |
+
state = build_initial_state(model)[0]
|
| 290 |
+
|
| 291 |
+
shift = parse_transitions.Shift()
|
| 292 |
+
assert shift.is_legal(state, model)
|
| 293 |
+
state = shift.apply(state, model)
|
| 294 |
+
|
| 295 |
+
open_np = parse_transitions.OpenConstituent("NP")
|
| 296 |
+
close_trans = parse_transitions.CloseConstituent()
|
| 297 |
+
|
| 298 |
+
assert open_np.is_legal(state, model)
|
| 299 |
+
state = open_np.apply(state, model)
|
| 300 |
+
assert not open_np.is_legal(state, model)
|
| 301 |
+
assert shift.is_legal(state, model)
|
| 302 |
+
state = shift.apply(state, model)
|
| 303 |
+
|
| 304 |
+
for _ in range(UNARY_LIMIT):
|
| 305 |
+
assert open_np.is_legal(state, model)
|
| 306 |
+
state = open_np.apply(state, model)
|
| 307 |
+
|
| 308 |
+
assert close_trans.is_legal(state, model)
|
| 309 |
+
state = close_trans.apply(state, model)
|
| 310 |
+
|
| 311 |
+
assert not open_np.is_legal(state, model)
|
| 312 |
+
|
| 313 |
+
def test_close(model=None):
|
| 314 |
+
if model is None:
|
| 315 |
+
model = SimpleModel()
|
| 316 |
+
|
| 317 |
+
# this one actually tests an entire subtree building
|
| 318 |
+
state = build_initial_state(model)[0]
|
| 319 |
+
|
| 320 |
+
open_transition_vp = parse_transitions.OpenConstituent("VP")
|
| 321 |
+
assert open_transition_vp.is_legal(state, model)
|
| 322 |
+
state = open_transition_vp.apply(state, model)
|
| 323 |
+
assert state.num_opens == 1
|
| 324 |
+
|
| 325 |
+
shift = parse_transitions.Shift()
|
| 326 |
+
assert shift.is_legal(state, model)
|
| 327 |
+
state = shift.apply(state, model)
|
| 328 |
+
|
| 329 |
+
open_transition_np = parse_transitions.OpenConstituent("NP")
|
| 330 |
+
assert open_transition_np.is_legal(state, model)
|
| 331 |
+
state = open_transition_np.apply(state, model)
|
| 332 |
+
assert state.num_opens == 2
|
| 333 |
+
|
| 334 |
+
assert shift.is_legal(state, model)
|
| 335 |
+
state = shift.apply(state, model)
|
| 336 |
+
assert shift.is_legal(state, model)
|
| 337 |
+
state = shift.apply(state, model)
|
| 338 |
+
assert not shift.is_legal(state, model)
|
| 339 |
+
assert state.num_opens == 2
|
| 340 |
+
# now should have "mox", "opal" on the constituents
|
| 341 |
+
|
| 342 |
+
close_transition = parse_transitions.CloseConstituent()
|
| 343 |
+
assert close_transition.is_legal(state, model)
|
| 344 |
+
state = close_transition.apply(state, model)
|
| 345 |
+
assert state.num_opens == 1
|
| 346 |
+
assert close_transition.is_legal(state, model)
|
| 347 |
+
state = close_transition.apply(state, model)
|
| 348 |
+
assert state.num_opens == 0
|
| 349 |
+
assert not close_transition.is_legal(state, model)
|
| 350 |
+
|
| 351 |
+
tree = model.get_top_constituent(state.constituents)
|
| 352 |
+
assert tree.label == 'VP'
|
| 353 |
+
assert len(tree.children) == 2
|
| 354 |
+
tree = tree.children[1]
|
| 355 |
+
assert tree.label == 'NP'
|
| 356 |
+
assert len(tree.children) == 2
|
| 357 |
+
assert tree.children[0].is_preterminal()
|
| 358 |
+
assert tree.children[1].is_preterminal()
|
| 359 |
+
assert tree.children[0].children[0].label == 'Mox'
|
| 360 |
+
assert tree.children[1].children[0].label == 'Opal'
|
| 361 |
+
|
| 362 |
+
# extra one for None at the start of the TreeStack
|
| 363 |
+
assert len(state.constituents) == 2
|
| 364 |
+
|
| 365 |
+
assert state.all_transitions(model) == [open_transition_vp, shift, open_transition_np, shift, shift, close_transition, close_transition]
|
| 366 |
+
|
| 367 |
+
def test_in_order_compound_finalize(model=None):
|
| 368 |
+
"""
|
| 369 |
+
Test the Finalize transition is only legal at the end of a sequence
|
| 370 |
+
"""
|
| 371 |
+
if model is None:
|
| 372 |
+
model = SimpleModel(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
|
| 373 |
+
|
| 374 |
+
state = build_initial_state(model)[0]
|
| 375 |
+
|
| 376 |
+
finalize = parse_transitions.Finalize("ROOT")
|
| 377 |
+
|
| 378 |
+
shift = parse_transitions.Shift()
|
| 379 |
+
assert shift.is_legal(state, model)
|
| 380 |
+
assert not finalize.is_legal(state, model)
|
| 381 |
+
state = shift.apply(state, model)
|
| 382 |
+
|
| 383 |
+
open_transition = parse_transitions.OpenConstituent("NP")
|
| 384 |
+
assert open_transition.is_legal(state, model)
|
| 385 |
+
assert not finalize.is_legal(state, model)
|
| 386 |
+
state = open_transition.apply(state, model)
|
| 387 |
+
assert state.num_opens == 1
|
| 388 |
+
|
| 389 |
+
assert shift.is_legal(state, model)
|
| 390 |
+
assert not finalize.is_legal(state, model)
|
| 391 |
+
state = shift.apply(state, model)
|
| 392 |
+
assert shift.is_legal(state, model)
|
| 393 |
+
assert not finalize.is_legal(state, model)
|
| 394 |
+
state = shift.apply(state, model)
|
| 395 |
+
|
| 396 |
+
close_transition = parse_transitions.CloseConstituent()
|
| 397 |
+
assert close_transition.is_legal(state, model)
|
| 398 |
+
state = close_transition.apply(state, model)
|
| 399 |
+
assert state.num_opens == 0
|
| 400 |
+
assert not close_transition.is_legal(state, model)
|
| 401 |
+
assert finalize.is_legal(state, model)
|
| 402 |
+
|
| 403 |
+
state = finalize.apply(state, model)
|
| 404 |
+
assert not finalize.is_legal(state, model)
|
| 405 |
+
tree = model.get_top_constituent(state.constituents)
|
| 406 |
+
assert tree.label == 'ROOT'
|
| 407 |
+
|
| 408 |
+
def test_hashes():
|
| 409 |
+
transitions = set()
|
| 410 |
+
|
| 411 |
+
shift = parse_transitions.Shift()
|
| 412 |
+
assert shift not in transitions
|
| 413 |
+
transitions.add(shift)
|
| 414 |
+
assert shift in transitions
|
| 415 |
+
shift = parse_transitions.Shift()
|
| 416 |
+
assert shift in transitions
|
| 417 |
+
|
| 418 |
+
for i in range(5):
|
| 419 |
+
transitions.add(shift)
|
| 420 |
+
assert len(transitions) == 1
|
| 421 |
+
|
| 422 |
+
unary = parse_transitions.CompoundUnary("asdf")
|
| 423 |
+
assert unary not in transitions
|
| 424 |
+
transitions.add(unary)
|
| 425 |
+
assert unary in transitions
|
| 426 |
+
|
| 427 |
+
unary = parse_transitions.CompoundUnary("asdf", "zzzz")
|
| 428 |
+
assert unary not in transitions
|
| 429 |
+
transitions.add(unary)
|
| 430 |
+
transitions.add(unary)
|
| 431 |
+
transitions.add(unary)
|
| 432 |
+
unary = parse_transitions.CompoundUnary("asdf", "zzzz")
|
| 433 |
+
assert unary in transitions
|
| 434 |
+
|
| 435 |
+
oc = parse_transitions.OpenConstituent("asdf")
|
| 436 |
+
assert oc not in transitions
|
| 437 |
+
transitions.add(oc)
|
| 438 |
+
assert oc in transitions
|
| 439 |
+
transitions.add(oc)
|
| 440 |
+
transitions.add(oc)
|
| 441 |
+
assert len(transitions) == 4
|
| 442 |
+
assert parse_transitions.OpenConstituent("asdf") in transitions
|
| 443 |
+
|
| 444 |
+
cc = parse_transitions.CloseConstituent()
|
| 445 |
+
assert cc not in transitions
|
| 446 |
+
transitions.add(cc)
|
| 447 |
+
transitions.add(cc)
|
| 448 |
+
transitions.add(cc)
|
| 449 |
+
assert cc in transitions
|
| 450 |
+
cc = parse_transitions.CloseConstituent()
|
| 451 |
+
assert cc in transitions
|
| 452 |
+
assert len(transitions) == 5
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def test_sort():
|
| 456 |
+
expected = []
|
| 457 |
+
|
| 458 |
+
expected.append(parse_transitions.Shift())
|
| 459 |
+
expected.append(parse_transitions.CloseConstituent())
|
| 460 |
+
expected.append(parse_transitions.CompoundUnary("NP"))
|
| 461 |
+
expected.append(parse_transitions.CompoundUnary("NP", "VP"))
|
| 462 |
+
expected.append(parse_transitions.OpenConstituent("mox"))
|
| 463 |
+
expected.append(parse_transitions.OpenConstituent("opal"))
|
| 464 |
+
expected.append(parse_transitions.OpenConstituent("unban"))
|
| 465 |
+
|
| 466 |
+
transitions = set(expected)
|
| 467 |
+
transitions = sorted(transitions)
|
| 468 |
+
assert transitions == expected
|
| 469 |
+
|
| 470 |
+
def test_check_transitions():
|
| 471 |
+
"""
|
| 472 |
+
Test that check_transitions passes or fails a couple simple, small test cases
|
| 473 |
+
"""
|
| 474 |
+
transitions = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
|
| 475 |
+
|
| 476 |
+
other = {Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("VP")}
|
| 477 |
+
parse_transitions.check_transitions(transitions, other, "test")
|
| 478 |
+
|
| 479 |
+
# This will get a pass because it is a unary made out of existing unaries
|
| 480 |
+
other = {Shift(), CloseConstituent(), OpenConstituent("NP", "VP")}
|
| 481 |
+
parse_transitions.check_transitions(transitions, other, "test")
|
| 482 |
+
|
| 483 |
+
# This should fail
|
| 484 |
+
with pytest.raises(RuntimeError):
|
| 485 |
+
other = {Shift(), CloseConstituent(), OpenConstituent("NP", "ZP")}
|
| 486 |
+
parse_transitions.check_transitions(transitions, other, "test")
|
stanza/stanza/tests/constituency/test_parse_tree.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency.parse_tree import Tree
|
| 4 |
+
from stanza.models.constituency import tree_reader
|
| 5 |
+
|
| 6 |
+
from stanza.tests import *
|
| 7 |
+
|
| 8 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 9 |
+
|
| 10 |
+
def test_leaf_preterminal():
|
| 11 |
+
foo = Tree(label="foo")
|
| 12 |
+
assert foo.is_leaf()
|
| 13 |
+
assert not foo.is_preterminal()
|
| 14 |
+
assert len(foo.children) == 0
|
| 15 |
+
assert str(foo) == 'foo'
|
| 16 |
+
|
| 17 |
+
bar = Tree(label="bar", children=foo)
|
| 18 |
+
assert not bar.is_leaf()
|
| 19 |
+
assert bar.is_preterminal()
|
| 20 |
+
assert len(bar.children) == 1
|
| 21 |
+
assert str(bar) == "(bar foo)"
|
| 22 |
+
|
| 23 |
+
baz = Tree(label="baz", children=[bar])
|
| 24 |
+
assert not baz.is_leaf()
|
| 25 |
+
assert not baz.is_preterminal()
|
| 26 |
+
assert len(baz.children) == 1
|
| 27 |
+
assert str(baz) == "(baz (bar foo))"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_yield_preterminals():
|
| 31 |
+
text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
|
| 32 |
+
trees = tree_reader.read_trees(text)
|
| 33 |
+
|
| 34 |
+
preterminals = list(trees[0].yield_preterminals())
|
| 35 |
+
assert len(preterminals) == 3
|
| 36 |
+
assert str(preterminals) == "[(VB Unban), (NNP Mox), (NNP Opal)]"
|
| 37 |
+
|
| 38 |
+
def test_depth():
|
| 39 |
+
text = "(foo) ((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
|
| 40 |
+
trees = tree_reader.read_trees(text)
|
| 41 |
+
assert trees[0].depth() == 0
|
| 42 |
+
assert trees[1].depth() == 4
|
| 43 |
+
|
| 44 |
+
def test_unique_labels():
|
| 45 |
+
"""
|
| 46 |
+
Test getting the unique labels from a tree
|
| 47 |
+
|
| 48 |
+
Assumes tree_reader works, which should be fine since it is tested elsewhere
|
| 49 |
+
"""
|
| 50 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 51 |
+
|
| 52 |
+
trees = tree_reader.read_trees(text)
|
| 53 |
+
|
| 54 |
+
labels = Tree.get_unique_constituent_labels(trees)
|
| 55 |
+
expected = ['NP', 'PP', 'ROOT', 'SBARQ', 'SQ', 'VP', 'WHNP']
|
| 56 |
+
assert labels == expected
|
| 57 |
+
|
| 58 |
+
def test_unique_tags():
|
| 59 |
+
"""
|
| 60 |
+
Test getting the unique tags from a tree
|
| 61 |
+
"""
|
| 62 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 63 |
+
|
| 64 |
+
trees = tree_reader.read_trees(text)
|
| 65 |
+
|
| 66 |
+
tags = Tree.get_unique_tags(trees)
|
| 67 |
+
expected = ['.', 'DT', 'IN', 'NN', 'VBZ', 'WP']
|
| 68 |
+
assert tags == expected
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_unique_words():
|
| 72 |
+
"""
|
| 73 |
+
Test getting the unique words from a tree
|
| 74 |
+
"""
|
| 75 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 76 |
+
|
| 77 |
+
trees = tree_reader.read_trees(text)
|
| 78 |
+
|
| 79 |
+
words = Tree.get_unique_words(trees)
|
| 80 |
+
expected = ['?', 'Who', 'in', 'seat', 'sits', 'this']
|
| 81 |
+
assert words == expected
|
| 82 |
+
|
| 83 |
+
def test_rare_words():
|
| 84 |
+
"""
|
| 85 |
+
Test getting the unique words from a tree
|
| 86 |
+
"""
|
| 87 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
|
| 88 |
+
|
| 89 |
+
trees = tree_reader.read_trees(text)
|
| 90 |
+
|
| 91 |
+
words = Tree.get_rare_words(trees, 0.5)
|
| 92 |
+
expected = ['Who', 'in', 'sits']
|
| 93 |
+
assert words == expected
|
| 94 |
+
|
| 95 |
+
def test_common_words():
|
| 96 |
+
"""
|
| 97 |
+
Test getting the unique words from a tree
|
| 98 |
+
"""
|
| 99 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?))) ((SBARQ (NP (DT this) (NN seat)) (. ?)))"
|
| 100 |
+
|
| 101 |
+
trees = tree_reader.read_trees(text)
|
| 102 |
+
|
| 103 |
+
words = Tree.get_common_words(trees, 3)
|
| 104 |
+
expected = ['?', 'seat', 'this']
|
| 105 |
+
assert words == expected
|
| 106 |
+
|
| 107 |
+
def test_root_labels():
|
| 108 |
+
text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 109 |
+
trees = tree_reader.read_trees(text)
|
| 110 |
+
assert ["ROOT"] == Tree.get_root_labels(trees)
|
| 111 |
+
|
| 112 |
+
text=("( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
|
| 113 |
+
"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" +
|
| 114 |
+
"( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))")
|
| 115 |
+
trees = tree_reader.read_trees(text)
|
| 116 |
+
assert ["ROOT"] == Tree.get_root_labels(trees)
|
| 117 |
+
|
| 118 |
+
text="(FOO) (BAR)"
|
| 119 |
+
trees = tree_reader.read_trees(text)
|
| 120 |
+
assert ["BAR", "FOO"] == Tree.get_root_labels(trees)
|
| 121 |
+
|
| 122 |
+
def test_prune_none():
|
| 123 |
+
text=["((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (-NONE- in) (NP (DT this) (NN seat))))) (. ?)))", # test one dead node
|
| 124 |
+
"((SBARQ (WHNP (-NONE- Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))", # test recursive dead nodes
|
| 125 |
+
"((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (-NONE- this) (-NONE- seat))))) (. ?)))"] # test all children dead
|
| 126 |
+
expected=["(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (NP (DT this) (NN seat))))) (. ?)))",
|
| 127 |
+
"(ROOT (SBARQ (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))",
|
| 128 |
+
"(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"]
|
| 129 |
+
|
| 130 |
+
for t, e in zip(text, expected):
|
| 131 |
+
trees = tree_reader.read_trees(t)
|
| 132 |
+
assert len(trees) == 1
|
| 133 |
+
tree = trees[0].prune_none()
|
| 134 |
+
assert e == str(tree)
|
| 135 |
+
|
| 136 |
+
def test_simplify_labels():
|
| 137 |
+
text="( (SBARQ-FOO (WHNP-BAR (WP Who)) (SQ#ASDF (VP=1 (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
|
| 138 |
+
expected = "(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (- -))))) (. ?)))"
|
| 139 |
+
trees = tree_reader.read_trees(text)
|
| 140 |
+
trees = [t.simplify_labels() for t in trees]
|
| 141 |
+
assert len(trees) == 1
|
| 142 |
+
assert expected == str(trees[0])
|
| 143 |
+
|
| 144 |
+
def test_remap_constituent_labels():
|
| 145 |
+
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
|
| 146 |
+
expected="(ROOT (FOO (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
|
| 147 |
+
|
| 148 |
+
label_map = { "SBARQ": "FOO" }
|
| 149 |
+
trees = tree_reader.read_trees(text)
|
| 150 |
+
trees = [t.remap_constituent_labels(label_map) for t in trees]
|
| 151 |
+
assert len(trees) == 1
|
| 152 |
+
assert expected == str(trees[0])
|
| 153 |
+
|
| 154 |
+
def test_remap_constituent_words():
|
| 155 |
+
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
|
| 156 |
+
expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
|
| 157 |
+
|
| 158 |
+
word_map = { "Who": "unban", "sits": "mox", "in": "opal" }
|
| 159 |
+
trees = tree_reader.read_trees(text)
|
| 160 |
+
trees = [t.remap_words(word_map) for t in trees]
|
| 161 |
+
assert len(trees) == 1
|
| 162 |
+
assert expected == str(trees[0])
|
| 163 |
+
|
| 164 |
+
def test_replace_words():
|
| 165 |
+
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
|
| 166 |
+
expected="(ROOT (SBARQ (WHNP (WP unban)) (SQ (VP (VBZ mox) (PP (IN opal)))) (. ?)))"
|
| 167 |
+
new_words = ["unban", "mox", "opal", "?"]
|
| 168 |
+
|
| 169 |
+
trees = tree_reader.read_trees(text)
|
| 170 |
+
assert len(trees) == 1
|
| 171 |
+
tree = trees[0]
|
| 172 |
+
new_tree = tree.replace_words(new_words)
|
| 173 |
+
assert expected == str(new_tree)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def test_compound_constituents():
|
| 177 |
+
# TODO: add skinny trees like this to the various transition tests
|
| 178 |
+
text="((VP (VB Unban)))"
|
| 179 |
+
trees = tree_reader.read_trees(text)
|
| 180 |
+
assert Tree.get_compound_constituents(trees) == [('ROOT', 'VP')]
|
| 181 |
+
|
| 182 |
+
text="(ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
|
| 183 |
+
trees = tree_reader.read_trees(text)
|
| 184 |
+
assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('SQ', 'VP'), ('WHNP',)]
|
| 185 |
+
|
| 186 |
+
text="((VP (VB Unban))) (ROOT (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in)))) (. ?)))"
|
| 187 |
+
trees = tree_reader.read_trees(text)
|
| 188 |
+
assert Tree.get_compound_constituents(trees) == [('PP',), ('ROOT', 'SBARQ'), ('ROOT', 'VP'), ('SQ', 'VP'), ('WHNP',)]
|
| 189 |
+
|
| 190 |
+
def test_equals():
|
| 191 |
+
"""
|
| 192 |
+
Check one tree from the actual dataset for ==
|
| 193 |
+
|
| 194 |
+
when built with compound Open, this didn't work because of a silly bug
|
| 195 |
+
"""
|
| 196 |
+
text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
|
| 197 |
+
|
| 198 |
+
trees = tree_reader.read_trees(text)
|
| 199 |
+
assert len(trees) == 1
|
| 200 |
+
tree = trees[0]
|
| 201 |
+
|
| 202 |
+
assert tree == tree
|
| 203 |
+
|
| 204 |
+
trees2 = tree_reader.read_trees(text)
|
| 205 |
+
tree2 = trees2[0]
|
| 206 |
+
|
| 207 |
+
assert tree is not tree2
|
| 208 |
+
assert tree == tree2
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# This tree was causing the model to barf on CTB7,
|
| 212 |
+
# although it turns out the problem was just the
|
| 213 |
+
# depth of the unary, not the list
|
| 214 |
+
CHINESE_LONG_LIST_TREE = """
|
| 215 |
+
(ROOT
|
| 216 |
+
(IP
|
| 217 |
+
(NP (NNP 证券法))
|
| 218 |
+
(VP
|
| 219 |
+
(PP
|
| 220 |
+
(IN 对)
|
| 221 |
+
(NP
|
| 222 |
+
(DNP
|
| 223 |
+
(NP
|
| 224 |
+
(NP (NNP 中国))
|
| 225 |
+
(NP
|
| 226 |
+
(NN 证券)
|
| 227 |
+
(NN 市场)))
|
| 228 |
+
(DEC 的))
|
| 229 |
+
(NP (NN 运作))))
|
| 230 |
+
(, ,)
|
| 231 |
+
(PP
|
| 232 |
+
(PP
|
| 233 |
+
(IN 从)
|
| 234 |
+
(NP
|
| 235 |
+
(NP (NN 股票))
|
| 236 |
+
(NP (VV 发行) (EC 、) (VV 交易))))
|
| 237 |
+
(, ,)
|
| 238 |
+
(PP
|
| 239 |
+
(VV 到)
|
| 240 |
+
(NP
|
| 241 |
+
(NP (NN 上市) (NN 公司) (NN 收购))
|
| 242 |
+
(EC 、)
|
| 243 |
+
(NP (NN 证券) (NN 交易所))
|
| 244 |
+
(EC 、)
|
| 245 |
+
(NP (NN 证券) (NN 公司))
|
| 246 |
+
(EC 、)
|
| 247 |
+
(NP (NN 登记) (NN 结算) (NN 机构))
|
| 248 |
+
(EC 、)
|
| 249 |
+
(NP (NN 交易) (NN 服务) (NN 机构))
|
| 250 |
+
(EC 、)
|
| 251 |
+
(NP (NN 证券业) (NN 协会))
|
| 252 |
+
(EC 、)
|
| 253 |
+
(NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))
|
| 254 |
+
(CC 和)
|
| 255 |
+
(NP
|
| 256 |
+
(DNP
|
| 257 |
+
(NP (CP (CP (IP (VP (VV 违法))))))
|
| 258 |
+
(DEC 的))
|
| 259 |
+
(NP (NN 法律) (NN 责任))))))
|
| 260 |
+
(ADVP (RB 都))
|
| 261 |
+
(VP
|
| 262 |
+
(VV 作)
|
| 263 |
+
(AS 了)
|
| 264 |
+
(NP
|
| 265 |
+
(ADJP (JJ 详细))
|
| 266 |
+
(NP (NN 规定)))))
|
| 267 |
+
(. 。)))
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
WEIRD_UNARY = """
|
| 271 |
+
(DNP
|
| 272 |
+
(NP (CP (CP (IP (VP (ASDF
|
| 273 |
+
(NP (NN 上市) (NN 公司) (NN 收购))
|
| 274 |
+
(EC 、)
|
| 275 |
+
(NP (NN 证券) (NN 交易所))
|
| 276 |
+
(EC 、)
|
| 277 |
+
(NP (NN 证券) (NN 公司))
|
| 278 |
+
(EC 、)
|
| 279 |
+
(NP (NN 登记) (NN 结算) (NN 机构))
|
| 280 |
+
(EC 、)
|
| 281 |
+
(NP (NN 交易) (NN 服务) (NN 机构))
|
| 282 |
+
(EC 、)
|
| 283 |
+
(NP (NN 证券业) (NN 协会))
|
| 284 |
+
(EC 、)
|
| 285 |
+
(NP (NN 证券) (NN 监督) (NN 管理) (NN 机构))))))))
|
| 286 |
+
(DEC 的))
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def test_count_unaries():
|
| 291 |
+
trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
|
| 292 |
+
assert len(trees) == 1
|
| 293 |
+
assert trees[0].count_unary_depth() == 5
|
| 294 |
+
|
| 295 |
+
trees = tree_reader.read_trees(WEIRD_UNARY)
|
| 296 |
+
assert len(trees) == 1
|
| 297 |
+
assert trees[0].count_unary_depth() == 5
|
| 298 |
+
|
| 299 |
+
def test_str_bracket_labels():
|
| 300 |
+
text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
|
| 301 |
+
expected = "(_ROOT (_S (_VP (_VB Unban )_VB )_VP (_NP (_NNP Mox )_NNP (_NNP Opal )_NNP )_NP )_S )_ROOT"
|
| 302 |
+
|
| 303 |
+
trees = tree_reader.read_trees(text)
|
| 304 |
+
assert len(trees) == 1
|
| 305 |
+
assert "{:L}".format(trees[0]) == expected
|
| 306 |
+
|
| 307 |
+
def test_all_leaves_are_preterminals():
|
| 308 |
+
text = "((S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal))))"
|
| 309 |
+
trees = tree_reader.read_trees(text)
|
| 310 |
+
assert len(trees) == 1
|
| 311 |
+
assert trees[0].all_leaves_are_preterminals()
|
| 312 |
+
|
| 313 |
+
text = "((S (VP (VB Unban)) (NP (Mox) (NNP Opal))))"
|
| 314 |
+
trees = tree_reader.read_trees(text)
|
| 315 |
+
assert len(trees) == 1
|
| 316 |
+
assert not trees[0].all_leaves_are_preterminals()
|
| 317 |
+
|
| 318 |
+
def test_latex():
|
| 319 |
+
"""
|
| 320 |
+
Test the latex format for trees
|
| 321 |
+
"""
|
| 322 |
+
expected = "\\Tree [.S [.NP Jennifer ] [.VP has [.NP nice antennae ] ] ]"
|
| 323 |
+
tree = "(ROOT (S (NP (NNP Jennifer)) (VP (VBZ has) (NP (JJ nice) (NNS antennae)))))"
|
| 324 |
+
tree = tree_reader.read_trees(tree)[0]
|
| 325 |
+
text = "{:T}".format(tree)
|
| 326 |
+
assert text == expected
|
| 327 |
+
|
| 328 |
+
def test_pretty_print():
|
| 329 |
+
"""
|
| 330 |
+
Pretty print a couple trees - newlines & indentation
|
| 331 |
+
"""
|
| 332 |
+
text = "(ROOT (S (VP (VB Unban)) (NP (NNP Mox) (NNP Opal)))) (ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric)))))))"
|
| 333 |
+
trees = tree_reader.read_trees(text)
|
| 334 |
+
assert len(trees) == 2
|
| 335 |
+
|
| 336 |
+
expected = """(ROOT
|
| 337 |
+
(S
|
| 338 |
+
(VP (VB Unban))
|
| 339 |
+
(NP (NNP Mox) (NNP Opal))))
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
assert "{:P}".format(trees[0]) == expected
|
| 343 |
+
|
| 344 |
+
expected = """(ROOT
|
| 345 |
+
(S
|
| 346 |
+
(NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission))
|
| 347 |
+
(VP
|
| 348 |
+
(VBD authorized)
|
| 349 |
+
(NP
|
| 350 |
+
(NP
|
| 351 |
+
(DT an)
|
| 352 |
+
(ADJP (CD 11.5))
|
| 353 |
+
(NN %)
|
| 354 |
+
(NN rate)
|
| 355 |
+
(NN increase))
|
| 356 |
+
(PP
|
| 357 |
+
(IN at)
|
| 358 |
+
(NP (NNP Tucson) (NNP Electric)))))))
|
| 359 |
+
"""
|
| 360 |
+
assert "{:P}".format(trees[1]) == expected
|
| 361 |
+
|
| 362 |
+
assert text == "{:O} {:O}".format(*trees)
|
| 363 |
+
|
| 364 |
+
def test_reverse():
|
| 365 |
+
text = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB lick) (NP (NP (NNP Jennifer) (POS 's)) (NNS antennae))))))))"
|
| 366 |
+
trees = tree_reader.read_trees(text)
|
| 367 |
+
assert len(trees) == 1
|
| 368 |
+
reversed_tree = trees[0].reverse()
|
| 369 |
+
assert str(reversed_tree) == "(ROOT (S (VP (S (VP (VP (NP (NNS antennae) (NP (POS 's) (NNP Jennifer))) (VB lick)) (TO to))) (VBP want)) (NP (PRP I))))"
|
stanza/stanza/tests/constituency/test_positional_encoding.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from stanza import Pipeline
|
| 6 |
+
from stanza.models.constituency.positional_encoding import SinusoidalEncoding, AddSinusoidalEncoding
|
| 7 |
+
|
| 8 |
+
from stanza.tests import *
|
| 9 |
+
|
| 10 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_positional_encoding():
|
| 14 |
+
encoding = SinusoidalEncoding(model_dim=10, max_len=6)
|
| 15 |
+
foo = encoding(torch.tensor([5]))
|
| 16 |
+
assert foo.shape == (1, 10)
|
| 17 |
+
# TODO: check the values
|
| 18 |
+
|
| 19 |
+
def test_resize():
|
| 20 |
+
encoding = SinusoidalEncoding(model_dim=10, max_len=3)
|
| 21 |
+
foo = encoding(torch.tensor([5]))
|
| 22 |
+
assert foo.shape == (1, 10)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_arange():
|
| 26 |
+
encoding = SinusoidalEncoding(model_dim=10, max_len=2)
|
| 27 |
+
foo = encoding(torch.arange(4))
|
| 28 |
+
assert foo.shape == (4, 10)
|
| 29 |
+
assert encoding.max_len() == 4
|
| 30 |
+
|
| 31 |
+
def test_add():
|
| 32 |
+
encoding = AddSinusoidalEncoding(d_model=10, max_len=4)
|
| 33 |
+
x = torch.zeros(1, 4, 10)
|
| 34 |
+
y = encoding(x)
|
| 35 |
+
|
| 36 |
+
r = torch.randn(1, 4, 10)
|
| 37 |
+
r2 = encoding(r)
|
| 38 |
+
|
| 39 |
+
assert torch.allclose(r2 - r, y, atol=1e-07)
|
| 40 |
+
|
| 41 |
+
r = torch.randn(2, 4, 10)
|
| 42 |
+
r2 = encoding(r)
|
| 43 |
+
|
| 44 |
+
assert torch.allclose(r2[0] - r[0], y, atol=1e-07)
|
| 45 |
+
assert torch.allclose(r2[1] - r[1], y, atol=1e-07)
|
stanza/stanza/tests/constituency/test_selftrain_vi_quad.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test some of the methods in the vi_quad dataset
|
| 3 |
+
|
| 4 |
+
Uses a small section of the dataset as a test
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from stanza.utils.datasets.constituency import selftrain_vi_quad
|
| 10 |
+
|
| 11 |
+
from stanza.tests import *
|
| 12 |
+
|
| 13 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 14 |
+
|
| 15 |
+
SAMPLE_TEXT = """
|
| 16 |
+
{"version": "1.1", "data": [{"title": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng", "paragraphs": [{"qas": [{"question": "T\u00ean g\u1ecdi n\u00e0o \u0111\u01b0\u1ee3c Ph\u1ea1m V\u0103n \u0110\u1ed3ng s\u1eed d\u1ee5ng khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m?", "answers": [{"answer_start": 507, "text": "L\u00e2m B\u00e1 Ki\u1ec7t"}], "id": "uit_01__05272_0_1"}, {"question": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec trong b\u1ed9 m\u00e1y Nh\u00e0 n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam?", "answers": [{"answer_start": 60, "text": "Th\u1ee7 t\u01b0\u1edbng"}], "id": "uit_01__05272_0_2"}, {"question": "Giai \u0111o\u1ea1n n\u0103m 1955-1976, Ph\u1ea1m V\u0103n \u0110\u1ed3ng n\u1eafm gi\u1eef ch\u1ee9c v\u1ee5 g\u00ec?", "answers": [{"answer_start": 245, "text": "Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a"}], "id": "uit_01__05272_0_3"}], "context": "Ph\u1ea1m V\u0103n \u0110\u1ed3ng (1 th\u00e1ng 3 n\u0103m 1906 \u2013 29 th\u00e1ng 4 n\u0103m 2000) l\u00e0 Th\u1ee7 t\u01b0\u1edbng \u0111\u1ea7u ti\u00ean c\u1ee7a n\u01b0\u1edbc C\u1ed9ng h\u00f2a X\u00e3 h\u1ed9i ch\u1ee7 ngh\u0129a Vi\u1ec7t Nam t\u1eeb n\u0103m 1976 (t\u1eeb n\u0103m 1981 g\u1ecdi l\u00e0 Ch\u1ee7 t\u1ecbch H\u1ed9i \u0111\u1ed3ng B\u1ed9 tr\u01b0\u1edfng) cho \u0111\u1ebfn khi ngh\u1ec9 h\u01b0u n\u0103m 1987. Tr\u01b0\u1edbc \u0111\u00f3 \u00f4ng t\u1eebng gi\u1eef ch\u1ee9c v\u1ee5 Th\u1ee7 t\u01b0\u1edbng Ch\u00ednh ph\u1ee7 Vi\u1ec7t Nam D\u00e2n ch\u1ee7 C\u1ed9ng h\u00f2a t\u1eeb n\u0103m 1955 \u0111\u1ebfn n\u0103m 1976. \u00d4ng l\u00e0 v\u1ecb Th\u1ee7 t\u01b0\u1edbng Vi\u1ec7t Nam t\u1ea1i v\u1ecb l\u00e2u nh\u1ea5t (1955\u20131987). \u00d4ng l\u00e0 h\u1ecdc tr\u00f2, c\u1ed9ng s\u1ef1 c\u1ee7a Ch\u1ee7 t\u1ecbch H\u1ed3 Ch\u00ed Minh. \u00d4ng c\u00f3 t\u00ean g\u1ecdi th\u00e2n m\u1eadt l\u00e0 T\u00f4, \u0111\u00e2y t\u1eebng l\u00e0 b\u00ed danh c\u1ee7a \u00f4ng. \u00d4ng c\u00f2n c\u00f3 t\u00ean g\u1ecdi l\u00e0 L\u00e2m B\u00e1 Ki\u1ec7t khi l\u00e0m Ph\u00f3 ch\u1ee7 nhi\u1ec7m c\u01a1 quan Bi\u1ec7n s\u1ef1 x\u1ee9 t\u1ea1i Qu\u1ebf L\u00e2m (Ch\u1ee7 nhi\u1ec7m l\u00e0 H\u1ed3 H\u1ecdc L\u00e3m)."}, {"qas": [{"question": "S\u1ef1 ki\u1ec7n quan tr\u1ecdng n\u00e0o \u0111\u00e3 di\u1ec5n ra v\u00e0o ng\u00e0y 20/7/1954?", "answers": [{"answer_start": 364, "text": "b\u1ea3n Hi\u1ec7p \u0111\u1ecbnh \u0111\u00ecnh ch\u1ec9 chi\u1ebfn s\u1ef1 \u1edf Vi\u1ec7t Nam, Campuchia v\u00e0 L\u00e0o \u0111\u00e3 \u0111\u01b0\u1ee3c k\u00fd k\u1ebft th\u1eeba nh\u1eadn t\u00f4n tr\u1ecdng \u0111\u1ed9c l\u1eadp, ch\u1ee7 quy\u1ec1n, c\u1ee7a n\u01b0\u1edbc Vi\u1ec7t Nam, L\u00e0o v\u00e0 Campuchia"}], "id": "uit_01__05272_1_1"}, {"question": "Ch\u1ee9c v\u1ee5 m\u00e0 Ph\u1ea1m V\u0103n \u0110\u1ed3ng \u0111\u1ea3m nhi\u1ec7m t\u1ea1i H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng?", "answers": [{"answer_start": 33, "text": "Tr\u01b0\u1edfng ph\u00e1i \u0111o\u00e0n Ch\u00ednh ph\u1ee7"}], "id": "uit_01__05272_1_2"}, {"question": "H\u1ed9i ngh\u1ecb Gen\u00e8ve v\u1ec1 \u0110\u00f4ng D\u01b0\u01a1ng c\u00f3 t\u00ednh ch\u1ea5t nh\u01b0 th\u1ebf n\u00e0o?", "answers": [{"answer_start": 262, "text": "r\u1ea5t c\u0103ng th\u1eb3ng v\u00e0 ph\u1ee9c t\u1ea1p"}], "id": "uit_01__05272_1_3"}]}]}]}
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
EXPECTED = ['Tên gọi nào được Phạm Văn Đồng sử dụng khi làm Phó chủ nhiệm cơ quan Biện sự xứ tại Quế Lâm?', 'Phạm Văn Đồng giữ chức vụ gì trong bộ máy Nhà nước Cộng hòa Xã hội chủ nghĩa Việt Nam?', 'Giai đoạn năm 1955-1976, Phạm Văn Đồng nắm giữ chức vụ gì?', 'Sự kiện quan trọng nào đã diễn ra vào ngày 20/7/1954?', 'Chức vụ mà Phạm Văn Đồng đảm nhiệm tại Hội nghị Genève về Đông Dương?', 'Hội nghị Genève về Đông Dương có tính chất như thế nào?']
|
| 20 |
+
|
| 21 |
+
def test_read_file():
|
| 22 |
+
results = selftrain_vi_quad.parse_quad(SAMPLE_TEXT)
|
| 23 |
+
assert results == EXPECTED
|
stanza/stanza/tests/constituency/test_utils.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza import Pipeline
|
| 4 |
+
from stanza.models.constituency import tree_reader
|
| 5 |
+
from stanza.models.constituency import utils
|
| 6 |
+
|
| 7 |
+
from stanza.tests import *
|
| 8 |
+
|
| 9 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture(scope="module")
|
| 13 |
+
def pipeline():
|
| 14 |
+
return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos", tokenize_pretokenized=True)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_xpos_retag(pipeline):
|
| 19 |
+
"""
|
| 20 |
+
Test using the English tagger that trees will be correctly retagged by read_trees using xpos
|
| 21 |
+
"""
|
| 22 |
+
text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
|
| 23 |
+
expected = "((S (VP (VB Find)) (NP (NNP Mox) (NNP Opal)))) ((S (NP (NNP Ragavan)) (VP (VBZ steals) (NP (JJ important) (NNS cards)))))"
|
| 24 |
+
|
| 25 |
+
trees = tree_reader.read_trees(text)
|
| 26 |
+
|
| 27 |
+
new_trees = utils.retag_trees(trees, [pipeline], xpos=True)
|
| 28 |
+
assert new_trees == tree_reader.read_trees(expected)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_upos_retag(pipeline):
|
| 33 |
+
"""
|
| 34 |
+
Test using the English tagger that trees will be correctly retagged by read_trees using upos
|
| 35 |
+
"""
|
| 36 |
+
text = "((S (VP (X Find)) (NP (X Mox) (X Opal)))) ((S (NP (X Ragavan)) (VP (X steals) (NP (X important) (X cards)))))"
|
| 37 |
+
expected = "((S (VP (VERB Find)) (NP (PROPN Mox) (PROPN Opal)))) ((S (NP (PROPN Ragavan)) (VP (VERB steals) (NP (ADJ important) (NOUN cards)))))"
|
| 38 |
+
|
| 39 |
+
trees = tree_reader.read_trees(text)
|
| 40 |
+
|
| 41 |
+
new_trees = utils.retag_trees(trees, [pipeline], xpos=False)
|
| 42 |
+
assert new_trees == tree_reader.read_trees(expected)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_replace_tags():
|
| 46 |
+
"""
|
| 47 |
+
Test the underlying replace_tags method
|
| 48 |
+
|
| 49 |
+
Also tests that the method throws exceptions when it is supposed to
|
| 50 |
+
"""
|
| 51 |
+
text = "((S (VP (X Find)) (NP (X Mox) (X Opal))))"
|
| 52 |
+
expected = "((S (VP (A Find)) (NP (B Mox) (C Opal))))"
|
| 53 |
+
|
| 54 |
+
trees = tree_reader.read_trees(text)
|
| 55 |
+
|
| 56 |
+
new_tags = ["A", "B", "C"]
|
| 57 |
+
new_tree = trees[0].replace_tags(new_tags)
|
| 58 |
+
|
| 59 |
+
assert new_tree == tree_reader.read_trees(expected)[0]
|
| 60 |
+
|
| 61 |
+
with pytest.raises(ValueError):
|
| 62 |
+
new_tags = ["A", "B"]
|
| 63 |
+
new_tree = trees[0].replace_tags(new_tags)
|
| 64 |
+
|
| 65 |
+
with pytest.raises(ValueError):
|
| 66 |
+
new_tags = ["A", "B", "C", "D"]
|
| 67 |
+
new_tree = trees[0].replace_tags(new_tags)
|
| 68 |
+
|
stanza/stanza/tests/data/example_french.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"sentences":
|
| 2 |
+
[{"index": 0,
|
| 3 |
+
"tokens": [
|
| 4 |
+
{"index": 1, "word": "Cette", "originalText": "Cette", "characterOffsetBegin": 0, "characterOffsetEnd": 5, "pos": "DET", "before": "", "after": " "},
|
| 5 |
+
{"index": 2, "word": "enquête", "originalText": "enquête", "characterOffsetBegin": 6, "characterOffsetEnd": 13, "pos": "NOUN", "before": " ", "after": " "},
|
| 6 |
+
{"index": 3, "word": "préliminaire", "originalText": "préliminaire", "characterOffsetBegin": 14, "characterOffsetEnd": 26, "pos": "ADJ", "before": " ", "after": " "},
|
| 7 |
+
{"index": 4, "word": "fait", "originalText": "fait", "characterOffsetBegin": 27, "characterOffsetEnd": 31, "pos": "VERB", "before": " ", "after": " "},
|
| 8 |
+
{"index": 5, "word": "suite", "originalText": "suite", "characterOffsetBegin": 32, "characterOffsetEnd": 37, "pos": "NOUN", "before": " ", "after": " "},
|
| 9 |
+
{"index": 6, "word": "à", "originalText": "à", "characterOffsetBegin": 38, "characterOffsetEnd": 41, "pos": "ADP", "before": " ", "after": " "},
|
| 10 |
+
{"index": 7, "word": "les", "originalText": "les", "characterOffsetBegin": 38, "characterOffsetEnd": 41, "pos": "DET", "before": " ", "after": " "},
|
| 11 |
+
{"index": 8, "word": "révélations", "originalText": "révélations", "characterOffsetBegin": 42, "characterOffsetEnd": 53, "pos": "NOUN", "before": " ", "after": " "},
|
| 12 |
+
{"index": 9, "word": "de", "originalText": "de", "characterOffsetBegin": 54, "characterOffsetEnd": 56, "pos": "ADP", "before": " ", "after": " "},
|
| 13 |
+
{"index": 10, "word": "l’", "originalText": "l’", "characterOffsetBegin": 57, "characterOffsetEnd": 59, "pos": "NOUN", "before": " ", "after": ""},
|
| 14 |
+
{"index": 11, "word": "hebdomadaire", "originalText": "hebdomadaire", "characterOffsetBegin": 59, "characterOffsetEnd": 71, "pos": "ADJ", "before": "", "after": " "},
|
| 15 |
+
{"index": 12, "word": "quelques", "originalText": "quelques", "characterOffsetBegin": 72, "characterOffsetEnd": 80, "pos": "DET", "before": " ", "after": " "},
|
| 16 |
+
{"index": 13, "word": "jours", "originalText": "jours", "characterOffsetBegin": 81, "characterOffsetEnd": 86, "pos": "NOUN", "before": " ", "after": " "},
|
| 17 |
+
{"index": 14, "word": "plus", "originalText": "plus", "characterOffsetBegin": 87, "characterOffsetEnd": 91, "pos": "ADV", "before": " ", "after": " "},
|
| 18 |
+
{"index": 15, "word": "tôt", "originalText": "tôt", "characterOffsetBegin": 92, "characterOffsetEnd": 95, "pos": "ADV", "before": " ", "after": ""},
|
| 19 |
+
{"index": 16, "word": ".", "originalText": ".", "characterOffsetBegin": 95, "characterOffsetEnd": 96, "pos": "PUNCT", "before": "", "after": ""}
|
| 20 |
+
]}
|
| 21 |
+
]
|
| 22 |
+
}
|
stanza/stanza/tests/data/test.dat
ADDED
|
Binary file (4.24 kB). View file
|
|
|
stanza/stanza/tests/data/tiny_emb.csv
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
3 4
|
| 2 |
+
unban,1,2,3,4
|
| 3 |
+
mox,5,6,7,8
|
| 4 |
+
opal,9,10,11,12
|
stanza/stanza/tests/datasets/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/datasets/ner/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/datasets/ner/test_prepare_ner_file.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test some simple conversions of NER bio files
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from stanza.models.common.doc import Document
|
| 10 |
+
from stanza.utils.datasets.ner.prepare_ner_file import process_dataset
|
| 11 |
+
|
| 12 |
+
BIO_1 = """
|
| 13 |
+
Jennifer B-PERSON
|
| 14 |
+
Sh'reyan I-PERSON
|
| 15 |
+
has O
|
| 16 |
+
lovely O
|
| 17 |
+
antennae O
|
| 18 |
+
""".strip()
|
| 19 |
+
|
| 20 |
+
BIO_2 = """
|
| 21 |
+
but O
|
| 22 |
+
I O
|
| 23 |
+
don't O
|
| 24 |
+
like O
|
| 25 |
+
the O
|
| 26 |
+
way O
|
| 27 |
+
Jennifer B-PERSON
|
| 28 |
+
treated O
|
| 29 |
+
Beckett B-PERSON
|
| 30 |
+
on O
|
| 31 |
+
the O
|
| 32 |
+
Cerritos B-LOCATION
|
| 33 |
+
""".strip()
|
| 34 |
+
|
| 35 |
+
def check_json_file(doc, raw_text, expected_sentences, expected_tokens):
|
| 36 |
+
raw_sentences = raw_text.strip().split("\n\n")
|
| 37 |
+
assert len(raw_sentences) == expected_sentences
|
| 38 |
+
if isinstance(expected_tokens, int):
|
| 39 |
+
expected_tokens = [expected_tokens]
|
| 40 |
+
for raw_sentence, expected_len in zip(raw_sentences, expected_tokens):
|
| 41 |
+
assert len(raw_sentence.strip().split("\n")) == expected_len
|
| 42 |
+
|
| 43 |
+
assert len(doc.sentences) == expected_sentences
|
| 44 |
+
for sentence, expected_len in zip(doc.sentences, expected_tokens):
|
| 45 |
+
assert len(sentence.tokens) == expected_len
|
| 46 |
+
for sentence, raw_sentence in zip(doc.sentences, raw_sentences):
|
| 47 |
+
for token, line in zip(sentence.tokens, raw_sentence.strip().split("\n")):
|
| 48 |
+
word, tag = line.strip().split()
|
| 49 |
+
assert token.text == word
|
| 50 |
+
assert token.ner == tag
|
| 51 |
+
|
| 52 |
+
def write_and_convert(tmp_path, raw_text):
|
| 53 |
+
bio_file = tmp_path / "test.bio"
|
| 54 |
+
with open(bio_file, "w", encoding="utf-8") as fout:
|
| 55 |
+
fout.write(raw_text)
|
| 56 |
+
|
| 57 |
+
json_file = tmp_path / "json.bio"
|
| 58 |
+
process_dataset(bio_file, json_file)
|
| 59 |
+
|
| 60 |
+
with open(json_file) as fin:
|
| 61 |
+
doc = Document(json.load(fin))
|
| 62 |
+
|
| 63 |
+
return doc
|
| 64 |
+
|
| 65 |
+
def run_test(tmp_path, raw_text, expected_sentences, expected_tokens):
|
| 66 |
+
doc = write_and_convert(tmp_path, raw_text)
|
| 67 |
+
check_json_file(doc, raw_text, expected_sentences, expected_tokens)
|
| 68 |
+
|
| 69 |
+
def test_simple(tmp_path):
|
| 70 |
+
run_test(tmp_path, BIO_1, 1, 5)
|
| 71 |
+
|
| 72 |
+
def test_ner_at_end(tmp_path):
|
| 73 |
+
run_test(tmp_path, BIO_2, 1, 12)
|
| 74 |
+
|
| 75 |
+
def test_two_sentences(tmp_path):
|
| 76 |
+
raw_text = BIO_1 + "\n\n" + BIO_2
|
| 77 |
+
run_test(tmp_path, raw_text, 2, [5, 12])
|
stanza/stanza/tests/datasets/ner/test_utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test the utils file of the NER dataset processing
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from stanza.utils.datasets.ner.utils import list_doc_entities
|
| 8 |
+
from stanza.tests.datasets.ner.test_prepare_ner_file import BIO_1, BIO_2, write_and_convert
|
| 9 |
+
|
| 10 |
+
def test_list_doc_entities(tmp_path):
|
| 11 |
+
"""
|
| 12 |
+
Test the function which lists all of the entities in a doc
|
| 13 |
+
"""
|
| 14 |
+
doc = write_and_convert(tmp_path, BIO_1)
|
| 15 |
+
entities = list_doc_entities(doc)
|
| 16 |
+
expected = [(('Jennifer', "Sh'reyan"), 'PERSON')]
|
| 17 |
+
assert expected == entities
|
| 18 |
+
|
| 19 |
+
doc = write_and_convert(tmp_path, BIO_2)
|
| 20 |
+
entities = list_doc_entities(doc)
|
| 21 |
+
expected = [(('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
|
| 22 |
+
assert expected == entities
|
| 23 |
+
|
| 24 |
+
doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_2]))
|
| 25 |
+
entities = list_doc_entities(doc)
|
| 26 |
+
expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
|
| 27 |
+
assert expected == entities
|
| 28 |
+
|
| 29 |
+
doc = write_and_convert(tmp_path, "\n\n".join([BIO_1, BIO_1, BIO_2]))
|
| 30 |
+
entities = list_doc_entities(doc)
|
| 31 |
+
expected = [(('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer', "Sh'reyan"), 'PERSON'), (('Jennifer',), 'PERSON'), (('Beckett',), 'PERSON'), (('Cerritos',), 'LOCATION')]
|
| 32 |
+
assert expected == entities
|
| 33 |
+
|
| 34 |
+
|
stanza/stanza/tests/lemma/test_data.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple basic data functions, such as processing a doc for its lemmas
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from stanza.models.common.doc import Document
|
| 8 |
+
from stanza.models.lemma.data import DataLoader
|
| 9 |
+
from stanza.utils.conll import CoNLL
|
| 10 |
+
|
| 11 |
+
TRAIN_DATA = """
|
| 12 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
|
| 13 |
+
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
|
| 14 |
+
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
|
| 15 |
+
2 : : PUNCT : _ 1 punct 1:punct _
|
| 16 |
+
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
|
| 17 |
+
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
|
| 18 |
+
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
|
| 19 |
+
6 that that SCONJ IN _ 9 mark 9:mark _
|
| 20 |
+
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
|
| 21 |
+
8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
|
| 22 |
+
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
|
| 23 |
+
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
|
| 24 |
+
11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
|
| 25 |
+
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
|
| 26 |
+
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
|
| 27 |
+
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
|
| 28 |
+
15 in in ADP IN _ 16 case 16:case _
|
| 29 |
+
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
|
| 30 |
+
17 . . PUNCT . _ 1 punct 1:punct _
|
| 31 |
+
|
| 32 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
|
| 33 |
+
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
|
| 34 |
+
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
|
| 35 |
+
2 of of ADP IN _ 3 case 3:case _
|
| 36 |
+
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
|
| 37 |
+
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
|
| 38 |
+
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
|
| 39 |
+
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
|
| 40 |
+
7 by by ADP IN _ 9 case 9:case _
|
| 41 |
+
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
|
| 42 |
+
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
|
| 43 |
+
10 of of ADP IN _ 12 case 12:case _
|
| 44 |
+
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
|
| 45 |
+
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
|
| 46 |
+
13 of of ADP IN _ 15 case 15:case _
|
| 47 |
+
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
|
| 48 |
+
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
|
| 49 |
+
16 ! ! PUNCT . _ 6 punct 6:punct _
|
| 50 |
+
|
| 51 |
+
""".lstrip()
|
| 52 |
+
|
| 53 |
+
GOESWITH_DATA = """
|
| 54 |
+
# sent_id = email-enronsent27_01-0041
|
| 55 |
+
# newpar id = email-enronsent27_01-p0005
|
| 56 |
+
# text = Ken Rice@ENRON COMMUNICATIONS
|
| 57 |
+
1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _
|
| 58 |
+
2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _
|
| 59 |
+
3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _
|
| 60 |
+
|
| 61 |
+
""".lstrip()
|
| 62 |
+
|
| 63 |
+
CORRECT_FORM_DATA = """
|
| 64 |
+
# sent_id = weblog-blogspot.com_healingiraq_20040409053012_ENG_20040409_053012-0019
|
| 65 |
+
# text = They are targetting ambulances
|
| 66 |
+
1 They they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 3 nsubj 3:nsubj _
|
| 67 |
+
2 are be AUX VBP Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
|
| 68 |
+
3 targetting target VERB VBG Tense=Pres|Typo=Yes|VerbForm=Part 0 root 0:root CorrectForm=targeting
|
| 69 |
+
4 ambulances ambulance NOUN NNS Number=Plur 3 obj 3:obj SpaceAfter=No
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_load_document():
|
| 74 |
+
train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)
|
| 75 |
+
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
|
| 76 |
+
assert len(data) == 33 # meticulously counted by hand
|
| 77 |
+
assert all(len(x) == 3 for x in data)
|
| 78 |
+
|
| 79 |
+
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
|
| 80 |
+
assert len(data) == 33
|
| 81 |
+
assert all(len(x) == 3 for x in data)
|
| 82 |
+
|
| 83 |
+
def test_load_goeswith():
|
| 84 |
+
raw_data = TRAIN_DATA + GOESWITH_DATA
|
| 85 |
+
train_doc = CoNLL.conll2doc(input_str=raw_data)
|
| 86 |
+
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
|
| 87 |
+
assert len(data) == 36 # will be the same as in test_load_document with three additional words
|
| 88 |
+
assert all(len(x) == 3 for x in data)
|
| 89 |
+
|
| 90 |
+
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
|
| 91 |
+
assert len(data) == 33 # will be the same as in test_load_document, but with the trailing 3 GOESWITH removed
|
| 92 |
+
assert all(len(x) == 3 for x in data)
|
| 93 |
+
|
| 94 |
+
def test_correct_form():
|
| 95 |
+
raw_data = TRAIN_DATA + CORRECT_FORM_DATA
|
| 96 |
+
train_doc = CoNLL.conll2doc(input_str=raw_data)
|
| 97 |
+
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
|
| 98 |
+
assert len(data) == 37
|
| 99 |
+
# the 'targeting' correction should not be applied if evaluation=True
|
| 100 |
+
# when evaluation=False, then the CorrectForms will be applied
|
| 101 |
+
assert not any(x[0] == 'targeting' for x in data)
|
| 102 |
+
|
| 103 |
+
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
|
| 104 |
+
assert len(data) == 38 # the same, but with an extra row so the model learns both 'targetting' and 'targeting'
|
| 105 |
+
assert any(x[0] == 'targeting' for x in data)
|
| 106 |
+
assert any(x[0] == 'targetting' for x in data)
|
stanza/stanza/tests/lemma/test_lemma_trainer.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple basic functions - load & save an existing model
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
import glob
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from stanza.models import lemmatizer
|
| 14 |
+
from stanza.models.lemma import trainer
|
| 15 |
+
from stanza.tests import *
|
| 16 |
+
from stanza.utils.training.common import choose_lemma_charlm, build_charlm_args
|
| 17 |
+
|
| 18 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 19 |
+
|
| 20 |
+
@pytest.fixture(scope="module")
|
| 21 |
+
def english_model():
|
| 22 |
+
models_path = os.path.join(TEST_MODELS_DIR, "en", "lemma", "*")
|
| 23 |
+
models = glob.glob(models_path)
|
| 24 |
+
# we expect at least one English model downloaded for the tests
|
| 25 |
+
assert len(models) >= 1
|
| 26 |
+
model_file = models[0]
|
| 27 |
+
return trainer.Trainer(model_file=model_file)
|
| 28 |
+
|
| 29 |
+
def test_load_model(english_model):
|
| 30 |
+
"""
|
| 31 |
+
Does nothing, just tests that loading works
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def test_save_load_model(english_model):
|
| 35 |
+
"""
|
| 36 |
+
Load, save, and load again
|
| 37 |
+
"""
|
| 38 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 39 |
+
save_file = os.path.join(tempdir, "resaved", "lemma.pt")
|
| 40 |
+
english_model.save(save_file)
|
| 41 |
+
reloaded = trainer.Trainer(model_file=save_file)
|
| 42 |
+
|
| 43 |
+
TRAIN_DATA = """
|
| 44 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
|
| 45 |
+
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
|
| 46 |
+
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
|
| 47 |
+
2 : : PUNCT : _ 1 punct 1:punct _
|
| 48 |
+
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
|
| 49 |
+
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
|
| 50 |
+
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
|
| 51 |
+
6 that that SCONJ IN _ 9 mark 9:mark _
|
| 52 |
+
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
|
| 53 |
+
8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
|
| 54 |
+
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
|
| 55 |
+
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
|
| 56 |
+
11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
|
| 57 |
+
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
|
| 58 |
+
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
|
| 59 |
+
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
|
| 60 |
+
15 in in ADP IN _ 16 case 16:case _
|
| 61 |
+
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
|
| 62 |
+
17 . . PUNCT . _ 1 punct 1:punct _
|
| 63 |
+
|
| 64 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
|
| 65 |
+
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
|
| 66 |
+
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
|
| 67 |
+
2 of of ADP IN _ 3 case 3:case _
|
| 68 |
+
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
|
| 69 |
+
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
|
| 70 |
+
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
|
| 71 |
+
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
|
| 72 |
+
7 by by ADP IN _ 9 case 9:case _
|
| 73 |
+
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
|
| 74 |
+
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
|
| 75 |
+
10 of of ADP IN _ 12 case 12:case _
|
| 76 |
+
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
|
| 77 |
+
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
|
| 78 |
+
13 of of ADP IN _ 15 case 15:case _
|
| 79 |
+
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
|
| 80 |
+
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
|
| 81 |
+
16 ! ! PUNCT . _ 6 punct 6:punct _
|
| 82 |
+
|
| 83 |
+
""".lstrip()
|
| 84 |
+
|
| 85 |
+
DEV_DATA = """
|
| 86 |
+
1 From from ADP IN _ 3 case 3:case _
|
| 87 |
+
2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
|
| 88 |
+
3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
|
| 89 |
+
4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
|
| 90 |
+
5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
|
| 91 |
+
6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
|
| 92 |
+
7 : : PUNCT : _ 4 punct 4:punct _
|
| 93 |
+
|
| 94 |
+
""".lstrip()
|
| 95 |
+
|
| 96 |
+
class TestLemmatizer:
|
| 97 |
+
@pytest.fixture(scope="class")
|
| 98 |
+
def charlm_args(self):
|
| 99 |
+
charlm = choose_lemma_charlm("en", "test", "default")
|
| 100 |
+
charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR)
|
| 101 |
+
return charlm_args
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def run_training(self, tmp_path, train_text, dev_text, extra_args=None):
|
| 105 |
+
"""
|
| 106 |
+
Run the training for a few iterations, load & return the model
|
| 107 |
+
"""
|
| 108 |
+
pred_file = str(tmp_path / "pred.conllu")
|
| 109 |
+
|
| 110 |
+
save_name = "test_tagger.pt"
|
| 111 |
+
save_file = str(tmp_path / save_name)
|
| 112 |
+
|
| 113 |
+
train_file = str(tmp_path / "train.conllu")
|
| 114 |
+
with open(train_file, "w", encoding="utf-8") as fout:
|
| 115 |
+
fout.write(train_text)
|
| 116 |
+
|
| 117 |
+
dev_file = str(tmp_path / "dev.conllu")
|
| 118 |
+
with open(dev_file, "w", encoding="utf-8") as fout:
|
| 119 |
+
fout.write(dev_text)
|
| 120 |
+
|
| 121 |
+
args = ["--train_file", train_file,
|
| 122 |
+
"--eval_file", dev_file,
|
| 123 |
+
"--gold_file", dev_file,
|
| 124 |
+
"--output_file", pred_file,
|
| 125 |
+
"--num_epoch", "2",
|
| 126 |
+
"--log_step", "10",
|
| 127 |
+
"--save_dir", str(tmp_path),
|
| 128 |
+
"--save_name", save_name,
|
| 129 |
+
"--shorthand", "en_test"]
|
| 130 |
+
if extra_args is not None:
|
| 131 |
+
args = args + extra_args
|
| 132 |
+
lemmatizer.main(args)
|
| 133 |
+
|
| 134 |
+
assert os.path.exists(save_file)
|
| 135 |
+
saved_model = trainer.Trainer(model_file=save_file)
|
| 136 |
+
return saved_model
|
| 137 |
+
|
| 138 |
+
def test_basic_train(self, tmp_path):
|
| 139 |
+
"""
|
| 140 |
+
Simple test of a few 'epochs' of lemmatizer training
|
| 141 |
+
"""
|
| 142 |
+
self.run_training(tmp_path, TRAIN_DATA, DEV_DATA)
|
| 143 |
+
|
| 144 |
+
def test_charlm_train(self, tmp_path, charlm_args):
|
| 145 |
+
"""
|
| 146 |
+
Simple test of a few 'epochs' of lemmatizer training
|
| 147 |
+
"""
|
| 148 |
+
saved_model = self.run_training(tmp_path, TRAIN_DATA, DEV_DATA, extra_args=charlm_args)
|
| 149 |
+
|
| 150 |
+
# check that the charlm wasn't saved in here
|
| 151 |
+
args = saved_model.args
|
| 152 |
+
save_name = os.path.join(args['save_dir'], args['save_name'])
|
| 153 |
+
checkpoint = torch.load(save_name, lambda storage, loc: storage, weights_only=True)
|
| 154 |
+
assert not any(x.startswith("contextual_embedding") for x in checkpoint['model'].keys())
|
stanza/stanza/tests/lemma_classifier/test_data_preparation.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
import stanza.models.lemma_classifier.utils as utils
|
| 6 |
+
import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier
|
| 7 |
+
|
| 8 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 9 |
+
|
| 10 |
+
EWT_ONE_SENTENCE = """
|
| 11 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002
|
| 12 |
+
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002
|
| 13 |
+
# text = Here's a Miami Herald interview
|
| 14 |
+
1-2 Here's _ _ _ _ _ _ _ _
|
| 15 |
+
1 Here here ADV RB PronType=Dem 0 root 0:root _
|
| 16 |
+
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _
|
| 17 |
+
3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
|
| 18 |
+
4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _
|
| 19 |
+
5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _
|
| 20 |
+
6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _
|
| 21 |
+
""".lstrip()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
EWT_TRAIN_SENTENCES = """
|
| 25 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002
|
| 26 |
+
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002
|
| 27 |
+
# text = Here's a Miami Herald interview
|
| 28 |
+
1-2 Here's _ _ _ _ _ _ _ _
|
| 29 |
+
1 Here here ADV RB PronType=Dem 0 root 0:root _
|
| 30 |
+
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _
|
| 31 |
+
3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
|
| 32 |
+
4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _
|
| 33 |
+
5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _
|
| 34 |
+
6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _
|
| 35 |
+
|
| 36 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0027
|
| 37 |
+
# text = But Posada's nearly 80 years old
|
| 38 |
+
1 But but CCONJ CC _ 7 cc 7:cc _
|
| 39 |
+
2-3 Posada's _ _ _ _ _ _ _ _
|
| 40 |
+
2 Posada Posada PROPN NNP Number=Sing 7 nsubj 7:nsubj _
|
| 41 |
+
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 cop 7:cop _
|
| 42 |
+
4 nearly nearly ADV RB _ 5 advmod 5:advmod _
|
| 43 |
+
5 80 80 NUM CD NumForm=Digit|NumType=Card 6 nummod 6:nummod _
|
| 44 |
+
6 years year NOUN NNS Number=Plur 7 obl:npmod 7:obl:npmod _
|
| 45 |
+
7 old old ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No
|
| 46 |
+
|
| 47 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0067
|
| 48 |
+
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0011
|
| 49 |
+
# text = Now that's a post I can relate to.
|
| 50 |
+
1 Now now ADV RB _ 5 advmod 5:advmod _
|
| 51 |
+
2-3 that's _ _ _ _ _ _ _ _
|
| 52 |
+
2 that that PRON DT Number=Sing|PronType=Dem 5 nsubj 5:nsubj _
|
| 53 |
+
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
|
| 54 |
+
4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
|
| 55 |
+
5 post post NOUN NN Number=Sing 0 root 0:root _
|
| 56 |
+
6 I I PRON PRP Case=Nom|Number=Sing|Person=1|PronType=Prs 8 nsubj 8:nsubj _
|
| 57 |
+
7 can can AUX MD VerbForm=Fin 8 aux 8:aux _
|
| 58 |
+
8 relate relate VERB VB VerbForm=Inf 5 acl:relcl 5:acl:relcl _
|
| 59 |
+
9 to to ADP IN _ 8 obl 8:obl SpaceAfter=No
|
| 60 |
+
10 . . PUNCT . _ 5 punct 5:punct _
|
| 61 |
+
|
| 62 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0073
|
| 63 |
+
# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0012
|
| 64 |
+
# text = hey that's a great blog
|
| 65 |
+
1 hey hey INTJ UH _ 6 discourse 6:discourse _
|
| 66 |
+
2-3 that's _ _ _ _ _ _ _ _
|
| 67 |
+
2 that that PRON DT Number=Sing|PronType=Dem 6 nsubj 6:nsubj _
|
| 68 |
+
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
|
| 69 |
+
4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
|
| 70 |
+
5 great great ADJ JJ Degree=Pos 6 amod 6:amod _
|
| 71 |
+
6 blog blog NOUN NN Number=Sing 0 root 0:root SpaceAfter=No
|
| 72 |
+
|
| 73 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0089
|
| 74 |
+
# text = And It's Not Hard To Do
|
| 75 |
+
1 And and CCONJ CC _ 5 cc 5:cc _
|
| 76 |
+
2-3 It's _ _ _ _ _ _ _ _
|
| 77 |
+
2 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 expl 5:expl _
|
| 78 |
+
3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
|
| 79 |
+
4 Not not PART RB _ 5 advmod 5:advmod _
|
| 80 |
+
5 Hard hard ADJ JJ Degree=Pos 0 root 0:root _
|
| 81 |
+
6 To to PART TO _ 7 mark 7:mark _
|
| 82 |
+
7 Do do VERB VB VerbForm=Inf 5 csubj 5:csubj SpaceAfter=No
|
| 83 |
+
|
| 84 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0029
|
| 85 |
+
# text = Meanwhile, a decision's been reached
|
| 86 |
+
1 Meanwhile meanwhile ADV RB _ 7 advmod 7:advmod SpaceAfter=No
|
| 87 |
+
2 , , PUNCT , _ 1 punct 1:punct _
|
| 88 |
+
3 a a DET DT Definite=Ind|PronType=Art 4 det 4:det _
|
| 89 |
+
4-5 decision's _ _ _ _ _ _ _ _
|
| 90 |
+
4 decision decision NOUN NN Number=Sing 7 nsubj:pass 7:nsubj:pass _
|
| 91 |
+
5 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux 7:aux _
|
| 92 |
+
6 been be AUX VBN Tense=Past|VerbForm=Part 7 aux:pass 7:aux:pass _
|
| 93 |
+
7 reached reach VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
|
| 94 |
+
|
| 95 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0138
|
| 96 |
+
# text = It's become a guardian of morality
|
| 97 |
+
1-2 It's _ _ _ _ _ _ _ _
|
| 98 |
+
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|5:nsubj:xsubj _
|
| 99 |
+
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
|
| 100 |
+
3 become become VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
|
| 101 |
+
4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
|
| 102 |
+
5 guardian guardian NOUN NN Number=Sing 3 xcomp 3:xcomp _
|
| 103 |
+
6 of of ADP IN _ 7 case 7:case _
|
| 104 |
+
7 morality morality NOUN NN Number=Sing 5 nmod 5:nmod:of _
|
| 105 |
+
|
| 106 |
+
# sent_id = email-enronsent15_01-0018
|
| 107 |
+
# text = It's got its own bathroom and tv
|
| 108 |
+
1-2 It's _ _ _ _ _ _ _ _
|
| 109 |
+
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|13:nsubj _
|
| 110 |
+
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _
|
| 111 |
+
3 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
|
| 112 |
+
4 its its PRON PRP$ Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 6 nmod:poss 6:nmod:poss _
|
| 113 |
+
5 own own ADJ JJ Degree=Pos 6 amod 6:amod _
|
| 114 |
+
6 bathroom bathroom NOUN NN Number=Sing 3 obj 3:obj _
|
| 115 |
+
7 and and CCONJ CC _ 8 cc 8:cc _
|
| 116 |
+
8 tv TV NOUN NN Number=Sing 6 conj 3:obj|6:conj:and SpaceAfter=No
|
| 117 |
+
|
| 118 |
+
# sent_id = newsgroup-groups.google.com_alt.animals.cat_01ff709c4bf2c60c_ENG_20040418_040100-0022
|
| 119 |
+
# text = It's also got the website
|
| 120 |
+
1-2 It's _ _ _ _ _ _ _ _
|
| 121 |
+
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _
|
| 122 |
+
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _
|
| 123 |
+
3 also also ADV RB _ 4 advmod 4:advmod _
|
| 124 |
+
4 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
|
| 125 |
+
5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _
|
| 126 |
+
6 website website NOUN NN Number=Sing 4 obj 4:obj|12:obl _
|
| 127 |
+
""".lstrip()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# from the train set, actually
|
| 131 |
+
EWT_DEV_SENTENCES = """
|
| 132 |
+
# sent_id = answers-20111108104724AAuBUR7_ans-0044
|
| 133 |
+
# text = He's only exhibited weight loss and some muscle atrophy
|
| 134 |
+
1-2 He's _ _ _ _ _ _ _ _
|
| 135 |
+
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _
|
| 136 |
+
2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _
|
| 137 |
+
3 only only ADV RB _ 4 advmod 4:advmod _
|
| 138 |
+
4 exhibited exhibit VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _
|
| 139 |
+
5 weight weight NOUN NN Number=Sing 6 compound 6:compound _
|
| 140 |
+
6 loss loss NOUN NN Number=Sing 4 obj 4:obj _
|
| 141 |
+
7 and and CCONJ CC _ 10 cc 10:cc _
|
| 142 |
+
8 some some DET DT PronType=Ind 10 det 10:det _
|
| 143 |
+
9 muscle muscle NOUN NN Number=Sing 10 compound 10:compound _
|
| 144 |
+
10 atrophy atrophy NOUN NN Number=Sing 6 conj 4:obj|6:conj:and SpaceAfter=No
|
| 145 |
+
|
| 146 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0097
|
| 147 |
+
# text = It's a good thing too.
|
| 148 |
+
1-2 It's _ _ _ _ _ _ _ _
|
| 149 |
+
1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _
|
| 150 |
+
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _
|
| 151 |
+
3 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _
|
| 152 |
+
4 good good ADJ JJ Degree=Pos 5 amod 5:amod _
|
| 153 |
+
5 thing thing NOUN NN Number=Sing 0 root 0:root _
|
| 154 |
+
6 too too ADV RB _ 5 advmod 5:advmod SpaceAfter=No
|
| 155 |
+
7 . . PUNCT . _ 5 punct 5:punct _
|
| 156 |
+
""".lstrip()
|
| 157 |
+
|
| 158 |
+
# from the train set, actually
|
| 159 |
+
EWT_TEST_SENTENCES = """
|
| 160 |
+
# sent_id = reviews-162422-0015
|
| 161 |
+
# text = He said he's had a long and bad day.
|
| 162 |
+
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 2 nsubj 2:nsubj _
|
| 163 |
+
2 said say VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _
|
| 164 |
+
3-4 he's _ _ _ _ _ _ _ _
|
| 165 |
+
3 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _
|
| 166 |
+
4 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux 5:aux _
|
| 167 |
+
5 had have VERB VBN Tense=Past|VerbForm=Part 2 ccomp 2:ccomp _
|
| 168 |
+
6 a a DET DT Definite=Ind|PronType=Art 10 det 10:det _
|
| 169 |
+
7 long long ADJ JJ Degree=Pos 10 amod 10:amod _
|
| 170 |
+
8 and and CCONJ CC _ 9 cc 9:cc _
|
| 171 |
+
9 bad bad ADJ JJ Degree=Pos 7 conj 7:conj:and|10:amod _
|
| 172 |
+
10 day day NOUN NN Number=Sing 5 obj 5:obj SpaceAfter=No
|
| 173 |
+
11 . . PUNCT . _ 2 punct 2:punct _
|
| 174 |
+
|
| 175 |
+
# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0100
|
| 176 |
+
# text = What's a few dead soldiers
|
| 177 |
+
1-2 What's _ _ _ _ _ _ _ _
|
| 178 |
+
1 What what PRON WP PronType=Int 6 nsubj 6:nsubj _
|
| 179 |
+
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
|
| 180 |
+
3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _
|
| 181 |
+
4 few few ADJ JJ Degree=Pos 6 amod 6:amod _
|
| 182 |
+
5 dead dead ADJ JJ Degree=Pos 6 amod 6:amod _
|
| 183 |
+
6 soldiers soldier NOUN NNS Number=Plur 0 root 0:root _
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def write_test_dataset(tmp_path, texts, datasets):
|
| 187 |
+
ud_path = tmp_path / "ud"
|
| 188 |
+
input_path = ud_path / "UD_English-EWT"
|
| 189 |
+
output_path = tmp_path / "data" / "lemma_classifier"
|
| 190 |
+
|
| 191 |
+
os.makedirs(input_path, exist_ok=True)
|
| 192 |
+
|
| 193 |
+
for text, dataset in zip(texts, datasets):
|
| 194 |
+
sample_file = input_path / ("en_ewt-ud-%s.conllu" % dataset)
|
| 195 |
+
with open(sample_file, "w", encoding="utf-8") as fout:
|
| 196 |
+
fout.write(text)
|
| 197 |
+
|
| 198 |
+
paths = {"UDBASE": ud_path,
|
| 199 |
+
"LEMMA_CLASSIFIER_DATA_DIR": output_path}
|
| 200 |
+
|
| 201 |
+
return paths
|
| 202 |
+
|
| 203 |
+
def write_english_test_dataset(tmp_path):
|
| 204 |
+
texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES)
|
| 205 |
+
datasets = prepare_lemma_classifier.SECTIONS
|
| 206 |
+
return write_test_dataset(tmp_path, texts, datasets)
|
| 207 |
+
|
| 208 |
+
def convert_english_dataset(tmp_path):
|
| 209 |
+
paths = write_english_test_dataset(tmp_path)
|
| 210 |
+
converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have")
|
| 211 |
+
assert len(converted_files) == 3
|
| 212 |
+
|
| 213 |
+
return converted_files
|
| 214 |
+
|
| 215 |
+
def test_convert_one_sentence(tmp_path):
|
| 216 |
+
texts = [EWT_ONE_SENTENCE]
|
| 217 |
+
datasets = ["train"]
|
| 218 |
+
paths = write_test_dataset(tmp_path, texts, datasets)
|
| 219 |
+
|
| 220 |
+
converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have", ["train"])
|
| 221 |
+
assert len(converted_files) == 1
|
| 222 |
+
|
| 223 |
+
dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)
|
| 224 |
+
|
| 225 |
+
assert len(dataset) == 1
|
| 226 |
+
assert dataset.label_decoder == {'be': 0}
|
| 227 |
+
id_to_upos = {y: x for x, y in dataset.upos_to_id.items()}
|
| 228 |
+
|
| 229 |
+
for text_batches, _, upos_batches, _ in dataset:
|
| 230 |
+
assert text_batches == [['Here', "'s", 'a', 'Miami', 'Herald', 'interview']]
|
| 231 |
+
upos = [id_to_upos[x] for x in upos_batches[0]]
|
| 232 |
+
assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN']
|
| 233 |
+
|
| 234 |
+
def test_convert_dataset(tmp_path):
|
| 235 |
+
converted_files = convert_english_dataset(tmp_path)
|
| 236 |
+
|
| 237 |
+
dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False)
|
| 238 |
+
|
| 239 |
+
assert len(dataset) == 1
|
| 240 |
+
label_decoder = dataset.label_decoder
|
| 241 |
+
assert len(label_decoder) == 2
|
| 242 |
+
assert "be" in label_decoder
|
| 243 |
+
assert "have" in label_decoder
|
| 244 |
+
for text_batches, _, _, _ in dataset:
|
| 245 |
+
assert len(text_batches) == 9
|
| 246 |
+
|
| 247 |
+
dataset = utils.Dataset(converted_files[1], get_counts=True, batch_size=10, shuffle=False)
|
| 248 |
+
assert len(dataset) == 1
|
| 249 |
+
for text_batches, _, _, _ in dataset:
|
| 250 |
+
assert len(text_batches) == 2
|
| 251 |
+
|
| 252 |
+
dataset = utils.Dataset(converted_files[2], get_counts=True, batch_size=10, shuffle=False)
|
| 253 |
+
assert len(dataset) == 1
|
| 254 |
+
for text_batches, _, _, _ in dataset:
|
| 255 |
+
assert len(text_batches) == 2
|
| 256 |
+
|
stanza/stanza/tests/mwt/test_character_classifier.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
from stanza.models import mwt_expander
|
| 5 |
+
from stanza.models.mwt.character_classifier import CharacterClassifier
|
| 6 |
+
from stanza.models.mwt.data import DataLoader
|
| 7 |
+
from stanza.models.mwt.trainer import Trainer
|
| 8 |
+
from stanza.utils.conll import CoNLL
|
| 9 |
+
|
| 10 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 11 |
+
|
| 12 |
+
ENG_TRAIN = """
|
| 13 |
+
# text = Elena's motorcycle tour
|
| 14 |
+
1-2 Elena's _ _ _ _ _ _ _ _
|
| 15 |
+
1 Elena Elena PROPN NNP Number=Sing 4 nmod:poss 4:nmod:poss _
|
| 16 |
+
2 's 's PART POS _ 1 case 1:case _
|
| 17 |
+
3 motorcycle motorcycle NOUN NN Number=Sing 4 compound 4:compound _
|
| 18 |
+
4 tour tour NOUN NN Number=Sing 0 root 0:root _
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# text = women's reproductive health
|
| 22 |
+
1-2 women's _ _ _ _ _ _ _ _
|
| 23 |
+
1 women woman NOUN NNS Number=Plur 4 nmod:poss 4:nmod:poss _
|
| 24 |
+
2 's 's PART POS _ 1 case 1:case _
|
| 25 |
+
3 reproductive reproductive ADJ JJ Degree=Pos 4 amod 4:amod _
|
| 26 |
+
4 health health NOUN NN Number=Sing 0 root 0:root SpaceAfter=No
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# text = The Chernobyl Children's Project
|
| 30 |
+
1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
|
| 31 |
+
2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
|
| 32 |
+
3-4 Children's _ _ _ _ _ _ _ _
|
| 33 |
+
3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
|
| 34 |
+
4 's 's PART POS _ 3 case 3:case _
|
| 35 |
+
5 Project Project PROPN NNP Number=Sing 0 root 0:root _
|
| 36 |
+
|
| 37 |
+
""".lstrip()
|
| 38 |
+
|
| 39 |
+
ENG_DEV = """
|
| 40 |
+
# text = The Chernobyl Children's Project
|
| 41 |
+
1 The the DET DT Definite=Def|PronType=Art 3 det 3:det _
|
| 42 |
+
2 Chernobyl Chernobyl PROPN NNP Number=Sing 3 compound 3:compound _
|
| 43 |
+
3-4 Children's _ _ _ _ _ _ _ _
|
| 44 |
+
3 Children Children PROPN NNP Number=Sing 5 nmod:poss 5:nmod:poss _
|
| 45 |
+
4 's 's PART POS _ 3 case 3:case _
|
| 46 |
+
5 Project Project PROPN NNP Number=Sing 0 root 0:root _
|
| 47 |
+
|
| 48 |
+
""".lstrip()
|
| 49 |
+
|
| 50 |
+
def test_train(tmp_path):
|
| 51 |
+
test_train = str(os.path.join(tmp_path, "en_test.train.conllu"))
|
| 52 |
+
with open(test_train, "w") as fout:
|
| 53 |
+
fout.write(ENG_TRAIN)
|
| 54 |
+
|
| 55 |
+
test_dev = str(os.path.join(tmp_path, "en_test.dev.conllu"))
|
| 56 |
+
with open(test_dev, "w") as fout:
|
| 57 |
+
fout.write(ENG_DEV)
|
| 58 |
+
|
| 59 |
+
test_output = str(os.path.join(tmp_path, "en_test.dev.pred.conllu"))
|
| 60 |
+
model_name = "en_test_mwt.pt"
|
| 61 |
+
|
| 62 |
+
args = [
|
| 63 |
+
"--data_dir", str(tmp_path),
|
| 64 |
+
"--train_file", test_train,
|
| 65 |
+
"--eval_file", test_dev,
|
| 66 |
+
"--gold_file", test_dev,
|
| 67 |
+
"--lang", "en",
|
| 68 |
+
"--shorthand", "en_test",
|
| 69 |
+
"--output_file", test_output,
|
| 70 |
+
"--save_dir", str(tmp_path),
|
| 71 |
+
"--save_name", model_name,
|
| 72 |
+
"--num_epoch", "10",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
mwt_expander.main(args=args)
|
| 76 |
+
|
| 77 |
+
model = Trainer(model_file=os.path.join(tmp_path, model_name))
|
| 78 |
+
assert model.model is not None
|
| 79 |
+
assert isinstance(model.model, CharacterClassifier)
|
| 80 |
+
|
| 81 |
+
doc = CoNLL.conll2doc(input_str=ENG_DEV)
|
| 82 |
+
dataloader = DataLoader(doc, 10, model.args, vocab=model.vocab, evaluation=True, expand_unk_vocab=True)
|
| 83 |
+
preds = []
|
| 84 |
+
for i, batch in enumerate(dataloader.to_loader()):
|
| 85 |
+
assert i == 0 # there should only be one batch
|
| 86 |
+
preds += model.predict(batch, never_decode_unk=True, vocab=dataloader.vocab)
|
| 87 |
+
assert len(preds) == 1
|
| 88 |
+
# it is possible to make a version of the test where this happens almost every time
|
| 89 |
+
# for example, running for 100 epochs makes the model succeed 30 times in a row
|
| 90 |
+
# (never saw a failure)
|
| 91 |
+
# but the one time that failure happened, it would be really annoying
|
| 92 |
+
#assert preds[0] == "Children 's"
|
stanza/stanza/tests/mwt/test_english_corner_cases.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple English MWT corner cases which might be more widely applicable to other MWT languages
|
| 3 |
+
|
| 4 |
+
- unknown English character doesn't result in bizarre splits
|
| 5 |
+
- Casing or CASING doesn't get lost in the dictionary lookup
|
| 6 |
+
|
| 7 |
+
In the English UD datasets, the MWT are composed exactly of the
|
| 8 |
+
subwords, so the MWT model should be chopping up the input text rather
|
| 9 |
+
than generating new text.
|
| 10 |
+
|
| 11 |
+
Furthermore, SHE'S and She's should be split "SHE 'S" and "She 's" respectively
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
import stanza
|
| 16 |
+
|
| 17 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 18 |
+
|
| 19 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 20 |
+
|
| 21 |
+
def test_mwt_unknown_char():
|
| 22 |
+
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
|
| 23 |
+
|
| 24 |
+
mwt_trainer = pipeline.processors['mwt']._trainer
|
| 25 |
+
|
| 26 |
+
assert mwt_trainer.args['force_exact_pieces']
|
| 27 |
+
|
| 28 |
+
# find a letter 'i' which isn't in the training data
|
| 29 |
+
# the MWT model should still recognize a possessive containing this letter
|
| 30 |
+
assert "i" in mwt_trainer.vocab
|
| 31 |
+
for letter in "ĩîíìī":
|
| 32 |
+
if letter not in mwt_trainer.vocab:
|
| 33 |
+
break
|
| 34 |
+
else:
|
| 35 |
+
raise AssertionError("Need to update the MWT test - all of the non-standard letters 'i' are now in the MWT vocab")
|
| 36 |
+
|
| 37 |
+
word = "Jenn" + letter + "fer"
|
| 38 |
+
possessive = word + "'s"
|
| 39 |
+
text = "I wanna lick " + possessive + " antennae"
|
| 40 |
+
doc = pipeline(text)
|
| 41 |
+
assert doc.sentences[0].tokens[1].text == 'wanna'
|
| 42 |
+
assert len(doc.sentences[0].tokens[1].words) == 2
|
| 43 |
+
assert "".join(x.text for x in doc.sentences[0].tokens[1].words) == 'wanna'
|
| 44 |
+
|
| 45 |
+
assert doc.sentences[0].tokens[3].text == possessive
|
| 46 |
+
assert len(doc.sentences[0].tokens[3].words) == 2
|
| 47 |
+
assert "".join(x.text for x in doc.sentences[0].tokens[3].words) == possessive
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_english_mwt_casing():
|
| 51 |
+
"""
|
| 52 |
+
Test that for a word where the lowercase split is known, the correct casing is still used
|
| 53 |
+
|
| 54 |
+
Once upon a time, the logic used in the MWT expander would split
|
| 55 |
+
SHE'S -> she 's
|
| 56 |
+
|
| 57 |
+
which is a very surprising tokenization to people expecting
|
| 58 |
+
the original text in the output document
|
| 59 |
+
"""
|
| 60 |
+
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
|
| 61 |
+
|
| 62 |
+
mwt_trainer = pipeline.processors['mwt']._trainer
|
| 63 |
+
for i in range(1, 20):
|
| 64 |
+
# many test cases follow this pattern for some reason,
|
| 65 |
+
# so we should proactively look for a test case which hasn't
|
| 66 |
+
# made its way into the MWT dictionary
|
| 67 |
+
unknown_name = "jennife" + "r" * i + "'s"
|
| 68 |
+
if unknown_name not in mwt_trainer.expansion_dict and unknown_name.upper() not in mwt_trainer.expansion_dict:
|
| 69 |
+
unknown_name = unknown_name.upper()
|
| 70 |
+
break
|
| 71 |
+
else:
|
| 72 |
+
raise AssertionError("Need a new heuristic for the unknown word in the English MWT!")
|
| 73 |
+
|
| 74 |
+
# this SHOULD show up in the expansion dict
|
| 75 |
+
assert "she's" in mwt_trainer.expansion_dict, "Expected |she's| to be in the English MWT expansion dict... perhaps find a different test case"
|
| 76 |
+
|
| 77 |
+
text = [x.text for x in pipeline("JENNIFER HAS NICE ANTENNAE").sentences[0].words]
|
| 78 |
+
assert text == ['JENNIFER', 'HAS', 'NICE', 'ANTENNAE']
|
| 79 |
+
|
| 80 |
+
text = [x.text for x in pipeline(unknown_name + " GOT NICE ANTENNAE").sentences[0].words]
|
| 81 |
+
assert text == [unknown_name[:-2], "'S", 'GOT', 'NICE', 'ANTENNAE']
|
| 82 |
+
|
| 83 |
+
text = [x.text for x in pipeline("SHE'S GOT NICE ANTENNAE").sentences[0].words]
|
| 84 |
+
assert text == ['SHE', "'S", 'GOT', 'NICE', 'ANTENNAE']
|
| 85 |
+
|
| 86 |
+
text = [x.text for x in pipeline("She's GOT NICE ANTENNAE").sentences[0].words]
|
| 87 |
+
assert text == ['She', "'s", 'GOT', 'NICE', 'ANTENNAE']
|
| 88 |
+
|
stanza/stanza/tests/ner/test_bsf_2_iob.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests the conversion code for the lang_uk NER dataset
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import unittest
|
| 6 |
+
from stanza.utils.datasets.ner.convert_bsf_to_beios import convert_bsf, parse_bsf, BsfInfo
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 10 |
+
|
| 11 |
+
class TestBsf2Iob(unittest.TestCase):
|
| 12 |
+
|
| 13 |
+
def test_1line_follow_markup_iob(self):
|
| 14 |
+
data = 'тележурналіст Василь .'
|
| 15 |
+
bsf_markup = 'T1 PERS 14 20 Василь'
|
| 16 |
+
expected = '''тележурналіст O
|
| 17 |
+
Василь B-PERS
|
| 18 |
+
. O'''
|
| 19 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
|
| 20 |
+
|
| 21 |
+
def test_1line_2tok_markup_iob(self):
|
| 22 |
+
data = 'тележурналіст Василь Нагірний .'
|
| 23 |
+
bsf_markup = 'T1 PERS 14 29 Василь Нагірний'
|
| 24 |
+
expected = '''тележурналіст O
|
| 25 |
+
Василь B-PERS
|
| 26 |
+
Нагірний I-PERS
|
| 27 |
+
. O'''
|
| 28 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
|
| 29 |
+
|
| 30 |
+
def test_1line_Long_tok_markup_iob(self):
|
| 31 |
+
data = 'А в музеї Гуцульщини і Покуття можна '
|
| 32 |
+
bsf_markup = 'T12 ORG 4 30 музеї Гуцульщини і Покуття'
|
| 33 |
+
expected = '''А O
|
| 34 |
+
в O
|
| 35 |
+
музеї B-ORG
|
| 36 |
+
Гуцульщини I-ORG
|
| 37 |
+
і I-ORG
|
| 38 |
+
Покуття I-ORG
|
| 39 |
+
можна O'''
|
| 40 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
|
| 41 |
+
|
| 42 |
+
def test_2line_2tok_markup_iob(self):
|
| 43 |
+
data = '''тележурналіст Василь Нагірний .
|
| 44 |
+
В івано-франківському видавництві «Лілея НВ» вийшла друком'''
|
| 45 |
+
bsf_markup = '''T1 PERS 14 29 Василь Нагірний
|
| 46 |
+
T2 ORG 67 75 Лілея НВ'''
|
| 47 |
+
expected = '''тележурналіст O
|
| 48 |
+
Василь B-PERS
|
| 49 |
+
Нагірний I-PERS
|
| 50 |
+
. O
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
В O
|
| 54 |
+
івано-франківському O
|
| 55 |
+
видавництві O
|
| 56 |
+
« O
|
| 57 |
+
Лілея B-ORG
|
| 58 |
+
НВ I-ORG
|
| 59 |
+
» O
|
| 60 |
+
вийшла O
|
| 61 |
+
друком O'''
|
| 62 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
|
| 63 |
+
|
| 64 |
+
def test_all_multiline_iob(self):
|
| 65 |
+
data = '''його книжечка «А .
|
| 66 |
+
Kubler .
|
| 67 |
+
Світло і тіні маестро» .
|
| 68 |
+
Причому'''
|
| 69 |
+
bsf_markup = '''T4 MISC 15 49 А .
|
| 70 |
+
Kubler .
|
| 71 |
+
Світло і тіні маестро
|
| 72 |
+
'''
|
| 73 |
+
expected = '''його O
|
| 74 |
+
книжечка O
|
| 75 |
+
« O
|
| 76 |
+
А B-MISC
|
| 77 |
+
. I-MISC
|
| 78 |
+
Kubler I-MISC
|
| 79 |
+
. I-MISC
|
| 80 |
+
Світло I-MISC
|
| 81 |
+
і I-MISC
|
| 82 |
+
тіні I-MISC
|
| 83 |
+
маестро I-MISC
|
| 84 |
+
» O
|
| 85 |
+
. O
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
Причому O'''
|
| 89 |
+
self.assertEqual(expected, convert_bsf(data, bsf_markup, converter='iob'))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == '__main__':
|
| 93 |
+
unittest.main()
|
stanza/stanza/tests/ner/test_convert_amt.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test some of the functions used for converting an AMT json to a Stanza json
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
import stanza
|
| 11 |
+
from stanza.utils.datasets.ner import convert_amt
|
| 12 |
+
|
| 13 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 14 |
+
|
| 15 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 16 |
+
|
| 17 |
+
TEXT = "Jennifer Sh'reyan has lovely antennae."
|
| 18 |
+
|
| 19 |
+
def fake_label(label, start_char, end_char):
|
| 20 |
+
return {'label': label,
|
| 21 |
+
'startOffset': start_char,
|
| 22 |
+
'endOffset': end_char}
|
| 23 |
+
|
| 24 |
+
LABELS = [
|
| 25 |
+
fake_label('Person', 0, 8),
|
| 26 |
+
fake_label('Person', 9, 17),
|
| 27 |
+
fake_label('Person', 0, 17),
|
| 28 |
+
fake_label('Andorian', 0, 8),
|
| 29 |
+
fake_label('Appendage', 29, 37),
|
| 30 |
+
fake_label('Person', 1, 8),
|
| 31 |
+
fake_label('Person', 0, 7),
|
| 32 |
+
fake_label('Person', 0, 9),
|
| 33 |
+
fake_label('Appendage', 29, 38),
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
def fake_labels(*indices):
|
| 37 |
+
return [LABELS[x] for x in indices]
|
| 38 |
+
|
| 39 |
+
def fake_docs(*indices):
|
| 40 |
+
return [(TEXT, fake_labels(*indices))]
|
| 41 |
+
|
| 42 |
+
def test_remove_nesting():
|
| 43 |
+
"""
|
| 44 |
+
Test a few orders on nested items to make sure the desired results are coming back
|
| 45 |
+
"""
|
| 46 |
+
# this should be unchanged
|
| 47 |
+
result = convert_amt.remove_nesting(fake_docs(0, 1))
|
| 48 |
+
assert result == fake_docs(0, 1)
|
| 49 |
+
|
| 50 |
+
# this should be returned sorted
|
| 51 |
+
result = convert_amt.remove_nesting(fake_docs(0, 4, 1))
|
| 52 |
+
assert result == fake_docs(0, 1, 4)
|
| 53 |
+
|
| 54 |
+
# this should just have one copy
|
| 55 |
+
result = convert_amt.remove_nesting(fake_docs(0, 0))
|
| 56 |
+
assert result == fake_docs(0)
|
| 57 |
+
|
| 58 |
+
# outer one preferred
|
| 59 |
+
result = convert_amt.remove_nesting(fake_docs(0, 2))
|
| 60 |
+
assert result == fake_docs(2)
|
| 61 |
+
result = convert_amt.remove_nesting(fake_docs(1, 2))
|
| 62 |
+
assert result == fake_docs(2)
|
| 63 |
+
result = convert_amt.remove_nesting(fake_docs(5, 2))
|
| 64 |
+
assert result == fake_docs(2)
|
| 65 |
+
# order doesn't matter
|
| 66 |
+
result = convert_amt.remove_nesting(fake_docs(0, 4, 2))
|
| 67 |
+
assert result == fake_docs(2, 4)
|
| 68 |
+
result = convert_amt.remove_nesting(fake_docs(2, 4, 0))
|
| 69 |
+
assert result == fake_docs(2, 4)
|
| 70 |
+
|
| 71 |
+
# first one preferred
|
| 72 |
+
result = convert_amt.remove_nesting(fake_docs(0, 3))
|
| 73 |
+
assert result == fake_docs(0)
|
| 74 |
+
result = convert_amt.remove_nesting(fake_docs(3, 0))
|
| 75 |
+
assert result == fake_docs(3)
|
| 76 |
+
|
| 77 |
+
def test_process_doc():
|
| 78 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
|
| 79 |
+
|
| 80 |
+
def check_results(doc, *expected):
|
| 81 |
+
ner = [x[1] for x in doc[0]]
|
| 82 |
+
assert ner == list(expected)
|
| 83 |
+
|
| 84 |
+
# test a standard case of all the values lining up
|
| 85 |
+
doc = convert_amt.process_doc(TEXT, fake_labels(2, 4), nlp)
|
| 86 |
+
check_results(doc, "B-Person", "I-Person", "O", "O", "B-Appendage", "O")
|
| 87 |
+
|
| 88 |
+
# test a slightly wrong start index
|
| 89 |
+
doc = convert_amt.process_doc(TEXT, fake_labels(5, 1, 4), nlp)
|
| 90 |
+
check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O")
|
| 91 |
+
|
| 92 |
+
# test a slightly wrong end index
|
| 93 |
+
doc = convert_amt.process_doc(TEXT, fake_labels(6, 1, 4), nlp)
|
| 94 |
+
check_results(doc, "B-Person", "B-Person", "O", "O", "B-Appendage", "O")
|
| 95 |
+
|
| 96 |
+
# test a slightly wronger end index
|
| 97 |
+
doc = convert_amt.process_doc(TEXT, fake_labels(7, 4), nlp)
|
| 98 |
+
check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O")
|
| 99 |
+
|
| 100 |
+
# test a period at the end of a text - should not be captured
|
| 101 |
+
doc = convert_amt.process_doc(TEXT, fake_labels(7, 8), nlp)
|
| 102 |
+
check_results(doc, "B-Person", "O", "O", "O", "B-Appendage", "O")
|
| 103 |
+
|
| 104 |
+
|
stanza/stanza/tests/ner/test_convert_starlang_ner.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple different classes of trees to check the output of the Starlang conversion for NER
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stanza.utils.datasets.ner import convert_starlang_ner
|
| 11 |
+
|
| 12 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 13 |
+
|
| 14 |
+
TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )"
|
| 15 |
+
|
| 16 |
+
def test_read_tree():
|
| 17 |
+
"""
|
| 18 |
+
Test a basic tree read
|
| 19 |
+
"""
|
| 20 |
+
sentence = convert_starlang_ner.read_tree(TREE)
|
| 21 |
+
expected = [('Bayan', 'PERSON'), ('Haag', 'PERSON'), ('Elianti', 'O'), ('çalar', 'O'), ('.', 'O')]
|
| 22 |
+
assert sentence == expected
|
| 23 |
+
|
stanza/stanza/tests/ner/test_from_conllu.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza import Pipeline
|
| 4 |
+
from stanza.utils.conll import CoNLL
|
| 5 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 8 |
+
|
| 9 |
+
def test_from_conllu():
|
| 10 |
+
"""
|
| 11 |
+
If the doc does not have the entire text available, make sure it still safely processes the text
|
| 12 |
+
|
| 13 |
+
Test case supplied from user - see issue #1428
|
| 14 |
+
"""
|
| 15 |
+
pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", download_method=None)
|
| 16 |
+
doc = pipe("In February, I traveled to Seattle. Dr. Pritchett gave me a new hip")
|
| 17 |
+
ents = [x.text for x in doc.ents]
|
| 18 |
+
# the default NER model ought to find these three
|
| 19 |
+
assert ents == ['February', 'Seattle', 'Pritchett']
|
| 20 |
+
|
| 21 |
+
doc_conllu = "{:C}\n\n".format(doc)
|
| 22 |
+
doc = CoNLL.conll2doc(input_str=doc_conllu)
|
| 23 |
+
pipe = Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,ner", tokenize_pretokenized=True, download_method=None)
|
| 24 |
+
pipe(doc)
|
| 25 |
+
ents = [x.text for x in doc.ents]
|
| 26 |
+
# this should still work when processed from a CoNLLu document
|
| 27 |
+
# the bug previously caused a crash because the text to construct
|
| 28 |
+
# the entities was not available, since the Document wouldn't have
|
| 29 |
+
# the entire document text available
|
| 30 |
+
assert ents == ['February', 'Seattle', 'Pritchett']
|
stanza/stanza/tests/ner/test_ner_utils.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.tests import *
|
| 4 |
+
|
| 5 |
+
from stanza.models.common.vocab import EMPTY
|
| 6 |
+
from stanza.models.ner import utils
|
| 7 |
+
|
| 8 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 9 |
+
|
| 10 |
+
WORDS = [["Unban", "Mox", "Opal"], ["Ragavan", "is", "red"], ["Urza", "Lord", "High", "Artificer", "goes", "infinite", "with", "Thopter", "Sword"]]
|
| 11 |
+
BIO_TAGS = [["O", "B-ART", "I-ART"], ["B-MONKEY", "O", "B-COLOR"], ["B-PER", "I-PER", "I-PER", "I-PER", "O", "O", "O", "B-WEAPON", "B-WEAPON"]]
|
| 12 |
+
BIO_U_TAGS = [["O", "B_ART", "I_ART"], ["B_MONKEY", "O", "B_COLOR"], ["B_PER", "I_PER", "I_PER", "I_PER", "O", "O", "O", "B_WEAPON", "B_WEAPON"]]
|
| 13 |
+
BIOES_TAGS = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "S-WEAPON", "S-WEAPON"]]
|
| 14 |
+
# note the problem with not using BIO tags - the consecutive tags for thopter/sword get treated as one item
|
| 15 |
+
BASIC_TAGS = [["O", "ART", "ART"], ["MONKEY", "O", "COLOR"], [ "PER", "PER", "PER", "PER", "O", "O", "O", "WEAPON", "WEAPON"]]
|
| 16 |
+
BASIC_BIOES = [["O", "B-ART", "E-ART"], ["S-MONKEY", "O", "S-COLOR"], ["B-PER", "I-PER", "I-PER", "E-PER", "O", "O", "O", "B-WEAPON", "E-WEAPON"]]
|
| 17 |
+
ALT_BIO = [["O", "B-MANA", "I-MANA"], ["B-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
|
| 18 |
+
ALT_BIOES = [["O", "B-MANA", "E-MANA"], ["S-CRE", "O", "O"], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]]
|
| 19 |
+
NONE_BIO = [["O", "B-MANA", "I-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
|
| 20 |
+
NONE_BIOES = [["O", "B-MANA", "E-MANA"], [None, None, None], ["B-CRE", "I-CRE", "I-CRE", "E-CRE", "O", "O", "O", "S-ART", "S-ART"]]
|
| 21 |
+
EMPTY_BIO = [["O", "B-MANA", "I-MANA"], [EMPTY, EMPTY, EMPTY], ["B-CRE", "I-CRE", "I-CRE", "I-CRE", "O", "O", "O", "B-ART", "B-ART"]]
|
| 22 |
+
|
| 23 |
+
def test_normalize_empty_tags():
|
| 24 |
+
sentences = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, NONE_BIO)]
|
| 25 |
+
new_sentences = utils.normalize_empty_tags(sentences)
|
| 26 |
+
expected = [[(word[0], (word[1],)) for word in zip(*sentence)] for sentence in zip(WORDS, EMPTY_BIO)]
|
| 27 |
+
assert new_sentences == expected
|
| 28 |
+
|
| 29 |
+
def check_reprocessed_tags(words, input_tags, expected_tags):
|
| 30 |
+
sentences = [list(zip(x, y)) for x, y in zip(words, input_tags)]
|
| 31 |
+
retagged = utils.process_tags(sentences=sentences, scheme="bioes")
|
| 32 |
+
# process_tags selectively returns tuples or strings based on the input
|
| 33 |
+
# so we don't need to fiddle with the expected output format here
|
| 34 |
+
expected_retagged = [list(zip(x, y)) for x, y in zip(words, expected_tags)]
|
| 35 |
+
assert retagged == expected_retagged
|
| 36 |
+
|
| 37 |
+
def test_process_tags_bio():
|
| 38 |
+
check_reprocessed_tags(WORDS, BIO_TAGS, BIOES_TAGS)
|
| 39 |
+
# check that the alternate version is correct as well
|
| 40 |
+
# that way we can independently check the two layer version
|
| 41 |
+
check_reprocessed_tags(WORDS, ALT_BIO, ALT_BIOES)
|
| 42 |
+
|
| 43 |
+
def test_process_tags_with_none():
|
| 44 |
+
# if there is a block of tags with None in them, the Nones should be skipped over
|
| 45 |
+
check_reprocessed_tags(WORDS, NONE_BIO, NONE_BIOES)
|
| 46 |
+
|
| 47 |
+
def merge_tags(*tags):
|
| 48 |
+
merged_tags = [[tuple(x) for x in zip(*sentences)] # combine tags such as ("O", "O"), ("B-ART", "B-MANA"), ...
|
| 49 |
+
for sentences in zip(*tags)] # ... for each set of sentences
|
| 50 |
+
return merged_tags
|
| 51 |
+
|
| 52 |
+
def test_combined_tags_bio():
|
| 53 |
+
bio_tags = merge_tags(BIO_TAGS, ALT_BIO)
|
| 54 |
+
expected = merge_tags(BIOES_TAGS, ALT_BIOES)
|
| 55 |
+
check_reprocessed_tags(WORDS, bio_tags, expected)
|
| 56 |
+
|
| 57 |
+
def test_combined_tags_mixed():
|
| 58 |
+
bio_tags = merge_tags(BIO_TAGS, ALT_BIOES)
|
| 59 |
+
expected = merge_tags(BIOES_TAGS, ALT_BIOES)
|
| 60 |
+
check_reprocessed_tags(WORDS, bio_tags, expected)
|
| 61 |
+
|
| 62 |
+
def test_process_tags_basic():
|
| 63 |
+
check_reprocessed_tags(WORDS, BASIC_TAGS, BASIC_BIOES)
|
| 64 |
+
|
| 65 |
+
def test_process_tags_bioes():
|
| 66 |
+
"""
|
| 67 |
+
This one should not change, naturally
|
| 68 |
+
"""
|
| 69 |
+
check_reprocessed_tags(WORDS, BIOES_TAGS, BIOES_TAGS)
|
| 70 |
+
check_reprocessed_tags(WORDS, BASIC_BIOES, BASIC_BIOES)
|
| 71 |
+
|
| 72 |
+
def run_flattened(fn, tags):
|
| 73 |
+
return fn([x for x in y for y in tags])
|
| 74 |
+
|
| 75 |
+
def test_check_bio():
|
| 76 |
+
assert utils.is_bio_scheme([x for y in BIO_TAGS for x in y])
|
| 77 |
+
assert not utils.is_bio_scheme([x for y in BIOES_TAGS for x in y])
|
| 78 |
+
assert not utils.is_bio_scheme([x for y in BASIC_TAGS for x in y])
|
| 79 |
+
assert not utils.is_bio_scheme([x for y in BASIC_BIOES for x in y])
|
| 80 |
+
|
| 81 |
+
def test_check_basic():
|
| 82 |
+
assert not utils.is_basic_scheme([x for y in BIO_TAGS for x in y])
|
| 83 |
+
assert not utils.is_basic_scheme([x for y in BIOES_TAGS for x in y])
|
| 84 |
+
assert utils.is_basic_scheme([x for y in BASIC_TAGS for x in y])
|
| 85 |
+
assert not utils.is_basic_scheme([x for y in BASIC_BIOES for x in y])
|
| 86 |
+
|
| 87 |
+
def test_underscores():
|
| 88 |
+
"""
|
| 89 |
+
Check that the methods work if the inputs are underscores instead of dashes
|
| 90 |
+
"""
|
| 91 |
+
assert not utils.is_basic_scheme([x for y in BIO_U_TAGS for x in y])
|
| 92 |
+
check_reprocessed_tags(WORDS, BIO_U_TAGS, BIOES_TAGS)
|
| 93 |
+
|
| 94 |
+
def test_merge_tags():
|
| 95 |
+
"""
|
| 96 |
+
Check a few versions of the tag sequence merging
|
| 97 |
+
"""
|
| 98 |
+
seq1 = [ "O", "O", "O", "B-FOO", "E-FOO", "O"]
|
| 99 |
+
seq2 = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"]
|
| 100 |
+
seq3 = [ "B-FOO", "E-FOO", "B-FOO", "E-FOO", "O", "O"]
|
| 101 |
+
seq_err = [ "O", "B-FOO", "O", "B-FOO", "E-FOO", "O"]
|
| 102 |
+
seq_err2 = [ "O", "B-FOO", "O", "B-FOO", "B-FOO", "O"]
|
| 103 |
+
seq_err3 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "O"]
|
| 104 |
+
seq_err4 = [ "O", "B-FOO", "O", "B-FOO", "I-FOO", "I-FOO"]
|
| 105 |
+
|
| 106 |
+
result = utils.merge_tags(seq1, seq2)
|
| 107 |
+
expected = [ "S-FOO", "O", "O", "B-FOO", "E-FOO", "O"]
|
| 108 |
+
assert result == expected
|
| 109 |
+
|
| 110 |
+
result = utils.merge_tags(seq2, seq1)
|
| 111 |
+
expected = [ "S-FOO", "O", "B-FOO", "E-FOO", "O", "O"]
|
| 112 |
+
assert result == expected
|
| 113 |
+
|
| 114 |
+
result = utils.merge_tags(seq1, seq3)
|
| 115 |
+
expected = [ "B-FOO", "E-FOO", "O", "B-FOO", "E-FOO", "O"]
|
| 116 |
+
assert result == expected
|
| 117 |
+
|
| 118 |
+
with pytest.raises(ValueError):
|
| 119 |
+
result = utils.merge_tags(seq1, seq_err)
|
| 120 |
+
|
| 121 |
+
with pytest.raises(ValueError):
|
| 122 |
+
result = utils.merge_tags(seq1, seq_err2)
|
| 123 |
+
|
| 124 |
+
with pytest.raises(ValueError):
|
| 125 |
+
result = utils.merge_tags(seq1, seq_err3)
|
| 126 |
+
|
| 127 |
+
with pytest.raises(ValueError):
|
| 128 |
+
result = utils.merge_tags(seq1, seq_err4)
|
| 129 |
+
|
stanza/stanza/tests/pipeline/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/pipeline/test_arabic_pipeline.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Small test of loading the Arabic pipeline
|
| 3 |
+
|
| 4 |
+
The main goal is to check that nothing goes wrong with RtL languages,
|
| 5 |
+
but incidentally this would have caught a bug where the xpos tags
|
| 6 |
+
were split into individual pieces instead of reassembled as expected
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
import stanza
|
| 11 |
+
|
| 12 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 13 |
+
|
| 14 |
+
pytestmark = pytest.mark.pipeline
|
| 15 |
+
|
| 16 |
+
def test_arabic_pos_pipeline():
|
| 17 |
+
pipe = stanza.Pipeline(**{'processors': 'tokenize,pos', 'dir': TEST_MODELS_DIR, 'download_method': None, 'lang': 'ar'})
|
| 18 |
+
text = "ولم يتم اعتقال احد بحسب المتحدث باسم الشرطة."
|
| 19 |
+
|
| 20 |
+
doc = pipe(text)
|
| 21 |
+
# the first token translates to "and not", seems common enough
|
| 22 |
+
# that we should be able to rely on it having a stable MWT and tag
|
| 23 |
+
|
| 24 |
+
assert len(doc.sentences) == 1
|
| 25 |
+
assert doc.sentences[0].tokens[0].text == "ولم"
|
| 26 |
+
assert doc.sentences[0].words[0].xpos == "C---------"
|
| 27 |
+
assert doc.sentences[0].words[1].xpos == "F---------"
|
stanza/stanza/tests/pipeline/test_core.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import shutil
|
| 3 |
+
import tempfile
|
| 4 |
+
|
| 5 |
+
import stanza
|
| 6 |
+
|
| 7 |
+
from stanza.tests import *
|
| 8 |
+
|
| 9 |
+
from stanza.pipeline import core
|
| 10 |
+
from stanza.resources.common import get_md5, load_resources_json
|
| 11 |
+
|
| 12 |
+
pytestmark = pytest.mark.pipeline
|
| 13 |
+
|
| 14 |
+
def test_pretagged():
|
| 15 |
+
"""
|
| 16 |
+
Test that the pipeline does or doesn't build if pos is left out and pretagged is specified
|
| 17 |
+
"""
|
| 18 |
+
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,pos,lemma,depparse")
|
| 19 |
+
with pytest.raises(core.PipelineRequirementsException):
|
| 20 |
+
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse")
|
| 21 |
+
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True)
|
| 22 |
+
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", pretagged=True)
|
| 23 |
+
# test that the module specific flag overrides the general flag
|
| 24 |
+
nlp = stanza.Pipeline(lang='en', dir=TEST_MODELS_DIR, processors="tokenize,lemma,depparse", depparse_pretagged=True, pretagged=False)
|
| 25 |
+
|
| 26 |
+
def test_download_missing_ner_model():
|
| 27 |
+
"""
|
| 28 |
+
Test that the pipeline will automatically download missing models
|
| 29 |
+
"""
|
| 30 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
|
| 31 |
+
stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined", verbose=False)
|
| 32 |
+
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"ner": ("ontonotes_charlm")})
|
| 33 |
+
|
| 34 |
+
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
|
| 35 |
+
en_dir = os.path.join(test_dir, 'en')
|
| 36 |
+
en_dir_listing = sorted(os.listdir(en_dir))
|
| 37 |
+
assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']
|
| 38 |
+
assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_download_missing_resources():
|
| 42 |
+
"""
|
| 43 |
+
Test that the pipeline will automatically download missing models
|
| 44 |
+
"""
|
| 45 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
|
| 46 |
+
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize,ner", package={"tokenize": "combined", "ner": "ontonotes_charlm"})
|
| 47 |
+
|
| 48 |
+
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
|
| 49 |
+
en_dir = os.path.join(test_dir, 'en')
|
| 50 |
+
en_dir_listing = sorted(os.listdir(en_dir))
|
| 51 |
+
assert en_dir_listing == ['backward_charlm', 'forward_charlm', 'mwt', 'ner', 'pretrain', 'tokenize']
|
| 52 |
+
assert os.listdir(os.path.join(en_dir, 'ner')) == ['ontonotes_charlm.pt']
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_download_resources_overwrites():
|
| 56 |
+
"""
|
| 57 |
+
Test that the DOWNLOAD_RESOURCES method overwrites an existing resources.json
|
| 58 |
+
"""
|
| 59 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
|
| 60 |
+
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
|
| 61 |
+
|
| 62 |
+
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
|
| 63 |
+
resources_path = os.path.join(test_dir, 'resources.json')
|
| 64 |
+
mod_time = os.path.getmtime(resources_path)
|
| 65 |
+
|
| 66 |
+
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
|
| 67 |
+
new_mod_time = os.path.getmtime(resources_path)
|
| 68 |
+
assert mod_time != new_mod_time
|
| 69 |
+
|
| 70 |
+
def test_reuse_resources_overwrites():
|
| 71 |
+
"""
|
| 72 |
+
Test that the REUSE_RESOURCES method does *not* overwrite an existing resources.json
|
| 73 |
+
"""
|
| 74 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
|
| 75 |
+
pipe = stanza.Pipeline("en",
|
| 76 |
+
download_method=core.DownloadMethod.REUSE_RESOURCES,
|
| 77 |
+
model_dir=test_dir,
|
| 78 |
+
processors="tokenize",
|
| 79 |
+
package={"tokenize": "combined"})
|
| 80 |
+
|
| 81 |
+
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
|
| 82 |
+
resources_path = os.path.join(test_dir, 'resources.json')
|
| 83 |
+
mod_time = os.path.getmtime(resources_path)
|
| 84 |
+
|
| 85 |
+
pipe = stanza.Pipeline("en",
|
| 86 |
+
download_method=core.DownloadMethod.REUSE_RESOURCES,
|
| 87 |
+
model_dir=test_dir,
|
| 88 |
+
processors="tokenize",
|
| 89 |
+
package={"tokenize": "combined"})
|
| 90 |
+
new_mod_time = os.path.getmtime(resources_path)
|
| 91 |
+
assert mod_time == new_mod_time
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_download_not_repeated():
|
| 95 |
+
"""
|
| 96 |
+
Test that a model is only downloaded once if it already matches the expected model from the resources file
|
| 97 |
+
"""
|
| 98 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
|
| 99 |
+
stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined")
|
| 100 |
+
|
| 101 |
+
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
|
| 102 |
+
en_dir = os.path.join(test_dir, 'en')
|
| 103 |
+
en_dir_listing = sorted(os.listdir(en_dir))
|
| 104 |
+
assert en_dir_listing == ['mwt', 'tokenize']
|
| 105 |
+
tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt")
|
| 106 |
+
mod_time = os.path.getmtime(tokenize_path)
|
| 107 |
+
|
| 108 |
+
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
|
| 109 |
+
assert os.path.getmtime(tokenize_path) == mod_time
|
| 110 |
+
|
| 111 |
+
def test_download_none():
|
| 112 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
|
| 113 |
+
stanza.download("it", model_dir=test_dir, processors="tokenize", package="combined")
|
| 114 |
+
stanza.download("it", model_dir=test_dir, processors="tokenize", package="vit")
|
| 115 |
+
|
| 116 |
+
it_dir = os.path.join(test_dir, 'it')
|
| 117 |
+
it_dir_listing = sorted(os.listdir(it_dir))
|
| 118 |
+
assert sorted(it_dir_listing) == ['mwt', 'tokenize']
|
| 119 |
+
combined_path = os.path.join(it_dir, "tokenize", "combined.pt")
|
| 120 |
+
vit_path = os.path.join(it_dir, "tokenize", "vit.pt")
|
| 121 |
+
|
| 122 |
+
assert os.path.exists(combined_path)
|
| 123 |
+
assert os.path.exists(vit_path)
|
| 124 |
+
|
| 125 |
+
combined_md5 = get_md5(combined_path)
|
| 126 |
+
vit_md5 = get_md5(vit_path)
|
| 127 |
+
# check that the models are different
|
| 128 |
+
# otherwise the test is not testing anything
|
| 129 |
+
assert combined_md5 != vit_md5
|
| 130 |
+
|
| 131 |
+
shutil.copyfile(vit_path, combined_path)
|
| 132 |
+
assert get_md5(combined_path) == vit_md5
|
| 133 |
+
|
| 134 |
+
pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=None)
|
| 135 |
+
assert get_md5(combined_path) == vit_md5
|
| 136 |
+
|
| 137 |
+
pipe = stanza.Pipeline("it", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"})
|
| 138 |
+
assert get_md5(combined_path) != vit_md5
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def check_download_method_updates(download_method):
|
| 142 |
+
"""
|
| 143 |
+
Run a single test of creating a pipeline with a given download_method, checking that the model is updated
|
| 144 |
+
"""
|
| 145 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as test_dir:
|
| 146 |
+
stanza.download("en", model_dir=test_dir, processors="tokenize", package="combined")
|
| 147 |
+
|
| 148 |
+
assert sorted(os.listdir(test_dir)) == ['en', 'resources.json']
|
| 149 |
+
en_dir = os.path.join(test_dir, 'en')
|
| 150 |
+
en_dir_listing = sorted(os.listdir(en_dir))
|
| 151 |
+
assert en_dir_listing == ['mwt', 'tokenize']
|
| 152 |
+
tokenize_path = os.path.join(en_dir, "tokenize", "combined.pt")
|
| 153 |
+
|
| 154 |
+
with open(tokenize_path, "w") as fout:
|
| 155 |
+
fout.write("Unban mox opal!")
|
| 156 |
+
mod_time = os.path.getmtime(tokenize_path)
|
| 157 |
+
|
| 158 |
+
pipe = stanza.Pipeline("en", model_dir=test_dir, processors="tokenize", package={"tokenize": "combined"}, download_method=download_method)
|
| 159 |
+
assert os.path.getmtime(tokenize_path) != mod_time
|
| 160 |
+
|
| 161 |
+
def test_download_fixed():
|
| 162 |
+
"""
|
| 163 |
+
Test that a model is fixed if the existing model doesn't match the md5sum
|
| 164 |
+
"""
|
| 165 |
+
for download_method in (core.DownloadMethod.REUSE_RESOURCES, core.DownloadMethod.DOWNLOAD_RESOURCES):
|
| 166 |
+
check_download_method_updates(download_method)
|
| 167 |
+
|
| 168 |
+
def test_download_strings():
|
| 169 |
+
"""
|
| 170 |
+
Same as the test of the download_method, but tests that the pipeline works for string download_method
|
| 171 |
+
"""
|
| 172 |
+
for download_method in ("reuse_resources", "download_resources"):
|
| 173 |
+
check_download_method_updates(download_method)
|
| 174 |
+
|
| 175 |
+
def test_limited_pipeline():
|
| 176 |
+
"""
|
| 177 |
+
Test loading a pipeline, but then only using a couple processors
|
| 178 |
+
"""
|
| 179 |
+
pipe = stanza.Pipeline(processors="tokenize,pos,lemma,depparse,ner", dir=TEST_MODELS_DIR)
|
| 180 |
+
doc = pipe("John Bauer works at Stanford")
|
| 181 |
+
assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)
|
| 182 |
+
assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
|
| 183 |
+
|
| 184 |
+
doc = pipe("John Bauer works at Stanford", processors=["tokenize","pos"])
|
| 185 |
+
assert all(word.upos is not None for sentence in doc.sentences for word in sentence.words)
|
| 186 |
+
assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
|
| 187 |
+
|
| 188 |
+
doc = pipe("John Bauer works at Stanford", processors="tokenize")
|
| 189 |
+
assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)
|
| 190 |
+
assert not any(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
|
| 191 |
+
|
| 192 |
+
doc = pipe("John Bauer works at Stanford", processors="tokenize,ner")
|
| 193 |
+
assert not any(word.upos is not None for sentence in doc.sentences for word in sentence.words)
|
| 194 |
+
assert all(token.ner is not None for sentence in doc.sentences for token in sentence.tokens)
|
| 195 |
+
|
| 196 |
+
with pytest.raises(ValueError):
|
| 197 |
+
# this should fail
|
| 198 |
+
doc = pipe("John Bauer works at Stanford", processors="tokenize,depparse")
|
| 199 |
+
|
| 200 |
+
@pytest.fixture(scope="module")
|
| 201 |
+
def unknown_language_name():
|
| 202 |
+
resources = load_resources_json(model_dir=TEST_MODELS_DIR)
|
| 203 |
+
name = "en"
|
| 204 |
+
while name in resources:
|
| 205 |
+
name = name + "z"
|
| 206 |
+
assert name != "en"
|
| 207 |
+
return name
|
| 208 |
+
|
| 209 |
+
def test_empty_unknown_language(unknown_language_name):
|
| 210 |
+
"""
|
| 211 |
+
Check that there is an error for trying to load an unknown language
|
| 212 |
+
"""
|
| 213 |
+
with pytest.raises(ValueError):
|
| 214 |
+
pipe = stanza.Pipeline(unknown_language_name, download_method=None)
|
| 215 |
+
|
| 216 |
+
def test_unknown_language_tokenizer(unknown_language_name):
|
| 217 |
+
"""
|
| 218 |
+
Test that loading tokenize works for an unknown language
|
| 219 |
+
"""
|
| 220 |
+
base_pipe = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
|
| 221 |
+
# even if we one day add MWT to English, the tokenizer by itself should still work
|
| 222 |
+
tokenize_processor = base_pipe.processors["tokenize"]
|
| 223 |
+
|
| 224 |
+
pipe=stanza.Pipeline(unknown_language_name,
|
| 225 |
+
processors="tokenize",
|
| 226 |
+
allow_unknown_language=True,
|
| 227 |
+
tokenize_model_path=tokenize_processor.config['model_path'],
|
| 228 |
+
download_method=None)
|
| 229 |
+
doc = pipe("This is a test")
|
| 230 |
+
words = [x.text for x in doc.sentences[0].words]
|
| 231 |
+
assert words == ['This', 'is', 'a', 'test']
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def test_unknown_language_mwt(unknown_language_name):
|
| 235 |
+
"""
|
| 236 |
+
Test that loading tokenize & mwt works for an unknown language
|
| 237 |
+
"""
|
| 238 |
+
base_pipe = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, processors="tokenize,mwt", download_method=None)
|
| 239 |
+
assert len(base_pipe.processors) == 2
|
| 240 |
+
tokenize_processor = base_pipe.processors["tokenize"]
|
| 241 |
+
mwt_processor = base_pipe.processors["mwt"]
|
| 242 |
+
|
| 243 |
+
pipe=stanza.Pipeline(unknown_language_name,
|
| 244 |
+
processors="tokenize,mwt",
|
| 245 |
+
allow_unknown_language=True,
|
| 246 |
+
tokenize_model_path=tokenize_processor.config['model_path'],
|
| 247 |
+
mwt_model_path=mwt_processor.config['model_path'],
|
| 248 |
+
download_method=None)
|
stanza/stanza/tests/pipeline/test_depparse.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic tests of the depparse processor boolean flags
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
import stanza
|
| 7 |
+
from stanza.pipeline.core import PipelineRequirementsException
|
| 8 |
+
from stanza.utils.conll import CoNLL
|
| 9 |
+
from stanza.tests import *
|
| 10 |
+
|
| 11 |
+
pytestmark = pytest.mark.pipeline
|
| 12 |
+
|
| 13 |
+
# data for testing
|
| 14 |
+
EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard."
|
| 15 |
+
|
| 16 |
+
EN_DOC_CONLLU_PRETAGGED = """
|
| 17 |
+
1 Barack Barack PROPN NNP Number=Sing 0 _ _ _
|
| 18 |
+
2 Obama Obama PROPN NNP Number=Sing 1 _ _ _
|
| 19 |
+
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 2 _ _ _
|
| 20 |
+
4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 3 _ _ _
|
| 21 |
+
5 in in ADP IN _ 4 _ _ _
|
| 22 |
+
6 Hawaii Hawaii PROPN NNP Number=Sing 5 _ _ _
|
| 23 |
+
7 . . PUNCT . _ 6 _ _ _
|
| 24 |
+
|
| 25 |
+
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 0 _ _ _
|
| 26 |
+
2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 _ _ _
|
| 27 |
+
3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 2 _ _ _
|
| 28 |
+
4 president president PROPN NNP Number=Sing 3 _ _ _
|
| 29 |
+
5 in in ADP IN _ 4 _ _ _
|
| 30 |
+
6 2008 2008 NUM CD NumType=Card 5 _ _ _
|
| 31 |
+
7 . . PUNCT . _ 6 _ _ _
|
| 32 |
+
|
| 33 |
+
1 Obama Obama PROPN NNP Number=Sing 0 _ _ _
|
| 34 |
+
2 attended attend VERB VBD Mood=Ind|Tense=Past|VerbForm=Fin 1 _ _ _
|
| 35 |
+
3 Harvard Harvard PROPN NNP Number=Sing 2 _ _ _
|
| 36 |
+
4 . . PUNCT . _ 3 _ _ _
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
""".lstrip()
|
| 40 |
+
|
| 41 |
+
EN_DOC_DEPENDENCY_PARSES_GOLD = """
|
| 42 |
+
('Barack', 4, 'nsubj:pass')
|
| 43 |
+
('Obama', 1, 'flat')
|
| 44 |
+
('was', 4, 'aux:pass')
|
| 45 |
+
('born', 0, 'root')
|
| 46 |
+
('in', 6, 'case')
|
| 47 |
+
('Hawaii', 4, 'obl')
|
| 48 |
+
('.', 4, 'punct')
|
| 49 |
+
|
| 50 |
+
('He', 3, 'nsubj:pass')
|
| 51 |
+
('was', 3, 'aux:pass')
|
| 52 |
+
('elected', 0, 'root')
|
| 53 |
+
('president', 3, 'xcomp')
|
| 54 |
+
('in', 6, 'case')
|
| 55 |
+
('2008', 3, 'obl')
|
| 56 |
+
('.', 3, 'punct')
|
| 57 |
+
|
| 58 |
+
('Obama', 2, 'nsubj')
|
| 59 |
+
('attended', 0, 'root')
|
| 60 |
+
('Harvard', 2, 'obj')
|
| 61 |
+
('.', 2, 'punct')
|
| 62 |
+
""".strip()
|
| 63 |
+
|
| 64 |
+
@pytest.fixture(scope="module")
|
| 65 |
+
def en_depparse_pipeline():
|
| 66 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,pos,lemma,depparse')
|
| 67 |
+
return nlp
|
| 68 |
+
|
| 69 |
+
def test_depparse(en_depparse_pipeline):
|
| 70 |
+
doc = en_depparse_pipeline(EN_DOC)
|
| 71 |
+
assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join([sent.dependencies_string() for sent in doc.sentences])
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_depparse_with_pretagged_doc():
|
| 75 |
+
nlp = stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en',
|
| 76 |
+
'depparse_pretagged': True})
|
| 77 |
+
|
| 78 |
+
doc = CoNLL.conll2doc(input_str=EN_DOC_CONLLU_PRETAGGED)
|
| 79 |
+
processed_doc = nlp(doc)
|
| 80 |
+
|
| 81 |
+
assert EN_DOC_DEPENDENCY_PARSES_GOLD == '\n\n'.join(
|
| 82 |
+
[sent.dependencies_string() for sent in processed_doc.sentences])
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_raises_requirements_exception_if_pretagged_not_passed():
|
| 86 |
+
with pytest.raises(PipelineRequirementsException):
|
| 87 |
+
stanza.Pipeline(**{'processors': 'depparse', 'dir': TEST_MODELS_DIR, 'lang': 'en'})
|
stanza/stanza/tests/pipeline/test_english_pipeline.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic testing of the English pipeline
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import stanza
|
| 7 |
+
from stanza.utils.conll import CoNLL
|
| 8 |
+
from stanza.models.common.doc import Document
|
| 9 |
+
|
| 10 |
+
from stanza.tests import *
|
| 11 |
+
from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu
|
| 12 |
+
|
| 13 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 14 |
+
|
| 15 |
+
# data for testing
|
| 16 |
+
EN_DOC = "Barack Obama was born in Hawaii. He was elected president in 2008. Obama attended Harvard."
|
| 17 |
+
|
| 18 |
+
EN_DOCS = ["Barack Obama was born in Hawaii.", "He was elected president in 2008.", "Obama attended Harvard."]
|
| 19 |
+
|
| 20 |
+
EN_DOC_TOKENS_GOLD = """
|
| 21 |
+
<Token id=1;words=[<Word id=1;text=Barack;lemma=Barack;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=nsubj:pass>]>
|
| 22 |
+
<Token id=2;words=[<Word id=2;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=1;deprel=flat>]>
|
| 23 |
+
<Token id=3;words=[<Word id=3;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=4;deprel=aux:pass>]>
|
| 24 |
+
<Token id=4;words=[<Word id=4;text=born;lemma=bear;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>]>
|
| 25 |
+
<Token id=5;words=[<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>]>
|
| 26 |
+
<Token id=6;words=[<Word id=6;text=Hawaii;lemma=Hawaii;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=obl>]>
|
| 27 |
+
<Token id=7;words=[<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=4;deprel=punct>]>
|
| 28 |
+
|
| 29 |
+
<Token id=1;words=[<Word id=1;text=He;lemma=he;upos=PRON;xpos=PRP;feats=Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs;head=3;deprel=nsubj:pass>]>
|
| 30 |
+
<Token id=2;words=[<Word id=2;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=3;deprel=aux:pass>]>
|
| 31 |
+
<Token id=3;words=[<Word id=3;text=elected;lemma=elect;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>]>
|
| 32 |
+
<Token id=4;words=[<Word id=4;text=president;lemma=president;upos=NOUN;xpos=NN;feats=Number=Sing;head=3;deprel=xcomp>]>
|
| 33 |
+
<Token id=5;words=[<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>]>
|
| 34 |
+
<Token id=6;words=[<Word id=6;text=2008;lemma=2008;upos=NUM;xpos=CD;feats=NumForm=Digit|NumType=Card;head=3;deprel=obl>]>
|
| 35 |
+
<Token id=7;words=[<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=3;deprel=punct>]>
|
| 36 |
+
|
| 37 |
+
<Token id=1;words=[<Word id=1;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=nsubj>]>
|
| 38 |
+
<Token id=2;words=[<Word id=2;text=attended;lemma=attend;upos=VERB;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=0;deprel=root>]>
|
| 39 |
+
<Token id=3;words=[<Word id=3;text=Harvard;lemma=Harvard;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=obj>]>
|
| 40 |
+
<Token id=4;words=[<Word id=4;text=.;lemma=.;upos=PUNCT;xpos=.;head=2;deprel=punct>]>
|
| 41 |
+
""".strip()
|
| 42 |
+
|
| 43 |
+
EN_DOC_WORDS_GOLD = """
|
| 44 |
+
<Word id=1;text=Barack;lemma=Barack;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=nsubj:pass>
|
| 45 |
+
<Word id=2;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=1;deprel=flat>
|
| 46 |
+
<Word id=3;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=4;deprel=aux:pass>
|
| 47 |
+
<Word id=4;text=born;lemma=bear;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>
|
| 48 |
+
<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>
|
| 49 |
+
<Word id=6;text=Hawaii;lemma=Hawaii;upos=PROPN;xpos=NNP;feats=Number=Sing;head=4;deprel=obl>
|
| 50 |
+
<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=4;deprel=punct>
|
| 51 |
+
|
| 52 |
+
<Word id=1;text=He;lemma=he;upos=PRON;xpos=PRP;feats=Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs;head=3;deprel=nsubj:pass>
|
| 53 |
+
<Word id=2;text=was;lemma=be;upos=AUX;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=3;deprel=aux:pass>
|
| 54 |
+
<Word id=3;text=elected;lemma=elect;upos=VERB;xpos=VBN;feats=Tense=Past|VerbForm=Part|Voice=Pass;head=0;deprel=root>
|
| 55 |
+
<Word id=4;text=president;lemma=president;upos=NOUN;xpos=NN;feats=Number=Sing;head=3;deprel=xcomp>
|
| 56 |
+
<Word id=5;text=in;lemma=in;upos=ADP;xpos=IN;head=6;deprel=case>
|
| 57 |
+
<Word id=6;text=2008;lemma=2008;upos=NUM;xpos=CD;feats=NumForm=Digit|NumType=Card;head=3;deprel=obl>
|
| 58 |
+
<Word id=7;text=.;lemma=.;upos=PUNCT;xpos=.;head=3;deprel=punct>
|
| 59 |
+
|
| 60 |
+
<Word id=1;text=Obama;lemma=Obama;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=nsubj>
|
| 61 |
+
<Word id=2;text=attended;lemma=attend;upos=VERB;xpos=VBD;feats=Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin;head=0;deprel=root>
|
| 62 |
+
<Word id=3;text=Harvard;lemma=Harvard;upos=PROPN;xpos=NNP;feats=Number=Sing;head=2;deprel=obj>
|
| 63 |
+
<Word id=4;text=.;lemma=.;upos=PUNCT;xpos=.;head=2;deprel=punct>
|
| 64 |
+
""".strip()
|
| 65 |
+
|
| 66 |
+
EN_DOC_DEPENDENCY_PARSES_GOLD = """
|
| 67 |
+
('Barack', 4, 'nsubj:pass')
|
| 68 |
+
('Obama', 1, 'flat')
|
| 69 |
+
('was', 4, 'aux:pass')
|
| 70 |
+
('born', 0, 'root')
|
| 71 |
+
('in', 6, 'case')
|
| 72 |
+
('Hawaii', 4, 'obl')
|
| 73 |
+
('.', 4, 'punct')
|
| 74 |
+
|
| 75 |
+
('He', 3, 'nsubj:pass')
|
| 76 |
+
('was', 3, 'aux:pass')
|
| 77 |
+
('elected', 0, 'root')
|
| 78 |
+
('president', 3, 'xcomp')
|
| 79 |
+
('in', 6, 'case')
|
| 80 |
+
('2008', 3, 'obl')
|
| 81 |
+
('.', 3, 'punct')
|
| 82 |
+
|
| 83 |
+
('Obama', 2, 'nsubj')
|
| 84 |
+
('attended', 0, 'root')
|
| 85 |
+
('Harvard', 2, 'obj')
|
| 86 |
+
('.', 2, 'punct')
|
| 87 |
+
""".strip()
|
| 88 |
+
|
| 89 |
+
EN_DOC_CONLLU_GOLD = """
|
| 90 |
+
# text = Barack Obama was born in Hawaii.
|
| 91 |
+
# sent_id = 0
|
| 92 |
+
# constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))
|
| 93 |
+
# sentiment = 1
|
| 94 |
+
1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON
|
| 95 |
+
2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON
|
| 96 |
+
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O
|
| 97 |
+
4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O
|
| 98 |
+
5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O
|
| 99 |
+
6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ start_char=25|end_char=31|ner=S-GPE|SpaceAfter=No
|
| 100 |
+
7 . . PUNCT . _ 4 punct _ start_char=31|end_char=32|ner=O|SpacesAfter=\\s\\s
|
| 101 |
+
|
| 102 |
+
# text = He was elected president in 2008.
|
| 103 |
+
# sent_id = 1
|
| 104 |
+
# constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))
|
| 105 |
+
# sentiment = 1
|
| 106 |
+
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=34|end_char=36|ner=O
|
| 107 |
+
2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=37|end_char=40|ner=O
|
| 108 |
+
3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=41|end_char=48|ner=O
|
| 109 |
+
4 president president NOUN NN Number=Sing 3 xcomp _ start_char=49|end_char=58|ner=O
|
| 110 |
+
5 in in ADP IN _ 6 case _ start_char=59|end_char=61|ner=O
|
| 111 |
+
6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ start_char=62|end_char=66|ner=S-DATE|SpaceAfter=No
|
| 112 |
+
7 . . PUNCT . _ 3 punct _ start_char=66|end_char=67|ner=O|SpacesAfter=\\s\\s
|
| 113 |
+
|
| 114 |
+
# text = Obama attended Harvard.
|
| 115 |
+
# sent_id = 2
|
| 116 |
+
# constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))
|
| 117 |
+
# sentiment = 1
|
| 118 |
+
1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=69|end_char=74|ner=S-PERSON
|
| 119 |
+
2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=75|end_char=83|ner=O
|
| 120 |
+
3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ start_char=84|end_char=91|ner=S-ORG|SpaceAfter=No
|
| 121 |
+
4 . . PUNCT . _ 2 punct _ start_char=91|end_char=92|ner=O|SpaceAfter=No
|
| 122 |
+
""".strip()
|
| 123 |
+
|
| 124 |
+
EN_DOC_CONLLU_GOLD_MULTIDOC = """
|
| 125 |
+
# text = Barack Obama was born in Hawaii.
|
| 126 |
+
# sent_id = 0
|
| 127 |
+
# constituency = (ROOT (S (NP (NNP Barack) (NNP Obama)) (VP (VBD was) (VP (VBN born) (PP (IN in) (NP (NNP Hawaii))))) (. .)))
|
| 128 |
+
# sentiment = 1
|
| 129 |
+
1 Barack Barack PROPN NNP Number=Sing 4 nsubj:pass _ start_char=0|end_char=6|ner=B-PERSON
|
| 130 |
+
2 Obama Obama PROPN NNP Number=Sing 1 flat _ start_char=7|end_char=12|ner=E-PERSON
|
| 131 |
+
3 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 4 aux:pass _ start_char=13|end_char=16|ner=O
|
| 132 |
+
4 born bear VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=17|end_char=21|ner=O
|
| 133 |
+
5 in in ADP IN _ 6 case _ start_char=22|end_char=24|ner=O
|
| 134 |
+
6 Hawaii Hawaii PROPN NNP Number=Sing 4 obl _ start_char=25|end_char=31|ner=S-GPE|SpaceAfter=No
|
| 135 |
+
7 . . PUNCT . _ 4 punct _ start_char=31|end_char=32|ner=O|SpaceAfter=No
|
| 136 |
+
|
| 137 |
+
# text = He was elected president in 2008.
|
| 138 |
+
# sent_id = 1
|
| 139 |
+
# constituency = (ROOT (S (NP (PRP He)) (VP (VBD was) (VP (VBN elected) (S (NP (NN president))) (PP (IN in) (NP (CD 2008))))) (. .)))
|
| 140 |
+
# sentiment = 1
|
| 141 |
+
1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 3 nsubj:pass _ start_char=0|end_char=2|ner=O
|
| 142 |
+
2 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 3 aux:pass _ start_char=3|end_char=6|ner=O
|
| 143 |
+
3 elected elect VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root _ start_char=7|end_char=14|ner=O
|
| 144 |
+
4 president president NOUN NN Number=Sing 3 xcomp _ start_char=15|end_char=24|ner=O
|
| 145 |
+
5 in in ADP IN _ 6 case _ start_char=25|end_char=27|ner=O
|
| 146 |
+
6 2008 2008 NUM CD NumForm=Digit|NumType=Card 3 obl _ start_char=28|end_char=32|ner=S-DATE|SpaceAfter=No
|
| 147 |
+
7 . . PUNCT . _ 3 punct _ start_char=32|end_char=33|ner=O|SpaceAfter=No
|
| 148 |
+
|
| 149 |
+
# text = Obama attended Harvard.
|
| 150 |
+
# sent_id = 2
|
| 151 |
+
# constituency = (ROOT (S (NP (NNP Obama)) (VP (VBD attended) (NP (NNP Harvard))) (. .)))
|
| 152 |
+
# sentiment = 1
|
| 153 |
+
1 Obama Obama PROPN NNP Number=Sing 2 nsubj _ start_char=0|end_char=5|ner=S-PERSON
|
| 154 |
+
2 attended attend VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root _ start_char=6|end_char=14|ner=O
|
| 155 |
+
3 Harvard Harvard PROPN NNP Number=Sing 2 obj _ start_char=15|end_char=22|ner=S-ORG|SpaceAfter=No
|
| 156 |
+
4 . . PUNCT . _ 2 punct _ start_char=22|end_char=23|ner=O|SpaceAfter=No
|
| 157 |
+
""".strip()
|
| 158 |
+
|
| 159 |
+
class TestEnglishPipeline:
|
| 160 |
+
@pytest.fixture(scope="class")
|
| 161 |
+
def pipeline(self):
|
| 162 |
+
return stanza.Pipeline(dir=TEST_MODELS_DIR)
|
| 163 |
+
|
| 164 |
+
@pytest.fixture(scope="class")
|
| 165 |
+
def processed_doc(self, pipeline):
|
| 166 |
+
""" Document created by running full English pipeline on a few sentences """
|
| 167 |
+
return pipeline(EN_DOC)
|
| 168 |
+
|
| 169 |
+
def test_text(self, processed_doc):
|
| 170 |
+
assert processed_doc.text == EN_DOC
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def test_conllu(self, processed_doc):
|
| 174 |
+
assert "{:C}".format(processed_doc) == EN_DOC_CONLLU_GOLD
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def test_tokens(self, processed_doc):
|
| 178 |
+
assert "\n\n".join([sent.tokens_string() for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def test_words(self, processed_doc):
|
| 182 |
+
assert "\n\n".join([sent.words_string() for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def test_dependency_parse(self, processed_doc):
|
| 186 |
+
assert "\n\n".join([sent.dependencies_string() for sent in processed_doc.sentences]) == \
|
| 187 |
+
EN_DOC_DEPENDENCY_PARSES_GOLD
|
| 188 |
+
|
| 189 |
+
def test_empty(self, pipeline):
|
| 190 |
+
# make sure that various models handle the degenerate empty case
|
| 191 |
+
pipeline("")
|
| 192 |
+
pipeline("--")
|
| 193 |
+
|
| 194 |
+
def test_bulk_process(self, pipeline):
|
| 195 |
+
""" Double check that the bulk_process method in Pipeline converts documents as expected """
|
| 196 |
+
# it should process strings
|
| 197 |
+
processed = pipeline.bulk_process(EN_DOCS)
|
| 198 |
+
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
|
| 199 |
+
|
| 200 |
+
# it should pass Documents through successfully
|
| 201 |
+
docs = [Document([], text=t) for t in EN_DOCS]
|
| 202 |
+
processed = pipeline.bulk_process(docs)
|
| 203 |
+
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
|
| 204 |
+
|
| 205 |
+
def test_empty_bulk_process(self, pipeline):
|
| 206 |
+
""" Previously we had a bug where an empty document list would cause a crash """
|
| 207 |
+
processed = pipeline.bulk_process([])
|
| 208 |
+
assert processed == []
|
| 209 |
+
|
| 210 |
+
def test_stream(self, pipeline):
|
| 211 |
+
""" Test the streaming interface to the Pipeline """
|
| 212 |
+
# Test all of the documents in one batch
|
| 213 |
+
# (the default batch size is significantly more than |EN_DOCS|)
|
| 214 |
+
processed = [doc for doc in pipeline.stream(EN_DOCS)]
|
| 215 |
+
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
|
| 216 |
+
|
| 217 |
+
# It should also work on an iterator rather than an iterable
|
| 218 |
+
processed = [doc for doc in pipeline.stream(iter(EN_DOCS))]
|
| 219 |
+
assert "\n\n".join(["{:C}".format(doc) for doc in processed]) == EN_DOC_CONLLU_GOLD_MULTIDOC
|
| 220 |
+
|
| 221 |
+
# Stream one at a time
|
| 222 |
+
processed = [doc for doc in pipeline.stream(EN_DOCS, batch_size=1)]
|
| 223 |
+
processed = ["{:C}".format(doc) for doc in processed]
|
| 224 |
+
assert "\n\n".join(processed) == EN_DOC_CONLLU_GOLD_MULTIDOC
|
| 225 |
+
|
| 226 |
+
@pytest.fixture(scope="class")
|
| 227 |
+
def processed_multidoc(self, pipeline):
|
| 228 |
+
""" Document created by running full English pipeline on a few sentences """
|
| 229 |
+
docs = [Document([], text=t) for t in EN_DOCS]
|
| 230 |
+
return pipeline(docs)
|
| 231 |
+
|
| 232 |
+
def test_conllu_multidoc(self, processed_multidoc):
|
| 233 |
+
assert "\n\n".join(["{:C}".format(doc) for doc in processed_multidoc]) == EN_DOC_CONLLU_GOLD_MULTIDOC
|
| 234 |
+
|
| 235 |
+
def test_tokens_multidoc(self, processed_multidoc):
|
| 236 |
+
assert "\n\n".join([sent.tokens_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_TOKENS_GOLD
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def test_words_multidoc(self, processed_multidoc):
|
| 240 |
+
assert "\n\n".join([sent.words_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == EN_DOC_WORDS_GOLD
|
| 241 |
+
|
| 242 |
+
def test_sentence_indices_multidoc(self, processed_multidoc):
|
| 243 |
+
sentences = [sent for doc in processed_multidoc for sent in doc.sentences]
|
| 244 |
+
for sent_idx, sentence in enumerate(sentences):
|
| 245 |
+
assert sent_idx == sentence.index
|
| 246 |
+
|
| 247 |
+
def test_dependency_parse_multidoc(self, processed_multidoc):
|
| 248 |
+
assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc for sent in processed_doc.sentences]) == \
|
| 249 |
+
EN_DOC_DEPENDENCY_PARSES_GOLD
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
@pytest.fixture(scope="class")
|
| 253 |
+
def processed_multidoc_variant(self):
|
| 254 |
+
""" Document created by running full English pipeline on a few sentences """
|
| 255 |
+
docs = [Document([], text=t) for t in EN_DOCS]
|
| 256 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors={'tokenize': 'spacy'})
|
| 257 |
+
return nlp(docs)
|
| 258 |
+
|
| 259 |
+
def test_dependency_parse_multidoc_variant(self, processed_multidoc_variant):
|
| 260 |
+
assert "\n\n".join([sent.dependencies_string() for processed_doc in processed_multidoc_variant for sent in processed_doc.sentences]) == \
|
| 261 |
+
EN_DOC_DEPENDENCY_PARSES_GOLD
|
| 262 |
+
|
| 263 |
+
def test_constituency_parser(self):
|
| 264 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency")
|
| 265 |
+
doc = nlp("This is a test")
|
| 266 |
+
assert str(doc.sentences[0].constituency) == '(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))'
|
| 267 |
+
|
| 268 |
+
def test_on_gpu(self, pipeline):
|
| 269 |
+
"""
|
| 270 |
+
The default pipeline should have all the models on the GPU
|
| 271 |
+
"""
|
| 272 |
+
check_on_gpu(pipeline)
|
| 273 |
+
|
| 274 |
+
def test_on_cpu(self):
|
| 275 |
+
"""
|
| 276 |
+
Create a pipeline on the CPU, check that all the models on CPU
|
| 277 |
+
"""
|
| 278 |
+
pipeline = stanza.Pipeline("en", dir=TEST_MODELS_DIR, use_gpu=False)
|
| 279 |
+
check_on_cpu(pipeline)
|
stanza/stanza/tests/pipeline/test_french_pipeline.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic testing of French pipeline
|
| 3 |
+
|
| 4 |
+
The benefit of this test is to verify that the bulk processing works
|
| 5 |
+
for languages with MWT in them
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import stanza
|
| 10 |
+
from stanza.models.common.doc import Document
|
| 11 |
+
|
| 12 |
+
from stanza.tests import *
|
| 13 |
+
from stanza.tests.pipeline.pipeline_device_tests import check_on_gpu, check_on_cpu
|
| 14 |
+
|
| 15 |
+
pytestmark = pytest.mark.pipeline
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \
|
| 19 |
+
"l'Industrie et du Numérique."
|
| 20 |
+
|
| 21 |
+
EXPECTED_RESULT = """
|
| 22 |
+
[
|
| 23 |
+
[
|
| 24 |
+
{
|
| 25 |
+
"id": 1,
|
| 26 |
+
"text": "Alors",
|
| 27 |
+
"lemma": "alors",
|
| 28 |
+
"upos": "ADV",
|
| 29 |
+
"head": 3,
|
| 30 |
+
"deprel": "advmod",
|
| 31 |
+
"start_char": 0,
|
| 32 |
+
"end_char": 5
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"id": 2,
|
| 36 |
+
"text": "encore",
|
| 37 |
+
"lemma": "encore",
|
| 38 |
+
"upos": "ADV",
|
| 39 |
+
"head": 3,
|
| 40 |
+
"deprel": "advmod",
|
| 41 |
+
"start_char": 6,
|
| 42 |
+
"end_char": 12
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"id": 3,
|
| 46 |
+
"text": "inconnu",
|
| 47 |
+
"lemma": "inconnu",
|
| 48 |
+
"upos": "ADJ",
|
| 49 |
+
"feats": "Gender=Masc|Number=Sing",
|
| 50 |
+
"head": 11,
|
| 51 |
+
"deprel": "advcl",
|
| 52 |
+
"start_char": 13,
|
| 53 |
+
"end_char": 20
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"id": [
|
| 57 |
+
4,
|
| 58 |
+
5
|
| 59 |
+
],
|
| 60 |
+
"text": "du",
|
| 61 |
+
"start_char": 21,
|
| 62 |
+
"end_char": 23
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"id": 4,
|
| 66 |
+
"text": "de",
|
| 67 |
+
"lemma": "de",
|
| 68 |
+
"upos": "ADP",
|
| 69 |
+
"head": 7,
|
| 70 |
+
"deprel": "case"
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"id": 5,
|
| 74 |
+
"text": "le",
|
| 75 |
+
"lemma": "le",
|
| 76 |
+
"upos": "DET",
|
| 77 |
+
"feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art",
|
| 78 |
+
"head": 7,
|
| 79 |
+
"deprel": "det"
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"id": 6,
|
| 83 |
+
"text": "grand",
|
| 84 |
+
"lemma": "grand",
|
| 85 |
+
"upos": "ADJ",
|
| 86 |
+
"feats": "Gender=Masc|Number=Sing",
|
| 87 |
+
"head": 7,
|
| 88 |
+
"deprel": "amod",
|
| 89 |
+
"start_char": 24,
|
| 90 |
+
"end_char": 29
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"id": 7,
|
| 94 |
+
"text": "public",
|
| 95 |
+
"lemma": "public",
|
| 96 |
+
"upos": "NOUN",
|
| 97 |
+
"feats": "Gender=Masc|Number=Sing",
|
| 98 |
+
"head": 3,
|
| 99 |
+
"deprel": "obl:arg",
|
| 100 |
+
"start_char": 30,
|
| 101 |
+
"end_char": 36,
|
| 102 |
+
"misc": "SpaceAfter=No"
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"id": 8,
|
| 106 |
+
"text": ",",
|
| 107 |
+
"lemma": ",",
|
| 108 |
+
"upos": "PUNCT",
|
| 109 |
+
"head": 3,
|
| 110 |
+
"deprel": "punct",
|
| 111 |
+
"start_char": 36,
|
| 112 |
+
"end_char": 37
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"id": 9,
|
| 116 |
+
"text": "Emmanuel",
|
| 117 |
+
"lemma": "Emmanuel",
|
| 118 |
+
"upos": "PROPN",
|
| 119 |
+
"head": 11,
|
| 120 |
+
"deprel": "nsubj",
|
| 121 |
+
"start_char": 38,
|
| 122 |
+
"end_char": 46
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"id": 10,
|
| 126 |
+
"text": "Macron",
|
| 127 |
+
"lemma": "Macron",
|
| 128 |
+
"upos": "PROPN",
|
| 129 |
+
"head": 9,
|
| 130 |
+
"deprel": "flat:name",
|
| 131 |
+
"start_char": 47,
|
| 132 |
+
"end_char": 53
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"id": 11,
|
| 136 |
+
"text": "devient",
|
| 137 |
+
"lemma": "devenir",
|
| 138 |
+
"upos": "VERB",
|
| 139 |
+
"feats": "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
|
| 140 |
+
"head": 0,
|
| 141 |
+
"deprel": "root",
|
| 142 |
+
"start_char": 54,
|
| 143 |
+
"end_char": 61
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"id": 12,
|
| 147 |
+
"text": "en",
|
| 148 |
+
"lemma": "en",
|
| 149 |
+
"upos": "ADP",
|
| 150 |
+
"head": 13,
|
| 151 |
+
"deprel": "case",
|
| 152 |
+
"start_char": 62,
|
| 153 |
+
"end_char": 64
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"id": 13,
|
| 157 |
+
"text": "2014",
|
| 158 |
+
"lemma": "2014",
|
| 159 |
+
"upos": "NUM",
|
| 160 |
+
"feats": "Number=Plur",
|
| 161 |
+
"head": 11,
|
| 162 |
+
"deprel": "obl:mod",
|
| 163 |
+
"start_char": 65,
|
| 164 |
+
"end_char": 69
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"id": 14,
|
| 168 |
+
"text": "ministre",
|
| 169 |
+
"lemma": "ministre",
|
| 170 |
+
"upos": "NOUN",
|
| 171 |
+
"feats": "Gender=Masc|Number=Sing",
|
| 172 |
+
"head": 11,
|
| 173 |
+
"deprel": "xcomp",
|
| 174 |
+
"start_char": 70,
|
| 175 |
+
"end_char": 78
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"id": 15,
|
| 179 |
+
"text": "de",
|
| 180 |
+
"lemma": "de",
|
| 181 |
+
"upos": "ADP",
|
| 182 |
+
"head": 17,
|
| 183 |
+
"deprel": "case",
|
| 184 |
+
"start_char": 79,
|
| 185 |
+
"end_char": 81
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"id": 16,
|
| 189 |
+
"text": "l'",
|
| 190 |
+
"lemma": "le",
|
| 191 |
+
"upos": "DET",
|
| 192 |
+
"feats": "Definite=Def|Number=Sing|PronType=Art",
|
| 193 |
+
"head": 17,
|
| 194 |
+
"deprel": "det",
|
| 195 |
+
"start_char": 82,
|
| 196 |
+
"end_char": 84,
|
| 197 |
+
"misc": "SpaceAfter=No"
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"id": 17,
|
| 201 |
+
"text": "Économie",
|
| 202 |
+
"lemma": "économie",
|
| 203 |
+
"upos": "NOUN",
|
| 204 |
+
"feats": "Gender=Fem|Number=Sing",
|
| 205 |
+
"head": 14,
|
| 206 |
+
"deprel": "nmod",
|
| 207 |
+
"start_char": 84,
|
| 208 |
+
"end_char": 92,
|
| 209 |
+
"misc": "SpaceAfter=No"
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"id": 18,
|
| 213 |
+
"text": ",",
|
| 214 |
+
"lemma": ",",
|
| 215 |
+
"upos": "PUNCT",
|
| 216 |
+
"head": 21,
|
| 217 |
+
"deprel": "punct",
|
| 218 |
+
"start_char": 92,
|
| 219 |
+
"end_char": 93
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"id": 19,
|
| 223 |
+
"text": "de",
|
| 224 |
+
"lemma": "de",
|
| 225 |
+
"upos": "ADP",
|
| 226 |
+
"head": 21,
|
| 227 |
+
"deprel": "case",
|
| 228 |
+
"start_char": 94,
|
| 229 |
+
"end_char": 96
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"id": 20,
|
| 233 |
+
"text": "l'",
|
| 234 |
+
"lemma": "le",
|
| 235 |
+
"upos": "DET",
|
| 236 |
+
"feats": "Definite=Def|Number=Sing|PronType=Art",
|
| 237 |
+
"head": 21,
|
| 238 |
+
"deprel": "det",
|
| 239 |
+
"start_char": 97,
|
| 240 |
+
"end_char": 99,
|
| 241 |
+
"misc": "SpaceAfter=No"
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"id": 21,
|
| 245 |
+
"text": "Industrie",
|
| 246 |
+
"lemma": "industrie",
|
| 247 |
+
"upos": "NOUN",
|
| 248 |
+
"feats": "Gender=Fem|Number=Sing",
|
| 249 |
+
"head": 17,
|
| 250 |
+
"deprel": "conj",
|
| 251 |
+
"start_char": 99,
|
| 252 |
+
"end_char": 108
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"id": 22,
|
| 256 |
+
"text": "et",
|
| 257 |
+
"lemma": "et",
|
| 258 |
+
"upos": "CCONJ",
|
| 259 |
+
"head": 25,
|
| 260 |
+
"deprel": "cc",
|
| 261 |
+
"start_char": 109,
|
| 262 |
+
"end_char": 111
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"id": [
|
| 266 |
+
23,
|
| 267 |
+
24
|
| 268 |
+
],
|
| 269 |
+
"text": "du",
|
| 270 |
+
"start_char": 112,
|
| 271 |
+
"end_char": 114
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"id": 23,
|
| 275 |
+
"text": "de",
|
| 276 |
+
"lemma": "de",
|
| 277 |
+
"upos": "ADP",
|
| 278 |
+
"head": 25,
|
| 279 |
+
"deprel": "case"
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"id": 24,
|
| 283 |
+
"text": "le",
|
| 284 |
+
"lemma": "le",
|
| 285 |
+
"upos": "DET",
|
| 286 |
+
"feats": "Definite=Def|Gender=Masc|Number=Sing|PronType=Art",
|
| 287 |
+
"head": 25,
|
| 288 |
+
"deprel": "det"
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"id": 25,
|
| 292 |
+
"text": "Numérique",
|
| 293 |
+
"lemma": "numérique",
|
| 294 |
+
"upos": "NOUN",
|
| 295 |
+
"feats": "Gender=Masc|Number=Sing",
|
| 296 |
+
"head": 17,
|
| 297 |
+
"deprel": "conj",
|
| 298 |
+
"start_char": 115,
|
| 299 |
+
"end_char": 124,
|
| 300 |
+
"misc": "SpaceAfter=No"
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"id": 26,
|
| 304 |
+
"text": ".",
|
| 305 |
+
"lemma": ".",
|
| 306 |
+
"upos": "PUNCT",
|
| 307 |
+
"head": 11,
|
| 308 |
+
"deprel": "punct",
|
| 309 |
+
"start_char": 124,
|
| 310 |
+
"end_char": 125,
|
| 311 |
+
"misc": "SpaceAfter=No"
|
| 312 |
+
}
|
| 313 |
+
]
|
| 314 |
+
]
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
class TestFrenchPipeline:
|
| 318 |
+
@pytest.fixture(scope="class")
|
| 319 |
+
def pipeline(self):
|
| 320 |
+
""" Create a pipeline with French models """
|
| 321 |
+
pipeline = stanza.Pipeline(processors='tokenize,mwt,pos,lemma,depparse', dir=TEST_MODELS_DIR, lang='fr')
|
| 322 |
+
return pipeline
|
| 323 |
+
|
| 324 |
+
def test_single(self, pipeline):
|
| 325 |
+
doc = pipeline(FR_MWT_SENTENCE)
|
| 326 |
+
compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)
|
| 327 |
+
|
| 328 |
+
def test_bulk(self, pipeline):
|
| 329 |
+
NUM_DOCS = 10
|
| 330 |
+
raw_text = [FR_MWT_SENTENCE] * NUM_DOCS
|
| 331 |
+
raw_doc = [Document([], text=doccontent) for doccontent in raw_text]
|
| 332 |
+
|
| 333 |
+
result = pipeline(raw_doc)
|
| 334 |
+
|
| 335 |
+
assert len(result) == NUM_DOCS
|
| 336 |
+
for doc in result:
|
| 337 |
+
compare_ignoring_whitespace(str(doc), EXPECTED_RESULT)
|
| 338 |
+
assert len(doc.sentences) == 1
|
| 339 |
+
assert doc.num_words == 26
|
| 340 |
+
assert doc.num_tokens == 24
|
| 341 |
+
|
| 342 |
+
def test_on_gpu(self, pipeline):
|
| 343 |
+
"""
|
| 344 |
+
The default pipeline should have all the models on the GPU
|
| 345 |
+
"""
|
| 346 |
+
check_on_gpu(pipeline)
|
| 347 |
+
|
| 348 |
+
def test_on_cpu(self):
|
| 349 |
+
"""
|
| 350 |
+
Create a pipeline on the CPU, check that all the models on CPU
|
| 351 |
+
"""
|
| 352 |
+
pipeline = stanza.Pipeline("fr", dir=TEST_MODELS_DIR, use_gpu=False)
|
| 353 |
+
check_on_cpu(pipeline)
|
stanza/stanza/tests/pipeline/test_lemmatizer.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic testing of lemmatization
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import stanza
|
| 7 |
+
|
| 8 |
+
from stanza.tests import *
|
| 9 |
+
from stanza.models.common.doc import TEXT, UPOS, LEMMA
|
| 10 |
+
|
| 11 |
+
pytestmark = pytest.mark.pipeline
|
| 12 |
+
|
| 13 |
+
EN_DOC = "Joe Smith was born in California."
|
| 14 |
+
|
| 15 |
+
EN_DOC_IDENTITY_GOLD = """
|
| 16 |
+
Joe Joe
|
| 17 |
+
Smith Smith
|
| 18 |
+
was was
|
| 19 |
+
born born
|
| 20 |
+
in in
|
| 21 |
+
California California
|
| 22 |
+
. .
|
| 23 |
+
""".strip()
|
| 24 |
+
|
| 25 |
+
EN_DOC_LEMMATIZER_MODEL_GOLD = """
|
| 26 |
+
Joe Joe
|
| 27 |
+
Smith Smith
|
| 28 |
+
was be
|
| 29 |
+
born bear
|
| 30 |
+
in in
|
| 31 |
+
California California
|
| 32 |
+
. .
|
| 33 |
+
""".strip()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_identity_lemmatizer():
|
| 37 |
+
nlp = stanza.Pipeline(**{'processors': 'tokenize,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'lemma_use_identity': True}, download_method=None)
|
| 38 |
+
doc = nlp(EN_DOC)
|
| 39 |
+
word_lemma_pairs = []
|
| 40 |
+
for w in doc.iter_words():
|
| 41 |
+
word_lemma_pairs += [f"{w.text} {w.lemma}"]
|
| 42 |
+
assert EN_DOC_IDENTITY_GOLD == "\n".join(word_lemma_pairs)
|
| 43 |
+
|
| 44 |
+
def test_full_lemmatizer():
|
| 45 |
+
nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, download_method=None)
|
| 46 |
+
doc = nlp(EN_DOC)
|
| 47 |
+
word_lemma_pairs = []
|
| 48 |
+
for w in doc.iter_words():
|
| 49 |
+
word_lemma_pairs += [f"{w.text} {w.lemma}"]
|
| 50 |
+
assert EN_DOC_LEMMATIZER_MODEL_GOLD == "\n".join(word_lemma_pairs)
|
| 51 |
+
|
| 52 |
+
def find_unknown_word(lemmatizer, base):
|
| 53 |
+
for i in range(10):
|
| 54 |
+
base = base + "z"
|
| 55 |
+
if base not in lemmatizer.word_dict and all(x[0] != base for x in lemmatizer.composite_dict.keys()):
|
| 56 |
+
return base
|
| 57 |
+
raise RuntimeError("wtf?")
|
| 58 |
+
|
| 59 |
+
def test_store_results():
|
| 60 |
+
nlp = stanza.Pipeline(**{'processors': 'tokenize,pos,lemma', 'dir': TEST_MODELS_DIR, 'lang': 'en'}, lemma_store_results=True, download_method=None)
|
| 61 |
+
lemmatizer = nlp.processors["lemma"]._trainer
|
| 62 |
+
|
| 63 |
+
az = find_unknown_word(lemmatizer, "a")
|
| 64 |
+
bz = find_unknown_word(lemmatizer, "b")
|
| 65 |
+
cz = find_unknown_word(lemmatizer, "c")
|
| 66 |
+
|
| 67 |
+
# try sentences with the order long, short
|
| 68 |
+
doc = nlp("I found an " + az + " in my " + bz + ". It was a " + cz)
|
| 69 |
+
stuff = doc.get([TEXT, UPOS, LEMMA])
|
| 70 |
+
assert len(stuff) == 12
|
| 71 |
+
assert stuff[3][0] == az
|
| 72 |
+
assert stuff[6][0] == bz
|
| 73 |
+
assert stuff[11][0] == cz
|
| 74 |
+
|
| 75 |
+
assert lemmatizer.composite_dict[(az, stuff[3][1])] == stuff[3][2]
|
| 76 |
+
assert lemmatizer.composite_dict[(bz, stuff[6][1])] == stuff[6][2]
|
| 77 |
+
assert lemmatizer.composite_dict[(cz, stuff[11][1])] == stuff[11][2]
|
| 78 |
+
|
| 79 |
+
doc2 = nlp("I found an " + az + " in my " + bz + ". It was a " + cz)
|
| 80 |
+
stuff2 = doc2.get([TEXT, UPOS, LEMMA])
|
| 81 |
+
|
| 82 |
+
assert stuff == stuff2
|
| 83 |
+
|
| 84 |
+
dz = find_unknown_word(lemmatizer, "d")
|
| 85 |
+
ez = find_unknown_word(lemmatizer, "e")
|
| 86 |
+
fz = find_unknown_word(lemmatizer, "f")
|
| 87 |
+
|
| 88 |
+
# try sentences with the order long, short
|
| 89 |
+
doc = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz)
|
| 90 |
+
stuff = doc.get([TEXT, UPOS, LEMMA])
|
| 91 |
+
assert len(stuff) == 12
|
| 92 |
+
assert stuff[3][0] == dz
|
| 93 |
+
assert stuff[8][0] == ez
|
| 94 |
+
assert stuff[11][0] == fz
|
| 95 |
+
|
| 96 |
+
assert lemmatizer.composite_dict[(dz, stuff[3][1])] == stuff[3][2]
|
| 97 |
+
assert lemmatizer.composite_dict[(ez, stuff[8][1])] == stuff[8][2]
|
| 98 |
+
assert lemmatizer.composite_dict[(fz, stuff[11][1])] == stuff[11][2]
|
| 99 |
+
|
| 100 |
+
doc2 = nlp("It was a " + dz + ". I found an " + ez + " in my " + fz)
|
| 101 |
+
stuff2 = doc2.get([TEXT, UPOS, LEMMA])
|
| 102 |
+
|
| 103 |
+
assert stuff == stuff2
|
| 104 |
+
|
| 105 |
+
assert az not in lemmatizer.word_dict
|
| 106 |
+
|
| 107 |
+
def test_caseless_lemmatizer():
|
| 108 |
+
"""
|
| 109 |
+
Test that setting the lemmatizer as caseless at Pipeline time lowercases the text
|
| 110 |
+
"""
|
| 111 |
+
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)
|
| 112 |
+
# the capital letter here should throw off the lemmatizer & it won't remove the plural
|
| 113 |
+
# although weirdly the current English model *does* lowercase the A
|
| 114 |
+
doc = nlp("Here is an Excerpt")
|
| 115 |
+
assert doc.sentences[0].words[-1].lemma == 'excerpt'
|
| 116 |
+
|
| 117 |
+
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None, lemma_caseless=True)
|
| 118 |
+
# with the model set to lowercasing, the word will be treated as if it were 'antennae'
|
| 119 |
+
doc = nlp("Here is an Excerpt")
|
| 120 |
+
assert doc.sentences[0].words[-1].lemma == 'Excerpt'
|
| 121 |
+
|
| 122 |
+
def test_latin_caseless_lemmatizer():
|
| 123 |
+
"""
|
| 124 |
+
Test the Latin caseless lemmatizer
|
| 125 |
+
"""
|
| 126 |
+
nlp = stanza.Pipeline('la', package='ittb', processors='tokenize,pos,lemma', model_dir=TEST_MODELS_DIR, download_method=None)
|
| 127 |
+
lemmatizer = nlp.processors['lemma']
|
| 128 |
+
assert lemmatizer.config['caseless']
|
| 129 |
+
|
| 130 |
+
doc = nlp("Quod Erat Demonstrandum")
|
| 131 |
+
expected_lemmas = "qui sum demonstro".split()
|
| 132 |
+
assert len(doc.sentences) == 1
|
| 133 |
+
assert len(doc.sentences[0].words) == 3
|
| 134 |
+
for word, expected in zip(doc.sentences[0].words, expected_lemmas):
|
| 135 |
+
assert word.lemma == expected
|
stanza/stanza/tests/pipeline/test_pipeline_constituency_processor.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import pytest
|
| 3 |
+
import stanza
|
| 4 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 5 |
+
|
| 6 |
+
from stanza.tests import *
|
| 7 |
+
|
| 8 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 9 |
+
|
| 10 |
+
# data for testing
|
| 11 |
+
TEST_TEXT = "This is a test. Another sentence. Are these sorted?"
|
| 12 |
+
|
| 13 |
+
TEST_TOKENS = [["This", "is", "a", "test", "."], ["Another", "sentence", "."], ["Are", "these", "sorted", "?"]]
|
| 14 |
+
|
| 15 |
+
@pytest.fixture(scope="module")
|
| 16 |
+
def foundation_cache():
|
| 17 |
+
gc.collect()
|
| 18 |
+
return FoundationCache()
|
| 19 |
+
|
| 20 |
+
def check_results(doc):
|
| 21 |
+
assert len(doc.sentences) == len(TEST_TOKENS)
|
| 22 |
+
for sentence, expected in zip(doc.sentences, TEST_TOKENS):
|
| 23 |
+
assert sentence.constituency.leaf_labels() == expected
|
| 24 |
+
|
| 25 |
+
def test_sorted_big_batch(foundation_cache):
|
| 26 |
+
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
|
| 27 |
+
doc = pipe(TEST_TEXT)
|
| 28 |
+
check_results(doc)
|
| 29 |
+
|
| 30 |
+
def test_comments(foundation_cache):
|
| 31 |
+
"""
|
| 32 |
+
Test that the pipeline is creating constituency comments
|
| 33 |
+
"""
|
| 34 |
+
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
|
| 35 |
+
doc = pipe(TEST_TEXT)
|
| 36 |
+
check_results(doc)
|
| 37 |
+
for sentence in doc.sentences:
|
| 38 |
+
assert any(x.startswith("# constituency = ") for x in sentence.comments)
|
| 39 |
+
doc.sentences[0].constituency = "asdf"
|
| 40 |
+
assert "# constituency = asdf" in doc.sentences[0].comments
|
| 41 |
+
for sentence in doc.sentences:
|
| 42 |
+
assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 1
|
| 43 |
+
|
| 44 |
+
def test_illegal_batch_size(foundation_cache):
|
| 45 |
+
stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None)
|
| 46 |
+
with pytest.raises(ValueError):
|
| 47 |
+
stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size="zzz", foundation_cache=foundation_cache, download_method=None)
|
| 48 |
+
|
| 49 |
+
def test_sorted_one_batch(foundation_cache):
|
| 50 |
+
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=1, foundation_cache=foundation_cache, download_method=None)
|
| 51 |
+
doc = pipe(TEST_TEXT)
|
| 52 |
+
check_results(doc)
|
| 53 |
+
|
| 54 |
+
def test_sorted_two_batch(foundation_cache):
|
| 55 |
+
pipe = stanza.Pipeline("en", model_dir=TEST_MODELS_DIR, processors="tokenize,pos,constituency", constituency_batch_size=2, foundation_cache=foundation_cache, download_method=None)
|
| 56 |
+
doc = pipe(TEST_TEXT)
|
| 57 |
+
check_results(doc)
|
| 58 |
+
|
| 59 |
+
def test_get_constituents(foundation_cache):
|
| 60 |
+
pipe = stanza.Pipeline("en", processors="tokenize,pos,constituency", foundation_cache=foundation_cache, download_method=None)
|
| 61 |
+
assert "SBAR" in pipe.processors["constituency"].get_constituents()
|