bowphs commited on
Commit
bf62052
·
verified ·
1 Parent(s): 6cd9428

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stanza/stanza/pipeline/demo/stanza-brat.css +74 -0
  2. stanza/stanza/pipeline/demo/stanza-parseviewer.js +215 -0
  3. stanza/stanza/pipeline/external/__init__.py +0 -0
  4. stanza/stanza/tests/classifiers/__init__.py +0 -0
  5. stanza/stanza/tests/classifiers/test_classifier.py +317 -0
  6. stanza/stanza/tests/classifiers/test_process_utils.py +83 -0
  7. stanza/stanza/tests/common/test_bert_embedding.py +33 -0
  8. stanza/stanza/tests/common/test_char_model.py +190 -0
  9. stanza/stanza/tests/common/test_common_data.py +32 -0
  10. stanza/stanza/tests/common/test_data_objects.py +60 -0
  11. stanza/stanza/tests/common/test_doc.py +174 -0
  12. stanza/stanza/tests/common/test_dropout.py +28 -0
  13. stanza/stanza/tests/common/test_short_name_to_treebank.py +14 -0
  14. stanza/stanza/tests/constituency/test_convert_it_vit.py +228 -0
  15. stanza/stanza/tests/constituency/test_convert_starlang.py +37 -0
  16. stanza/stanza/tests/constituency/test_in_order_oracle.py +522 -0
  17. stanza/stanza/tests/constituency/test_lstm_model.py +552 -0
  18. stanza/stanza/tests/constituency/test_text_processing.py +109 -0
  19. stanza/stanza/tests/constituency/test_top_down_oracle.py +443 -0
  20. stanza/stanza/tests/constituency/test_trainer.py +639 -0
  21. stanza/stanza/tests/constituency/test_transformer_tree_stack.py +195 -0
  22. stanza/stanza/tests/constituency/test_transition_sequence.py +156 -0
  23. stanza/stanza/tests/constituency/test_tree_reader.py +119 -0
  24. stanza/stanza/tests/constituency/test_vietnamese.py +121 -0
  25. stanza/stanza/tests/langid/test_langid.py +615 -0
  26. stanza/stanza/tests/lemma/__init__.py +0 -0
  27. stanza/stanza/tests/mwt/test_utils.py +59 -0
  28. stanza/stanza/tests/ner/__init__.py +0 -0
  29. stanza/stanza/tests/ner/test_combine_ner_datasets.py +39 -0
  30. stanza/stanza/tests/ner/test_models_ner_scorer.py +28 -0
  31. stanza/stanza/tests/ner/test_ner_tagger.py +94 -0
  32. stanza/stanza/tests/ner/test_ner_trainer.py +32 -0
  33. stanza/stanza/tests/ner/test_pay_amt_annotators.py +50 -0
  34. stanza/stanza/tests/ner/test_split_wikiner.py +202 -0
  35. stanza/stanza/tests/ner/test_suc3.py +91 -0
  36. stanza/stanza/tests/pipeline/test_decorators.py +127 -0
  37. stanza/stanza/tests/pipeline/test_pipeline_mwt_expander.py +123 -0
  38. stanza/stanza/tests/pos/__init__.py +0 -0
  39. stanza/stanza/tests/pos/test_tagger.py +315 -0
  40. stanza/stanza/tests/resources/__init__.py +0 -0
  41. stanza/stanza/tests/resources/test_default_packages.py +24 -0
  42. stanza/stanza/tests/resources/test_prepare_resources.py +30 -0
  43. stanza/stanza/tests/server/test_server_misc.py +115 -0
  44. stanza/stanza/utils/datasets/common.py +286 -0
  45. stanza/stanza/utils/datasets/conllu_to_text.pl +248 -0
  46. stanza/stanza/utils/datasets/prepare_lemma_classifier.py +144 -0
  47. stanza/stanza/utils/datasets/prepare_mwt_treebank.py +88 -0
  48. stanza/stanza/utils/datasets/prepare_pos_treebank.py +38 -0
  49. stanza/stanza/utils/datasets/random_split_conllu.py +59 -0
  50. stanza/stanza/utils/datasets/thai_syllable_dict_generator.py +53 -0
stanza/stanza/pipeline/demo/stanza-brat.css ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .red {
3
+ color:#990000
4
+ }
5
+
6
+ #wrap {
7
+ min-height: 100%;
8
+ height: auto;
9
+ margin: 0 auto -6ex;
10
+ padding: 0 0 6ex;
11
+ }
12
+
13
+ .pattern_tab {
14
+ margin: 1ex;
15
+ }
16
+
17
+ .pattern_brat {
18
+ margin-top: 1ex;
19
+ }
20
+
21
+ .label {
22
+ color: #777777;
23
+ font-size: small;
24
+ }
25
+
26
+ .footer {
27
+ bottom: 0;
28
+ width: 100%;
29
+ /* Set the fixed height of the footer here */
30
+ height: 5ex;
31
+ padding-top: 1ex;
32
+ margin-top: 1ex;
33
+ background-color: #f5f5f5;
34
+ }
35
+
36
+ .corenlp_error {
37
+ margin-top: 2ex;
38
+ }
39
+
40
+ /* Styling for parse graph */
41
+ .node rect {
42
+ stroke: #333;
43
+ fill: #fff;
44
+ }
45
+
46
+ .parse-RULE rect {
47
+ fill: #C0D9AF;
48
+ }
49
+
50
+ .parse-TERMINAL rect {
51
+ stroke: #333;
52
+ fill: #EEE8AA;
53
+ }
54
+
55
+ .node.highlighted {
56
+ stroke: #ffff00;
57
+ }
58
+
59
+ .edgePath path {
60
+ stroke: #333;
61
+ fill: #333;
62
+ stroke-width: 1.5px;
63
+ }
64
+
65
+ .parse-EDGE path {
66
+ stroke: DarkGray;
67
+ fill: DarkGray;
68
+ stroke-width: 1.5px;
69
+ }
70
+
71
+ .logo {
72
+ font-family: "Lato", "Gill Sans MT", "Gill Sans", "Helvetica", "Arial", sans-serif;
73
+ font-style: italic;
74
+ }
stanza/stanza/pipeline/demo/stanza-parseviewer.js ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //'use strict';
2
+
3
+ //d3 || require('d3');
4
+ //var dagreD3 = require('dagre-d3');
5
+ //var jquery = require('jquery');
6
+ //var $ = jquery;
7
+
8
+ var ParseViewer = function(params) {
9
+ // Container in which the scene template is displayed
10
+ this.selector = params.selector;
11
+ this.container = $(this.selector);
12
+ this.fitToGraph = true;
13
+ this.onClickNodeCallback = params.onClickNodeCallback;
14
+ this.onHoverNodeCallback = params.onHoverNodeCallback;
15
+ this.init();
16
+ return this;
17
+ };
18
+
19
+ ParseViewer.MIN_WIDTH = 100;
20
+ ParseViewer.MIN_HEIGHT = 100;
21
+
22
+ ParseViewer.prototype.constructor = ParseViewer;
23
+
24
+ ParseViewer.prototype.getAutoWidth = function () {
25
+ return Math.max(ParseViewer.MIN_WIDTH, this.container.width());
26
+ };
27
+
28
+ ParseViewer.prototype.getAutoHeight = function () {
29
+ return Math.max(ParseViewer.MIN_HEIGHT, this.container.height() - 20);
30
+ };
31
+
32
+ ParseViewer.prototype.init = function () {
33
+ var canvasWidth = this.getAutoWidth();
34
+ var canvasHeight = this.getAutoHeight();
35
+ this.parseElem = d3.select(this.selector)
36
+ .append('svg')
37
+ .attr({'width': canvasWidth, 'height': canvasHeight})
38
+ .style({'width': canvasWidth, 'height': canvasHeight});
39
+ console.log(this.parseElem);
40
+ this.graph = null;
41
+ this.graphRendered = false;
42
+
43
+ this.controls = $('<div class="text"></div>');
44
+ this.container.append(this.controls);
45
+ };
46
+
47
+ var GraphBuilder = function(roots) {
48
+ // Create the input graph
49
+ this.graph = new dagreD3.graphlib.Graph()
50
+ .setGraph({})
51
+ .setDefaultEdgeLabel(function () {
52
+ return {};
53
+ });
54
+ this.visitIndex = 0;
55
+ //console.log('building graph', roots);
56
+ for (var i = 0; i < roots.length; i++) {
57
+ this.build(roots[i]);
58
+ }
59
+ };
60
+
61
+ GraphBuilder.prototype.build = function(node) {
62
+ console.log(node);
63
+ // Track my visit index
64
+ this.visitIndex++;
65
+ node.visitIndex = this.visitIndex;
66
+
67
+ // Add a node
68
+ var nodeData = node; // TODO: replace with semantic data
69
+ var nodeLabel = node.label;
70
+ var nodeIndex = node.visitIndex;
71
+ var nodeClass = 'parse-RULE';
72
+
73
+ this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });
74
+ if (node.parent) {
75
+ this.graph.setEdge(node.parent.visitIndex, nodeIndex, {
76
+ class: 'parse-EDGE'
77
+ });
78
+ }
79
+
80
+ if (node.isTerminal) {
81
+ this.visitIndex++;
82
+ nodeIndex = this.visitIndex;
83
+ nodeLabel = node.text;
84
+ nodeClass = 'parse-TERMINAL';
85
+
86
+ this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });
87
+ this.graph.setEdge(node.visitIndex, nodeIndex, {
88
+ class: 'parse-EDGE'
89
+ });
90
+ } else if (node.children) {
91
+ for (var i = 0; i < node.children.length; i++) {
92
+ this.build(node.children[i]);
93
+ }
94
+ }
95
+ };
96
+
97
+ ParseViewer.prototype.updateGraphPosition = function (svg, g, minWidth, minHeight) {
98
+ if (this.fitToGraph) {
99
+ minWidth = g.graph().width;
100
+ minHeight = this.getAutoHeight();
101
+ }
102
+ adjustGraphPositioning(svg, g, minWidth, minHeight);
103
+ };
104
+
105
+ function adjustGraphPositioning(svg, g, minWidth, minHeight) {
106
+ // Resize svg
107
+ var newWidth = Math.max(minWidth, g.graph().width);
108
+ var newHeight = Math.max(minHeight, g.graph().height + 40);
109
+ svg.attr({'width': newWidth, 'height': newHeight});
110
+ svg.style({'width': newWidth, 'height': newHeight});
111
+ // Center the graph
112
+ var svgGroup = svg.select('g');
113
+ var xCenterOffset = (svg.attr('width') - g.graph().width) / 2;
114
+ svgGroup.attr('transform', 'translate(' + xCenterOffset + ', 20)');
115
+ svg.attr('height', g.graph().height + 40);
116
+ svg.style('height', g.graph().height + 40);
117
+ }
118
+
119
+ ParseViewer.prototype.renderGraph = function (svg, g, parse) {
120
+ // Create the renderer
121
+ var render = new dagreD3.render();
122
+ // Run the renderer. This is what draws the final graph.
123
+ var svgGroup = svg.select('g');
124
+ render(svgGroup, g);
125
+
126
+ var scope = this;
127
+ var nodes = svgGroup.selectAll('g.node');
128
+ nodes.on('click',
129
+ function (d) {
130
+ var v = d;
131
+ var node = g.node(v);
132
+ if (scope.onClickNodeCallback) {
133
+ scope.onClickNodeCallback(node.data);
134
+ }
135
+ console.log(g.node(v));
136
+ }
137
+ );
138
+
139
+ nodes.on('mouseover',
140
+ function (d) {
141
+ var v = d;
142
+ var node = g.node(v);
143
+ if (scope.onHoverNodeCallback) {
144
+ scope.onHoverNodeCallback(node.data);
145
+ }
146
+ }
147
+ );
148
+
149
+ this.updateGraphPosition(svg, g, svg.attr('width'), svg.attr('height'));
150
+ this.graphRendered = true;
151
+ };
152
+
153
+ ParseViewer.prototype.showParse = function (root) {
154
+ this.showParses([root]);
155
+ };
156
+
157
+ ParseViewer.prototype.showParses = function (roots) {
158
+ // Take parse and create a graph
159
+ var gb = new GraphBuilder(roots);
160
+ var g = gb.graph;
161
+
162
+ g.nodes().forEach(function (v) {
163
+ var node = g.node(v);
164
+ // Round the corners of the nodes
165
+ node.rx = node.ry = 5;
166
+ });
167
+
168
+ var svg = this.parseElem;
169
+ svg.selectAll('*').remove();
170
+ var svgGroup = svg.append('g');
171
+ this.graph = g;
172
+ this.parse = roots;
173
+ if (this.container.is(':visible')) {
174
+ if (roots.length > 0) {
175
+ this.renderGraph(svg, this.graph, this.parse);
176
+ }
177
+ } else {
178
+ this.graphRendered = false;
179
+ }
180
+ };
181
+
182
+ ParseViewer.prototype.showAnnotation = function (annotation) {
183
+ var parses = [];
184
+ for (var i = 0; i < annotation.sentences.length; i++) {
185
+ var s = annotation.sentences[i];
186
+ if (s && s.parseTree) {
187
+ parses.push(s.parseTree);
188
+ }
189
+ }
190
+ this.showParses(parses);
191
+ };
192
+
193
+ ParseViewer.prototype.onResize = function () {
194
+ var canvasWidth = this.getAutoWidth();
195
+ var canvasHeight = this.getAutoHeight();
196
+ var svg = this.parseElem;
197
+
198
+ // Center the graph
199
+ var svgGroup = svg.select('g');
200
+ if (svgGroup && this.graph) {
201
+ if (!this.graphRendered) {
202
+ svg.attr({'width': canvasWidth, 'height': canvasHeight});
203
+ svg.style({'width': canvasWidth, 'height': canvasHeight});
204
+ this.renderGraph(svg, this.graph, this.parse);
205
+ } else {
206
+ this.updateGraphPosition(svg, this.graph, canvasWidth, canvasHeight);
207
+ }
208
+ } else {
209
+ svg.attr({'width': canvasWidth, 'height': canvasHeight});
210
+ svg.style({'width': canvasWidth, 'height': canvasHeight});
211
+ }
212
+ };
213
+
214
+ // Exports
215
+ //module.exports = ParseViewer;
stanza/stanza/pipeline/external/__init__.py ADDED
File without changes
stanza/stanza/tests/classifiers/__init__.py ADDED
File without changes
stanza/stanza/tests/classifiers/test_classifier.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import pytest
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ import stanza
10
+ import stanza.models.classifier as classifier
11
+ import stanza.models.classifiers.data as data
12
+ from stanza.models.classifiers.trainer import Trainer
13
+ from stanza.models.common import pretrain
14
+ from stanza.models.common import utils
15
+
16
+ from stanza.tests import TEST_MODELS_DIR
17
+ from stanza.tests.classifiers.test_data import train_file, dev_file, test_file, DATASET, SENTENCES
18
+
19
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
20
+
21
+ EMB_DIM = 5
22
+
23
+ @pytest.fixture(scope="module")
24
+ def fake_embeddings(tmp_path_factory):
25
+ """
26
+ will return a path to a fake embeddings file with the words in SENTENCES
27
+ """
28
+ # could set np random seed here
29
+ words = sorted(set([x.lower() for y in SENTENCES for x in y]))
30
+ words = words[:-1]
31
+ embedding_dir = tmp_path_factory.mktemp("data")
32
+ embedding_txt = embedding_dir / "embedding.txt"
33
+ embedding_pt = embedding_dir / "embedding.pt"
34
+ embedding = np.random.random((len(words), EMB_DIM))
35
+
36
+ with open(embedding_txt, "w", encoding="utf-8") as fout:
37
+ for word, emb in zip(words, embedding):
38
+ fout.write(word)
39
+ fout.write("\t")
40
+ fout.write("\t".join(str(x) for x in emb))
41
+ fout.write("\n")
42
+
43
+ pt = pretrain.Pretrain(str(embedding_pt), str(embedding_txt))
44
+ pt.load()
45
+ assert os.path.exists(embedding_pt)
46
+ return embedding_pt
47
+
48
+ class TestClassifier:
49
+ def build_model(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):
50
+ """
51
+ Build a model to be used by one of the later tests
52
+ """
53
+ save_dir = str(tmp_path / "classifier")
54
+ save_name = "model.pt"
55
+ args = ["--save_dir", save_dir,
56
+ "--save_name", save_name,
57
+ "--wordvec_pretrain_file", str(fake_embeddings),
58
+ "--filter_channels", "20",
59
+ "--fc_shapes", "20,10",
60
+ "--train_file", str(train_file),
61
+ "--dev_file", str(dev_file),
62
+ "--max_epochs", "2",
63
+ "--batch_size", "60"]
64
+ if extra_args is not None:
65
+ args = args + extra_args
66
+ args = classifier.parse_args(args)
67
+ train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
68
+ if checkpoint_file:
69
+ trainer = Trainer.load(checkpoint_file, args, load_optimizer=True)
70
+ else:
71
+ trainer = Trainer.build_new_model(args, train_set)
72
+ return trainer, train_set, args
73
+
74
+ def run_training(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):
75
+ """
76
+ Iterate a couple times over a model
77
+ """
78
+ trainer, train_set, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args, checkpoint_file)
79
+ dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
80
+ labels = data.dataset_labels(train_set)
81
+
82
+ save_filename = os.path.join(args.save_dir, args.save_name)
83
+ if checkpoint_file is None:
84
+ checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)
85
+ classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)
86
+ return trainer, save_filename, checkpoint_file
87
+
88
+ def test_build_model(self, tmp_path, fake_embeddings, train_file, dev_file):
89
+ """
90
+ Test that building a basic model works
91
+ """
92
+ self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
93
+
94
+ def test_save_load(self, tmp_path, fake_embeddings, train_file, dev_file):
95
+ """
96
+ Test that a basic model can save & load
97
+ """
98
+ trainer, _, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
99
+
100
+ save_filename = os.path.join(args.save_dir, args.save_name)
101
+ trainer.save(save_filename)
102
+
103
+ args.load_name = args.save_name
104
+ trainer = Trainer.load(args.load_name, args)
105
+ args.load_name = save_filename
106
+ trainer = Trainer.load(args.load_name, args)
107
+
108
+ def test_train_basic(self, tmp_path, fake_embeddings, train_file, dev_file):
109
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
110
+
111
+ def test_train_bilstm(self, tmp_path, fake_embeddings, train_file, dev_file):
112
+ """
113
+ Test w/ and w/o bilstm variations of the classifier
114
+ """
115
+ args = ["--bilstm", "--bilstm_hidden_dim", "20"]
116
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
117
+
118
+ args = ["--no_bilstm"]
119
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
120
+
121
+ def test_train_maxpool_width(self, tmp_path, fake_embeddings, train_file, dev_file):
122
+ """
123
+ Test various maxpool widths
124
+
125
+ Also sets --filter_channels to a multiple of 2 but not of 3 for
126
+ the test to make sure the math is done correctly on a non-divisible width
127
+ """
128
+ args = ["--maxpool_width", "1", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
129
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
130
+
131
+ args = ["--maxpool_width", "2", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
132
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
133
+
134
+ args = ["--maxpool_width", "3", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
135
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
136
+
137
+ def test_train_conv_2d(self, tmp_path, fake_embeddings, train_file, dev_file):
138
+ args = ["--filter_sizes", "(3,4,5)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
139
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
140
+
141
+ args = ["--filter_sizes", "((3,2),)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
142
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
143
+
144
+ args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
145
+ self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
146
+
147
+ def test_train_filter_channels(self, tmp_path, fake_embeddings, train_file, dev_file):
148
+ args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--no_bilstm"]
149
+ trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
150
+ assert trainer.model.fc_input_size == 40
151
+
152
+ args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "15,20", "--no_bilstm"]
153
+ trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
154
+ # 50 = 2x15 for the 2d conv (over 5 dim embeddings) + 20
155
+ assert trainer.model.fc_input_size == 50
156
+
157
+ def test_train_bert(self, tmp_path, fake_embeddings, train_file, dev_file):
158
+ """
159
+ Test on a tiny Bert WITHOUT finetuning, which hopefully does not take up too much disk space or memory
160
+ """
161
+ bert_model = "hf-internal-testing/tiny-bert"
162
+
163
+ trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model])
164
+ assert os.path.exists(save_filename)
165
+ saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
166
+ # check that the bert model wasn't saved as part of the classifier
167
+ assert not saved_model['params']['config']['force_bert_saved']
168
+ assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
169
+
170
+ def test_finetune_bert(self, tmp_path, fake_embeddings, train_file, dev_file):
171
+ """
172
+ Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory
173
+ """
174
+ bert_model = "hf-internal-testing/tiny-bert"
175
+
176
+ trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune"])
177
+ assert os.path.exists(save_filename)
178
+ saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
179
+ # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer
180
+ assert saved_model['params']['config']['force_bert_saved']
181
+ assert any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
182
+
183
+ def test_finetune_bert_layers(self, tmp_path, fake_embeddings, train_file, dev_file):
184
+ """Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory, using 2 layers
185
+
186
+ As an added bonus (or eager test), load the finished model and continue
187
+ training from there. Then check that the initial model and
188
+ the middle model are different, then that the middle model and
189
+ final model are different
190
+
191
+ """
192
+ bert_model = "hf-internal-testing/tiny-bert"
193
+
194
+ trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models"])
195
+ assert os.path.exists(save_filename)
196
+
197
+ save_path = os.path.split(save_filename)[0]
198
+
199
+ initial_model = glob.glob(os.path.join(save_path, "*E0000*"))
200
+ assert len(initial_model) == 1
201
+ initial_model = initial_model[0]
202
+ initial_model = torch.load(initial_model, lambda storage, loc: storage, weights_only=True)
203
+
204
+ second_model_file = glob.glob(os.path.join(save_path, "*E0002*"))
205
+ assert len(second_model_file) == 1
206
+ second_model_file = second_model_file[0]
207
+ second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
208
+
209
+ for layer_idx in range(2):
210
+ bert_names = [x for x in second_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x]
211
+ assert len(bert_names) > 0
212
+ assert all(x in initial_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)
213
+ assert not all(torch.allclose(initial_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)
214
+
215
+ # put some random marker in the file to look for later,
216
+ # check the continued training didn't clobber the expected file
217
+ assert "asdf" not in second_model
218
+ second_model["asdf"] = 1234
219
+ torch.save(second_model, second_model_file)
220
+
221
+ trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file)
222
+
223
+ second_model_file_redo = glob.glob(os.path.join(save_path, "*E0002*"))
224
+ assert len(second_model_file_redo) == 1
225
+ assert second_model_file == second_model_file_redo[0]
226
+ second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
227
+ assert "asdf" in second_model
228
+
229
+ fifth_model_file = glob.glob(os.path.join(save_path, "*E0005*"))
230
+ assert len(fifth_model_file) == 1
231
+
232
+ final_model = torch.load(fifth_model_file[0], lambda storage, loc: storage, weights_only=True)
233
+ for layer_idx in range(2):
234
+ bert_names = [x for x in final_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x]
235
+ assert len(bert_names) > 0
236
+ assert all(x in final_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)
237
+ assert not all(torch.allclose(final_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)
238
+
239
+ def test_finetune_peft(self, tmp_path, fake_embeddings, train_file, dev_file):
240
+ """
241
+ Test on a tiny Bert with PEFT finetuning
242
+ """
243
+ bert_model = "hf-internal-testing/tiny-bert"
244
+
245
+ trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler"])
246
+ assert os.path.exists(save_filename)
247
+ saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
248
+ # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer, but only in peft form
249
+ assert saved_model['params']['config']['bert_model'] == bert_model
250
+ assert saved_model['params']['config']['force_bert_saved']
251
+ assert saved_model['params']['config']['use_peft']
252
+
253
+ assert not saved_model['params']['config']['has_charlm_forward']
254
+ assert not saved_model['params']['config']['has_charlm_backward']
255
+
256
+ assert len(saved_model['params']['bert_lora']) > 0
257
+ assert any(x.find(".pooler.") >= 0 for x in saved_model['params']['bert_lora'])
258
+ assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora'])
259
+ assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
260
+
261
+ # The Pipeline should load and run a PEFT trained model,
262
+ # although obviously we don't expect the results to do
263
+ # anything correct
264
+ pipeline = stanza.Pipeline("en", download_method=None, model_dir=TEST_MODELS_DIR, processors="tokenize,sentiment", sentiment_model_path=save_filename, sentiment_pretrain_path=str(fake_embeddings))
265
+ doc = pipeline("This is a test")
266
+
267
+ def test_finetune_peft_restart(self, tmp_path, fake_embeddings, train_file, dev_file):
268
+ """
269
+ Test that if we restart training on a peft model, the peft weights change
270
+ """
271
+ bert_model = "hf-internal-testing/tiny-bert"
272
+
273
+ trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models"])
274
+
275
+ assert os.path.exists(save_file)
276
+ saved_model = torch.load(save_file, lambda storage, loc: storage, weights_only=True)
277
+ assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora'])
278
+
279
+
280
+ trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file)
281
+
282
+ save_path = os.path.split(save_file)[0]
283
+
284
+ initial_model_file = glob.glob(os.path.join(save_path, "*E0000*"))
285
+ assert len(initial_model_file) == 1
286
+ initial_model_file = initial_model_file[0]
287
+ initial_model = torch.load(initial_model_file, lambda storage, loc: storage, weights_only=True)
288
+
289
+ second_model_file = glob.glob(os.path.join(save_path, "*E0002*"))
290
+ assert len(second_model_file) == 1
291
+ second_model_file = second_model_file[0]
292
+ second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
293
+
294
+ final_model_file = glob.glob(os.path.join(save_path, "*E0005*"))
295
+ assert len(final_model_file) == 1
296
+ final_model_file = final_model_file[0]
297
+ final_model = torch.load(final_model_file, lambda storage, loc: storage, weights_only=True)
298
+
299
+ # params in initial_model & second_model start with "base_model.model."
300
+ # whereas params in final_model start directly with "encoder" or "pooler"
301
+ initial_lora = initial_model['params']['bert_lora']
302
+ second_lora = second_model['params']['bert_lora']
303
+ final_lora = final_model['params']['bert_lora']
304
+ for side in ("_A.", "_B."):
305
+ for layer in (".0.", ".1."):
306
+ initial_params = sorted([x for x in initial_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
307
+ second_params = sorted([x for x in second_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
308
+ final_params = sorted([x for x in final_lora if x.startswith("encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
309
+ assert len(initial_params) > 0
310
+ assert len(initial_params) == len(second_params)
311
+ assert len(initial_params) == len(final_params)
312
+ for x, y in zip(second_params, final_params):
313
+ assert x.endswith(y)
314
+ if side != "_A.": # the A tensors don't move very much, if at all
315
+ assert not torch.allclose(initial_lora.get(x), second_lora.get(x))
316
+ assert not torch.allclose(second_lora.get(x), final_lora.get(y))
317
+
stanza/stanza/tests/classifiers/test_process_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A few tests of the utils module for the sentiment datasets
3
+ """
4
+
5
+ import os
6
+ import pytest
7
+
8
+ import stanza
9
+
10
+ from stanza.models.classifiers import data
11
+ from stanza.models.classifiers.data import SentimentDatum
12
+ from stanza.models.classifiers.utils import WVType
13
+ from stanza.utils.datasets.sentiment import process_utils
14
+
15
+ from stanza.tests import TEST_MODELS_DIR
16
+ from stanza.tests.classifiers.test_data import train_file, dev_file, test_file
17
+
18
+
19
+ def test_write_list(tmp_path, train_file):
20
+ """
21
+ Test that writing a single list of items to an output file works
22
+ """
23
+ train_set = data.read_dataset(train_file, WVType.OTHER, 1)
24
+
25
+ dataset_file = tmp_path / "foo.json"
26
+ process_utils.write_list(dataset_file, train_set)
27
+
28
+ train_copy = data.read_dataset(dataset_file, WVType.OTHER, 1)
29
+ assert train_copy == train_set
30
+
31
+ def test_write_dataset(tmp_path, train_file, dev_file, test_file):
32
+ """
33
+ Test that writing all three parts of a dataset works
34
+ """
35
+ dataset = [data.read_dataset(filename, WVType.OTHER, 1) for filename in (train_file, dev_file, test_file)]
36
+ process_utils.write_dataset(dataset, tmp_path, "en_test")
37
+
38
+ expected_files = ['en_test.train.json', 'en_test.dev.json', 'en_test.test.json']
39
+ dataset_files = os.listdir(tmp_path)
40
+ assert sorted(dataset_files) == sorted(expected_files)
41
+
42
+ for filename, expected in zip(expected_files, dataset):
43
+ written = data.read_dataset(tmp_path / filename, WVType.OTHER, 1)
44
+ assert written == expected
45
+
46
+ def test_read_snippets(tmp_path):
47
+ """
48
+ Test the basic operation of the read_snippets function
49
+ """
50
+ filename = tmp_path / "foo.csv"
51
+ with open(filename, "w", encoding="utf-8") as fout:
52
+ fout.write("FOO\tThis is a test\thappy\n")
53
+ fout.write("FOO\tThis is a second sentence\tsad\n")
54
+
55
+ nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
56
+
57
+ mapping = {"happy": 0, "sad": 1}
58
+
59
+ snippets = process_utils.read_snippets(filename, 2, 1, "en", mapping, nlp=nlp)
60
+ assert len(snippets) == 2
61
+ assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
62
+ SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence'])]
63
+
64
+ def test_read_snippets_two_columns(tmp_path):
65
+ """
66
+ Test what happens when multiple columns are combined for the sentiment value
67
+ """
68
+ filename = tmp_path / "foo.csv"
69
+ with open(filename, "w", encoding="utf-8") as fout:
70
+ fout.write("FOO\tThis is a test\thappy\tfoo\n")
71
+ fout.write("FOO\tThis is a second sentence\tsad\tbar\n")
72
+ fout.write("FOO\tThis is a third sentence\tsad\tfoo\n")
73
+
74
+ nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
75
+
76
+ mapping = {("happy", "foo"): 0, ("sad", "bar"): 1, ("sad", "foo"): 2}
77
+
78
+ snippets = process_utils.read_snippets(filename, (2,3), 1, "en", mapping, nlp=nlp)
79
+ assert len(snippets) == 3
80
+ assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
81
+ SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence']),
82
+ SentimentDatum(sentiment=2, text=['This', 'is', 'a', 'third', 'sentence'])]
83
+
stanza/stanza/tests/common/test_bert_embedding.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+
4
+ from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings
5
+
6
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
7
+
8
+ BERT_MODEL = "hf-internal-testing/tiny-bert"
9
+
10
+ @pytest.fixture(scope="module")
11
+ def tiny_bert():
12
+ m, t = load_bert(BERT_MODEL)
13
+ return m, t
14
+
15
+ def test_load_bert(tiny_bert):
16
+ """
17
+ Empty method that just tests loading the bert
18
+ """
19
+ m, t = tiny_bert
20
+
21
+ def test_run_bert(tiny_bert):
22
+ m, t = tiny_bert
23
+ device = next(m.parameters()).device
24
+ extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True)
25
+
26
+ def test_run_bert_empty_word(tiny_bert):
27
+ m, t = tiny_bert
28
+ device = next(m.parameters()).device
29
+ foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True)
30
+ bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True)
31
+
32
+ assert len(foo) == 1
33
+ assert torch.allclose(foo[0], bar[0])
stanza/stanza/tests/common/test_char_model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Currently tests a few configurations of files for creating a charlm vocab
3
+
4
+ Also has a skeleton test of loading & saving a charlm
5
+ """
6
+
7
+ from collections import Counter
8
+ import glob
9
+ import lzma
10
+ import os
11
+ import tempfile
12
+
13
+ import pytest
14
+
15
+ from stanza.models import charlm
16
+ from stanza.models.common import char_model
17
+ from stanza.tests import TEST_MODELS_DIR
18
+
19
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
20
+
21
+ fake_text_1 = """
22
+ Unban mox opal!
23
+ I hate watching Peppa Pig
24
+ """
25
+
26
+ fake_text_2 = """
27
+ This is plastic cheese
28
+ """
29
+
30
+ class TestCharModel:
31
+ def test_single_file_vocab(self):
32
+ with tempfile.TemporaryDirectory() as tempdir:
33
+ sample_file = os.path.join(tempdir, "text.txt")
34
+ with open(sample_file, "w", encoding="utf-8") as fout:
35
+ fout.write(fake_text_1)
36
+ vocab = char_model.build_charlm_vocab(sample_file)
37
+
38
+ for i in fake_text_1:
39
+ assert i in vocab
40
+ assert "Q" not in vocab
41
+
42
+ def test_single_file_xz_vocab(self):
43
+ with tempfile.TemporaryDirectory() as tempdir:
44
+ sample_file = os.path.join(tempdir, "text.txt.xz")
45
+ with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
46
+ fout.write(fake_text_1)
47
+ vocab = char_model.build_charlm_vocab(sample_file)
48
+
49
+ for i in fake_text_1:
50
+ assert i in vocab
51
+ assert "Q" not in vocab
52
+
53
+ def test_single_file_dir_vocab(self):
54
+ with tempfile.TemporaryDirectory() as tempdir:
55
+ sample_file = os.path.join(tempdir, "text.txt")
56
+ with open(sample_file, "w", encoding="utf-8") as fout:
57
+ fout.write(fake_text_1)
58
+ vocab = char_model.build_charlm_vocab(tempdir)
59
+
60
+ for i in fake_text_1:
61
+ assert i in vocab
62
+ assert "Q" not in vocab
63
+
64
+ def test_multiple_files_vocab(self):
65
+ with tempfile.TemporaryDirectory() as tempdir:
66
+ sample_file = os.path.join(tempdir, "t1.txt")
67
+ with open(sample_file, "w", encoding="utf-8") as fout:
68
+ fout.write(fake_text_1)
69
+ sample_file = os.path.join(tempdir, "t2.txt.xz")
70
+ with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
71
+ fout.write(fake_text_2)
72
+ vocab = char_model.build_charlm_vocab(tempdir)
73
+
74
+ for i in fake_text_1:
75
+ assert i in vocab
76
+ for i in fake_text_2:
77
+ assert i in vocab
78
+ assert "Q" not in vocab
79
+
80
+ def test_cutoff_vocab(self):
81
+ with tempfile.TemporaryDirectory() as tempdir:
82
+ sample_file = os.path.join(tempdir, "t1.txt")
83
+ with open(sample_file, "w", encoding="utf-8") as fout:
84
+ fout.write(fake_text_1)
85
+ sample_file = os.path.join(tempdir, "t2.txt.xz")
86
+ with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
87
+ fout.write(fake_text_2)
88
+
89
+ vocab = char_model.build_charlm_vocab(tempdir, cutoff=2)
90
+
91
+ counts = Counter(fake_text_1) + Counter(fake_text_2)
92
+ for letter, count in counts.most_common():
93
+ if count < 2:
94
+ assert letter not in vocab
95
+ else:
96
+ assert letter in vocab
97
+
98
+ def test_build_model(self):
99
+ """
100
+ Test the whole thing on a small dataset for an iteration or two
101
+ """
102
+ with tempfile.TemporaryDirectory() as tempdir:
103
+ eval_file = os.path.join(tempdir, "en_test.dev.txt")
104
+ with open(eval_file, "w", encoding="utf-8") as fout:
105
+ fout.write(fake_text_1)
106
+ train_file = os.path.join(tempdir, "en_test.train.txt")
107
+ with open(train_file, "w", encoding="utf-8") as fout:
108
+ for i in range(1000):
109
+ fout.write(fake_text_1)
110
+ fout.write("\n")
111
+ fout.write(fake_text_2)
112
+ fout.write("\n")
113
+ save_name = 'en_test.forward.pt'
114
+ vocab_save_name = 'en_text.vocab.pt'
115
+ checkpoint_save_name = 'en_text.checkpoint.pt'
116
+ args = ['--train_file', train_file,
117
+ '--eval_file', eval_file,
118
+ '--eval_steps', '0', # eval once per opoch
119
+ '--epochs', '2',
120
+ '--cutoff', '1',
121
+ '--batch_size', '%d' % len(fake_text_1),
122
+ '--shorthand', 'en_test',
123
+ '--save_dir', tempdir,
124
+ '--save_name', save_name,
125
+ '--vocab_save_name', vocab_save_name,
126
+ '--checkpoint_save_name', checkpoint_save_name]
127
+ args = charlm.parse_args(args)
128
+ charlm.train(args)
129
+
130
+ assert os.path.exists(os.path.join(tempdir, vocab_save_name))
131
+
132
+ # test that saving & loading of the model worked
133
+ assert os.path.exists(os.path.join(tempdir, save_name))
134
+ model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name))
135
+
136
+ # test that saving & loading of the checkpoint worked
137
+ assert os.path.exists(os.path.join(tempdir, checkpoint_save_name))
138
+ model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name))
139
+ trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name))
140
+
141
+ assert trainer.global_step > 0
142
+ assert trainer.epoch == 2
143
+
144
+ # quick test to verify this method works with a trained model
145
+ charlm.get_current_lr(trainer, args)
146
+
147
+ # test loading a vocab built by the training method...
148
+ vocab = charlm.load_char_vocab(os.path.join(tempdir, vocab_save_name))
149
+ trainer = char_model.CharacterLanguageModelTrainer.from_new_model(args, vocab)
150
+ # ... and test the get_current_lr for an untrained model as well
151
+ # this test is super "eager"
152
+ assert charlm.get_current_lr(trainer, args) == args['lr0']
153
+
154
+ @pytest.fixture(scope="class")
155
+ def english_forward(self):
156
+ # eg, stanza_test/models/en/forward_charlm/1billion.pt
157
+ models_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "*")
158
+ models = glob.glob(models_path)
159
+ # we expect at least one English model downloaded for the tests
160
+ assert len(models) >= 1
161
+ model_file = models[0]
162
+ return char_model.CharacterLanguageModel.load(model_file)
163
+
164
+ @pytest.fixture(scope="class")
165
+ def english_backward(self):
166
+ # eg, stanza_test/models/en/forward_charlm/1billion.pt
167
+ models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*")
168
+ models = glob.glob(models_path)
169
+ # we expect at least one English model downloaded for the tests
170
+ assert len(models) >= 1
171
+ model_file = models[0]
172
+ return char_model.CharacterLanguageModel.load(model_file)
173
+
174
+ def test_load_model(self, english_forward, english_backward):
175
+ """
176
+ Check that basic loading functions work
177
+ """
178
+ assert english_forward.is_forward_lm
179
+ assert not english_backward.is_forward_lm
180
+
181
+ def test_save_load_model(self, english_forward, english_backward):
182
+ """
183
+ Load, save, and load again
184
+ """
185
+ with tempfile.TemporaryDirectory() as tempdir:
186
+ for model in (english_forward, english_backward):
187
+ save_file = os.path.join(tempdir, "resaved", "charlm.pt")
188
+ model.save(save_file)
189
+ reloaded = char_model.CharacterLanguageModel.load(save_file)
190
+ assert model.is_forward_lm == reloaded.is_forward_lm
stanza/stanza/tests/common/test_common_data.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import stanza
3
+
4
+ from stanza.tests import *
5
+ from stanza.models.common.data import get_augment_ratio, augment_punct
6
+
7
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
8
+
9
+ def test_augment_ratio():
10
+ data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
11
+ should_augment = lambda x: x >= 3
12
+ can_augment = lambda x: x >= 4
13
+ # check that zero is returned if no augmentation is needed
14
+ # which will be the case since 2 are already satisfactory
15
+ assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0
16
+
17
+ # this should throw an error
18
+ with pytest.raises(AssertionError):
19
+ get_augment_ratio(data, can_augment, should_augment)
20
+
21
+ # with a desired ratio of 0.4,
22
+ # there are already 2 that don't need augmenting
23
+ # and 7 that are eligible to be augmented
24
+ # so 2/7 will need to be augmented
25
+ assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7)
26
+
27
+ def test_augment_punct():
28
+ data = [["Simple", "test", "."]]
29
+ should_augment = lambda x: x[-1] == "."
30
+ can_augment = should_augment
31
+ new_data = augment_punct(data, 1.0, should_augment, can_augment)
32
+ assert new_data == [["Simple", "test"]]
stanza/stanza/tests/common/test_data_objects.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic tests of the stanza data objects, especially the setter/getter routines
3
+ """
4
+ import pytest
5
+
6
+ import stanza
7
+ from stanza.models.common.doc import Document, Sentence, Word
8
+ from stanza.tests import *
9
+
10
+ pytestmark = pytest.mark.pipeline
11
+
12
+ # data for testing
13
+ EN_DOC = "This is a test document. Pretty cool!"
14
+
15
+ EN_DOC_UPOS_XPOS = (('PRON_DT', 'AUX_VBZ', 'DET_DT', 'NOUN_NN', 'NOUN_NN', 'PUNCT_.'), ('ADV_RB', 'ADJ_JJ', 'PUNCT_.'))
16
+
17
+ EN_DOC2 = "Chris Manning wrote a sentence. Then another."
18
+
19
+ @pytest.fixture(scope="module")
20
+ def nlp_pipeline():
21
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en')
22
+ return nlp
23
+
24
+ def test_readonly(nlp_pipeline):
25
+ Document.add_property('some_property', 123)
26
+ doc = nlp_pipeline(EN_DOC)
27
+ assert doc.some_property == 123
28
+ with pytest.raises(ValueError):
29
+ doc.some_property = 456
30
+
31
+
32
+ def test_getter(nlp_pipeline):
33
+ Word.add_property('upos_xpos', getter=lambda self: f"{self.upos}_{self.xpos}")
34
+
35
+ doc = nlp_pipeline(EN_DOC)
36
+
37
+ assert EN_DOC_UPOS_XPOS == tuple(tuple(word.upos_xpos for word in sentence.words) for sentence in doc.sentences)
38
+
39
+ def test_setter_getter(nlp_pipeline):
40
+ int2str = {0: 'ok', 1: 'good', 2: 'bad'}
41
+ str2int = {'ok': 0, 'good': 1, 'bad': 2}
42
+ def setter(self, value):
43
+ self._classname = str2int[value]
44
+ Sentence.add_property('classname', getter=lambda self: int2str[self._classname] if self._classname is not None else None, setter=setter)
45
+
46
+ doc = nlp_pipeline(EN_DOC)
47
+ sentence = doc.sentences[0]
48
+ sentence.classname = 'good'
49
+ assert sentence._classname == 1
50
+
51
+ # don't try this at home
52
+ sentence._classname = 2
53
+ assert sentence.classname == 'bad'
54
+
55
+ def test_backpointer(nlp_pipeline):
56
+ doc = nlp_pipeline(EN_DOC2)
57
+ ent = doc.ents[0]
58
+ assert ent.sent is doc.sentences[0]
59
+ assert list(doc.iter_words())[0].sent is doc.sentences[0]
60
+ assert list(doc.iter_tokens())[-1].sent is doc.sentences[-1]
stanza/stanza/tests/common/test_doc.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import stanza
4
+ from stanza.tests import *
5
+ from stanza.models.common.doc import Document, ID, TEXT, NER, CONSTITUENCY, SENTIMENT
6
+
7
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
8
+
9
+ @pytest.fixture
10
+ def sentences_dict():
11
+ return [[{ID: 1, TEXT: "unban"},
12
+ {ID: 2, TEXT: "mox"},
13
+ {ID: 3, TEXT: "opal"}],
14
+ [{ID: 4, TEXT: "ban"},
15
+ {ID: 5, TEXT: "Lurrus"}]]
16
+
17
+ @pytest.fixture
18
+ def doc(sentences_dict):
19
+ doc = Document(sentences_dict)
20
+ return doc
21
+
22
+ def test_basic_values(doc, sentences_dict):
23
+ """
24
+ Test that sentences & token text are properly set when constructing a doc
25
+ """
26
+ assert len(doc.sentences) == len(sentences_dict)
27
+
28
+ for sentence, raw_sentence in zip(doc.sentences, sentences_dict):
29
+ assert sentence.doc == doc
30
+ assert len(sentence.tokens) == len(raw_sentence)
31
+ for token, raw_token in zip(sentence.tokens, raw_sentence):
32
+ assert token.text == raw_token[TEXT]
33
+
34
+ def test_set_sentence(doc):
35
+ """
36
+ Test setting a field on the sentences themselves
37
+ """
38
+ doc.set(fields="sentiment",
39
+ contents=["4", "0"],
40
+ to_sentence=True)
41
+
42
+ assert doc.sentences[0].sentiment == "4"
43
+ assert doc.sentences[1].sentiment == "0"
44
+
45
+ def test_set_tokens(doc):
46
+ """
47
+ Test setting values on tokens
48
+ """
49
+ ner_contents = ["O", "ARTIFACT", "ARTIFACT", "O", "CAT"]
50
+ doc.set(fields=NER,
51
+ contents=ner_contents,
52
+ to_token=True)
53
+
54
+ result = doc.get(NER, from_token=True)
55
+ assert result == ner_contents
56
+
57
+
58
+ def test_constituency_comment(doc):
59
+ """
60
+ Test that setting the constituency tree on a doc sets the constituency comment
61
+ """
62
+ for sentence in doc.sentences:
63
+ assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 0
64
+
65
+ # currently nothing is checking that the items are actually trees
66
+ trees = ["asdf", "zzzz"]
67
+ doc.set(fields=CONSTITUENCY,
68
+ contents=trees,
69
+ to_sentence=True)
70
+
71
+ for sentence, expected in zip(doc.sentences, trees):
72
+ constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")]
73
+ assert len(constituency_comments) == 1
74
+ assert constituency_comments[0].endswith(expected)
75
+
76
+ # Test that if we replace the trees with an updated tree, the comment is also replaced
77
+ trees = ["zzzz", "asdf"]
78
+ doc.set(fields=CONSTITUENCY,
79
+ contents=trees,
80
+ to_sentence=True)
81
+
82
+ for sentence, expected in zip(doc.sentences, trees):
83
+ constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")]
84
+ assert len(constituency_comments) == 1
85
+ assert constituency_comments[0].endswith(expected)
86
+
87
+ def test_sentiment_comment(doc):
88
+ """
89
+ Test that setting the sentiment on a doc sets the sentiment comment
90
+ """
91
+ for sentence in doc.sentences:
92
+ assert len([x for x in sentence.comments if x.startswith("# sentiment")]) == 0
93
+
94
+ # currently nothing is checking that the items are actually trees
95
+ sentiments = ["1", "2"]
96
+ doc.set(fields=SENTIMENT,
97
+ contents=sentiments,
98
+ to_sentence=True)
99
+
100
+ for sentence, expected in zip(doc.sentences, sentiments):
101
+ sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")]
102
+ assert len(sentiment_comments) == 1
103
+ assert sentiment_comments[0].endswith(expected)
104
+
105
+ # Test that if we replace the trees with an updated tree, the comment is also replaced
106
+ sentiments = ["3", "4"]
107
+ doc.set(fields=SENTIMENT,
108
+ contents=sentiments,
109
+ to_sentence=True)
110
+
111
+ for sentence, expected in zip(doc.sentences, sentiments):
112
+ sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")]
113
+ assert len(sentiment_comments) == 1
114
+ assert sentiment_comments[0].endswith(expected)
115
+
116
+ def test_sent_id_comment(doc):
117
+ """
118
+ Test that setting the sent_id on a sentence sets the sentiment comment
119
+ """
120
+ for sent_idx, sentence in enumerate(doc.sentences):
121
+ assert len([x for x in sentence.comments if x.startswith("# sent_id")]) == 1
122
+ assert sentence.sent_id == "%d" % sent_idx
123
+ doc.sentences[0].sent_id = "foo"
124
+ assert doc.sentences[0].sent_id == "foo"
125
+ assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
126
+ assert "# sent_id = foo" in doc.sentences[0].comments
127
+
128
+ doc.reindex_sentences(10)
129
+ for sent_idx, sentence in enumerate(doc.sentences):
130
+ assert sentence.sent_id == "%d" % (sent_idx + 10)
131
+ assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
132
+ assert "# sent_id = %d" % (sent_idx + 10) in sentence.comments
133
+
134
+ doc.sentences[0].add_comment("# sent_id = bar")
135
+ assert doc.sentences[0].sent_id == "bar"
136
+ assert "# sent_id = bar" in doc.sentences[0].comments
137
+ assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
138
+
139
+ def test_doc_id_comment(doc):
140
+ """
141
+ Test that setting the doc_id on a sentence sets the document comment
142
+ """
143
+ assert doc.sentences[0].doc_id is None
144
+ assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 0
145
+
146
+ doc.sentences[0].doc_id = "foo"
147
+ assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1
148
+ assert "# doc_id = foo" in doc.sentences[0].comments
149
+ assert doc.sentences[0].doc_id == "foo"
150
+
151
+ doc.sentences[0].add_comment("# doc_id = bar")
152
+ assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1
153
+ assert doc.sentences[0].doc_id == "bar"
154
+
155
+ @pytest.fixture(scope="module")
156
+ def pipeline():
157
+ return stanza.Pipeline(dir=TEST_MODELS_DIR)
158
+
159
+ def test_serialized(pipeline):
160
+ """
161
+ Brief test of the serialized format
162
+
163
+ Checks that NER entities are correctly set.
164
+ Also checks that constituency & sentiment are set on the sentences.
165
+ """
166
+ text = "John Bauer works at Stanford"
167
+ doc = pipeline(text)
168
+ assert len(doc.ents) == 2
169
+ serialized = doc.to_serialized()
170
+ doc2 = Document.from_serialized(serialized)
171
+ assert len(doc2.sentences) == 1
172
+ assert len(doc2.ents) == 2
173
+ assert doc.sentences[0].constituency == doc2.sentences[0].constituency
174
+ assert doc.sentences[0].sentiment == doc2.sentences[0].sentiment
stanza/stanza/tests/common/test_dropout.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ import stanza
6
+ from stanza.models.common.dropout import WordDropout
7
+
8
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
9
+
10
+ def test_word_dropout():
11
+ """
12
+ Test that word_dropout is randomly dropping out the entire final dimension of a tensor
13
+
14
+ Doing 600 small rows should be super fast, but it leaves us with
15
+ something like a 1 in 10^180 chance of the test failing. Not very
16
+ common, in other words
17
+ """
18
+ wd = WordDropout(0.5)
19
+ batch = torch.randn(600, 4)
20
+ dropped = wd(batch)
21
+ # the one time any of this happens, it's going to be really confusing
22
+ assert not torch.allclose(batch, dropped)
23
+ num_zeros = 0
24
+ for i in range(batch.shape[0]):
25
+ assert torch.allclose(dropped[i], batch[i]) or torch.sum(dropped[i]) == 0.0
26
+ if torch.sum(dropped[i]) == 0.0:
27
+ num_zeros += 1
28
+ assert num_zeros > 0 and num_zeros < batch.shape[0]
stanza/stanza/tests/common/test_short_name_to_treebank.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import stanza
4
+ from stanza.models.common import short_name_to_treebank
5
+
6
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
7
+
8
+ def test_short_name():
9
+ assert short_name_to_treebank.short_name_to_treebank("en_ewt") == "UD_English-EWT"
10
+
11
+ def test_canonical_name():
12
+ assert short_name_to_treebank.canonical_treebank_name("UD_URDU-UDTB") == "UD_Urdu-UDTB"
13
+ assert short_name_to_treebank.canonical_treebank_name("ur_udtb") == "UD_Urdu-UDTB"
14
+ assert short_name_to_treebank.canonical_treebank_name("Unban_Mox_Opal") == "Unban_Mox_Opal"
stanza/stanza/tests/constituency/test_convert_it_vit.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple different classes of trees to check the output of the VIT conversion
3
+
4
+ A couple representative trees are included, but hopefully not enough
5
+ to be a problem in terms of our license.
6
+
7
+ One of the tests is currently disabled as it relies on tregex & tsurgeon features
8
+ not yet released
9
+ """
10
+
11
+ import io
12
+ import os
13
+ import tempfile
14
+
15
+ import pytest
16
+
17
+ from stanza.server import tsurgeon
18
+ from stanza.utils.conll import CoNLL
19
+ from stanza.utils.datasets.constituency import convert_it_vit
20
+
21
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
22
+
23
+ # just a sample! don't sue us please
24
+ CON_SAMPLE = """
25
+ #ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]]
26
+
27
+ #ID=sent_00318 dirsp-[fc-[congf-tuttavia, f-[sn-[sq-[ind-qualche], n-problema], ir_infl-[vsupir-potrebbe, vcl-esserci], compc-[clit-ci, sp-[p-per, sn-[art-la, n-commissione, sa-[ag-esteri], f2-[sp-[part-alla, relob-cui, sn-[n-presidenza]], f-[ibar-[vc-è], compc-[sn-[n-candidato], sn-[art-l, n-esponente, spd-[pd-di, sn-[mw-Alleanza, npro-Nazionale]], sn-[mw-Mirko, nh-Tremaglia]]]]]]]]]], dirs-':', f3-[sn-[art-una, n-candidatura, sc-[q-più, sa-[ppas-subìta], sc-[ccong-che, sa-[ppas-gradita]], compt-[spda-[partda-dalla, sn-[mw-Lega, npro-Nord, punt-',', f2-[rel-che, fc-[congf-tuttavia, f-[ir_infl-[vsupir-dovrebbe, vit-rispettare], compt-[sn-[art-gli, n-accordi]]]]]]]]]], punto-.]]
28
+
29
+ #ID=sent_00589 f-[sn-[art-l, n-ottimismo, spd-[pd-di, sn-[nh-Kantor]]], ir_infl-[vsupir-potrebbe, congf-però, vcl-rivelarsi], compc-[sn-[in-ancora, art-una, nt-volta], sa-[ag-prematuro]], punto-.]
30
+ """
31
+
32
+ UD_SAMPLE = """
33
+ # sent_id = VIT-2
34
+ # text = Negli ultimi anni la dinamica dei polo di attrazione è stata sempre più caratterizzata dall'emergere di una crescente concorrenza che si è progressivamente spostata dalle singole imprese ai sistemi economici e territoriali, determinando l'esigenza di una riconsiderazione dei rapporti esistenti tra soggetti produttivi e ambiente in cui questi operano.
35
+ 1-2 Negli _ _ _ _ _ _ _ _
36
+ 1 In in ADP E _ 4 case _ _
37
+ 2 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 4 det _ _
38
+ 3 ultimi ultimo ADJ A Gender=Masc|Number=Plur 4 amod _ _
39
+ 4 anni anno NOUN S Gender=Masc|Number=Plur 16 obl _ _
40
+ 5 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 6 det _ _
41
+ 6 dinamica dinamica NOUN S Gender=Fem|Number=Sing 16 nsubj:pass _ _
42
+ 7-8 dei _ _ _ _ _ _ _ _
43
+ 7 di di ADP E _ 9 case _ _
44
+ 8 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 9 det _ _
45
+ 9 polo polo NOUN S Gender=Masc|Number=Sing 6 nmod _ _
46
+ 10 di di ADP E _ 11 case _ _
47
+ 11 attrazione attrazione NOUN S Gender=Fem|Number=Sing 9 nmod _ _
48
+ 12 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux _ _
49
+ 13 stata essere AUX VA Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 16 aux:pass _ _
50
+ 14 sempre sempre ADV B _ 15 advmod _ _
51
+ 15 più più ADV B _ 16 advmod _ _
52
+ 16 caratterizzata caratterizzare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 0 root _ _
53
+ 17-18 dall' _ _ _ _ _ _ _ SpaceAfter=No
54
+ 17 da da ADP E _ 19 case _ _
55
+ 18 l' il DET RD Definite=Def|Number=Sing|PronType=Art 19 det _ _
56
+ 19 emergere emergere NOUN S Gender=Masc|Number=Sing 16 obl _ _
57
+ 20 di di ADP E _ 23 case _ _
58
+ 21 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 23 det _ _
59
+ 22 crescente crescente ADJ A Number=Sing 23 amod _ _
60
+ 23 concorrenza concorrenza NOUN S Gender=Fem|Number=Sing 19 nmod _ _
61
+ 24 che che PRON PR PronType=Rel 28 nsubj _ _
62
+ 25 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 28 expl _ _
63
+ 26 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 28 aux _ _
64
+ 27 progressivamente progressivamente ADV B _ 28 advmod _ _
65
+ 28 spostata spostare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 23 acl:relcl _ _
66
+ 29-30 dalle _ _ _ _ _ _ _ _
67
+ 29 da da ADP E _ 32 case _ _
68
+ 30 le il DET RD Definite=Def|Gender=Fem|Number=Plur|PronType=Art 32 det _ _
69
+ 31 singole singolo ADJ A Gender=Fem|Number=Plur 32 amod _ _
70
+ 32 imprese impresa NOUN S Gender=Fem|Number=Plur 28 obl _ _
71
+ 33-34 ai _ _ _ _ _ _ _ _
72
+ 33 a a ADP E _ 35 case _ _
73
+ 34 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 35 det _ _
74
+ 35 sistemi sistema NOUN S Gender=Masc|Number=Plur 28 obl _ _
75
+ 36 economici economico ADJ A Gender=Masc|Number=Plur 35 amod _ _
76
+ 37 e e CCONJ CC _ 38 cc _ _
77
+ 38 territoriali territoriale ADJ A Number=Plur 36 conj _ SpaceAfter=No
78
+ 39 , , PUNCT FF _ 28 punct _ _
79
+ 40 determinando determinare VERB V VerbForm=Ger 28 advcl _ _
80
+ 41 l' il DET RD Definite=Def|Number=Sing|PronType=Art 42 det _ SpaceAfter=No
81
+ 42 esigenza esigenza NOUN S Gender=Fem|Number=Sing 40 obj _ _
82
+ 43 di di ADP E _ 45 case _ _
83
+ 44 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 45 det _ _
84
+ 45 riconsiderazione riconsiderazione NOUN S Gender=Fem|Number=Sing 42 nmod _ _
85
+ 46-47 dei _ _ _ _ _ _ _ _
86
+ 46 di di ADP E _ 48 case _ _
87
+ 47 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 48 det _ _
88
+ 48 rapporti rapporto NOUN S Gender=Masc|Number=Plur 45 nmod _ _
89
+ 49 esistenti esistente VERB V Number=Plur 48 acl _ _
90
+ 50 tra tra ADP E _ 51 case _ _
91
+ 51 soggetti soggetto NOUN S Gender=Masc|Number=Plur 49 obl _ _
92
+ 52 produttivi produttivo ADJ A Gender=Masc|Number=Plur 51 amod _ _
93
+ 53 e e CCONJ CC _ 54 cc _ _
94
+ 54 ambiente ambiente NOUN S Gender=Masc|Number=Sing 51 conj _ _
95
+ 55 in in ADP E _ 56 case _ _
96
+ 56 cui cui PRON PR PronType=Rel 58 obl _ _
97
+ 57 questi questo PRON PD Gender=Masc|Number=Plur|PronType=Dem 58 nsubj _ _
98
+ 58 operano operare VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 54 acl:relcl _ SpaceAfter=No
99
+ 59 . . PUNCT FS _ 16 punct _ _
100
+
101
+ # sent_id = VIT-318
102
+ # text = Tuttavia qualche problema potrebbe esserci per la commissione esteri alla cui presidenza è candidato l'esponente di Alleanza Nazionale Mirko Tremaglia: una candidatura più subìta che gradita dalla Lega Nord, che tuttavia dovrebbe rispettare gli accordi.
103
+ 1 Tuttavia tuttavia CCONJ CC _ 5 cc _ _
104
+ 2 qualche qualche DET DI Number=Sing|PronType=Ind 3 det _ _
105
+ 3 problema problema NOUN S Gender=Masc|Number=Sing 5 nsubj _ _
106
+ 4 potrebbe potere AUX VA Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux _ _
107
+ 5-6 esserci _ _ _ _ _ _ _ _
108
+ 5 esser essere VERB V VerbForm=Inf 0 root _ _
109
+ 6 ci ci PRON PC Clitic=Yes|Number=Plur|Person=1|PronType=Prs 5 expl _ _
110
+ 7 per per ADP E _ 9 case _ _
111
+ 8 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 9 det _ _
112
+ 9 commissione commissione NOUN S Gender=Fem|Number=Sing 5 obl _ _
113
+ 10 esteri estero ADJ A Gender=Masc|Number=Plur 9 amod _ _
114
+ 11-12 alla _ _ _ _ _ _ _ _
115
+ 11 a a ADP E _ 14 case _ _
116
+ 12 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 14 det _ _
117
+ 13 cui cui DET DR PronType=Rel 14 det:poss _ _
118
+ 14 presidenza presidenza NOUN S Gender=Fem|Number=Sing 16 obl _ _
119
+ 15 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux:pass _ _
120
+ 16 candidato candidare VERB V Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 9 acl:relcl _ _
121
+ 17 l' il DET RD Definite=Def|Number=Sing|PronType=Art 18 det _ SpaceAfter=No
122
+ 18 esponente esponente NOUN S Number=Sing 16 nsubj:pass _ _
123
+ 19 di di ADP E _ 20 case _ _
124
+ 20 Alleanza Alleanza PROPN SP _ 18 nmod _ _
125
+ 21 Nazionale Nazionale PROPN SP _ 20 flat:name _ _
126
+ 22 Mirko Mirko PROPN SP _ 18 nmod _ _
127
+ 23 Tremaglia Tremaglia PROPN SP _ 22 flat:name _ SpaceAfter=No
128
+ 24 : : PUNCT FC _ 22 punct _ _
129
+ 25 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 26 det _ _
130
+ 26 candidatura candidatura NOUN S Gender=Fem|Number=Sing 22 appos _ _
131
+ 27 più più ADV B _ 28 advmod _ _
132
+ 28 subìta subire VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 26 advcl _ _
133
+ 29 che che CCONJ CC _ 30 cc _ _
134
+ 30 gradita gradito ADJ A Gender=Fem|Number=Sing 28 amod _ _
135
+ 31-32 dalla _ _ _ _ _ _ _ _
136
+ 31 da da ADP E _ 33 case _ _
137
+ 32 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 33 det _ _
138
+ 33 Lega Lega PROPN SP _ 28 obl:agent _ _
139
+ 34 Nord Nord PROPN SP _ 33 flat:name _ SpaceAfter=No
140
+ 35 , , PUNCT FC _ 33 punct _ _
141
+ 36 che che PRON PR PronType=Rel 39 nsubj _ _
142
+ 37 tuttavia tuttavia CCONJ CC _ 39 cc _ _
143
+ 38 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 39 aux _ _
144
+ 39 rispettare rispettare VERB V VerbForm=Inf 33 acl:relcl _ _
145
+ 40 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 41 det _ _
146
+ 41 accordi accordio NOUN S Gender=Masc|Number=Plur 39 obj _ SpaceAfter=No
147
+ 42 . . PUNCT FS _ 5 punct _ _
148
+
149
+ # sent_id = VIT-591
150
+ # text = L'ottimismo di Kantor potrebbe però rivelarsi ancora una volta prematuro.
151
+ 1 L' il DET RD Definite=Def|Number=Sing|PronType=Art 2 det _ SpaceAfter=No
152
+ 2 ottimismo ottimismo NOUN S Gender=Masc|Number=Sing 7 nsubj _ _
153
+ 3 di di ADP E _ 4 case _ _
154
+ 4 Kantor Kantor PROPN SP _ 2 nmod _ _
155
+ 5 potrebbe potere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux _ _
156
+ 6 però però ADV B _ 7 advmod _ _
157
+ 7-8 rivelarsi _ _ _ _ _ _ _ _
158
+ 7 rivelar rivelare VERB V VerbForm=Inf 0 root _ _
159
+ 8 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 7 expl _ _
160
+ 9 ancora ancora ADV B _ 7 advmod _ _
161
+ 10 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 11 det _ _
162
+ 11 volta volta NOUN S Gender=Fem|Number=Sing 7 obl _ _
163
+ 12 prematuro prematuro ADJ A Gender=Masc|Number=Sing 7 xcomp _ SpaceAfter=No
164
+ 13 . . PUNCT FS _ 7 punct _ _
165
+ """
166
+
167
+
168
+ def test_process_mwts():
169
+ # dei appears multiple times
170
+ # the verb/pron esserci will be ignored
171
+ expected_mwts = {'Negli': ('In', 'gli'), 'dei': ('di', 'i'), "dall'": ('da', "l'"), 'dalle': ('da', 'le'), 'ai': ('a', 'i'), 'alla': ('a', 'la'), 'dalla': ('da', 'la')}
172
+
173
+ ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)
174
+
175
+ mwts = convert_it_vit.get_mwt(ud_train_data)
176
+ assert expected_mwts == mwts
177
+
178
+ def test_raw_tree():
179
+ con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))
180
+ expected_ids = ["#ID=sent_00002", "#ID=sent_00318", "#ID=sent_00589"]
181
+ expected_trees = ["(ROOT (cp (sp (part negli) (sn (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd dei) (sn (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda dall) (sn (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda dalle) (sn (sa (ag singole)) (n imprese))) (sp (part ai) (sn (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd dei) (sn (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))",
182
+ "(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part alla) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda dalla) (sn (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))",
183
+ "(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"]
184
+ assert len(con_sentences) == 3
185
+ for sentence, expected_id, expected_tree in zip(con_sentences, expected_ids, expected_trees):
186
+ assert sentence[0] == expected_id
187
+ tree = convert_it_vit.raw_tree(sentence[1])
188
+ assert str(tree) == expected_tree
189
+
190
+ def test_update_mwts():
191
+ con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))
192
+ ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)
193
+ mwt_map = convert_it_vit.get_mwt(ud_train_data)
194
+ expected_trees=["(ROOT (cp (sp (part In) (sn (art gli) (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd di) (sn (art i) (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda da) (sn (art l') (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda da) (sn (art le) (sa (ag singole)) (n imprese))) (sp (part a) (sn (art i) (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd di) (sn (art i) (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))",
195
+ "(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part a) (art la) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda da) (sn (art la) (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))",
196
+ "(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (clit si) (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"]
197
+ with tsurgeon.Tsurgeon() as tsurgeon_processor:
198
+ for con_sentence, ud_sentence, expected_tree in zip(con_sentences, ud_train_data.sentences, expected_trees):
199
+ con_tree = convert_it_vit.raw_tree(con_sentence[1])
200
+ updated_tree, _ = convert_it_vit.update_mwts_and_special_cases(con_tree, ud_sentence, mwt_map, tsurgeon_processor)
201
+ assert str(updated_tree) == expected_tree
202
+
203
+
204
+ CON_PERCENT_SAMPLE = """
205
+ ID#sent_00020 f-[sn-[art-il, n-tesoro], ibar-[vt-mette], compt-[sp-[part-sul, sn-[n-mercato]], sn-[art-il, num-51%, sp-[p-a, sn-[num-2, n-lire]], sp-[p-per, sn-[n-azione]]]], punto-.]
206
+ ID#sent_00022 dirsp-[f3-[sn-[art-le, n-novità]], dirs-':', f3-[coord-[sn-[n-voto, spd-[pd-di, sn-[n-lista]]], cong-e, sn-[n-tetto, sp-[part-agli, sn-[n-acquisti]], sv3-[vppt-limitato, comppas-[sp-[part-allo, sn-[num-0/5%]]]]]], punto-.]]
207
+ ID#sent_00517 dirsp-[fc-[f-[sn-[art-l, n-aumento, sa-[ag-mensile], spd-[pd-di, sn-[nt-aprile]]], ibar-[ause-è, vppc-stato], compc-[sq-[q-dell_, sn-[num-1/3%]], sp-[p-contro, sn-[art-lo, num-0/7/0/8%, spd-[partd-degli, sn-[sa-[ag-ultimi], num-due, sn-[nt-mesi]]]]]]]]]
208
+ ID#sent_01117 fc-[f-[sn-[art-La, sa-[ag-crescente], n-ripresa, spd-[partd-dei, sn-[n-beni, spd-[pd-di, sn-[n-consumo]]]]], ibar-[vin-deriva], savv-[avv-esclusivamente], compin-[spda-[partda-dal, sn-[n-miglioramento, f2-[spd-[pd-di, sn-[relob-cui]], f-[ibar-[ausa-hanno, vppin-beneficiato], compin-[sn-[n-beni, coord-[sa-[ag-durevoli, fp-[par-'(', sn-[num-plus4/5%], par-')']], cong-e, sa-[ag-semidurevoli, fp-[par-'(', sn-[num-plus1/5%], par-')']]]]]]]]]]], punt-',', fs-[cosu-mentre, f-[sn-[art-i, n-beni, sa-[neg-non, ag-durevoli], fp-[par-'(', sn-[num-min1%], par-')']], ibar-[vt-accusano], cong-ancora, compt-[sn-[art-un, sa-[ag-evidente], n-ritardo]]]], punto-.]
209
+ """
210
+
211
+ CON_PERCENT_LEAVES = [
212
+ ['il', 'tesoro', 'mette', 'sul', 'mercato', 'il', '51', '%%', 'a', '2', 'lire', 'per', 'azione', '.'],
213
+ ['le', 'novità', ':', 'voto', 'di', 'lista', 'e', 'tetto', 'agli', 'acquisti', 'limitato', 'allo', '0,5', '%%', '.'],
214
+ ['l', 'aumento', 'mensile', 'di', 'aprile', 'è', 'stato', "dell'", '1,3', '%%', 'contro', 'lo', '0/7,0/8', '%%', 'degli', 'ultimi', 'due', 'mesi'],
215
+ # the plus and min look bad, but they get cleaned up when merging with the UD version of the dataset
216
+ ['La', 'crescente', 'ripresa', 'dei', 'beni', 'di', 'consumo', 'deriva', 'esclusivamente', 'dal', 'miglioramento', 'di', 'cui', 'hanno', 'beneficiato', 'beni', 'durevoli', '(', 'plus4,5', '%%', ')', 'e', 'semidurevoli', '(', 'plus1,5', '%%', ')', ',', 'mentre', 'i', 'beni', 'non', 'durevoli', '(', 'min1', '%%', ')', 'accusano', 'ancora', 'un', 'evidente', 'ritardo', '.']
217
+ ]
218
+
219
+ def test_read_percent():
220
+ con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_PERCENT_SAMPLE))
221
+ assert len(con_sentences) == len(CON_PERCENT_LEAVES)
222
+ for (_, raw_tree), expected_leaves in zip(con_sentences, CON_PERCENT_LEAVES):
223
+ tree = convert_it_vit.raw_tree(raw_tree)
224
+ words = tree.leaf_labels()
225
+ if expected_leaves is None:
226
+ print(words)
227
+ else:
228
+ assert words == expected_leaves
stanza/stanza/tests/constituency/test_convert_starlang.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test a couple different classes of trees to check the output of the Starlang conversion
3
+ """
4
+
5
+ import os
6
+ import tempfile
7
+
8
+ import pytest
9
+
10
+ from stanza.utils.datasets.constituency import convert_starlang
11
+
12
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
13
+
14
+ TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )"
15
+
16
+ def test_read_tree():
17
+ """
18
+ Test a basic tree read
19
+ """
20
+ tree = convert_starlang.read_tree(TREE)
21
+ assert "(ROOT (S (NP (NP Bayan) (NP Haag)) (VP (NP Elianti) (VP çalar)) (. .)))" == str(tree)
22
+
23
+ def test_missing_word():
24
+ """
25
+ Test that an error is thrown if the word is missing
26
+ """
27
+ tree_text = TREE.replace("turkish=", "foo=")
28
+ with pytest.raises(ValueError):
29
+ tree = convert_starlang.read_tree(tree_text)
30
+
31
+ def test_bad_label():
32
+ """
33
+ Test that an unexpected label results in an error
34
+ """
35
+ tree_text = TREE.replace("(S", "(s")
36
+ with pytest.raises(ValueError):
37
+ tree = convert_starlang.read_tree(tree_text)
stanza/stanza/tests/constituency/test_in_order_oracle.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import pytest
3
+
4
+ from stanza.models.constituency import parse_transitions
5
+ from stanza.models.constituency import tree_reader
6
+ from stanza.models.constituency.base_model import SimpleModel
7
+ from stanza.models.constituency.in_order_oracle import *
8
+ from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme
9
+ from stanza.models.constituency.transition_sequence import build_treebank
10
+
11
+ from stanza.tests import *
12
+ from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
13
+
14
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
15
+
16
+ # A sample tree from PTB with a single unary transition (at a location other than root)
17
+ SINGLE_UNARY_TREE = """
18
+ ( (S
19
+ (NP-SBJ-1 (DT A) (NN record) (NN date) )
20
+ (VP (VBZ has) (RB n't)
21
+ (VP (VBN been)
22
+ (VP (VBN set)
23
+ (NP (-NONE- *-1) ))))
24
+ (. .) ))
25
+ """
26
+
27
+ # [Shift, OpenConstituent(('NP-SBJ-1',)), Shift, Shift, CloseConstituent, OpenConstituent(('S',)), Shift, OpenConstituent(('VP',)), Shift, Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), CloseConstituent, CloseConstituent, CloseConstituent, CloseConstituent, Shift, CloseConstituent, OpenConstituent(('ROOT',)), CloseConstituent]
28
+
29
+ # A sample tree from PTB with a double unary transition (at a location other than root)
30
+ DOUBLE_UNARY_TREE = """
31
+ ( (S
32
+ (NP-SBJ
33
+ (NP (RB Not) (PDT all) (DT those) )
34
+ (SBAR
35
+ (WHNP-3 (WP who) )
36
+ (S
37
+ (NP-SBJ (-NONE- *T*-3) )
38
+ (VP (VBD wrote) ))))
39
+ (VP (VBP oppose)
40
+ (NP (DT the) (NNS changes) ))
41
+ (. .) ))
42
+ """
43
+
44
+ # A sample tree from PTB with a triple unary transition (at a location other than root)
45
+ # The triple unary is at the START of the next bracket, which affects how the
46
+ # dynamic oracle repairs the transition sequence
47
+ TRIPLE_UNARY_START_TREE = """
48
+ ( (S
49
+ (PRN
50
+ (S
51
+ (NP-SBJ (-NONE- *) )
52
+ (VP (VB See) )))
53
+ (, ,)
54
+ (NP-SBJ
55
+ (NP (DT the) (JJ other) (NN rule) )
56
+ (PP (IN of)
57
+ (NP (NN thumb) ))
58
+ (PP (IN about)
59
+ (NP (NN ballooning) )))))
60
+ """
61
+
62
+ # A sample tree from PTB with a triple unary transition (at a location other than root)
63
+ # The triple unary is at the END of the next bracket, which affects how the
64
+ # dynamic oracle repairs the transition sequence
65
+ TRIPLE_UNARY_END_TREE = """
66
+ ( (S
67
+ (NP (NNS optimists) )
68
+ (VP (VBP expect)
69
+ (S
70
+ (NP-SBJ-4 (NNP Hong) (NNP Kong) )
71
+ (VP (TO to)
72
+ (VP (VB hum)
73
+ (ADVP-CLR (RB along) )
74
+ (SBAR-MNR (RB as)
75
+ (S
76
+ (NP-SBJ (-NONE- *-4) )
77
+ (VP (-NONE- *?*)
78
+ (ADVP-TMP (IN before) ))))))))))
79
+ """
80
+
81
+ TREES = [SINGLE_UNARY_TREE, DOUBLE_UNARY_TREE, TRIPLE_UNARY_START_TREE, TRIPLE_UNARY_END_TREE]
82
+ TREEBANK = "\n".join(TREES)
83
+
84
+ NOUN_PHRASE_TREE = """
85
+ ( (NP
86
+ (NP (NNP Chicago) (POS 's))
87
+ (NNP Goodman)
88
+ (NNP Theatre)))
89
+ """
90
+
91
+ WIDE_NP_TREE = """
92
+ ( (S
93
+ (NP-SBJ (DT These) (NNS studies))
94
+ (VP (VBP demonstrate)
95
+ (SBAR (IN that)
96
+ (S
97
+ (NP-SBJ (NNS mice))
98
+ (VP (VBP are)
99
+ (NP-PRD
100
+ (NP (DT a)
101
+ (ADJP (JJ practical)
102
+ (CC and)
103
+ (JJ powerful))
104
+ (JJ experimental) (NN system))
105
+ (SBAR
106
+ (WHADVP-2 (-NONE- *0*))
107
+ (S
108
+ (NP-SBJ (-NONE- *PRO*))
109
+ (VP (TO to)
110
+ (VP (VB study)
111
+ (NP (DT the) (NN genetics)))))))))))))
112
+ """
113
+
114
+ WIDE_TREES = [NOUN_PHRASE_TREE, WIDE_NP_TREE]
115
+ WIDE_TREEBANK = "\n".join(WIDE_TREES)
116
+
117
+ ROOT_LABELS = ["ROOT"]
118
+
119
+ def get_repairs(gold_sequence, wrong_transition, repair_fn):
120
+ """
121
+ Use the repair function and the wrong transition to iterate over the gold sequence
122
+
123
+ Returns a list of possible repairs, one for each position in the sequence
124
+ Repairs are tuples, (idx, seq)
125
+ """
126
+ repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))
127
+ for idx, gold_transition in enumerate(gold_sequence)]
128
+ repairs = [x for x in repairs if x[1] is not None]
129
+ return repairs
130
+
131
+ @pytest.fixture(scope="module")
132
+ def unary_trees():
133
+ trees = tree_reader.read_trees(TREEBANK)
134
+ trees = [t.prune_none().simplify_labels() for t in trees]
135
+ assert len(trees) == len(TREES)
136
+
137
+ return trees
138
+
139
+ @pytest.fixture(scope="module")
140
+ def gold_sequences(unary_trees):
141
+ gold_sequences = build_treebank(unary_trees, TransitionScheme.IN_ORDER)
142
+ return gold_sequences
143
+
144
+ @pytest.fixture(scope="module")
145
+ def wide_trees():
146
+ trees = tree_reader.read_trees(WIDE_TREEBANK)
147
+ trees = [t.prune_none().simplify_labels() for t in trees]
148
+ assert len(trees) == len(WIDE_TREES)
149
+
150
+ return trees
151
+
152
+ def test_wrong_open_root(gold_sequences):
153
+ """
154
+ Test the results of the dynamic oracle on a few trees if the ROOT is mishandled.
155
+ """
156
+ wrong_transition = OpenConstituent("S")
157
+ gold_transition = OpenConstituent("ROOT")
158
+ close_transition = CloseConstituent()
159
+
160
+ for gold_sequence in gold_sequences:
161
+ # each of the sequences should be ended with ROOT, Close
162
+ assert gold_sequence[-2] == gold_transition
163
+
164
+ repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_root_error)
165
+ # there is only spot in the sequence with a ROOT, so there should
166
+ # be exactly one location which affords a S/ROOT replacement
167
+ assert len(repairs) == 1
168
+ repair = repairs[0]
169
+
170
+ # the repair should occur at the -2 position, which is where ROOT is
171
+ assert repair[0] == len(gold_sequence) - 2
172
+ # and the resulting list should have the wrong transition followed by a Close
173
+ # to give the model another chance to close the tree
174
+ expected = gold_sequence[:-2] + [wrong_transition, close_transition] + gold_sequence[-2:]
175
+ assert repair[1] == expected
176
+
177
+ def test_missed_unary(gold_sequences):
178
+ """
179
+ Test the repairs of an open/open error if it is effectively a skipped unary transition
180
+ """
181
+ wrong_transition = OpenConstituent("S")
182
+
183
+ repairs = get_repairs(gold_sequences[0], wrong_transition, fix_wrong_open_unary_chain)
184
+ assert len(repairs) == 0
185
+
186
+ # here we are simulating picking NT-S instead of NT-VP
187
+ # the DOUBLE_UNARY tree has one location where this is relevant, index 11
188
+ repairs = get_repairs(gold_sequences[1], wrong_transition, fix_wrong_open_unary_chain)
189
+ assert len(repairs) == 1
190
+ assert repairs[0][0] == 11
191
+ assert repairs[0][1] == gold_sequences[1][:11] + gold_sequences[1][13:]
192
+
193
+ # the TRIPLE_UNARY_START tree has two locations where this is relevant
194
+ # at index 1, the pattern goes (S (VP ...))
195
+ # so choosing S instead of VP means you can skip the VP and only miss that one bracket
196
+ # at index 5, the pattern goes (S (PRN (S (VP ...))) (...))
197
+ # note that this is capturing a unary transition into a larger constituent
198
+ # skipping the PRN is satisfactory
199
+ repairs = get_repairs(gold_sequences[2], wrong_transition, fix_wrong_open_unary_chain)
200
+ assert len(repairs) == 2
201
+ assert repairs[0][0] == 1
202
+ assert repairs[0][1] == gold_sequences[2][:1] + gold_sequences[2][3:]
203
+ assert repairs[1][0] == 5
204
+ assert repairs[1][1] == gold_sequences[2][:5] + gold_sequences[2][7:]
205
+
206
+ # The TRIPLE_UNARY_END tree has 2 sections of tree for a total of 3 locations
207
+ # where the repair might happen
208
+ # Surprisingly the unary transition at the very start can only be
209
+ # repaired by skipping it and using the outer S transition instead
210
+ # The second repair overall (first repair in the second location)
211
+ # should have a double skip to reach the S node
212
+ repairs = get_repairs(gold_sequences[3], wrong_transition, fix_wrong_open_unary_chain)
213
+ assert len(repairs) == 3
214
+ assert repairs[0][0] == 1
215
+ assert repairs[0][1] == gold_sequences[3][:1] + gold_sequences[3][3:]
216
+ assert repairs[1][0] == 21
217
+ assert repairs[1][1] == gold_sequences[3][:21] + gold_sequences[3][25:]
218
+ assert repairs[2][0] == 23
219
+ assert repairs[2][1] == gold_sequences[3][:23] + gold_sequences[3][25:]
220
+
221
+
222
+ def test_open_with_stuff(unary_trees, gold_sequences):
223
+ wrong_transition = OpenConstituent("S")
224
+ expected_trees = [
225
+ "(ROOT (S (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))",
226
+ "(ROOT (S (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))",
227
+ None,
228
+ "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))"
229
+ ]
230
+
231
+ for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):
232
+ repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_stuff_unary)
233
+ if expected is None:
234
+ assert len(repairs) == 0
235
+ else:
236
+ assert len(repairs) == 1
237
+ result = reconstruct_tree(tree, repairs[0][1])
238
+ assert str(result) == expected
239
+
240
+ def test_general_open(gold_sequences):
241
+ wrong_transition = OpenConstituent("SBARQ")
242
+
243
+ for sequence in gold_sequences:
244
+ repairs = get_repairs(sequence, wrong_transition, fix_wrong_open_general)
245
+ assert len(repairs) == sum(isinstance(x, OpenConstituent) for x in sequence) - 1
246
+ for repair in repairs:
247
+ assert len(repair[1]) == len(sequence)
248
+ assert repair[1][repair[0]] == wrong_transition
249
+ assert repair[1][:repair[0]] == sequence[:repair[0]]
250
+ assert repair[1][repair[0]+1:] == sequence[repair[0]+1:]
251
+
252
+ def test_missed_unary(unary_trees, gold_sequences):
253
+ shift_transition = Shift()
254
+ close_transition = CloseConstituent()
255
+
256
+ expected_close_results = [
257
+ [(12, 2)],
258
+ [(11, 4), (13, 2)],
259
+ # (NP NN thumb) and (NP NN ballooning) are both candidates for this repair
260
+ [(18, 2), (24, 2)],
261
+ [(21, 6), (23, 4), (25, 2)],
262
+ ]
263
+
264
+ expected_shift_results = [
265
+ (),
266
+ (),
267
+ (),
268
+ # (ADVP-CLR (RB along)) is followed by a shift
269
+ [(16, 2)],
270
+ ]
271
+
272
+ for tree, sequence, expected_close, expected_shift in zip(unary_trees, gold_sequences, expected_close_results, expected_shift_results):
273
+ repairs = get_repairs(sequence, close_transition, fix_missed_unary)
274
+ assert len(repairs) == len(expected_close)
275
+ for repair, (expected_idx, expected_len) in zip(repairs, expected_close):
276
+ assert repair[0] == expected_idx
277
+ assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]
278
+
279
+ repairs = get_repairs(sequence, shift_transition, fix_missed_unary)
280
+ assert len(repairs) == len(expected_shift)
281
+ for repair, (expected_idx, expected_len) in zip(repairs, expected_shift):
282
+ assert repair[0] == expected_idx
283
+ assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]
284
+
285
+ def test_open_shift(unary_trees, gold_sequences):
286
+ shift_transition = Shift()
287
+
288
+ expected_repairs = [
289
+ [(7, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"),
290
+ (10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VBN been) (VP (VBN set))) (. .)))")],
291
+ [(7, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WP who) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
292
+ (9, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
293
+ (19, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose) (NP (DT the) (NNS changes)) (. .)))"),
294
+ (21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (DT the) (NNS changes)) (. .)))")],
295
+ [(14, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"),
296
+ (16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"),
297
+ (22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about) (NP (NN ballooning)))))")],
298
+ [(5, "(ROOT (S (NP (NNS optimists)) (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
299
+ (10, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
300
+ (12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
301
+ (14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
302
+ (19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (RB as) (S (VP (ADVP (IN before))))))))))")]
303
+ ]
304
+
305
+ for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):
306
+ repairs = get_repairs(sequence, shift_transition, fix_open_shift)
307
+ assert len(repairs) == len(expected)
308
+ for repair, (idx, expected_tree) in zip(repairs, expected):
309
+ assert repair[0] == idx
310
+ result_tree = reconstruct_tree(tree, repair[1])
311
+ assert str(result_tree) == expected_tree
312
+
313
+
314
+ def test_open_close(unary_trees, gold_sequences):
315
+ close_transition = CloseConstituent()
316
+
317
+ expected_repairs = [
318
+ [(7, "(ROOT (S (S (NP (DT A) (NN record) (NN date)) (VBZ has)) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"),
319
+ (10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VP (VBZ has) (RB n't) (VBN been)) (VP (VBN set))) (. .)))")],
320
+ # missed the WHNP. The surrounding SBAR cannot be created, either
321
+ [(7, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
322
+ # missed the SBAR
323
+ (9, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who))) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
324
+ # missed the VP around "oppose the changes"
325
+ (19, "(ROOT (S (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose)) (NP (DT the) (NNS changes)) (. .)))"),
326
+ # missed the NP in "the changes", looks pretty bad tbh
327
+ (21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VP (VBP oppose) (DT the)) (NNS changes)) (. .)))")],
328
+ [(14, "(ROOT (S (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule))) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"),
329
+ (16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (IN of)) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"),
330
+ (22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about)) (NP (NN ballooning)))))")],
331
+ [(5, "(ROOT (S (S (NP (NNS optimists)) (VBP expect)) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
332
+ (10, "(ROOT (S (NP (NNS optimists)) (VP (VP (VBP expect) (NP (NNP Hong) (NNP Kong))) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
333
+ (12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (S (NP (NNP Hong) (NNP Kong)) (TO to)) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
334
+ (14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (VP (TO to) (VB hum)) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
335
+ (19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (ADVP (RB along)) (RB as)) (S (VP (ADVP (IN before))))))))))")]
336
+ ]
337
+
338
+ for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):
339
+ repairs = get_repairs(sequence, close_transition, fix_open_close)
340
+
341
+ assert len(repairs) == len(expected)
342
+ for repair, (idx, expected_tree) in zip(repairs, expected):
343
+ assert repair[0] == idx
344
+ result_tree = reconstruct_tree(tree, repair[1])
345
+ assert str(result_tree) == expected_tree
346
+
347
+ def test_shift_close(unary_trees, gold_sequences):
348
+ """
349
+ Test the fix for a shift -> close
350
+
351
+ These errors can occur pretty much everywhere, and the fix is quite simple,
352
+ so we only test a few cases.
353
+ """
354
+
355
+ close_transition = CloseConstituent()
356
+
357
+ expected_tree = "(ROOT (S (NP (NP (DT A)) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))"
358
+
359
+ repairs = get_repairs(gold_sequences[0], close_transition, fix_shift_close)
360
+ assert len(repairs) == 7
361
+ result_tree = reconstruct_tree(unary_trees[0], repairs[0][1])
362
+ assert str(result_tree) == expected_tree
363
+
364
+ repairs = get_repairs(gold_sequences[1], close_transition, fix_shift_close)
365
+ assert len(repairs) == 8
366
+
367
+ repairs = get_repairs(gold_sequences[2], close_transition, fix_shift_close)
368
+ assert len(repairs) == 8
369
+
370
+ repairs = get_repairs(gold_sequences[3], close_transition, fix_shift_close)
371
+ assert len(repairs) == 9
372
+ for rep in repairs:
373
+ if rep[0] == 16:
374
+ # This one is special because it occurs as part of a unary
375
+ # in other words, it should go unary, shift
376
+ # and instead we are making it close where the unary should be
377
+ # ... the unary would create "(ADVP (RB along))"
378
+ expected_tree = "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))"
379
+ result_tree = reconstruct_tree(unary_trees[3], rep[1])
380
+ assert str(result_tree) == expected_tree
381
+ break
382
+ else:
383
+ raise AssertionError("Did not find an expected repair location")
384
+
385
+ def test_close_open_shift_nested(unary_trees, gold_sequences):
386
+ shift_transition = Shift()
387
+
388
+ expected_trees = [{},
389
+ {4: "(ROOT (S (NP (RB Not) (PDT all) (DT those) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"},
390
+ {4: "(ROOT (S (VP (VB See)) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
391
+ 13: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"},
392
+ {}]
393
+
394
+ for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):
395
+ repairs = get_repairs(gold_sequence, shift_transition, fix_close_open_shift_nested)
396
+ assert len(repairs) == len(expected)
397
+ if len(expected) >= 1:
398
+ for repair in repairs:
399
+ assert repair[0] in expected.keys()
400
+ result_tree = reconstruct_tree(tree, repair[1])
401
+ assert str(result_tree) == expected[repair[0]]
402
+
403
+ def check_repairs(trees, gold_sequences, expected_trees, transition, repair_fn):
404
+ for tree_idx, (gold_tree, gold_sequence, expected) in enumerate(zip(trees, gold_sequences, expected_trees)):
405
+ repairs = get_repairs(gold_sequence, transition, repair_fn)
406
+ if expected is not None:
407
+ assert len(repairs) == len(expected)
408
+ for repair in repairs:
409
+ assert repair[0] in expected
410
+ result_tree = reconstruct_tree(gold_tree, repair[1])
411
+ assert str(result_tree) == expected[repair[0]]
412
+ else:
413
+ print("---------------------")
414
+ print("{:P}".format(gold_tree))
415
+ print(gold_sequence)
416
+ #print(repairs)
417
+ for repair in repairs:
418
+ print("---------------------")
419
+ print(gold_sequence)
420
+ print(repair[1])
421
+ result_tree = reconstruct_tree(gold_tree, repair[1])
422
+ print("{:P}".format(gold_tree))
423
+ print("{:P}".format(result_tree))
424
+ print(tree_idx)
425
+ print(repair[0])
426
+ print(result_tree)
427
+
428
+ def test_close_open_shift_unambiguous(unary_trees, gold_sequences):
429
+ shift_transition = Shift()
430
+
431
+ expected_trees = [{},
432
+ {8: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who) (S (VP (VBD wrote)))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"},
433
+ {},
434
+ {2: "(ROOT (S (NP (NNS optimists) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))",
435
+ 9: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"}]
436
+ check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_unambiguous_bracket)
437
+
438
+ def test_close_open_shift_ambiguous_early(unary_trees, gold_sequences):
439
+ shift_transition = Shift()
440
+
441
+ expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))))) (. .)))"},
442
+ {16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes)))) (. .)))"},
443
+ {2: "(ROOT (S (PRN (S (VP (VB See) (, ,)))) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
444
+ 6: "(ROOT (S (PRN (S (VP (VB See))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"},
445
+ {}]
446
+ check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_early)
447
+
448
+ def test_close_open_shift_ambiguous_late(unary_trees, gold_sequences):
449
+ shift_transition = Shift()
450
+
451
+ expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))))"},
452
+ {16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .))))"},
453
+ {2: "(ROOT (S (PRN (S (VP (VB See) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))))",
454
+ 6: "(ROOT (S (PRN (S (VP (VB See))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))"},
455
+ {}]
456
+ check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_late)
457
+
458
+
459
+ def test_close_shift_shift(unary_trees, wide_trees):
460
+ """
461
+ Test that close -> shift works when there is a single block shifted after
462
+
463
+ Includes a test specifically that there is no oracle action when there are two blocks after the missed close
464
+ """
465
+ shift_transition = Shift()
466
+
467
+ expected_trees = [{15: "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .))))"},
468
+ {24: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes)) (. .))))"},
469
+ {20: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning)))))))"},
470
+ {17: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"},
471
+ {},
472
+ {}]
473
+
474
+ test_trees = unary_trees + wide_trees
475
+ gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
476
+
477
+ check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_unambiguous)
478
+
479
+
480
+ def test_close_shift_shift_early(unary_trees, wide_trees):
481
+ """
482
+ Test that close -> shift works when there are multiple blocks shifted after
483
+
484
+ Also checks that the single block case is skipped, so as to keep them separate when testing
485
+
486
+ A tree with the expected property was specifically added for this test
487
+ """
488
+ shift_transition = Shift()
489
+
490
+ test_trees = unary_trees + wide_trees
491
+ gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
492
+
493
+ expected_trees = [{},
494
+ {},
495
+ {},
496
+ {},
497
+ {},
498
+ {21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental)) (NN system)) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}]
499
+
500
+ check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_early)
501
+
502
+ def test_close_shift_shift_late(unary_trees, wide_trees):
503
+ """
504
+ Test that close -> shift works when there are multiple blocks shifted after
505
+
506
+ Also checks that the single block case is skipped, so as to keep them separate when testing
507
+
508
+ A tree with the expected property was specifically added for this test
509
+ """
510
+ shift_transition = Shift()
511
+
512
+ test_trees = unary_trees + wide_trees
513
+ gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
514
+
515
+ expected_trees = [{},
516
+ {},
517
+ {},
518
+ {},
519
+ {},
520
+ {21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental) (NN system))) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}]
521
+
522
+ check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_late)
stanza/stanza/tests/constituency/test_lstm_model.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from stanza.models.common import pretrain
7
+ from stanza.models.common.utils import set_random_seed
8
+ from stanza.models.constituency import parse_transitions
9
+ from stanza.tests import *
10
+ from stanza.tests.constituency import test_parse_transitions
11
+ from stanza.tests.constituency.test_trainer import build_trainer
12
+
13
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
14
+
15
+ @pytest.fixture(scope="module")
16
+ def pretrain_file():
17
+ return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
18
+
19
+ def build_model(pretrain_file, *args):
20
+ # By default, we turn off multistage, since that can turn off various other structures in the initial training
21
+ args = ['--no_multistage', '--pattn_num_layers', '4', '--pattn_d_model', '256', '--hidden_size', '128', '--use_lattn'] + list(args)
22
+ trainer = build_trainer(pretrain_file, *args)
23
+ return trainer.model
24
+
25
+ @pytest.fixture(scope="module")
26
+ def unary_model(pretrain_file):
27
+ return build_model(pretrain_file, "--transition_scheme", "TOP_DOWN_UNARY")
28
+
29
+ def test_initial_state(unary_model):
30
+ test_parse_transitions.test_initial_state(unary_model)
31
+
32
+ def test_shift(pretrain_file):
33
+ # TODO: might be good to include some tests specifically for shift
34
+ # in the context of a model with unaries
35
+ model = build_model(pretrain_file)
36
+ test_parse_transitions.test_shift(model)
37
+
38
+ def test_unary(unary_model):
39
+ test_parse_transitions.test_unary(unary_model)
40
+
41
+ def test_unary_requires_root(unary_model):
42
+ test_parse_transitions.test_unary_requires_root(unary_model)
43
+
44
+ def test_open(unary_model):
45
+ test_parse_transitions.test_open(unary_model)
46
+
47
+ def test_compound_open(pretrain_file):
48
+ model = build_model(pretrain_file, '--transition_scheme', "TOP_DOWN_COMPOUND")
49
+ test_parse_transitions.test_compound_open(model)
50
+
51
+ def test_in_order_open(pretrain_file):
52
+ model = build_model(pretrain_file, '--transition_scheme', "IN_ORDER")
53
+ test_parse_transitions.test_in_order_open(model)
54
+
55
+ def test_close(unary_model):
56
+ test_parse_transitions.test_close(unary_model)
57
+
58
+ def run_forward_checks(model, num_states=1):
59
+ """
60
+ Run a couple small transitions and a forward pass on the given model
61
+
62
+ Results are not checked in any way. This function allows for
63
+ testing that building models with various options results in a
64
+ functional model.
65
+ """
66
+ states = test_parse_transitions.build_initial_state(model, num_states)
67
+ model(states)
68
+
69
+ shift = parse_transitions.Shift()
70
+ shifts = [shift for _ in range(num_states)]
71
+ states = model.bulk_apply(states, shifts)
72
+ model(states)
73
+
74
+ open_transition = parse_transitions.OpenConstituent("NP")
75
+ open_transitions = [open_transition for _ in range(num_states)]
76
+ assert open_transition.is_legal(states[0], model)
77
+ states = model.bulk_apply(states, open_transitions)
78
+ assert states[0].num_opens == 1
79
+ model(states)
80
+
81
+ states = model.bulk_apply(states, shifts)
82
+ model(states)
83
+ states = model.bulk_apply(states, shifts)
84
+ model(states)
85
+ assert states[0].num_opens == 1
86
+ # now should have "mox", "opal" on the constituents
87
+
88
+ close_transition = parse_transitions.CloseConstituent()
89
+ close_transitions = [close_transition for _ in range(num_states)]
90
+ assert close_transition.is_legal(states[0], model)
91
+ states = model.bulk_apply(states, close_transitions)
92
+ assert states[0].num_opens == 0
93
+
94
+ model(states)
95
+
96
+ def test_unary_forward(unary_model):
97
+ """
98
+ Checks that the forward pass doesn't crash when run after various operations
99
+
100
+ Doesn't check the forward pass for making reasonable answers
101
+ """
102
+ run_forward_checks(unary_model)
103
+
104
+ def test_lstm_forward(pretrain_file):
105
+ model = build_model(pretrain_file)
106
+ run_forward_checks(model, num_states=1)
107
+ run_forward_checks(model, num_states=2)
108
+
109
+ def test_lstm_layers(pretrain_file):
110
+ model = build_model(pretrain_file, '--num_lstm_layers', '1')
111
+ run_forward_checks(model)
112
+ model = build_model(pretrain_file, '--num_lstm_layers', '2')
113
+ run_forward_checks(model)
114
+ model = build_model(pretrain_file, '--num_lstm_layers', '3')
115
+ run_forward_checks(model)
116
+
117
+ def test_multiple_output_forward(pretrain_file):
118
+ """
119
+ Test a couple different sizes of output layers
120
+ """
121
+ model = build_model(pretrain_file, '--num_output_layers', '1', '--num_lstm_layers', '2')
122
+ run_forward_checks(model)
123
+
124
+ model = build_model(pretrain_file, '--num_output_layers', '2', '--num_lstm_layers', '2')
125
+ run_forward_checks(model)
126
+
127
+ model = build_model(pretrain_file, '--num_output_layers', '3', '--num_lstm_layers', '2')
128
+ run_forward_checks(model)
129
+
130
+ def test_no_tag_embedding_forward(pretrain_file):
131
+ """
132
+ Test that the model continues to work if the tag embedding is turned on or off
133
+ """
134
+ model = build_model(pretrain_file, '--tag_embedding_dim', '20')
135
+ run_forward_checks(model)
136
+
137
+ model = build_model(pretrain_file, '--tag_embedding_dim', '0')
138
+ run_forward_checks(model)
139
+
140
+ def test_forward_combined_dummy(pretrain_file):
141
+ """
142
+ Tests combined dummy and open node embeddings
143
+ """
144
+ model = build_model(pretrain_file, '--combined_dummy_embedding')
145
+ run_forward_checks(model)
146
+
147
+ model = build_model(pretrain_file, '--no_combined_dummy_embedding')
148
+ run_forward_checks(model)
149
+
150
+ def test_nonlinearity_init(pretrain_file):
151
+ """
152
+ Tests that different initialization methods of the nonlinearities result in valid tensors
153
+ """
154
+ model = build_model(pretrain_file, '--nonlinearity', 'relu')
155
+ run_forward_checks(model)
156
+
157
+ model = build_model(pretrain_file, '--nonlinearity', 'tanh')
158
+ run_forward_checks(model)
159
+
160
+ model = build_model(pretrain_file, '--nonlinearity', 'silu')
161
+ run_forward_checks(model)
162
+
163
+ def test_forward_charlm(pretrain_file):
164
+ """
165
+ Tests loading and running a charlm
166
+
167
+ Note that this doesn't test the results of the charlm itself,
168
+ just that the model is shaped correctly
169
+ """
170
+ forward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "1billion.pt")
171
+ backward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "1billion.pt")
172
+ assert os.path.exists(forward_charlm_path), "Need to download en test models (or update path to the forward charlm)"
173
+ assert os.path.exists(backward_charlm_path), "Need to download en test models (or update path to the backward charlm)"
174
+
175
+ model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'none')
176
+ run_forward_checks(model)
177
+
178
+ model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'words')
179
+ run_forward_checks(model)
180
+
181
+ def test_forward_bert(pretrain_file):
182
+ """
183
+ Test on a tiny Bert, which hopefully does not take up too much disk space or memory
184
+ """
185
+ bert_model = "hf-internal-testing/tiny-bert"
186
+
187
+ model = build_model(pretrain_file, '--bert_model', bert_model)
188
+ run_forward_checks(model)
189
+
190
+
191
+ def test_forward_xlnet(pretrain_file):
192
+ """
193
+ Test on a tiny xlnet, which hopefully does not take up too much disk space or memory
194
+ """
195
+ bert_model = "hf-internal-testing/tiny-random-xlnet"
196
+
197
+ model = build_model(pretrain_file, '--bert_model', bert_model)
198
+ run_forward_checks(model)
199
+
200
+
201
+ def test_forward_sentence_boundaries(pretrain_file):
202
+ """
203
+ Test start & stop boundary vectors
204
+ """
205
+ model = build_model(pretrain_file, '--sentence_boundary_vectors', 'everything')
206
+ run_forward_checks(model)
207
+
208
+ model = build_model(pretrain_file, '--sentence_boundary_vectors', 'words')
209
+ run_forward_checks(model)
210
+
211
+ model = build_model(pretrain_file, '--sentence_boundary_vectors', 'none')
212
+ run_forward_checks(model)
213
+
214
+ def test_forward_constituency_composition(pretrain_file):
215
+ """
216
+ Test different constituency composition functions
217
+ """
218
+ model = build_model(pretrain_file, '--constituency_composition', 'bilstm')
219
+ run_forward_checks(model, num_states=2)
220
+
221
+ model = build_model(pretrain_file, '--constituency_composition', 'max')
222
+ run_forward_checks(model, num_states=2)
223
+
224
+ model = build_model(pretrain_file, '--constituency_composition', 'key')
225
+ run_forward_checks(model, num_states=2)
226
+
227
+ model = build_model(pretrain_file, '--constituency_composition', 'untied_key')
228
+ run_forward_checks(model, num_states=2)
229
+
230
+ model = build_model(pretrain_file, '--constituency_composition', 'untied_max')
231
+ run_forward_checks(model, num_states=2)
232
+
233
+ model = build_model(pretrain_file, '--constituency_composition', 'bilstm_max')
234
+ run_forward_checks(model, num_states=2)
235
+
236
+ model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm')
237
+ run_forward_checks(model, num_states=2)
238
+
239
+ model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm_cx')
240
+ run_forward_checks(model, num_states=2)
241
+
242
+ model = build_model(pretrain_file, '--constituency_composition', 'bigram')
243
+ run_forward_checks(model, num_states=2)
244
+
245
+ model = build_model(pretrain_file, '--constituency_composition', 'attn')
246
+ run_forward_checks(model, num_states=2)
247
+
248
+ def test_forward_key_position(pretrain_file):
249
+ """
250
+ Test KEY and UNTIED_KEY either with or without reduce_position
251
+ """
252
+ model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0')
253
+ run_forward_checks(model, num_states=2)
254
+
255
+ model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32')
256
+ run_forward_checks(model, num_states=2)
257
+
258
+ model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0')
259
+ run_forward_checks(model, num_states=2)
260
+
261
+ model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32')
262
+ run_forward_checks(model, num_states=2)
263
+
264
+
265
+ def test_forward_attn_hidden_size(pretrain_file):
266
+ """
267
+ Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size
268
+ """
269
+ model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129')
270
+ assert model.hidden_size >= 129
271
+ assert model.hidden_size % model.reduce_heads == 0
272
+ run_forward_checks(model, num_states=2)
273
+
274
+ model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129', '--reduce_heads', '10')
275
+ assert model.hidden_size == 130
276
+ assert model.reduce_heads == 10
277
+
278
+ def test_forward_partitioned_attention(pretrain_file):
279
+ """
280
+ Test with & without partitioned attention layers
281
+ """
282
+ model = build_model(pretrain_file, '--pattn_num_heads', '8', '--pattn_num_layers', '8')
283
+ run_forward_checks(model)
284
+
285
+ model = build_model(pretrain_file, '--pattn_num_heads', '0', '--pattn_num_layers', '0')
286
+ run_forward_checks(model)
287
+
288
+ def test_forward_labeled_attention(pretrain_file):
289
+ """
290
+ Test with & without labeled attention layers
291
+ """
292
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16')
293
+ run_forward_checks(model)
294
+
295
+ model = build_model(pretrain_file, '--lattn_d_proj', '0', '--lattn_d_l', '0')
296
+ run_forward_checks(model)
297
+
298
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input')
299
+ run_forward_checks(model)
300
+
301
+ def test_lattn_partitioned(pretrain_file):
302
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned')
303
+ run_forward_checks(model)
304
+
305
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned')
306
+ run_forward_checks(model)
307
+
308
+
309
+ def test_lattn_projection(pretrain_file):
310
+ """
311
+ Test with & without labeled attention layers
312
+ """
313
+ with pytest.raises(ValueError):
314
+ # this is too small
315
+ model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256', '--lattn_partitioned')
316
+ run_forward_checks(model)
317
+
318
+ model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256')
319
+ run_forward_checks(model)
320
+
321
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768')
322
+ run_forward_checks(model)
323
+
324
+ # check that it works if we turn off the projection,
325
+ # in case having it on beccomes the default
326
+ model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '0')
327
+ run_forward_checks(model)
328
+
329
+ def test_forward_timing_choices(pretrain_file):
330
+ """
331
+ Test different timing / position encodings
332
+ """
333
+ model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'sin')
334
+ run_forward_checks(model)
335
+
336
+ model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'learned')
337
+ run_forward_checks(model)
338
+
339
+ def test_transition_stack(pretrain_file):
340
+ """
341
+ Test different transition stack types: lstm & attention
342
+ """
343
+ model = build_model(pretrain_file,
344
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
345
+ '--transition_stack', 'attn', '--transition_heads', '1')
346
+ run_forward_checks(model)
347
+
348
+ model = build_model(pretrain_file,
349
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
350
+ '--transition_stack', 'attn', '--transition_heads', '4')
351
+ run_forward_checks(model)
352
+
353
+ model = build_model(pretrain_file,
354
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
355
+ '--transition_stack', 'lstm')
356
+ run_forward_checks(model)
357
+
358
+ def test_constituent_stack(pretrain_file):
359
+ """
360
+ Test different constituent stack types: lstm & attention
361
+ """
362
+ model = build_model(pretrain_file,
363
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
364
+ '--constituent_stack', 'attn', '--constituent_heads', '1')
365
+ run_forward_checks(model)
366
+
367
+ model = build_model(pretrain_file,
368
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
369
+ '--constituent_stack', 'attn', '--constituent_heads', '4')
370
+ run_forward_checks(model)
371
+
372
+ model = build_model(pretrain_file,
373
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
374
+ '--constituent_stack', 'lstm')
375
+ run_forward_checks(model)
376
+
377
+ def test_different_transition_sizes(pretrain_file):
378
+ """
379
+ If the transition hidden size and embedding size are different, the model should still work
380
+ """
381
+ model = build_model(pretrain_file,
382
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
383
+ '--transition_embedding_dim', '10', '--transition_hidden_size', '10',
384
+ '--sentence_boundary_vectors', 'everything')
385
+ run_forward_checks(model)
386
+
387
+ model = build_model(pretrain_file,
388
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
389
+ '--transition_embedding_dim', '20', '--transition_hidden_size', '10',
390
+ '--sentence_boundary_vectors', 'everything')
391
+ run_forward_checks(model)
392
+
393
+ model = build_model(pretrain_file,
394
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
395
+ '--transition_embedding_dim', '10', '--transition_hidden_size', '20',
396
+ '--sentence_boundary_vectors', 'everything')
397
+ run_forward_checks(model)
398
+
399
+ model = build_model(pretrain_file,
400
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
401
+ '--transition_embedding_dim', '10', '--transition_hidden_size', '10',
402
+ '--sentence_boundary_vectors', 'none')
403
+ run_forward_checks(model)
404
+
405
+ model = build_model(pretrain_file,
406
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
407
+ '--transition_embedding_dim', '20', '--transition_hidden_size', '10',
408
+ '--sentence_boundary_vectors', 'none')
409
+ run_forward_checks(model)
410
+
411
+ model = build_model(pretrain_file,
412
+ '--pattn_num_layers', '0', '--lattn_d_proj', '0',
413
+ '--transition_embedding_dim', '10', '--transition_hidden_size', '20',
414
+ '--sentence_boundary_vectors', 'none')
415
+ run_forward_checks(model)
416
+
417
+
418
+ def test_lstm_tree_forward(pretrain_file):
419
+ """
420
+ Test the LSTM_TREE forward pass
421
+ """
422
+ model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm')
423
+ run_forward_checks(model)
424
+ model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm')
425
+ run_forward_checks(model)
426
+ model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm')
427
+ run_forward_checks(model)
428
+
429
+ def test_lstm_tree_cx_forward(pretrain_file):
430
+ """
431
+ Test the LSTM_TREE_CX forward pass
432
+ """
433
+ model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm_cx')
434
+ run_forward_checks(model)
435
+ model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm_cx')
436
+ run_forward_checks(model)
437
+ model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm_cx')
438
+ run_forward_checks(model)
439
+
440
+ def test_maxout(pretrain_file):
441
+ """
442
+ Test with and without maxout layers for output
443
+ """
444
+ model = build_model(pretrain_file, '--maxout_k', '0')
445
+ run_forward_checks(model)
446
+ # check the output size & implicitly check the type
447
+ # to check for a particularly silly bug
448
+ assert model.output_layers[-1].weight.shape[0] == len(model.transitions)
449
+
450
+ model = build_model(pretrain_file, '--maxout_k', '2')
451
+ run_forward_checks(model)
452
+ assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 2
453
+
454
+ model = build_model(pretrain_file, '--maxout_k', '3')
455
+ run_forward_checks(model)
456
+ assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 3
457
+
458
+ def check_structure_test(pretrain_file, args1, args2):
459
+ """
460
+ Test that the "copy" method copies the parameters from one model to another
461
+
462
+ Also check that the copied models produce the same results
463
+ """
464
+ set_random_seed(1000)
465
+ other = build_model(pretrain_file, *args1)
466
+ other.eval()
467
+
468
+ set_random_seed(1001)
469
+ model = build_model(pretrain_file, *args2)
470
+ model.eval()
471
+
472
+ assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
473
+ assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
474
+
475
+ model.copy_with_new_structure(other)
476
+
477
+ assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
478
+ assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
479
+ # the norms will be the same, as the non-zero values are all the same
480
+ assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0))
481
+
482
+ # now, check that applying one transition to an initial state
483
+ # results in the same values in the output states for both models
484
+ # as the pattn layer inputs are 0, the output values should be equal
485
+ shift = [parse_transitions.Shift()]
486
+ model_states = test_parse_transitions.build_initial_state(model, 1)
487
+ model_states = model.bulk_apply(model_states, shift)
488
+
489
+ other_states = test_parse_transitions.build_initial_state(other, 1)
490
+ other_states = other.bulk_apply(other_states, shift)
491
+
492
+ for i, j in zip(other_states[0].word_queue, model_states[0].word_queue):
493
+ assert torch.allclose(i.hx, j.hx, atol=1e-07)
494
+ for i, j in zip(other_states[0].transitions, model_states[0].transitions):
495
+ assert torch.allclose(i.lstm_hx, j.lstm_hx)
496
+ assert torch.allclose(i.lstm_cx, j.lstm_cx)
497
+ for i, j in zip(other_states[0].constituents, model_states[0].constituents):
498
+ assert (i.value is None) == (j.value is None)
499
+ if i.value is not None:
500
+ assert torch.allclose(i.value.tree_hx, j.value.tree_hx, atol=1e-07)
501
+ assert torch.allclose(i.lstm_hx, j.lstm_hx)
502
+ assert torch.allclose(i.lstm_cx, j.lstm_cx)
503
+
504
+ def test_copy_with_new_structure_same(pretrain_file):
505
+ """
506
+ Test that copying the structure with no changes works as expected
507
+ """
508
+ check_structure_test(pretrain_file,
509
+ ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
510
+ ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'])
511
+
512
+ def test_copy_with_new_structure_untied(pretrain_file):
513
+ """
514
+ Test that copying the structure with no changes works as expected
515
+ """
516
+ check_structure_test(pretrain_file,
517
+ ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'MAX'],
518
+ ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'UNTIED_MAX'])
519
+
520
+ def test_copy_with_new_structure_pattn(pretrain_file):
521
+ check_structure_test(pretrain_file,
522
+ ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
523
+ ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
524
+
525
+ def test_copy_with_new_structure_both(pretrain_file):
526
+ check_structure_test(pretrain_file,
527
+ ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
528
+ ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
529
+
530
+ def test_copy_with_new_structure_lattn(pretrain_file):
531
+ check_structure_test(pretrain_file,
532
+ ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'],
533
+ ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
534
+
535
+ def test_parse_tagged_words(pretrain_file):
536
+ """
537
+ Small test which doesn't check results, just execution
538
+ """
539
+ model = build_model(pretrain_file)
540
+
541
+ sentence = [("I", "PRP"), ("am", "VBZ"), ("Luffa", "NNP")]
542
+
543
+ # we don't expect a useful tree out of a random model
544
+ # so we don't check the result
545
+ # just check that it works without crashing
546
+ result = model.parse_tagged_words([sentence], 10)
547
+ assert len(result) == 1
548
+ pts = [x for x in result[0].yield_preterminals()]
549
+
550
+ for word, pt in zip(sentence, pts):
551
+ assert pt.children[0].label == word[0]
552
+ assert pt.label == word[1]
stanza/stanza/tests/constituency/test_text_processing.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run through the various text processing methods for using the parser on text files / directories
3
+
4
+ Uses a simple tree where the parser should always get it right, but things could potentially go wrong
5
+ """
6
+
7
+ import glob
8
+ import os
9
+ import pytest
10
+
11
+ from stanza import Pipeline
12
+
13
+ from stanza.models.constituency import text_processing
14
+ from stanza.models.constituency import tree_reader
15
+ from stanza.tests import TEST_MODELS_DIR
16
+
17
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
18
+
19
+ @pytest.fixture(scope="module")
20
+ def pipeline():
21
+ return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True)
22
+
23
+ def test_read_tokenized_file(tmp_path):
24
+ filename = str(tmp_path / "test_input.txt")
25
+ with open(filename, "w") as fout:
26
+ # test that the underscore token comes back with spaces
27
+ fout.write("This is a_small test\nLine two\n")
28
+ text, ids = text_processing.read_tokenized_file(filename)
29
+ assert text == [['This', 'is', 'a small', 'test'], ['Line', 'two']]
30
+ assert ids == [None, None]
31
+
32
+ def test_parse_tokenized_sentences(pipeline):
33
+ con_processor = pipeline.processors["constituency"]
34
+ model = con_processor._model
35
+ args = model.args
36
+
37
+ sentences = [["This", "is", "a", "test"]]
38
+ trees = text_processing.parse_tokenized_sentences(args, model, [pipeline], sentences)
39
+ predictions = [x.predictions for x in trees]
40
+ assert len(predictions) == 1
41
+ scored_trees = predictions[0]
42
+ assert len(scored_trees) == 1
43
+ result = "{}".format(scored_trees[0].tree)
44
+ expected = "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))"
45
+ assert result == expected
46
+
47
+ def test_parse_text(tmp_path, pipeline):
48
+ con_processor = pipeline.processors["constituency"]
49
+ model = con_processor._model
50
+ args = model.args
51
+
52
+ raw_file = str(tmp_path / "test_input.txt")
53
+ with open(raw_file, "w") as fout:
54
+ fout.write("This is a test\nThis is another test\n")
55
+ output_file = str(tmp_path / "test_output.txt")
56
+ text_processing.parse_text(args, model, [pipeline], tokenized_file=raw_file, predict_file=output_file)
57
+
58
+ trees = tree_reader.read_treebank(output_file)
59
+ trees = ["{}".format(x) for x in trees]
60
+ expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
61
+ "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
62
+ assert trees == expected_trees
63
+
64
+ def test_parse_dir(tmp_path, pipeline):
65
+ con_processor = pipeline.processors["constituency"]
66
+ model = con_processor._model
67
+ args = model.args
68
+
69
+ raw_dir = str(tmp_path / "input")
70
+ os.makedirs(raw_dir)
71
+ raw_f1 = str(tmp_path / "input" / "f1.txt")
72
+ raw_f2 = str(tmp_path / "input" / "f2.txt")
73
+ output_dir = str(tmp_path / "output")
74
+
75
+ with open(raw_f1, "w") as fout:
76
+ fout.write("This is a test")
77
+ with open(raw_f2, "w") as fout:
78
+ fout.write("This is another test")
79
+
80
+ text_processing.parse_dir(args, model, [pipeline], raw_dir, output_dir)
81
+ output_files = sorted(glob.glob(os.path.join(output_dir, "*")))
82
+ expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
83
+ "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
84
+ for output_file, expected_tree in zip(output_files, expected_trees):
85
+ trees = tree_reader.read_treebank(output_file)
86
+ assert len(trees) == 1
87
+ assert "{}".format(trees[0]) == expected_tree
88
+
89
+ def test_parse_text(tmp_path, pipeline):
90
+ con_processor = pipeline.processors["constituency"]
91
+ model = con_processor._model
92
+ args = dict(model.args)
93
+
94
+ model_path = con_processor._config['model_path']
95
+
96
+ raw_file = str(tmp_path / "test_input.txt")
97
+ with open(raw_file, "w") as fout:
98
+ fout.write("This is a test\nThis is another test\n")
99
+ output_file = str(tmp_path / "test_output.txt")
100
+
101
+ args['tokenized_file'] = raw_file
102
+ args['predict_file'] = output_file
103
+
104
+ text_processing.load_model_parse_text(args, model_path, [pipeline])
105
+ trees = tree_reader.read_treebank(output_file)
106
+ trees = ["{}".format(x) for x in trees]
107
+ expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
108
+ "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
109
+ assert trees == expected_trees
stanza/stanza/tests/constituency/test_top_down_oracle.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.models.constituency.base_model import SimpleModel
4
+ from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, TransitionScheme
5
+ from stanza.models.constituency.top_down_oracle import *
6
+ from stanza.models.constituency.transition_sequence import build_sequence
7
+ from stanza.models.constituency.tree_reader import read_trees
8
+
9
+ from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
10
+
11
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
12
+
13
+ OPEN_SHIFT_EXAMPLE_TREE = """
14
+ ( (S
15
+ (NP (NNP Jennifer) (NNP Sh\'reyan))
16
+ (VP (VBZ has)
17
+ (NP (RB nice) (NNS antennae)))))
18
+ """
19
+
20
+ OPEN_SHIFT_PROBLEM_TREE = """
21
+ (ROOT (S (NP (NP (NP (DT The) (`` ``) (JJ Thin) (NNP Man) ('' '') (NN series)) (PP (IN of) (NP (NNS movies)))) (, ,) (CONJP (RB as) (RB well) (IN as)) (NP (JJ many) (NNS others)) (, ,)) (VP (VBD based) (NP (PRP$ their) (JJ entire) (JJ comedic) (NN appeal)) (PP (IN on) (NP (NP (DT the) (NN star) (NNS detectives) (POS ')) (JJ witty) (NNS quips) (CC and) (NNS puns))) (SBAR (IN as) (S (NP (NP (JJ other) (NNS characters)) (PP (IN in) (NP (DT the) (NNS movies)))) (VP (VBD were) (VP (VBN murdered)))))) (. .)))
22
+ """
23
+
24
+ ROOT_LABELS = ["ROOT"]
25
+
26
+ def get_single_repair(gold_sequence, wrong_transition, repair_fn, idx, *args, **kwargs):
27
+ return repair_fn(gold_sequence[idx], wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None, *args, **kwargs)
28
+
29
+ def build_state(model, tree, num_transitions):
30
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
31
+ states = model.initial_state_from_gold_trees([tree], [transitions])
32
+ for idx, t in enumerate(transitions[:num_transitions]):
33
+ assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence)
34
+ states = model.bulk_apply(states, [t])
35
+ state = states[0]
36
+ return state
37
+
38
+ def test_fix_open_shift():
39
+ trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
40
+ assert len(trees) == 1
41
+ tree = trees[0]
42
+
43
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
44
+ EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
45
+ EXPECTED_FIX_EARLY = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
46
+ EXPECTED_FIX_LATE = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
47
+
48
+ assert transitions == EXPECTED_ORIG
49
+
50
+ new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)
51
+ assert new_transitions == EXPECTED_FIX_EARLY
52
+
53
+ new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 8)
54
+ assert new_transitions == EXPECTED_FIX_LATE
55
+
56
+ def test_fix_open_shift_observed_error():
57
+ """
58
+ Ran into an error on this tree, need to fix it
59
+
60
+ The problem is the multiple Open in a row all need to be removed when a Shift happens
61
+ """
62
+ trees = read_trees(OPEN_SHIFT_PROBLEM_TREE)
63
+ assert len(trees) == 1
64
+ tree = trees[0]
65
+
66
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
67
+ new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)
68
+ assert new_transitions is None
69
+
70
+ new_transitions = get_single_repair(transitions, Shift(), fix_multiple_open_shift, 2)
71
+
72
+ # Can break the expected transitions down like this:
73
+ # [OpenConstituent(('ROOT',)), OpenConstituent(('S',)),
74
+ # all gone: OpenConstituent(('NP',)), OpenConstituent(('NP',)), OpenConstituent(('NP',)),
75
+ # Shift, Shift, Shift, Shift, Shift, Shift,
76
+ # gone: CloseConstituent,
77
+ # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), Shift, CloseConstituent, CloseConstituent,
78
+ # gone: CloseConstituent,
79
+ # Shift, OpenConstituent(('CONJP',)), Shift, Shift, Shift, CloseConstituent, OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, Shift,
80
+ # gone: CloseConstituent,
81
+ # and then the rest:
82
+ # OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)),
83
+ # Shift, Shift, Shift, Shift, CloseConstituent,
84
+ # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),
85
+ # OpenConstituent(('NP',)), Shift, Shift, Shift, Shift,
86
+ # CloseConstituent, Shift, Shift, Shift, Shift, CloseConstituent,
87
+ # CloseConstituent, OpenConstituent(('SBAR',)), Shift,
88
+ # OpenConstituent(('S',)), OpenConstituent(('NP',)),
89
+ # OpenConstituent(('NP',)), Shift, Shift, CloseConstituent,
90
+ # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),
91
+ # Shift, Shift, CloseConstituent, CloseConstituent,
92
+ # CloseConstituent, OpenConstituent(('VP',)), Shift,
93
+ # OpenConstituent(('VP',)), Shift, CloseConstituent,
94
+ # CloseConstituent, CloseConstituent, CloseConstituent,
95
+ # CloseConstituent, Shift, CloseConstituent, CloseConstituent]
96
+ expected_transitions = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), Shift(), Shift(), Shift(), Shift(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), Shift(), OpenConstituent('CONJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('SBAR'), Shift(), OpenConstituent('S'), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('VP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
97
+
98
+ assert new_transitions == expected_transitions
99
+
100
+ def test_open_open_ambiguous_unary_fix():
101
+ trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
102
+ assert len(trees) == 1
103
+ tree = trees[0]
104
+
105
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
106
+ EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
107
+ EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
108
+ assert transitions == EXPECTED_ORIG
109
+ new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_unary, 2)
110
+ assert new_transitions == EXPECTED_FIX
111
+
112
+
113
+ def test_open_open_ambiguous_later_fix():
114
+ trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
115
+ assert len(trees) == 1
116
+ tree = trees[0]
117
+
118
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
119
+ EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
120
+ EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
121
+ assert transitions == EXPECTED_ORIG
122
+ new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_later, 2)
123
+ assert new_transitions == EXPECTED_FIX
124
+
125
+
126
+ CLOSE_SHIFT_EXAMPLE_TREE = """
127
+ ( (NP (DT a)
128
+ (ADJP (NN stock) (HYPH -) (VBG picking))
129
+ (NN tool)))
130
+ """
131
+
132
+ # not intended to be a correct tree
133
+ CLOSE_SHIFT_DEEP_EXAMPLE_TREE = """
134
+ ( (NP (DT a)
135
+ (VP (ADJP (NN stock) (HYPH -) (VBG picking)))
136
+ (NN tool)))
137
+ """
138
+
139
+ # not intended to be a correct tree
140
+ CLOSE_SHIFT_OPEN_EXAMPLE_TREE = """
141
+ ( (NP (DT a)
142
+ (ADJP (NN stock) (HYPH -) (VBG picking))
143
+ (NP (NN tool))))
144
+ """
145
+
146
+ CLOSE_SHIFT_AMBIGUOUS_TREE = """
147
+ ( (NP (DT a)
148
+ (ADJP (NN stock) (HYPH -) (VBG picking))
149
+ (NN tool)
150
+ (NN foo)))
151
+ """
152
+
153
+ def test_fix_close_shift_ambiguous_immediate():
154
+ """
155
+ Test the result when a close/shift error occurs and we want to close the new, incorrect constituent immediately
156
+ """
157
+ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
158
+ assert len(trees) == 1
159
+ tree = trees[0]
160
+
161
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
162
+ new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_later, 7)
163
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
164
+ expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
165
+ assert transitions == expected_original
166
+ assert new_sequence == expected_update
167
+
168
+ def test_fix_close_shift_ambiguous_later():
169
+ # test that the one with two shifts, which is ambiguous, gets rejected
170
+ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
171
+ assert len(trees) == 1
172
+ tree = trees[0]
173
+
174
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
175
+ new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_immediate, 7)
176
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
177
+ expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
178
+ assert transitions == expected_original
179
+ assert new_sequence == expected_update
180
+
181
+ def test_oracle_with_optional_level():
182
+ tree = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)[0]
183
+ gold_sequence = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
184
+ expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
185
+
186
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
187
+ assert transitions == gold_sequence
188
+
189
+ oracle = TopDownOracle(ROOT_LABELS, 1, "", "")
190
+
191
+ model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY, root_labels=ROOT_LABELS)
192
+ state = build_state(model, tree, 7)
193
+ fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],
194
+ model=model,
195
+ state=state)
196
+ assert fix is RepairType.OTHER_CLOSE_SHIFT
197
+ assert new_sequence is None
198
+
199
+ oracle = TopDownOracle(ROOT_LABELS, 1, "CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR", "")
200
+ fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],
201
+ model=model,
202
+ state=state)
203
+ assert fix is RepairType.CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR
204
+ assert new_sequence == expected_update
205
+
206
+
207
+ def test_fix_close_shift():
208
+ """
209
+ Test a tree of the kind we expect the close/shift to be able to get right
210
+ """
211
+ trees = read_trees(CLOSE_SHIFT_EXAMPLE_TREE)
212
+ assert len(trees) == 1
213
+ tree = trees[0]
214
+
215
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
216
+
217
+ new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)
218
+
219
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
220
+ expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
221
+ assert transitions == expected_original
222
+ assert new_sequence == expected_update
223
+
224
+ # test that the one with two shifts, which is ambiguous, gets rejected
225
+ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
226
+ assert len(trees) == 1
227
+ tree = trees[0]
228
+
229
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
230
+ new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)
231
+ assert new_sequence is None
232
+
233
+ def test_fix_close_shift_deeper_tree():
234
+ """
235
+ Test a tree of the kind we expect the close/shift to be able to get right
236
+ """
237
+ trees = read_trees(CLOSE_SHIFT_DEEP_EXAMPLE_TREE)
238
+ assert len(trees) == 1
239
+ tree = trees[0]
240
+
241
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
242
+
243
+ for count_opens in [True, False]:
244
+ new_sequence = get_single_repair(transitions, transitions[10], fix_close_shift, 8, count_opens=count_opens)
245
+
246
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
247
+ expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
248
+ assert transitions == expected_original
249
+ assert new_sequence == expected_update
250
+
251
+ def test_fix_close_shift_open_tree():
252
+ """
253
+ We would like the close/shift to get this case right as well
254
+ """
255
+ trees = read_trees(CLOSE_SHIFT_OPEN_EXAMPLE_TREE)
256
+ assert len(trees) == 1
257
+ tree = trees[0]
258
+
259
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
260
+
261
+ new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift, 7, count_opens=False)
262
+ assert new_sequence is None
263
+
264
+ new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift_with_opens, 7)
265
+
266
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
267
+ expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
268
+ assert transitions == expected_original
269
+ assert new_sequence == expected_update
270
+
271
+ CLOSE_OPEN_EXAMPLE_TREE = """
272
+ ( (VP (VBZ eat)
273
+ (NP (NN spaghetti))
274
+ (PP (IN with) (DT a) (NN fork))))
275
+ """
276
+
277
+ CLOSE_OPEN_DIFFERENT_LABEL_TREE = """
278
+ ( (VP (VBZ eat)
279
+ (NP (NN spaghetti))
280
+ (NP (DT a) (NN fork))))
281
+ """
282
+
283
+ CLOSE_OPEN_TWO_LABELS_TREE = """
284
+ ( (VP (VBZ eat)
285
+ (NP (NN spaghetti))
286
+ (PP (IN with) (DT a) (NN fork))
287
+ (PP (IN in) (DT a) (NN restaurant))))
288
+ """
289
+
290
+ def test_fix_close_open():
291
+ trees = read_trees(CLOSE_OPEN_EXAMPLE_TREE)
292
+ assert len(trees) == 1
293
+ tree = trees[0]
294
+
295
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
296
+
297
+ assert isinstance(transitions[5], CloseConstituent)
298
+ assert transitions[6] == OpenConstituent("PP")
299
+
300
+ new_transitions = get_single_repair(transitions, transitions[6], fix_close_open_correct_open, 5)
301
+
302
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
303
+ expected_update = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
304
+
305
+ assert transitions == expected_original
306
+ assert new_transitions == expected_update
307
+
308
+ def test_fix_close_open_invalid():
309
+ for TREE in (CLOSE_OPEN_DIFFERENT_LABEL_TREE, CLOSE_OPEN_TWO_LABELS_TREE):
310
+ trees = read_trees(TREE)
311
+ assert len(trees) == 1
312
+ tree = trees[0]
313
+
314
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
315
+
316
+ assert isinstance(transitions[5], CloseConstituent)
317
+ assert isinstance(transitions[6], OpenConstituent)
318
+
319
+ new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5)
320
+ assert new_transitions is None
321
+
322
+ def test_fix_close_open_ambiguous_immediate():
323
+ """
324
+ Test that a fix for an ambiguous close/open works as expected
325
+ """
326
+ trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)
327
+ assert len(trees) == 1
328
+ tree = trees[0]
329
+
330
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
331
+ assert isinstance(transitions[5], CloseConstituent)
332
+ assert isinstance(transitions[6], OpenConstituent)
333
+
334
+ reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)
335
+ assert tree == reconstructed
336
+
337
+ new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5, check_close=False)
338
+ reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
339
+
340
+ expected = """
341
+ ( (VP (VBZ eat)
342
+ (NP (NN spaghetti)
343
+ (PP (IN with) (DT a) (NN fork)))
344
+ (PP (IN in) (DT a) (NN restaurant))))
345
+ """
346
+ expected = read_trees(expected)[0]
347
+ assert reconstructed == expected
348
+
349
+ def test_fix_close_open_ambiguous_later():
350
+ """
351
+ Test that a fix for an ambiguous close/open works as expected
352
+ """
353
+ trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)
354
+ assert len(trees) == 1
355
+ tree = trees[0]
356
+
357
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
358
+ assert isinstance(transitions[5], CloseConstituent)
359
+ assert isinstance(transitions[6], OpenConstituent)
360
+
361
+ reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)
362
+ assert tree == reconstructed
363
+
364
+ new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open_ambiguous_later, 5, check_close=False)
365
+ reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
366
+
367
+ expected = """
368
+ ( (VP (VBZ eat)
369
+ (NP (NN spaghetti)
370
+ (PP (IN with) (DT a) (NN fork))
371
+ (PP (IN in) (DT a) (NN restaurant)))))
372
+ """
373
+ expected = read_trees(expected)[0]
374
+ assert reconstructed == expected
375
+
376
+
377
+ SHIFT_CLOSE_EXAMPLES = [
378
+ ("((S (NP (DT an) (NML (NNP Oct) (CD 19)) (NN review))))", "((S (NP (DT an) (NML (NNP Oct) (CD 19))) (NN review)))", 8),
379
+ ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
380
+ "((S (NP (` `) (NP (DT The)) (NN Misanthrope) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", 6),
381
+ ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
382
+ "((S (NP (` `) (NP (DT The) (NN Misanthrope))) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre)))))", 8),
383
+ ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
384
+ "((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman)) (NNP Theatre)))))", 13),
385
+ ]
386
+
387
+ def test_shift_close():
388
+ for idx, (orig_tree, expected_tree, shift_position) in enumerate(SHIFT_CLOSE_EXAMPLES):
389
+ trees = read_trees(orig_tree)
390
+ assert len(trees) == 1
391
+ tree = trees[0]
392
+
393
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
394
+ if shift_position is None:
395
+ print(transitions)
396
+ continue
397
+
398
+ assert isinstance(transitions[shift_position], Shift)
399
+ new_transitions = get_single_repair(transitions, CloseConstituent(), fix_shift_close, shift_position)
400
+ reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
401
+ if expected_tree is None:
402
+ print(transitions)
403
+ print(new_transitions)
404
+
405
+ print("{:P}".format(reconstructed))
406
+ else:
407
+ expected_tree = read_trees(expected_tree)
408
+ assert len(expected_tree) == 1
409
+ expected_tree = expected_tree[0]
410
+
411
+ assert reconstructed == expected_tree
412
+
413
+ def test_shift_open_ambiguous_unary():
414
+ """
415
+ Test what happens if a Shift is turned into an Open in an ambiguous manner
416
+ """
417
+ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
418
+ assert len(trees) == 1
419
+ tree = trees[0]
420
+
421
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
422
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
423
+ assert transitions == expected_original
424
+
425
+ new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_unary, 4)
426
+ expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
427
+ assert new_sequence == expected_updated
428
+
429
+ def test_shift_open_ambiguous_later():
430
+ """
431
+ Test what happens if a Shift is turned into an Open in an ambiguous manner
432
+ """
433
+ trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
434
+ assert len(trees) == 1
435
+ tree = trees[0]
436
+
437
+ transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
438
+ expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
439
+ assert transitions == expected_original
440
+
441
+ new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_later, 4)
442
+ expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
443
+ assert new_sequence == expected_updated
stanza/stanza/tests/constituency/test_trainer.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import logging
3
+ import pathlib
4
+ import tempfile
5
+
6
+ import pytest
7
+ import torch
8
+ from torch import nn
9
+ from torch import optim
10
+
11
+ from stanza import Pipeline
12
+
13
+ from stanza.models import constituency_parser
14
+ from stanza.models.common import pretrain
15
+ from stanza.models.common.bert_embedding import load_bert, load_tokenizer
16
+ from stanza.models.common.foundation_cache import FoundationCache
17
+ from stanza.models.common.utils import set_random_seed
18
+ from stanza.models.constituency import lstm_model
19
+ from stanza.models.constituency.parse_transitions import Transition
20
+ from stanza.models.constituency import parser_training
21
+ from stanza.models.constituency import trainer
22
+ from stanza.models.constituency import tree_reader
23
+ from stanza.tests import *
24
+
25
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
26
+
27
+ logger = logging.getLogger('stanza.constituency.trainer')
28
+ logger.setLevel(logging.WARNING)
29
+
30
+ TREEBANK = """
31
+ ( (S
32
+ (VP (VBG Enjoying)
33
+ (NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition)))
34
+ (. .)))
35
+
36
+ ( (NP
37
+ (VP (VBG Sitting)
38
+ (PP (IN in)
39
+ (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station)))
40
+ (VP (VBG waiting)
41
+ (PP (IN for)
42
+ (NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train)))))
43
+ (. .)))
44
+
45
+ ( (S
46
+ (NP (PRP I))
47
+ (VP
48
+ (ADVP (RB really))
49
+ (VBP hate)
50
+ (NP (DT the) (NNP @MBTA)))))
51
+
52
+ ( (S
53
+ (S (VP (VB Seek)))
54
+ (CC and)
55
+ (S (NP (PRP ye))
56
+ (VP (MD shall)
57
+ (VP (VB find))))
58
+ (. .)))
59
+ """
60
+
61
+ def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK):
62
+ # TODO: build a fake embedding some other way?
63
+ train_trees = tree_reader.read_trees(treebank)
64
+ dev_trees = train_trees[-1:]
65
+ silver_trees = []
66
+
67
+ args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args)
68
+ args = constituency_parser.parse_args(args)
69
+
70
+ foundation_cache = FoundationCache()
71
+ # might be None, unless we're testing loading an existing model
72
+ model_load_name = args['load_name']
73
+
74
+ model, _, _, _ = parser_training.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name)
75
+ assert isinstance(model.model, lstm_model.LSTMModel)
76
+ return model
77
+
78
+ class TestTrainer:
79
+ @pytest.fixture(scope="class")
80
+ def wordvec_pretrain_file(self):
81
+ return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
82
+
83
+ @pytest.fixture(scope="class")
84
+ def tiny_random_xlnet(self, tmp_path_factory):
85
+ """
86
+ Download the tiny-random-xlnet model and make a concrete copy of it
87
+
88
+ The issue here is that the "random" nature of the original
89
+ makes it difficult or impossible to test that the values in
90
+ the transformer don't change during certain operations.
91
+ Saving a concrete instantiation of those random numbers makes
92
+ it so we can test there is no difference when training only a
93
+ subset of the layers, for example
94
+ """
95
+ xlnet_name = 'hf-internal-testing/tiny-random-xlnet'
96
+ xlnet_model, xlnet_tokenizer = load_bert(xlnet_name)
97
+ path = str(tmp_path_factory.mktemp('tiny-random-xlnet'))
98
+ xlnet_model.save_pretrained(path)
99
+ xlnet_tokenizer.save_pretrained(path)
100
+ return path
101
+
102
+ @pytest.fixture(scope="class")
103
+ def tiny_random_bart(self, tmp_path_factory):
104
+ """
105
+ Download the tiny-random-bart model and make a concrete copy of it
106
+
107
+ Issue is the same as with tiny_random_xlnet
108
+ """
109
+ bart_name = 'hf-internal-testing/tiny-random-bart'
110
+ bart_model, bart_tokenizer = load_bert(bart_name)
111
+ path = str(tmp_path_factory.mktemp('tiny-random-bart'))
112
+ bart_model.save_pretrained(path)
113
+ bart_tokenizer.save_pretrained(path)
114
+ return path
115
+
116
+ def test_initial_model(self, wordvec_pretrain_file):
117
+ """
118
+ does nothing, just tests that the construction went okay
119
+ """
120
+ args = ['wordvec_pretrain_file', wordvec_pretrain_file]
121
+ build_trainer(wordvec_pretrain_file)
122
+
123
+
124
+ def test_save_load_model(self, wordvec_pretrain_file):
125
+ """
126
+ Just tests that saving and loading works without crashs.
127
+
128
+ Currently no test of the values themselves
129
+ (checks some fields to make sure they are regenerated correctly)
130
+ """
131
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
132
+ tr = build_trainer(wordvec_pretrain_file)
133
+ transitions = tr.model.transitions
134
+
135
+ # attempt saving
136
+ filename = os.path.join(tmpdirname, "parser.pt")
137
+ tr.save(filename)
138
+
139
+ assert os.path.exists(filename)
140
+
141
+ # load it back in
142
+ tr2 = tr.load(filename)
143
+ trans2 = tr2.model.transitions
144
+ assert(transitions == trans2)
145
+ assert all(isinstance(x, Transition) for x in trans2)
146
+
147
+ def test_relearn_structure(self, wordvec_pretrain_file):
148
+ """
149
+ Test that starting a trainer with --relearn_structure copies the old model
150
+ """
151
+
152
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
153
+ set_random_seed(1000)
154
+ args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
155
+ tr = build_trainer(wordvec_pretrain_file, *args)
156
+
157
+ # attempt saving
158
+ filename = os.path.join(tmpdirname, "parser.pt")
159
+ tr.save(filename)
160
+
161
+ set_random_seed(1001)
162
+ args = ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--relearn_structure', '--load_name', filename]
163
+ tr2 = build_trainer(wordvec_pretrain_file, *args)
164
+
165
+ assert torch.allclose(tr.model.delta_embedding.weight, tr2.model.delta_embedding.weight)
166
+ assert torch.allclose(tr.model.output_layers[0].weight, tr2.model.output_layers[0].weight)
167
+ # the norms will be the same, as the non-zero values are all the same
168
+ assert torch.allclose(torch.linalg.norm(tr.model.word_lstm.weight_ih_l0), torch.linalg.norm(tr2.model.word_lstm.weight_ih_l0))
169
+
170
+ def write_treebanks(self, tmpdirname):
171
+ train_treebank_file = os.path.join(tmpdirname, "train.mrg")
172
+ with open(train_treebank_file, 'w', encoding='utf-8') as fout:
173
+ fout.write(TREEBANK)
174
+ fout.write(TREEBANK)
175
+
176
+ eval_treebank_file = os.path.join(tmpdirname, "eval.mrg")
177
+ with open(eval_treebank_file, 'w', encoding='utf-8') as fout:
178
+ fout.write(TREEBANK)
179
+
180
+ return train_treebank_file, eval_treebank_file
181
+
182
+ def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args):
183
+ # let's not make the model huge...
184
+ args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10',
185
+ '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname,
186
+ '--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_start', '0', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'),
187
+ '--train_file', train_treebank_file, '--eval_file', eval_treebank_file,
188
+ '--epoch_size', '6', '--train_batch_size', '3',
189
+ '--shorthand', 'en_test']
190
+ args = args + list(additional_args)
191
+ args = constituency_parser.parse_args(args)
192
+ # just in case we change the defaults in the future
193
+ args['wandb'] = None
194
+ return args
195
+
196
+ def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False, foundation_cache=None):
197
+ """
198
+ Runs a test of the trainer for a few iterations.
199
+
200
+ Checks some basic properties of the saved model, but doesn't
201
+ check for the accuracy of the results
202
+ """
203
+ if extra_args is None:
204
+ extra_args = []
205
+ extra_args += ['--epochs', '%d' % num_epochs]
206
+
207
+ train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
208
+ if use_silver:
209
+ extra_args += ['--silver_file', str(eval_treebank_file)]
210
+ args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)
211
+
212
+ each_name = args['save_each_name']
213
+ if not exists_ok:
214
+ assert not os.path.exists(args['save_name'])
215
+ retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True, dir=TEST_MODELS_DIR, foundation_cache=foundation_cache)
216
+ trained_model = parser_training.train(args, None, [retag_pipeline])
217
+ # check that hooks are in the model if expected
218
+ for p in trained_model.model.parameters():
219
+ if p.requires_grad:
220
+ if args['grad_clipping'] is not None:
221
+ assert len(p._backward_hooks) == 1
222
+ else:
223
+ assert p._backward_hooks is None
224
+
225
+ # check that the model can be loaded back
226
+ assert os.path.exists(args['save_name'])
227
+ peft_name = trained_model.model.peft_name
228
+ tr = trainer.Trainer.load(args['save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
229
+ assert tr.optimizer is not None
230
+ assert tr.scheduler is not None
231
+ assert tr.epochs_trained >= 1
232
+ for p in tr.model.parameters():
233
+ if p.requires_grad:
234
+ assert p._backward_hooks is None
235
+
236
+ tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
237
+ assert tr.optimizer is not None
238
+ assert tr.scheduler is not None
239
+ assert tr.epochs_trained == num_epochs
240
+
241
+ for i in range(1, num_epochs+1):
242
+ model_name = each_name % i
243
+ assert os.path.exists(model_name)
244
+ tr = trainer.Trainer.load(model_name, load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
245
+ assert tr.epochs_trained == i
246
+ assert tr.batches_trained == (4 * i if use_silver else 2 * i)
247
+
248
+ return args, trained_model
249
+
250
+ def test_train(self, wordvec_pretrain_file):
251
+ """
252
+ Test the whole thing for a few iterations on the fake data
253
+ """
254
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
255
+ self.run_train_test(wordvec_pretrain_file, tmpdirname)
256
+
257
+ def test_early_dropout(self, wordvec_pretrain_file):
258
+ """
259
+ Test the whole thing for a few iterations on the fake data
260
+ """
261
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
262
+ args = ['--early_dropout', '3']
263
+ _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
264
+ model = model.model
265
+ dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
266
+ assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
267
+ for name, module in dropouts:
268
+ assert module.p == 0.0, "Dropout module %s was not set to 0 with early_dropout"
269
+
270
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
271
+ # test that when turned off, early_dropout doesn't happen
272
+ args = ['--early_dropout', '-1']
273
+ _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
274
+ model = model.model
275
+ dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
276
+ assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
277
+ if all(module.p == 0.0 for _, module in dropouts):
278
+ raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1")
279
+
280
+ def test_train_silver(self, wordvec_pretrain_file):
281
+ """
282
+ Test the whole thing for a few iterations on the fake data
283
+
284
+ This tests that it works if you give it a silver file
285
+ The check for the use of the silver data is that the
286
+ number of batches trained should go up
287
+ """
288
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
289
+ self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True)
290
+
291
+ def test_train_checkpoint(self, wordvec_pretrain_file):
292
+ """
293
+ Test the whole thing for a few iterations, then restart
294
+
295
+ This tests that the 5th iteration save file is not rewritten
296
+ and that the iterations continue to 10
297
+
298
+ TODO: could make it more robust by verifying that only 5 more
299
+ epochs are trained. Perhaps a "most recent epochs" could be
300
+ saved in the trainer
301
+ """
302
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
303
+ args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False)
304
+ save_5 = args['save_each_name'] % 5
305
+ save_10 = args['save_each_name'] % 10
306
+ assert os.path.exists(save_5)
307
+ assert not os.path.exists(save_10)
308
+
309
+ save_5_stat = pathlib.Path(save_5).stat()
310
+
311
+ self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True)
312
+ assert os.path.exists(save_5)
313
+ assert os.path.exists(save_10)
314
+
315
+ assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime
316
+
317
+ def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):
318
+ train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
319
+ args = ['--multistage', '--pattn_num_layers', '1']
320
+ if use_lattn:
321
+ args += ['--lattn_d_proj', '16']
322
+ if extra_args:
323
+ args += extra_args
324
+ args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args)
325
+ each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
326
+
327
+ word_input_sizes = defaultdict(list)
328
+ for i in range(1, 9):
329
+ model_name = each_name % i
330
+ assert os.path.exists(model_name)
331
+ tr = trainer.Trainer.load(model_name, load_optimizer=True)
332
+ assert tr.epochs_trained == i
333
+ word_input_sizes[tr.model.word_input_size].append(i)
334
+ if use_lattn:
335
+ # there should be three stages: no attn, pattn, pattn+lattn
336
+ assert len(word_input_sizes) == 3
337
+ word_input_keys = sorted(word_input_sizes.keys())
338
+ assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]
339
+ assert word_input_sizes[word_input_keys[1]] == [4, 5]
340
+ assert word_input_sizes[word_input_keys[2]] == [6, 7, 8]
341
+ else:
342
+ # with no lattn, there are two stages: no attn, pattn
343
+ assert len(word_input_sizes) == 2
344
+ word_input_keys = sorted(word_input_sizes.keys())
345
+ assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]
346
+ assert word_input_sizes[word_input_keys[1]] == [4, 5, 6, 7, 8]
347
+
348
+ def test_multistage_lattn(self, wordvec_pretrain_file):
349
+ """
350
+ Test a multistage training for a few iterations on the fake data
351
+
352
+ This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
353
+ """
354
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
355
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True)
356
+
357
+ def test_multistage_no_lattn(self, wordvec_pretrain_file):
358
+ """
359
+ Test a multistage training for a few iterations on the fake data
360
+
361
+ This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
362
+ """
363
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
364
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False)
365
+
366
+ def test_multistage_optimizer(self, wordvec_pretrain_file):
367
+ """
368
+ Test that the correct optimizers are built for a multistage training process
369
+ """
370
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
371
+ extra_args = ['--optim', 'adamw']
372
+ self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args)
373
+
374
+ # check that the optimizers which get rebuilt when loading
375
+ # the models are adadelta for the first half of the
376
+ # multistage, then adamw
377
+ each_name = os.path.join(tmpdirname, 'each_%02d.pt')
378
+ for i in range(1, 3):
379
+ model_name = each_name % i
380
+ tr = trainer.Trainer.load(model_name, load_optimizer=True)
381
+ assert tr.epochs_trained == i
382
+ assert isinstance(tr.optimizer, optim.Adadelta)
383
+ # double check that this is actually a valid test
384
+ assert not isinstance(tr.optimizer, optim.AdamW)
385
+
386
+ for i in range(4, 8):
387
+ model_name = each_name % i
388
+ tr = trainer.Trainer.load(model_name, load_optimizer=True)
389
+ assert tr.epochs_trained == i
390
+ assert isinstance(tr.optimizer, optim.AdamW)
391
+
392
+
393
+ def test_grad_clip_hooks(self, wordvec_pretrain_file):
394
+ """
395
+ Verify that grad clipping is not saved with the model, but is attached at training time
396
+ """
397
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
398
+ args = ['--grad_clipping', '25']
399
+ self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
400
+
401
+ def test_analyze_trees(self, wordvec_pretrain_file):
402
+ test_str = "(ROOT (S (NP (PRP I)) (VP (VBP wan) (S (VP (TO na) (VP (VB lick) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))))) (ROOT (S (NP (DT This) (NN interface)) (VP (VBZ sucks))))"
403
+
404
+ test_tree = tree_reader.read_trees(test_str)
405
+ assert len(test_tree) == 2
406
+
407
+ args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
408
+ tr = build_trainer(wordvec_pretrain_file, *args)
409
+
410
+ results = tr.model.analyze_trees(test_tree)
411
+ assert len(results) == 2
412
+ assert len(results[0].predictions) == 1
413
+ assert results[0].predictions[0].tree == test_tree[0]
414
+ assert results[0].state is not None
415
+ assert isinstance(results[0].state.score, torch.Tensor)
416
+ assert results[0].state.score.shape == torch.Size([])
417
+ assert len(results[0].constituents) == 9
418
+ assert results[0].constituents[-1].value == test_tree[0]
419
+ # the way the results are built, the next-to-last entry
420
+ # should be the thing just below the root
421
+ assert results[0].constituents[-2].value == test_tree[0].children[0]
422
+
423
+ assert len(results[1].predictions) == 1
424
+ assert results[1].predictions[0].tree == test_tree[1]
425
+ assert results[1].state is not None
426
+ assert isinstance(results[1].state.score, torch.Tensor)
427
+ assert results[1].state.score.shape == torch.Size([])
428
+ assert len(results[1].constituents) == 4
429
+ assert results[1].constituents[-1].value == test_tree[1]
430
+ assert results[1].constituents[-2].value == test_tree[1].children[0]
431
+
432
+ def bert_weights_allclose(self, bert_model, parser_model):
433
+ """
434
+ Return True if all bert weights are close, False otherwise
435
+ """
436
+ for name, parameter in bert_model.named_parameters():
437
+ other_name = "bert_model." + name
438
+ other_parameter = parser_model.model.get_parameter(other_name)
439
+ if not torch.allclose(parameter.cpu(), other_parameter.cpu()):
440
+ return False
441
+ return True
442
+
443
+ def frozen_transformer_test(self, wordvec_pretrain_file, transformer_name):
444
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
445
+ foundation_cache = FoundationCache()
446
+ args = ['--bert_model', transformer_name]
447
+ args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args, foundation_cache=foundation_cache)
448
+ bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
449
+ assert self.bert_weights_allclose(bert_model, trained_model)
450
+
451
+ checkpoint = torch.load(args['save_name'], lambda storage, loc: storage, weights_only=True)
452
+ params = checkpoint['params']
453
+ # check that the bert model wasn't saved in the model
454
+ assert all(not x.startswith("bert_model.") for x in params['model'].keys())
455
+ # make sure we're looking at the right thing
456
+ assert any(x.startswith("output_layers.") for x in params['model'].keys())
457
+
458
+ # check that the cached model is used as expected when loading a bert model
459
+ trained_model = trainer.Trainer.load(args['save_name'], foundation_cache=foundation_cache)
460
+ assert trained_model.model.bert_model is bert_model
461
+
462
+ def test_bert_frozen(self, wordvec_pretrain_file):
463
+ """
464
+ Check that the parameters of the bert model don't change when training a basic model
465
+ """
466
+ self.frozen_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
467
+
468
+ def test_xlnet_frozen(self, wordvec_pretrain_file, tiny_random_xlnet):
469
+ """
470
+ Check that the parameters of an xlnet model don't change when training a basic model
471
+ """
472
+ self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
473
+
474
+ def test_bart_frozen(self, wordvec_pretrain_file, tiny_random_bart):
475
+ """
476
+ Check that the parameters of an xlnet model don't change when training a basic model
477
+ """
478
+ self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_bart)
479
+
480
+ def test_bert_finetune_one_epoch(self, wordvec_pretrain_file):
481
+ """
482
+ Check that the parameters the bert model DO change over a single training step
483
+ """
484
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
485
+ transformer_name = 'hf-internal-testing/tiny-bert'
486
+ args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adadelta']
487
+ args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=1, extra_args=args)
488
+
489
+ # check that the weights are different
490
+ foundation_cache = FoundationCache()
491
+ bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
492
+ assert not self.bert_weights_allclose(bert_model, trained_model)
493
+
494
+ # double check that a new bert is created instead of using the FoundationCache when the bert has been trained
495
+ model_name = args['save_name']
496
+ assert os.path.exists(model_name)
497
+ no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name)
498
+ tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)
499
+ assert tr.model.bert_model is not bert_model
500
+ assert not self.bert_weights_allclose(bert_model, tr)
501
+ assert self.bert_weights_allclose(trained_model.model.bert_model, tr)
502
+
503
+ new_save_name = os.path.join(tmpdirname, "test_resave_bert.pt")
504
+ assert not os.path.exists(new_save_name)
505
+ tr.save(new_save_name, save_optimizer=False)
506
+ tr2 = trainer.Trainer.load(new_save_name, args=no_finetune_args, foundation_cache=foundation_cache)
507
+ # check that the resaved model included its finetuned bert weights
508
+ assert tr2.model.bert_model is not bert_model
509
+ # the finetuned bert weights should also be scheduled for saving the next time as well
510
+ assert not tr2.model.is_unsaved_module("bert_model")
511
+
512
+ def finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):
513
+ """
514
+ Check that the parameters of the transformer DO change when using bert_finetune
515
+ """
516
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
517
+ args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw']
518
+ args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
519
+
520
+ # check that the weights are different
521
+ foundation_cache = FoundationCache()
522
+ bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
523
+ assert not self.bert_weights_allclose(bert_model, trained_model)
524
+
525
+ # double check that a new bert is created instead of using the FoundationCache when the bert has been trained
526
+ no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name)
527
+ trained_model = trainer.Trainer.load(args['save_name'], args=no_finetune_args, foundation_cache=foundation_cache)
528
+ assert not trained_model.model.args['bert_finetune']
529
+ assert not trained_model.model.args['stage1_bert_finetune']
530
+ assert trained_model.model.bert_model is not bert_model
531
+
532
+ def test_bert_finetune(self, wordvec_pretrain_file):
533
+ """
534
+ Check that the parameters of a bert model DO change when using bert_finetune
535
+ """
536
+ self.finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
537
+
538
+ def test_xlnet_finetune(self, wordvec_pretrain_file, tiny_random_xlnet):
539
+ """
540
+ Check that the parameters of an xlnet model DO change when using bert_finetune
541
+ """
542
+ self.finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
543
+
544
+ def test_stage1_bert_finetune(self, wordvec_pretrain_file):
545
+ """
546
+ Check that the parameters the bert model DO change when using stage1_bert_finetune, but only for the first couple steps
547
+ """
548
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
549
+ bert_model_name = 'hf-internal-testing/tiny-bert'
550
+ args = ['--bert_model', bert_model_name, '--stage1_bert_finetune', '--optim', 'adamw']
551
+ # need to use num_epochs==6 so that epochs 1 and 2 are saved to be different
552
+ # a test of 5 or less means that sometimes it will reload the params
553
+ # at step 2 to get ready for the following iterations with adamw
554
+ args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
555
+
556
+ # check that the weights are different
557
+ foundation_cache = FoundationCache()
558
+ bert_model, bert_tokenizer = foundation_cache.load_bert(bert_model_name)
559
+ assert not self.bert_weights_allclose(bert_model, trained_model)
560
+
561
+ # double check that a new bert is created instead of using the FoundationCache when the bert has been trained
562
+ no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', bert_model_name, '--optim', 'adamw')
563
+ num_epochs = trained_model.model.args['epochs']
564
+ each_name = os.path.join(tmpdirname, 'each_%02d.pt')
565
+ for i in range(1, num_epochs+1):
566
+ model_name = each_name % i
567
+ assert os.path.exists(model_name)
568
+ tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)
569
+ assert tr.model.bert_model is not bert_model
570
+ assert not self.bert_weights_allclose(bert_model, tr)
571
+ if i >= num_epochs // 2:
572
+ assert self.bert_weights_allclose(trained_model.model.bert_model, tr)
573
+
574
+ # verify that models 1 and 2 are saved to be different
575
+ model_name_1 = each_name % 1
576
+ model_name_2 = each_name % 2
577
+ tr_1 = trainer.Trainer.load(model_name_1, args=no_finetune_args, foundation_cache=foundation_cache)
578
+ tr_2 = trainer.Trainer.load(model_name_2, args=no_finetune_args, foundation_cache=foundation_cache)
579
+ assert not self.bert_weights_allclose(tr_1.model.bert_model, tr_2)
580
+
581
+
582
+ def one_layer_finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):
583
+ """
584
+ Check that the parameters the bert model DO change when using bert_finetune
585
+ """
586
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
587
+ args = ['--bert_model', transformer_name, '--bert_finetune', '--bert_finetune_layers', '1', '--optim', 'adamw', '--bert_finetune_layers', '1']
588
+ args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
589
+
590
+ # check that the weights of the last layer are different,
591
+ # but the weights of the earlier layers and
592
+ # non-transformer-layers are the same
593
+ foundation_cache = FoundationCache()
594
+ bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
595
+ assert bert_model.config.num_hidden_layers > 1
596
+ layer_name = "layer.%d." % (bert_model.config.num_hidden_layers - 1)
597
+ for name, parameter in bert_model.named_parameters():
598
+ other_name = "bert_model." + name
599
+ other_parameter = trained_model.model.get_parameter(other_name)
600
+ if layer_name in name:
601
+ if 'rel_attn.seg_embed' in name or 'rel_attn.r_s_bias' in name:
602
+ # not sure why this happens for xlnet, just roll with it
603
+ continue
604
+ assert not torch.allclose(parameter.cpu(), other_parameter.cpu())
605
+ else:
606
+ assert torch.allclose(parameter.cpu(), other_parameter.cpu())
607
+
608
+ def test_bert_finetune_one_layer(self, wordvec_pretrain_file):
609
+ self.one_layer_finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
610
+
611
+ def test_xlnet_finetune_one_layer(self, wordvec_pretrain_file, tiny_random_xlnet):
612
+ self.one_layer_finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
613
+
614
+ def test_peft_finetune(self, tmp_path, wordvec_pretrain_file):
615
+ transformer_name = 'hf-internal-testing/tiny-bert'
616
+ args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw', '--use_peft']
617
+ args, trained_model = self.run_train_test(wordvec_pretrain_file, str(tmp_path), extra_args=args)
618
+
619
+ def test_peft_twostage_finetune(self, wordvec_pretrain_file):
620
+ with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
621
+ num_epochs = 6
622
+ transformer_name = 'hf-internal-testing/tiny-bert'
623
+ args = ['--bert_model', transformer_name, '--stage1_bert_finetune', '--optim', 'adamw', '--use_peft']
624
+ args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=num_epochs, extra_args=args)
625
+ for epoch in range(num_epochs):
626
+ filename_prev = args['save_each_name'] % epoch
627
+ filename_next = args['save_each_name'] % (epoch+1)
628
+ trainer_prev = trainer.Trainer.load(filename_prev, args=args, load_optimizer=False)
629
+ trainer_next = trainer.Trainer.load(filename_next, args=args, load_optimizer=False)
630
+
631
+ lora_names = [name for name, _ in trainer_prev.model.bert_model.named_parameters() if name.find("lora") >= 0]
632
+ if epoch < 2:
633
+ assert not any(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),
634
+ trainer_next.model.bert_model.get_parameter(name).cpu())
635
+ for name in lora_names)
636
+ elif epoch > 2:
637
+ assert all(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),
638
+ trainer_next.model.bert_model.get_parameter(name).cpu())
639
+ for name in lora_names)
stanza/stanza/tests/constituency/test_transformer_tree_stack.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack
6
+
7
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
8
+
9
+ def test_initial_state():
10
+ """
11
+ Test that the initial state has the expected shapes
12
+ """
13
+ ts = TransformerTreeStack(3, 5, 0.0)
14
+ initial = ts.initial_state()
15
+ assert len(initial) == 1
16
+ assert initial.value.output.shape == torch.Size([5])
17
+ assert initial.value.key_stack.shape == torch.Size([1, 5])
18
+ assert initial.value.value_stack.shape == torch.Size([1, 5])
19
+
20
+ def test_output():
21
+ """
22
+ Test that you can get an expected output shape from the TTS
23
+ """
24
+ ts = TransformerTreeStack(3, 5, 0.0)
25
+ initial = ts.initial_state()
26
+ out = ts.output(initial)
27
+ assert out.shape == torch.Size([5])
28
+ assert torch.allclose(initial.value.output, out)
29
+
30
+ def test_push_state_single():
31
+ """
32
+ Test that stacks are being updated correctly when using a single stack
33
+
34
+ Values of the attention are not verified, though
35
+ """
36
+ ts = TransformerTreeStack(3, 5, 0.0)
37
+ initial = ts.initial_state()
38
+ rand_input = torch.randn(1, 3)
39
+ stacks = ts.push_states([initial], ["A"], rand_input)
40
+ stacks = ts.push_states(stacks, ["B"], rand_input)
41
+ assert len(stacks) == 1
42
+ assert len(stacks[0]) == 3
43
+ assert stacks[0].value.value == "B"
44
+ assert stacks[0].pop().value.value == "A"
45
+ assert stacks[0].pop().pop().value.value is None
46
+
47
+ def test_push_state_same_length():
48
+ """
49
+ Test that stacks are being updated correctly when using 3 stacks of the same length
50
+
51
+ Values of the attention are not verified, though
52
+ """
53
+ ts = TransformerTreeStack(3, 5, 0.0)
54
+ initial = ts.initial_state()
55
+ rand_input = torch.randn(3, 3)
56
+ stacks = ts.push_states([initial, initial, initial], ["A", "A", "A"], rand_input)
57
+ stacks = ts.push_states(stacks, ["B", "B", "B"], rand_input)
58
+ stacks = ts.push_states(stacks, ["C", "C", "C"], rand_input)
59
+ assert len(stacks) == 3
60
+ for s in stacks:
61
+ assert len(s) == 4
62
+ assert s.value.key_stack.shape == torch.Size([4, 5])
63
+ assert s.value.value_stack.shape == torch.Size([4, 5])
64
+ assert s.value.value == "C"
65
+ assert s.pop().value.value == "B"
66
+ assert s.pop().pop().value.value == "A"
67
+ assert s.pop().pop().pop().value.value is None
68
+
69
+ def test_push_state_different_length():
70
+ """
71
+ Test what happens if stacks of different lengths are passed in
72
+ """
73
+ ts = TransformerTreeStack(3, 5, 0.0)
74
+ initial = ts.initial_state()
75
+ rand_input = torch.randn(2, 3)
76
+ one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
77
+ stacks = [one_step, initial]
78
+ stacks = ts.push_states(stacks, ["B", "C"], rand_input)
79
+ assert len(stacks) == 2
80
+ assert len(stacks[0]) == 3
81
+ assert len(stacks[1]) == 2
82
+ assert stacks[0].pop().value.value == 'A'
83
+ assert stacks[0].value.value == 'B'
84
+ assert stacks[1].value.value == 'C'
85
+
86
+ assert stacks[0].value.key_stack.shape == torch.Size([3, 5])
87
+ assert stacks[1].value.key_stack.shape == torch.Size([2, 5])
88
+
89
+ def test_mask():
90
+ """
91
+ Test that a mask prevents the softmax from picking up unwanted values
92
+ """
93
+ ts = TransformerTreeStack(3, 5, 0.0)
94
+
95
+ random_v = torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5]]])
96
+ double_v = random_v * 2
97
+ value = torch.cat([random_v, double_v], axis=1)
98
+ random_k = torch.randn(1, 1, 5)
99
+ key = torch.cat([random_k, random_k], axis=1)
100
+ query = torch.randn(1, 5)
101
+
102
+ output = ts.attention(key, query, value)
103
+ # when the two keys are equal, we expect the attention to be 50/50
104
+ expected_output = (random_v + double_v) / 2
105
+ assert torch.allclose(output, expected_output)
106
+
107
+ # If the first entry is masked out, the second one should be the
108
+ # only one represented
109
+ mask = torch.zeros(1, 2, dtype=torch.bool)
110
+ mask[0][0] = True
111
+ output = ts.attention(key, query, value, mask)
112
+ assert torch.allclose(output, double_v)
113
+
114
+ # If the second entry is masked out, the first one should be the
115
+ # only one represented
116
+ mask = torch.zeros(1, 2, dtype=torch.bool)
117
+ mask[0][1] = True
118
+ output = ts.attention(key, query, value, mask)
119
+ assert torch.allclose(output, random_v)
120
+
121
+ def test_position():
122
+ """
123
+ Test that nothing goes horribly wrong when position encodings are used
124
+
125
+ Does not actually test the results of the encodings
126
+ """
127
+ ts = TransformerTreeStack(4, 5, 0.0, use_position=True)
128
+ initial = ts.initial_state()
129
+ assert len(initial) == 1
130
+ assert initial.value.output.shape == torch.Size([5])
131
+ assert initial.value.key_stack.shape == torch.Size([1, 5])
132
+ assert initial.value.value_stack.shape == torch.Size([1, 5])
133
+
134
+ rand_input = torch.randn(2, 4)
135
+ one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
136
+ stacks = [one_step, initial]
137
+ stacks = ts.push_states(stacks, ["B", "C"], rand_input)
138
+
139
+ def test_length_limit():
140
+ """
141
+ Test that the length limit drops nodes as the length limit is exceeded
142
+ """
143
+ ts = TransformerTreeStack(4, 5, 0.0, length_limit = 2)
144
+ initial = ts.initial_state()
145
+ assert len(initial) == 1
146
+ assert initial.value.output.shape == torch.Size([5])
147
+ assert initial.value.key_stack.shape == torch.Size([1, 5])
148
+ assert initial.value.value_stack.shape == torch.Size([1, 5])
149
+
150
+ data = torch.tensor([[0.1, 0.2, 0.3, 0.4]])
151
+ stacks = ts.push_states([initial], ["A"], data)
152
+
153
+ stacks = ts.push_states(stacks, ["B"], data)
154
+ assert len(stacks) == 1
155
+ assert len(stacks[0]) == 3
156
+ assert stacks[0].value.key_stack.shape[0] == 3
157
+ assert stacks[0].value.value_stack.shape[0] == 3
158
+
159
+ stacks = ts.push_states(stacks, ["C"], data)
160
+ assert len(stacks) == 1
161
+ assert len(stacks[0]) == 4
162
+ assert stacks[0].value.key_stack.shape[0] == 3
163
+ assert stacks[0].value.value_stack.shape[0] == 3
164
+
165
+ stacks = ts.push_states(stacks, ["D"], data)
166
+ assert len(stacks) == 1
167
+ assert len(stacks[0]) == 5
168
+ assert stacks[0].value.key_stack.shape[0] == 3
169
+ assert stacks[0].value.value_stack.shape[0] == 3
170
+
171
+ def test_two_heads():
172
+ """
173
+ Test that the length limit drops nodes as the length limit is exceeded
174
+ """
175
+ ts = TransformerTreeStack(4, 6, 0.0, num_heads = 2)
176
+ initial = ts.initial_state()
177
+ assert len(initial) == 1
178
+ assert initial.value.output.shape == torch.Size([6])
179
+ assert initial.value.key_stack.shape == torch.Size([1, 6])
180
+ assert initial.value.value_stack.shape == torch.Size([1, 6])
181
+
182
+ rand_input = torch.randn(2, 4)
183
+ one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
184
+ stacks = [one_step, initial]
185
+ stacks = ts.push_states(stacks, ["B", "C"], rand_input)
186
+ assert len(stacks) == 2
187
+ assert len(stacks[0]) == 3
188
+ assert len(stacks[1]) == 2
189
+ assert stacks[0].pop().value.value == 'A'
190
+ assert stacks[0].value.value == 'B'
191
+ assert stacks[1].value.value == 'C'
192
+
193
+ assert stacks[0].value.key_stack.shape == torch.Size([3, 6])
194
+ assert stacks[1].value.key_stack.shape == torch.Size([2, 6])
195
+
stanza/stanza/tests/constituency/test_transition_sequence.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from stanza.models.constituency import parse_transitions
3
+ from stanza.models.constituency import transition_sequence
4
+ from stanza.models.constituency import tree_reader
5
+ from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT
6
+ from stanza.models.constituency.parse_transitions import *
7
+
8
+ from stanza.tests import *
9
+ from stanza.tests.constituency.test_parse_tree import CHINESE_LONG_LIST_TREE
10
+
11
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
12
+
13
+ def reconstruct_tree(tree, sequence, transition_scheme=TransitionScheme.IN_ORDER, unary_limit=UNARY_LIMIT, reverse=False):
14
+ """
15
+ Starting from a tree and a list of transitions, build the tree caused by the transitions
16
+ """
17
+ model = SimpleModel(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse)
18
+ states = model.initial_state_from_gold_trees([tree])
19
+ assert(len(states)) == 1
20
+ assert states[0].num_transitions == 0
21
+
22
+ # TODO: could fold this into parse_sentences (similar to verify_transitions in trainer.py)
23
+ for idx, t in enumerate(sequence):
24
+ assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence)
25
+ states = model.bulk_apply(states, [t])
26
+
27
+ result_tree = states[0].constituents.value
28
+ if reverse:
29
+ result_tree = result_tree.reverse()
30
+ return result_tree
31
+
32
+ def check_reproduce_tree(transition_scheme):
33
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
34
+ trees = tree_reader.read_trees(text)
35
+
36
+ model = SimpleModel(transition_scheme)
37
+ transitions = transition_sequence.build_sequence(trees[0], transition_scheme)
38
+ states = model.initial_state_from_gold_trees(trees)
39
+ assert(len(states)) == 1
40
+ state = states[0]
41
+ assert state.num_transitions == 0
42
+
43
+ for t in transitions:
44
+ assert t.is_legal(state, model)
45
+ state = t.apply(state, model)
46
+
47
+ # one item for the final tree
48
+ # one item for the sentinel at the end
49
+ assert len(state.constituents) == 2
50
+ # the transition sequence should put all of the words
51
+ # from the buffer onto the tree
52
+ # one spot left for the sentinel value
53
+ assert len(state.word_queue) == 8
54
+ assert state.sentence_length == 6
55
+ assert state.word_position == state.sentence_length
56
+ assert len(state.transitions) == len(transitions) + 1
57
+
58
+ result_tree = state.constituents.value
59
+ assert result_tree == trees[0]
60
+
61
+ def test_top_down_unary():
62
+ check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)
63
+
64
+ def test_top_down_no_unary():
65
+ check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN)
66
+
67
+ def test_in_order():
68
+ check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER)
69
+
70
+ def test_in_order_compound():
71
+ check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
72
+
73
+ def test_in_order_unary():
74
+ check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_UNARY)
75
+
76
+ def test_all_transitions():
77
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
78
+ trees = tree_reader.read_trees(text)
79
+ model = SimpleModel()
80
+ transitions = transition_sequence.build_treebank(trees)
81
+
82
+ expected = [Shift(), CloseConstituent(), CompoundUnary("ROOT"), CompoundUnary("SQ"), CompoundUnary("WHNP"), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("SBARQ"), OpenConstituent("VP")]
83
+ assert transition_sequence.all_transitions(transitions) == expected
84
+
85
+
86
+ def test_all_transitions_no_unary():
87
+ text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
88
+ trees = tree_reader.read_trees(text)
89
+ model = SimpleModel()
90
+ transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)
91
+
92
+ expected = [Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("ROOT"), OpenConstituent("SBARQ"), OpenConstituent("SQ"), OpenConstituent("VP"), OpenConstituent("WHNP")]
93
+ assert transition_sequence.all_transitions(transitions) == expected
94
+
95
+ def test_top_down_compound_unary():
96
+ text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
97
+
98
+ trees = tree_reader.read_trees(text)
99
+ assert len(trees) == 1
100
+
101
+ model = SimpleModel()
102
+ transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND)
103
+
104
+ states = model.initial_state_from_gold_trees(trees)
105
+ assert len(states) == 1
106
+ state = states[0]
107
+
108
+ for t in transitions:
109
+ assert t.is_legal(state, model)
110
+ state = t.apply(state, model)
111
+
112
+ result = model.get_top_constituent(state.constituents)
113
+ assert trees[0] == result
114
+
115
+
116
+ def test_chinese_tree():
117
+ trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
118
+
119
+ transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)
120
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)
121
+ assert redone == trees[0]
122
+
123
+ transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER)
124
+ with pytest.raises(AssertionError):
125
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER)
126
+
127
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
128
+ assert redone == trees[0]
129
+
130
+
131
+ def test_chinese_tree_reversed():
132
+ """
133
+ test that the reversed transitions also work
134
+ """
135
+ trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
136
+
137
+ transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)
138
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)
139
+ assert redone == trees[0]
140
+
141
+ with pytest.raises(AssertionError):
142
+ # turn off reverse - it should fail to rebuild the tree
143
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)
144
+ assert redone == trees[0]
145
+
146
+ transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER, reverse=True)
147
+ with pytest.raises(AssertionError):
148
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, reverse=True)
149
+
150
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6, reverse=True)
151
+ assert redone == trees[0]
152
+
153
+ with pytest.raises(AssertionError):
154
+ # turn off reverse - it should fail to rebuild the tree
155
+ redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
156
+ assert redone == trees[0]
stanza/stanza/tests/constituency/test_tree_reader.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from stanza.models.constituency import tree_reader
3
+ from stanza.models.constituency.tree_reader import MixedTreeError, UnclosedTreeError, UnlabeledTreeError
4
+
5
+ from stanza.tests import *
6
+
7
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
8
+
9
+ def test_simple():
10
+ """
11
+ Tests reading two simple trees from the same text
12
+ """
13
+ text = "(VB Unban) (NNP Opal)"
14
+ trees = tree_reader.read_trees(text)
15
+ assert len(trees) == 2
16
+ assert trees[0].is_preterminal()
17
+ assert trees[0].label == 'VB'
18
+ assert trees[0].children[0].label == 'Unban'
19
+ assert trees[1].is_preterminal()
20
+ assert trees[1].label == 'NNP'
21
+ assert trees[1].children[0].label == 'Opal'
22
+
23
+ def test_newlines():
24
+ """
25
+ The same test should work if there are newlines
26
+ """
27
+ text = "(VB Unban)\n\n(NNP Opal)"
28
+ trees = tree_reader.read_trees(text)
29
+ assert len(trees) == 2
30
+
31
+ def test_parens():
32
+ """
33
+ Parens should be escaped in the tree files and escaped when written
34
+ """
35
+ text = "(-LRB- -LRB-) (-RRB- -RRB-)"
36
+ trees = tree_reader.read_trees(text)
37
+ assert len(trees) == 2
38
+
39
+ assert trees[0].label == '-LRB-'
40
+ assert trees[0].children[0].label == '('
41
+ assert "{}".format(trees[0]) == '(-LRB- -LRB-)'
42
+
43
+ assert trees[1].label == '-RRB-'
44
+ assert trees[1].children[0].label == ')'
45
+ assert "{}".format(trees[1]) == '(-RRB- -RRB-)'
46
+
47
+ def test_complicated():
48
+ """
49
+ A more complicated tree that should successfully read
50
+ """
51
+ text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
52
+ trees = tree_reader.read_trees(text)
53
+ assert len(trees) == 1
54
+ tree = trees[0]
55
+ assert not tree.is_leaf()
56
+ assert not tree.is_preterminal()
57
+ assert tree.label == 'ROOT'
58
+ assert len(tree.children) == 1
59
+ assert tree.children[0].label == 'SBARQ'
60
+ assert len(tree.children[0].children) == 3
61
+ assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.']
62
+ # etc etc
63
+
64
+ def test_one_word():
65
+ """
66
+ Check that one node trees are correctly read
67
+
68
+ probably not super relevant for the parsing use case
69
+ """
70
+ text="(FOO) (BAR)"
71
+ trees = tree_reader.read_trees(text)
72
+ assert len(trees) == 2
73
+
74
+ assert trees[0].is_leaf()
75
+ assert trees[0].label == 'FOO'
76
+
77
+ assert trees[1].is_leaf()
78
+ assert trees[1].label == 'BAR'
79
+
80
+ def test_missing_close_parens():
81
+ """
82
+ Test the unclosed error condition
83
+ """
84
+ text = "(Foo) \n (Bar \n zzz"
85
+ try:
86
+ trees = tree_reader.read_trees(text)
87
+ raise AssertionError("Expected an exception")
88
+ except UnclosedTreeError as e:
89
+ assert e.line_num == 1
90
+
91
+ def test_mixed_tree():
92
+ """
93
+ Test the mixed error condition
94
+ """
95
+ text = "(Foo) \n (Bar) \n (Unban (Mox) Opal)"
96
+ try:
97
+ trees = tree_reader.read_trees(text)
98
+ raise AssertionError("Expected an exception")
99
+ except MixedTreeError as e:
100
+ assert e.line_num == 2
101
+
102
+ trees = tree_reader.read_trees(text, broken_ok=True)
103
+ assert len(trees) == 3
104
+
105
+ def test_unlabeled_tree():
106
+ """
107
+ Test the unlabeled error condition
108
+ """
109
+ text = "(ROOT ((Foo) (Bar)))"
110
+ try:
111
+ trees = tree_reader.read_trees(text)
112
+ raise AssertionError("Expected an exception")
113
+ except UnlabeledTreeError as e:
114
+ assert e.line_num == 0
115
+
116
+ trees = tree_reader.read_trees(text, broken_ok=True)
117
+ assert len(trees) == 1
118
+
119
+
stanza/stanza/tests/constituency/test_vietnamese.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A few tests for Vietnamese parsing, which has some difficulties related to spaces in words
3
+
4
+ Technically some other languages can have this, too, like that one French token
5
+ """
6
+
7
+ import os
8
+ import tempfile
9
+
10
+ import pytest
11
+
12
+ from stanza.models.common import pretrain
13
+ from stanza.models.constituency import tree_reader
14
+
15
+ from stanza.tests.constituency.test_trainer import build_trainer
16
+
17
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
18
+
19
+ VI_TREEBANK = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài Loan) (" ") (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'
20
+
21
+ VI_TREEBANK_UNDERSCORE = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài_Loan) (" ") (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .)))'
22
+
23
+ VI_TREEBANK_SIMPLE = '(ROOT (S (NP (" ") (N Đảo) (Np Đài Loan) (" ") (PP (E ở) (NP (N đồng bằng) (NP (N sông) (Np Cửu Long))))) (. .)))'
24
+
25
+ VI_TREEBANK_PAREN = '(ROOT (S-TTL (NP (PUNCT -LRB-) (N-H Đảo) (Np Đài Loan) (PUNCT -RRB-) (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'
26
+ VI_TREEBANK_VLSP = '<s>\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n</s>'
27
+ VI_TREEBANK_VLSP_50 = '<s id=50>\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n</s>'
28
+ VI_TREEBANK_VLSP_100 = '<s id=100>\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n</s>'
29
+
30
+ EXPECTED_LABELED_BRACKETS = '(_ROOT (_S (_NP (_" " )_" (_N Đảo )_N (_Np Đài_Loan )_Np (_" " )_" (_PP (_E ở )_E (_NP (_N đồng_bằng )_N (_NP (_N sông )_N (_Np Cửu_Long )_Np )_NP )_NP )_PP )_NP (_. . )_. )_S )_ROOT'
31
+
32
+
33
+ def test_read_vi_tree():
34
+ """
35
+ Test that an individual tree with spaces in the leaves is being processed as we expect
36
+ """
37
+ text = VI_TREEBANK.split("\n")[0]
38
+ trees = tree_reader.read_trees(text)
39
+ assert len(trees) == 1
40
+ assert str(trees[0]) == text
41
+ # this is the first NP
42
+ # the third node of that NP, eg (Np Đài Loan)
43
+ node = trees[0].children[0].children[0].children[2]
44
+ assert node.is_preterminal()
45
+ assert node.children[0].label == "Đài Loan"
46
+
47
+ VI_EMBEDDING = """
48
+ 4 4
49
+ Đảo 0.11 0.21 0.31 0.41
50
+ Đài Loan 0.12 0.22 0.32 0.42
51
+ đồng bằng 0.13 0.23 0.33 0.43
52
+ sông 0.14 0.24 0.34 0.44
53
+ """.strip()
54
+
55
+ def test_vi_embedding():
56
+ """
57
+ Test that a VI embedding's words are correctly found when processing trees
58
+ """
59
+ text = VI_TREEBANK.split("\n")[0]
60
+ trees = tree_reader.read_trees(text)
61
+ words = set(trees[0].leaf_labels())
62
+
63
+ with tempfile.TemporaryDirectory() as tempdir:
64
+ emb_filename = os.path.join(tempdir, "emb.txt")
65
+ pt_filename = os.path.join(tempdir, "emb.pt")
66
+ with open(emb_filename, "w", encoding="utf-8") as fout:
67
+ fout.write(VI_EMBEDDING)
68
+ pt = pretrain.Pretrain(filename=pt_filename, vec_filename=emb_filename, save_to_file=True)
69
+ pt.load()
70
+
71
+ trainer = build_trainer(pt_filename)
72
+ model = trainer.model
73
+
74
+ assert model.num_words_known(words) == 4
75
+
76
+
77
+ def test_space_formatting():
78
+ """
79
+ By default, spaces are left as spaces, but there is a format option to change spaces
80
+ """
81
+ text = VI_TREEBANK.split("\n")[0]
82
+ trees = tree_reader.read_trees(text)
83
+ assert len(trees) == 1
84
+ assert str(trees[0]) == text
85
+
86
+ assert "{}".format(trees[0]) == VI_TREEBANK
87
+ assert "{:_O}".format(trees[0]) == VI_TREEBANK_UNDERSCORE
88
+
89
+ def test_vlsp_formatting():
90
+ text = VI_TREEBANK_PAREN.split("\n")[0]
91
+ trees = tree_reader.read_trees(text)
92
+ assert len(trees) == 1
93
+ assert str(trees[0]) == text
94
+
95
+ assert "{:_V}".format(trees[0]) == VI_TREEBANK_VLSP
96
+ trees[0].tree_id = 50
97
+ assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_50
98
+ trees[0].tree_id = 100
99
+ assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_100
100
+
101
+ empty = tree_reader.read_trees("(ROOT)")[0]
102
+ with pytest.raises(ValueError):
103
+ "{:V}".format(empty)
104
+
105
+ branches = tree_reader.read_trees("(ROOT (1) (2) (3))")[0]
106
+ with pytest.raises(ValueError):
107
+ "{:V}".format(branches)
108
+
109
+ def test_language_formatting():
110
+ """
111
+ Test turning the parse tree into a 'language' for GPT
112
+ """
113
+ text = VI_TREEBANK.split("\n")[0]
114
+ trees = tree_reader.read_trees(text)
115
+ trees = [t.prune_none().simplify_labels() for t in trees]
116
+ assert len(trees) == 1
117
+ assert str(trees[0]) == VI_TREEBANK_SIMPLE
118
+
119
+ text = "{:L}".format(trees[0])
120
+ assert text == EXPECTED_LABELED_BRACKETS
121
+
stanza/stanza/tests/langid/test_langid.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic tests of langid module
3
+ """
4
+
5
+ import pytest
6
+
7
+ from stanza.models.common.doc import Document
8
+ from stanza.pipeline.core import Pipeline
9
+ from stanza.pipeline.langid_processor import LangIDProcessor
10
+ from stanza.tests import TEST_MODELS_DIR
11
+
12
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
13
+
14
+ #pytestmark = pytest.mark.skip
15
+
16
+ @pytest.fixture(scope="module")
17
+ def basic_multilingual():
18
+ return Pipeline(dir=TEST_MODELS_DIR, lang='multilingual', processors="langid")
19
+
20
+ @pytest.fixture(scope="module")
21
+ def enfr_multilingual():
22
+ return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en", "fr"])
23
+
24
+ @pytest.fixture(scope="module")
25
+ def en_multilingual():
26
+ return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"])
27
+
28
+ @pytest.fixture(scope="module")
29
+ def clean_multilingual():
30
+ return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_clean_text=True)
31
+
32
+ def test_langid(basic_multilingual):
33
+ """
34
+ Basic test of language identification
35
+ """
36
+ english_text = "This is an English sentence."
37
+ french_text = "C'est une phrase française."
38
+ docs = [english_text, french_text]
39
+
40
+ docs = [Document([], text=text) for text in docs]
41
+ basic_multilingual(docs)
42
+ predictions = [doc.lang for doc in docs]
43
+ assert predictions == ["en", "fr"]
44
+
45
+ def test_langid_benchmark(basic_multilingual):
46
+ """
47
+ Run lang id model on 500 examples, confirm reasonable accuracy.
48
+ """
49
+ examples = [
50
+ {"text": "contingentiam in naturalibus causis.", "label": "la"},
51
+ {"text": "I jak opowiadał nieżyjący już pan Czesław", "label": "pl"},
52
+ {"text": "Sonera gilt seit längerem als Übernahmekandidat", "label": "de"},
53
+ {"text": "与银类似,汞也可以与空气中的硫化氢反应。", "label": "zh-hans"},
54
+ {"text": "contradictionem implicat.", "label": "la"},
55
+ {"text": "Bis zu Prozent gingen die Offerten etwa im", "label": "de"},
56
+ {"text": "inneren Sicherheit vorgeschlagene Ausweitung der", "label": "de"},
57
+ {"text": "Multimedia-PDA mit Mini-Tastatur", "label": "de"},
58
+ {"text": "Ponášalo sa to na rovnicu o dvoch neznámych.", "label": "sk"},
59
+ {"text": "이처럼 앞으로 심판의 그 날에 다시 올 메시아가 예수 그리스도이며 , 그는 모든 인류의", "label": "ko"},
60
+ {"text": "Die Arbeitsgruppe bedauert , dass der weit über", "label": "de"},
61
+ {"text": "И только раз довелось поговорить с ним не вполне", "label": "ru"},
62
+ {"text": "de a-l lovi cu piciorul și conștiința că era", "label": "ro"},
63
+ {"text": "relación coas pretensións do demandante e que, nos", "label": "gl"},
64
+ {"text": "med petdeset in sedemdeset", "label": "sl"},
65
+ {"text": "Catalunya; el Consell Comarcal del Vallès Oriental", "label": "ca"},
66
+ {"text": "kunnen worden.", "label": "nl"},
67
+ {"text": "Witkin je ve většině ohledů zcela jiný.", "label": "cs"},
68
+ {"text": "lernen, so zu agieren, dass sie positive oder auch", "label": "de"},
69
+ {"text": "olurmuş...", "label": "tr"},
70
+ {"text": "sarcasmo de Altman, desde as «peruas» que discutem", "label": "pt"},
71
+ {"text": "خلاف فوجداری مقدمہ درج کرے۔", "label": "ur"},
72
+ {"text": "Norddal kommune :", "label": "no"},
73
+ {"text": "dem Windows-.-Zeitalter , soll in diesem Jahr", "label": "de"},
74
+ {"text": "przeklętych ucieleśniają mit poety-cygana,", "label": "pl"},
75
+ {"text": "We do not believe the suspect has ties to this", "label": "en"},
76
+ {"text": "groziņu pīšanu.", "label": "lv"},
77
+ {"text": "Senior Vice-President David M. Thomas möchte", "label": "de"},
78
+ {"text": "neomylně vybral nějakou knihu a začetl se.", "label": "cs"},
79
+ {"text": "Statt dessen darf beispielsweise der Browser des", "label": "de"},
80
+ {"text": "outubro, alcançando R $ bilhões em .", "label": "pt"},
81
+ {"text": "(Porte, ), as it does other disciplines", "label": "en"},
82
+ {"text": "uskupení se mylně domnívaly, že podporu", "label": "cs"},
83
+ {"text": "Übernahme von Next Ende an dem System herum , das", "label": "de"},
84
+ {"text": "No podemos decir a la Hacienda que los alemanes", "label": "es"},
85
+ {"text": "и рѣста еи братья", "label": "orv"},
86
+ {"text": "الذي اتخذ قرارا بتجميد اعلان الدولة الفلسطينية", "label": "ar"},
87
+ {"text": "uurides Rootsi sõjaarhiivist toodud . sajandi", "label": "et"},
88
+ {"text": "selskapets penger til å pusse opp sin enebolig på", "label": "no"},
89
+ {"text": "средней полосе и севернее в Ярославской,", "label": "ru"},
90
+ {"text": "il-massa żejda fil-ġemgħat u superġemgħat ta'", "label": "mt"},
91
+ {"text": "The Global Beauties on internetilehekülg, mida", "label": "et"},
92
+ {"text": "이스라엘 인들은 하나님이 그 큰 팔을 펴 이집트 인들을 치는 것을 보고 하나님을 두려워하며", "label": "ko"},
93
+ {"text": "Snad ještě dodejme jeden ekonomický argument.", "label": "cs"},
94
+ {"text": "Spalio d. vykusiame pirmajame rinkimų ture", "label": "lt"},
95
+ {"text": "und schlechter Journalismus ein gutes Geschäft .", "label": "de"},
96
+ {"text": "Du sodiečiai sėdi ant potvynio apsemtų namų stogo.", "label": "lt"},
97
+ {"text": "цей є автентичним.", "label": "uk"},
98
+ {"text": "Și îndegrabă fu cu îngerul mulțime de șireaguri", "label": "ro"},
99
+ {"text": "sobra personal cualificado.", "label": "es"},
100
+ {"text": "Tako se u Njemačkoj dvije trećine liječnika služe", "label": "hr"},
101
+ {"text": "Dual-Athlon-Chipsatz noch in diesem Jahr", "label": "de"},
102
+ {"text": "यहां तक कि चीन के चीफ ऑफ जनरल स्टाफ भी भारत का", "label": "hi"},
103
+ {"text": "Li forestier du mont avale", "label": "fro"},
104
+ {"text": "Netzwerken für Privatanwender zu bewundern .", "label": "de"},
105
+ {"text": "만해는 승적을 가진 중이 결혼할 수 없다는 불교의 계율을 시대에 맞지 않는 것으로 보았다", "label": "ko"},
106
+ {"text": "balance and weight distribution but not really for", "label": "en"},
107
+ {"text": "og så e # tente vi opp den om morgonen å sfyrte", "label": "nn"},
108
+ {"text": "변화는 의심의 여지가 없는 것이지만 반면에 진화는 논쟁의 씨앗이다 .", "label": "ko"},
109
+ {"text": "puteare fac aceastea.", "label": "ro"},
110
+ {"text": "Waitt seine Führungsmannschaft nicht dem", "label": "de"},
111
+ {"text": "juhtimisega, tulid sealt.", "label": "et"},
112
+ {"text": "Veränderungen .", "label": "de"},
113
+ {"text": "banda en el Bayer Leverkusen de la Bundesliga de", "label": "es"},
114
+ {"text": "В туже зиму посла всеволодъ сн҃а своѥго ст҃ослава", "label": "orv"},
115
+ {"text": "пославъ приведе я мастеры ѿ грекъ", "label": "orv"},
116
+ {"text": "En un nou escenari difícil d'imaginar fa poques", "label": "ca"},
117
+ {"text": "καὶ γὰρ τινὲς αὐτοὺς εὐεργεσίαι εἶχον ἐκ Κροίσου", "label": "grc"},
118
+ {"text": "직접적인 관련이 있다 .", "label": "ko"},
119
+ {"text": "가까운 듯하면서도 멀다 .", "label": "ko"},
120
+ {"text": "Er bietet ein ähnliches Leistungsniveau und", "label": "de"},
121
+ {"text": "民都洛水牛是獨居的,並不會以群族聚居。", "label": "zh-hant"},
122
+ {"text": "την τρομοκρατία.", "label": "el"},
123
+ {"text": "hurbiltzen diren neurrian.", "label": "eu"},
124
+ {"text": "Ah dimenticavo, ma tutta sta caciara per fare un", "label": "it"},
125
+ {"text": "На первом этапе (-) прошла так называемая", "label": "ru"},
126
+ {"text": "of games are on the market.", "label": "en"},
127
+ {"text": "находится Мост дружбы, соединяющий узбекский и", "label": "ru"},
128
+ {"text": "lessié je voldroie que li saint fussent aporté", "label": "fro"},
129
+ {"text": "Дошла очередь и до Гималаев.", "label": "ru"},
130
+ {"text": "vzácným suknem táhly pouští, si jednou chtěl do", "label": "cs"},
131
+ {"text": "E no terceiro tipo sitúa a familias (%), nos que a", "label": "gl"},
132
+ {"text": "وجابت دوريات امريكية وعراقية شوارع المدينة، فيما", "label": "ar"},
133
+ {"text": "Jeg har bodd her i år .", "label": "no"},
134
+ {"text": "Pohrozil, že odbory zostří postoj, pokud se", "label": "cs"},
135
+ {"text": "tinham conseguido.", "label": "pt"},
136
+ {"text": "Nicht-Erkrankten einen Anfangsverdacht für einen", "label": "de"},
137
+ {"text": "permanece em aberto.", "label": "pt"},
138
+ {"text": "questi possono promettere rendimenti fino a un", "label": "it"},
139
+ {"text": "Tema juurutatud kahevedurisüsteemita oleksid", "label": "et"},
140
+ {"text": "Поведение внешне простой игрушки оказалось", "label": "ru"},
141
+ {"text": "Bundesländern war vom Börsenverein des Deutschen", "label": "de"},
142
+ {"text": "acció, 'a mesura que avanci l'estiu, amb l'augment", "label": "ca"},
143
+ {"text": "Dove trovare queste risorse? Jay Naidoo, ministro", "label": "it"},
144
+ {"text": "essas gordurinhas.", "label": "pt"},
145
+ {"text": "Im zweiten Schritt sollen im übernächsten Jahr", "label": "de"},
146
+ {"text": "allveelaeva pole enam vaja, kuna külm sõda on läbi", "label": "et"},
147
+ {"text": "उपद्रवी दुकानों को लूटने के साथ ही उनमें आग लगा", "label": "hi"},
148
+ {"text": "@user nella sfortuna sei fortunata ..", "label": "it"},
149
+ {"text": "математических школ в виде грозовых туч.", "label": "ru"},
150
+ {"text": "No cambiaremos nunca nuestra forma de jugar por un", "label": "es"},
151
+ {"text": "dla tej klasy ani wymogów minimalnych, z wyjątkiem", "label": "pl"},
152
+ {"text": "en todo el mundo, mientras que en España consiguió", "label": "es"},
153
+ {"text": "политики считать надежное обеспечение военной", "label": "ru"},
154
+ {"text": "gogoratzen du, genio alemana delakoaren", "label": "eu"},
155
+ {"text": "Бычий глаз.", "label": "ru"},
156
+ {"text": "Opeření se v pravidelných obdobích obnovuje", "label": "cs"},
157
+ {"text": "I no és només la seva, es tracta d'una resposta", "label": "ca"},
158
+ {"text": "오경을 가르쳤다 .", "label": "ko"},
159
+ {"text": "Nach der so genannten Start-up-Periode vergibt die", "label": "de"},
160
+ {"text": "Saulista huomasi jo lapsena , että hänellä on", "label": "fi"},
161
+ {"text": "Министерство культуры сочло нецелесообразным, и", "label": "ru"},
162
+ {"text": "znepřátelené tábory v Tádžikistánu předseda", "label": "cs"},
163
+ {"text": "καὶ ἦν ὁ λαὸς προσδοκῶν τὸν Ζαχαρίαν καὶ ἐθαύμαζον", "label": "grc"},
164
+ {"text": "Вечером, в продукте, этот же человек говорил о", "label": "ru"},
165
+ {"text": "lugar á formación de xuizos máis complexos.", "label": "gl"},
166
+ {"text": "cheaper, in the end?", "label": "en"},
167
+ {"text": "الوزارة في شأن صفقات بيع الشركات العامة التي تم", "label": "ar"},
168
+ {"text": "tärkeintä elämässäni .", "label": "fi"},
169
+ {"text": "Виконання Мінських угод було заблоковано Росією та", "label": "uk"},
170
+ {"text": "Aby szybko rozpoznać żołnierzy desantu, należy", "label": "pl"},
171
+ {"text": "Bankengeschäfte liegen vorn , sagte Strothmann .", "label": "de"},
172
+ {"text": "продолжение работы.", "label": "ru"},
173
+ {"text": "Metro AG plant Online-Offensive", "label": "de"},
174
+ {"text": "nu vor veni, și să vor osîndi, aceia nu pot porni", "label": "ro"},
175
+ {"text": "Ich denke , es geht in Wirklichkeit darum , NT bei", "label": "de"},
176
+ {"text": "de turism care încasează contravaloarea", "label": "ro"},
177
+ {"text": "Aurkaria itotzea da helburua, baloia lapurtu eta", "label": "eu"},
178
+ {"text": "com a centre de formació en Tecnologies de la", "label": "ca"},
179
+ {"text": "oportet igitur quod omne agens in agendo intendat", "label": "la"},
180
+ {"text": "Jerzego Andrzejewskiego, oparty na chińskich", "label": "pl"},
181
+ {"text": "sau một vài câu chuyện xã giao không dính dáng tới", "label": "vi"},
182
+ {"text": "что экономическому прорыву жесткий авторитарный", "label": "ru"},
183
+ {"text": "DRAM-Preisen scheinen DSPs ein", "label": "de"},
184
+ {"text": "Jos dajan nubbái: Mana!", "label": "sme"},
185
+ {"text": "toți carii ascultară de el să răsipiră.", "label": "ro"},
186
+ {"text": "odpowiedzialności, które w systemie własności", "label": "pl"},
187
+ {"text": "Dvomesečno potovanje do Mollenda v Peruju je", "label": "sl"},
188
+ {"text": "d'entre les agències internacionals.", "label": "ca"},
189
+ {"text": "Fahrzeugzugangssysteme gefertigt und an viele", "label": "de"},
190
+ {"text": "in an answer to the sharers' petition in Cuthbert", "label": "en"},
191
+ {"text": "Europa-Domain per Verordnung zu regeln .", "label": "de"},
192
+ {"text": "#Balotelli. Su ebay prezzi stracciati per Silvio", "label": "it"},
193
+ {"text": "Ne na košickém trávníku, ale už včera v letadle se", "label": "cs"},
194
+ {"text": "zaměstnanosti a investičních strategií.", "label": "cs"},
195
+ {"text": "Tatínku, udělej den", "label": "cs"},
196
+ {"text": "frecuencia con Mary.", "label": "es"},
197
+ {"text": "Свеаборге.", "label": "ru"},
198
+ {"text": "opatření slovenské strany o certifikaci nejvíce", "label": "cs"},
199
+ {"text": "En todas me decían: 'Espera que hagamos un estudio", "label": "es"},
200
+ {"text": "Die Demonstration sollte nach Darstellung der", "label": "de"},
201
+ {"text": "Ci vorrà un assoluto rigore se dietro i disavanzi", "label": "it"},
202
+ {"text": "Tatínku, víš, že Honzovi odešla maminka?", "label": "cs"},
203
+ {"text": "Die Anzahl der Rechner wuchs um % auf und die", "label": "de"},
204
+ {"text": "האמריקאית על אדמת סעודיה עלולה לסבך את ישראל, אין", "label": "he"},
205
+ {"text": "Volán Egyesülés, a Közlekedési Főfelügyelet is.", "label": "hu"},
206
+ {"text": "Schejbala, který stejnou hru s velkým úspěchem", "label": "cs"},
207
+ {"text": "depends on the data type of the field.", "label": "en"},
208
+ {"text": "Umsatzwarnung zu Wochenbeginn zeitweise auf ein", "label": "de"},
209
+ {"text": "niin heti nukun .", "label": "fi"},
210
+ {"text": "Mobilfunkunternehmen gegen die Anwendung der so", "label": "de"},
211
+ {"text": "sapessi le intenzioni del governo Monti e dell'UE", "label": "it"},
212
+ {"text": "Di chi è figlia Martine Aubry?", "label": "it"},
213
+ {"text": "avec le reste du monde.", "label": "fr"},
214
+ {"text": "Այդ մաքոքը ինքնին նոր չէ, աշխարհը արդեն մի քանի", "label": "hy"},
215
+ {"text": "și în cazul destrămării cenaclului.", "label": "ro"},
216
+ {"text": "befriedigen kann , und ohne die auftretenden", "label": "de"},
217
+ {"text": "Κύκνον τ̓ ἐξεναρεῖν καὶ ἀπὸ κλυτὰ τεύχεα δῦσαι.", "label": "grc"},
218
+ {"text": "færdiguddannede.", "label": "da"},
219
+ {"text": "Schmidt war Sohn eines Rittergutsbesitzers.", "label": "de"},
220
+ {"text": "и вдаша попадь ѡпрати", "label": "orv"},
221
+ {"text": "cine nu știe învățătură”.", "label": "ro"},
222
+ {"text": "détacha et cette dernière tenta de tuer le jeune", "label": "fr"},
223
+ {"text": "Der har saka også ei lengre forhistorie.", "label": "nn"},
224
+ {"text": "Pieprz roztłuc w moździerzu, dodać do pasty,", "label": "pl"},
225
+ {"text": "Лежа за гребнем оврага, как за бруствером, Ушаков", "label": "ru"},
226
+ {"text": "gesucht habe, vielen Dank nochmals!", "label": "de"},
227
+ {"text": "инструментальных сталей, повышения", "label": "ru"},
228
+ {"text": "im Halbfinale Patrick Smith und im Finale dann", "label": "de"},
229
+ {"text": "البنوك التريث في منح تسهيلات جديدة لمنتجي حديد", "label": "ar"},
230
+ {"text": "una bolsa ventral, la cual se encuentra debajo de", "label": "es"},
231
+ {"text": "za SETimes.", "label": "sr"},
232
+ {"text": "de Irak, a un piloto italiano que había violado el", "label": "es"},
233
+ {"text": "Er könne sich nicht erklären , wie die Zeitung auf", "label": "de"},
234
+ {"text": "Прохорова.", "label": "ru"},
235
+ {"text": "la democrazia perde sulla tecnocrazia? #", "label": "it"},
236
+ {"text": "entre ambas instituciones, confirmó al medio que", "label": "es"},
237
+ {"text": "Austlandet, vart det funne om lag førti", "label": "nn"},
238
+ {"text": "уровнями власти.", "label": "ru"},
239
+ {"text": "Dá tedy primáři úplatek, a často ne malý.", "label": "cs"},
240
+ {"text": "brillantes del acto, al llevar a cabo en el", "label": "es"},
241
+ {"text": "eee druga zadeva je majhen priročen gre kamorkoli", "label": "sl"},
242
+ {"text": "Das ATX-Board paßt in herkömmliche PC-ATX-Gehäuse", "label": "de"},
243
+ {"text": "Za vodné bylo v prvním pololetí zaplaceno v ČR", "label": "cs"},
244
+ {"text": "Даже на полсантиметра.", "label": "ru"},
245
+ {"text": "com la del primer tinent d'alcalde en funcions,", "label": "ca"},
246
+ {"text": "кількох оповідань в цілості — щось на зразок того", "label": "uk"},
247
+ {"text": "sed ad divitias congregandas, vel superfluum", "label": "la"},
248
+ {"text": "Norma Talmadge, spela mot Valentino i en version", "label": "sv"},
249
+ {"text": "Dlatego chciał się jej oświadczyć w niezwykłym", "label": "pl"},
250
+ {"text": "будут выступать на одинаковых снарядах.", "label": "ru"},
251
+ {"text": "Orang-orang terbunuh di sana.", "label": "id"},
252
+ {"text": "لدى رايت شقيق اسمه أوسكار, وهو يعمل كرسام للكتب", "label": "ar"},
253
+ {"text": "Wirklichkeit verlagerten und kaum noch", "label": "de"},
254
+ {"text": "как перемешивают костяшки перед игрой в домино, и", "label": "ru"},
255
+ {"text": "В средине дня, когда солнце светило в нашу", "label": "ru"},
256
+ {"text": "d'aventure aux rôles de jeune romantique avec une", "label": "fr"},
257
+ {"text": "My teď hledáme organizace, jež by s námi chtěly", "label": "cs"},
258
+ {"text": "Urteilsfähigkeit einbüßen , wenn ich eigene", "label": "de"},
259
+ {"text": "sua appartenenza anche a voci diverse da quella in", "label": "it"},
260
+ {"text": "Aufträge dieses Jahr verdoppeln werden .", "label": "de"},
261
+ {"text": "M.E.: Miała szanse mnie odnaleźć, gdyby naprawdę", "label": "pl"},
262
+ {"text": "secundum contactum virtutis, cum careat dimensiva", "label": "la"},
263
+ {"text": "ezinbestekoa dela esan zuen.", "label": "eu"},
264
+ {"text": "Anek hurbiltzeko eskatzen zion besaulkitik, eta", "label": "eu"},
265
+ {"text": "perfectius alio videat, quamvis uterque videat", "label": "la"},
266
+ {"text": "Die Strecke war anspruchsvoll und führte unter", "label": "de"},
267
+ {"text": "саморазоблачительным уроком, западные СМИ не", "label": "ru"},
268
+ {"text": "han representerer radikal islamisme .", "label": "no"},
269
+ {"text": "Què s'hi respira pel que fa a la reforma del", "label": "ca"},
270
+ {"text": "previsto para também ser desconstruido.", "label": "pt"},
271
+ {"text": "Ὠκεανοῦ βαθυκόλποις ἄνθεά τ̓ αἰνυμένην, ῥόδα καὶ", "label": "grc"},
272
+ {"text": "para jovens de a anos nos Cieps.", "label": "pt"},
273
+ {"text": "संघर्ष को अंजाम तक पहुंचाने का ऐलान किया है ।", "label": "hi"},
274
+ {"text": "objeví i u nás.", "label": "cs"},
275
+ {"text": "kvitteringer.", "label": "da"},
276
+ {"text": "This report is no exception.", "label": "en"},
277
+ {"text": "Разлепват доносниците до избирателните списъци", "label": "bg"},
278
+ {"text": "anderem ihre Bewegungsfreiheit in den USA", "label": "de"},
279
+ {"text": "Ñu tegoon ca kaw gor ña ay njotti bopp yu kenn", "label": "wo"},
280
+ {"text": "Struktur kann beispielsweise der Schwerpunkt mehr", "label": "de"},
281
+ {"text": "% la velocidad permitida, la sanción es muy grave.", "label": "es"},
282
+ {"text": "Teles-Einstieg in ADSL-Markt", "label": "de"},
283
+ {"text": "ettekäändeks liiga suure osamaksu.", "label": "et"},
284
+ {"text": "als Indiz für die geänderte Marktpolitik des", "label": "de"},
285
+ {"text": "quod quidem aperte consequitur ponentes", "label": "la"},
286
+ {"text": "de negociación para el próximo de junio.", "label": "es"},
287
+ {"text": "Tyto důmyslné dekorace doznaly v poslední době", "label": "cs"},
288
+ {"text": "največjega uspeha doslej.", "label": "sl"},
289
+ {"text": "Paul Allen je jedan od suosnivača Interval", "label": "hr"},
290
+ {"text": "Federal (Seac / DF) eo Sindicato das Empresas de", "label": "pt"},
291
+ {"text": "Quartal mit . Mark gegenüber dem gleichen Quartal", "label": "de"},
292
+ {"text": "otros clubes y del Barça B saldrán varios", "label": "es"},
293
+ {"text": "Jaskula (Pol.) -", "label": "cs"},
294
+ {"text": "umožnily říci, že je možné přejít k mnohem", "label": "cs"},
295
+ {"text": "اعلن الجنرال تومي فرانكس قائد القوات الامريكية", "label": "ar"},
296
+ {"text": "Telekom-Chef Ron Sommer und der Vorstandssprecher", "label": "de"},
297
+ {"text": "My, jako průmyslový a finanční holding, můžeme", "label": "cs"},
298
+ {"text": "voorlichting onder andere betrekking kan hebben:", "label": "nl"},
299
+ {"text": "Hinrichtung geistig Behinderter applaudiert oder", "label": "de"},
300
+ {"text": "wie beispielsweise Anzahl erzielte Klicks ,", "label": "de"},
301
+ {"text": "Intel-PC-SDRAM-Spezifikation in der Version . (", "label": "de"},
302
+ {"text": "plângere în termen de zile de la comunicarea", "label": "ro"},
303
+ {"text": "и Испания ще изгубят втория си комисар в ЕК.", "label": "bg"},
304
+ {"text": "इसके चलते इस आदिवासी जनजाति का क्षरण हो रहा है ।", "label": "hi"},
305
+ {"text": "aunque se mostró contrario a establecer un", "label": "es"},
306
+ {"text": "des letzten Jahres von auf Millionen Euro .", "label": "de"},
307
+ {"text": "Ankara se također poziva da u cijelosti ratificira", "label": "hr"},
308
+ {"text": "herunterlädt .", "label": "de"},
309
+ {"text": "стрессовую ситуацию для организма, каковой", "label": "ru"},
310
+ {"text": "Státního shromáždění (parlamentu).", "label": "cs"},
311
+ {"text": "diskutieren , ob und wie dieser Dienst weiterhin", "label": "de"},
312
+ {"text": "Verbindungen zu FPÖ-nahen Polizisten gepflegt und", "label": "de"},
313
+ {"text": "Pražského volebního lídra ovšem nevybírá Miloš", "label": "cs"},
314
+ {"text": "Nach einem Bericht der Washington Post bleibt das", "label": "de"},
315
+ {"text": "للوضع آنذاك، لكني في قرارة نفسي كنت سعيداً لما", "label": "ar"},
316
+ {"text": "не желаят запазването на статуквото.", "label": "bg"},
317
+ {"text": "Offenburg gewesen .", "label": "de"},
318
+ {"text": "ἐὰν ὑμῖν εἴπω οὐ μὴ πιστεύσητε", "label": "grc"},
319
+ {"text": "all'odiato compagno di squadra Prost, il quale", "label": "it"},
320
+ {"text": "historischen Gänselieselbrunnens.", "label": "de"},
321
+ {"text": "למידע מלווייני הריגול האמריקאיים העוקבים אחר", "label": "he"},
322
+ {"text": "οὐδὲν ἄρα διαφέρεις Ἀμάσιος τοῦ Ἠλείου, ὃν", "label": "grc"},
323
+ {"text": "movementos migratorios.", "label": "gl"},
324
+ {"text": "Handy und ein Spracherkennungsprogramm sämtliche", "label": "de"},
325
+ {"text": "Kümne aasta jooksul on Eestisse ohjeldamatult", "label": "et"},
326
+ {"text": "H.G. Bücknera.", "label": "pl"},
327
+ {"text": "protiv krijumčarenja, ili pak traženju ukidanja", "label": "hr"},
328
+ {"text": "Topware-Anteile mehrere Millionen Mark gefordert", "label": "de"},
329
+ {"text": "Maar de mensen die nu over Van Dijk bij FC Twente", "label": "nl"},
330
+ {"text": "poidan experimentar as percepcións do interesado,", "label": "gl"},
331
+ {"text": "Miał przecież w kieszeni nóż.", "label": "pl"},
332
+ {"text": "Avšak žádná z nich nepronikla za hranice přímé", "label": "cs"},
333
+ {"text": "esim. helpottamalla luottoja muiden", "label": "fi"},
334
+ {"text": "Podle předběžných výsledků zvítězila v", "label": "cs"},
335
+ {"text": "Nicht nur das Web-Frontend , auch die", "label": "de"},
336
+ {"text": "Regierungsinstitutionen oder Universitäten bei", "label": "de"},
337
+ {"text": "Խուլեն Լոպետեգիին, պատճառաբանելով, որ վերջինս", "label": "hy"},
338
+ {"text": "Афганистана, где в последние дни идут ожесточенные", "label": "ru"},
339
+ {"text": "лѧхове же не идоша", "label": "orv"},
340
+ {"text": "Mit Hilfe von IBMs Chip-Management-Systemen sollen", "label": "de"},
341
+ {"text": ", als Manager zu Telefonica zu wechseln .", "label": "de"},
342
+ {"text": "którym zajmuje się człowiek, zmienia go i pozwala", "label": "pl"},
343
+ {"text": "činí kyperských liber, to je asi USD.", "label": "cs"},
344
+ {"text": "Studienplätze getauscht werden .", "label": "de"},
345
+ {"text": "учёных, орнитологов признают вид.", "label": "ru"},
346
+ {"text": "acordare a concediilor prevăzute de legislațiile", "label": "ro"},
347
+ {"text": "at større innsats for fornybar, berekraftig energi", "label": "nn"},
348
+ {"text": "Politiet veit ikkje kor mange personar som deltok", "label": "nn"},
349
+ {"text": "offentligheten av unge , sinte menn som har", "label": "no"},
350
+ {"text": "însuși în jurul lapunei, care încet DISPARE în", "label": "ro"},
351
+ {"text": "O motivo da decisão é evitar uma sobrecarga ainda", "label": "pt"},
352
+ {"text": "El Apostolado de la prensa contribuye en modo", "label": "es"},
353
+ {"text": "Teltow ( Kreis Teltow-Fläming ) ist Schmitt einer", "label": "de"},
354
+ {"text": "grozījumus un iesniegt tos Apvienoto Nāciju", "label": "lv"},
355
+ {"text": "Gestalt einer deutschen Nationalmannschaft als", "label": "de"},
356
+ {"text": "D überholt zu haben , konterte am heutigen Montag", "label": "de"},
357
+ {"text": "Softwarehersteller Oracle hat im dritten Quartal", "label": "de"},
358
+ {"text": "Během nich se ekonomické podmínky mohou radikálně", "label": "cs"},
359
+ {"text": "Dziki kot w górach zeskakuje z kamienia.", "label": "pl"},
360
+ {"text": "Ačkoliv ligový nováček prohrál, opět potvrdil, že", "label": "cs"},
361
+ {"text": "des Tages , Portraits internationaler Stars sowie", "label": "de"},
362
+ {"text": "Communicator bekannt wurde .", "label": "de"},
363
+ {"text": "τῷ δ’ ἄρα καὶ αὐτῷ ἡ γυνή ἐπίτεξ ἐοῦσα πᾶσαν", "label": "grc"},
364
+ {"text": "Triadú tenia, mentre redactava 'Dies de memòria',", "label": "ca"},
365
+ {"text": "دسته‌جمعی در درخشندگی ماه سیم‌گون زمزمه ستاینده و", "label": "fa"},
366
+ {"text": "Книгу, наполненную мелочной заботой об одежде,", "label": "ru"},
367
+ {"text": "putares canem leporem persequi.", "label": "la"},
368
+ {"text": "В дальнейшем эта яркость слегка померкла, но в", "label": "ru"},
369
+ {"text": "offizielles Verfahren gegen die Telekom", "label": "de"},
370
+ {"text": "podrían haber sido habitantes de la Península", "label": "es"},
371
+ {"text": "Grundlage für dieses Verfahren sind spezielle", "label": "de"},
372
+ {"text": "Rechtsausschuß vorgelegten Entwurf der Richtlinie", "label": "de"},
373
+ {"text": "Im so genannten Portalgeschäft sei das Unternehmen", "label": "de"},
374
+ {"text": "ⲏ ⲉⲓϣⲁⲛϥⲓ ⲛⲉⲓⲇⲱⲗⲟⲛ ⲉⲧϩⲙⲡⲉⲕⲏⲓ ⲙⲏ ⲉⲓⲛⲁϣϩⲱⲡ ⲟⲛ ⲙⲡⲣⲏ", "label": "cop"},
375
+ {"text": "juego podían matar a cualquier herbívoro, pero", "label": "es"},
376
+ {"text": "Nach Angaben von Axent nutzen Unternehmen aus der", "label": "de"},
377
+ {"text": "hrdiny Havlovy Zahradní slavnosti (premiéra ) se", "label": "cs"},
378
+ {"text": "Een zin van heb ik jou daar", "label": "nl"},
379
+ {"text": "hat sein Hirn an der CeBIT-Kasse vergessen .", "label": "de"},
380
+ {"text": "καὶ τοὺς ἐκπλαγέντας οὐκ ἔχειν ἔτι ἐλεγχομένους", "label": "grc"},
381
+ {"text": "nachgewiesenen langfristigen Kosten , sowie den im", "label": "de"},
382
+ {"text": "jučer nakon četiri dana putovanja u Helsinki.", "label": "hr"},
383
+ {"text": "pašto paslaugos teikėjas gali susitarti su", "label": "lt"},
384
+ {"text": "В результате, эти золотые кадры переходят из одной", "label": "ru"},
385
+ {"text": "द फाइव-ईयर एंगेजमेंट में अभिनय किया जिसमें जैसन", "label": "hi"},
386
+ {"text": "výpis o počtu akcií.", "label": "cs"},
387
+ {"text": "Enfin, elles arrivent à un pavillon chinois", "label": "fr"},
388
+ {"text": "Tentu saja, tren yang berhubungandengan", "label": "id"},
389
+ {"text": "Arbeidarpartiet og SV har sikra seg fleirtal mot", "label": "nn"},
390
+ {"text": "eles: 'Tudo isso está errado' , disse um", "label": "pt"},
391
+ {"text": "The islands are in their own time zone, minutes", "label": "en"},
392
+ {"text": "Auswahl debütierte er am .", "label": "de"},
393
+ {"text": "Bu komisyonlar, arazilerini satın almak için", "label": "tr"},
394
+ {"text": "Geschütze gegen Redmond aufgefahren .", "label": "de"},
395
+ {"text": "Time scything the hours, but at the top, over the", "label": "en"},
396
+ {"text": "Di musim semi , berharap mengadaptasi Tintin untuk", "label": "id"},
397
+ {"text": "крупнейшей геополитической катастрофой XX века.", "label": "ru"},
398
+ {"text": "Rajojen avaaminen ei suju ongelmitta .", "label": "fi"},
399
+ {"text": "непроницаемым, как для СССР.", "label": "ru"},
400
+ {"text": "Ma non mancano le polemiche.", "label": "it"},
401
+ {"text": "Internet als Ort politischer Diskussion und auch", "label": "de"},
402
+ {"text": "incomplets.", "label": "ca"},
403
+ {"text": "Su padre luchó al lado de Luis Moya, primer Jefe", "label": "es"},
404
+ {"text": "informazione.", "label": "it"},
405
+ {"text": "Primacom bietet für Telekom-Kabelnetz", "label": "de"},
406
+ {"text": "Oświadczenie prezydencji w imieniu Unii", "label": "pl"},
407
+ {"text": "foran rattet i familiens gamle Baleno hvis døra på", "label": "no"},
408
+ {"text": "[speaker:laughter]", "label": "sl"},
409
+ {"text": "Dog med langt mindre utstyr med seg.", "label": "nn"},
410
+ {"text": "dass es nicht schon mit der anfänglichen", "label": "de"},
411
+ {"text": "इस पर दोनों पक्षों में नोकझोंक शुरू हो गई ।", "label": "hi"},
412
+ {"text": "کے ترجمان منیش تیواری اور دگ وجئے سنگھ نے بھی یہ", "label": "ur"},
413
+ {"text": "dell'Assemblea Costituente che posseggono i", "label": "it"},
414
+ {"text": "и аште вьси съблазнѧтъ сѧ нъ не азъ", "label": "cu"},
415
+ {"text": "In Irvine hat auch das Logistikunternehmen Atlas", "label": "de"},
416
+ {"text": "законодательных норм, принимаемых существующей", "label": "ru"},
417
+ {"text": "Κροίσῳ προτείνων τὰς χεῖρας ἐπικατασφάξαι μιν", "label": "grc"},
418
+ {"text": "МИНУСЫ: ИНФЛЯЦИЯ И КРИЗИС В ЖИВОТНОВОДСТВЕ.", "label": "ru"},
419
+ {"text": "unterschiedlicher Meinung .", "label": "de"},
420
+ {"text": "Jospa joku ystävällinen sielu auttaisi kassieni", "label": "fi"},
421
+ {"text": "Añadió que, en el futuro se harán otros", "label": "es"},
422
+ {"text": "Sessiz tonlama hem Fince, hem de Kuzey Sami", "label": "tr"},
423
+ {"text": "nicht ihnen gehört und sie nicht alles , was sie", "label": "de"},
424
+ {"text": "Etelästä Kuivajärveen laskee Tammelan Liesjärvestä", "label": "fi"},
425
+ {"text": "ICANNs Vorsitzender Vint Cerf warb mit dem Hinweis", "label": "de"},
426
+ {"text": "Norsk politikk frå til kan dermed, i", "label": "nn"},
427
+ {"text": "Głosowało posłów.", "label": "pl"},
428
+ {"text": "Danny Jones -- smithjones@ev.net", "label": "en"},
429
+ {"text": "sebeuvědomění moderní civilizace sehrála lučavka", "label": "cs"},
430
+ {"text": "относительно спокойный сон: тому гарантия", "label": "ru"},
431
+ {"text": "A halte voiz prist li pedra a crïer", "label": "fro"},
432
+ {"text": "آن‌ها امیدوارند این واکسن به‌زودی در دسترس بیماران", "label": "fa"},
433
+ {"text": "vlastní důstojnou vousatou tváří.", "label": "cs"},
434
+ {"text": "ora aprire la strada a nuove cause e alimentare il", "label": "it"},
435
+ {"text": "Die Zahl der Vielleser nahm von auf Prozent zu ,", "label": "de"},
436
+ {"text": "Finanzvorstand von Hotline-Dienstleister InfoGenie", "label": "de"},
437
+ {"text": "entwickeln .", "label": "de"},
438
+ {"text": "incolumità pubblica.", "label": "it"},
439
+ {"text": "lehtija televisiomainonta", "label": "fi"},
440
+ {"text": "joistakin kohdista eri mieltä.", "label": "fi"},
441
+ {"text": "Hlavně anglická nezávislá scéna, Dead Can Dance,", "label": "cs"},
442
+ {"text": "pásmech od do bodů bodové stupnice.", "label": "cs"},
443
+ {"text": "Zu Beginn des Ersten Weltkrieges zählte das", "label": "de"},
444
+ {"text": "Així van sorgir, damunt els antics cementiris,", "label": "ca"},
445
+ {"text": "In manchem Gedicht der spätern Alten, wie zum", "label": "de"},
446
+ {"text": "gaweihaida jah insandida in þana fairƕu jus qiþiþ", "label": "got"},
447
+ {"text": "Beides sollte gelöscht werden!", "label": "de"},
448
+ {"text": "modifiqués la seva petició inicial de anys de", "label": "ca"},
449
+ {"text": "В день открытия симпозиума состоялась закладка", "label": "ru"},
450
+ {"text": "tõestatud.", "label": "et"},
451
+ {"text": "ἵππῳ πίπτει αὐτοῦ ταύτῃ", "label": "grc"},
452
+ {"text": "bisher nie enttäuscht!", "label": "de"},
453
+ {"text": "De bohte ollu tuollárat ja suttolaččat ja", "label": "sme"},
454
+ {"text": "Klarsignal från röstlängdsläsaren, tre tryck i", "label": "sv"},
455
+ {"text": "Tvůrcem nového termínu je Joseph Fisher.", "label": "cs"},
456
+ {"text": "Nie miałem czasu na reakcję twierdzi Norbert,", "label": "pl"},
457
+ {"text": "potentia Schöpfer.", "label": "de"},
458
+ {"text": "Un poquito caro, pero vale mucho la pena;", "label": "es"},
459
+ {"text": "οὔ τε γὰρ ἴφθιμοι Λύκιοι Δαναῶν ἐδύναντο τεῖχος", "label": "grc"},
460
+ {"text": "vajec, sladového výtažku a některých vitamínových", "label": "cs"},
461
+ {"text": "Настоящие герои, те, чьи истории потом", "label": "ru"},
462
+ {"text": "praesumptio:", "label": "la"},
463
+ {"text": "Olin justkui nende vastutusel.", "label": "et"},
464
+ {"text": "Jokainen keinahdus tuo lähemmäksi hetkeä jolloin", "label": "fi"},
465
+ {"text": "ekonomicky výhodných způsobů odvodnění těžkých,", "label": "cs"},
466
+ {"text": "Poprvé ve své historii dokázala v kvalifikaci pro", "label": "cs"},
467
+ {"text": "zpracovatelského a spotřebního průmyslu bude nutné", "label": "cs"},
468
+ {"text": "Windows CE zu integrieren .", "label": "de"},
469
+ {"text": "Armangué, a través d'un decret, ordenés l'aturada", "label": "ca"},
470
+ {"text": "to, co nás Evropany spojuje, než to, co nás od", "label": "cs"},
471
+ {"text": "ergänzt durch einen gesetzlich verankertes", "label": "de"},
472
+ {"text": "Насчитал, что с начала года всего три дня были", "label": "ru"},
473
+ {"text": "Borisovu tražeći od njega da prihvati njenu", "label": "sr"},
474
+ {"text": "la presenza di ben veleni diversi: . chili di", "label": "it"},
475
+ {"text": "καὶ τῶν ἐκλεκτῶν ἀγγέλων ἵνα ταῦτα φυλάξῃς χωρὶς", "label": "grc"},
476
+ {"text": "pretraživale obližnju bolnicu i stambene zgrade u", "label": "hr"},
477
+ {"text": "An rund Katzen habe Wolf seine Spiele getestet ,", "label": "de"},
478
+ {"text": "investigating since March.", "label": "en"},
479
+ {"text": "Tonböden (Mullböden).", "label": "de"},
480
+ {"text": "Stálý dopisovatel LN v SRN Bedřich Utitz", "label": "cs"},
481
+ {"text": "červnu předložené smlouvy.", "label": "cs"},
482
+ {"text": "πνεύματι ᾧ ἐλάλει", "label": "grc"},
483
+ {"text": ".%의 신장세를 보였다.", "label": "ko"},
484
+ {"text": "Foae verde, foi de nuc, Prin pădure, prin colnic,", "label": "ro"},
485
+ {"text": "διαπέμψας ἄλλους ἄλλῃ τοὺς μὲν ἐς Δελφοὺς ἰέναι", "label": "grc"},
486
+ {"text": "المسلمين أو أي تيار سياسي طالما عمل ذلك التيار في", "label": "ar"},
487
+ {"text": "As informações são da Dow Jones.", "label": "pt"},
488
+ {"text": "Milliarde DM ausgestattet sein .", "label": "de"},
489
+ {"text": "De utgår fortfarande från att kvinnans jämlikhet", "label": "sv"},
490
+ {"text": "Sneeuw maakte in Davos bij de voorbereiding een", "label": "nl"},
491
+ {"text": "De ahí que en este mercado puedan negociarse", "label": "es"},
492
+ {"text": "intenzívnějšímu sbírání a studiu.", "label": "cs"},
493
+ {"text": "और औसकर ४.० पैकेज का प्रयोग किया गया है ।", "label": "hi"},
494
+ {"text": "Adipati Kuningan karena Kuningan menjadi bagian", "label": "id"},
495
+ {"text": "Svako je bar jednom poželeo da mašine prosto umeju", "label": "sr"},
496
+ {"text": "Im vergangenen Jahr haben die Regierungen einen", "label": "de"},
497
+ {"text": "durat motus, aliquid fit et non est;", "label": "la"},
498
+ {"text": "Dominować będą piosenki do tekstów Edwarda", "label": "pl"},
499
+ {"text": "beantwortet .", "label": "de"},
500
+ {"text": "О гуманитариях было кому рассказывать, а вот за", "label": "ru"},
501
+ {"text": "Helsingin kaupunki riitautti vuokrasopimuksen", "label": "fi"},
502
+ {"text": "chợt tan biến.", "label": "vi"},
503
+ {"text": "avtomobil ločuje od drugih.", "label": "sl"},
504
+ {"text": "Congress has proven itself ineffective as a body.", "label": "en"},
505
+ {"text": "मैक्सिको ने इस तरह का शो इस समय आयोजित करने का", "label": "hi"},
506
+ {"text": "No minimum order amount.", "label": "en"},
507
+ {"text": "Convertassa .", "label": "fi"},
508
+ {"text": "Как это можно сделать?", "label": "ru"},
509
+ {"text": "tha mi creidsinn gu robh iad ceart cho saor shuas", "label": "gd"},
510
+ {"text": "실제 일제는 이런 만해의 논리를 묵살하고 한반도를 침략한 다음 , 이어 만주를 침략하고", "label": "ko"},
511
+ {"text": "Da un semplice richiamo all'ordine fino a grandi", "label": "it"},
512
+ {"text": "pozoruhodný nejen po umělecké stránce, jež", "label": "cs"},
513
+ {"text": "La comida y el servicio aprueban.", "label": "es"},
514
+ {"text": "again, connected not with each other but to the", "label": "en"},
515
+ {"text": "Protokol výslovně stanoví, že nikdo nemůže být", "label": "cs"},
516
+ {"text": "ఒక విషయం అడగాలని ఉంది .", "label": "te"},
517
+ {"text": "Безгранично почитая дирекцию, ловя на лету каждое", "label": "ru"},
518
+ {"text": "rovnoběžných růstových vrstev, zůstávají krychlové", "label": "cs"},
519
+ {"text": "प्रवेश और पूर्व प्रधानमंत्री लाल बहादुर शास्त्री", "label": "hi"},
520
+ {"text": "Bronzen medaille in de Europese marathon.", "label": "nl"},
521
+ {"text": "- gadu vecumā viņi to nesaprot.", "label": "lv"},
522
+ {"text": "Realizó sus estudios primarios en la Escuela Julia", "label": "es"},
523
+ {"text": "cuartos de final, su clasificación para la final a", "label": "es"},
524
+ {"text": "Sem si pro něho přiletí americký raketoplán, na", "label": "cs"},
525
+ {"text": "Way to go!", "label": "en"},
526
+ {"text": "gehört der neuen SPD-Führung unter Parteichef", "label": "de"},
527
+ {"text": "Somit simuliert der Player mit einer GByte-Platte", "label": "de"},
528
+ {"text": "Berufung auf kommissionsnahe Kreise , die bereits", "label": "de"},
529
+ {"text": "Dist Clarïen", "label": "fro"},
530
+ {"text": "Schon nach den Gerüchten , die Telekom wolle den", "label": "de"},
531
+ {"text": "Software von NetObjects ist nach Angaben des", "label": "de"},
532
+ {"text": "si enim per legem iustitia ergo Christus gratis", "label": "la"},
533
+ {"text": "ducerent in ipsam magis quam in corpus christi,", "label": "la"},
534
+ {"text": "Neustar-Melbourne-IT-Partnerschaft NeuLevel .", "label": "de"},
535
+ {"text": "forderte dagegen seine drastische Verschärfung.", "label": "de"},
536
+ {"text": "pemmican på hundrede forskellige måder.", "label": "da"},
537
+ {"text": "Lehån, själv matematiklärare, visar hur den nya", "label": "sv"},
538
+ {"text": "I highly recommend his shop.", "label": "en"},
539
+ {"text": "verità, giovani fedeli prostratevi #amen", "label": "it"},
540
+ {"text": "उत्तर प्रदेश के अध्यक्ष पद से हटाए गए विनय कटियार", "label": "hi"},
541
+ {"text": "() روزی مےں کشادگی ہوتی ہے۔", "label": "ur"},
542
+ {"text": "Prozessorgeschäft profitieren kann , stellen", "label": "de"},
543
+ {"text": "školy začalo počítat pytle s moukou a zjistilo, že", "label": "cs"},
544
+ {"text": "प्रभावशाली पर गैर सरकारी लोगों के घरों में भी", "label": "hi"},
545
+ {"text": "geschichtslos , oder eine Farce , wie sich", "label": "de"},
546
+ {"text": "Ústrednými mocnosťami v marci však spôsobilo, že", "label": "sk"},
547
+ {"text": "التسليح بدون مبرر، واستمرار الأضرار الناجمة عن فرض", "label": "ar"},
548
+ {"text": "Například Pedagogická fakulta Univerzity Karlovy", "label": "cs"},
549
+ {"text": "nostris ut eriperet nos de praesenti saeculo", "label": "la"}]
550
+
551
+ docs = [Document([], text=example["text"]) for example in examples]
552
+ gold_labels = [example["label"] for example in examples]
553
+ basic_multilingual(docs)
554
+ accuracy = sum([(doc.lang == label) for doc,label in zip(docs,gold_labels)])/len(docs)
555
+ assert accuracy >= 0.98
556
+
557
+
558
+ def test_text_cleaning(basic_multilingual, clean_multilingual):
559
+ """
560
+ Basic test of cleaning text
561
+ """
562
+ docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
563
+ "Bonjour le monde! https://t.co/U0Zjp3tusD"]
564
+ docs = [Document([], text=text) for text in docs]
565
+
566
+ basic_multilingual(docs)
567
+ assert [doc.lang for doc in docs] == ["it", "it"]
568
+
569
+ assert clean_multilingual.processors["langid"]._clean_text
570
+ clean_multilingual(docs)
571
+ assert [doc.lang for doc in docs] == ["fr", "fr"]
572
+
573
+ def test_emoji_cleaning():
574
+ TEXT = ["Sh'reyan has nice antennae :thumbs_up:",
575
+ "This is🐱 a cat"]
576
+ EXPECTED = ["Sh'reyan has nice antennae",
577
+ "This is a cat"]
578
+ for text, expected in zip(TEXT, EXPECTED):
579
+ assert LangIDProcessor.clean_text(text) == expected
580
+
581
+ def test_lang_subset(basic_multilingual, enfr_multilingual, en_multilingual):
582
+ """
583
+ Basic test of restricting output to subset of languages
584
+ """
585
+ docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
586
+ "Bonjour le monde! https://t.co/U0Zjp3tusD"]
587
+ docs = [Document([], text=text) for text in docs]
588
+
589
+ basic_multilingual(docs)
590
+ assert [doc.lang for doc in docs] == ["it", "it"]
591
+
592
+ assert enfr_multilingual.processors["langid"]._model.lang_subset == ["en", "fr"]
593
+ enfr_multilingual(docs)
594
+ assert [doc.lang for doc in docs] == ["fr", "fr"]
595
+
596
+ assert en_multilingual.processors["langid"]._model.lang_subset == ["en"]
597
+ en_multilingual(docs)
598
+ assert [doc.lang for doc in docs] == ["en", "en"]
599
+
600
+ def test_lang_subset_unlikely_language(en_multilingual):
601
+ """
602
+ Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely
603
+ """
604
+ sentences = ["你好" * 200]
605
+ docs = [Document([], text=text) for text in sentences]
606
+ en_multilingual(docs)
607
+ assert [doc.lang for doc in docs] == ["en"]
608
+
609
+ processor = en_multilingual.processors['langid']
610
+ model = processor._model
611
+ text_tensor = processor._text_to_tensor(sentences)
612
+ en_idx = model.tag_to_idx['en']
613
+ predictions = model(text_tensor)
614
+ assert predictions[0, en_idx] < 0, "If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English. Update the test by picking a different combination of languages & input"
615
+
stanza/stanza/tests/lemma/__init__.py ADDED
File without changes
stanza/stanza/tests/mwt/test_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test the MWT resplitting of preexisting tokens without word splits
3
+ """
4
+
5
+ import pytest
6
+
7
+ import stanza
8
+ from stanza.models.mwt.utils import resplit_mwt
9
+
10
+ from stanza.tests import TEST_MODELS_DIR
11
+
12
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
13
+
14
+ @pytest.fixture(scope="module")
15
+ def pipeline():
16
+ """
17
+ A reusable pipeline with the NER module
18
+ """
19
+ return stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,mwt", package="gum")
20
+
21
+
22
+ def test_resplit_keep_tokens(pipeline):
23
+ """
24
+ Test splitting with enforced token boundaries
25
+ """
26
+ tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
27
+ doc = resplit_mwt(tokens, pipeline)
28
+ assert len(doc.sentences) == 2
29
+ assert len(doc.sentences[0].tokens) == 4
30
+ assert len(doc.sentences[0].tokens[1].words) == 2
31
+ assert doc.sentences[0].tokens[1].words[0].text == "ca"
32
+ assert doc.sentences[0].tokens[1].words[1].text == "n't"
33
+
34
+ assert len(doc.sentences[1].tokens) == 2
35
+ # updated GUM MWT splits "I can't" into three segments
36
+ # the way we want, "I - ca - n't"
37
+ # previously it would split "I - can - 't"
38
+ assert len(doc.sentences[1].tokens[0].words) == 3
39
+ assert doc.sentences[1].tokens[0].words[0].text == "I"
40
+ assert doc.sentences[1].tokens[0].words[1].text == "ca"
41
+ assert doc.sentences[1].tokens[0].words[2].text == "n't"
42
+
43
+
44
+ def test_resplit_no_keep_tokens(pipeline):
45
+ """
46
+ Test splitting without enforced token boundaries
47
+ """
48
+ tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
49
+ doc = resplit_mwt(tokens, pipeline, keep_tokens=False)
50
+ assert len(doc.sentences) == 2
51
+ assert len(doc.sentences[0].tokens) == 4
52
+ assert len(doc.sentences[0].tokens[1].words) == 2
53
+ assert doc.sentences[0].tokens[1].words[0].text == "ca"
54
+ assert doc.sentences[0].tokens[1].words[1].text == "n't"
55
+
56
+ assert len(doc.sentences[1].tokens) == 3
57
+ assert len(doc.sentences[1].tokens[1].words) == 2
58
+ assert doc.sentences[1].tokens[1].words[0].text == "ca"
59
+ assert doc.sentences[1].tokens[1].words[1].text == "n't"
stanza/stanza/tests/ner/__init__.py ADDED
File without changes
stanza/stanza/tests/ner/test_combine_ner_datasets.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pytest
4
+
5
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
6
+
7
+ from stanza.models.common.doc import Document
8
+ from stanza.tests.ner.test_ner_training import write_temp_file, EN_TRAIN_BIO, EN_DEV_BIO
9
+ from stanza.utils.datasets.ner import combine_ner_datasets
10
+
11
+
12
+ def test_combine(tmp_path):
13
+ """
14
+ Test that if we write two short datasets and combine them, we get back
15
+ one slightly longer dataset
16
+
17
+ To simplify matters, we just use the same input text with longer
18
+ amounts of text for each shard.
19
+ """
20
+ SHARDS = ("train", "dev", "test")
21
+ for s_num, shard in enumerate(SHARDS):
22
+ t1_json = tmp_path / ("en_t1.%s.json" % shard)
23
+ # eg, 1x, 2x, 3x the test data from test_ner_training
24
+ write_temp_file(t1_json, "\n\n".join([EN_TRAIN_BIO] * (s_num + 1)))
25
+
26
+ t2_json = tmp_path / ("en_t2.%s.json" % shard)
27
+ write_temp_file(t2_json, "\n\n".join([EN_DEV_BIO] * (s_num + 1)))
28
+
29
+ args = ["--output_dataset", "en_c", "en_t1", "en_t2", "--input_dir", str(tmp_path), "--output_dir", str(tmp_path)]
30
+ combine_ner_datasets.main(args)
31
+
32
+ for s_num, shard in enumerate(SHARDS):
33
+ filename = tmp_path / ("en_c.%s.json" % shard)
34
+ assert os.path.exists(filename)
35
+
36
+ with open(filename, encoding="utf-8") as fin:
37
+ doc = Document(json.load(fin))
38
+ assert len(doc.sentences) == (s_num + 1) * 3
39
+
stanza/stanza/tests/ner/test_models_ner_scorer.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple test of the scorer module for NER
3
+ """
4
+
5
+ import pytest
6
+ import stanza
7
+
8
+ from stanza.tests import *
9
+ from stanza.models.ner.scorer import score_by_token, score_by_entity
10
+
11
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
12
+
13
+ def test_ner_scorer():
14
+ pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'],
15
+ ['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']]
16
+ gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'],
17
+ ['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']]
18
+
19
+ token_p, token_r, token_f, confusion = score_by_token(pred_sequences, gold_sequences)
20
+ assert pytest.approx(token_p, abs=0.00001) == 0.625
21
+ assert pytest.approx(token_r, abs=0.00001) == 0.5
22
+ assert pytest.approx(token_f, abs=0.00001) == 0.55555
23
+
24
+ entity_p, entity_r, entity_f, entity_f1 = score_by_entity(pred_sequences, gold_sequences)
25
+ assert pytest.approx(entity_p, abs=0.00001) == 0.4
26
+ assert pytest.approx(entity_r, abs=0.00001) == 0.33333
27
+ assert pytest.approx(entity_f, abs=0.00001) == 0.36363
28
+ assert entity_f1 == {'LOC': 0.0, 'MISC': 1.0, 'ORG': 0.0, 'PER': 0.5}
stanza/stanza/tests/ner/test_ner_tagger.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic testing of the NER tagger.
3
+ """
4
+
5
+ import os
6
+ import pytest
7
+ import stanza
8
+
9
+ from stanza.tests import *
10
+ from stanza.models import ner_tagger
11
+ from stanza.utils.confusion import confusion_to_macro_f1
12
+ import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file
13
+ from stanza.utils.training.run_ner import build_pretrain_args
14
+
15
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
16
+
17
+ EN_DOC = "Chris Manning is a good man. He works in Stanford University."
18
+
19
+ EN_DOC_GOLD = """
20
+ <Span text=Chris Manning;type=PERSON;start_char=0;end_char=13>
21
+ <Span text=Stanford University;type=ORG;start_char=41;end_char=60>
22
+ """.strip()
23
+
24
+ EN_BIO = """
25
+ Chris B-PERSON
26
+ Manning E-PERSON
27
+ is O
28
+ a O
29
+ good O
30
+ man O
31
+ . O
32
+
33
+ He O
34
+ works O
35
+ in O
36
+ Stanford B-ORG
37
+ University E-ORG
38
+ . O
39
+ """.strip().replace(" ", "\t")
40
+
41
+ EN_EXPECTED_OUTPUT = """
42
+ Chris B-PERSON B-PERSON
43
+ Manning E-PERSON E-PERSON
44
+ is O O
45
+ a O O
46
+ good O O
47
+ man O O
48
+ . O O
49
+
50
+ He O O
51
+ works O O
52
+ in O O
53
+ Stanford B-ORG B-ORG
54
+ University E-ORG E-ORG
55
+ . O O
56
+ """.strip().replace(" ", "\t")
57
+
58
+
59
+ def test_ner():
60
+ nlp = stanza.Pipeline(**{'processors': 'tokenize,ner', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'logging_level': 'error'})
61
+ doc = nlp(EN_DOC)
62
+ assert EN_DOC_GOLD == '\n'.join([ent.pretty_print() for ent in doc.ents])
63
+
64
+ def test_evaluate(tmp_path):
65
+ """
66
+ This simple example should have a 1.0 f1 for the ontonote model
67
+ """
68
+ package = "ontonotes-ww-multi_charlm"
69
+ model_path = os.path.join(TEST_MODELS_DIR, "en", "ner", package + ".pt")
70
+ assert os.path.exists(model_path), "The {} model should be downloaded as part of setup.py".format(package)
71
+
72
+ os.makedirs(tmp_path, exist_ok=True)
73
+
74
+ test_bio_filename = tmp_path / "test.bio"
75
+ test_json_filename = tmp_path / "test.json"
76
+ test_output_filename = tmp_path / "output.bio"
77
+ with open(test_bio_filename, "w", encoding="utf-8") as fout:
78
+ fout.write(EN_BIO)
79
+
80
+ prepare_ner_file.process_dataset(test_bio_filename, test_json_filename)
81
+
82
+ args = ["--save_name", str(model_path),
83
+ "--eval_file", str(test_json_filename),
84
+ "--eval_output_file", str(test_output_filename),
85
+ "--mode", "predict"]
86
+ args = args + build_pretrain_args("en", package, model_dir=TEST_MODELS_DIR)
87
+ args = ner_tagger.parse_args(args=args)
88
+ confusion = ner_tagger.evaluate(args)
89
+ assert confusion_to_macro_f1(confusion) == pytest.approx(1.0)
90
+
91
+ with open(test_output_filename, encoding="utf-8") as fin:
92
+ results = fin.read().strip()
93
+
94
+ assert results == EN_EXPECTED_OUTPUT
stanza/stanza/tests/ner/test_ner_trainer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from stanza.tests import *
4
+
5
+ from stanza.models.ner import trainer
6
+
7
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
8
+
9
+ def test_fix_singleton_tags():
10
+ TESTS = [
11
+ (["O"], ["O"]),
12
+ (["B-PER"], ["S-PER"]),
13
+ (["B-PER", "I-PER"], ["B-PER", "E-PER"]),
14
+ (["B-PER", "O", "B-PER"], ["S-PER", "O", "S-PER"]),
15
+ (["B-PER", "B-PER", "I-PER"], ["S-PER", "B-PER", "E-PER"]),
16
+ (["B-PER", "I-PER", "O", "B-PER"], ["B-PER", "E-PER", "O", "S-PER"]),
17
+ (["B-PER", "B-PER", "I-PER", "B-PER"], ["S-PER", "B-PER", "E-PER", "S-PER"]),
18
+ (["B-PER", "I-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]),
19
+ (["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
20
+ (["S-PER", "B-PER", "E-PER"], ["S-PER", "B-PER", "E-PER"]),
21
+ (["E-PER"], ["S-PER"]),
22
+ (["E-PER", "O", "E-PER"], ["S-PER", "O", "S-PER"]),
23
+ (["B-PER", "E-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]),
24
+ (["I-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
25
+ (["B-PER", "I-PER", "I-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
26
+ (["B-PER", "I-PER", "E-PER", "O", "I-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
27
+ (["B-PER", "I-PER", "E-PER", "O", "B-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
28
+ (["I-PER", "I-PER", "I-PER", "O", "I-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
29
+ ]
30
+
31
+ for unfixed, expected in TESTS:
32
+ assert trainer.fix_singleton_tags(unfixed) == expected, "Error converting {} to {}".format(unfixed, expected)
stanza/stanza/tests/ner/test_pay_amt_annotators.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple test for tracking AMT annotator work
3
+ """
4
+
5
+ import os
6
+ import zipfile
7
+
8
+ import pytest
9
+
10
+ from stanza.tests import TEST_WORKING_DIR
11
+ from stanza.utils.ner import paying_annotators
12
+
13
+ DATA_SOURCE = os.path.join(TEST_WORKING_DIR, "in", "aws_annotations.zip")
14
+
15
+ @pytest.fixture(scope="module")
16
+ def completed_amt_job_metadata(tmp_path_factory):
17
+ assert os.path.exists(DATA_SOURCE)
18
+ unzip_path = tmp_path_factory.mktemp("amt_test")
19
+ input_path = unzip_path / "ner" / "aws_labeling_copy"
20
+ with zipfile.ZipFile(DATA_SOURCE, 'r') as zin:
21
+ zin.extractall(unzip_path)
22
+ return input_path
23
+
24
+ def test_amt_annotator_track(completed_amt_job_metadata):
25
+ workers = {
26
+ "7efc17ac-3397-4472-afe5-89184ad145d0": "Worker1",
27
+ "afce8c28-969c-4e73-a20f-622ef122f585": "Worker2",
28
+ "91f6236e-63c6-4a84-8fd6-1efbab6dedab": "Worker3",
29
+ "6f202e93-e6b6-4e1d-8f07-0484b9a9093a": "Worker4",
30
+ "2b674d33-f656-44b0-8f90-d70a1ab71ec2": "Worker5"
31
+ } # map AMT annotator subs to relevant identifier
32
+
33
+ tracked_work = paying_annotators.track_tasks(completed_amt_job_metadata, workers)
34
+ assert tracked_work == {'Worker4': 20, 'Worker5': 20, 'Worker2': 3, 'Worker3': 16}
35
+
36
+
37
+ def test_amt_annotator_track_no_map(completed_amt_job_metadata):
38
+ sub_to_count = paying_annotators.track_tasks(completed_amt_job_metadata)
39
+ assert sub_to_count == {'6f202e93-e6b6-4e1d-8f07-0484b9a9093a': 20, '2b674d33-f656-44b0-8f90-d70a1ab71ec2': 20,
40
+ 'afce8c28-969c-4e73-a20f-622ef122f585': 3, '91f6236e-63c6-4a84-8fd6-1efbab6dedab': 16}
41
+
42
+
43
+ def main():
44
+ test_amt_annotator_track()
45
+ test_amt_annotator_track_no_map()
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
50
+ print("TESTS COMPLETED!")
stanza/stanza/tests/ner/test_split_wikiner.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runs a few tests on the split_wikiner file
3
+ """
4
+
5
+ import os
6
+ import tempfile
7
+
8
+ import pytest
9
+
10
+ from stanza.utils.datasets.ner import split_wikiner
11
+
12
+ from stanza.tests import *
13
+
14
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
15
+
16
+ # two sentences from the Italian dataset, split into many pieces
17
+ # to test the splitting functionality
18
+ FBK_SAMPLE = """
19
+ Il O
20
+ Papa O
21
+ si O
22
+ aggrava O
23
+
24
+ Le O
25
+ condizioni O
26
+ di O
27
+
28
+ Papa O
29
+ Giovanni PER
30
+ Paolo PER
31
+ II PER
32
+ si O
33
+
34
+ sono O
35
+ aggravate O
36
+ in O
37
+ il O
38
+ corso O
39
+
40
+ di O
41
+ la O
42
+ giornata O
43
+ di O
44
+ giovedì O
45
+ . O
46
+
47
+ Il O
48
+ portavoce O
49
+ Navarro PER
50
+ Valls PER
51
+
52
+ ha O
53
+ dichiarato O
54
+ che O
55
+
56
+ il O
57
+ Santo O
58
+ Padre O
59
+
60
+ in O
61
+ la O
62
+ giornata O
63
+
64
+ di O
65
+ oggi O
66
+ è O
67
+ stato O
68
+
69
+ colpito O
70
+ da O
71
+ una O
72
+ affezione O
73
+
74
+ altamente O
75
+ febbrile O
76
+ provocata O
77
+ da O
78
+ una O
79
+
80
+ infezione O
81
+ documentata O
82
+
83
+ di O
84
+ le O
85
+ vie O
86
+ urinarie O
87
+ . O
88
+
89
+ A O
90
+ il O
91
+ momento O
92
+
93
+ non O
94
+ è O
95
+ previsto O
96
+ il O
97
+ ricovero O
98
+
99
+ a O
100
+ il O
101
+ Policlinico LOC
102
+ Gemelli LOC
103
+ , O
104
+
105
+ come O
106
+ ha O
107
+ precisato O
108
+ il O
109
+
110
+ responsabile O
111
+ di O
112
+ il O
113
+ dipartimento O
114
+
115
+ di O
116
+ emergenza O
117
+ professor O
118
+ Rodolfo PER
119
+ Proietti PER
120
+ . O
121
+ """
122
+
123
+
124
+ def test_read_sentences():
125
+ with tempfile.TemporaryDirectory() as tempdir:
126
+ raw_filename = os.path.join(tempdir, "raw.tsv")
127
+ with open(raw_filename, "w") as fout:
128
+ fout.write(FBK_SAMPLE)
129
+
130
+ sentences = split_wikiner.read_sentences(raw_filename, "utf-8")
131
+ assert len(sentences) == 20
132
+ text = [["\t".join(word) for word in sent] for sent in sentences]
133
+ text = ["\n".join(sent) for sent in text]
134
+ text = "\n\n".join(text)
135
+ assert FBK_SAMPLE.strip() == text
136
+
137
+ def test_write_sentences():
138
+ with tempfile.TemporaryDirectory() as tempdir:
139
+ raw_filename = os.path.join(tempdir, "raw.tsv")
140
+ with open(raw_filename, "w") as fout:
141
+ fout.write(FBK_SAMPLE)
142
+
143
+ sentences = split_wikiner.read_sentences(raw_filename, "utf-8")
144
+ copy_filename = os.path.join(tempdir, "copy.tsv")
145
+ split_wikiner.write_sentences_to_file(sentences, copy_filename)
146
+
147
+ sent2 = split_wikiner.read_sentences(raw_filename, "utf-8")
148
+ assert sent2 == sentences
149
+
150
+ def run_split_wikiner(expected_train=14, expected_dev=3, expected_test=3, **kwargs):
151
+ """
152
+ Runs a test using various parameters to check the results of the splitting process
153
+ """
154
+ with tempfile.TemporaryDirectory() as indir:
155
+ raw_filename = os.path.join(indir, "raw.tsv")
156
+ with open(raw_filename, "w") as fout:
157
+ fout.write(FBK_SAMPLE)
158
+
159
+ with tempfile.TemporaryDirectory() as outdir:
160
+ split_wikiner.split_wikiner(outdir, raw_filename, **kwargs)
161
+
162
+ train_file = os.path.join(outdir, "it_fbk.train.bio")
163
+ dev_file = os.path.join(outdir, "it_fbk.dev.bio")
164
+ test_file = os.path.join(outdir, "it_fbk.test.bio")
165
+
166
+ assert os.path.exists(train_file)
167
+ assert os.path.exists(dev_file)
168
+ if kwargs["test_section"]:
169
+ assert os.path.exists(test_file)
170
+ else:
171
+ assert not os.path.exists(test_file)
172
+
173
+ train_sent = split_wikiner.read_sentences(train_file, "utf-8")
174
+ dev_sent = split_wikiner.read_sentences(dev_file, "utf-8")
175
+ assert len(train_sent) == expected_train
176
+ assert len(dev_sent) == expected_dev
177
+ if kwargs["test_section"]:
178
+ test_sent = split_wikiner.read_sentences(test_file, "utf-8")
179
+ assert len(test_sent) == expected_test
180
+ else:
181
+ test_sent = []
182
+
183
+ if kwargs["shuffle"]:
184
+ orig_sents = sorted(split_wikiner.read_sentences(raw_filename, "utf-8"))
185
+ split_sents = sorted(train_sent + dev_sent + test_sent)
186
+ else:
187
+ orig_sents = split_wikiner.read_sentences(raw_filename, "utf-8")
188
+ split_sents = train_sent + dev_sent + test_sent
189
+ assert orig_sents == split_sents
190
+
191
+ def test_no_shuffle_split():
192
+ run_split_wikiner(prefix="it_fbk", shuffle=False, test_section=True)
193
+
194
+ def test_shuffle_split():
195
+ run_split_wikiner(prefix="it_fbk", shuffle=True, test_section=True)
196
+
197
+ def test_resize():
198
+ run_split_wikiner(expected_train=12, expected_dev=2, expected_test=6, train_fraction=0.6, dev_fraction=0.1, prefix="it_fbk", shuffle=True, test_section=True)
199
+
200
+ def test_no_test_split():
201
+ run_split_wikiner(expected_train=17, train_fraction=0.85, prefix="it_fbk", shuffle=False, test_section=False)
202
+
stanza/stanza/tests/ner/test_suc3.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests the conversion code for the SUC3 NER dataset
3
+ """
4
+
5
+ import os
6
+ import tempfile
7
+ from zipfile import ZipFile
8
+
9
+ import pytest
10
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
11
+
12
+ import stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob
13
+
14
+ TEST_CONLL = """
15
+ 1 Den den PN PN UTR|SIN|DEF|SUB/OBJ _ _ _ _ O _ ac01b-030:2328
16
+ 2 Gud Gud PM PM NOM _ _ _ _ B myth ac01b-030:2329
17
+ 3 giver giva VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2330
18
+ 4 ämbetet ämbete NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2331
19
+ 5 får få VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2332
20
+ 6 också också AB AB _ _ _ _ O _ ac01b-030:2333
21
+ 7 förståndet förstånd NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2334
22
+ 8 . . MAD MAD _ _ _ _ O _ ac01b-030:2335
23
+
24
+ 1 Han han PN PN UTR|SIN|DEF|SUB _ _ _ _ O _ aa01a-017:227
25
+ 2 berättar berätta VB VB PRS|AKT _ _ _ _ O _ aa01a-017:228
26
+ 3 anekdoten anekdot NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:229
27
+ 4 som som HP HP -|-|- _ _ _ _ O _ aa01a-017:230
28
+ 5 FN-medlaren FN-medlare NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:231
29
+ 6 Brian Brian PM PM NOM _ _ _ _ B person aa01a-017:232
30
+ 7 Urquhart Urquhart PM PM NOM _ _ _ _ I person aa01a-017:233
31
+ 8 myntat mynta VB VB SUP|AKT _ _ _ _ O _ aa01a-017:234
32
+ 9 : : MAD MAD _ _ _ _ O _ aa01a-017:235
33
+ """
34
+
35
+ EXPECTED_IOB = """
36
+ Den O
37
+ Gud B-myth
38
+ giver O
39
+ ämbetet O
40
+ får O
41
+ också O
42
+ förståndet O
43
+ . O
44
+
45
+ Han O
46
+ berättar O
47
+ anekdoten O
48
+ som O
49
+ FN-medlaren O
50
+ Brian B-person
51
+ Urquhart I-person
52
+ myntat O
53
+ : O
54
+ """
55
+
56
+ def test_read_zip():
57
+ """
58
+ Test creating a fake zip file, then converting it to an .iob file
59
+ """
60
+ with tempfile.TemporaryDirectory() as tempdir:
61
+ zip_name = os.path.join(tempdir, "test.zip")
62
+ in_filename = "conll"
63
+ with ZipFile(zip_name, "w") as zout:
64
+ with zout.open(in_filename, "w") as fout:
65
+ fout.write(TEST_CONLL.encode())
66
+
67
+ out_filename = "iob"
68
+ num = suc_conll_to_iob.extract_from_zip(zip_name, in_filename, out_filename)
69
+ assert num == 2
70
+
71
+ with open(out_filename) as fin:
72
+ result = fin.read()
73
+ assert EXPECTED_IOB.strip() == result.strip()
74
+
75
+ def test_read_raw():
76
+ """
77
+ Test a direct text file conversion w/o the zip file
78
+ """
79
+ with tempfile.TemporaryDirectory() as tempdir:
80
+ in_filename = os.path.join(tempdir, "test.txt")
81
+ with open(in_filename, "w", encoding="utf-8") as fout:
82
+ fout.write(TEST_CONLL)
83
+
84
+ out_filename = "iob"
85
+ with open(in_filename, encoding="utf-8") as fin, open(out_filename, "w", encoding="utf-8") as fout:
86
+ num = suc_conll_to_iob.extract(fin, fout)
87
+ assert num == 2
88
+
89
+ with open(out_filename) as fin:
90
+ result = fin.read()
91
+ assert EXPECTED_IOB.strip() == result.strip()
stanza/stanza/tests/pipeline/test_decorators.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic tests of the depparse processor boolean flags
3
+ """
4
+ import pytest
5
+
6
+ import stanza
7
+ from stanza.models.common.doc import Document
8
+ from stanza.pipeline.core import PipelineRequirementsException
9
+ from stanza.pipeline.processor import Processor, ProcessorVariant, register_processor, register_processor_variant, ProcessorRegisterException
10
+ from stanza.utils.conll import CoNLL
11
+ from stanza.tests import *
12
+
13
+ pytestmark = pytest.mark.pipeline
14
+
15
+ # data for testing
16
+ EN_DOC = "This is a test sentence. This is another!"
17
+
18
+ EN_DOC_LOWERCASE_TOKENS = '''<Token id=1;words=[<Word id=1;text=this>]>
19
+ <Token id=2;words=[<Word id=2;text=is>]>
20
+ <Token id=3;words=[<Word id=3;text=a>]>
21
+ <Token id=4;words=[<Word id=4;text=test>]>
22
+ <Token id=5;words=[<Word id=5;text=sentence>]>
23
+ <Token id=6;words=[<Word id=6;text=.>]>
24
+
25
+ <Token id=1;words=[<Word id=1;text=this>]>
26
+ <Token id=2;words=[<Word id=2;text=is>]>
27
+ <Token id=3;words=[<Word id=3;text=another>]>
28
+ <Token id=4;words=[<Word id=4;text=!>]>'''
29
+
30
+ EN_DOC_LOL_TOKENS = '''<Token id=1;words=[<Word id=1;text=LOL>]>
31
+ <Token id=2;words=[<Word id=2;text=LOL>]>
32
+ <Token id=3;words=[<Word id=3;text=LOL>]>
33
+ <Token id=4;words=[<Word id=4;text=LOL>]>
34
+ <Token id=5;words=[<Word id=5;text=LOL>]>
35
+ <Token id=6;words=[<Word id=6;text=LOL>]>
36
+ <Token id=7;words=[<Word id=7;text=LOL>]>
37
+ <Token id=8;words=[<Word id=8;text=LOL>]>'''
38
+
39
+ EN_DOC_COOL_LEMMAS = '''<Token id=1;words=[<Word id=1;text=This;lemma=cool;upos=PRON;xpos=DT;feats=Number=Sing|PronType=Dem>]>
40
+ <Token id=2;words=[<Word id=2;text=is;lemma=cool;upos=AUX;xpos=VBZ;feats=Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin>]>
41
+ <Token id=3;words=[<Word id=3;text=a;lemma=cool;upos=DET;xpos=DT;feats=Definite=Ind|PronType=Art>]>
42
+ <Token id=4;words=[<Word id=4;text=test;lemma=cool;upos=NOUN;xpos=NN;feats=Number=Sing>]>
43
+ <Token id=5;words=[<Word id=5;text=sentence;lemma=cool;upos=NOUN;xpos=NN;feats=Number=Sing>]>
44
+ <Token id=6;words=[<Word id=6;text=.;lemma=cool;upos=PUNCT;xpos=.>]>
45
+
46
+ <Token id=1;words=[<Word id=1;text=This;lemma=cool;upos=PRON;xpos=DT;feats=Number=Sing|PronType=Dem>]>
47
+ <Token id=2;words=[<Word id=2;text=is;lemma=cool;upos=AUX;xpos=VBZ;feats=Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin>]>
48
+ <Token id=3;words=[<Word id=3;text=another;lemma=cool;upos=DET;xpos=DT;feats=PronType=Ind>]>
49
+ <Token id=4;words=[<Word id=4;text=!;lemma=cool;upos=PUNCT;xpos=.>]>'''
50
+
51
+ @register_processor("lowercase")
52
+ class LowercaseProcessor(Processor):
53
+ ''' Processor that lowercases all text '''
54
+ _requires = set(['tokenize'])
55
+ _provides = set(['lowercase'])
56
+
57
+ def __init__(self, config, pipeline, device):
58
+ pass
59
+
60
+ def _set_up_model(self, *args):
61
+ pass
62
+
63
+ def process(self, doc):
64
+ doc.text = doc.text.lower()
65
+ for sent in doc.sentences:
66
+ for tok in sent.tokens:
67
+ tok.text = tok.text.lower()
68
+
69
+ for word in sent.words:
70
+ word.text = word.text.lower()
71
+
72
+ return doc
73
+
74
+ def test_register_processor():
75
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,lowercase', download_method=None)
76
+ doc = nlp(EN_DOC)
77
+ assert EN_DOC_LOWERCASE_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
78
+
79
+ def test_register_nonprocessor():
80
+ with pytest.raises(ProcessorRegisterException):
81
+ @register_processor("nonprocessor")
82
+ class NonProcessor:
83
+ pass
84
+
85
+ @register_processor_variant("tokenize", "lol")
86
+ class LOLTokenizer(ProcessorVariant):
87
+ ''' An alternative tokenizer that splits text by space and replaces all tokens with LOL '''
88
+
89
+ def __init__(self, lang):
90
+ pass
91
+
92
+ def process(self, text):
93
+ sentence = [{'id': (i+1, ), 'text': 'LOL'} for i, tok in enumerate(text.split())]
94
+ return Document([sentence], text)
95
+
96
+ def test_register_processor_variant():
97
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "lol"}, package=None, download_method=None)
98
+ doc = nlp(EN_DOC)
99
+ assert EN_DOC_LOL_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
100
+
101
+ @register_processor_variant("lemma", "cool")
102
+ class CoolLemmatizer(ProcessorVariant):
103
+ ''' An alternative lemmatizer that lemmatizes every word to "cool". '''
104
+
105
+ OVERRIDE = True
106
+
107
+ def __init__(self, lang):
108
+ pass
109
+
110
+ def process(self, document):
111
+ for sentence in document.sentences:
112
+ for word in sentence.words:
113
+ word.lemma = "cool"
114
+
115
+ return document
116
+
117
+ def test_register_processor_variant_with_override():
118
+ nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "combined", "pos": "combined", "lemma": "cool"}, package=None, download_method=None)
119
+ doc = nlp(EN_DOC)
120
+ result = '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
121
+ assert EN_DOC_COOL_LEMMAS == result
122
+
123
+ def test_register_nonprocessor_variant():
124
+ with pytest.raises(ProcessorRegisterException):
125
+ @register_processor_variant("tokenize", "nonvariant")
126
+ class NonVariant:
127
+ pass
stanza/stanza/tests/pipeline/test_pipeline_mwt_expander.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic testing of multi-word-token expansion
3
+ """
4
+
5
+ import pytest
6
+ import stanza
7
+
8
+ from stanza.tests import *
9
+
10
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
11
+
12
+ # mwt data for testing
13
+ FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \
14
+ "l'Industrie et du Numérique."
15
+
16
+
17
+ FR_MWT_TOKEN_TO_WORDS_GOLD = """
18
+ token: Alors words: [<Word id=1;text=Alors>]
19
+ token: encore words: [<Word id=2;text=encore>]
20
+ token: inconnu words: [<Word id=3;text=inconnu>]
21
+ token: du words: [<Word id=4;text=de>, <Word id=5;text=le>]
22
+ token: grand words: [<Word id=6;text=grand>]
23
+ token: public words: [<Word id=7;text=public>]
24
+ token: , words: [<Word id=8;text=,>]
25
+ token: Emmanuel words: [<Word id=9;text=Emmanuel>]
26
+ token: Macron words: [<Word id=10;text=Macron>]
27
+ token: devient words: [<Word id=11;text=devient>]
28
+ token: en words: [<Word id=12;text=en>]
29
+ token: 2014 words: [<Word id=13;text=2014>]
30
+ token: ministre words: [<Word id=14;text=ministre>]
31
+ token: de words: [<Word id=15;text=de>]
32
+ token: l' words: [<Word id=16;text=l'>]
33
+ token: Économie words: [<Word id=17;text=Économie>]
34
+ token: , words: [<Word id=18;text=,>]
35
+ token: de words: [<Word id=19;text=de>]
36
+ token: l' words: [<Word id=20;text=l'>]
37
+ token: Industrie words: [<Word id=21;text=Industrie>]
38
+ token: et words: [<Word id=22;text=et>]
39
+ token: du words: [<Word id=23;text=de>, <Word id=24;text=le>]
40
+ token: Numérique words: [<Word id=25;text=Numérique>]
41
+ token: . words: [<Word id=26;text=.>]
42
+ """.strip()
43
+
44
+ FR_MWT_WORD_TO_TOKEN_GOLD = """
45
+ word: Alors token parent:1-Alors
46
+ word: encore token parent:2-encore
47
+ word: inconnu token parent:3-inconnu
48
+ word: de token parent:4-5-du
49
+ word: le token parent:4-5-du
50
+ word: grand token parent:6-grand
51
+ word: public token parent:7-public
52
+ word: , token parent:8-,
53
+ word: Emmanuel token parent:9-Emmanuel
54
+ word: Macron token parent:10-Macron
55
+ word: devient token parent:11-devient
56
+ word: en token parent:12-en
57
+ word: 2014 token parent:13-2014
58
+ word: ministre token parent:14-ministre
59
+ word: de token parent:15-de
60
+ word: l' token parent:16-l'
61
+ word: Économie token parent:17-Économie
62
+ word: , token parent:18-,
63
+ word: de token parent:19-de
64
+ word: l' token parent:20-l'
65
+ word: Industrie token parent:21-Industrie
66
+ word: et token parent:22-et
67
+ word: de token parent:23-24-du
68
+ word: le token parent:23-24-du
69
+ word: Numérique token parent:25-Numérique
70
+ word: . token parent:26-.
71
+ """.strip()
72
+
73
+
74
+ def test_mwt():
75
+ pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='fr', download_method=None)
76
+ doc = pipeline(FR_MWT_SENTENCE)
77
+ token_to_words = "\n".join(
78
+ [f'token: {token.text.ljust(9)}\t\twords: [{", ".join([word.pretty_print() for word in token.words])}]' for sent in doc.sentences for token in sent.tokens]
79
+ ).strip()
80
+ word_to_token = "\n".join(
81
+ [f'word: {word.text.ljust(9)}\t\ttoken parent:{"-".join([str(x) for x in word.parent.id])}-{word.parent.text}'
82
+ for sent in doc.sentences for word in sent.words]).strip()
83
+ assert token_to_words == FR_MWT_TOKEN_TO_WORDS_GOLD
84
+ assert word_to_token == FR_MWT_WORD_TO_TOKEN_GOLD
85
+
86
+ def test_unknown_character():
87
+ """
88
+ The MWT processor has a mechanism to temporarily add unknown characters to the vocab
89
+
90
+ Here we check that it is properly adding the characters from a test case a user sent us
91
+ """
92
+ pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
93
+ text = "Björkängshallen's"
94
+ mwt_processor = pipeline.processors["mwt"]
95
+ trainer = mwt_processor.trainer
96
+ # verify that the test case is still valid
97
+ # (perhaps an updated MWT model will have all of these characters in the future)
98
+ assert not all(x in trainer.vocab._unit2id for x in text)
99
+ doc = pipeline(text)
100
+ batch = mwt_processor.build_batch(doc)
101
+ # the vocab used in this batch should have the missing characters
102
+ assert all(x in batch.vocab._unit2id for x in text)
103
+
104
+ def test_unknown_word():
105
+ """
106
+ Test a word which wasn't in the MWT training data
107
+
108
+ The seq2seq model for MWT was randomly hallucinating, but with the
109
+ CharacterClassifier, it should be able to process unusual MWT
110
+ without hallucinations
111
+ """
112
+ pipe = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
113
+ doc = pipe("I read the newspaper's report.")
114
+ assert len(doc.sentences) == 1
115
+ assert len(doc.sentences[0].tokens) == 6
116
+ assert len(doc.sentences[0].tokens[3].words) == 2
117
+ assert doc.sentences[0].tokens[3].words[0].text == 'newspaper'
118
+
119
+ # double check that this is something unknown to the model
120
+ mwt_processor = pipe.processors["mwt"]
121
+ trainer = mwt_processor.trainer
122
+ expansion = trainer.dict_expansion("newspaper's")
123
+ assert expansion is None
stanza/stanza/tests/pos/__init__.py ADDED
File without changes
stanza/stanza/tests/pos/test_tagger.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run the tagger for a couple iterations on some fake data
3
+
4
+ Uses a couple sentences of UD_English-EWT as training/dev data
5
+ """
6
+
7
+ import os
8
+ import pytest
9
+
10
+ import torch
11
+
12
+ import stanza
13
+ from stanza.models import tagger
14
+ from stanza.models.common import pretrain
15
+ from stanza.models.pos.trainer import Trainer
16
+ from stanza.tests import TEST_WORKING_DIR, TEST_MODELS_DIR
17
+ from stanza.utils.training.common import choose_pos_charlm, build_charlm_args
18
+
19
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
20
+
21
+ TRAIN_DATA = """
22
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
23
+ # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
24
+ 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
25
+ 2 : : PUNCT : _ 1 punct 1:punct _
26
+ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
27
+ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
28
+ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
29
+ 6 that that SCONJ IN _ 9 mark 9:mark _
30
+ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
31
+ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
32
+ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
33
+ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _
34
+ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
35
+ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
36
+ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
37
+ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
38
+ 15 in in ADP IN _ 16 case 16:case _
39
+ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
40
+ 17 . . PUNCT . _ 1 punct 1:punct _
41
+
42
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
43
+ # text = Two of them were being run by 2 officials of the Ministry of the Interior!
44
+ 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
45
+ 2 of of ADP IN _ 3 case 3:case _
46
+ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
47
+ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
48
+ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
49
+ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
50
+ 7 by by ADP IN _ 9 case 9:case _
51
+ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
52
+ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
53
+ 10 of of ADP IN _ 12 case 12:case _
54
+ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
55
+ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
56
+ 13 of of ADP IN _ 15 case 15:case _
57
+ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
58
+ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
59
+ 16 ! ! PUNCT . _ 6 punct 6:punct _
60
+
61
+ """.lstrip()
62
+
63
+ TRAIN_DATA_2 = """
64
+ # sent_id = 11
65
+ # text = It's all hers!
66
+ # previous = Which person owns this?
67
+ # comment = predeterminer modifier
68
+ 1 It it PRON PRP Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No
69
+ 2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _
70
+ 3 all all DET DT Case=Nom 4 det:predet _ _
71
+ 4 hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
72
+ 5 ! ! PUNCT . _ 4 punct _ _
73
+
74
+ """.lstrip()
75
+
76
+ TRAIN_DATA_NO_UPOS = """
77
+ # sent_id = 11
78
+ # text = It's all hers!
79
+ # previous = Which person owns this?
80
+ # comment = predeterminer modifier
81
+ 1 It it _ PRP Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No
82
+ 2 's be _ VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _
83
+ 3 all all _ DT Case=Nom 4 det:predet _ _
84
+ 4 hers hers _ PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
85
+ 5 ! ! _ . _ 4 punct _ _
86
+
87
+ """.lstrip()
88
+
89
+ TRAIN_DATA_NO_XPOS = """
90
+ # sent_id = 11
91
+ # text = It's all hers!
92
+ # previous = Which person owns this?
93
+ # comment = predeterminer modifier
94
+ 1 It it PRON _ Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No
95
+ 2 's be AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _
96
+ 3 all all DET _ Case=Nom 4 det:predet _ _
97
+ 4 hers hers PRON _ Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
98
+ 5 ! ! PUNCT _ _ 4 punct _ _
99
+
100
+ """.lstrip()
101
+
102
+ TRAIN_DATA_NO_FEATS = """
103
+ # sent_id = 11
104
+ # text = It's all hers!
105
+ # previous = Which person owns this?
106
+ # comment = predeterminer modifier
107
+ 1 It it PRON PRP _ 4 nsubj _ SpaceAfter=No
108
+ 2 's be AUX VBZ _ 4 cop _ _
109
+ 3 all all DET DT _ 4 det:predet _ _
110
+ 4 hers hers PRON PRP _ 0 root _ SpaceAfter=No
111
+ 5 ! ! PUNCT . _ 4 punct _ _
112
+
113
+ """.lstrip()
114
+
115
+ DEV_DATA = """
116
+ 1 From from ADP IN _ 3 case 3:case _
117
+ 2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
118
+ 3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
119
+ 4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
120
+ 5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
121
+ 6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
122
+ 7 : : PUNCT : _ 4 punct 4:punct _
123
+
124
+ """.lstrip()
125
+
126
+ class TestTagger:
127
+ @pytest.fixture(scope="class")
128
+ def wordvec_pretrain_file(self):
129
+ return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
130
+
131
+ @pytest.fixture(scope="class")
132
+ def charlm_args(self):
133
+ charlm = choose_pos_charlm("en", "test", "default")
134
+ charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR)
135
+ return charlm_args
136
+
137
+ def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None):
138
+ """
139
+ Run the training for a few iterations, load & return the model
140
+ """
141
+ dev_file = str(tmp_path / "dev.conllu")
142
+ pred_file = str(tmp_path / "pred.conllu")
143
+
144
+ save_name = "test_tagger.pt"
145
+ save_file = str(tmp_path / save_name)
146
+
147
+ if isinstance(train_text, str):
148
+ train_text = [train_text]
149
+ train_files = []
150
+ for idx, train_blob in enumerate(train_text):
151
+ train_file = str(tmp_path / ("train_%d.conllu" % idx))
152
+ with open(train_file, "w", encoding="utf-8") as fout:
153
+ fout.write(train_blob)
154
+ train_files.append(train_file)
155
+ train_file = ";".join(train_files)
156
+
157
+ with open(dev_file, "w", encoding="utf-8") as fout:
158
+ fout.write(dev_text)
159
+
160
+ args = ["--wordvec_pretrain_file", wordvec_pretrain_file,
161
+ "--train_file", train_file,
162
+ "--eval_file", dev_file,
163
+ "--output_file", pred_file,
164
+ "--log_step", "10",
165
+ "--eval_interval", "20",
166
+ "--max_steps", "100",
167
+ "--shorthand", "en_test",
168
+ "--save_dir", str(tmp_path),
169
+ "--save_name", save_name,
170
+ "--lang", "en"]
171
+ if not augment_nopunct:
172
+ args.extend(["--augment_nopunct", "0.0"])
173
+ if extra_args is not None:
174
+ args = args + extra_args
175
+ tagger.main(args)
176
+
177
+ assert os.path.exists(save_file)
178
+ pt = pretrain.Pretrain(wordvec_pretrain_file)
179
+ saved_model = Trainer(pretrain=pt, model_file=save_file)
180
+ return saved_model
181
+
182
+ def test_train(self, tmp_path, wordvec_pretrain_file, augment_nopunct=True):
183
+ """
184
+ Simple test of a few 'epochs' of tagger training
185
+ """
186
+ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)
187
+
188
+ def test_vocab_cutoff(self, tmp_path, wordvec_pretrain_file):
189
+ """
190
+ Test that the vocab cutoff leaves words we expect in the vocab, but not rare words
191
+ """
192
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=["--word_cutoff", "3"])
193
+ word_vocab = trainer.vocab['word']
194
+ assert 'of' in word_vocab
195
+ assert 'officials' in TRAIN_DATA
196
+ assert 'officials' not in word_vocab
197
+
198
+ def test_multiple_files(self, tmp_path, wordvec_pretrain_file):
199
+ """
200
+ Test that multiple train files works
201
+
202
+ Checks for evidence of it working by looking for words from the second file in the vocab
203
+ """
204
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA, TRAIN_DATA_2 * 3], DEV_DATA, extra_args=["--word_cutoff", "3"])
205
+ word_vocab = trainer.vocab['word']
206
+ assert 'of' in word_vocab
207
+ assert 'officials' in TRAIN_DATA
208
+ assert 'officials' not in word_vocab
209
+
210
+ assert ' hers ' not in TRAIN_DATA
211
+ assert ' hers ' in TRAIN_DATA_2
212
+ assert 'hers' in word_vocab
213
+
214
+ def test_train_zero_augment(self, tmp_path, wordvec_pretrain_file):
215
+ """
216
+ Train with the punct augmentation set to zero
217
+
218
+ Distinguishs cases where training works w/ or w/o augmentation
219
+ """
220
+ extra_args = ['--augment_nopunct', '0.0']
221
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
222
+
223
+ def test_train_100_augment(self, tmp_path, wordvec_pretrain_file):
224
+ """
225
+ Train with the punct augmentation set to 1.0
226
+
227
+ Distinguishs cases where training works w/ or w/o augmentation
228
+ """
229
+ extra_args = ['--augment_nopunct', '1.0']
230
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
231
+
232
+ def test_train_charlm(self, tmp_path, wordvec_pretrain_file, charlm_args):
233
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=charlm_args)
234
+
235
+ def test_train_charlm_projection(self, tmp_path, wordvec_pretrain_file, charlm_args):
236
+ extra_args = charlm_args + ['--charlm_transform_dim', '100']
237
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
238
+
239
+ def test_missing_column(self, tmp_path, wordvec_pretrain_file):
240
+ """
241
+ Test that using train files with missing columns works
242
+
243
+ In this test, we create three separate files, each with a single training entry.
244
+ We then train on an amalgam of those three files with a batch size of 1, saving after each batch.
245
+ This will ensure that only one item is used for each training loop and we can inspect the models which were saved.
246
+
247
+ Since each of the three files have exactly one column missing
248
+ from the training data, we expect to see the output maps for
249
+ each column stay unchanged in one iteration and change in the
250
+ other two.
251
+ """
252
+ # use SGD because some old versions of pytorch with Adam keep
253
+ # learning a value even if the loss is 0 in subsequent steps
254
+ # (perhaps it had a momentum by default?)
255
+ extra_args = ['--save_each', '--eval_interval', '1', '--max_steps', '3', '--batch_size', '1', '--optim', 'sgd']
256
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_FEATS], DEV_DATA, extra_args=extra_args)
257
+ save_each_name = tagger.save_each_file_name(trainer.args)
258
+ model_files = [save_each_name % i for i in range(4)]
259
+ assert all(os.path.exists(x) for x in model_files)
260
+ pt = pretrain.Pretrain(wordvec_pretrain_file)
261
+ saved_trainers = [Trainer(pretrain=pt, model_file=model_file) for model_file in model_files]
262
+
263
+ upos_unchanged = 0
264
+ xpos_unchanged = 0
265
+ ufeats_unchanged = 0
266
+ for t1, t2 in zip(saved_trainers[:-1], saved_trainers[1:]):
267
+ upos_unchanged += torch.allclose(t1.model.upos_clf.weight, t2.model.upos_clf.weight)
268
+ xpos_unchanged += torch.allclose(t1.model.xpos_clf.W_bilin.weight, t2.model.xpos_clf.W_bilin.weight)
269
+ ufeats_unchanged += all(torch.allclose(f1.W_bilin.weight, f2.W_bilin.weight) for f1, f2 in zip(t1.model.ufeats_clf, t2.model.ufeats_clf))
270
+ upos_norms = [torch.linalg.norm(t.model.upos_clf.weight) for t in saved_trainers]
271
+ assert upos_unchanged == 1, "Unchanged: {} {} {} {}".format(upos_unchanged, xpos_unchanged, ufeats_unchanged, upos_norms)
272
+ assert xpos_unchanged == 1, "Unchanged: %d %d %d" % (upos_unchanged, xpos_unchanged, ufeats_unchanged)
273
+ assert ufeats_unchanged == 1, "Unchanged: %d %d %d" % (upos_unchanged, xpos_unchanged, ufeats_unchanged)
274
+
275
+ def test_save_each(self, tmp_path, wordvec_pretrain_file):
276
+ extra_args = ['--save_each']
277
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
278
+ save_each_name = tagger.save_each_file_name(trainer.args)
279
+ expected_models = sorted(set([save_each_name % i for i in range(0, trainer.args['max_steps']+1, trainer.args['eval_interval'])]))
280
+ assert len(expected_models) == 6
281
+ for model_name in expected_models:
282
+ assert os.path.exists(model_name)
283
+
284
+
285
+ def test_with_bert(self, tmp_path, wordvec_pretrain_file):
286
+ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert'])
287
+
288
+ def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
289
+ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])
290
+
291
+ def test_with_bert_finetune(self, tmp_path, wordvec_pretrain_file):
292
+ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_learning_rate', '0.01', '--bert_hidden_layers', '2'])
293
+
294
+ def test_bert_pipeline(self, tmp_path, wordvec_pretrain_file):
295
+ """
296
+ Test training the tagger, then using it in a pipeline
297
+
298
+ The pipeline use of the tagger also tests the longer-than-maxlen workaround for the transformer
299
+ """
300
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert'])
301
+ save_name = trainer.args['save_name']
302
+ save_file = str(tmp_path / save_name)
303
+ assert os.path.exists(save_file)
304
+
305
+ pipe = stanza.Pipeline("en", processors="tokenize,pos", models_dir=TEST_MODELS_DIR, pos_model_path=save_file, pos_pretrain_path=wordvec_pretrain_file)
306
+ trainer = pipe.processors['pos'].trainer
307
+ assert trainer.args['save_name'] == save_name
308
+
309
+ # these should be one chunk only
310
+ doc = pipe("foo " * 100)
311
+ doc = pipe("foo " * 500)
312
+ # this is two chunks of bert embedding
313
+ doc = pipe("foo " * 1000)
314
+ # this is multiple chunks
315
+ doc = pipe("foo " * 2000)
stanza/stanza/tests/resources/__init__.py ADDED
File without changes
stanza/stanza/tests/resources/test_default_packages.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import stanza
4
+
5
+ from stanza.resources import default_packages
6
+
7
+ def test_default_pretrains():
8
+ """
9
+ Test that all languages with a default treebank have a default pretrain or are specifically marked as not having a pretrain
10
+ """
11
+ for lang in default_packages.default_treebanks.keys():
12
+ assert lang in default_packages.no_pretrain_languages or lang in default_packages.default_pretrains, "Lang %s does not have a default pretrain marked!" % lang
13
+
14
+ def test_no_pretrain_languages():
15
+ """
16
+ Test that no languages have no_default_pretrain marked despite having a pretrain
17
+ """
18
+ for lang in default_packages.no_pretrain_languages:
19
+ assert lang not in default_packages.default_pretrains, "Lang %s is marked as no_pretrain but has a default pretrain!" % lang
20
+
21
+
22
+
23
+
24
+
stanza/stanza/tests/resources/test_prepare_resources.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ import stanza
4
+ import stanza.resources.prepare_resources as prepare_resources
5
+
6
+ from stanza.tests import *
7
+
8
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
9
+
10
+ def test_split_model_name():
11
+ # Basic test
12
+ lang, package, processor = prepare_resources.split_model_name('ro_nonstandard_tagger.pt')
13
+ assert lang == 'ro'
14
+ assert package == 'nonstandard'
15
+ assert processor == 'pos'
16
+
17
+ # Check that nertagger is found even though it also ends with tagger
18
+ # Check that ncbi_disease is correctly partitioned despite the extra _
19
+ lang, package, processor = prepare_resources.split_model_name('en_ncbi_disease_nertagger.pt')
20
+ assert lang == 'en'
21
+ assert package == 'ncbi_disease'
22
+ assert processor == 'ner'
23
+
24
+ # assert that processors with _ in them are also okay
25
+ lang, package, processor = prepare_resources.split_model_name('en_pubmed_forward_charlm.pt')
26
+ assert lang == 'en'
27
+ assert package == 'pubmed'
28
+ assert processor == 'forward_charlm'
29
+
30
+
stanza/stanza/tests/server/test_server_misc.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Misc tests for the server
3
+ """
4
+
5
+ import pytest
6
+ import re
7
+ import stanza.server as corenlp
8
+ from stanza.tests import compare_ignoring_whitespace
9
+
10
+ pytestmark = pytest.mark.client
11
+
12
+ EN_DOC = "Joe Smith lives in California."
13
+
14
+ EN_DOC_GOLD = """
15
+ Sentence #1 (6 tokens):
16
+ Joe Smith lives in California.
17
+
18
+ Tokens:
19
+ [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP Lemma=Joe NamedEntityTag=PERSON]
20
+ [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP Lemma=Smith NamedEntityTag=PERSON]
21
+ [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ Lemma=live NamedEntityTag=O]
22
+ [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN Lemma=in NamedEntityTag=O]
23
+ [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP Lemma=California NamedEntityTag=STATE_OR_PROVINCE]
24
+ [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=. Lemma=. NamedEntityTag=O]
25
+
26
+ Dependency Parse (enhanced plus plus dependencies):
27
+ root(ROOT-0, lives-3)
28
+ compound(Smith-2, Joe-1)
29
+ nsubj(lives-3, Smith-2)
30
+ case(California-5, in-4)
31
+ obl:in(lives-3, California-5)
32
+ punct(lives-3, .-6)
33
+
34
+ Extracted the following NER entity mentions:
35
+ Joe Smith PERSON PERSON:0.9972202681743931
36
+ California STATE_OR_PROVINCE LOCATION:0.9990868267559281
37
+
38
+ Extracted the following KBP triples:
39
+ 1.0 Joe Smith per:statesorprovinces_of_residence California
40
+ """
41
+
42
+
43
+ EN_DOC_POS_ONLY_GOLD = """
44
+ Sentence #1 (6 tokens):
45
+ Joe Smith lives in California.
46
+
47
+ Tokens:
48
+ [Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP]
49
+ [Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP]
50
+ [Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ]
51
+ [Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN]
52
+ [Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP]
53
+ [Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.]
54
+ """
55
+
56
+ def test_english_request():
57
+ """ Test case of starting server with Spanish defaults, and then requesting default English properties """
58
+ with corenlp.CoreNLPClient(properties='spanish', server_id='test_spanish_english_request') as client:
59
+ ann = client.annotate(EN_DOC, properties='english', output_format='text')
60
+ compare_ignoring_whitespace(ann, EN_DOC_GOLD)
61
+
62
+ # Rerun the test with a server created in English mode to verify
63
+ # that the expected output is what the defaults actually give us
64
+ with corenlp.CoreNLPClient(properties='english', server_id='test_english_request') as client:
65
+ ann = client.annotate(EN_DOC, output_format='text')
66
+ compare_ignoring_whitespace(ann, EN_DOC_GOLD)
67
+
68
+
69
+ def test_default_annotators():
70
+ """
71
+ Test case of creating a client with start_server=False and a set of annotators
72
+ The annotators should be used instead of the server's default annotators
73
+ """
74
+ with corenlp.CoreNLPClient(server_id='test_default_annotators',
75
+ output_format='text',
76
+ annotators=['tokenize','ssplit','pos','lemma','ner','depparse']) as client:
77
+ with corenlp.CoreNLPClient(start_server=False,
78
+ output_format='text',
79
+ annotators=['tokenize','ssplit','pos']) as client2:
80
+ ann = client2.annotate(EN_DOC)
81
+
82
+ expected_codepoints = ((0, 1), (2, 4), (5, 8), (9, 15), (16, 20))
83
+ expected_characters = ((0, 1), (2, 4), (5, 10), (11, 17), (18, 22))
84
+ codepoint_doc = "I am 𝒚̂𝒊 random text"
85
+
86
+ def test_codepoints():
87
+ """ Test case of asking for codepoints from the English tokenizer """
88
+ with corenlp.CoreNLPClient(annotators=['tokenize','ssplit'], # 'depparse','coref'],
89
+ properties={'tokenize.codepoint': 'true'}) as client:
90
+ ann = client.annotate(codepoint_doc)
91
+ for i, (codepoints, characters) in enumerate(zip(expected_codepoints, expected_characters)):
92
+ token = ann.sentence[0].token[i]
93
+ assert token.codepointOffsetBegin == codepoints[0]
94
+ assert token.codepointOffsetEnd == codepoints[1]
95
+ assert token.beginChar == characters[0]
96
+ assert token.endChar == characters[1]
97
+
98
+ def test_codepoint_text():
99
+ """ Test case of extracting the correct sentence text using codepoints """
100
+
101
+ text = 'Unban mox opal 🐱. This is a second sentence.'
102
+
103
+ with corenlp.CoreNLPClient(annotators=["tokenize","ssplit"],
104
+ properties={'tokenize.codepoint': 'true'}) as client:
105
+ ann = client.annotate(text)
106
+
107
+ text_start = ann.sentence[0].token[0].codepointOffsetBegin
108
+ text_end = ann.sentence[0].token[-1].codepointOffsetEnd
109
+ sentence_text = text[text_start:text_end]
110
+ assert sentence_text == 'Unban mox opal 🐱.'
111
+
112
+ text_start = ann.sentence[1].token[0].codepointOffsetBegin
113
+ text_end = ann.sentence[1].token[-1].codepointOffsetEnd
114
+ sentence_text = text[text_start:text_end]
115
+ assert sentence_text == 'This is a second sentence.'
stanza/stanza/utils/datasets/common.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ from enum import Enum
4
+ import glob
5
+ import logging
6
+ import os
7
+ import re
8
+ import subprocess
9
+ import sys
10
+
11
+ from stanza.models.common.short_name_to_treebank import canonical_treebank_name
12
+ import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data
13
+ import stanza.utils.default_paths as default_paths
14
+
15
+ logger = logging.getLogger('stanza')
16
+
17
+ # RE to see if the index of a conllu line represents an MWT
18
+ MWT_RE = re.compile("^[0-9]+[-][0-9]+")
19
+
20
+ # RE to see if the index of a conllu line represents an MWT or copy node
21
+ MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+")
22
+
23
+ # more restrictive than an actual int as we expect certain formats in the conllu files
24
+ INT_RE = re.compile("^[0-9]+$")
25
+
26
+ CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl")
27
+
28
+ class ModelType(Enum):
29
+ TOKENIZER = 1
30
+ MWT = 2
31
+ POS = 3
32
+ LEMMA = 4
33
+ DEPPARSE = 5
34
+
35
+ class UnknownDatasetError(ValueError):
36
+ def __init__(self, dataset, text):
37
+ super().__init__(text)
38
+ self.dataset = dataset
39
+
40
+ def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")):
41
+ """
42
+ Uses the udtools perl script to convert a conllu file to txt
43
+
44
+ TODO: switch to a python version to get rid of some perl dependence
45
+ """
46
+ for dataset in shards:
47
+ output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
48
+ output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
49
+
50
+ if not os.path.exists(output_conllu):
51
+ # the perl script doesn't raise an error code for file not found!
52
+ raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu)
53
+ # use an external script to produce the txt files
54
+ subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True)
55
+
56
+ def mwt_name(base_dir, short_name, dataset):
57
+ return os.path.join(base_dir, f"{short_name}-ud-{dataset}-mwt.json")
58
+
59
+ def tokenizer_conllu_name(base_dir, short_name, dataset):
60
+ return os.path.join(base_dir, f"{short_name}.{dataset}.gold.conllu")
61
+
62
+ def prepare_tokenizer_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset):
63
+ labels_filename = f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels"
64
+ mwt_filename = mwt_name(tokenizer_dir, short_name, dataset)
65
+ prepare_tokenizer_data.main([input_txt,
66
+ input_conllu,
67
+ "-o", labels_filename,
68
+ "-m", mwt_filename])
69
+
70
+ def prepare_tokenizer_treebank_labels(tokenizer_dir, short_name):
71
+ """
72
+ Given the txt and gold.conllu files, prepare mwt and label files for train/dev/test
73
+ """
74
+ for dataset in ("train", "dev", "test"):
75
+ output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
76
+ output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
77
+ try:
78
+ prepare_tokenizer_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset)
79
+ except (KeyboardInterrupt, SystemExit):
80
+ raise
81
+ except:
82
+ print("Failed to convert %s to %s" % (output_txt, output_conllu))
83
+ raise
84
+
85
+ def read_sentences_from_conllu(filename):
86
+ """
87
+ Reads a conllu file as a list of list of strings
88
+
89
+ Finding a blank line separates the lists
90
+ """
91
+ sents = []
92
+ cache = []
93
+ with open(filename, encoding="utf-8") as infile:
94
+ for line in infile:
95
+ line = line.strip()
96
+ if len(line) == 0:
97
+ if len(cache) > 0:
98
+ sents.append(cache)
99
+ cache = []
100
+ continue
101
+ cache.append(line)
102
+ if len(cache) > 0:
103
+ sents.append(cache)
104
+ return sents
105
+
106
+ def maybe_add_fake_dependencies(lines):
107
+ """
108
+ Possibly add fake dependencies in columns 6 and 7 (counting from 0)
109
+
110
+ The conllu scripts need the dependencies column filled out, so in
111
+ the case of models we build without dependency data, we need to
112
+ add those fake dependencies in order to use the eval script etc
113
+
114
+ lines: a list of strings with 10 tab separated columns
115
+ comments are allowed (they will be skipped)
116
+
117
+ returns: the same strings, but with fake dependencies added
118
+ if columns 6 and 7 were empty
119
+ """
120
+ new_lines = []
121
+ root_idx = None
122
+ first_idx = None
123
+ for line_idx, line in enumerate(lines):
124
+ if line.startswith("#"):
125
+ new_lines.append(line)
126
+ continue
127
+
128
+ pieces = line.split("\t")
129
+ if MWT_OR_COPY_RE.match(pieces[0]):
130
+ new_lines.append(line)
131
+ continue
132
+
133
+ token_idx = int(pieces[0])
134
+ if pieces[6] != '_':
135
+ if pieces[6] == '0':
136
+ root_idx = token_idx
137
+ new_lines.append(line)
138
+ elif token_idx == 1:
139
+ # note that the comments might make this not the first line
140
+ # we keep track of this separately so we can either make this the root,
141
+ # or set this to be the root later
142
+ first_idx = line_idx
143
+ new_lines.append(pieces)
144
+ else:
145
+ pieces[6] = "1"
146
+ pieces[7] = "dep"
147
+ new_lines.append("\t".join(pieces))
148
+ if first_idx is not None:
149
+ if root_idx is None:
150
+ new_lines[first_idx][6] = "0"
151
+ new_lines[first_idx][7] = "root"
152
+ else:
153
+ new_lines[first_idx][6] = str(root_idx)
154
+ new_lines[first_idx][7] = "dep"
155
+ new_lines[first_idx] = "\t".join(new_lines[first_idx])
156
+ return new_lines
157
+
158
+ def write_sentences_to_file(outfile, sents):
159
+ for lines in sents:
160
+ lines = maybe_add_fake_dependencies(lines)
161
+ for line in lines:
162
+ print(line, file=outfile)
163
+ print("", file=outfile)
164
+
165
+ def write_sentences_to_conllu(filename, sents):
166
+ with open(filename, 'w', encoding="utf-8") as outfile:
167
+ write_sentences_to_file(outfile, sents)
168
+
169
+ def find_treebank_dataset_file(treebank, udbase_dir, dataset, extension, fail=False, env_var="UDBASE"):
170
+ """
171
+ For a given treebank, dataset, extension, look for the exact filename to use.
172
+
173
+ Sometimes the short name we use is different from the short name
174
+ used by UD. For example, Norwegian or Chinese. Hence the reason
175
+ to not hardcode it based on treebank
176
+
177
+ set fail=True to fail if the file is not found
178
+ """
179
+ if treebank.startswith("UD_Korean") and treebank.endswith("_seg"):
180
+ treebank = treebank[:-4]
181
+ filename = os.path.join(udbase_dir, treebank, f"*-ud-{dataset}.{extension}")
182
+ files = glob.glob(filename)
183
+ if len(files) == 0:
184
+ if fail:
185
+ raise FileNotFoundError("Could not find any treebank files which matched {}\nIf you have the data elsewhere, you can change the base directory for the search by changing the {} environment variable".format(filename, env_var))
186
+ else:
187
+ return None
188
+ elif len(files) == 1:
189
+ return files[0]
190
+ else:
191
+ raise RuntimeError(f"Unexpected number of files matched '{udbase_dir}/{treebank}/*-ud-{dataset}.{extension}'")
192
+
193
+ def mostly_underscores(filename):
194
+ """
195
+ Certain treebanks have proprietary data, so the text is hidden
196
+
197
+ For example:
198
+ UD_Arabic-NYUAD
199
+ UD_English-ESL
200
+ UD_English-GUMReddit
201
+ UD_Hindi_English-HIENCS
202
+ UD_Japanese-BCCWJ
203
+ """
204
+ underscore_count = 0
205
+ total_count = 0
206
+ for line in open(filename).readlines():
207
+ line = line.strip()
208
+ if not line:
209
+ continue
210
+ if line.startswith("#"):
211
+ continue
212
+ total_count = total_count + 1
213
+ pieces = line.split("\t")
214
+ if pieces[1] in ("_", "-"):
215
+ underscore_count = underscore_count + 1
216
+ return underscore_count / total_count > 0.5
217
+
218
+ def num_words_in_file(conllu_file):
219
+ """
220
+ Count the number of non-blank lines in a conllu file
221
+ """
222
+ count = 0
223
+ with open(conllu_file) as fin:
224
+ for line in fin:
225
+ line = line.strip()
226
+ if not line:
227
+ continue
228
+ if line.startswith("#"):
229
+ continue
230
+ count = count + 1
231
+ return count
232
+
233
+
234
+ def get_ud_treebanks(udbase_dir, filtered=True):
235
+ """
236
+ Looks in udbase_dir for all the treebanks which have both train, dev, and test
237
+ """
238
+ treebanks = sorted(glob.glob(udbase_dir + "/UD_*"))
239
+ # skip UD_English-GUMReddit as it is usually incorporated into UD_English-GUM
240
+ treebanks = [os.path.split(t)[1] for t in treebanks]
241
+ treebanks = [t for t in treebanks if t != "UD_English-GUMReddit"]
242
+ if filtered:
243
+ treebanks = [t for t in treebanks
244
+ if (find_treebank_dataset_file(t, udbase_dir, "train", "conllu") and
245
+ # this will be fixed using XV
246
+ #find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") and
247
+ find_treebank_dataset_file(t, udbase_dir, "test", "conllu"))]
248
+ treebanks = [t for t in treebanks
249
+ if not mostly_underscores(find_treebank_dataset_file(t, udbase_dir, "train", "conllu"))]
250
+ # eliminate partial treebanks (fixed with XV) for which we only have 1000 words or less
251
+ # if the train set is small and the test set is large enough, we'll flip them
252
+ treebanks = [t for t in treebanks
253
+ if (find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") or
254
+ num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "train", "conllu")) > 1000 or
255
+ num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "test", "conllu")) > 5000)]
256
+ return treebanks
257
+
258
+ def build_argparse():
259
+ parser = argparse.ArgumentParser()
260
+ parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
261
+
262
+ return parser
263
+
264
+
265
+ def main(process_treebank, model_type, add_specific_args=None):
266
+ logger.info("Datasets program called with:\n" + " ".join(sys.argv))
267
+
268
+ parser = build_argparse()
269
+ if add_specific_args is not None:
270
+ add_specific_args(parser)
271
+ args = parser.parse_args()
272
+
273
+ paths = default_paths.get_default_paths()
274
+
275
+ treebanks = []
276
+ for treebank in args.treebanks:
277
+ if treebank.lower() in ('ud_all', 'all_ud'):
278
+ ud_treebanks = get_ud_treebanks(paths["UDBASE"])
279
+ treebanks.extend(ud_treebanks)
280
+ else:
281
+ # If this is a known UD short name, use the official name (we need it for the paths)
282
+ treebank = canonical_treebank_name(treebank)
283
+ treebanks.append(treebank)
284
+
285
+ for treebank in treebanks:
286
+ process_treebank(treebank, model_type, paths, args)
stanza/stanza/utils/datasets/conllu_to_text.pl ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env perl
2
+ # Extracts raw text from CoNLL-U file. Uses newdoc and newpar tags when available.
3
+ # Copyright © 2017 Dan Zeman <zeman@ufal.mff.cuni.cz>
4
+ # License: GNU GPL
5
+
6
+ use utf8;
7
+ use open ':utf8';
8
+ binmode(STDIN, ':utf8');
9
+ binmode(STDOUT, ':utf8');
10
+ binmode(STDERR, ':utf8');
11
+ use Getopt::Long;
12
+
13
+ # Language code 'zh' or 'ja' will trigger Chinese-like text formatting.
14
+ my $language = 'en';
15
+ GetOptions
16
+ (
17
+ 'language=s' => \$language
18
+ );
19
+ my $chinese = $language =~ m/^(zh|ja|lzh|yue)(_|$)/;
20
+
21
+ my $text = ''; # from the text attribute of the sentence
22
+ my $ftext = ''; # from the word forms of the tokens
23
+ my $newpar = 0;
24
+ my $newdoc = 0;
25
+ my $buffer = '';
26
+ my $start = 1;
27
+ my $mwtlast;
28
+ while(<>)
29
+ {
30
+ if(m/^\#\s*text\s*=\s*(.+)/)
31
+ {
32
+ $text = $1;
33
+ }
34
+ elsif(m/^\#\s*newpar(\s|$)/i)
35
+ {
36
+ $newpar = 1;
37
+ }
38
+ elsif(m/^\#\s*newdoc(\s|$)/i)
39
+ {
40
+ $newdoc = 1;
41
+ }
42
+ elsif(m/^\d+-(\d+)\t/)
43
+ {
44
+ $mwtlast = $1;
45
+ my @f = split(/\t/, $_);
46
+ # Paragraphs may start in the middle of a sentence (bulleted lists, verse etc.)
47
+ # The first token of the new paragraph has "NewPar=Yes" in the MISC column.
48
+ # Multi-word tokens have this in the token-introducing line.
49
+ if($f[9] =~ m/NewPar=Yes/i)
50
+ {
51
+ # Empty line between documents and paragraphs. (There may have been
52
+ # a paragraph break before the first part of this sentence as well!)
53
+ $buffer = print_new_paragraph_if_needed($start, $newdoc, $newpar, $buffer);
54
+ $buffer .= $ftext;
55
+ # Line breaks at word boundaries after at most 80 characters.
56
+ $buffer = print_lines_from_buffer($buffer, 80, $chinese);
57
+ print("$buffer\n\n");
58
+ $buffer = '';
59
+ # Start is only true until we write the first sentence of the input stream.
60
+ $start = 0;
61
+ $newdoc = 0;
62
+ $newpar = 0;
63
+ $text = '';
64
+ $ftext = '';
65
+ }
66
+ $ftext .= $f[1];
67
+ $ftext .= ' ' unless($f[9] =~ m/SpaceAfter=No/);
68
+ }
69
+ elsif(m/^(\d+)\t/ && !(defined($mwtlast) && $1<=$mwtlast))
70
+ {
71
+ $mwtlast = undef;
72
+ my @f = split(/\t/, $_);
73
+ # Paragraphs may start in the middle of a sentence (bulleted lists, verse etc.)
74
+ # The first token of the new paragraph has "NewPar=Yes" in the MISC column.
75
+ # Multi-word tokens have this in the token-introducing line.
76
+ if($f[9] =~ m/NewPar=Yes/i)
77
+ {
78
+ # Empty line between documents and paragraphs. (There may have been
79
+ # a paragraph break before the first part of this sentence as well!)
80
+ $buffer = print_new_paragraph_if_needed($start, $newdoc, $newpar, $buffer);
81
+ $buffer .= $ftext;
82
+ # Line breaks at word boundaries after at most 80 characters.
83
+ $buffer = print_lines_from_buffer($buffer, 80, $chinese);
84
+ print("$buffer\n\n");
85
+ $buffer = '';
86
+ # Start is only true until we write the first sentence of the input stream.
87
+ $start = 0;
88
+ $newdoc = 0;
89
+ $newpar = 0;
90
+ $text = '';
91
+ $ftext = '';
92
+ }
93
+ $ftext .= $f[1];
94
+ $ftext .= ' ' unless($f[9] =~ m/SpaceAfter=No/);
95
+ }
96
+ elsif(m/^\s*$/)
97
+ {
98
+ # In a valid CoNLL-U file, $text should be equal to $ftext except for the
99
+ # space after the last token. However, if there have been intra-sentential
100
+ # paragraph breaks, $ftext contains only the part after the last such
101
+ # break, and $text is empty. Hence we currently use $ftext everywhere
102
+ # and ignore $text, even though we note it when seeing the text attribute.
103
+ # $text .= ' ' unless($chinese);
104
+ # Empty line between documents and paragraphs.
105
+ $buffer = print_new_paragraph_if_needed($start, $newdoc, $newpar, $buffer);
106
+ $buffer .= $ftext;
107
+ # Line breaks at word boundaries after at most 80 characters.
108
+ $buffer = print_lines_from_buffer($buffer, 80, $chinese);
109
+ # Start is only true until we write the first sentence of the input stream.
110
+ $start = 0;
111
+ $newdoc = 0;
112
+ $newpar = 0;
113
+ $text = '';
114
+ $ftext = '';
115
+ $mwtlast = undef;
116
+ }
117
+ }
118
+ # There may be unflushed buffer contents after the last sentence, less than 80 characters
119
+ # (otherwise we would have already dealt with it), so just flush it.
120
+ if($buffer ne '')
121
+ {
122
+ print("$buffer\n");
123
+ }
124
+
125
+
126
+
127
+ #------------------------------------------------------------------------------
128
+ # Checks whether we have to print an extra line to separate paragraphs. Does it
129
+ # if necessary. Returns the updated buffer.
130
+ #------------------------------------------------------------------------------
131
+ sub print_new_paragraph_if_needed
132
+ {
133
+ my $start = shift;
134
+ my $newdoc = shift;
135
+ my $newpar = shift;
136
+ my $buffer = shift;
137
+ if(!$start && ($newdoc || $newpar))
138
+ {
139
+ if($buffer ne '')
140
+ {
141
+ print("$buffer\n");
142
+ $buffer = '';
143
+ }
144
+ print("\n");
145
+ }
146
+ return $buffer;
147
+ }
148
+
149
+
150
+
151
+ #------------------------------------------------------------------------------
152
+ # Prints as many complete lines of text as there are in the buffer. Returns the
153
+ # remaining contents of the buffer.
154
+ #------------------------------------------------------------------------------
155
+ sub print_lines_from_buffer
156
+ {
157
+ my $buffer = shift;
158
+ # Maximum number of characters allowed on one line, not counting the line
159
+ # break character(s), which also replace any number of trailing spaces.
160
+ # Exception: If there is a word longer than the limit, it will be printed
161
+ # on one line.
162
+ # Note that this algorithm is not suitable for Chinese and Japanese.
163
+ my $limit = shift;
164
+ # We need a different algorithm for Chinese and Japanese.
165
+ my $chinese = shift;
166
+ if($chinese)
167
+ {
168
+ return print_chinese_lines_from_buffer($buffer, $limit);
169
+ }
170
+ if(length($buffer) >= $limit)
171
+ {
172
+ my @cbuffer = split(//, $buffer);
173
+ # There may be more than one new line waiting in the buffer.
174
+ while(scalar(@cbuffer) >= $limit)
175
+ {
176
+ ###!!! We could make it simpler if we ignored multi-space sequences
177
+ ###!!! between words. It sounds OK to ignore them because at the
178
+ ###!!! line break we do not respect original spacing anyway.
179
+ my $i;
180
+ my $ilastspace;
181
+ for($i = 0; $i<=$#cbuffer; $i++)
182
+ {
183
+ if($i>$limit && defined($ilastspace))
184
+ {
185
+ last;
186
+ }
187
+ if($cbuffer[$i] =~ m/\s/)
188
+ {
189
+ $ilastspace = $i;
190
+ }
191
+ }
192
+ if(defined($ilastspace) && $ilastspace>0)
193
+ {
194
+ my @out = @cbuffer[0..($ilastspace-1)];
195
+ splice(@cbuffer, 0, $ilastspace+1);
196
+ print(join('', @out), "\n");
197
+ }
198
+ else
199
+ {
200
+ print(join('', @cbuffer), "\n");
201
+ splice(@cbuffer);
202
+ }
203
+ }
204
+ $buffer = join('', @cbuffer);
205
+ }
206
+ return $buffer;
207
+ }
208
+
209
+
210
+
211
+ #------------------------------------------------------------------------------
212
+ # Prints as many complete lines of text as there are in the buffer. Returns the
213
+ # remaining contents of the buffer. Assumes that there are no spaces between
214
+ # words and lines can be broken between any two characters, as is the custom in
215
+ # Chinese and Japanese.
216
+ #------------------------------------------------------------------------------
217
+ sub print_chinese_lines_from_buffer
218
+ {
219
+ my $buffer = shift;
220
+ # Maximum number of characters allowed on one line, not counting the line
221
+ # break character(s).
222
+ my $limit = shift;
223
+ # We cannot simply print the first $limit characters from the buffer,
224
+ # followed by a line break. There could be embedded Latin words or
225
+ # numbers and we do not want to insert a line break in the middle of
226
+ # a foreign word.
227
+ my @cbuffer = split(//, $buffer);
228
+ while(scalar(@cbuffer) >= $limit)
229
+ {
230
+ my $nprint = 0;
231
+ for(my $i = 0; $i <= $#cbuffer; $i++)
232
+ {
233
+ if($i > $limit && $nprint > 0)
234
+ {
235
+ last;
236
+ }
237
+ unless($i < $#cbuffer && $cbuffer[$i] =~ m/[\p{Latin}0-9]/ && $cbuffer[$i+1] =~ m/[\p{Latin}0-9]/)
238
+ {
239
+ $nprint = $i+1;
240
+ }
241
+ }
242
+ my @out = @cbuffer[0..($nprint-1)];
243
+ splice(@cbuffer, 0, $nprint);
244
+ print(join('', @out), "\n");
245
+ }
246
+ $buffer = join('', @cbuffer);
247
+ return $buffer;
248
+ }
stanza/stanza/utils/datasets/prepare_lemma_classifier.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from stanza.utils.datasets.common import find_treebank_dataset_file, UnknownDatasetError
5
+ from stanza.utils.default_paths import get_default_paths
6
+ from stanza.models.lemma_classifier import prepare_dataset
7
+ from stanza.models.common.short_name_to_treebank import short_name_to_treebank
8
+ from stanza.utils.conll import CoNLL
9
+
10
+ SECTIONS = ("train", "dev", "test")
11
+
12
+ def process_treebank(paths, short_name, word, upos, allowed_lemmas, sections=SECTIONS):
13
+ treebank = short_name_to_treebank(short_name)
14
+ udbase_dir = paths["UDBASE"]
15
+
16
+ output_dir = paths["LEMMA_CLASSIFIER_DATA_DIR"]
17
+ os.makedirs(output_dir, exist_ok=True)
18
+
19
+ output_filenames = []
20
+
21
+ for section in sections:
22
+ filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True)
23
+ output_filename = os.path.join(output_dir, "%s.%s.lemma" % (short_name, section))
24
+ args = ["--conll_path", filename,
25
+ "--target_word", word,
26
+ "--target_upos", upos,
27
+ "--output_path", output_filename]
28
+ if allowed_lemmas is not None:
29
+ args.extend(["--allowed_lemmas", allowed_lemmas])
30
+ prepare_dataset.main(args)
31
+ output_filenames.append(output_filename)
32
+
33
+ return output_filenames
34
+
35
+ def process_en_combined(paths, short_name):
36
+ udbase_dir = paths["UDBASE"]
37
+ output_dir = paths["LEMMA_CLASSIFIER_DATA_DIR"]
38
+ os.makedirs(output_dir, exist_ok=True)
39
+
40
+ train_treebanks = ["UD_English-EWT", "UD_English-GUM", "UD_English-GUMReddit", "UD_English-LinES"]
41
+ test_treebanks = ["UD_English-PUD", "UD_English-Pronouns"]
42
+
43
+ target_word = "'s"
44
+ target_upos = ["AUX"]
45
+
46
+ sentences = [ [], [], [] ]
47
+ for treebank in train_treebanks:
48
+ for section_idx, section in enumerate(SECTIONS):
49
+ filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True)
50
+ doc = CoNLL.conll2doc(filename)
51
+ processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=".*")
52
+ new_sentences = processor.process_document(doc, save_name=None)
53
+ print("Read %d sentences from %s" % (len(new_sentences), filename))
54
+ sentences[section_idx].extend(new_sentences)
55
+ for treebank in test_treebanks:
56
+ section = "test"
57
+ filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True)
58
+ doc = CoNLL.conll2doc(filename)
59
+ processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=".*")
60
+ new_sentences = processor.process_document(doc, save_name=None)
61
+ print("Read %d sentences from %s" % (len(new_sentences), filename))
62
+ sentences[2].extend(new_sentences)
63
+
64
+ for section, section_sentences in zip(SECTIONS, sentences):
65
+ output_filename = os.path.join(output_dir, "%s.%s.lemma" % (short_name, section))
66
+ prepare_dataset.DataProcessor.write_output_file(output_filename, target_upos, section_sentences)
67
+ print("Wrote %s sentences to %s" % (len(section_sentences), output_filename))
68
+
69
+ def process_ja_gsd(paths, short_name):
70
+ # this one looked promising, but only has 10 total dev & test cases
71
+ # 行っ VERB Counter({'行う': 60, '行く': 38})
72
+ # could possibly do
73
+ # ない AUX Counter({'ない': 383, '無い': 99})
74
+ # なく AUX Counter({'無い': 53, 'ない': 42})
75
+ # currently this one has enough in the dev & test data
76
+ # and functions well
77
+ # だ AUX Counter({'だ': 237, 'た': 67})
78
+ word = "だ"
79
+ upos = "AUX"
80
+ allowed_lemmas = None
81
+
82
+ process_treebank(paths, short_name, word, upos, allowed_lemmas)
83
+
84
+ def process_fa_perdt(paths, short_name):
85
+ word = "شد"
86
+ upos = "VERB"
87
+ allowed_lemmas = "کرد|شد"
88
+
89
+ process_treebank(paths, short_name, word, upos, allowed_lemmas)
90
+
91
+ def process_hi_hdtb(paths, short_name):
92
+ word = "के"
93
+ upos = "ADP"
94
+ allowed_lemmas = "का|के"
95
+
96
+ process_treebank(paths, short_name, word, upos, allowed_lemmas)
97
+
98
+ def process_ar_padt(paths, short_name):
99
+ word = "أن"
100
+ upos = "SCONJ"
101
+ allowed_lemmas = "أَن|أَنَّ"
102
+
103
+ process_treebank(paths, short_name, word, upos, allowed_lemmas)
104
+
105
+ def process_el_gdt(paths, short_name):
106
+ """
107
+ All of the Greek lemmas for these words are εγώ or μου
108
+
109
+ τους PRON Counter({'μου': 118, 'εγώ': 32})
110
+ μας PRON Counter({'μου': 89, 'εγώ': 32})
111
+ του PRON Counter({'μου': 82, 'εγώ': 8})
112
+ της PRON Counter({'μου': 80, 'εγώ': 2})
113
+ σας PRON Counter({'μου': 34, 'εγώ': 24})
114
+ μου PRON Counter({'μου': 45, 'εγώ': 10})
115
+ """
116
+ word = "τους|μας|του|της|σας|μου"
117
+ upos = "PRON"
118
+ allowed_lemmas = None
119
+
120
+ process_treebank(paths, short_name, word, upos, allowed_lemmas)
121
+
122
+ DATASET_MAPPING = {
123
+ "ar_padt": process_ar_padt,
124
+ "el_gdt": process_el_gdt,
125
+ "en_combined": process_en_combined,
126
+ "fa_perdt": process_fa_perdt,
127
+ "hi_hdtb": process_hi_hdtb,
128
+ "ja_gsd": process_ja_gsd,
129
+ }
130
+
131
+
132
+ def main(dataset_name):
133
+ paths = get_default_paths()
134
+ print("Processing %s" % dataset_name)
135
+
136
+ # obviously will want to multiplex to multiple languages / datasets
137
+ if dataset_name in DATASET_MAPPING:
138
+ DATASET_MAPPING[dataset_name](paths, dataset_name)
139
+ else:
140
+ raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_lemma_classifier.py")
141
+ print("Done processing %s" % dataset_name)
142
+
143
+ if __name__ == '__main__':
144
+ main(sys.argv[1])
stanza/stanza/utils/datasets/prepare_mwt_treebank.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A script to prepare all MWT datasets.
3
+
4
+ For example, do
5
+ python -m stanza.utils.datasets.prepare_mwt_treebank TREEBANK
6
+ such as
7
+ python -m stanza.utils.datasets.prepare_mwt_treebank UD_English-EWT
8
+
9
+ and it will prepare each of train, dev, test
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+ import shutil
15
+ import tempfile
16
+
17
+ from stanza.utils.conll import CoNLL
18
+ from stanza.models.common.constant import treebank_to_short_name
19
+ import stanza.utils.datasets.common as common
20
+ import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
21
+
22
+ from stanza.utils.datasets.contract_mwt import contract_mwt
23
+
24
+ # languages where the MWTs are always a composition of the words themselves
25
+ KNOWN_COMPOSABLE_MWTS = {"en"}
26
+ # ... but partut is not put together that way
27
+ MWT_EXCEPTIONS = {"en_partut"}
28
+
29
+ def copy_conllu(tokenizer_dir, mwt_dir, short_name, dataset, particle):
30
+ input_conllu_tokenizer = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
31
+ input_conllu_mwt = f"{mwt_dir}/{short_name}.{dataset}.{particle}.conllu"
32
+ shutil.copyfile(input_conllu_tokenizer, input_conllu_mwt)
33
+
34
+ def check_mwt_composition(filename):
35
+ print("Checking the MWTs in %s" % filename)
36
+ doc = CoNLL.conll2doc(filename)
37
+ for sent_idx, sentence in enumerate(doc.sentences):
38
+ for token_idx, token in enumerate(sentence.tokens):
39
+ if len(token.words) > 1:
40
+ expected = "".join(x.text for x in token.words)
41
+ if token.text != expected:
42
+ raise ValueError("Unexpected token composition in filename %s sentence %d id %s token %d: %s instead of %s" % (filename, sent_idx, sentence.sent_id, token_idx, token.text, expected))
43
+
44
+ def process_treebank(treebank, model_type, paths, args):
45
+ short_name = treebank_to_short_name(treebank)
46
+
47
+ mwt_dir = paths["MWT_DATA_DIR"]
48
+ os.makedirs(mwt_dir, exist_ok=True)
49
+
50
+ with tempfile.TemporaryDirectory() as tokenizer_dir:
51
+ paths = dict(paths)
52
+ paths["TOKENIZE_DATA_DIR"] = tokenizer_dir
53
+
54
+ # first we process the tokenization data
55
+ tokenizer_args = argparse.Namespace()
56
+ tokenizer_args.augment = False
57
+ tokenizer_args.prepare_labels = True
58
+ prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, tokenizer_args)
59
+
60
+ copy_conllu(tokenizer_dir, mwt_dir, short_name, "train", "in")
61
+ copy_conllu(tokenizer_dir, mwt_dir, short_name, "dev", "gold")
62
+ copy_conllu(tokenizer_dir, mwt_dir, short_name, "test", "gold")
63
+
64
+ for shard in ("train", "dev", "test"):
65
+ source_filename = common.mwt_name(tokenizer_dir, short_name, shard)
66
+ dest_filename = common.mwt_name(mwt_dir, short_name, shard)
67
+ print("Copying from %s to %s" % (source_filename, dest_filename))
68
+ shutil.copyfile(source_filename, dest_filename)
69
+
70
+ language = short_name.split("_", 1)[0]
71
+ if language in KNOWN_COMPOSABLE_MWTS and short_name not in MWT_EXCEPTIONS:
72
+ print("Language %s is known to have all MWT composed of exactly its word pieces. Checking..." % language)
73
+ check_mwt_composition(f"{mwt_dir}/{short_name}.train.in.conllu")
74
+ check_mwt_composition(f"{mwt_dir}/{short_name}.dev.gold.conllu")
75
+ check_mwt_composition(f"{mwt_dir}/{short_name}.test.gold.conllu")
76
+
77
+ contract_mwt(f"{mwt_dir}/{short_name}.dev.gold.conllu",
78
+ f"{mwt_dir}/{short_name}.dev.in.conllu")
79
+ contract_mwt(f"{mwt_dir}/{short_name}.test.gold.conllu",
80
+ f"{mwt_dir}/{short_name}.test.in.conllu")
81
+
82
+ def main():
83
+ common.main(process_treebank, common.ModelType.MWT)
84
+
85
+ if __name__ == '__main__':
86
+ main()
87
+
88
+
stanza/stanza/utils/datasets/prepare_pos_treebank.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A script to prepare all pos datasets.
3
+
4
+ For example, do
5
+ python -m stanza.utils.datasets.prepare_pos_treebank TREEBANK
6
+ such as
7
+ python -m stanza.utils.datasets.prepare_pos_treebank UD_English-EWT
8
+
9
+ and it will prepare each of train, dev, test
10
+ """
11
+
12
+ import os
13
+ import shutil
14
+
15
+ import stanza.utils.datasets.common as common
16
+ import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
17
+
18
+ def copy_conllu_file_or_zip(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):
19
+ original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.zip"
20
+ copied = f"{dest_dir}/{short_name}.{dest_file}.zip"
21
+
22
+ if os.path.exists(original):
23
+ print("Copying from %s to %s" % (original, copied))
24
+ shutil.copyfile(original, copied)
25
+ else:
26
+ prepare_tokenizer_treebank.copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name)
27
+
28
+
29
+ def process_treebank(treebank, model_type, paths, args):
30
+ prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["POS_DATA_DIR"], postprocess=copy_conllu_file_or_zip)
31
+
32
+ def main():
33
+ common.main(process_treebank, common.ModelType.POS)
34
+
35
+ if __name__ == '__main__':
36
+ main()
37
+
38
+
stanza/stanza/utils/datasets/random_split_conllu.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Randomly split a file into train, dev, and test sections
3
+
4
+ Specifically used in the case of building a tagger from the initial
5
+ POS tagging provided by Isra, but obviously can be used to split any
6
+ conllu file
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import random
12
+
13
+ from stanza.models.common.doc import Document
14
+ from stanza.utils.conll import CoNLL
15
+ from stanza.utils.default_paths import get_default_paths
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--filename', default='extern_data/sindhi/upos/sindhi_upos.conllu', help='Which file to split')
20
+ parser.add_argument('--train', type=float, default=0.8, help='Fraction of the data to use for train')
21
+ parser.add_argument('--dev', type=float, default=0.1, help='Fraction of the data to use for dev')
22
+ parser.add_argument('--test', type=float, default=0.1, help='Fraction of the data to use for test')
23
+ parser.add_argument('--seed', default='1234', help='Random seed to use')
24
+ parser.add_argument('--short_name', default='sd_isra', help='Dataset name to use when writing output files')
25
+ parser.add_argument('--no_remove_xpos', default=True, action='store_false', dest='remove_xpos', help='By default, we remove the xpos from the dataset')
26
+ parser.add_argument('--no_remove_feats', default=True, action='store_false', dest='remove_feats', help='By default, we remove the feats from the dataset')
27
+ parser.add_argument('--output_directory', default=get_default_paths()["POS_DATA_DIR"], help="Where to put the split conllu")
28
+ args = parser.parse_args()
29
+
30
+ weights = (args.train, args.dev, args.test)
31
+
32
+ doc = CoNLL.conll2doc(args.filename)
33
+ random.seed(args.seed)
34
+
35
+ train_doc = ([], [])
36
+ dev_doc = ([], [])
37
+ test_doc = ([], [])
38
+ splits = [train_doc, dev_doc, test_doc]
39
+ for sentence in doc.sentences:
40
+ sentence_dict = sentence.to_dict()
41
+ if args.remove_xpos:
42
+ for x in sentence_dict:
43
+ x.pop('xpos', None)
44
+ if args.remove_feats:
45
+ for x in sentence_dict:
46
+ x.pop('feats', None)
47
+ split = random.choices(splits, weights)[0]
48
+ split[0].append(sentence_dict)
49
+ split[1].append(sentence.comments)
50
+
51
+ splits = [Document(split[0], comments=split[1]) for split in splits]
52
+ for split_doc, split_name in zip(splits, ("train", "dev", "test")):
53
+ filename = os.path.join(args.output_directory, "%s.%s.in.conllu" % (args.short_name, split_name))
54
+ print("Outputting %d sentences to %s" % (len(split_doc.sentences), filename))
55
+ CoNLL.write_doc2conll(split_doc, filename)
56
+
57
+ if __name__ == '__main__':
58
+ main()
59
+
stanza/stanza/utils/datasets/thai_syllable_dict_generator.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import pathlib
3
+ import argparse
4
+
5
+
6
+ def create_dictionary(dataset_dir, save_dir):
7
+ syllables = set()
8
+
9
+ for p in pathlib.Path(dataset_dir).rglob("*.ssg"): # iterate through all files
10
+
11
+ with open(p) as f: # for each file
12
+ sentences = f.readlines()
13
+
14
+ for i in range(len(sentences)):
15
+
16
+ sentences[i] = sentences[i].replace("\n", "")
17
+ sentences[i] = sentences[i].replace("<s/>", "~")
18
+ sentences[i] = sentences[i].split("~") # create list of all syllables
19
+
20
+ syllables = syllables.union(sentences[i])
21
+
22
+
23
+ print(len(syllables))
24
+
25
+ # Filter out syllables with English words
26
+ import re
27
+
28
+ a = []
29
+
30
+ for s in syllables:
31
+ print("---")
32
+ if bool(re.match("^[\u0E00-\u0E7F]*$", s)) and s != "" and " " not in s:
33
+ a.append(s)
34
+ else:
35
+ pass
36
+
37
+ a = set(a)
38
+ a = dict(zip(list(a), range(len(a))))
39
+
40
+ import json
41
+ print(a)
42
+ print(len(a))
43
+ with open(save_dir, "w") as fp:
44
+ json.dump(a, fp)
45
+
46
+ if __name__ == "__main__":
47
+
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument('--dataset_dir', type=str, default="syllable_segmentation_data", help="Directory for syllable dataset")
50
+ parser.add_argument('--save_dir', type=str, default="thai-syllable.json", help="Directory for generated file")
51
+ args = parser.parse_args()
52
+
53
+ create_dictionary(args.dataset_dir, args.save_dir)