Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- stanza/stanza/pipeline/demo/stanza-brat.css +74 -0
- stanza/stanza/pipeline/demo/stanza-parseviewer.js +215 -0
- stanza/stanza/pipeline/external/__init__.py +0 -0
- stanza/stanza/tests/classifiers/__init__.py +0 -0
- stanza/stanza/tests/classifiers/test_classifier.py +317 -0
- stanza/stanza/tests/classifiers/test_process_utils.py +83 -0
- stanza/stanza/tests/common/test_bert_embedding.py +33 -0
- stanza/stanza/tests/common/test_char_model.py +190 -0
- stanza/stanza/tests/common/test_common_data.py +32 -0
- stanza/stanza/tests/common/test_data_objects.py +60 -0
- stanza/stanza/tests/common/test_doc.py +174 -0
- stanza/stanza/tests/common/test_dropout.py +28 -0
- stanza/stanza/tests/common/test_short_name_to_treebank.py +14 -0
- stanza/stanza/tests/constituency/test_convert_it_vit.py +228 -0
- stanza/stanza/tests/constituency/test_convert_starlang.py +37 -0
- stanza/stanza/tests/constituency/test_in_order_oracle.py +522 -0
- stanza/stanza/tests/constituency/test_lstm_model.py +552 -0
- stanza/stanza/tests/constituency/test_text_processing.py +109 -0
- stanza/stanza/tests/constituency/test_top_down_oracle.py +443 -0
- stanza/stanza/tests/constituency/test_trainer.py +639 -0
- stanza/stanza/tests/constituency/test_transformer_tree_stack.py +195 -0
- stanza/stanza/tests/constituency/test_transition_sequence.py +156 -0
- stanza/stanza/tests/constituency/test_tree_reader.py +119 -0
- stanza/stanza/tests/constituency/test_vietnamese.py +121 -0
- stanza/stanza/tests/langid/test_langid.py +615 -0
- stanza/stanza/tests/lemma/__init__.py +0 -0
- stanza/stanza/tests/mwt/test_utils.py +59 -0
- stanza/stanza/tests/ner/__init__.py +0 -0
- stanza/stanza/tests/ner/test_combine_ner_datasets.py +39 -0
- stanza/stanza/tests/ner/test_models_ner_scorer.py +28 -0
- stanza/stanza/tests/ner/test_ner_tagger.py +94 -0
- stanza/stanza/tests/ner/test_ner_trainer.py +32 -0
- stanza/stanza/tests/ner/test_pay_amt_annotators.py +50 -0
- stanza/stanza/tests/ner/test_split_wikiner.py +202 -0
- stanza/stanza/tests/ner/test_suc3.py +91 -0
- stanza/stanza/tests/pipeline/test_decorators.py +127 -0
- stanza/stanza/tests/pipeline/test_pipeline_mwt_expander.py +123 -0
- stanza/stanza/tests/pos/__init__.py +0 -0
- stanza/stanza/tests/pos/test_tagger.py +315 -0
- stanza/stanza/tests/resources/__init__.py +0 -0
- stanza/stanza/tests/resources/test_default_packages.py +24 -0
- stanza/stanza/tests/resources/test_prepare_resources.py +30 -0
- stanza/stanza/tests/server/test_server_misc.py +115 -0
- stanza/stanza/utils/datasets/common.py +286 -0
- stanza/stanza/utils/datasets/conllu_to_text.pl +248 -0
- stanza/stanza/utils/datasets/prepare_lemma_classifier.py +144 -0
- stanza/stanza/utils/datasets/prepare_mwt_treebank.py +88 -0
- stanza/stanza/utils/datasets/prepare_pos_treebank.py +38 -0
- stanza/stanza/utils/datasets/random_split_conllu.py +59 -0
- stanza/stanza/utils/datasets/thai_syllable_dict_generator.py +53 -0
stanza/stanza/pipeline/demo/stanza-brat.css
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
.red {
|
| 3 |
+
color:#990000
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
#wrap {
|
| 7 |
+
min-height: 100%;
|
| 8 |
+
height: auto;
|
| 9 |
+
margin: 0 auto -6ex;
|
| 10 |
+
padding: 0 0 6ex;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
.pattern_tab {
|
| 14 |
+
margin: 1ex;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
.pattern_brat {
|
| 18 |
+
margin-top: 1ex;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
.label {
|
| 22 |
+
color: #777777;
|
| 23 |
+
font-size: small;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
.footer {
|
| 27 |
+
bottom: 0;
|
| 28 |
+
width: 100%;
|
| 29 |
+
/* Set the fixed height of the footer here */
|
| 30 |
+
height: 5ex;
|
| 31 |
+
padding-top: 1ex;
|
| 32 |
+
margin-top: 1ex;
|
| 33 |
+
background-color: #f5f5f5;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
.corenlp_error {
|
| 37 |
+
margin-top: 2ex;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
/* Styling for parse graph */
|
| 41 |
+
.node rect {
|
| 42 |
+
stroke: #333;
|
| 43 |
+
fill: #fff;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.parse-RULE rect {
|
| 47 |
+
fill: #C0D9AF;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.parse-TERMINAL rect {
|
| 51 |
+
stroke: #333;
|
| 52 |
+
fill: #EEE8AA;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
.node.highlighted {
|
| 56 |
+
stroke: #ffff00;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
.edgePath path {
|
| 60 |
+
stroke: #333;
|
| 61 |
+
fill: #333;
|
| 62 |
+
stroke-width: 1.5px;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.parse-EDGE path {
|
| 66 |
+
stroke: DarkGray;
|
| 67 |
+
fill: DarkGray;
|
| 68 |
+
stroke-width: 1.5px;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.logo {
|
| 72 |
+
font-family: "Lato", "Gill Sans MT", "Gill Sans", "Helvetica", "Arial", sans-serif;
|
| 73 |
+
font-style: italic;
|
| 74 |
+
}
|
stanza/stanza/pipeline/demo/stanza-parseviewer.js
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//'use strict';
|
| 2 |
+
|
| 3 |
+
//d3 || require('d3');
|
| 4 |
+
//var dagreD3 = require('dagre-d3');
|
| 5 |
+
//var jquery = require('jquery');
|
| 6 |
+
//var $ = jquery;
|
| 7 |
+
|
| 8 |
+
var ParseViewer = function(params) {
|
| 9 |
+
// Container in which the scene template is displayed
|
| 10 |
+
this.selector = params.selector;
|
| 11 |
+
this.container = $(this.selector);
|
| 12 |
+
this.fitToGraph = true;
|
| 13 |
+
this.onClickNodeCallback = params.onClickNodeCallback;
|
| 14 |
+
this.onHoverNodeCallback = params.onHoverNodeCallback;
|
| 15 |
+
this.init();
|
| 16 |
+
return this;
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
ParseViewer.MIN_WIDTH = 100;
|
| 20 |
+
ParseViewer.MIN_HEIGHT = 100;
|
| 21 |
+
|
| 22 |
+
ParseViewer.prototype.constructor = ParseViewer;
|
| 23 |
+
|
| 24 |
+
ParseViewer.prototype.getAutoWidth = function () {
|
| 25 |
+
return Math.max(ParseViewer.MIN_WIDTH, this.container.width());
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
ParseViewer.prototype.getAutoHeight = function () {
|
| 29 |
+
return Math.max(ParseViewer.MIN_HEIGHT, this.container.height() - 20);
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
ParseViewer.prototype.init = function () {
|
| 33 |
+
var canvasWidth = this.getAutoWidth();
|
| 34 |
+
var canvasHeight = this.getAutoHeight();
|
| 35 |
+
this.parseElem = d3.select(this.selector)
|
| 36 |
+
.append('svg')
|
| 37 |
+
.attr({'width': canvasWidth, 'height': canvasHeight})
|
| 38 |
+
.style({'width': canvasWidth, 'height': canvasHeight});
|
| 39 |
+
console.log(this.parseElem);
|
| 40 |
+
this.graph = null;
|
| 41 |
+
this.graphRendered = false;
|
| 42 |
+
|
| 43 |
+
this.controls = $('<div class="text"></div>');
|
| 44 |
+
this.container.append(this.controls);
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
var GraphBuilder = function(roots) {
|
| 48 |
+
// Create the input graph
|
| 49 |
+
this.graph = new dagreD3.graphlib.Graph()
|
| 50 |
+
.setGraph({})
|
| 51 |
+
.setDefaultEdgeLabel(function () {
|
| 52 |
+
return {};
|
| 53 |
+
});
|
| 54 |
+
this.visitIndex = 0;
|
| 55 |
+
//console.log('building graph', roots);
|
| 56 |
+
for (var i = 0; i < roots.length; i++) {
|
| 57 |
+
this.build(roots[i]);
|
| 58 |
+
}
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
GraphBuilder.prototype.build = function(node) {
|
| 62 |
+
console.log(node);
|
| 63 |
+
// Track my visit index
|
| 64 |
+
this.visitIndex++;
|
| 65 |
+
node.visitIndex = this.visitIndex;
|
| 66 |
+
|
| 67 |
+
// Add a node
|
| 68 |
+
var nodeData = node; // TODO: replace with semantic data
|
| 69 |
+
var nodeLabel = node.label;
|
| 70 |
+
var nodeIndex = node.visitIndex;
|
| 71 |
+
var nodeClass = 'parse-RULE';
|
| 72 |
+
|
| 73 |
+
this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });
|
| 74 |
+
if (node.parent) {
|
| 75 |
+
this.graph.setEdge(node.parent.visitIndex, nodeIndex, {
|
| 76 |
+
class: 'parse-EDGE'
|
| 77 |
+
});
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if (node.isTerminal) {
|
| 81 |
+
this.visitIndex++;
|
| 82 |
+
nodeIndex = this.visitIndex;
|
| 83 |
+
nodeLabel = node.text;
|
| 84 |
+
nodeClass = 'parse-TERMINAL';
|
| 85 |
+
|
| 86 |
+
this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData });
|
| 87 |
+
this.graph.setEdge(node.visitIndex, nodeIndex, {
|
| 88 |
+
class: 'parse-EDGE'
|
| 89 |
+
});
|
| 90 |
+
} else if (node.children) {
|
| 91 |
+
for (var i = 0; i < node.children.length; i++) {
|
| 92 |
+
this.build(node.children[i]);
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
ParseViewer.prototype.updateGraphPosition = function (svg, g, minWidth, minHeight) {
|
| 98 |
+
if (this.fitToGraph) {
|
| 99 |
+
minWidth = g.graph().width;
|
| 100 |
+
minHeight = this.getAutoHeight();
|
| 101 |
+
}
|
| 102 |
+
adjustGraphPositioning(svg, g, minWidth, minHeight);
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
function adjustGraphPositioning(svg, g, minWidth, minHeight) {
|
| 106 |
+
// Resize svg
|
| 107 |
+
var newWidth = Math.max(minWidth, g.graph().width);
|
| 108 |
+
var newHeight = Math.max(minHeight, g.graph().height + 40);
|
| 109 |
+
svg.attr({'width': newWidth, 'height': newHeight});
|
| 110 |
+
svg.style({'width': newWidth, 'height': newHeight});
|
| 111 |
+
// Center the graph
|
| 112 |
+
var svgGroup = svg.select('g');
|
| 113 |
+
var xCenterOffset = (svg.attr('width') - g.graph().width) / 2;
|
| 114 |
+
svgGroup.attr('transform', 'translate(' + xCenterOffset + ', 20)');
|
| 115 |
+
svg.attr('height', g.graph().height + 40);
|
| 116 |
+
svg.style('height', g.graph().height + 40);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
ParseViewer.prototype.renderGraph = function (svg, g, parse) {
|
| 120 |
+
// Create the renderer
|
| 121 |
+
var render = new dagreD3.render();
|
| 122 |
+
// Run the renderer. This is what draws the final graph.
|
| 123 |
+
var svgGroup = svg.select('g');
|
| 124 |
+
render(svgGroup, g);
|
| 125 |
+
|
| 126 |
+
var scope = this;
|
| 127 |
+
var nodes = svgGroup.selectAll('g.node');
|
| 128 |
+
nodes.on('click',
|
| 129 |
+
function (d) {
|
| 130 |
+
var v = d;
|
| 131 |
+
var node = g.node(v);
|
| 132 |
+
if (scope.onClickNodeCallback) {
|
| 133 |
+
scope.onClickNodeCallback(node.data);
|
| 134 |
+
}
|
| 135 |
+
console.log(g.node(v));
|
| 136 |
+
}
|
| 137 |
+
);
|
| 138 |
+
|
| 139 |
+
nodes.on('mouseover',
|
| 140 |
+
function (d) {
|
| 141 |
+
var v = d;
|
| 142 |
+
var node = g.node(v);
|
| 143 |
+
if (scope.onHoverNodeCallback) {
|
| 144 |
+
scope.onHoverNodeCallback(node.data);
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
);
|
| 148 |
+
|
| 149 |
+
this.updateGraphPosition(svg, g, svg.attr('width'), svg.attr('height'));
|
| 150 |
+
this.graphRendered = true;
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
ParseViewer.prototype.showParse = function (root) {
|
| 154 |
+
this.showParses([root]);
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
ParseViewer.prototype.showParses = function (roots) {
|
| 158 |
+
// Take parse and create a graph
|
| 159 |
+
var gb = new GraphBuilder(roots);
|
| 160 |
+
var g = gb.graph;
|
| 161 |
+
|
| 162 |
+
g.nodes().forEach(function (v) {
|
| 163 |
+
var node = g.node(v);
|
| 164 |
+
// Round the corners of the nodes
|
| 165 |
+
node.rx = node.ry = 5;
|
| 166 |
+
});
|
| 167 |
+
|
| 168 |
+
var svg = this.parseElem;
|
| 169 |
+
svg.selectAll('*').remove();
|
| 170 |
+
var svgGroup = svg.append('g');
|
| 171 |
+
this.graph = g;
|
| 172 |
+
this.parse = roots;
|
| 173 |
+
if (this.container.is(':visible')) {
|
| 174 |
+
if (roots.length > 0) {
|
| 175 |
+
this.renderGraph(svg, this.graph, this.parse);
|
| 176 |
+
}
|
| 177 |
+
} else {
|
| 178 |
+
this.graphRendered = false;
|
| 179 |
+
}
|
| 180 |
+
};
|
| 181 |
+
|
| 182 |
+
ParseViewer.prototype.showAnnotation = function (annotation) {
|
| 183 |
+
var parses = [];
|
| 184 |
+
for (var i = 0; i < annotation.sentences.length; i++) {
|
| 185 |
+
var s = annotation.sentences[i];
|
| 186 |
+
if (s && s.parseTree) {
|
| 187 |
+
parses.push(s.parseTree);
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
this.showParses(parses);
|
| 191 |
+
};
|
| 192 |
+
|
| 193 |
+
ParseViewer.prototype.onResize = function () {
|
| 194 |
+
var canvasWidth = this.getAutoWidth();
|
| 195 |
+
var canvasHeight = this.getAutoHeight();
|
| 196 |
+
var svg = this.parseElem;
|
| 197 |
+
|
| 198 |
+
// Center the graph
|
| 199 |
+
var svgGroup = svg.select('g');
|
| 200 |
+
if (svgGroup && this.graph) {
|
| 201 |
+
if (!this.graphRendered) {
|
| 202 |
+
svg.attr({'width': canvasWidth, 'height': canvasHeight});
|
| 203 |
+
svg.style({'width': canvasWidth, 'height': canvasHeight});
|
| 204 |
+
this.renderGraph(svg, this.graph, this.parse);
|
| 205 |
+
} else {
|
| 206 |
+
this.updateGraphPosition(svg, this.graph, canvasWidth, canvasHeight);
|
| 207 |
+
}
|
| 208 |
+
} else {
|
| 209 |
+
svg.attr({'width': canvasWidth, 'height': canvasHeight});
|
| 210 |
+
svg.style({'width': canvasWidth, 'height': canvasHeight});
|
| 211 |
+
}
|
| 212 |
+
};
|
| 213 |
+
|
| 214 |
+
// Exports
|
| 215 |
+
//module.exports = ParseViewer;
|
stanza/stanza/pipeline/external/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/classifiers/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/classifiers/test_classifier.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
import stanza
|
| 10 |
+
import stanza.models.classifier as classifier
|
| 11 |
+
import stanza.models.classifiers.data as data
|
| 12 |
+
from stanza.models.classifiers.trainer import Trainer
|
| 13 |
+
from stanza.models.common import pretrain
|
| 14 |
+
from stanza.models.common import utils
|
| 15 |
+
|
| 16 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 17 |
+
from stanza.tests.classifiers.test_data import train_file, dev_file, test_file, DATASET, SENTENCES
|
| 18 |
+
|
| 19 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 20 |
+
|
| 21 |
+
EMB_DIM = 5
|
| 22 |
+
|
| 23 |
+
@pytest.fixture(scope="module")
|
| 24 |
+
def fake_embeddings(tmp_path_factory):
|
| 25 |
+
"""
|
| 26 |
+
will return a path to a fake embeddings file with the words in SENTENCES
|
| 27 |
+
"""
|
| 28 |
+
# could set np random seed here
|
| 29 |
+
words = sorted(set([x.lower() for y in SENTENCES for x in y]))
|
| 30 |
+
words = words[:-1]
|
| 31 |
+
embedding_dir = tmp_path_factory.mktemp("data")
|
| 32 |
+
embedding_txt = embedding_dir / "embedding.txt"
|
| 33 |
+
embedding_pt = embedding_dir / "embedding.pt"
|
| 34 |
+
embedding = np.random.random((len(words), EMB_DIM))
|
| 35 |
+
|
| 36 |
+
with open(embedding_txt, "w", encoding="utf-8") as fout:
|
| 37 |
+
for word, emb in zip(words, embedding):
|
| 38 |
+
fout.write(word)
|
| 39 |
+
fout.write("\t")
|
| 40 |
+
fout.write("\t".join(str(x) for x in emb))
|
| 41 |
+
fout.write("\n")
|
| 42 |
+
|
| 43 |
+
pt = pretrain.Pretrain(str(embedding_pt), str(embedding_txt))
|
| 44 |
+
pt.load()
|
| 45 |
+
assert os.path.exists(embedding_pt)
|
| 46 |
+
return embedding_pt
|
| 47 |
+
|
| 48 |
+
class TestClassifier:
|
| 49 |
+
def build_model(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):
|
| 50 |
+
"""
|
| 51 |
+
Build a model to be used by one of the later tests
|
| 52 |
+
"""
|
| 53 |
+
save_dir = str(tmp_path / "classifier")
|
| 54 |
+
save_name = "model.pt"
|
| 55 |
+
args = ["--save_dir", save_dir,
|
| 56 |
+
"--save_name", save_name,
|
| 57 |
+
"--wordvec_pretrain_file", str(fake_embeddings),
|
| 58 |
+
"--filter_channels", "20",
|
| 59 |
+
"--fc_shapes", "20,10",
|
| 60 |
+
"--train_file", str(train_file),
|
| 61 |
+
"--dev_file", str(dev_file),
|
| 62 |
+
"--max_epochs", "2",
|
| 63 |
+
"--batch_size", "60"]
|
| 64 |
+
if extra_args is not None:
|
| 65 |
+
args = args + extra_args
|
| 66 |
+
args = classifier.parse_args(args)
|
| 67 |
+
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
|
| 68 |
+
if checkpoint_file:
|
| 69 |
+
trainer = Trainer.load(checkpoint_file, args, load_optimizer=True)
|
| 70 |
+
else:
|
| 71 |
+
trainer = Trainer.build_new_model(args, train_set)
|
| 72 |
+
return trainer, train_set, args
|
| 73 |
+
|
| 74 |
+
def run_training(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None):
|
| 75 |
+
"""
|
| 76 |
+
Iterate a couple times over a model
|
| 77 |
+
"""
|
| 78 |
+
trainer, train_set, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args, checkpoint_file)
|
| 79 |
+
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
|
| 80 |
+
labels = data.dataset_labels(train_set)
|
| 81 |
+
|
| 82 |
+
save_filename = os.path.join(args.save_dir, args.save_name)
|
| 83 |
+
if checkpoint_file is None:
|
| 84 |
+
checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)
|
| 85 |
+
classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)
|
| 86 |
+
return trainer, save_filename, checkpoint_file
|
| 87 |
+
|
| 88 |
+
def test_build_model(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 89 |
+
"""
|
| 90 |
+
Test that building a basic model works
|
| 91 |
+
"""
|
| 92 |
+
self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
|
| 93 |
+
|
| 94 |
+
def test_save_load(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 95 |
+
"""
|
| 96 |
+
Test that a basic model can save & load
|
| 97 |
+
"""
|
| 98 |
+
trainer, _, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
|
| 99 |
+
|
| 100 |
+
save_filename = os.path.join(args.save_dir, args.save_name)
|
| 101 |
+
trainer.save(save_filename)
|
| 102 |
+
|
| 103 |
+
args.load_name = args.save_name
|
| 104 |
+
trainer = Trainer.load(args.load_name, args)
|
| 105 |
+
args.load_name = save_filename
|
| 106 |
+
trainer = Trainer.load(args.load_name, args)
|
| 107 |
+
|
| 108 |
+
def test_train_basic(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 109 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"])
|
| 110 |
+
|
| 111 |
+
def test_train_bilstm(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 112 |
+
"""
|
| 113 |
+
Test w/ and w/o bilstm variations of the classifier
|
| 114 |
+
"""
|
| 115 |
+
args = ["--bilstm", "--bilstm_hidden_dim", "20"]
|
| 116 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 117 |
+
|
| 118 |
+
args = ["--no_bilstm"]
|
| 119 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 120 |
+
|
| 121 |
+
def test_train_maxpool_width(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 122 |
+
"""
|
| 123 |
+
Test various maxpool widths
|
| 124 |
+
|
| 125 |
+
Also sets --filter_channels to a multiple of 2 but not of 3 for
|
| 126 |
+
the test to make sure the math is done correctly on a non-divisible width
|
| 127 |
+
"""
|
| 128 |
+
args = ["--maxpool_width", "1", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
|
| 129 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 130 |
+
|
| 131 |
+
args = ["--maxpool_width", "2", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
|
| 132 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 133 |
+
|
| 134 |
+
args = ["--maxpool_width", "3", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
|
| 135 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 136 |
+
|
| 137 |
+
def test_train_conv_2d(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 138 |
+
args = ["--filter_sizes", "(3,4,5)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
|
| 139 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 140 |
+
|
| 141 |
+
args = ["--filter_sizes", "((3,2),)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
|
| 142 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 143 |
+
|
| 144 |
+
args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"]
|
| 145 |
+
self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 146 |
+
|
| 147 |
+
def test_train_filter_channels(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 148 |
+
args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--no_bilstm"]
|
| 149 |
+
trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 150 |
+
assert trainer.model.fc_input_size == 40
|
| 151 |
+
|
| 152 |
+
args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "15,20", "--no_bilstm"]
|
| 153 |
+
trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
|
| 154 |
+
# 50 = 2x15 for the 2d conv (over 5 dim embeddings) + 20
|
| 155 |
+
assert trainer.model.fc_input_size == 50
|
| 156 |
+
|
| 157 |
+
def test_train_bert(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 158 |
+
"""
|
| 159 |
+
Test on a tiny Bert WITHOUT finetuning, which hopefully does not take up too much disk space or memory
|
| 160 |
+
"""
|
| 161 |
+
bert_model = "hf-internal-testing/tiny-bert"
|
| 162 |
+
|
| 163 |
+
trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model])
|
| 164 |
+
assert os.path.exists(save_filename)
|
| 165 |
+
saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
|
| 166 |
+
# check that the bert model wasn't saved as part of the classifier
|
| 167 |
+
assert not saved_model['params']['config']['force_bert_saved']
|
| 168 |
+
assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
|
| 169 |
+
|
| 170 |
+
def test_finetune_bert(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 171 |
+
"""
|
| 172 |
+
Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory
|
| 173 |
+
"""
|
| 174 |
+
bert_model = "hf-internal-testing/tiny-bert"
|
| 175 |
+
|
| 176 |
+
trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune"])
|
| 177 |
+
assert os.path.exists(save_filename)
|
| 178 |
+
saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
|
| 179 |
+
# after finetuning the bert model, make sure that the save file DOES contain parts of the transformer
|
| 180 |
+
assert saved_model['params']['config']['force_bert_saved']
|
| 181 |
+
assert any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
|
| 182 |
+
|
| 183 |
+
def test_finetune_bert_layers(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 184 |
+
"""Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory, using 2 layers
|
| 185 |
+
|
| 186 |
+
As an added bonus (or eager test), load the finished model and continue
|
| 187 |
+
training from there. Then check that the initial model and
|
| 188 |
+
the middle model are different, then that the middle model and
|
| 189 |
+
final model are different
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
bert_model = "hf-internal-testing/tiny-bert"
|
| 193 |
+
|
| 194 |
+
trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models"])
|
| 195 |
+
assert os.path.exists(save_filename)
|
| 196 |
+
|
| 197 |
+
save_path = os.path.split(save_filename)[0]
|
| 198 |
+
|
| 199 |
+
initial_model = glob.glob(os.path.join(save_path, "*E0000*"))
|
| 200 |
+
assert len(initial_model) == 1
|
| 201 |
+
initial_model = initial_model[0]
|
| 202 |
+
initial_model = torch.load(initial_model, lambda storage, loc: storage, weights_only=True)
|
| 203 |
+
|
| 204 |
+
second_model_file = glob.glob(os.path.join(save_path, "*E0002*"))
|
| 205 |
+
assert len(second_model_file) == 1
|
| 206 |
+
second_model_file = second_model_file[0]
|
| 207 |
+
second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
|
| 208 |
+
|
| 209 |
+
for layer_idx in range(2):
|
| 210 |
+
bert_names = [x for x in second_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x]
|
| 211 |
+
assert len(bert_names) > 0
|
| 212 |
+
assert all(x in initial_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)
|
| 213 |
+
assert not all(torch.allclose(initial_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)
|
| 214 |
+
|
| 215 |
+
# put some random marker in the file to look for later,
|
| 216 |
+
# check the continued training didn't clobber the expected file
|
| 217 |
+
assert "asdf" not in second_model
|
| 218 |
+
second_model["asdf"] = 1234
|
| 219 |
+
torch.save(second_model, second_model_file)
|
| 220 |
+
|
| 221 |
+
trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file)
|
| 222 |
+
|
| 223 |
+
second_model_file_redo = glob.glob(os.path.join(save_path, "*E0002*"))
|
| 224 |
+
assert len(second_model_file_redo) == 1
|
| 225 |
+
assert second_model_file == second_model_file_redo[0]
|
| 226 |
+
second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
|
| 227 |
+
assert "asdf" in second_model
|
| 228 |
+
|
| 229 |
+
fifth_model_file = glob.glob(os.path.join(save_path, "*E0005*"))
|
| 230 |
+
assert len(fifth_model_file) == 1
|
| 231 |
+
|
| 232 |
+
final_model = torch.load(fifth_model_file[0], lambda storage, loc: storage, weights_only=True)
|
| 233 |
+
for layer_idx in range(2):
|
| 234 |
+
bert_names = [x for x in final_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x]
|
| 235 |
+
assert len(bert_names) > 0
|
| 236 |
+
assert all(x in final_model['params']['model'] and x in second_model['params']['model'] for x in bert_names)
|
| 237 |
+
assert not all(torch.allclose(final_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names)
|
| 238 |
+
|
| 239 |
+
def test_finetune_peft(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 240 |
+
"""
|
| 241 |
+
Test on a tiny Bert with PEFT finetuning
|
| 242 |
+
"""
|
| 243 |
+
bert_model = "hf-internal-testing/tiny-bert"
|
| 244 |
+
|
| 245 |
+
trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler"])
|
| 246 |
+
assert os.path.exists(save_filename)
|
| 247 |
+
saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True)
|
| 248 |
+
# after finetuning the bert model, make sure that the save file DOES contain parts of the transformer, but only in peft form
|
| 249 |
+
assert saved_model['params']['config']['bert_model'] == bert_model
|
| 250 |
+
assert saved_model['params']['config']['force_bert_saved']
|
| 251 |
+
assert saved_model['params']['config']['use_peft']
|
| 252 |
+
|
| 253 |
+
assert not saved_model['params']['config']['has_charlm_forward']
|
| 254 |
+
assert not saved_model['params']['config']['has_charlm_backward']
|
| 255 |
+
|
| 256 |
+
assert len(saved_model['params']['bert_lora']) > 0
|
| 257 |
+
assert any(x.find(".pooler.") >= 0 for x in saved_model['params']['bert_lora'])
|
| 258 |
+
assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora'])
|
| 259 |
+
assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys())
|
| 260 |
+
|
| 261 |
+
# The Pipeline should load and run a PEFT trained model,
|
| 262 |
+
# although obviously we don't expect the results to do
|
| 263 |
+
# anything correct
|
| 264 |
+
pipeline = stanza.Pipeline("en", download_method=None, model_dir=TEST_MODELS_DIR, processors="tokenize,sentiment", sentiment_model_path=save_filename, sentiment_pretrain_path=str(fake_embeddings))
|
| 265 |
+
doc = pipeline("This is a test")
|
| 266 |
+
|
| 267 |
+
def test_finetune_peft_restart(self, tmp_path, fake_embeddings, train_file, dev_file):
|
| 268 |
+
"""
|
| 269 |
+
Test that if we restart training on a peft model, the peft weights change
|
| 270 |
+
"""
|
| 271 |
+
bert_model = "hf-internal-testing/tiny-bert"
|
| 272 |
+
|
| 273 |
+
trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models"])
|
| 274 |
+
|
| 275 |
+
assert os.path.exists(save_file)
|
| 276 |
+
saved_model = torch.load(save_file, lambda storage, loc: storage, weights_only=True)
|
| 277 |
+
assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora'])
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file)
|
| 281 |
+
|
| 282 |
+
save_path = os.path.split(save_file)[0]
|
| 283 |
+
|
| 284 |
+
initial_model_file = glob.glob(os.path.join(save_path, "*E0000*"))
|
| 285 |
+
assert len(initial_model_file) == 1
|
| 286 |
+
initial_model_file = initial_model_file[0]
|
| 287 |
+
initial_model = torch.load(initial_model_file, lambda storage, loc: storage, weights_only=True)
|
| 288 |
+
|
| 289 |
+
second_model_file = glob.glob(os.path.join(save_path, "*E0002*"))
|
| 290 |
+
assert len(second_model_file) == 1
|
| 291 |
+
second_model_file = second_model_file[0]
|
| 292 |
+
second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True)
|
| 293 |
+
|
| 294 |
+
final_model_file = glob.glob(os.path.join(save_path, "*E0005*"))
|
| 295 |
+
assert len(final_model_file) == 1
|
| 296 |
+
final_model_file = final_model_file[0]
|
| 297 |
+
final_model = torch.load(final_model_file, lambda storage, loc: storage, weights_only=True)
|
| 298 |
+
|
| 299 |
+
# params in initial_model & second_model start with "base_model.model."
|
| 300 |
+
# whereas params in final_model start directly with "encoder" or "pooler"
|
| 301 |
+
initial_lora = initial_model['params']['bert_lora']
|
| 302 |
+
second_lora = second_model['params']['bert_lora']
|
| 303 |
+
final_lora = final_model['params']['bert_lora']
|
| 304 |
+
for side in ("_A.", "_B."):
|
| 305 |
+
for layer in (".0.", ".1."):
|
| 306 |
+
initial_params = sorted([x for x in initial_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
|
| 307 |
+
second_params = sorted([x for x in second_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
|
| 308 |
+
final_params = sorted([x for x in final_lora if x.startswith("encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0])
|
| 309 |
+
assert len(initial_params) > 0
|
| 310 |
+
assert len(initial_params) == len(second_params)
|
| 311 |
+
assert len(initial_params) == len(final_params)
|
| 312 |
+
for x, y in zip(second_params, final_params):
|
| 313 |
+
assert x.endswith(y)
|
| 314 |
+
if side != "_A.": # the A tensors don't move very much, if at all
|
| 315 |
+
assert not torch.allclose(initial_lora.get(x), second_lora.get(x))
|
| 316 |
+
assert not torch.allclose(second_lora.get(x), final_lora.get(y))
|
| 317 |
+
|
stanza/stanza/tests/classifiers/test_process_utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A few tests of the utils module for the sentiment datasets
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
import stanza
|
| 9 |
+
|
| 10 |
+
from stanza.models.classifiers import data
|
| 11 |
+
from stanza.models.classifiers.data import SentimentDatum
|
| 12 |
+
from stanza.models.classifiers.utils import WVType
|
| 13 |
+
from stanza.utils.datasets.sentiment import process_utils
|
| 14 |
+
|
| 15 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 16 |
+
from stanza.tests.classifiers.test_data import train_file, dev_file, test_file
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_write_list(tmp_path, train_file):
|
| 20 |
+
"""
|
| 21 |
+
Test that writing a single list of items to an output file works
|
| 22 |
+
"""
|
| 23 |
+
train_set = data.read_dataset(train_file, WVType.OTHER, 1)
|
| 24 |
+
|
| 25 |
+
dataset_file = tmp_path / "foo.json"
|
| 26 |
+
process_utils.write_list(dataset_file, train_set)
|
| 27 |
+
|
| 28 |
+
train_copy = data.read_dataset(dataset_file, WVType.OTHER, 1)
|
| 29 |
+
assert train_copy == train_set
|
| 30 |
+
|
| 31 |
+
def test_write_dataset(tmp_path, train_file, dev_file, test_file):
|
| 32 |
+
"""
|
| 33 |
+
Test that writing all three parts of a dataset works
|
| 34 |
+
"""
|
| 35 |
+
dataset = [data.read_dataset(filename, WVType.OTHER, 1) for filename in (train_file, dev_file, test_file)]
|
| 36 |
+
process_utils.write_dataset(dataset, tmp_path, "en_test")
|
| 37 |
+
|
| 38 |
+
expected_files = ['en_test.train.json', 'en_test.dev.json', 'en_test.test.json']
|
| 39 |
+
dataset_files = os.listdir(tmp_path)
|
| 40 |
+
assert sorted(dataset_files) == sorted(expected_files)
|
| 41 |
+
|
| 42 |
+
for filename, expected in zip(expected_files, dataset):
|
| 43 |
+
written = data.read_dataset(tmp_path / filename, WVType.OTHER, 1)
|
| 44 |
+
assert written == expected
|
| 45 |
+
|
| 46 |
+
def test_read_snippets(tmp_path):
|
| 47 |
+
"""
|
| 48 |
+
Test the basic operation of the read_snippets function
|
| 49 |
+
"""
|
| 50 |
+
filename = tmp_path / "foo.csv"
|
| 51 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 52 |
+
fout.write("FOO\tThis is a test\thappy\n")
|
| 53 |
+
fout.write("FOO\tThis is a second sentence\tsad\n")
|
| 54 |
+
|
| 55 |
+
nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
|
| 56 |
+
|
| 57 |
+
mapping = {"happy": 0, "sad": 1}
|
| 58 |
+
|
| 59 |
+
snippets = process_utils.read_snippets(filename, 2, 1, "en", mapping, nlp=nlp)
|
| 60 |
+
assert len(snippets) == 2
|
| 61 |
+
assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
|
| 62 |
+
SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence'])]
|
| 63 |
+
|
| 64 |
+
def test_read_snippets_two_columns(tmp_path):
|
| 65 |
+
"""
|
| 66 |
+
Test what happens when multiple columns are combined for the sentiment value
|
| 67 |
+
"""
|
| 68 |
+
filename = tmp_path / "foo.csv"
|
| 69 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 70 |
+
fout.write("FOO\tThis is a test\thappy\tfoo\n")
|
| 71 |
+
fout.write("FOO\tThis is a second sentence\tsad\tbar\n")
|
| 72 |
+
fout.write("FOO\tThis is a third sentence\tsad\tfoo\n")
|
| 73 |
+
|
| 74 |
+
nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
|
| 75 |
+
|
| 76 |
+
mapping = {("happy", "foo"): 0, ("sad", "bar"): 1, ("sad", "foo"): 2}
|
| 77 |
+
|
| 78 |
+
snippets = process_utils.read_snippets(filename, (2,3), 1, "en", mapping, nlp=nlp)
|
| 79 |
+
assert len(snippets) == 3
|
| 80 |
+
assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
|
| 81 |
+
SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence']),
|
| 82 |
+
SentimentDatum(sentiment=2, text=['This', 'is', 'a', 'third', 'sentence'])]
|
| 83 |
+
|
stanza/stanza/tests/common/test_bert_embedding.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings
|
| 5 |
+
|
| 6 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 7 |
+
|
| 8 |
+
BERT_MODEL = "hf-internal-testing/tiny-bert"
|
| 9 |
+
|
| 10 |
+
@pytest.fixture(scope="module")
|
| 11 |
+
def tiny_bert():
|
| 12 |
+
m, t = load_bert(BERT_MODEL)
|
| 13 |
+
return m, t
|
| 14 |
+
|
| 15 |
+
def test_load_bert(tiny_bert):
|
| 16 |
+
"""
|
| 17 |
+
Empty method that just tests loading the bert
|
| 18 |
+
"""
|
| 19 |
+
m, t = tiny_bert
|
| 20 |
+
|
| 21 |
+
def test_run_bert(tiny_bert):
|
| 22 |
+
m, t = tiny_bert
|
| 23 |
+
device = next(m.parameters()).device
|
| 24 |
+
extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True)
|
| 25 |
+
|
| 26 |
+
def test_run_bert_empty_word(tiny_bert):
|
| 27 |
+
m, t = tiny_bert
|
| 28 |
+
device = next(m.parameters()).device
|
| 29 |
+
foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True)
|
| 30 |
+
bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True)
|
| 31 |
+
|
| 32 |
+
assert len(foo) == 1
|
| 33 |
+
assert torch.allclose(foo[0], bar[0])
|
stanza/stanza/tests/common/test_char_model.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Currently tests a few configurations of files for creating a charlm vocab
|
| 3 |
+
|
| 4 |
+
Also has a skeleton test of loading & saving a charlm
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from collections import Counter
|
| 8 |
+
import glob
|
| 9 |
+
import lzma
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
|
| 15 |
+
from stanza.models import charlm
|
| 16 |
+
from stanza.models.common import char_model
|
| 17 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 18 |
+
|
| 19 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 20 |
+
|
| 21 |
+
fake_text_1 = """
|
| 22 |
+
Unban mox opal!
|
| 23 |
+
I hate watching Peppa Pig
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
fake_text_2 = """
|
| 27 |
+
This is plastic cheese
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
class TestCharModel:
|
| 31 |
+
def test_single_file_vocab(self):
|
| 32 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 33 |
+
sample_file = os.path.join(tempdir, "text.txt")
|
| 34 |
+
with open(sample_file, "w", encoding="utf-8") as fout:
|
| 35 |
+
fout.write(fake_text_1)
|
| 36 |
+
vocab = char_model.build_charlm_vocab(sample_file)
|
| 37 |
+
|
| 38 |
+
for i in fake_text_1:
|
| 39 |
+
assert i in vocab
|
| 40 |
+
assert "Q" not in vocab
|
| 41 |
+
|
| 42 |
+
def test_single_file_xz_vocab(self):
|
| 43 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 44 |
+
sample_file = os.path.join(tempdir, "text.txt.xz")
|
| 45 |
+
with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
|
| 46 |
+
fout.write(fake_text_1)
|
| 47 |
+
vocab = char_model.build_charlm_vocab(sample_file)
|
| 48 |
+
|
| 49 |
+
for i in fake_text_1:
|
| 50 |
+
assert i in vocab
|
| 51 |
+
assert "Q" not in vocab
|
| 52 |
+
|
| 53 |
+
def test_single_file_dir_vocab(self):
|
| 54 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 55 |
+
sample_file = os.path.join(tempdir, "text.txt")
|
| 56 |
+
with open(sample_file, "w", encoding="utf-8") as fout:
|
| 57 |
+
fout.write(fake_text_1)
|
| 58 |
+
vocab = char_model.build_charlm_vocab(tempdir)
|
| 59 |
+
|
| 60 |
+
for i in fake_text_1:
|
| 61 |
+
assert i in vocab
|
| 62 |
+
assert "Q" not in vocab
|
| 63 |
+
|
| 64 |
+
def test_multiple_files_vocab(self):
|
| 65 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 66 |
+
sample_file = os.path.join(tempdir, "t1.txt")
|
| 67 |
+
with open(sample_file, "w", encoding="utf-8") as fout:
|
| 68 |
+
fout.write(fake_text_1)
|
| 69 |
+
sample_file = os.path.join(tempdir, "t2.txt.xz")
|
| 70 |
+
with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
|
| 71 |
+
fout.write(fake_text_2)
|
| 72 |
+
vocab = char_model.build_charlm_vocab(tempdir)
|
| 73 |
+
|
| 74 |
+
for i in fake_text_1:
|
| 75 |
+
assert i in vocab
|
| 76 |
+
for i in fake_text_2:
|
| 77 |
+
assert i in vocab
|
| 78 |
+
assert "Q" not in vocab
|
| 79 |
+
|
| 80 |
+
def test_cutoff_vocab(self):
|
| 81 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 82 |
+
sample_file = os.path.join(tempdir, "t1.txt")
|
| 83 |
+
with open(sample_file, "w", encoding="utf-8") as fout:
|
| 84 |
+
fout.write(fake_text_1)
|
| 85 |
+
sample_file = os.path.join(tempdir, "t2.txt.xz")
|
| 86 |
+
with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
|
| 87 |
+
fout.write(fake_text_2)
|
| 88 |
+
|
| 89 |
+
vocab = char_model.build_charlm_vocab(tempdir, cutoff=2)
|
| 90 |
+
|
| 91 |
+
counts = Counter(fake_text_1) + Counter(fake_text_2)
|
| 92 |
+
for letter, count in counts.most_common():
|
| 93 |
+
if count < 2:
|
| 94 |
+
assert letter not in vocab
|
| 95 |
+
else:
|
| 96 |
+
assert letter in vocab
|
| 97 |
+
|
| 98 |
+
def test_build_model(self):
|
| 99 |
+
"""
|
| 100 |
+
Test the whole thing on a small dataset for an iteration or two
|
| 101 |
+
"""
|
| 102 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 103 |
+
eval_file = os.path.join(tempdir, "en_test.dev.txt")
|
| 104 |
+
with open(eval_file, "w", encoding="utf-8") as fout:
|
| 105 |
+
fout.write(fake_text_1)
|
| 106 |
+
train_file = os.path.join(tempdir, "en_test.train.txt")
|
| 107 |
+
with open(train_file, "w", encoding="utf-8") as fout:
|
| 108 |
+
for i in range(1000):
|
| 109 |
+
fout.write(fake_text_1)
|
| 110 |
+
fout.write("\n")
|
| 111 |
+
fout.write(fake_text_2)
|
| 112 |
+
fout.write("\n")
|
| 113 |
+
save_name = 'en_test.forward.pt'
|
| 114 |
+
vocab_save_name = 'en_text.vocab.pt'
|
| 115 |
+
checkpoint_save_name = 'en_text.checkpoint.pt'
|
| 116 |
+
args = ['--train_file', train_file,
|
| 117 |
+
'--eval_file', eval_file,
|
| 118 |
+
'--eval_steps', '0', # eval once per opoch
|
| 119 |
+
'--epochs', '2',
|
| 120 |
+
'--cutoff', '1',
|
| 121 |
+
'--batch_size', '%d' % len(fake_text_1),
|
| 122 |
+
'--shorthand', 'en_test',
|
| 123 |
+
'--save_dir', tempdir,
|
| 124 |
+
'--save_name', save_name,
|
| 125 |
+
'--vocab_save_name', vocab_save_name,
|
| 126 |
+
'--checkpoint_save_name', checkpoint_save_name]
|
| 127 |
+
args = charlm.parse_args(args)
|
| 128 |
+
charlm.train(args)
|
| 129 |
+
|
| 130 |
+
assert os.path.exists(os.path.join(tempdir, vocab_save_name))
|
| 131 |
+
|
| 132 |
+
# test that saving & loading of the model worked
|
| 133 |
+
assert os.path.exists(os.path.join(tempdir, save_name))
|
| 134 |
+
model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name))
|
| 135 |
+
|
| 136 |
+
# test that saving & loading of the checkpoint worked
|
| 137 |
+
assert os.path.exists(os.path.join(tempdir, checkpoint_save_name))
|
| 138 |
+
model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name))
|
| 139 |
+
trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name))
|
| 140 |
+
|
| 141 |
+
assert trainer.global_step > 0
|
| 142 |
+
assert trainer.epoch == 2
|
| 143 |
+
|
| 144 |
+
# quick test to verify this method works with a trained model
|
| 145 |
+
charlm.get_current_lr(trainer, args)
|
| 146 |
+
|
| 147 |
+
# test loading a vocab built by the training method...
|
| 148 |
+
vocab = charlm.load_char_vocab(os.path.join(tempdir, vocab_save_name))
|
| 149 |
+
trainer = char_model.CharacterLanguageModelTrainer.from_new_model(args, vocab)
|
| 150 |
+
# ... and test the get_current_lr for an untrained model as well
|
| 151 |
+
# this test is super "eager"
|
| 152 |
+
assert charlm.get_current_lr(trainer, args) == args['lr0']
|
| 153 |
+
|
| 154 |
+
@pytest.fixture(scope="class")
|
| 155 |
+
def english_forward(self):
|
| 156 |
+
# eg, stanza_test/models/en/forward_charlm/1billion.pt
|
| 157 |
+
models_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "*")
|
| 158 |
+
models = glob.glob(models_path)
|
| 159 |
+
# we expect at least one English model downloaded for the tests
|
| 160 |
+
assert len(models) >= 1
|
| 161 |
+
model_file = models[0]
|
| 162 |
+
return char_model.CharacterLanguageModel.load(model_file)
|
| 163 |
+
|
| 164 |
+
@pytest.fixture(scope="class")
|
| 165 |
+
def english_backward(self):
|
| 166 |
+
# eg, stanza_test/models/en/forward_charlm/1billion.pt
|
| 167 |
+
models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*")
|
| 168 |
+
models = glob.glob(models_path)
|
| 169 |
+
# we expect at least one English model downloaded for the tests
|
| 170 |
+
assert len(models) >= 1
|
| 171 |
+
model_file = models[0]
|
| 172 |
+
return char_model.CharacterLanguageModel.load(model_file)
|
| 173 |
+
|
| 174 |
+
def test_load_model(self, english_forward, english_backward):
|
| 175 |
+
"""
|
| 176 |
+
Check that basic loading functions work
|
| 177 |
+
"""
|
| 178 |
+
assert english_forward.is_forward_lm
|
| 179 |
+
assert not english_backward.is_forward_lm
|
| 180 |
+
|
| 181 |
+
def test_save_load_model(self, english_forward, english_backward):
|
| 182 |
+
"""
|
| 183 |
+
Load, save, and load again
|
| 184 |
+
"""
|
| 185 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 186 |
+
for model in (english_forward, english_backward):
|
| 187 |
+
save_file = os.path.join(tempdir, "resaved", "charlm.pt")
|
| 188 |
+
model.save(save_file)
|
| 189 |
+
reloaded = char_model.CharacterLanguageModel.load(save_file)
|
| 190 |
+
assert model.is_forward_lm == reloaded.is_forward_lm
|
stanza/stanza/tests/common/test_common_data.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import stanza
|
| 3 |
+
|
| 4 |
+
from stanza.tests import *
|
| 5 |
+
from stanza.models.common.data import get_augment_ratio, augment_punct
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 8 |
+
|
| 9 |
+
def test_augment_ratio():
|
| 10 |
+
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 11 |
+
should_augment = lambda x: x >= 3
|
| 12 |
+
can_augment = lambda x: x >= 4
|
| 13 |
+
# check that zero is returned if no augmentation is needed
|
| 14 |
+
# which will be the case since 2 are already satisfactory
|
| 15 |
+
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0
|
| 16 |
+
|
| 17 |
+
# this should throw an error
|
| 18 |
+
with pytest.raises(AssertionError):
|
| 19 |
+
get_augment_ratio(data, can_augment, should_augment)
|
| 20 |
+
|
| 21 |
+
# with a desired ratio of 0.4,
|
| 22 |
+
# there are already 2 that don't need augmenting
|
| 23 |
+
# and 7 that are eligible to be augmented
|
| 24 |
+
# so 2/7 will need to be augmented
|
| 25 |
+
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7)
|
| 26 |
+
|
| 27 |
+
def test_augment_punct():
|
| 28 |
+
data = [["Simple", "test", "."]]
|
| 29 |
+
should_augment = lambda x: x[-1] == "."
|
| 30 |
+
can_augment = should_augment
|
| 31 |
+
new_data = augment_punct(data, 1.0, should_augment, can_augment)
|
| 32 |
+
assert new_data == [["Simple", "test"]]
|
stanza/stanza/tests/common/test_data_objects.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic tests of the stanza data objects, especially the setter/getter routines
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
import stanza
|
| 7 |
+
from stanza.models.common.doc import Document, Sentence, Word
|
| 8 |
+
from stanza.tests import *
|
| 9 |
+
|
| 10 |
+
pytestmark = pytest.mark.pipeline
|
| 11 |
+
|
| 12 |
+
# data for testing
|
| 13 |
+
EN_DOC = "This is a test document. Pretty cool!"
|
| 14 |
+
|
| 15 |
+
EN_DOC_UPOS_XPOS = (('PRON_DT', 'AUX_VBZ', 'DET_DT', 'NOUN_NN', 'NOUN_NN', 'PUNCT_.'), ('ADV_RB', 'ADJ_JJ', 'PUNCT_.'))
|
| 16 |
+
|
| 17 |
+
EN_DOC2 = "Chris Manning wrote a sentence. Then another."
|
| 18 |
+
|
| 19 |
+
@pytest.fixture(scope="module")
|
| 20 |
+
def nlp_pipeline():
|
| 21 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en')
|
| 22 |
+
return nlp
|
| 23 |
+
|
| 24 |
+
def test_readonly(nlp_pipeline):
|
| 25 |
+
Document.add_property('some_property', 123)
|
| 26 |
+
doc = nlp_pipeline(EN_DOC)
|
| 27 |
+
assert doc.some_property == 123
|
| 28 |
+
with pytest.raises(ValueError):
|
| 29 |
+
doc.some_property = 456
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_getter(nlp_pipeline):
|
| 33 |
+
Word.add_property('upos_xpos', getter=lambda self: f"{self.upos}_{self.xpos}")
|
| 34 |
+
|
| 35 |
+
doc = nlp_pipeline(EN_DOC)
|
| 36 |
+
|
| 37 |
+
assert EN_DOC_UPOS_XPOS == tuple(tuple(word.upos_xpos for word in sentence.words) for sentence in doc.sentences)
|
| 38 |
+
|
| 39 |
+
def test_setter_getter(nlp_pipeline):
|
| 40 |
+
int2str = {0: 'ok', 1: 'good', 2: 'bad'}
|
| 41 |
+
str2int = {'ok': 0, 'good': 1, 'bad': 2}
|
| 42 |
+
def setter(self, value):
|
| 43 |
+
self._classname = str2int[value]
|
| 44 |
+
Sentence.add_property('classname', getter=lambda self: int2str[self._classname] if self._classname is not None else None, setter=setter)
|
| 45 |
+
|
| 46 |
+
doc = nlp_pipeline(EN_DOC)
|
| 47 |
+
sentence = doc.sentences[0]
|
| 48 |
+
sentence.classname = 'good'
|
| 49 |
+
assert sentence._classname == 1
|
| 50 |
+
|
| 51 |
+
# don't try this at home
|
| 52 |
+
sentence._classname = 2
|
| 53 |
+
assert sentence.classname == 'bad'
|
| 54 |
+
|
| 55 |
+
def test_backpointer(nlp_pipeline):
|
| 56 |
+
doc = nlp_pipeline(EN_DOC2)
|
| 57 |
+
ent = doc.ents[0]
|
| 58 |
+
assert ent.sent is doc.sentences[0]
|
| 59 |
+
assert list(doc.iter_words())[0].sent is doc.sentences[0]
|
| 60 |
+
assert list(doc.iter_tokens())[-1].sent is doc.sentences[-1]
|
stanza/stanza/tests/common/test_doc.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
import stanza
|
| 4 |
+
from stanza.tests import *
|
| 5 |
+
from stanza.models.common.doc import Document, ID, TEXT, NER, CONSTITUENCY, SENTIMENT
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 8 |
+
|
| 9 |
+
@pytest.fixture
|
| 10 |
+
def sentences_dict():
|
| 11 |
+
return [[{ID: 1, TEXT: "unban"},
|
| 12 |
+
{ID: 2, TEXT: "mox"},
|
| 13 |
+
{ID: 3, TEXT: "opal"}],
|
| 14 |
+
[{ID: 4, TEXT: "ban"},
|
| 15 |
+
{ID: 5, TEXT: "Lurrus"}]]
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def doc(sentences_dict):
|
| 19 |
+
doc = Document(sentences_dict)
|
| 20 |
+
return doc
|
| 21 |
+
|
| 22 |
+
def test_basic_values(doc, sentences_dict):
|
| 23 |
+
"""
|
| 24 |
+
Test that sentences & token text are properly set when constructing a doc
|
| 25 |
+
"""
|
| 26 |
+
assert len(doc.sentences) == len(sentences_dict)
|
| 27 |
+
|
| 28 |
+
for sentence, raw_sentence in zip(doc.sentences, sentences_dict):
|
| 29 |
+
assert sentence.doc == doc
|
| 30 |
+
assert len(sentence.tokens) == len(raw_sentence)
|
| 31 |
+
for token, raw_token in zip(sentence.tokens, raw_sentence):
|
| 32 |
+
assert token.text == raw_token[TEXT]
|
| 33 |
+
|
| 34 |
+
def test_set_sentence(doc):
|
| 35 |
+
"""
|
| 36 |
+
Test setting a field on the sentences themselves
|
| 37 |
+
"""
|
| 38 |
+
doc.set(fields="sentiment",
|
| 39 |
+
contents=["4", "0"],
|
| 40 |
+
to_sentence=True)
|
| 41 |
+
|
| 42 |
+
assert doc.sentences[0].sentiment == "4"
|
| 43 |
+
assert doc.sentences[1].sentiment == "0"
|
| 44 |
+
|
| 45 |
+
def test_set_tokens(doc):
|
| 46 |
+
"""
|
| 47 |
+
Test setting values on tokens
|
| 48 |
+
"""
|
| 49 |
+
ner_contents = ["O", "ARTIFACT", "ARTIFACT", "O", "CAT"]
|
| 50 |
+
doc.set(fields=NER,
|
| 51 |
+
contents=ner_contents,
|
| 52 |
+
to_token=True)
|
| 53 |
+
|
| 54 |
+
result = doc.get(NER, from_token=True)
|
| 55 |
+
assert result == ner_contents
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_constituency_comment(doc):
|
| 59 |
+
"""
|
| 60 |
+
Test that setting the constituency tree on a doc sets the constituency comment
|
| 61 |
+
"""
|
| 62 |
+
for sentence in doc.sentences:
|
| 63 |
+
assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 0
|
| 64 |
+
|
| 65 |
+
# currently nothing is checking that the items are actually trees
|
| 66 |
+
trees = ["asdf", "zzzz"]
|
| 67 |
+
doc.set(fields=CONSTITUENCY,
|
| 68 |
+
contents=trees,
|
| 69 |
+
to_sentence=True)
|
| 70 |
+
|
| 71 |
+
for sentence, expected in zip(doc.sentences, trees):
|
| 72 |
+
constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")]
|
| 73 |
+
assert len(constituency_comments) == 1
|
| 74 |
+
assert constituency_comments[0].endswith(expected)
|
| 75 |
+
|
| 76 |
+
# Test that if we replace the trees with an updated tree, the comment is also replaced
|
| 77 |
+
trees = ["zzzz", "asdf"]
|
| 78 |
+
doc.set(fields=CONSTITUENCY,
|
| 79 |
+
contents=trees,
|
| 80 |
+
to_sentence=True)
|
| 81 |
+
|
| 82 |
+
for sentence, expected in zip(doc.sentences, trees):
|
| 83 |
+
constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")]
|
| 84 |
+
assert len(constituency_comments) == 1
|
| 85 |
+
assert constituency_comments[0].endswith(expected)
|
| 86 |
+
|
| 87 |
+
def test_sentiment_comment(doc):
|
| 88 |
+
"""
|
| 89 |
+
Test that setting the sentiment on a doc sets the sentiment comment
|
| 90 |
+
"""
|
| 91 |
+
for sentence in doc.sentences:
|
| 92 |
+
assert len([x for x in sentence.comments if x.startswith("# sentiment")]) == 0
|
| 93 |
+
|
| 94 |
+
# currently nothing is checking that the items are actually trees
|
| 95 |
+
sentiments = ["1", "2"]
|
| 96 |
+
doc.set(fields=SENTIMENT,
|
| 97 |
+
contents=sentiments,
|
| 98 |
+
to_sentence=True)
|
| 99 |
+
|
| 100 |
+
for sentence, expected in zip(doc.sentences, sentiments):
|
| 101 |
+
sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")]
|
| 102 |
+
assert len(sentiment_comments) == 1
|
| 103 |
+
assert sentiment_comments[0].endswith(expected)
|
| 104 |
+
|
| 105 |
+
# Test that if we replace the trees with an updated tree, the comment is also replaced
|
| 106 |
+
sentiments = ["3", "4"]
|
| 107 |
+
doc.set(fields=SENTIMENT,
|
| 108 |
+
contents=sentiments,
|
| 109 |
+
to_sentence=True)
|
| 110 |
+
|
| 111 |
+
for sentence, expected in zip(doc.sentences, sentiments):
|
| 112 |
+
sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")]
|
| 113 |
+
assert len(sentiment_comments) == 1
|
| 114 |
+
assert sentiment_comments[0].endswith(expected)
|
| 115 |
+
|
| 116 |
+
def test_sent_id_comment(doc):
|
| 117 |
+
"""
|
| 118 |
+
Test that setting the sent_id on a sentence sets the sentiment comment
|
| 119 |
+
"""
|
| 120 |
+
for sent_idx, sentence in enumerate(doc.sentences):
|
| 121 |
+
assert len([x for x in sentence.comments if x.startswith("# sent_id")]) == 1
|
| 122 |
+
assert sentence.sent_id == "%d" % sent_idx
|
| 123 |
+
doc.sentences[0].sent_id = "foo"
|
| 124 |
+
assert doc.sentences[0].sent_id == "foo"
|
| 125 |
+
assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
|
| 126 |
+
assert "# sent_id = foo" in doc.sentences[0].comments
|
| 127 |
+
|
| 128 |
+
doc.reindex_sentences(10)
|
| 129 |
+
for sent_idx, sentence in enumerate(doc.sentences):
|
| 130 |
+
assert sentence.sent_id == "%d" % (sent_idx + 10)
|
| 131 |
+
assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
|
| 132 |
+
assert "# sent_id = %d" % (sent_idx + 10) in sentence.comments
|
| 133 |
+
|
| 134 |
+
doc.sentences[0].add_comment("# sent_id = bar")
|
| 135 |
+
assert doc.sentences[0].sent_id == "bar"
|
| 136 |
+
assert "# sent_id = bar" in doc.sentences[0].comments
|
| 137 |
+
assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1
|
| 138 |
+
|
| 139 |
+
def test_doc_id_comment(doc):
|
| 140 |
+
"""
|
| 141 |
+
Test that setting the doc_id on a sentence sets the document comment
|
| 142 |
+
"""
|
| 143 |
+
assert doc.sentences[0].doc_id is None
|
| 144 |
+
assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 0
|
| 145 |
+
|
| 146 |
+
doc.sentences[0].doc_id = "foo"
|
| 147 |
+
assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1
|
| 148 |
+
assert "# doc_id = foo" in doc.sentences[0].comments
|
| 149 |
+
assert doc.sentences[0].doc_id == "foo"
|
| 150 |
+
|
| 151 |
+
doc.sentences[0].add_comment("# doc_id = bar")
|
| 152 |
+
assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1
|
| 153 |
+
assert doc.sentences[0].doc_id == "bar"
|
| 154 |
+
|
| 155 |
+
@pytest.fixture(scope="module")
|
| 156 |
+
def pipeline():
|
| 157 |
+
return stanza.Pipeline(dir=TEST_MODELS_DIR)
|
| 158 |
+
|
| 159 |
+
def test_serialized(pipeline):
|
| 160 |
+
"""
|
| 161 |
+
Brief test of the serialized format
|
| 162 |
+
|
| 163 |
+
Checks that NER entities are correctly set.
|
| 164 |
+
Also checks that constituency & sentiment are set on the sentences.
|
| 165 |
+
"""
|
| 166 |
+
text = "John Bauer works at Stanford"
|
| 167 |
+
doc = pipeline(text)
|
| 168 |
+
assert len(doc.ents) == 2
|
| 169 |
+
serialized = doc.to_serialized()
|
| 170 |
+
doc2 = Document.from_serialized(serialized)
|
| 171 |
+
assert len(doc2.sentences) == 1
|
| 172 |
+
assert len(doc2.ents) == 2
|
| 173 |
+
assert doc.sentences[0].constituency == doc2.sentences[0].constituency
|
| 174 |
+
assert doc.sentences[0].sentiment == doc2.sentences[0].sentiment
|
stanza/stanza/tests/common/test_dropout.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import stanza
|
| 6 |
+
from stanza.models.common.dropout import WordDropout
|
| 7 |
+
|
| 8 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 9 |
+
|
| 10 |
+
def test_word_dropout():
|
| 11 |
+
"""
|
| 12 |
+
Test that word_dropout is randomly dropping out the entire final dimension of a tensor
|
| 13 |
+
|
| 14 |
+
Doing 600 small rows should be super fast, but it leaves us with
|
| 15 |
+
something like a 1 in 10^180 chance of the test failing. Not very
|
| 16 |
+
common, in other words
|
| 17 |
+
"""
|
| 18 |
+
wd = WordDropout(0.5)
|
| 19 |
+
batch = torch.randn(600, 4)
|
| 20 |
+
dropped = wd(batch)
|
| 21 |
+
# the one time any of this happens, it's going to be really confusing
|
| 22 |
+
assert not torch.allclose(batch, dropped)
|
| 23 |
+
num_zeros = 0
|
| 24 |
+
for i in range(batch.shape[0]):
|
| 25 |
+
assert torch.allclose(dropped[i], batch[i]) or torch.sum(dropped[i]) == 0.0
|
| 26 |
+
if torch.sum(dropped[i]) == 0.0:
|
| 27 |
+
num_zeros += 1
|
| 28 |
+
assert num_zeros > 0 and num_zeros < batch.shape[0]
|
stanza/stanza/tests/common/test_short_name_to_treebank.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
import stanza
|
| 4 |
+
from stanza.models.common import short_name_to_treebank
|
| 5 |
+
|
| 6 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 7 |
+
|
| 8 |
+
def test_short_name():
|
| 9 |
+
assert short_name_to_treebank.short_name_to_treebank("en_ewt") == "UD_English-EWT"
|
| 10 |
+
|
| 11 |
+
def test_canonical_name():
|
| 12 |
+
assert short_name_to_treebank.canonical_treebank_name("UD_URDU-UDTB") == "UD_Urdu-UDTB"
|
| 13 |
+
assert short_name_to_treebank.canonical_treebank_name("ur_udtb") == "UD_Urdu-UDTB"
|
| 14 |
+
assert short_name_to_treebank.canonical_treebank_name("Unban_Mox_Opal") == "Unban_Mox_Opal"
|
stanza/stanza/tests/constituency/test_convert_it_vit.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple different classes of trees to check the output of the VIT conversion
|
| 3 |
+
|
| 4 |
+
A couple representative trees are included, but hopefully not enough
|
| 5 |
+
to be a problem in terms of our license.
|
| 6 |
+
|
| 7 |
+
One of the tests is currently disabled as it relies on tregex & tsurgeon features
|
| 8 |
+
not yet released
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import io
|
| 12 |
+
import os
|
| 13 |
+
import tempfile
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
|
| 17 |
+
from stanza.server import tsurgeon
|
| 18 |
+
from stanza.utils.conll import CoNLL
|
| 19 |
+
from stanza.utils.datasets.constituency import convert_it_vit
|
| 20 |
+
|
| 21 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 22 |
+
|
| 23 |
+
# just a sample! don't sue us please
|
| 24 |
+
CON_SAMPLE = """
|
| 25 |
+
#ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]]
|
| 26 |
+
|
| 27 |
+
#ID=sent_00318 dirsp-[fc-[congf-tuttavia, f-[sn-[sq-[ind-qualche], n-problema], ir_infl-[vsupir-potrebbe, vcl-esserci], compc-[clit-ci, sp-[p-per, sn-[art-la, n-commissione, sa-[ag-esteri], f2-[sp-[part-alla, relob-cui, sn-[n-presidenza]], f-[ibar-[vc-è], compc-[sn-[n-candidato], sn-[art-l, n-esponente, spd-[pd-di, sn-[mw-Alleanza, npro-Nazionale]], sn-[mw-Mirko, nh-Tremaglia]]]]]]]]]], dirs-':', f3-[sn-[art-una, n-candidatura, sc-[q-più, sa-[ppas-subìta], sc-[ccong-che, sa-[ppas-gradita]], compt-[spda-[partda-dalla, sn-[mw-Lega, npro-Nord, punt-',', f2-[rel-che, fc-[congf-tuttavia, f-[ir_infl-[vsupir-dovrebbe, vit-rispettare], compt-[sn-[art-gli, n-accordi]]]]]]]]]], punto-.]]
|
| 28 |
+
|
| 29 |
+
#ID=sent_00589 f-[sn-[art-l, n-ottimismo, spd-[pd-di, sn-[nh-Kantor]]], ir_infl-[vsupir-potrebbe, congf-però, vcl-rivelarsi], compc-[sn-[in-ancora, art-una, nt-volta], sa-[ag-prematuro]], punto-.]
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
UD_SAMPLE = """
|
| 33 |
+
# sent_id = VIT-2
|
| 34 |
+
# text = Negli ultimi anni la dinamica dei polo di attrazione è stata sempre più caratterizzata dall'emergere di una crescente concorrenza che si è progressivamente spostata dalle singole imprese ai sistemi economici e territoriali, determinando l'esigenza di una riconsiderazione dei rapporti esistenti tra soggetti produttivi e ambiente in cui questi operano.
|
| 35 |
+
1-2 Negli _ _ _ _ _ _ _ _
|
| 36 |
+
1 In in ADP E _ 4 case _ _
|
| 37 |
+
2 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 4 det _ _
|
| 38 |
+
3 ultimi ultimo ADJ A Gender=Masc|Number=Plur 4 amod _ _
|
| 39 |
+
4 anni anno NOUN S Gender=Masc|Number=Plur 16 obl _ _
|
| 40 |
+
5 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 6 det _ _
|
| 41 |
+
6 dinamica dinamica NOUN S Gender=Fem|Number=Sing 16 nsubj:pass _ _
|
| 42 |
+
7-8 dei _ _ _ _ _ _ _ _
|
| 43 |
+
7 di di ADP E _ 9 case _ _
|
| 44 |
+
8 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 9 det _ _
|
| 45 |
+
9 polo polo NOUN S Gender=Masc|Number=Sing 6 nmod _ _
|
| 46 |
+
10 di di ADP E _ 11 case _ _
|
| 47 |
+
11 attrazione attrazione NOUN S Gender=Fem|Number=Sing 9 nmod _ _
|
| 48 |
+
12 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux _ _
|
| 49 |
+
13 stata essere AUX VA Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 16 aux:pass _ _
|
| 50 |
+
14 sempre sempre ADV B _ 15 advmod _ _
|
| 51 |
+
15 più più ADV B _ 16 advmod _ _
|
| 52 |
+
16 caratterizzata caratterizzare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 0 root _ _
|
| 53 |
+
17-18 dall' _ _ _ _ _ _ _ SpaceAfter=No
|
| 54 |
+
17 da da ADP E _ 19 case _ _
|
| 55 |
+
18 l' il DET RD Definite=Def|Number=Sing|PronType=Art 19 det _ _
|
| 56 |
+
19 emergere emergere NOUN S Gender=Masc|Number=Sing 16 obl _ _
|
| 57 |
+
20 di di ADP E _ 23 case _ _
|
| 58 |
+
21 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 23 det _ _
|
| 59 |
+
22 crescente crescente ADJ A Number=Sing 23 amod _ _
|
| 60 |
+
23 concorrenza concorrenza NOUN S Gender=Fem|Number=Sing 19 nmod _ _
|
| 61 |
+
24 che che PRON PR PronType=Rel 28 nsubj _ _
|
| 62 |
+
25 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 28 expl _ _
|
| 63 |
+
26 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 28 aux _ _
|
| 64 |
+
27 progressivamente progressivamente ADV B _ 28 advmod _ _
|
| 65 |
+
28 spostata spostare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 23 acl:relcl _ _
|
| 66 |
+
29-30 dalle _ _ _ _ _ _ _ _
|
| 67 |
+
29 da da ADP E _ 32 case _ _
|
| 68 |
+
30 le il DET RD Definite=Def|Gender=Fem|Number=Plur|PronType=Art 32 det _ _
|
| 69 |
+
31 singole singolo ADJ A Gender=Fem|Number=Plur 32 amod _ _
|
| 70 |
+
32 imprese impresa NOUN S Gender=Fem|Number=Plur 28 obl _ _
|
| 71 |
+
33-34 ai _ _ _ _ _ _ _ _
|
| 72 |
+
33 a a ADP E _ 35 case _ _
|
| 73 |
+
34 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 35 det _ _
|
| 74 |
+
35 sistemi sistema NOUN S Gender=Masc|Number=Plur 28 obl _ _
|
| 75 |
+
36 economici economico ADJ A Gender=Masc|Number=Plur 35 amod _ _
|
| 76 |
+
37 e e CCONJ CC _ 38 cc _ _
|
| 77 |
+
38 territoriali territoriale ADJ A Number=Plur 36 conj _ SpaceAfter=No
|
| 78 |
+
39 , , PUNCT FF _ 28 punct _ _
|
| 79 |
+
40 determinando determinare VERB V VerbForm=Ger 28 advcl _ _
|
| 80 |
+
41 l' il DET RD Definite=Def|Number=Sing|PronType=Art 42 det _ SpaceAfter=No
|
| 81 |
+
42 esigenza esigenza NOUN S Gender=Fem|Number=Sing 40 obj _ _
|
| 82 |
+
43 di di ADP E _ 45 case _ _
|
| 83 |
+
44 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 45 det _ _
|
| 84 |
+
45 riconsiderazione riconsiderazione NOUN S Gender=Fem|Number=Sing 42 nmod _ _
|
| 85 |
+
46-47 dei _ _ _ _ _ _ _ _
|
| 86 |
+
46 di di ADP E _ 48 case _ _
|
| 87 |
+
47 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 48 det _ _
|
| 88 |
+
48 rapporti rapporto NOUN S Gender=Masc|Number=Plur 45 nmod _ _
|
| 89 |
+
49 esistenti esistente VERB V Number=Plur 48 acl _ _
|
| 90 |
+
50 tra tra ADP E _ 51 case _ _
|
| 91 |
+
51 soggetti soggetto NOUN S Gender=Masc|Number=Plur 49 obl _ _
|
| 92 |
+
52 produttivi produttivo ADJ A Gender=Masc|Number=Plur 51 amod _ _
|
| 93 |
+
53 e e CCONJ CC _ 54 cc _ _
|
| 94 |
+
54 ambiente ambiente NOUN S Gender=Masc|Number=Sing 51 conj _ _
|
| 95 |
+
55 in in ADP E _ 56 case _ _
|
| 96 |
+
56 cui cui PRON PR PronType=Rel 58 obl _ _
|
| 97 |
+
57 questi questo PRON PD Gender=Masc|Number=Plur|PronType=Dem 58 nsubj _ _
|
| 98 |
+
58 operano operare VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 54 acl:relcl _ SpaceAfter=No
|
| 99 |
+
59 . . PUNCT FS _ 16 punct _ _
|
| 100 |
+
|
| 101 |
+
# sent_id = VIT-318
|
| 102 |
+
# text = Tuttavia qualche problema potrebbe esserci per la commissione esteri alla cui presidenza è candidato l'esponente di Alleanza Nazionale Mirko Tremaglia: una candidatura più subìta che gradita dalla Lega Nord, che tuttavia dovrebbe rispettare gli accordi.
|
| 103 |
+
1 Tuttavia tuttavia CCONJ CC _ 5 cc _ _
|
| 104 |
+
2 qualche qualche DET DI Number=Sing|PronType=Ind 3 det _ _
|
| 105 |
+
3 problema problema NOUN S Gender=Masc|Number=Sing 5 nsubj _ _
|
| 106 |
+
4 potrebbe potere AUX VA Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux _ _
|
| 107 |
+
5-6 esserci _ _ _ _ _ _ _ _
|
| 108 |
+
5 esser essere VERB V VerbForm=Inf 0 root _ _
|
| 109 |
+
6 ci ci PRON PC Clitic=Yes|Number=Plur|Person=1|PronType=Prs 5 expl _ _
|
| 110 |
+
7 per per ADP E _ 9 case _ _
|
| 111 |
+
8 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 9 det _ _
|
| 112 |
+
9 commissione commissione NOUN S Gender=Fem|Number=Sing 5 obl _ _
|
| 113 |
+
10 esteri estero ADJ A Gender=Masc|Number=Plur 9 amod _ _
|
| 114 |
+
11-12 alla _ _ _ _ _ _ _ _
|
| 115 |
+
11 a a ADP E _ 14 case _ _
|
| 116 |
+
12 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 14 det _ _
|
| 117 |
+
13 cui cui DET DR PronType=Rel 14 det:poss _ _
|
| 118 |
+
14 presidenza presidenza NOUN S Gender=Fem|Number=Sing 16 obl _ _
|
| 119 |
+
15 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux:pass _ _
|
| 120 |
+
16 candidato candidare VERB V Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 9 acl:relcl _ _
|
| 121 |
+
17 l' il DET RD Definite=Def|Number=Sing|PronType=Art 18 det _ SpaceAfter=No
|
| 122 |
+
18 esponente esponente NOUN S Number=Sing 16 nsubj:pass _ _
|
| 123 |
+
19 di di ADP E _ 20 case _ _
|
| 124 |
+
20 Alleanza Alleanza PROPN SP _ 18 nmod _ _
|
| 125 |
+
21 Nazionale Nazionale PROPN SP _ 20 flat:name _ _
|
| 126 |
+
22 Mirko Mirko PROPN SP _ 18 nmod _ _
|
| 127 |
+
23 Tremaglia Tremaglia PROPN SP _ 22 flat:name _ SpaceAfter=No
|
| 128 |
+
24 : : PUNCT FC _ 22 punct _ _
|
| 129 |
+
25 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 26 det _ _
|
| 130 |
+
26 candidatura candidatura NOUN S Gender=Fem|Number=Sing 22 appos _ _
|
| 131 |
+
27 più più ADV B _ 28 advmod _ _
|
| 132 |
+
28 subìta subire VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 26 advcl _ _
|
| 133 |
+
29 che che CCONJ CC _ 30 cc _ _
|
| 134 |
+
30 gradita gradito ADJ A Gender=Fem|Number=Sing 28 amod _ _
|
| 135 |
+
31-32 dalla _ _ _ _ _ _ _ _
|
| 136 |
+
31 da da ADP E _ 33 case _ _
|
| 137 |
+
32 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 33 det _ _
|
| 138 |
+
33 Lega Lega PROPN SP _ 28 obl:agent _ _
|
| 139 |
+
34 Nord Nord PROPN SP _ 33 flat:name _ SpaceAfter=No
|
| 140 |
+
35 , , PUNCT FC _ 33 punct _ _
|
| 141 |
+
36 che che PRON PR PronType=Rel 39 nsubj _ _
|
| 142 |
+
37 tuttavia tuttavia CCONJ CC _ 39 cc _ _
|
| 143 |
+
38 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 39 aux _ _
|
| 144 |
+
39 rispettare rispettare VERB V VerbForm=Inf 33 acl:relcl _ _
|
| 145 |
+
40 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 41 det _ _
|
| 146 |
+
41 accordi accordio NOUN S Gender=Masc|Number=Plur 39 obj _ SpaceAfter=No
|
| 147 |
+
42 . . PUNCT FS _ 5 punct _ _
|
| 148 |
+
|
| 149 |
+
# sent_id = VIT-591
|
| 150 |
+
# text = L'ottimismo di Kantor potrebbe però rivelarsi ancora una volta prematuro.
|
| 151 |
+
1 L' il DET RD Definite=Def|Number=Sing|PronType=Art 2 det _ SpaceAfter=No
|
| 152 |
+
2 ottimismo ottimismo NOUN S Gender=Masc|Number=Sing 7 nsubj _ _
|
| 153 |
+
3 di di ADP E _ 4 case _ _
|
| 154 |
+
4 Kantor Kantor PROPN SP _ 2 nmod _ _
|
| 155 |
+
5 potrebbe potere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux _ _
|
| 156 |
+
6 però però ADV B _ 7 advmod _ _
|
| 157 |
+
7-8 rivelarsi _ _ _ _ _ _ _ _
|
| 158 |
+
7 rivelar rivelare VERB V VerbForm=Inf 0 root _ _
|
| 159 |
+
8 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 7 expl _ _
|
| 160 |
+
9 ancora ancora ADV B _ 7 advmod _ _
|
| 161 |
+
10 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 11 det _ _
|
| 162 |
+
11 volta volta NOUN S Gender=Fem|Number=Sing 7 obl _ _
|
| 163 |
+
12 prematuro prematuro ADJ A Gender=Masc|Number=Sing 7 xcomp _ SpaceAfter=No
|
| 164 |
+
13 . . PUNCT FS _ 7 punct _ _
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def test_process_mwts():
|
| 169 |
+
# dei appears multiple times
|
| 170 |
+
# the verb/pron esserci will be ignored
|
| 171 |
+
expected_mwts = {'Negli': ('In', 'gli'), 'dei': ('di', 'i'), "dall'": ('da', "l'"), 'dalle': ('da', 'le'), 'ai': ('a', 'i'), 'alla': ('a', 'la'), 'dalla': ('da', 'la')}
|
| 172 |
+
|
| 173 |
+
ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)
|
| 174 |
+
|
| 175 |
+
mwts = convert_it_vit.get_mwt(ud_train_data)
|
| 176 |
+
assert expected_mwts == mwts
|
| 177 |
+
|
| 178 |
+
def test_raw_tree():
|
| 179 |
+
con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))
|
| 180 |
+
expected_ids = ["#ID=sent_00002", "#ID=sent_00318", "#ID=sent_00589"]
|
| 181 |
+
expected_trees = ["(ROOT (cp (sp (part negli) (sn (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd dei) (sn (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda dall) (sn (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda dalle) (sn (sa (ag singole)) (n imprese))) (sp (part ai) (sn (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd dei) (sn (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))",
|
| 182 |
+
"(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part alla) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda dalla) (sn (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))",
|
| 183 |
+
"(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"]
|
| 184 |
+
assert len(con_sentences) == 3
|
| 185 |
+
for sentence, expected_id, expected_tree in zip(con_sentences, expected_ids, expected_trees):
|
| 186 |
+
assert sentence[0] == expected_id
|
| 187 |
+
tree = convert_it_vit.raw_tree(sentence[1])
|
| 188 |
+
assert str(tree) == expected_tree
|
| 189 |
+
|
| 190 |
+
def test_update_mwts():
|
| 191 |
+
con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE))
|
| 192 |
+
ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE)
|
| 193 |
+
mwt_map = convert_it_vit.get_mwt(ud_train_data)
|
| 194 |
+
expected_trees=["(ROOT (cp (sp (part In) (sn (art gli) (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd di) (sn (art i) (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda da) (sn (art l') (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda da) (sn (art le) (sa (ag singole)) (n imprese))) (sp (part a) (sn (art i) (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd di) (sn (art i) (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))",
|
| 195 |
+
"(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part a) (art la) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda da) (sn (art la) (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))",
|
| 196 |
+
"(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (clit si) (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"]
|
| 197 |
+
with tsurgeon.Tsurgeon() as tsurgeon_processor:
|
| 198 |
+
for con_sentence, ud_sentence, expected_tree in zip(con_sentences, ud_train_data.sentences, expected_trees):
|
| 199 |
+
con_tree = convert_it_vit.raw_tree(con_sentence[1])
|
| 200 |
+
updated_tree, _ = convert_it_vit.update_mwts_and_special_cases(con_tree, ud_sentence, mwt_map, tsurgeon_processor)
|
| 201 |
+
assert str(updated_tree) == expected_tree
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
CON_PERCENT_SAMPLE = """
|
| 205 |
+
ID#sent_00020 f-[sn-[art-il, n-tesoro], ibar-[vt-mette], compt-[sp-[part-sul, sn-[n-mercato]], sn-[art-il, num-51%, sp-[p-a, sn-[num-2, n-lire]], sp-[p-per, sn-[n-azione]]]], punto-.]
|
| 206 |
+
ID#sent_00022 dirsp-[f3-[sn-[art-le, n-novità]], dirs-':', f3-[coord-[sn-[n-voto, spd-[pd-di, sn-[n-lista]]], cong-e, sn-[n-tetto, sp-[part-agli, sn-[n-acquisti]], sv3-[vppt-limitato, comppas-[sp-[part-allo, sn-[num-0/5%]]]]]], punto-.]]
|
| 207 |
+
ID#sent_00517 dirsp-[fc-[f-[sn-[art-l, n-aumento, sa-[ag-mensile], spd-[pd-di, sn-[nt-aprile]]], ibar-[ause-è, vppc-stato], compc-[sq-[q-dell_, sn-[num-1/3%]], sp-[p-contro, sn-[art-lo, num-0/7/0/8%, spd-[partd-degli, sn-[sa-[ag-ultimi], num-due, sn-[nt-mesi]]]]]]]]]
|
| 208 |
+
ID#sent_01117 fc-[f-[sn-[art-La, sa-[ag-crescente], n-ripresa, spd-[partd-dei, sn-[n-beni, spd-[pd-di, sn-[n-consumo]]]]], ibar-[vin-deriva], savv-[avv-esclusivamente], compin-[spda-[partda-dal, sn-[n-miglioramento, f2-[spd-[pd-di, sn-[relob-cui]], f-[ibar-[ausa-hanno, vppin-beneficiato], compin-[sn-[n-beni, coord-[sa-[ag-durevoli, fp-[par-'(', sn-[num-plus4/5%], par-')']], cong-e, sa-[ag-semidurevoli, fp-[par-'(', sn-[num-plus1/5%], par-')']]]]]]]]]]], punt-',', fs-[cosu-mentre, f-[sn-[art-i, n-beni, sa-[neg-non, ag-durevoli], fp-[par-'(', sn-[num-min1%], par-')']], ibar-[vt-accusano], cong-ancora, compt-[sn-[art-un, sa-[ag-evidente], n-ritardo]]]], punto-.]
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
CON_PERCENT_LEAVES = [
|
| 212 |
+
['il', 'tesoro', 'mette', 'sul', 'mercato', 'il', '51', '%%', 'a', '2', 'lire', 'per', 'azione', '.'],
|
| 213 |
+
['le', 'novità', ':', 'voto', 'di', 'lista', 'e', 'tetto', 'agli', 'acquisti', 'limitato', 'allo', '0,5', '%%', '.'],
|
| 214 |
+
['l', 'aumento', 'mensile', 'di', 'aprile', 'è', 'stato', "dell'", '1,3', '%%', 'contro', 'lo', '0/7,0/8', '%%', 'degli', 'ultimi', 'due', 'mesi'],
|
| 215 |
+
# the plus and min look bad, but they get cleaned up when merging with the UD version of the dataset
|
| 216 |
+
['La', 'crescente', 'ripresa', 'dei', 'beni', 'di', 'consumo', 'deriva', 'esclusivamente', 'dal', 'miglioramento', 'di', 'cui', 'hanno', 'beneficiato', 'beni', 'durevoli', '(', 'plus4,5', '%%', ')', 'e', 'semidurevoli', '(', 'plus1,5', '%%', ')', ',', 'mentre', 'i', 'beni', 'non', 'durevoli', '(', 'min1', '%%', ')', 'accusano', 'ancora', 'un', 'evidente', 'ritardo', '.']
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
def test_read_percent():
|
| 220 |
+
con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_PERCENT_SAMPLE))
|
| 221 |
+
assert len(con_sentences) == len(CON_PERCENT_LEAVES)
|
| 222 |
+
for (_, raw_tree), expected_leaves in zip(con_sentences, CON_PERCENT_LEAVES):
|
| 223 |
+
tree = convert_it_vit.raw_tree(raw_tree)
|
| 224 |
+
words = tree.leaf_labels()
|
| 225 |
+
if expected_leaves is None:
|
| 226 |
+
print(words)
|
| 227 |
+
else:
|
| 228 |
+
assert words == expected_leaves
|
stanza/stanza/tests/constituency/test_convert_starlang.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test a couple different classes of trees to check the output of the Starlang conversion
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stanza.utils.datasets.constituency import convert_starlang
|
| 11 |
+
|
| 12 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 13 |
+
|
| 14 |
+
TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )"
|
| 15 |
+
|
| 16 |
+
def test_read_tree():
|
| 17 |
+
"""
|
| 18 |
+
Test a basic tree read
|
| 19 |
+
"""
|
| 20 |
+
tree = convert_starlang.read_tree(TREE)
|
| 21 |
+
assert "(ROOT (S (NP (NP Bayan) (NP Haag)) (VP (NP Elianti) (VP çalar)) (. .)))" == str(tree)
|
| 22 |
+
|
| 23 |
+
def test_missing_word():
|
| 24 |
+
"""
|
| 25 |
+
Test that an error is thrown if the word is missing
|
| 26 |
+
"""
|
| 27 |
+
tree_text = TREE.replace("turkish=", "foo=")
|
| 28 |
+
with pytest.raises(ValueError):
|
| 29 |
+
tree = convert_starlang.read_tree(tree_text)
|
| 30 |
+
|
| 31 |
+
def test_bad_label():
|
| 32 |
+
"""
|
| 33 |
+
Test that an unexpected label results in an error
|
| 34 |
+
"""
|
| 35 |
+
tree_text = TREE.replace("(S", "(s")
|
| 36 |
+
with pytest.raises(ValueError):
|
| 37 |
+
tree = convert_starlang.read_tree(tree_text)
|
stanza/stanza/tests/constituency/test_in_order_oracle.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
from stanza.models.constituency import parse_transitions
|
| 5 |
+
from stanza.models.constituency import tree_reader
|
| 6 |
+
from stanza.models.constituency.base_model import SimpleModel
|
| 7 |
+
from stanza.models.constituency.in_order_oracle import *
|
| 8 |
+
from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme
|
| 9 |
+
from stanza.models.constituency.transition_sequence import build_treebank
|
| 10 |
+
|
| 11 |
+
from stanza.tests import *
|
| 12 |
+
from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
|
| 13 |
+
|
| 14 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 15 |
+
|
| 16 |
+
# A sample tree from PTB with a single unary transition (at a location other than root)
|
| 17 |
+
SINGLE_UNARY_TREE = """
|
| 18 |
+
( (S
|
| 19 |
+
(NP-SBJ-1 (DT A) (NN record) (NN date) )
|
| 20 |
+
(VP (VBZ has) (RB n't)
|
| 21 |
+
(VP (VBN been)
|
| 22 |
+
(VP (VBN set)
|
| 23 |
+
(NP (-NONE- *-1) ))))
|
| 24 |
+
(. .) ))
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# [Shift, OpenConstituent(('NP-SBJ-1',)), Shift, Shift, CloseConstituent, OpenConstituent(('S',)), Shift, OpenConstituent(('VP',)), Shift, Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), CloseConstituent, CloseConstituent, CloseConstituent, CloseConstituent, Shift, CloseConstituent, OpenConstituent(('ROOT',)), CloseConstituent]
|
| 28 |
+
|
| 29 |
+
# A sample tree from PTB with a double unary transition (at a location other than root)
|
| 30 |
+
DOUBLE_UNARY_TREE = """
|
| 31 |
+
( (S
|
| 32 |
+
(NP-SBJ
|
| 33 |
+
(NP (RB Not) (PDT all) (DT those) )
|
| 34 |
+
(SBAR
|
| 35 |
+
(WHNP-3 (WP who) )
|
| 36 |
+
(S
|
| 37 |
+
(NP-SBJ (-NONE- *T*-3) )
|
| 38 |
+
(VP (VBD wrote) ))))
|
| 39 |
+
(VP (VBP oppose)
|
| 40 |
+
(NP (DT the) (NNS changes) ))
|
| 41 |
+
(. .) ))
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# A sample tree from PTB with a triple unary transition (at a location other than root)
|
| 45 |
+
# The triple unary is at the START of the next bracket, which affects how the
|
| 46 |
+
# dynamic oracle repairs the transition sequence
|
| 47 |
+
TRIPLE_UNARY_START_TREE = """
|
| 48 |
+
( (S
|
| 49 |
+
(PRN
|
| 50 |
+
(S
|
| 51 |
+
(NP-SBJ (-NONE- *) )
|
| 52 |
+
(VP (VB See) )))
|
| 53 |
+
(, ,)
|
| 54 |
+
(NP-SBJ
|
| 55 |
+
(NP (DT the) (JJ other) (NN rule) )
|
| 56 |
+
(PP (IN of)
|
| 57 |
+
(NP (NN thumb) ))
|
| 58 |
+
(PP (IN about)
|
| 59 |
+
(NP (NN ballooning) )))))
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# A sample tree from PTB with a triple unary transition (at a location other than root)
|
| 63 |
+
# The triple unary is at the END of the next bracket, which affects how the
|
| 64 |
+
# dynamic oracle repairs the transition sequence
|
| 65 |
+
TRIPLE_UNARY_END_TREE = """
|
| 66 |
+
( (S
|
| 67 |
+
(NP (NNS optimists) )
|
| 68 |
+
(VP (VBP expect)
|
| 69 |
+
(S
|
| 70 |
+
(NP-SBJ-4 (NNP Hong) (NNP Kong) )
|
| 71 |
+
(VP (TO to)
|
| 72 |
+
(VP (VB hum)
|
| 73 |
+
(ADVP-CLR (RB along) )
|
| 74 |
+
(SBAR-MNR (RB as)
|
| 75 |
+
(S
|
| 76 |
+
(NP-SBJ (-NONE- *-4) )
|
| 77 |
+
(VP (-NONE- *?*)
|
| 78 |
+
(ADVP-TMP (IN before) ))))))))))
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
TREES = [SINGLE_UNARY_TREE, DOUBLE_UNARY_TREE, TRIPLE_UNARY_START_TREE, TRIPLE_UNARY_END_TREE]
|
| 82 |
+
TREEBANK = "\n".join(TREES)
|
| 83 |
+
|
| 84 |
+
NOUN_PHRASE_TREE = """
|
| 85 |
+
( (NP
|
| 86 |
+
(NP (NNP Chicago) (POS 's))
|
| 87 |
+
(NNP Goodman)
|
| 88 |
+
(NNP Theatre)))
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
WIDE_NP_TREE = """
|
| 92 |
+
( (S
|
| 93 |
+
(NP-SBJ (DT These) (NNS studies))
|
| 94 |
+
(VP (VBP demonstrate)
|
| 95 |
+
(SBAR (IN that)
|
| 96 |
+
(S
|
| 97 |
+
(NP-SBJ (NNS mice))
|
| 98 |
+
(VP (VBP are)
|
| 99 |
+
(NP-PRD
|
| 100 |
+
(NP (DT a)
|
| 101 |
+
(ADJP (JJ practical)
|
| 102 |
+
(CC and)
|
| 103 |
+
(JJ powerful))
|
| 104 |
+
(JJ experimental) (NN system))
|
| 105 |
+
(SBAR
|
| 106 |
+
(WHADVP-2 (-NONE- *0*))
|
| 107 |
+
(S
|
| 108 |
+
(NP-SBJ (-NONE- *PRO*))
|
| 109 |
+
(VP (TO to)
|
| 110 |
+
(VP (VB study)
|
| 111 |
+
(NP (DT the) (NN genetics)))))))))))))
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
WIDE_TREES = [NOUN_PHRASE_TREE, WIDE_NP_TREE]
|
| 115 |
+
WIDE_TREEBANK = "\n".join(WIDE_TREES)
|
| 116 |
+
|
| 117 |
+
ROOT_LABELS = ["ROOT"]
|
| 118 |
+
|
| 119 |
+
def get_repairs(gold_sequence, wrong_transition, repair_fn):
|
| 120 |
+
"""
|
| 121 |
+
Use the repair function and the wrong transition to iterate over the gold sequence
|
| 122 |
+
|
| 123 |
+
Returns a list of possible repairs, one for each position in the sequence
|
| 124 |
+
Repairs are tuples, (idx, seq)
|
| 125 |
+
"""
|
| 126 |
+
repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None))
|
| 127 |
+
for idx, gold_transition in enumerate(gold_sequence)]
|
| 128 |
+
repairs = [x for x in repairs if x[1] is not None]
|
| 129 |
+
return repairs
|
| 130 |
+
|
| 131 |
+
@pytest.fixture(scope="module")
|
| 132 |
+
def unary_trees():
|
| 133 |
+
trees = tree_reader.read_trees(TREEBANK)
|
| 134 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 135 |
+
assert len(trees) == len(TREES)
|
| 136 |
+
|
| 137 |
+
return trees
|
| 138 |
+
|
| 139 |
+
@pytest.fixture(scope="module")
|
| 140 |
+
def gold_sequences(unary_trees):
|
| 141 |
+
gold_sequences = build_treebank(unary_trees, TransitionScheme.IN_ORDER)
|
| 142 |
+
return gold_sequences
|
| 143 |
+
|
| 144 |
+
@pytest.fixture(scope="module")
|
| 145 |
+
def wide_trees():
|
| 146 |
+
trees = tree_reader.read_trees(WIDE_TREEBANK)
|
| 147 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 148 |
+
assert len(trees) == len(WIDE_TREES)
|
| 149 |
+
|
| 150 |
+
return trees
|
| 151 |
+
|
| 152 |
+
def test_wrong_open_root(gold_sequences):
|
| 153 |
+
"""
|
| 154 |
+
Test the results of the dynamic oracle on a few trees if the ROOT is mishandled.
|
| 155 |
+
"""
|
| 156 |
+
wrong_transition = OpenConstituent("S")
|
| 157 |
+
gold_transition = OpenConstituent("ROOT")
|
| 158 |
+
close_transition = CloseConstituent()
|
| 159 |
+
|
| 160 |
+
for gold_sequence in gold_sequences:
|
| 161 |
+
# each of the sequences should be ended with ROOT, Close
|
| 162 |
+
assert gold_sequence[-2] == gold_transition
|
| 163 |
+
|
| 164 |
+
repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_root_error)
|
| 165 |
+
# there is only spot in the sequence with a ROOT, so there should
|
| 166 |
+
# be exactly one location which affords a S/ROOT replacement
|
| 167 |
+
assert len(repairs) == 1
|
| 168 |
+
repair = repairs[0]
|
| 169 |
+
|
| 170 |
+
# the repair should occur at the -2 position, which is where ROOT is
|
| 171 |
+
assert repair[0] == len(gold_sequence) - 2
|
| 172 |
+
# and the resulting list should have the wrong transition followed by a Close
|
| 173 |
+
# to give the model another chance to close the tree
|
| 174 |
+
expected = gold_sequence[:-2] + [wrong_transition, close_transition] + gold_sequence[-2:]
|
| 175 |
+
assert repair[1] == expected
|
| 176 |
+
|
| 177 |
+
def test_missed_unary(gold_sequences):
|
| 178 |
+
"""
|
| 179 |
+
Test the repairs of an open/open error if it is effectively a skipped unary transition
|
| 180 |
+
"""
|
| 181 |
+
wrong_transition = OpenConstituent("S")
|
| 182 |
+
|
| 183 |
+
repairs = get_repairs(gold_sequences[0], wrong_transition, fix_wrong_open_unary_chain)
|
| 184 |
+
assert len(repairs) == 0
|
| 185 |
+
|
| 186 |
+
# here we are simulating picking NT-S instead of NT-VP
|
| 187 |
+
# the DOUBLE_UNARY tree has one location where this is relevant, index 11
|
| 188 |
+
repairs = get_repairs(gold_sequences[1], wrong_transition, fix_wrong_open_unary_chain)
|
| 189 |
+
assert len(repairs) == 1
|
| 190 |
+
assert repairs[0][0] == 11
|
| 191 |
+
assert repairs[0][1] == gold_sequences[1][:11] + gold_sequences[1][13:]
|
| 192 |
+
|
| 193 |
+
# the TRIPLE_UNARY_START tree has two locations where this is relevant
|
| 194 |
+
# at index 1, the pattern goes (S (VP ...))
|
| 195 |
+
# so choosing S instead of VP means you can skip the VP and only miss that one bracket
|
| 196 |
+
# at index 5, the pattern goes (S (PRN (S (VP ...))) (...))
|
| 197 |
+
# note that this is capturing a unary transition into a larger constituent
|
| 198 |
+
# skipping the PRN is satisfactory
|
| 199 |
+
repairs = get_repairs(gold_sequences[2], wrong_transition, fix_wrong_open_unary_chain)
|
| 200 |
+
assert len(repairs) == 2
|
| 201 |
+
assert repairs[0][0] == 1
|
| 202 |
+
assert repairs[0][1] == gold_sequences[2][:1] + gold_sequences[2][3:]
|
| 203 |
+
assert repairs[1][0] == 5
|
| 204 |
+
assert repairs[1][1] == gold_sequences[2][:5] + gold_sequences[2][7:]
|
| 205 |
+
|
| 206 |
+
# The TRIPLE_UNARY_END tree has 2 sections of tree for a total of 3 locations
|
| 207 |
+
# where the repair might happen
|
| 208 |
+
# Surprisingly the unary transition at the very start can only be
|
| 209 |
+
# repaired by skipping it and using the outer S transition instead
|
| 210 |
+
# The second repair overall (first repair in the second location)
|
| 211 |
+
# should have a double skip to reach the S node
|
| 212 |
+
repairs = get_repairs(gold_sequences[3], wrong_transition, fix_wrong_open_unary_chain)
|
| 213 |
+
assert len(repairs) == 3
|
| 214 |
+
assert repairs[0][0] == 1
|
| 215 |
+
assert repairs[0][1] == gold_sequences[3][:1] + gold_sequences[3][3:]
|
| 216 |
+
assert repairs[1][0] == 21
|
| 217 |
+
assert repairs[1][1] == gold_sequences[3][:21] + gold_sequences[3][25:]
|
| 218 |
+
assert repairs[2][0] == 23
|
| 219 |
+
assert repairs[2][1] == gold_sequences[3][:23] + gold_sequences[3][25:]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def test_open_with_stuff(unary_trees, gold_sequences):
|
| 223 |
+
wrong_transition = OpenConstituent("S")
|
| 224 |
+
expected_trees = [
|
| 225 |
+
"(ROOT (S (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))",
|
| 226 |
+
"(ROOT (S (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))",
|
| 227 |
+
None,
|
| 228 |
+
"(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))"
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):
|
| 232 |
+
repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_stuff_unary)
|
| 233 |
+
if expected is None:
|
| 234 |
+
assert len(repairs) == 0
|
| 235 |
+
else:
|
| 236 |
+
assert len(repairs) == 1
|
| 237 |
+
result = reconstruct_tree(tree, repairs[0][1])
|
| 238 |
+
assert str(result) == expected
|
| 239 |
+
|
| 240 |
+
def test_general_open(gold_sequences):
|
| 241 |
+
wrong_transition = OpenConstituent("SBARQ")
|
| 242 |
+
|
| 243 |
+
for sequence in gold_sequences:
|
| 244 |
+
repairs = get_repairs(sequence, wrong_transition, fix_wrong_open_general)
|
| 245 |
+
assert len(repairs) == sum(isinstance(x, OpenConstituent) for x in sequence) - 1
|
| 246 |
+
for repair in repairs:
|
| 247 |
+
assert len(repair[1]) == len(sequence)
|
| 248 |
+
assert repair[1][repair[0]] == wrong_transition
|
| 249 |
+
assert repair[1][:repair[0]] == sequence[:repair[0]]
|
| 250 |
+
assert repair[1][repair[0]+1:] == sequence[repair[0]+1:]
|
| 251 |
+
|
| 252 |
+
def test_missed_unary(unary_trees, gold_sequences):
|
| 253 |
+
shift_transition = Shift()
|
| 254 |
+
close_transition = CloseConstituent()
|
| 255 |
+
|
| 256 |
+
expected_close_results = [
|
| 257 |
+
[(12, 2)],
|
| 258 |
+
[(11, 4), (13, 2)],
|
| 259 |
+
# (NP NN thumb) and (NP NN ballooning) are both candidates for this repair
|
| 260 |
+
[(18, 2), (24, 2)],
|
| 261 |
+
[(21, 6), (23, 4), (25, 2)],
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
expected_shift_results = [
|
| 265 |
+
(),
|
| 266 |
+
(),
|
| 267 |
+
(),
|
| 268 |
+
# (ADVP-CLR (RB along)) is followed by a shift
|
| 269 |
+
[(16, 2)],
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
for tree, sequence, expected_close, expected_shift in zip(unary_trees, gold_sequences, expected_close_results, expected_shift_results):
|
| 273 |
+
repairs = get_repairs(sequence, close_transition, fix_missed_unary)
|
| 274 |
+
assert len(repairs) == len(expected_close)
|
| 275 |
+
for repair, (expected_idx, expected_len) in zip(repairs, expected_close):
|
| 276 |
+
assert repair[0] == expected_idx
|
| 277 |
+
assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]
|
| 278 |
+
|
| 279 |
+
repairs = get_repairs(sequence, shift_transition, fix_missed_unary)
|
| 280 |
+
assert len(repairs) == len(expected_shift)
|
| 281 |
+
for repair, (expected_idx, expected_len) in zip(repairs, expected_shift):
|
| 282 |
+
assert repair[0] == expected_idx
|
| 283 |
+
assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:]
|
| 284 |
+
|
| 285 |
+
def test_open_shift(unary_trees, gold_sequences):
|
| 286 |
+
shift_transition = Shift()
|
| 287 |
+
|
| 288 |
+
expected_repairs = [
|
| 289 |
+
[(7, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"),
|
| 290 |
+
(10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VBN been) (VP (VBN set))) (. .)))")],
|
| 291 |
+
[(7, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WP who) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
|
| 292 |
+
(9, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
|
| 293 |
+
(19, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose) (NP (DT the) (NNS changes)) (. .)))"),
|
| 294 |
+
(21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (DT the) (NNS changes)) (. .)))")],
|
| 295 |
+
[(14, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"),
|
| 296 |
+
(16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"),
|
| 297 |
+
(22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about) (NP (NN ballooning)))))")],
|
| 298 |
+
[(5, "(ROOT (S (NP (NNS optimists)) (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 299 |
+
(10, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 300 |
+
(12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 301 |
+
(14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 302 |
+
(19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (RB as) (S (VP (ADVP (IN before))))))))))")]
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):
|
| 306 |
+
repairs = get_repairs(sequence, shift_transition, fix_open_shift)
|
| 307 |
+
assert len(repairs) == len(expected)
|
| 308 |
+
for repair, (idx, expected_tree) in zip(repairs, expected):
|
| 309 |
+
assert repair[0] == idx
|
| 310 |
+
result_tree = reconstruct_tree(tree, repair[1])
|
| 311 |
+
assert str(result_tree) == expected_tree
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def test_open_close(unary_trees, gold_sequences):
|
| 315 |
+
close_transition = CloseConstituent()
|
| 316 |
+
|
| 317 |
+
expected_repairs = [
|
| 318 |
+
[(7, "(ROOT (S (S (NP (DT A) (NN record) (NN date)) (VBZ has)) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"),
|
| 319 |
+
(10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VP (VBZ has) (RB n't) (VBN been)) (VP (VBN set))) (. .)))")],
|
| 320 |
+
# missed the WHNP. The surrounding SBAR cannot be created, either
|
| 321 |
+
[(7, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
|
| 322 |
+
# missed the SBAR
|
| 323 |
+
(9, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who))) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"),
|
| 324 |
+
# missed the VP around "oppose the changes"
|
| 325 |
+
(19, "(ROOT (S (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose)) (NP (DT the) (NNS changes)) (. .)))"),
|
| 326 |
+
# missed the NP in "the changes", looks pretty bad tbh
|
| 327 |
+
(21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VP (VBP oppose) (DT the)) (NNS changes)) (. .)))")],
|
| 328 |
+
[(14, "(ROOT (S (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule))) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"),
|
| 329 |
+
(16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (IN of)) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"),
|
| 330 |
+
(22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about)) (NP (NN ballooning)))))")],
|
| 331 |
+
[(5, "(ROOT (S (S (NP (NNS optimists)) (VBP expect)) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 332 |
+
(10, "(ROOT (S (NP (NNS optimists)) (VP (VP (VBP expect) (NP (NNP Hong) (NNP Kong))) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 333 |
+
(12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (S (NP (NNP Hong) (NNP Kong)) (TO to)) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 334 |
+
(14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (VP (TO to) (VB hum)) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"),
|
| 335 |
+
(19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (ADVP (RB along)) (RB as)) (S (VP (ADVP (IN before))))))))))")]
|
| 336 |
+
]
|
| 337 |
+
|
| 338 |
+
for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs):
|
| 339 |
+
repairs = get_repairs(sequence, close_transition, fix_open_close)
|
| 340 |
+
|
| 341 |
+
assert len(repairs) == len(expected)
|
| 342 |
+
for repair, (idx, expected_tree) in zip(repairs, expected):
|
| 343 |
+
assert repair[0] == idx
|
| 344 |
+
result_tree = reconstruct_tree(tree, repair[1])
|
| 345 |
+
assert str(result_tree) == expected_tree
|
| 346 |
+
|
| 347 |
+
def test_shift_close(unary_trees, gold_sequences):
|
| 348 |
+
"""
|
| 349 |
+
Test the fix for a shift -> close
|
| 350 |
+
|
| 351 |
+
These errors can occur pretty much everywhere, and the fix is quite simple,
|
| 352 |
+
so we only test a few cases.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
close_transition = CloseConstituent()
|
| 356 |
+
|
| 357 |
+
expected_tree = "(ROOT (S (NP (NP (DT A)) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))"
|
| 358 |
+
|
| 359 |
+
repairs = get_repairs(gold_sequences[0], close_transition, fix_shift_close)
|
| 360 |
+
assert len(repairs) == 7
|
| 361 |
+
result_tree = reconstruct_tree(unary_trees[0], repairs[0][1])
|
| 362 |
+
assert str(result_tree) == expected_tree
|
| 363 |
+
|
| 364 |
+
repairs = get_repairs(gold_sequences[1], close_transition, fix_shift_close)
|
| 365 |
+
assert len(repairs) == 8
|
| 366 |
+
|
| 367 |
+
repairs = get_repairs(gold_sequences[2], close_transition, fix_shift_close)
|
| 368 |
+
assert len(repairs) == 8
|
| 369 |
+
|
| 370 |
+
repairs = get_repairs(gold_sequences[3], close_transition, fix_shift_close)
|
| 371 |
+
assert len(repairs) == 9
|
| 372 |
+
for rep in repairs:
|
| 373 |
+
if rep[0] == 16:
|
| 374 |
+
# This one is special because it occurs as part of a unary
|
| 375 |
+
# in other words, it should go unary, shift
|
| 376 |
+
# and instead we are making it close where the unary should be
|
| 377 |
+
# ... the unary would create "(ADVP (RB along))"
|
| 378 |
+
expected_tree = "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))"
|
| 379 |
+
result_tree = reconstruct_tree(unary_trees[3], rep[1])
|
| 380 |
+
assert str(result_tree) == expected_tree
|
| 381 |
+
break
|
| 382 |
+
else:
|
| 383 |
+
raise AssertionError("Did not find an expected repair location")
|
| 384 |
+
|
| 385 |
+
def test_close_open_shift_nested(unary_trees, gold_sequences):
|
| 386 |
+
shift_transition = Shift()
|
| 387 |
+
|
| 388 |
+
expected_trees = [{},
|
| 389 |
+
{4: "(ROOT (S (NP (RB Not) (PDT all) (DT those) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"},
|
| 390 |
+
{4: "(ROOT (S (VP (VB See)) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
|
| 391 |
+
13: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"},
|
| 392 |
+
{}]
|
| 393 |
+
|
| 394 |
+
for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees):
|
| 395 |
+
repairs = get_repairs(gold_sequence, shift_transition, fix_close_open_shift_nested)
|
| 396 |
+
assert len(repairs) == len(expected)
|
| 397 |
+
if len(expected) >= 1:
|
| 398 |
+
for repair in repairs:
|
| 399 |
+
assert repair[0] in expected.keys()
|
| 400 |
+
result_tree = reconstruct_tree(tree, repair[1])
|
| 401 |
+
assert str(result_tree) == expected[repair[0]]
|
| 402 |
+
|
| 403 |
+
def check_repairs(trees, gold_sequences, expected_trees, transition, repair_fn):
|
| 404 |
+
for tree_idx, (gold_tree, gold_sequence, expected) in enumerate(zip(trees, gold_sequences, expected_trees)):
|
| 405 |
+
repairs = get_repairs(gold_sequence, transition, repair_fn)
|
| 406 |
+
if expected is not None:
|
| 407 |
+
assert len(repairs) == len(expected)
|
| 408 |
+
for repair in repairs:
|
| 409 |
+
assert repair[0] in expected
|
| 410 |
+
result_tree = reconstruct_tree(gold_tree, repair[1])
|
| 411 |
+
assert str(result_tree) == expected[repair[0]]
|
| 412 |
+
else:
|
| 413 |
+
print("---------------------")
|
| 414 |
+
print("{:P}".format(gold_tree))
|
| 415 |
+
print(gold_sequence)
|
| 416 |
+
#print(repairs)
|
| 417 |
+
for repair in repairs:
|
| 418 |
+
print("---------------------")
|
| 419 |
+
print(gold_sequence)
|
| 420 |
+
print(repair[1])
|
| 421 |
+
result_tree = reconstruct_tree(gold_tree, repair[1])
|
| 422 |
+
print("{:P}".format(gold_tree))
|
| 423 |
+
print("{:P}".format(result_tree))
|
| 424 |
+
print(tree_idx)
|
| 425 |
+
print(repair[0])
|
| 426 |
+
print(result_tree)
|
| 427 |
+
|
| 428 |
+
def test_close_open_shift_unambiguous(unary_trees, gold_sequences):
|
| 429 |
+
shift_transition = Shift()
|
| 430 |
+
|
| 431 |
+
expected_trees = [{},
|
| 432 |
+
{8: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who) (S (VP (VBD wrote)))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"},
|
| 433 |
+
{},
|
| 434 |
+
{2: "(ROOT (S (NP (NNS optimists) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))",
|
| 435 |
+
9: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"}]
|
| 436 |
+
check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_unambiguous_bracket)
|
| 437 |
+
|
| 438 |
+
def test_close_open_shift_ambiguous_early(unary_trees, gold_sequences):
|
| 439 |
+
shift_transition = Shift()
|
| 440 |
+
|
| 441 |
+
expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))))) (. .)))"},
|
| 442 |
+
{16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes)))) (. .)))"},
|
| 443 |
+
{2: "(ROOT (S (PRN (S (VP (VB See) (, ,)))) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))",
|
| 444 |
+
6: "(ROOT (S (PRN (S (VP (VB See))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"},
|
| 445 |
+
{}]
|
| 446 |
+
check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_early)
|
| 447 |
+
|
| 448 |
+
def test_close_open_shift_ambiguous_late(unary_trees, gold_sequences):
|
| 449 |
+
shift_transition = Shift()
|
| 450 |
+
|
| 451 |
+
expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))))"},
|
| 452 |
+
{16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .))))"},
|
| 453 |
+
{2: "(ROOT (S (PRN (S (VP (VB See) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))))",
|
| 454 |
+
6: "(ROOT (S (PRN (S (VP (VB See))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))"},
|
| 455 |
+
{}]
|
| 456 |
+
check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_late)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def test_close_shift_shift(unary_trees, wide_trees):
|
| 460 |
+
"""
|
| 461 |
+
Test that close -> shift works when there is a single block shifted after
|
| 462 |
+
|
| 463 |
+
Includes a test specifically that there is no oracle action when there are two blocks after the missed close
|
| 464 |
+
"""
|
| 465 |
+
shift_transition = Shift()
|
| 466 |
+
|
| 467 |
+
expected_trees = [{15: "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .))))"},
|
| 468 |
+
{24: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes)) (. .))))"},
|
| 469 |
+
{20: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning)))))))"},
|
| 470 |
+
{17: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"},
|
| 471 |
+
{},
|
| 472 |
+
{}]
|
| 473 |
+
|
| 474 |
+
test_trees = unary_trees + wide_trees
|
| 475 |
+
gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
|
| 476 |
+
|
| 477 |
+
check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_unambiguous)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def test_close_shift_shift_early(unary_trees, wide_trees):
|
| 481 |
+
"""
|
| 482 |
+
Test that close -> shift works when there are multiple blocks shifted after
|
| 483 |
+
|
| 484 |
+
Also checks that the single block case is skipped, so as to keep them separate when testing
|
| 485 |
+
|
| 486 |
+
A tree with the expected property was specifically added for this test
|
| 487 |
+
"""
|
| 488 |
+
shift_transition = Shift()
|
| 489 |
+
|
| 490 |
+
test_trees = unary_trees + wide_trees
|
| 491 |
+
gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
|
| 492 |
+
|
| 493 |
+
expected_trees = [{},
|
| 494 |
+
{},
|
| 495 |
+
{},
|
| 496 |
+
{},
|
| 497 |
+
{},
|
| 498 |
+
{21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental)) (NN system)) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}]
|
| 499 |
+
|
| 500 |
+
check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_early)
|
| 501 |
+
|
| 502 |
+
def test_close_shift_shift_late(unary_trees, wide_trees):
|
| 503 |
+
"""
|
| 504 |
+
Test that close -> shift works when there are multiple blocks shifted after
|
| 505 |
+
|
| 506 |
+
Also checks that the single block case is skipped, so as to keep them separate when testing
|
| 507 |
+
|
| 508 |
+
A tree with the expected property was specifically added for this test
|
| 509 |
+
"""
|
| 510 |
+
shift_transition = Shift()
|
| 511 |
+
|
| 512 |
+
test_trees = unary_trees + wide_trees
|
| 513 |
+
gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER)
|
| 514 |
+
|
| 515 |
+
expected_trees = [{},
|
| 516 |
+
{},
|
| 517 |
+
{},
|
| 518 |
+
{},
|
| 519 |
+
{},
|
| 520 |
+
{21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental) (NN system))) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}]
|
| 521 |
+
|
| 522 |
+
check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_late)
|
stanza/stanza/tests/constituency/test_lstm_model.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from stanza.models.common import pretrain
|
| 7 |
+
from stanza.models.common.utils import set_random_seed
|
| 8 |
+
from stanza.models.constituency import parse_transitions
|
| 9 |
+
from stanza.tests import *
|
| 10 |
+
from stanza.tests.constituency import test_parse_transitions
|
| 11 |
+
from stanza.tests.constituency.test_trainer import build_trainer
|
| 12 |
+
|
| 13 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 14 |
+
|
| 15 |
+
@pytest.fixture(scope="module")
|
| 16 |
+
def pretrain_file():
|
| 17 |
+
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
|
| 18 |
+
|
| 19 |
+
def build_model(pretrain_file, *args):
|
| 20 |
+
# By default, we turn off multistage, since that can turn off various other structures in the initial training
|
| 21 |
+
args = ['--no_multistage', '--pattn_num_layers', '4', '--pattn_d_model', '256', '--hidden_size', '128', '--use_lattn'] + list(args)
|
| 22 |
+
trainer = build_trainer(pretrain_file, *args)
|
| 23 |
+
return trainer.model
|
| 24 |
+
|
| 25 |
+
@pytest.fixture(scope="module")
|
| 26 |
+
def unary_model(pretrain_file):
|
| 27 |
+
return build_model(pretrain_file, "--transition_scheme", "TOP_DOWN_UNARY")
|
| 28 |
+
|
| 29 |
+
def test_initial_state(unary_model):
|
| 30 |
+
test_parse_transitions.test_initial_state(unary_model)
|
| 31 |
+
|
| 32 |
+
def test_shift(pretrain_file):
|
| 33 |
+
# TODO: might be good to include some tests specifically for shift
|
| 34 |
+
# in the context of a model with unaries
|
| 35 |
+
model = build_model(pretrain_file)
|
| 36 |
+
test_parse_transitions.test_shift(model)
|
| 37 |
+
|
| 38 |
+
def test_unary(unary_model):
|
| 39 |
+
test_parse_transitions.test_unary(unary_model)
|
| 40 |
+
|
| 41 |
+
def test_unary_requires_root(unary_model):
|
| 42 |
+
test_parse_transitions.test_unary_requires_root(unary_model)
|
| 43 |
+
|
| 44 |
+
def test_open(unary_model):
|
| 45 |
+
test_parse_transitions.test_open(unary_model)
|
| 46 |
+
|
| 47 |
+
def test_compound_open(pretrain_file):
|
| 48 |
+
model = build_model(pretrain_file, '--transition_scheme', "TOP_DOWN_COMPOUND")
|
| 49 |
+
test_parse_transitions.test_compound_open(model)
|
| 50 |
+
|
| 51 |
+
def test_in_order_open(pretrain_file):
|
| 52 |
+
model = build_model(pretrain_file, '--transition_scheme', "IN_ORDER")
|
| 53 |
+
test_parse_transitions.test_in_order_open(model)
|
| 54 |
+
|
| 55 |
+
def test_close(unary_model):
|
| 56 |
+
test_parse_transitions.test_close(unary_model)
|
| 57 |
+
|
| 58 |
+
def run_forward_checks(model, num_states=1):
|
| 59 |
+
"""
|
| 60 |
+
Run a couple small transitions and a forward pass on the given model
|
| 61 |
+
|
| 62 |
+
Results are not checked in any way. This function allows for
|
| 63 |
+
testing that building models with various options results in a
|
| 64 |
+
functional model.
|
| 65 |
+
"""
|
| 66 |
+
states = test_parse_transitions.build_initial_state(model, num_states)
|
| 67 |
+
model(states)
|
| 68 |
+
|
| 69 |
+
shift = parse_transitions.Shift()
|
| 70 |
+
shifts = [shift for _ in range(num_states)]
|
| 71 |
+
states = model.bulk_apply(states, shifts)
|
| 72 |
+
model(states)
|
| 73 |
+
|
| 74 |
+
open_transition = parse_transitions.OpenConstituent("NP")
|
| 75 |
+
open_transitions = [open_transition for _ in range(num_states)]
|
| 76 |
+
assert open_transition.is_legal(states[0], model)
|
| 77 |
+
states = model.bulk_apply(states, open_transitions)
|
| 78 |
+
assert states[0].num_opens == 1
|
| 79 |
+
model(states)
|
| 80 |
+
|
| 81 |
+
states = model.bulk_apply(states, shifts)
|
| 82 |
+
model(states)
|
| 83 |
+
states = model.bulk_apply(states, shifts)
|
| 84 |
+
model(states)
|
| 85 |
+
assert states[0].num_opens == 1
|
| 86 |
+
# now should have "mox", "opal" on the constituents
|
| 87 |
+
|
| 88 |
+
close_transition = parse_transitions.CloseConstituent()
|
| 89 |
+
close_transitions = [close_transition for _ in range(num_states)]
|
| 90 |
+
assert close_transition.is_legal(states[0], model)
|
| 91 |
+
states = model.bulk_apply(states, close_transitions)
|
| 92 |
+
assert states[0].num_opens == 0
|
| 93 |
+
|
| 94 |
+
model(states)
|
| 95 |
+
|
| 96 |
+
def test_unary_forward(unary_model):
|
| 97 |
+
"""
|
| 98 |
+
Checks that the forward pass doesn't crash when run after various operations
|
| 99 |
+
|
| 100 |
+
Doesn't check the forward pass for making reasonable answers
|
| 101 |
+
"""
|
| 102 |
+
run_forward_checks(unary_model)
|
| 103 |
+
|
| 104 |
+
def test_lstm_forward(pretrain_file):
|
| 105 |
+
model = build_model(pretrain_file)
|
| 106 |
+
run_forward_checks(model, num_states=1)
|
| 107 |
+
run_forward_checks(model, num_states=2)
|
| 108 |
+
|
| 109 |
+
def test_lstm_layers(pretrain_file):
|
| 110 |
+
model = build_model(pretrain_file, '--num_lstm_layers', '1')
|
| 111 |
+
run_forward_checks(model)
|
| 112 |
+
model = build_model(pretrain_file, '--num_lstm_layers', '2')
|
| 113 |
+
run_forward_checks(model)
|
| 114 |
+
model = build_model(pretrain_file, '--num_lstm_layers', '3')
|
| 115 |
+
run_forward_checks(model)
|
| 116 |
+
|
| 117 |
+
def test_multiple_output_forward(pretrain_file):
|
| 118 |
+
"""
|
| 119 |
+
Test a couple different sizes of output layers
|
| 120 |
+
"""
|
| 121 |
+
model = build_model(pretrain_file, '--num_output_layers', '1', '--num_lstm_layers', '2')
|
| 122 |
+
run_forward_checks(model)
|
| 123 |
+
|
| 124 |
+
model = build_model(pretrain_file, '--num_output_layers', '2', '--num_lstm_layers', '2')
|
| 125 |
+
run_forward_checks(model)
|
| 126 |
+
|
| 127 |
+
model = build_model(pretrain_file, '--num_output_layers', '3', '--num_lstm_layers', '2')
|
| 128 |
+
run_forward_checks(model)
|
| 129 |
+
|
| 130 |
+
def test_no_tag_embedding_forward(pretrain_file):
|
| 131 |
+
"""
|
| 132 |
+
Test that the model continues to work if the tag embedding is turned on or off
|
| 133 |
+
"""
|
| 134 |
+
model = build_model(pretrain_file, '--tag_embedding_dim', '20')
|
| 135 |
+
run_forward_checks(model)
|
| 136 |
+
|
| 137 |
+
model = build_model(pretrain_file, '--tag_embedding_dim', '0')
|
| 138 |
+
run_forward_checks(model)
|
| 139 |
+
|
| 140 |
+
def test_forward_combined_dummy(pretrain_file):
|
| 141 |
+
"""
|
| 142 |
+
Tests combined dummy and open node embeddings
|
| 143 |
+
"""
|
| 144 |
+
model = build_model(pretrain_file, '--combined_dummy_embedding')
|
| 145 |
+
run_forward_checks(model)
|
| 146 |
+
|
| 147 |
+
model = build_model(pretrain_file, '--no_combined_dummy_embedding')
|
| 148 |
+
run_forward_checks(model)
|
| 149 |
+
|
| 150 |
+
def test_nonlinearity_init(pretrain_file):
|
| 151 |
+
"""
|
| 152 |
+
Tests that different initialization methods of the nonlinearities result in valid tensors
|
| 153 |
+
"""
|
| 154 |
+
model = build_model(pretrain_file, '--nonlinearity', 'relu')
|
| 155 |
+
run_forward_checks(model)
|
| 156 |
+
|
| 157 |
+
model = build_model(pretrain_file, '--nonlinearity', 'tanh')
|
| 158 |
+
run_forward_checks(model)
|
| 159 |
+
|
| 160 |
+
model = build_model(pretrain_file, '--nonlinearity', 'silu')
|
| 161 |
+
run_forward_checks(model)
|
| 162 |
+
|
| 163 |
+
def test_forward_charlm(pretrain_file):
|
| 164 |
+
"""
|
| 165 |
+
Tests loading and running a charlm
|
| 166 |
+
|
| 167 |
+
Note that this doesn't test the results of the charlm itself,
|
| 168 |
+
just that the model is shaped correctly
|
| 169 |
+
"""
|
| 170 |
+
forward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "1billion.pt")
|
| 171 |
+
backward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "1billion.pt")
|
| 172 |
+
assert os.path.exists(forward_charlm_path), "Need to download en test models (or update path to the forward charlm)"
|
| 173 |
+
assert os.path.exists(backward_charlm_path), "Need to download en test models (or update path to the backward charlm)"
|
| 174 |
+
|
| 175 |
+
model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'none')
|
| 176 |
+
run_forward_checks(model)
|
| 177 |
+
|
| 178 |
+
model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'words')
|
| 179 |
+
run_forward_checks(model)
|
| 180 |
+
|
| 181 |
+
def test_forward_bert(pretrain_file):
|
| 182 |
+
"""
|
| 183 |
+
Test on a tiny Bert, which hopefully does not take up too much disk space or memory
|
| 184 |
+
"""
|
| 185 |
+
bert_model = "hf-internal-testing/tiny-bert"
|
| 186 |
+
|
| 187 |
+
model = build_model(pretrain_file, '--bert_model', bert_model)
|
| 188 |
+
run_forward_checks(model)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def test_forward_xlnet(pretrain_file):
|
| 192 |
+
"""
|
| 193 |
+
Test on a tiny xlnet, which hopefully does not take up too much disk space or memory
|
| 194 |
+
"""
|
| 195 |
+
bert_model = "hf-internal-testing/tiny-random-xlnet"
|
| 196 |
+
|
| 197 |
+
model = build_model(pretrain_file, '--bert_model', bert_model)
|
| 198 |
+
run_forward_checks(model)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def test_forward_sentence_boundaries(pretrain_file):
|
| 202 |
+
"""
|
| 203 |
+
Test start & stop boundary vectors
|
| 204 |
+
"""
|
| 205 |
+
model = build_model(pretrain_file, '--sentence_boundary_vectors', 'everything')
|
| 206 |
+
run_forward_checks(model)
|
| 207 |
+
|
| 208 |
+
model = build_model(pretrain_file, '--sentence_boundary_vectors', 'words')
|
| 209 |
+
run_forward_checks(model)
|
| 210 |
+
|
| 211 |
+
model = build_model(pretrain_file, '--sentence_boundary_vectors', 'none')
|
| 212 |
+
run_forward_checks(model)
|
| 213 |
+
|
| 214 |
+
def test_forward_constituency_composition(pretrain_file):
|
| 215 |
+
"""
|
| 216 |
+
Test different constituency composition functions
|
| 217 |
+
"""
|
| 218 |
+
model = build_model(pretrain_file, '--constituency_composition', 'bilstm')
|
| 219 |
+
run_forward_checks(model, num_states=2)
|
| 220 |
+
|
| 221 |
+
model = build_model(pretrain_file, '--constituency_composition', 'max')
|
| 222 |
+
run_forward_checks(model, num_states=2)
|
| 223 |
+
|
| 224 |
+
model = build_model(pretrain_file, '--constituency_composition', 'key')
|
| 225 |
+
run_forward_checks(model, num_states=2)
|
| 226 |
+
|
| 227 |
+
model = build_model(pretrain_file, '--constituency_composition', 'untied_key')
|
| 228 |
+
run_forward_checks(model, num_states=2)
|
| 229 |
+
|
| 230 |
+
model = build_model(pretrain_file, '--constituency_composition', 'untied_max')
|
| 231 |
+
run_forward_checks(model, num_states=2)
|
| 232 |
+
|
| 233 |
+
model = build_model(pretrain_file, '--constituency_composition', 'bilstm_max')
|
| 234 |
+
run_forward_checks(model, num_states=2)
|
| 235 |
+
|
| 236 |
+
model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm')
|
| 237 |
+
run_forward_checks(model, num_states=2)
|
| 238 |
+
|
| 239 |
+
model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm_cx')
|
| 240 |
+
run_forward_checks(model, num_states=2)
|
| 241 |
+
|
| 242 |
+
model = build_model(pretrain_file, '--constituency_composition', 'bigram')
|
| 243 |
+
run_forward_checks(model, num_states=2)
|
| 244 |
+
|
| 245 |
+
model = build_model(pretrain_file, '--constituency_composition', 'attn')
|
| 246 |
+
run_forward_checks(model, num_states=2)
|
| 247 |
+
|
| 248 |
+
def test_forward_key_position(pretrain_file):
|
| 249 |
+
"""
|
| 250 |
+
Test KEY and UNTIED_KEY either with or without reduce_position
|
| 251 |
+
"""
|
| 252 |
+
model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0')
|
| 253 |
+
run_forward_checks(model, num_states=2)
|
| 254 |
+
|
| 255 |
+
model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32')
|
| 256 |
+
run_forward_checks(model, num_states=2)
|
| 257 |
+
|
| 258 |
+
model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0')
|
| 259 |
+
run_forward_checks(model, num_states=2)
|
| 260 |
+
|
| 261 |
+
model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32')
|
| 262 |
+
run_forward_checks(model, num_states=2)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def test_forward_attn_hidden_size(pretrain_file):
|
| 266 |
+
"""
|
| 267 |
+
Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size
|
| 268 |
+
"""
|
| 269 |
+
model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129')
|
| 270 |
+
assert model.hidden_size >= 129
|
| 271 |
+
assert model.hidden_size % model.reduce_heads == 0
|
| 272 |
+
run_forward_checks(model, num_states=2)
|
| 273 |
+
|
| 274 |
+
model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129', '--reduce_heads', '10')
|
| 275 |
+
assert model.hidden_size == 130
|
| 276 |
+
assert model.reduce_heads == 10
|
| 277 |
+
|
| 278 |
+
def test_forward_partitioned_attention(pretrain_file):
|
| 279 |
+
"""
|
| 280 |
+
Test with & without partitioned attention layers
|
| 281 |
+
"""
|
| 282 |
+
model = build_model(pretrain_file, '--pattn_num_heads', '8', '--pattn_num_layers', '8')
|
| 283 |
+
run_forward_checks(model)
|
| 284 |
+
|
| 285 |
+
model = build_model(pretrain_file, '--pattn_num_heads', '0', '--pattn_num_layers', '0')
|
| 286 |
+
run_forward_checks(model)
|
| 287 |
+
|
| 288 |
+
def test_forward_labeled_attention(pretrain_file):
|
| 289 |
+
"""
|
| 290 |
+
Test with & without labeled attention layers
|
| 291 |
+
"""
|
| 292 |
+
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16')
|
| 293 |
+
run_forward_checks(model)
|
| 294 |
+
|
| 295 |
+
model = build_model(pretrain_file, '--lattn_d_proj', '0', '--lattn_d_l', '0')
|
| 296 |
+
run_forward_checks(model)
|
| 297 |
+
|
| 298 |
+
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input')
|
| 299 |
+
run_forward_checks(model)
|
| 300 |
+
|
| 301 |
+
def test_lattn_partitioned(pretrain_file):
|
| 302 |
+
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned')
|
| 303 |
+
run_forward_checks(model)
|
| 304 |
+
|
| 305 |
+
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned')
|
| 306 |
+
run_forward_checks(model)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def test_lattn_projection(pretrain_file):
|
| 310 |
+
"""
|
| 311 |
+
Test with & without labeled attention layers
|
| 312 |
+
"""
|
| 313 |
+
with pytest.raises(ValueError):
|
| 314 |
+
# this is too small
|
| 315 |
+
model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256', '--lattn_partitioned')
|
| 316 |
+
run_forward_checks(model)
|
| 317 |
+
|
| 318 |
+
model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256')
|
| 319 |
+
run_forward_checks(model)
|
| 320 |
+
|
| 321 |
+
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768')
|
| 322 |
+
run_forward_checks(model)
|
| 323 |
+
|
| 324 |
+
# check that it works if we turn off the projection,
|
| 325 |
+
# in case having it on beccomes the default
|
| 326 |
+
model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '0')
|
| 327 |
+
run_forward_checks(model)
|
| 328 |
+
|
| 329 |
+
def test_forward_timing_choices(pretrain_file):
|
| 330 |
+
"""
|
| 331 |
+
Test different timing / position encodings
|
| 332 |
+
"""
|
| 333 |
+
model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'sin')
|
| 334 |
+
run_forward_checks(model)
|
| 335 |
+
|
| 336 |
+
model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'learned')
|
| 337 |
+
run_forward_checks(model)
|
| 338 |
+
|
| 339 |
+
def test_transition_stack(pretrain_file):
|
| 340 |
+
"""
|
| 341 |
+
Test different transition stack types: lstm & attention
|
| 342 |
+
"""
|
| 343 |
+
model = build_model(pretrain_file,
|
| 344 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 345 |
+
'--transition_stack', 'attn', '--transition_heads', '1')
|
| 346 |
+
run_forward_checks(model)
|
| 347 |
+
|
| 348 |
+
model = build_model(pretrain_file,
|
| 349 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 350 |
+
'--transition_stack', 'attn', '--transition_heads', '4')
|
| 351 |
+
run_forward_checks(model)
|
| 352 |
+
|
| 353 |
+
model = build_model(pretrain_file,
|
| 354 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 355 |
+
'--transition_stack', 'lstm')
|
| 356 |
+
run_forward_checks(model)
|
| 357 |
+
|
| 358 |
+
def test_constituent_stack(pretrain_file):
|
| 359 |
+
"""
|
| 360 |
+
Test different constituent stack types: lstm & attention
|
| 361 |
+
"""
|
| 362 |
+
model = build_model(pretrain_file,
|
| 363 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 364 |
+
'--constituent_stack', 'attn', '--constituent_heads', '1')
|
| 365 |
+
run_forward_checks(model)
|
| 366 |
+
|
| 367 |
+
model = build_model(pretrain_file,
|
| 368 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 369 |
+
'--constituent_stack', 'attn', '--constituent_heads', '4')
|
| 370 |
+
run_forward_checks(model)
|
| 371 |
+
|
| 372 |
+
model = build_model(pretrain_file,
|
| 373 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 374 |
+
'--constituent_stack', 'lstm')
|
| 375 |
+
run_forward_checks(model)
|
| 376 |
+
|
| 377 |
+
def test_different_transition_sizes(pretrain_file):
|
| 378 |
+
"""
|
| 379 |
+
If the transition hidden size and embedding size are different, the model should still work
|
| 380 |
+
"""
|
| 381 |
+
model = build_model(pretrain_file,
|
| 382 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 383 |
+
'--transition_embedding_dim', '10', '--transition_hidden_size', '10',
|
| 384 |
+
'--sentence_boundary_vectors', 'everything')
|
| 385 |
+
run_forward_checks(model)
|
| 386 |
+
|
| 387 |
+
model = build_model(pretrain_file,
|
| 388 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 389 |
+
'--transition_embedding_dim', '20', '--transition_hidden_size', '10',
|
| 390 |
+
'--sentence_boundary_vectors', 'everything')
|
| 391 |
+
run_forward_checks(model)
|
| 392 |
+
|
| 393 |
+
model = build_model(pretrain_file,
|
| 394 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 395 |
+
'--transition_embedding_dim', '10', '--transition_hidden_size', '20',
|
| 396 |
+
'--sentence_boundary_vectors', 'everything')
|
| 397 |
+
run_forward_checks(model)
|
| 398 |
+
|
| 399 |
+
model = build_model(pretrain_file,
|
| 400 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 401 |
+
'--transition_embedding_dim', '10', '--transition_hidden_size', '10',
|
| 402 |
+
'--sentence_boundary_vectors', 'none')
|
| 403 |
+
run_forward_checks(model)
|
| 404 |
+
|
| 405 |
+
model = build_model(pretrain_file,
|
| 406 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 407 |
+
'--transition_embedding_dim', '20', '--transition_hidden_size', '10',
|
| 408 |
+
'--sentence_boundary_vectors', 'none')
|
| 409 |
+
run_forward_checks(model)
|
| 410 |
+
|
| 411 |
+
model = build_model(pretrain_file,
|
| 412 |
+
'--pattn_num_layers', '0', '--lattn_d_proj', '0',
|
| 413 |
+
'--transition_embedding_dim', '10', '--transition_hidden_size', '20',
|
| 414 |
+
'--sentence_boundary_vectors', 'none')
|
| 415 |
+
run_forward_checks(model)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def test_lstm_tree_forward(pretrain_file):
|
| 419 |
+
"""
|
| 420 |
+
Test the LSTM_TREE forward pass
|
| 421 |
+
"""
|
| 422 |
+
model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm')
|
| 423 |
+
run_forward_checks(model)
|
| 424 |
+
model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm')
|
| 425 |
+
run_forward_checks(model)
|
| 426 |
+
model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm')
|
| 427 |
+
run_forward_checks(model)
|
| 428 |
+
|
| 429 |
+
def test_lstm_tree_cx_forward(pretrain_file):
|
| 430 |
+
"""
|
| 431 |
+
Test the LSTM_TREE_CX forward pass
|
| 432 |
+
"""
|
| 433 |
+
model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm_cx')
|
| 434 |
+
run_forward_checks(model)
|
| 435 |
+
model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm_cx')
|
| 436 |
+
run_forward_checks(model)
|
| 437 |
+
model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm_cx')
|
| 438 |
+
run_forward_checks(model)
|
| 439 |
+
|
| 440 |
+
def test_maxout(pretrain_file):
|
| 441 |
+
"""
|
| 442 |
+
Test with and without maxout layers for output
|
| 443 |
+
"""
|
| 444 |
+
model = build_model(pretrain_file, '--maxout_k', '0')
|
| 445 |
+
run_forward_checks(model)
|
| 446 |
+
# check the output size & implicitly check the type
|
| 447 |
+
# to check for a particularly silly bug
|
| 448 |
+
assert model.output_layers[-1].weight.shape[0] == len(model.transitions)
|
| 449 |
+
|
| 450 |
+
model = build_model(pretrain_file, '--maxout_k', '2')
|
| 451 |
+
run_forward_checks(model)
|
| 452 |
+
assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 2
|
| 453 |
+
|
| 454 |
+
model = build_model(pretrain_file, '--maxout_k', '3')
|
| 455 |
+
run_forward_checks(model)
|
| 456 |
+
assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 3
|
| 457 |
+
|
| 458 |
+
def check_structure_test(pretrain_file, args1, args2):
|
| 459 |
+
"""
|
| 460 |
+
Test that the "copy" method copies the parameters from one model to another
|
| 461 |
+
|
| 462 |
+
Also check that the copied models produce the same results
|
| 463 |
+
"""
|
| 464 |
+
set_random_seed(1000)
|
| 465 |
+
other = build_model(pretrain_file, *args1)
|
| 466 |
+
other.eval()
|
| 467 |
+
|
| 468 |
+
set_random_seed(1001)
|
| 469 |
+
model = build_model(pretrain_file, *args2)
|
| 470 |
+
model.eval()
|
| 471 |
+
|
| 472 |
+
assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
|
| 473 |
+
assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
|
| 474 |
+
|
| 475 |
+
model.copy_with_new_structure(other)
|
| 476 |
+
|
| 477 |
+
assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight)
|
| 478 |
+
assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight)
|
| 479 |
+
# the norms will be the same, as the non-zero values are all the same
|
| 480 |
+
assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0))
|
| 481 |
+
|
| 482 |
+
# now, check that applying one transition to an initial state
|
| 483 |
+
# results in the same values in the output states for both models
|
| 484 |
+
# as the pattn layer inputs are 0, the output values should be equal
|
| 485 |
+
shift = [parse_transitions.Shift()]
|
| 486 |
+
model_states = test_parse_transitions.build_initial_state(model, 1)
|
| 487 |
+
model_states = model.bulk_apply(model_states, shift)
|
| 488 |
+
|
| 489 |
+
other_states = test_parse_transitions.build_initial_state(other, 1)
|
| 490 |
+
other_states = other.bulk_apply(other_states, shift)
|
| 491 |
+
|
| 492 |
+
for i, j in zip(other_states[0].word_queue, model_states[0].word_queue):
|
| 493 |
+
assert torch.allclose(i.hx, j.hx, atol=1e-07)
|
| 494 |
+
for i, j in zip(other_states[0].transitions, model_states[0].transitions):
|
| 495 |
+
assert torch.allclose(i.lstm_hx, j.lstm_hx)
|
| 496 |
+
assert torch.allclose(i.lstm_cx, j.lstm_cx)
|
| 497 |
+
for i, j in zip(other_states[0].constituents, model_states[0].constituents):
|
| 498 |
+
assert (i.value is None) == (j.value is None)
|
| 499 |
+
if i.value is not None:
|
| 500 |
+
assert torch.allclose(i.value.tree_hx, j.value.tree_hx, atol=1e-07)
|
| 501 |
+
assert torch.allclose(i.lstm_hx, j.lstm_hx)
|
| 502 |
+
assert torch.allclose(i.lstm_cx, j.lstm_cx)
|
| 503 |
+
|
| 504 |
+
def test_copy_with_new_structure_same(pretrain_file):
|
| 505 |
+
"""
|
| 506 |
+
Test that copying the structure with no changes works as expected
|
| 507 |
+
"""
|
| 508 |
+
check_structure_test(pretrain_file,
|
| 509 |
+
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
|
| 510 |
+
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'])
|
| 511 |
+
|
| 512 |
+
def test_copy_with_new_structure_untied(pretrain_file):
|
| 513 |
+
"""
|
| 514 |
+
Test that copying the structure with no changes works as expected
|
| 515 |
+
"""
|
| 516 |
+
check_structure_test(pretrain_file,
|
| 517 |
+
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'MAX'],
|
| 518 |
+
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'UNTIED_MAX'])
|
| 519 |
+
|
| 520 |
+
def test_copy_with_new_structure_pattn(pretrain_file):
|
| 521 |
+
check_structure_test(pretrain_file,
|
| 522 |
+
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
|
| 523 |
+
['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
|
| 524 |
+
|
| 525 |
+
def test_copy_with_new_structure_both(pretrain_file):
|
| 526 |
+
check_structure_test(pretrain_file,
|
| 527 |
+
['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'],
|
| 528 |
+
['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
|
| 529 |
+
|
| 530 |
+
def test_copy_with_new_structure_lattn(pretrain_file):
|
| 531 |
+
check_structure_test(pretrain_file,
|
| 532 |
+
['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'],
|
| 533 |
+
['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'])
|
| 534 |
+
|
| 535 |
+
def test_parse_tagged_words(pretrain_file):
|
| 536 |
+
"""
|
| 537 |
+
Small test which doesn't check results, just execution
|
| 538 |
+
"""
|
| 539 |
+
model = build_model(pretrain_file)
|
| 540 |
+
|
| 541 |
+
sentence = [("I", "PRP"), ("am", "VBZ"), ("Luffa", "NNP")]
|
| 542 |
+
|
| 543 |
+
# we don't expect a useful tree out of a random model
|
| 544 |
+
# so we don't check the result
|
| 545 |
+
# just check that it works without crashing
|
| 546 |
+
result = model.parse_tagged_words([sentence], 10)
|
| 547 |
+
assert len(result) == 1
|
| 548 |
+
pts = [x for x in result[0].yield_preterminals()]
|
| 549 |
+
|
| 550 |
+
for word, pt in zip(sentence, pts):
|
| 551 |
+
assert pt.children[0].label == word[0]
|
| 552 |
+
assert pt.label == word[1]
|
stanza/stanza/tests/constituency/test_text_processing.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run through the various text processing methods for using the parser on text files / directories
|
| 3 |
+
|
| 4 |
+
Uses a simple tree where the parser should always get it right, but things could potentially go wrong
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import glob
|
| 8 |
+
import os
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
from stanza import Pipeline
|
| 12 |
+
|
| 13 |
+
from stanza.models.constituency import text_processing
|
| 14 |
+
from stanza.models.constituency import tree_reader
|
| 15 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 16 |
+
|
| 17 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 18 |
+
|
| 19 |
+
@pytest.fixture(scope="module")
|
| 20 |
+
def pipeline():
|
| 21 |
+
return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True)
|
| 22 |
+
|
| 23 |
+
def test_read_tokenized_file(tmp_path):
|
| 24 |
+
filename = str(tmp_path / "test_input.txt")
|
| 25 |
+
with open(filename, "w") as fout:
|
| 26 |
+
# test that the underscore token comes back with spaces
|
| 27 |
+
fout.write("This is a_small test\nLine two\n")
|
| 28 |
+
text, ids = text_processing.read_tokenized_file(filename)
|
| 29 |
+
assert text == [['This', 'is', 'a small', 'test'], ['Line', 'two']]
|
| 30 |
+
assert ids == [None, None]
|
| 31 |
+
|
| 32 |
+
def test_parse_tokenized_sentences(pipeline):
|
| 33 |
+
con_processor = pipeline.processors["constituency"]
|
| 34 |
+
model = con_processor._model
|
| 35 |
+
args = model.args
|
| 36 |
+
|
| 37 |
+
sentences = [["This", "is", "a", "test"]]
|
| 38 |
+
trees = text_processing.parse_tokenized_sentences(args, model, [pipeline], sentences)
|
| 39 |
+
predictions = [x.predictions for x in trees]
|
| 40 |
+
assert len(predictions) == 1
|
| 41 |
+
scored_trees = predictions[0]
|
| 42 |
+
assert len(scored_trees) == 1
|
| 43 |
+
result = "{}".format(scored_trees[0].tree)
|
| 44 |
+
expected = "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))"
|
| 45 |
+
assert result == expected
|
| 46 |
+
|
| 47 |
+
def test_parse_text(tmp_path, pipeline):
|
| 48 |
+
con_processor = pipeline.processors["constituency"]
|
| 49 |
+
model = con_processor._model
|
| 50 |
+
args = model.args
|
| 51 |
+
|
| 52 |
+
raw_file = str(tmp_path / "test_input.txt")
|
| 53 |
+
with open(raw_file, "w") as fout:
|
| 54 |
+
fout.write("This is a test\nThis is another test\n")
|
| 55 |
+
output_file = str(tmp_path / "test_output.txt")
|
| 56 |
+
text_processing.parse_text(args, model, [pipeline], tokenized_file=raw_file, predict_file=output_file)
|
| 57 |
+
|
| 58 |
+
trees = tree_reader.read_treebank(output_file)
|
| 59 |
+
trees = ["{}".format(x) for x in trees]
|
| 60 |
+
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
|
| 61 |
+
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
|
| 62 |
+
assert trees == expected_trees
|
| 63 |
+
|
| 64 |
+
def test_parse_dir(tmp_path, pipeline):
|
| 65 |
+
con_processor = pipeline.processors["constituency"]
|
| 66 |
+
model = con_processor._model
|
| 67 |
+
args = model.args
|
| 68 |
+
|
| 69 |
+
raw_dir = str(tmp_path / "input")
|
| 70 |
+
os.makedirs(raw_dir)
|
| 71 |
+
raw_f1 = str(tmp_path / "input" / "f1.txt")
|
| 72 |
+
raw_f2 = str(tmp_path / "input" / "f2.txt")
|
| 73 |
+
output_dir = str(tmp_path / "output")
|
| 74 |
+
|
| 75 |
+
with open(raw_f1, "w") as fout:
|
| 76 |
+
fout.write("This is a test")
|
| 77 |
+
with open(raw_f2, "w") as fout:
|
| 78 |
+
fout.write("This is another test")
|
| 79 |
+
|
| 80 |
+
text_processing.parse_dir(args, model, [pipeline], raw_dir, output_dir)
|
| 81 |
+
output_files = sorted(glob.glob(os.path.join(output_dir, "*")))
|
| 82 |
+
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
|
| 83 |
+
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
|
| 84 |
+
for output_file, expected_tree in zip(output_files, expected_trees):
|
| 85 |
+
trees = tree_reader.read_treebank(output_file)
|
| 86 |
+
assert len(trees) == 1
|
| 87 |
+
assert "{}".format(trees[0]) == expected_tree
|
| 88 |
+
|
| 89 |
+
def test_parse_text(tmp_path, pipeline):
|
| 90 |
+
con_processor = pipeline.processors["constituency"]
|
| 91 |
+
model = con_processor._model
|
| 92 |
+
args = dict(model.args)
|
| 93 |
+
|
| 94 |
+
model_path = con_processor._config['model_path']
|
| 95 |
+
|
| 96 |
+
raw_file = str(tmp_path / "test_input.txt")
|
| 97 |
+
with open(raw_file, "w") as fout:
|
| 98 |
+
fout.write("This is a test\nThis is another test\n")
|
| 99 |
+
output_file = str(tmp_path / "test_output.txt")
|
| 100 |
+
|
| 101 |
+
args['tokenized_file'] = raw_file
|
| 102 |
+
args['predict_file'] = output_file
|
| 103 |
+
|
| 104 |
+
text_processing.load_model_parse_text(args, model_path, [pipeline])
|
| 105 |
+
trees = tree_reader.read_treebank(output_file)
|
| 106 |
+
trees = ["{}".format(x) for x in trees]
|
| 107 |
+
expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))",
|
| 108 |
+
"(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"]
|
| 109 |
+
assert trees == expected_trees
|
stanza/stanza/tests/constituency/test_top_down_oracle.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.models.constituency.base_model import SimpleModel
|
| 4 |
+
from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, TransitionScheme
|
| 5 |
+
from stanza.models.constituency.top_down_oracle import *
|
| 6 |
+
from stanza.models.constituency.transition_sequence import build_sequence
|
| 7 |
+
from stanza.models.constituency.tree_reader import read_trees
|
| 8 |
+
|
| 9 |
+
from stanza.tests.constituency.test_transition_sequence import reconstruct_tree
|
| 10 |
+
|
| 11 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 12 |
+
|
| 13 |
+
OPEN_SHIFT_EXAMPLE_TREE = """
|
| 14 |
+
( (S
|
| 15 |
+
(NP (NNP Jennifer) (NNP Sh\'reyan))
|
| 16 |
+
(VP (VBZ has)
|
| 17 |
+
(NP (RB nice) (NNS antennae)))))
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
OPEN_SHIFT_PROBLEM_TREE = """
|
| 21 |
+
(ROOT (S (NP (NP (NP (DT The) (`` ``) (JJ Thin) (NNP Man) ('' '') (NN series)) (PP (IN of) (NP (NNS movies)))) (, ,) (CONJP (RB as) (RB well) (IN as)) (NP (JJ many) (NNS others)) (, ,)) (VP (VBD based) (NP (PRP$ their) (JJ entire) (JJ comedic) (NN appeal)) (PP (IN on) (NP (NP (DT the) (NN star) (NNS detectives) (POS ')) (JJ witty) (NNS quips) (CC and) (NNS puns))) (SBAR (IN as) (S (NP (NP (JJ other) (NNS characters)) (PP (IN in) (NP (DT the) (NNS movies)))) (VP (VBD were) (VP (VBN murdered)))))) (. .)))
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
ROOT_LABELS = ["ROOT"]
|
| 25 |
+
|
| 26 |
+
def get_single_repair(gold_sequence, wrong_transition, repair_fn, idx, *args, **kwargs):
|
| 27 |
+
return repair_fn(gold_sequence[idx], wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None, *args, **kwargs)
|
| 28 |
+
|
| 29 |
+
def build_state(model, tree, num_transitions):
|
| 30 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 31 |
+
states = model.initial_state_from_gold_trees([tree], [transitions])
|
| 32 |
+
for idx, t in enumerate(transitions[:num_transitions]):
|
| 33 |
+
assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence)
|
| 34 |
+
states = model.bulk_apply(states, [t])
|
| 35 |
+
state = states[0]
|
| 36 |
+
return state
|
| 37 |
+
|
| 38 |
+
def test_fix_open_shift():
|
| 39 |
+
trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
|
| 40 |
+
assert len(trees) == 1
|
| 41 |
+
tree = trees[0]
|
| 42 |
+
|
| 43 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 44 |
+
EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 45 |
+
EXPECTED_FIX_EARLY = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 46 |
+
EXPECTED_FIX_LATE = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 47 |
+
|
| 48 |
+
assert transitions == EXPECTED_ORIG
|
| 49 |
+
|
| 50 |
+
new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)
|
| 51 |
+
assert new_transitions == EXPECTED_FIX_EARLY
|
| 52 |
+
|
| 53 |
+
new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 8)
|
| 54 |
+
assert new_transitions == EXPECTED_FIX_LATE
|
| 55 |
+
|
| 56 |
+
def test_fix_open_shift_observed_error():
|
| 57 |
+
"""
|
| 58 |
+
Ran into an error on this tree, need to fix it
|
| 59 |
+
|
| 60 |
+
The problem is the multiple Open in a row all need to be removed when a Shift happens
|
| 61 |
+
"""
|
| 62 |
+
trees = read_trees(OPEN_SHIFT_PROBLEM_TREE)
|
| 63 |
+
assert len(trees) == 1
|
| 64 |
+
tree = trees[0]
|
| 65 |
+
|
| 66 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 67 |
+
new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2)
|
| 68 |
+
assert new_transitions is None
|
| 69 |
+
|
| 70 |
+
new_transitions = get_single_repair(transitions, Shift(), fix_multiple_open_shift, 2)
|
| 71 |
+
|
| 72 |
+
# Can break the expected transitions down like this:
|
| 73 |
+
# [OpenConstituent(('ROOT',)), OpenConstituent(('S',)),
|
| 74 |
+
# all gone: OpenConstituent(('NP',)), OpenConstituent(('NP',)), OpenConstituent(('NP',)),
|
| 75 |
+
# Shift, Shift, Shift, Shift, Shift, Shift,
|
| 76 |
+
# gone: CloseConstituent,
|
| 77 |
+
# OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), Shift, CloseConstituent, CloseConstituent,
|
| 78 |
+
# gone: CloseConstituent,
|
| 79 |
+
# Shift, OpenConstituent(('CONJP',)), Shift, Shift, Shift, CloseConstituent, OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, Shift,
|
| 80 |
+
# gone: CloseConstituent,
|
| 81 |
+
# and then the rest:
|
| 82 |
+
# OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)),
|
| 83 |
+
# Shift, Shift, Shift, Shift, CloseConstituent,
|
| 84 |
+
# OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),
|
| 85 |
+
# OpenConstituent(('NP',)), Shift, Shift, Shift, Shift,
|
| 86 |
+
# CloseConstituent, Shift, Shift, Shift, Shift, CloseConstituent,
|
| 87 |
+
# CloseConstituent, OpenConstituent(('SBAR',)), Shift,
|
| 88 |
+
# OpenConstituent(('S',)), OpenConstituent(('NP',)),
|
| 89 |
+
# OpenConstituent(('NP',)), Shift, Shift, CloseConstituent,
|
| 90 |
+
# OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)),
|
| 91 |
+
# Shift, Shift, CloseConstituent, CloseConstituent,
|
| 92 |
+
# CloseConstituent, OpenConstituent(('VP',)), Shift,
|
| 93 |
+
# OpenConstituent(('VP',)), Shift, CloseConstituent,
|
| 94 |
+
# CloseConstituent, CloseConstituent, CloseConstituent,
|
| 95 |
+
# CloseConstituent, Shift, CloseConstituent, CloseConstituent]
|
| 96 |
+
expected_transitions = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), Shift(), Shift(), Shift(), Shift(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), Shift(), OpenConstituent('CONJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('SBAR'), Shift(), OpenConstituent('S'), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('VP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 97 |
+
|
| 98 |
+
assert new_transitions == expected_transitions
|
| 99 |
+
|
| 100 |
+
def test_open_open_ambiguous_unary_fix():
|
| 101 |
+
trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
|
| 102 |
+
assert len(trees) == 1
|
| 103 |
+
tree = trees[0]
|
| 104 |
+
|
| 105 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 106 |
+
EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 107 |
+
EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 108 |
+
assert transitions == EXPECTED_ORIG
|
| 109 |
+
new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_unary, 2)
|
| 110 |
+
assert new_transitions == EXPECTED_FIX
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_open_open_ambiguous_later_fix():
|
| 114 |
+
trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE)
|
| 115 |
+
assert len(trees) == 1
|
| 116 |
+
tree = trees[0]
|
| 117 |
+
|
| 118 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 119 |
+
EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 120 |
+
EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 121 |
+
assert transitions == EXPECTED_ORIG
|
| 122 |
+
new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_later, 2)
|
| 123 |
+
assert new_transitions == EXPECTED_FIX
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
CLOSE_SHIFT_EXAMPLE_TREE = """
|
| 127 |
+
( (NP (DT a)
|
| 128 |
+
(ADJP (NN stock) (HYPH -) (VBG picking))
|
| 129 |
+
(NN tool)))
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
# not intended to be a correct tree
|
| 133 |
+
CLOSE_SHIFT_DEEP_EXAMPLE_TREE = """
|
| 134 |
+
( (NP (DT a)
|
| 135 |
+
(VP (ADJP (NN stock) (HYPH -) (VBG picking)))
|
| 136 |
+
(NN tool)))
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# not intended to be a correct tree
|
| 140 |
+
CLOSE_SHIFT_OPEN_EXAMPLE_TREE = """
|
| 141 |
+
( (NP (DT a)
|
| 142 |
+
(ADJP (NN stock) (HYPH -) (VBG picking))
|
| 143 |
+
(NP (NN tool))))
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
CLOSE_SHIFT_AMBIGUOUS_TREE = """
|
| 147 |
+
( (NP (DT a)
|
| 148 |
+
(ADJP (NN stock) (HYPH -) (VBG picking))
|
| 149 |
+
(NN tool)
|
| 150 |
+
(NN foo)))
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def test_fix_close_shift_ambiguous_immediate():
|
| 154 |
+
"""
|
| 155 |
+
Test the result when a close/shift error occurs and we want to close the new, incorrect constituent immediately
|
| 156 |
+
"""
|
| 157 |
+
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
|
| 158 |
+
assert len(trees) == 1
|
| 159 |
+
tree = trees[0]
|
| 160 |
+
|
| 161 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 162 |
+
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_later, 7)
|
| 163 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 164 |
+
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 165 |
+
assert transitions == expected_original
|
| 166 |
+
assert new_sequence == expected_update
|
| 167 |
+
|
| 168 |
+
def test_fix_close_shift_ambiguous_later():
|
| 169 |
+
# test that the one with two shifts, which is ambiguous, gets rejected
|
| 170 |
+
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
|
| 171 |
+
assert len(trees) == 1
|
| 172 |
+
tree = trees[0]
|
| 173 |
+
|
| 174 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 175 |
+
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_immediate, 7)
|
| 176 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 177 |
+
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 178 |
+
assert transitions == expected_original
|
| 179 |
+
assert new_sequence == expected_update
|
| 180 |
+
|
| 181 |
+
def test_oracle_with_optional_level():
|
| 182 |
+
tree = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)[0]
|
| 183 |
+
gold_sequence = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 184 |
+
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 185 |
+
|
| 186 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 187 |
+
assert transitions == gold_sequence
|
| 188 |
+
|
| 189 |
+
oracle = TopDownOracle(ROOT_LABELS, 1, "", "")
|
| 190 |
+
|
| 191 |
+
model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY, root_labels=ROOT_LABELS)
|
| 192 |
+
state = build_state(model, tree, 7)
|
| 193 |
+
fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],
|
| 194 |
+
model=model,
|
| 195 |
+
state=state)
|
| 196 |
+
assert fix is RepairType.OTHER_CLOSE_SHIFT
|
| 197 |
+
assert new_sequence is None
|
| 198 |
+
|
| 199 |
+
oracle = TopDownOracle(ROOT_LABELS, 1, "CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR", "")
|
| 200 |
+
fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8],
|
| 201 |
+
model=model,
|
| 202 |
+
state=state)
|
| 203 |
+
assert fix is RepairType.CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR
|
| 204 |
+
assert new_sequence == expected_update
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def test_fix_close_shift():
|
| 208 |
+
"""
|
| 209 |
+
Test a tree of the kind we expect the close/shift to be able to get right
|
| 210 |
+
"""
|
| 211 |
+
trees = read_trees(CLOSE_SHIFT_EXAMPLE_TREE)
|
| 212 |
+
assert len(trees) == 1
|
| 213 |
+
tree = trees[0]
|
| 214 |
+
|
| 215 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 216 |
+
|
| 217 |
+
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)
|
| 218 |
+
|
| 219 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 220 |
+
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 221 |
+
assert transitions == expected_original
|
| 222 |
+
assert new_sequence == expected_update
|
| 223 |
+
|
| 224 |
+
# test that the one with two shifts, which is ambiguous, gets rejected
|
| 225 |
+
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
|
| 226 |
+
assert len(trees) == 1
|
| 227 |
+
tree = trees[0]
|
| 228 |
+
|
| 229 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 230 |
+
new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7)
|
| 231 |
+
assert new_sequence is None
|
| 232 |
+
|
| 233 |
+
def test_fix_close_shift_deeper_tree():
|
| 234 |
+
"""
|
| 235 |
+
Test a tree of the kind we expect the close/shift to be able to get right
|
| 236 |
+
"""
|
| 237 |
+
trees = read_trees(CLOSE_SHIFT_DEEP_EXAMPLE_TREE)
|
| 238 |
+
assert len(trees) == 1
|
| 239 |
+
tree = trees[0]
|
| 240 |
+
|
| 241 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 242 |
+
|
| 243 |
+
for count_opens in [True, False]:
|
| 244 |
+
new_sequence = get_single_repair(transitions, transitions[10], fix_close_shift, 8, count_opens=count_opens)
|
| 245 |
+
|
| 246 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 247 |
+
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 248 |
+
assert transitions == expected_original
|
| 249 |
+
assert new_sequence == expected_update
|
| 250 |
+
|
| 251 |
+
def test_fix_close_shift_open_tree():
|
| 252 |
+
"""
|
| 253 |
+
We would like the close/shift to get this case right as well
|
| 254 |
+
"""
|
| 255 |
+
trees = read_trees(CLOSE_SHIFT_OPEN_EXAMPLE_TREE)
|
| 256 |
+
assert len(trees) == 1
|
| 257 |
+
tree = trees[0]
|
| 258 |
+
|
| 259 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 260 |
+
|
| 261 |
+
new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift, 7, count_opens=False)
|
| 262 |
+
assert new_sequence is None
|
| 263 |
+
|
| 264 |
+
new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift_with_opens, 7)
|
| 265 |
+
|
| 266 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 267 |
+
expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 268 |
+
assert transitions == expected_original
|
| 269 |
+
assert new_sequence == expected_update
|
| 270 |
+
|
| 271 |
+
CLOSE_OPEN_EXAMPLE_TREE = """
|
| 272 |
+
( (VP (VBZ eat)
|
| 273 |
+
(NP (NN spaghetti))
|
| 274 |
+
(PP (IN with) (DT a) (NN fork))))
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
CLOSE_OPEN_DIFFERENT_LABEL_TREE = """
|
| 278 |
+
( (VP (VBZ eat)
|
| 279 |
+
(NP (NN spaghetti))
|
| 280 |
+
(NP (DT a) (NN fork))))
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
CLOSE_OPEN_TWO_LABELS_TREE = """
|
| 284 |
+
( (VP (VBZ eat)
|
| 285 |
+
(NP (NN spaghetti))
|
| 286 |
+
(PP (IN with) (DT a) (NN fork))
|
| 287 |
+
(PP (IN in) (DT a) (NN restaurant))))
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
def test_fix_close_open():
|
| 291 |
+
trees = read_trees(CLOSE_OPEN_EXAMPLE_TREE)
|
| 292 |
+
assert len(trees) == 1
|
| 293 |
+
tree = trees[0]
|
| 294 |
+
|
| 295 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 296 |
+
|
| 297 |
+
assert isinstance(transitions[5], CloseConstituent)
|
| 298 |
+
assert transitions[6] == OpenConstituent("PP")
|
| 299 |
+
|
| 300 |
+
new_transitions = get_single_repair(transitions, transitions[6], fix_close_open_correct_open, 5)
|
| 301 |
+
|
| 302 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 303 |
+
expected_update = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()]
|
| 304 |
+
|
| 305 |
+
assert transitions == expected_original
|
| 306 |
+
assert new_transitions == expected_update
|
| 307 |
+
|
| 308 |
+
def test_fix_close_open_invalid():
|
| 309 |
+
for TREE in (CLOSE_OPEN_DIFFERENT_LABEL_TREE, CLOSE_OPEN_TWO_LABELS_TREE):
|
| 310 |
+
trees = read_trees(TREE)
|
| 311 |
+
assert len(trees) == 1
|
| 312 |
+
tree = trees[0]
|
| 313 |
+
|
| 314 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 315 |
+
|
| 316 |
+
assert isinstance(transitions[5], CloseConstituent)
|
| 317 |
+
assert isinstance(transitions[6], OpenConstituent)
|
| 318 |
+
|
| 319 |
+
new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5)
|
| 320 |
+
assert new_transitions is None
|
| 321 |
+
|
| 322 |
+
def test_fix_close_open_ambiguous_immediate():
|
| 323 |
+
"""
|
| 324 |
+
Test that a fix for an ambiguous close/open works as expected
|
| 325 |
+
"""
|
| 326 |
+
trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)
|
| 327 |
+
assert len(trees) == 1
|
| 328 |
+
tree = trees[0]
|
| 329 |
+
|
| 330 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 331 |
+
assert isinstance(transitions[5], CloseConstituent)
|
| 332 |
+
assert isinstance(transitions[6], OpenConstituent)
|
| 333 |
+
|
| 334 |
+
reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 335 |
+
assert tree == reconstructed
|
| 336 |
+
|
| 337 |
+
new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5, check_close=False)
|
| 338 |
+
reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 339 |
+
|
| 340 |
+
expected = """
|
| 341 |
+
( (VP (VBZ eat)
|
| 342 |
+
(NP (NN spaghetti)
|
| 343 |
+
(PP (IN with) (DT a) (NN fork)))
|
| 344 |
+
(PP (IN in) (DT a) (NN restaurant))))
|
| 345 |
+
"""
|
| 346 |
+
expected = read_trees(expected)[0]
|
| 347 |
+
assert reconstructed == expected
|
| 348 |
+
|
| 349 |
+
def test_fix_close_open_ambiguous_later():
|
| 350 |
+
"""
|
| 351 |
+
Test that a fix for an ambiguous close/open works as expected
|
| 352 |
+
"""
|
| 353 |
+
trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE)
|
| 354 |
+
assert len(trees) == 1
|
| 355 |
+
tree = trees[0]
|
| 356 |
+
|
| 357 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 358 |
+
assert isinstance(transitions[5], CloseConstituent)
|
| 359 |
+
assert isinstance(transitions[6], OpenConstituent)
|
| 360 |
+
|
| 361 |
+
reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 362 |
+
assert tree == reconstructed
|
| 363 |
+
|
| 364 |
+
new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open_ambiguous_later, 5, check_close=False)
|
| 365 |
+
reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 366 |
+
|
| 367 |
+
expected = """
|
| 368 |
+
( (VP (VBZ eat)
|
| 369 |
+
(NP (NN spaghetti)
|
| 370 |
+
(PP (IN with) (DT a) (NN fork))
|
| 371 |
+
(PP (IN in) (DT a) (NN restaurant)))))
|
| 372 |
+
"""
|
| 373 |
+
expected = read_trees(expected)[0]
|
| 374 |
+
assert reconstructed == expected
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
SHIFT_CLOSE_EXAMPLES = [
|
| 378 |
+
("((S (NP (DT an) (NML (NNP Oct) (CD 19)) (NN review))))", "((S (NP (DT an) (NML (NNP Oct) (CD 19))) (NN review)))", 8),
|
| 379 |
+
("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
|
| 380 |
+
"((S (NP (` `) (NP (DT The)) (NN Misanthrope) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", 6),
|
| 381 |
+
("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
|
| 382 |
+
"((S (NP (` `) (NP (DT The) (NN Misanthrope))) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre)))))", 8),
|
| 383 |
+
("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))",
|
| 384 |
+
"((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman)) (NNP Theatre)))))", 13),
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
def test_shift_close():
|
| 388 |
+
for idx, (orig_tree, expected_tree, shift_position) in enumerate(SHIFT_CLOSE_EXAMPLES):
|
| 389 |
+
trees = read_trees(orig_tree)
|
| 390 |
+
assert len(trees) == 1
|
| 391 |
+
tree = trees[0]
|
| 392 |
+
|
| 393 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 394 |
+
if shift_position is None:
|
| 395 |
+
print(transitions)
|
| 396 |
+
continue
|
| 397 |
+
|
| 398 |
+
assert isinstance(transitions[shift_position], Shift)
|
| 399 |
+
new_transitions = get_single_repair(transitions, CloseConstituent(), fix_shift_close, shift_position)
|
| 400 |
+
reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 401 |
+
if expected_tree is None:
|
| 402 |
+
print(transitions)
|
| 403 |
+
print(new_transitions)
|
| 404 |
+
|
| 405 |
+
print("{:P}".format(reconstructed))
|
| 406 |
+
else:
|
| 407 |
+
expected_tree = read_trees(expected_tree)
|
| 408 |
+
assert len(expected_tree) == 1
|
| 409 |
+
expected_tree = expected_tree[0]
|
| 410 |
+
|
| 411 |
+
assert reconstructed == expected_tree
|
| 412 |
+
|
| 413 |
+
def test_shift_open_ambiguous_unary():
|
| 414 |
+
"""
|
| 415 |
+
Test what happens if a Shift is turned into an Open in an ambiguous manner
|
| 416 |
+
"""
|
| 417 |
+
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
|
| 418 |
+
assert len(trees) == 1
|
| 419 |
+
tree = trees[0]
|
| 420 |
+
|
| 421 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 422 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 423 |
+
assert transitions == expected_original
|
| 424 |
+
|
| 425 |
+
new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_unary, 4)
|
| 426 |
+
expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 427 |
+
assert new_sequence == expected_updated
|
| 428 |
+
|
| 429 |
+
def test_shift_open_ambiguous_later():
|
| 430 |
+
"""
|
| 431 |
+
Test what happens if a Shift is turned into an Open in an ambiguous manner
|
| 432 |
+
"""
|
| 433 |
+
trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)
|
| 434 |
+
assert len(trees) == 1
|
| 435 |
+
tree = trees[0]
|
| 436 |
+
|
| 437 |
+
transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 438 |
+
expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 439 |
+
assert transitions == expected_original
|
| 440 |
+
|
| 441 |
+
new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_later, 4)
|
| 442 |
+
expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()]
|
| 443 |
+
assert new_sequence == expected_updated
|
stanza/stanza/tests/constituency/test_trainer.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import logging
|
| 3 |
+
import pathlib
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch import optim
|
| 10 |
+
|
| 11 |
+
from stanza import Pipeline
|
| 12 |
+
|
| 13 |
+
from stanza.models import constituency_parser
|
| 14 |
+
from stanza.models.common import pretrain
|
| 15 |
+
from stanza.models.common.bert_embedding import load_bert, load_tokenizer
|
| 16 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 17 |
+
from stanza.models.common.utils import set_random_seed
|
| 18 |
+
from stanza.models.constituency import lstm_model
|
| 19 |
+
from stanza.models.constituency.parse_transitions import Transition
|
| 20 |
+
from stanza.models.constituency import parser_training
|
| 21 |
+
from stanza.models.constituency import trainer
|
| 22 |
+
from stanza.models.constituency import tree_reader
|
| 23 |
+
from stanza.tests import *
|
| 24 |
+
|
| 25 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger('stanza.constituency.trainer')
|
| 28 |
+
logger.setLevel(logging.WARNING)
|
| 29 |
+
|
| 30 |
+
TREEBANK = """
|
| 31 |
+
( (S
|
| 32 |
+
(VP (VBG Enjoying)
|
| 33 |
+
(NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition)))
|
| 34 |
+
(. .)))
|
| 35 |
+
|
| 36 |
+
( (NP
|
| 37 |
+
(VP (VBG Sitting)
|
| 38 |
+
(PP (IN in)
|
| 39 |
+
(NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station)))
|
| 40 |
+
(VP (VBG waiting)
|
| 41 |
+
(PP (IN for)
|
| 42 |
+
(NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train)))))
|
| 43 |
+
(. .)))
|
| 44 |
+
|
| 45 |
+
( (S
|
| 46 |
+
(NP (PRP I))
|
| 47 |
+
(VP
|
| 48 |
+
(ADVP (RB really))
|
| 49 |
+
(VBP hate)
|
| 50 |
+
(NP (DT the) (NNP @MBTA)))))
|
| 51 |
+
|
| 52 |
+
( (S
|
| 53 |
+
(S (VP (VB Seek)))
|
| 54 |
+
(CC and)
|
| 55 |
+
(S (NP (PRP ye))
|
| 56 |
+
(VP (MD shall)
|
| 57 |
+
(VP (VB find))))
|
| 58 |
+
(. .)))
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK):
|
| 62 |
+
# TODO: build a fake embedding some other way?
|
| 63 |
+
train_trees = tree_reader.read_trees(treebank)
|
| 64 |
+
dev_trees = train_trees[-1:]
|
| 65 |
+
silver_trees = []
|
| 66 |
+
|
| 67 |
+
args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args)
|
| 68 |
+
args = constituency_parser.parse_args(args)
|
| 69 |
+
|
| 70 |
+
foundation_cache = FoundationCache()
|
| 71 |
+
# might be None, unless we're testing loading an existing model
|
| 72 |
+
model_load_name = args['load_name']
|
| 73 |
+
|
| 74 |
+
model, _, _, _ = parser_training.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name)
|
| 75 |
+
assert isinstance(model.model, lstm_model.LSTMModel)
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
class TestTrainer:
|
| 79 |
+
@pytest.fixture(scope="class")
|
| 80 |
+
def wordvec_pretrain_file(self):
|
| 81 |
+
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
|
| 82 |
+
|
| 83 |
+
@pytest.fixture(scope="class")
|
| 84 |
+
def tiny_random_xlnet(self, tmp_path_factory):
|
| 85 |
+
"""
|
| 86 |
+
Download the tiny-random-xlnet model and make a concrete copy of it
|
| 87 |
+
|
| 88 |
+
The issue here is that the "random" nature of the original
|
| 89 |
+
makes it difficult or impossible to test that the values in
|
| 90 |
+
the transformer don't change during certain operations.
|
| 91 |
+
Saving a concrete instantiation of those random numbers makes
|
| 92 |
+
it so we can test there is no difference when training only a
|
| 93 |
+
subset of the layers, for example
|
| 94 |
+
"""
|
| 95 |
+
xlnet_name = 'hf-internal-testing/tiny-random-xlnet'
|
| 96 |
+
xlnet_model, xlnet_tokenizer = load_bert(xlnet_name)
|
| 97 |
+
path = str(tmp_path_factory.mktemp('tiny-random-xlnet'))
|
| 98 |
+
xlnet_model.save_pretrained(path)
|
| 99 |
+
xlnet_tokenizer.save_pretrained(path)
|
| 100 |
+
return path
|
| 101 |
+
|
| 102 |
+
@pytest.fixture(scope="class")
|
| 103 |
+
def tiny_random_bart(self, tmp_path_factory):
|
| 104 |
+
"""
|
| 105 |
+
Download the tiny-random-bart model and make a concrete copy of it
|
| 106 |
+
|
| 107 |
+
Issue is the same as with tiny_random_xlnet
|
| 108 |
+
"""
|
| 109 |
+
bart_name = 'hf-internal-testing/tiny-random-bart'
|
| 110 |
+
bart_model, bart_tokenizer = load_bert(bart_name)
|
| 111 |
+
path = str(tmp_path_factory.mktemp('tiny-random-bart'))
|
| 112 |
+
bart_model.save_pretrained(path)
|
| 113 |
+
bart_tokenizer.save_pretrained(path)
|
| 114 |
+
return path
|
| 115 |
+
|
| 116 |
+
def test_initial_model(self, wordvec_pretrain_file):
|
| 117 |
+
"""
|
| 118 |
+
does nothing, just tests that the construction went okay
|
| 119 |
+
"""
|
| 120 |
+
args = ['wordvec_pretrain_file', wordvec_pretrain_file]
|
| 121 |
+
build_trainer(wordvec_pretrain_file)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def test_save_load_model(self, wordvec_pretrain_file):
|
| 125 |
+
"""
|
| 126 |
+
Just tests that saving and loading works without crashs.
|
| 127 |
+
|
| 128 |
+
Currently no test of the values themselves
|
| 129 |
+
(checks some fields to make sure they are regenerated correctly)
|
| 130 |
+
"""
|
| 131 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 132 |
+
tr = build_trainer(wordvec_pretrain_file)
|
| 133 |
+
transitions = tr.model.transitions
|
| 134 |
+
|
| 135 |
+
# attempt saving
|
| 136 |
+
filename = os.path.join(tmpdirname, "parser.pt")
|
| 137 |
+
tr.save(filename)
|
| 138 |
+
|
| 139 |
+
assert os.path.exists(filename)
|
| 140 |
+
|
| 141 |
+
# load it back in
|
| 142 |
+
tr2 = tr.load(filename)
|
| 143 |
+
trans2 = tr2.model.transitions
|
| 144 |
+
assert(transitions == trans2)
|
| 145 |
+
assert all(isinstance(x, Transition) for x in trans2)
|
| 146 |
+
|
| 147 |
+
def test_relearn_structure(self, wordvec_pretrain_file):
|
| 148 |
+
"""
|
| 149 |
+
Test that starting a trainer with --relearn_structure copies the old model
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 153 |
+
set_random_seed(1000)
|
| 154 |
+
args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
|
| 155 |
+
tr = build_trainer(wordvec_pretrain_file, *args)
|
| 156 |
+
|
| 157 |
+
# attempt saving
|
| 158 |
+
filename = os.path.join(tmpdirname, "parser.pt")
|
| 159 |
+
tr.save(filename)
|
| 160 |
+
|
| 161 |
+
set_random_seed(1001)
|
| 162 |
+
args = ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--relearn_structure', '--load_name', filename]
|
| 163 |
+
tr2 = build_trainer(wordvec_pretrain_file, *args)
|
| 164 |
+
|
| 165 |
+
assert torch.allclose(tr.model.delta_embedding.weight, tr2.model.delta_embedding.weight)
|
| 166 |
+
assert torch.allclose(tr.model.output_layers[0].weight, tr2.model.output_layers[0].weight)
|
| 167 |
+
# the norms will be the same, as the non-zero values are all the same
|
| 168 |
+
assert torch.allclose(torch.linalg.norm(tr.model.word_lstm.weight_ih_l0), torch.linalg.norm(tr2.model.word_lstm.weight_ih_l0))
|
| 169 |
+
|
| 170 |
+
def write_treebanks(self, tmpdirname):
|
| 171 |
+
train_treebank_file = os.path.join(tmpdirname, "train.mrg")
|
| 172 |
+
with open(train_treebank_file, 'w', encoding='utf-8') as fout:
|
| 173 |
+
fout.write(TREEBANK)
|
| 174 |
+
fout.write(TREEBANK)
|
| 175 |
+
|
| 176 |
+
eval_treebank_file = os.path.join(tmpdirname, "eval.mrg")
|
| 177 |
+
with open(eval_treebank_file, 'w', encoding='utf-8') as fout:
|
| 178 |
+
fout.write(TREEBANK)
|
| 179 |
+
|
| 180 |
+
return train_treebank_file, eval_treebank_file
|
| 181 |
+
|
| 182 |
+
def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args):
|
| 183 |
+
# let's not make the model huge...
|
| 184 |
+
args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10',
|
| 185 |
+
'--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname,
|
| 186 |
+
'--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_start', '0', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'),
|
| 187 |
+
'--train_file', train_treebank_file, '--eval_file', eval_treebank_file,
|
| 188 |
+
'--epoch_size', '6', '--train_batch_size', '3',
|
| 189 |
+
'--shorthand', 'en_test']
|
| 190 |
+
args = args + list(additional_args)
|
| 191 |
+
args = constituency_parser.parse_args(args)
|
| 192 |
+
# just in case we change the defaults in the future
|
| 193 |
+
args['wandb'] = None
|
| 194 |
+
return args
|
| 195 |
+
|
| 196 |
+
def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False, foundation_cache=None):
|
| 197 |
+
"""
|
| 198 |
+
Runs a test of the trainer for a few iterations.
|
| 199 |
+
|
| 200 |
+
Checks some basic properties of the saved model, but doesn't
|
| 201 |
+
check for the accuracy of the results
|
| 202 |
+
"""
|
| 203 |
+
if extra_args is None:
|
| 204 |
+
extra_args = []
|
| 205 |
+
extra_args += ['--epochs', '%d' % num_epochs]
|
| 206 |
+
|
| 207 |
+
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
|
| 208 |
+
if use_silver:
|
| 209 |
+
extra_args += ['--silver_file', str(eval_treebank_file)]
|
| 210 |
+
args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args)
|
| 211 |
+
|
| 212 |
+
each_name = args['save_each_name']
|
| 213 |
+
if not exists_ok:
|
| 214 |
+
assert not os.path.exists(args['save_name'])
|
| 215 |
+
retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True, dir=TEST_MODELS_DIR, foundation_cache=foundation_cache)
|
| 216 |
+
trained_model = parser_training.train(args, None, [retag_pipeline])
|
| 217 |
+
# check that hooks are in the model if expected
|
| 218 |
+
for p in trained_model.model.parameters():
|
| 219 |
+
if p.requires_grad:
|
| 220 |
+
if args['grad_clipping'] is not None:
|
| 221 |
+
assert len(p._backward_hooks) == 1
|
| 222 |
+
else:
|
| 223 |
+
assert p._backward_hooks is None
|
| 224 |
+
|
| 225 |
+
# check that the model can be loaded back
|
| 226 |
+
assert os.path.exists(args['save_name'])
|
| 227 |
+
peft_name = trained_model.model.peft_name
|
| 228 |
+
tr = trainer.Trainer.load(args['save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
|
| 229 |
+
assert tr.optimizer is not None
|
| 230 |
+
assert tr.scheduler is not None
|
| 231 |
+
assert tr.epochs_trained >= 1
|
| 232 |
+
for p in tr.model.parameters():
|
| 233 |
+
if p.requires_grad:
|
| 234 |
+
assert p._backward_hooks is None
|
| 235 |
+
|
| 236 |
+
tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
|
| 237 |
+
assert tr.optimizer is not None
|
| 238 |
+
assert tr.scheduler is not None
|
| 239 |
+
assert tr.epochs_trained == num_epochs
|
| 240 |
+
|
| 241 |
+
for i in range(1, num_epochs+1):
|
| 242 |
+
model_name = each_name % i
|
| 243 |
+
assert os.path.exists(model_name)
|
| 244 |
+
tr = trainer.Trainer.load(model_name, load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name)
|
| 245 |
+
assert tr.epochs_trained == i
|
| 246 |
+
assert tr.batches_trained == (4 * i if use_silver else 2 * i)
|
| 247 |
+
|
| 248 |
+
return args, trained_model
|
| 249 |
+
|
| 250 |
+
def test_train(self, wordvec_pretrain_file):
|
| 251 |
+
"""
|
| 252 |
+
Test the whole thing for a few iterations on the fake data
|
| 253 |
+
"""
|
| 254 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 255 |
+
self.run_train_test(wordvec_pretrain_file, tmpdirname)
|
| 256 |
+
|
| 257 |
+
def test_early_dropout(self, wordvec_pretrain_file):
|
| 258 |
+
"""
|
| 259 |
+
Test the whole thing for a few iterations on the fake data
|
| 260 |
+
"""
|
| 261 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 262 |
+
args = ['--early_dropout', '3']
|
| 263 |
+
_, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
|
| 264 |
+
model = model.model
|
| 265 |
+
dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
|
| 266 |
+
assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
|
| 267 |
+
for name, module in dropouts:
|
| 268 |
+
assert module.p == 0.0, "Dropout module %s was not set to 0 with early_dropout"
|
| 269 |
+
|
| 270 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 271 |
+
# test that when turned off, early_dropout doesn't happen
|
| 272 |
+
args = ['--early_dropout', '-1']
|
| 273 |
+
_, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
|
| 274 |
+
model = model.model
|
| 275 |
+
dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)]
|
| 276 |
+
assert len(dropouts) > 0, "Didn't find any dropouts in the model!"
|
| 277 |
+
if all(module.p == 0.0 for _, module in dropouts):
|
| 278 |
+
raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1")
|
| 279 |
+
|
| 280 |
+
def test_train_silver(self, wordvec_pretrain_file):
|
| 281 |
+
"""
|
| 282 |
+
Test the whole thing for a few iterations on the fake data
|
| 283 |
+
|
| 284 |
+
This tests that it works if you give it a silver file
|
| 285 |
+
The check for the use of the silver data is that the
|
| 286 |
+
number of batches trained should go up
|
| 287 |
+
"""
|
| 288 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 289 |
+
self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True)
|
| 290 |
+
|
| 291 |
+
def test_train_checkpoint(self, wordvec_pretrain_file):
|
| 292 |
+
"""
|
| 293 |
+
Test the whole thing for a few iterations, then restart
|
| 294 |
+
|
| 295 |
+
This tests that the 5th iteration save file is not rewritten
|
| 296 |
+
and that the iterations continue to 10
|
| 297 |
+
|
| 298 |
+
TODO: could make it more robust by verifying that only 5 more
|
| 299 |
+
epochs are trained. Perhaps a "most recent epochs" could be
|
| 300 |
+
saved in the trainer
|
| 301 |
+
"""
|
| 302 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 303 |
+
args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False)
|
| 304 |
+
save_5 = args['save_each_name'] % 5
|
| 305 |
+
save_10 = args['save_each_name'] % 10
|
| 306 |
+
assert os.path.exists(save_5)
|
| 307 |
+
assert not os.path.exists(save_10)
|
| 308 |
+
|
| 309 |
+
save_5_stat = pathlib.Path(save_5).stat()
|
| 310 |
+
|
| 311 |
+
self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True)
|
| 312 |
+
assert os.path.exists(save_5)
|
| 313 |
+
assert os.path.exists(save_10)
|
| 314 |
+
|
| 315 |
+
assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime
|
| 316 |
+
|
| 317 |
+
def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None):
|
| 318 |
+
train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname)
|
| 319 |
+
args = ['--multistage', '--pattn_num_layers', '1']
|
| 320 |
+
if use_lattn:
|
| 321 |
+
args += ['--lattn_d_proj', '16']
|
| 322 |
+
if extra_args:
|
| 323 |
+
args += extra_args
|
| 324 |
+
args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args)
|
| 325 |
+
each_name = os.path.join(args['save_dir'], 'each_%02d.pt')
|
| 326 |
+
|
| 327 |
+
word_input_sizes = defaultdict(list)
|
| 328 |
+
for i in range(1, 9):
|
| 329 |
+
model_name = each_name % i
|
| 330 |
+
assert os.path.exists(model_name)
|
| 331 |
+
tr = trainer.Trainer.load(model_name, load_optimizer=True)
|
| 332 |
+
assert tr.epochs_trained == i
|
| 333 |
+
word_input_sizes[tr.model.word_input_size].append(i)
|
| 334 |
+
if use_lattn:
|
| 335 |
+
# there should be three stages: no attn, pattn, pattn+lattn
|
| 336 |
+
assert len(word_input_sizes) == 3
|
| 337 |
+
word_input_keys = sorted(word_input_sizes.keys())
|
| 338 |
+
assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]
|
| 339 |
+
assert word_input_sizes[word_input_keys[1]] == [4, 5]
|
| 340 |
+
assert word_input_sizes[word_input_keys[2]] == [6, 7, 8]
|
| 341 |
+
else:
|
| 342 |
+
# with no lattn, there are two stages: no attn, pattn
|
| 343 |
+
assert len(word_input_sizes) == 2
|
| 344 |
+
word_input_keys = sorted(word_input_sizes.keys())
|
| 345 |
+
assert word_input_sizes[word_input_keys[0]] == [1, 2, 3]
|
| 346 |
+
assert word_input_sizes[word_input_keys[1]] == [4, 5, 6, 7, 8]
|
| 347 |
+
|
| 348 |
+
def test_multistage_lattn(self, wordvec_pretrain_file):
|
| 349 |
+
"""
|
| 350 |
+
Test a multistage training for a few iterations on the fake data
|
| 351 |
+
|
| 352 |
+
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
|
| 353 |
+
"""
|
| 354 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 355 |
+
self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True)
|
| 356 |
+
|
| 357 |
+
def test_multistage_no_lattn(self, wordvec_pretrain_file):
|
| 358 |
+
"""
|
| 359 |
+
Test a multistage training for a few iterations on the fake data
|
| 360 |
+
|
| 361 |
+
This should start with no pattn or lattn, have pattn in the middle, then lattn at the end
|
| 362 |
+
"""
|
| 363 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 364 |
+
self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False)
|
| 365 |
+
|
| 366 |
+
def test_multistage_optimizer(self, wordvec_pretrain_file):
|
| 367 |
+
"""
|
| 368 |
+
Test that the correct optimizers are built for a multistage training process
|
| 369 |
+
"""
|
| 370 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 371 |
+
extra_args = ['--optim', 'adamw']
|
| 372 |
+
self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args)
|
| 373 |
+
|
| 374 |
+
# check that the optimizers which get rebuilt when loading
|
| 375 |
+
# the models are adadelta for the first half of the
|
| 376 |
+
# multistage, then adamw
|
| 377 |
+
each_name = os.path.join(tmpdirname, 'each_%02d.pt')
|
| 378 |
+
for i in range(1, 3):
|
| 379 |
+
model_name = each_name % i
|
| 380 |
+
tr = trainer.Trainer.load(model_name, load_optimizer=True)
|
| 381 |
+
assert tr.epochs_trained == i
|
| 382 |
+
assert isinstance(tr.optimizer, optim.Adadelta)
|
| 383 |
+
# double check that this is actually a valid test
|
| 384 |
+
assert not isinstance(tr.optimizer, optim.AdamW)
|
| 385 |
+
|
| 386 |
+
for i in range(4, 8):
|
| 387 |
+
model_name = each_name % i
|
| 388 |
+
tr = trainer.Trainer.load(model_name, load_optimizer=True)
|
| 389 |
+
assert tr.epochs_trained == i
|
| 390 |
+
assert isinstance(tr.optimizer, optim.AdamW)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def test_grad_clip_hooks(self, wordvec_pretrain_file):
|
| 394 |
+
"""
|
| 395 |
+
Verify that grad clipping is not saved with the model, but is attached at training time
|
| 396 |
+
"""
|
| 397 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 398 |
+
args = ['--grad_clipping', '25']
|
| 399 |
+
self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
|
| 400 |
+
|
| 401 |
+
def test_analyze_trees(self, wordvec_pretrain_file):
|
| 402 |
+
test_str = "(ROOT (S (NP (PRP I)) (VP (VBP wan) (S (VP (TO na) (VP (VB lick) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))))) (ROOT (S (NP (DT This) (NN interface)) (VP (VBZ sucks))))"
|
| 403 |
+
|
| 404 |
+
test_tree = tree_reader.read_trees(test_str)
|
| 405 |
+
assert len(test_tree) == 2
|
| 406 |
+
|
| 407 |
+
args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
|
| 408 |
+
tr = build_trainer(wordvec_pretrain_file, *args)
|
| 409 |
+
|
| 410 |
+
results = tr.model.analyze_trees(test_tree)
|
| 411 |
+
assert len(results) == 2
|
| 412 |
+
assert len(results[0].predictions) == 1
|
| 413 |
+
assert results[0].predictions[0].tree == test_tree[0]
|
| 414 |
+
assert results[0].state is not None
|
| 415 |
+
assert isinstance(results[0].state.score, torch.Tensor)
|
| 416 |
+
assert results[0].state.score.shape == torch.Size([])
|
| 417 |
+
assert len(results[0].constituents) == 9
|
| 418 |
+
assert results[0].constituents[-1].value == test_tree[0]
|
| 419 |
+
# the way the results are built, the next-to-last entry
|
| 420 |
+
# should be the thing just below the root
|
| 421 |
+
assert results[0].constituents[-2].value == test_tree[0].children[0]
|
| 422 |
+
|
| 423 |
+
assert len(results[1].predictions) == 1
|
| 424 |
+
assert results[1].predictions[0].tree == test_tree[1]
|
| 425 |
+
assert results[1].state is not None
|
| 426 |
+
assert isinstance(results[1].state.score, torch.Tensor)
|
| 427 |
+
assert results[1].state.score.shape == torch.Size([])
|
| 428 |
+
assert len(results[1].constituents) == 4
|
| 429 |
+
assert results[1].constituents[-1].value == test_tree[1]
|
| 430 |
+
assert results[1].constituents[-2].value == test_tree[1].children[0]
|
| 431 |
+
|
| 432 |
+
def bert_weights_allclose(self, bert_model, parser_model):
|
| 433 |
+
"""
|
| 434 |
+
Return True if all bert weights are close, False otherwise
|
| 435 |
+
"""
|
| 436 |
+
for name, parameter in bert_model.named_parameters():
|
| 437 |
+
other_name = "bert_model." + name
|
| 438 |
+
other_parameter = parser_model.model.get_parameter(other_name)
|
| 439 |
+
if not torch.allclose(parameter.cpu(), other_parameter.cpu()):
|
| 440 |
+
return False
|
| 441 |
+
return True
|
| 442 |
+
|
| 443 |
+
def frozen_transformer_test(self, wordvec_pretrain_file, transformer_name):
|
| 444 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 445 |
+
foundation_cache = FoundationCache()
|
| 446 |
+
args = ['--bert_model', transformer_name]
|
| 447 |
+
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args, foundation_cache=foundation_cache)
|
| 448 |
+
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
|
| 449 |
+
assert self.bert_weights_allclose(bert_model, trained_model)
|
| 450 |
+
|
| 451 |
+
checkpoint = torch.load(args['save_name'], lambda storage, loc: storage, weights_only=True)
|
| 452 |
+
params = checkpoint['params']
|
| 453 |
+
# check that the bert model wasn't saved in the model
|
| 454 |
+
assert all(not x.startswith("bert_model.") for x in params['model'].keys())
|
| 455 |
+
# make sure we're looking at the right thing
|
| 456 |
+
assert any(x.startswith("output_layers.") for x in params['model'].keys())
|
| 457 |
+
|
| 458 |
+
# check that the cached model is used as expected when loading a bert model
|
| 459 |
+
trained_model = trainer.Trainer.load(args['save_name'], foundation_cache=foundation_cache)
|
| 460 |
+
assert trained_model.model.bert_model is bert_model
|
| 461 |
+
|
| 462 |
+
def test_bert_frozen(self, wordvec_pretrain_file):
|
| 463 |
+
"""
|
| 464 |
+
Check that the parameters of the bert model don't change when training a basic model
|
| 465 |
+
"""
|
| 466 |
+
self.frozen_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
|
| 467 |
+
|
| 468 |
+
def test_xlnet_frozen(self, wordvec_pretrain_file, tiny_random_xlnet):
|
| 469 |
+
"""
|
| 470 |
+
Check that the parameters of an xlnet model don't change when training a basic model
|
| 471 |
+
"""
|
| 472 |
+
self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
|
| 473 |
+
|
| 474 |
+
def test_bart_frozen(self, wordvec_pretrain_file, tiny_random_bart):
|
| 475 |
+
"""
|
| 476 |
+
Check that the parameters of an xlnet model don't change when training a basic model
|
| 477 |
+
"""
|
| 478 |
+
self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_bart)
|
| 479 |
+
|
| 480 |
+
def test_bert_finetune_one_epoch(self, wordvec_pretrain_file):
|
| 481 |
+
"""
|
| 482 |
+
Check that the parameters the bert model DO change over a single training step
|
| 483 |
+
"""
|
| 484 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 485 |
+
transformer_name = 'hf-internal-testing/tiny-bert'
|
| 486 |
+
args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adadelta']
|
| 487 |
+
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=1, extra_args=args)
|
| 488 |
+
|
| 489 |
+
# check that the weights are different
|
| 490 |
+
foundation_cache = FoundationCache()
|
| 491 |
+
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
|
| 492 |
+
assert not self.bert_weights_allclose(bert_model, trained_model)
|
| 493 |
+
|
| 494 |
+
# double check that a new bert is created instead of using the FoundationCache when the bert has been trained
|
| 495 |
+
model_name = args['save_name']
|
| 496 |
+
assert os.path.exists(model_name)
|
| 497 |
+
no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name)
|
| 498 |
+
tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)
|
| 499 |
+
assert tr.model.bert_model is not bert_model
|
| 500 |
+
assert not self.bert_weights_allclose(bert_model, tr)
|
| 501 |
+
assert self.bert_weights_allclose(trained_model.model.bert_model, tr)
|
| 502 |
+
|
| 503 |
+
new_save_name = os.path.join(tmpdirname, "test_resave_bert.pt")
|
| 504 |
+
assert not os.path.exists(new_save_name)
|
| 505 |
+
tr.save(new_save_name, save_optimizer=False)
|
| 506 |
+
tr2 = trainer.Trainer.load(new_save_name, args=no_finetune_args, foundation_cache=foundation_cache)
|
| 507 |
+
# check that the resaved model included its finetuned bert weights
|
| 508 |
+
assert tr2.model.bert_model is not bert_model
|
| 509 |
+
# the finetuned bert weights should also be scheduled for saving the next time as well
|
| 510 |
+
assert not tr2.model.is_unsaved_module("bert_model")
|
| 511 |
+
|
| 512 |
+
def finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):
|
| 513 |
+
"""
|
| 514 |
+
Check that the parameters of the transformer DO change when using bert_finetune
|
| 515 |
+
"""
|
| 516 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 517 |
+
args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw']
|
| 518 |
+
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
|
| 519 |
+
|
| 520 |
+
# check that the weights are different
|
| 521 |
+
foundation_cache = FoundationCache()
|
| 522 |
+
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
|
| 523 |
+
assert not self.bert_weights_allclose(bert_model, trained_model)
|
| 524 |
+
|
| 525 |
+
# double check that a new bert is created instead of using the FoundationCache when the bert has been trained
|
| 526 |
+
no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name)
|
| 527 |
+
trained_model = trainer.Trainer.load(args['save_name'], args=no_finetune_args, foundation_cache=foundation_cache)
|
| 528 |
+
assert not trained_model.model.args['bert_finetune']
|
| 529 |
+
assert not trained_model.model.args['stage1_bert_finetune']
|
| 530 |
+
assert trained_model.model.bert_model is not bert_model
|
| 531 |
+
|
| 532 |
+
def test_bert_finetune(self, wordvec_pretrain_file):
|
| 533 |
+
"""
|
| 534 |
+
Check that the parameters of a bert model DO change when using bert_finetune
|
| 535 |
+
"""
|
| 536 |
+
self.finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
|
| 537 |
+
|
| 538 |
+
def test_xlnet_finetune(self, wordvec_pretrain_file, tiny_random_xlnet):
|
| 539 |
+
"""
|
| 540 |
+
Check that the parameters of an xlnet model DO change when using bert_finetune
|
| 541 |
+
"""
|
| 542 |
+
self.finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
|
| 543 |
+
|
| 544 |
+
def test_stage1_bert_finetune(self, wordvec_pretrain_file):
|
| 545 |
+
"""
|
| 546 |
+
Check that the parameters the bert model DO change when using stage1_bert_finetune, but only for the first couple steps
|
| 547 |
+
"""
|
| 548 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 549 |
+
bert_model_name = 'hf-internal-testing/tiny-bert'
|
| 550 |
+
args = ['--bert_model', bert_model_name, '--stage1_bert_finetune', '--optim', 'adamw']
|
| 551 |
+
# need to use num_epochs==6 so that epochs 1 and 2 are saved to be different
|
| 552 |
+
# a test of 5 or less means that sometimes it will reload the params
|
| 553 |
+
# at step 2 to get ready for the following iterations with adamw
|
| 554 |
+
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args)
|
| 555 |
+
|
| 556 |
+
# check that the weights are different
|
| 557 |
+
foundation_cache = FoundationCache()
|
| 558 |
+
bert_model, bert_tokenizer = foundation_cache.load_bert(bert_model_name)
|
| 559 |
+
assert not self.bert_weights_allclose(bert_model, trained_model)
|
| 560 |
+
|
| 561 |
+
# double check that a new bert is created instead of using the FoundationCache when the bert has been trained
|
| 562 |
+
no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', bert_model_name, '--optim', 'adamw')
|
| 563 |
+
num_epochs = trained_model.model.args['epochs']
|
| 564 |
+
each_name = os.path.join(tmpdirname, 'each_%02d.pt')
|
| 565 |
+
for i in range(1, num_epochs+1):
|
| 566 |
+
model_name = each_name % i
|
| 567 |
+
assert os.path.exists(model_name)
|
| 568 |
+
tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache)
|
| 569 |
+
assert tr.model.bert_model is not bert_model
|
| 570 |
+
assert not self.bert_weights_allclose(bert_model, tr)
|
| 571 |
+
if i >= num_epochs // 2:
|
| 572 |
+
assert self.bert_weights_allclose(trained_model.model.bert_model, tr)
|
| 573 |
+
|
| 574 |
+
# verify that models 1 and 2 are saved to be different
|
| 575 |
+
model_name_1 = each_name % 1
|
| 576 |
+
model_name_2 = each_name % 2
|
| 577 |
+
tr_1 = trainer.Trainer.load(model_name_1, args=no_finetune_args, foundation_cache=foundation_cache)
|
| 578 |
+
tr_2 = trainer.Trainer.load(model_name_2, args=no_finetune_args, foundation_cache=foundation_cache)
|
| 579 |
+
assert not self.bert_weights_allclose(tr_1.model.bert_model, tr_2)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def one_layer_finetune_transformer_test(self, wordvec_pretrain_file, transformer_name):
|
| 583 |
+
"""
|
| 584 |
+
Check that the parameters the bert model DO change when using bert_finetune
|
| 585 |
+
"""
|
| 586 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 587 |
+
args = ['--bert_model', transformer_name, '--bert_finetune', '--bert_finetune_layers', '1', '--optim', 'adamw', '--bert_finetune_layers', '1']
|
| 588 |
+
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args)
|
| 589 |
+
|
| 590 |
+
# check that the weights of the last layer are different,
|
| 591 |
+
# but the weights of the earlier layers and
|
| 592 |
+
# non-transformer-layers are the same
|
| 593 |
+
foundation_cache = FoundationCache()
|
| 594 |
+
bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name)
|
| 595 |
+
assert bert_model.config.num_hidden_layers > 1
|
| 596 |
+
layer_name = "layer.%d." % (bert_model.config.num_hidden_layers - 1)
|
| 597 |
+
for name, parameter in bert_model.named_parameters():
|
| 598 |
+
other_name = "bert_model." + name
|
| 599 |
+
other_parameter = trained_model.model.get_parameter(other_name)
|
| 600 |
+
if layer_name in name:
|
| 601 |
+
if 'rel_attn.seg_embed' in name or 'rel_attn.r_s_bias' in name:
|
| 602 |
+
# not sure why this happens for xlnet, just roll with it
|
| 603 |
+
continue
|
| 604 |
+
assert not torch.allclose(parameter.cpu(), other_parameter.cpu())
|
| 605 |
+
else:
|
| 606 |
+
assert torch.allclose(parameter.cpu(), other_parameter.cpu())
|
| 607 |
+
|
| 608 |
+
def test_bert_finetune_one_layer(self, wordvec_pretrain_file):
|
| 609 |
+
self.one_layer_finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert')
|
| 610 |
+
|
| 611 |
+
def test_xlnet_finetune_one_layer(self, wordvec_pretrain_file, tiny_random_xlnet):
|
| 612 |
+
self.one_layer_finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet)
|
| 613 |
+
|
| 614 |
+
def test_peft_finetune(self, tmp_path, wordvec_pretrain_file):
|
| 615 |
+
transformer_name = 'hf-internal-testing/tiny-bert'
|
| 616 |
+
args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw', '--use_peft']
|
| 617 |
+
args, trained_model = self.run_train_test(wordvec_pretrain_file, str(tmp_path), extra_args=args)
|
| 618 |
+
|
| 619 |
+
def test_peft_twostage_finetune(self, wordvec_pretrain_file):
|
| 620 |
+
with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname:
|
| 621 |
+
num_epochs = 6
|
| 622 |
+
transformer_name = 'hf-internal-testing/tiny-bert'
|
| 623 |
+
args = ['--bert_model', transformer_name, '--stage1_bert_finetune', '--optim', 'adamw', '--use_peft']
|
| 624 |
+
args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=num_epochs, extra_args=args)
|
| 625 |
+
for epoch in range(num_epochs):
|
| 626 |
+
filename_prev = args['save_each_name'] % epoch
|
| 627 |
+
filename_next = args['save_each_name'] % (epoch+1)
|
| 628 |
+
trainer_prev = trainer.Trainer.load(filename_prev, args=args, load_optimizer=False)
|
| 629 |
+
trainer_next = trainer.Trainer.load(filename_next, args=args, load_optimizer=False)
|
| 630 |
+
|
| 631 |
+
lora_names = [name for name, _ in trainer_prev.model.bert_model.named_parameters() if name.find("lora") >= 0]
|
| 632 |
+
if epoch < 2:
|
| 633 |
+
assert not any(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),
|
| 634 |
+
trainer_next.model.bert_model.get_parameter(name).cpu())
|
| 635 |
+
for name in lora_names)
|
| 636 |
+
elif epoch > 2:
|
| 637 |
+
assert all(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(),
|
| 638 |
+
trainer_next.model.bert_model.get_parameter(name).cpu())
|
| 639 |
+
for name in lora_names)
|
stanza/stanza/tests/constituency/test_transformer_tree_stack.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 8 |
+
|
| 9 |
+
def test_initial_state():
|
| 10 |
+
"""
|
| 11 |
+
Test that the initial state has the expected shapes
|
| 12 |
+
"""
|
| 13 |
+
ts = TransformerTreeStack(3, 5, 0.0)
|
| 14 |
+
initial = ts.initial_state()
|
| 15 |
+
assert len(initial) == 1
|
| 16 |
+
assert initial.value.output.shape == torch.Size([5])
|
| 17 |
+
assert initial.value.key_stack.shape == torch.Size([1, 5])
|
| 18 |
+
assert initial.value.value_stack.shape == torch.Size([1, 5])
|
| 19 |
+
|
| 20 |
+
def test_output():
|
| 21 |
+
"""
|
| 22 |
+
Test that you can get an expected output shape from the TTS
|
| 23 |
+
"""
|
| 24 |
+
ts = TransformerTreeStack(3, 5, 0.0)
|
| 25 |
+
initial = ts.initial_state()
|
| 26 |
+
out = ts.output(initial)
|
| 27 |
+
assert out.shape == torch.Size([5])
|
| 28 |
+
assert torch.allclose(initial.value.output, out)
|
| 29 |
+
|
| 30 |
+
def test_push_state_single():
|
| 31 |
+
"""
|
| 32 |
+
Test that stacks are being updated correctly when using a single stack
|
| 33 |
+
|
| 34 |
+
Values of the attention are not verified, though
|
| 35 |
+
"""
|
| 36 |
+
ts = TransformerTreeStack(3, 5, 0.0)
|
| 37 |
+
initial = ts.initial_state()
|
| 38 |
+
rand_input = torch.randn(1, 3)
|
| 39 |
+
stacks = ts.push_states([initial], ["A"], rand_input)
|
| 40 |
+
stacks = ts.push_states(stacks, ["B"], rand_input)
|
| 41 |
+
assert len(stacks) == 1
|
| 42 |
+
assert len(stacks[0]) == 3
|
| 43 |
+
assert stacks[0].value.value == "B"
|
| 44 |
+
assert stacks[0].pop().value.value == "A"
|
| 45 |
+
assert stacks[0].pop().pop().value.value is None
|
| 46 |
+
|
| 47 |
+
def test_push_state_same_length():
|
| 48 |
+
"""
|
| 49 |
+
Test that stacks are being updated correctly when using 3 stacks of the same length
|
| 50 |
+
|
| 51 |
+
Values of the attention are not verified, though
|
| 52 |
+
"""
|
| 53 |
+
ts = TransformerTreeStack(3, 5, 0.0)
|
| 54 |
+
initial = ts.initial_state()
|
| 55 |
+
rand_input = torch.randn(3, 3)
|
| 56 |
+
stacks = ts.push_states([initial, initial, initial], ["A", "A", "A"], rand_input)
|
| 57 |
+
stacks = ts.push_states(stacks, ["B", "B", "B"], rand_input)
|
| 58 |
+
stacks = ts.push_states(stacks, ["C", "C", "C"], rand_input)
|
| 59 |
+
assert len(stacks) == 3
|
| 60 |
+
for s in stacks:
|
| 61 |
+
assert len(s) == 4
|
| 62 |
+
assert s.value.key_stack.shape == torch.Size([4, 5])
|
| 63 |
+
assert s.value.value_stack.shape == torch.Size([4, 5])
|
| 64 |
+
assert s.value.value == "C"
|
| 65 |
+
assert s.pop().value.value == "B"
|
| 66 |
+
assert s.pop().pop().value.value == "A"
|
| 67 |
+
assert s.pop().pop().pop().value.value is None
|
| 68 |
+
|
| 69 |
+
def test_push_state_different_length():
|
| 70 |
+
"""
|
| 71 |
+
Test what happens if stacks of different lengths are passed in
|
| 72 |
+
"""
|
| 73 |
+
ts = TransformerTreeStack(3, 5, 0.0)
|
| 74 |
+
initial = ts.initial_state()
|
| 75 |
+
rand_input = torch.randn(2, 3)
|
| 76 |
+
one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
|
| 77 |
+
stacks = [one_step, initial]
|
| 78 |
+
stacks = ts.push_states(stacks, ["B", "C"], rand_input)
|
| 79 |
+
assert len(stacks) == 2
|
| 80 |
+
assert len(stacks[0]) == 3
|
| 81 |
+
assert len(stacks[1]) == 2
|
| 82 |
+
assert stacks[0].pop().value.value == 'A'
|
| 83 |
+
assert stacks[0].value.value == 'B'
|
| 84 |
+
assert stacks[1].value.value == 'C'
|
| 85 |
+
|
| 86 |
+
assert stacks[0].value.key_stack.shape == torch.Size([3, 5])
|
| 87 |
+
assert stacks[1].value.key_stack.shape == torch.Size([2, 5])
|
| 88 |
+
|
| 89 |
+
def test_mask():
|
| 90 |
+
"""
|
| 91 |
+
Test that a mask prevents the softmax from picking up unwanted values
|
| 92 |
+
"""
|
| 93 |
+
ts = TransformerTreeStack(3, 5, 0.0)
|
| 94 |
+
|
| 95 |
+
random_v = torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5]]])
|
| 96 |
+
double_v = random_v * 2
|
| 97 |
+
value = torch.cat([random_v, double_v], axis=1)
|
| 98 |
+
random_k = torch.randn(1, 1, 5)
|
| 99 |
+
key = torch.cat([random_k, random_k], axis=1)
|
| 100 |
+
query = torch.randn(1, 5)
|
| 101 |
+
|
| 102 |
+
output = ts.attention(key, query, value)
|
| 103 |
+
# when the two keys are equal, we expect the attention to be 50/50
|
| 104 |
+
expected_output = (random_v + double_v) / 2
|
| 105 |
+
assert torch.allclose(output, expected_output)
|
| 106 |
+
|
| 107 |
+
# If the first entry is masked out, the second one should be the
|
| 108 |
+
# only one represented
|
| 109 |
+
mask = torch.zeros(1, 2, dtype=torch.bool)
|
| 110 |
+
mask[0][0] = True
|
| 111 |
+
output = ts.attention(key, query, value, mask)
|
| 112 |
+
assert torch.allclose(output, double_v)
|
| 113 |
+
|
| 114 |
+
# If the second entry is masked out, the first one should be the
|
| 115 |
+
# only one represented
|
| 116 |
+
mask = torch.zeros(1, 2, dtype=torch.bool)
|
| 117 |
+
mask[0][1] = True
|
| 118 |
+
output = ts.attention(key, query, value, mask)
|
| 119 |
+
assert torch.allclose(output, random_v)
|
| 120 |
+
|
| 121 |
+
def test_position():
|
| 122 |
+
"""
|
| 123 |
+
Test that nothing goes horribly wrong when position encodings are used
|
| 124 |
+
|
| 125 |
+
Does not actually test the results of the encodings
|
| 126 |
+
"""
|
| 127 |
+
ts = TransformerTreeStack(4, 5, 0.0, use_position=True)
|
| 128 |
+
initial = ts.initial_state()
|
| 129 |
+
assert len(initial) == 1
|
| 130 |
+
assert initial.value.output.shape == torch.Size([5])
|
| 131 |
+
assert initial.value.key_stack.shape == torch.Size([1, 5])
|
| 132 |
+
assert initial.value.value_stack.shape == torch.Size([1, 5])
|
| 133 |
+
|
| 134 |
+
rand_input = torch.randn(2, 4)
|
| 135 |
+
one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
|
| 136 |
+
stacks = [one_step, initial]
|
| 137 |
+
stacks = ts.push_states(stacks, ["B", "C"], rand_input)
|
| 138 |
+
|
| 139 |
+
def test_length_limit():
|
| 140 |
+
"""
|
| 141 |
+
Test that the length limit drops nodes as the length limit is exceeded
|
| 142 |
+
"""
|
| 143 |
+
ts = TransformerTreeStack(4, 5, 0.0, length_limit = 2)
|
| 144 |
+
initial = ts.initial_state()
|
| 145 |
+
assert len(initial) == 1
|
| 146 |
+
assert initial.value.output.shape == torch.Size([5])
|
| 147 |
+
assert initial.value.key_stack.shape == torch.Size([1, 5])
|
| 148 |
+
assert initial.value.value_stack.shape == torch.Size([1, 5])
|
| 149 |
+
|
| 150 |
+
data = torch.tensor([[0.1, 0.2, 0.3, 0.4]])
|
| 151 |
+
stacks = ts.push_states([initial], ["A"], data)
|
| 152 |
+
|
| 153 |
+
stacks = ts.push_states(stacks, ["B"], data)
|
| 154 |
+
assert len(stacks) == 1
|
| 155 |
+
assert len(stacks[0]) == 3
|
| 156 |
+
assert stacks[0].value.key_stack.shape[0] == 3
|
| 157 |
+
assert stacks[0].value.value_stack.shape[0] == 3
|
| 158 |
+
|
| 159 |
+
stacks = ts.push_states(stacks, ["C"], data)
|
| 160 |
+
assert len(stacks) == 1
|
| 161 |
+
assert len(stacks[0]) == 4
|
| 162 |
+
assert stacks[0].value.key_stack.shape[0] == 3
|
| 163 |
+
assert stacks[0].value.value_stack.shape[0] == 3
|
| 164 |
+
|
| 165 |
+
stacks = ts.push_states(stacks, ["D"], data)
|
| 166 |
+
assert len(stacks) == 1
|
| 167 |
+
assert len(stacks[0]) == 5
|
| 168 |
+
assert stacks[0].value.key_stack.shape[0] == 3
|
| 169 |
+
assert stacks[0].value.value_stack.shape[0] == 3
|
| 170 |
+
|
| 171 |
+
def test_two_heads():
|
| 172 |
+
"""
|
| 173 |
+
Test that the length limit drops nodes as the length limit is exceeded
|
| 174 |
+
"""
|
| 175 |
+
ts = TransformerTreeStack(4, 6, 0.0, num_heads = 2)
|
| 176 |
+
initial = ts.initial_state()
|
| 177 |
+
assert len(initial) == 1
|
| 178 |
+
assert initial.value.output.shape == torch.Size([6])
|
| 179 |
+
assert initial.value.key_stack.shape == torch.Size([1, 6])
|
| 180 |
+
assert initial.value.value_stack.shape == torch.Size([1, 6])
|
| 181 |
+
|
| 182 |
+
rand_input = torch.randn(2, 4)
|
| 183 |
+
one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0]
|
| 184 |
+
stacks = [one_step, initial]
|
| 185 |
+
stacks = ts.push_states(stacks, ["B", "C"], rand_input)
|
| 186 |
+
assert len(stacks) == 2
|
| 187 |
+
assert len(stacks[0]) == 3
|
| 188 |
+
assert len(stacks[1]) == 2
|
| 189 |
+
assert stacks[0].pop().value.value == 'A'
|
| 190 |
+
assert stacks[0].value.value == 'B'
|
| 191 |
+
assert stacks[1].value.value == 'C'
|
| 192 |
+
|
| 193 |
+
assert stacks[0].value.key_stack.shape == torch.Size([3, 6])
|
| 194 |
+
assert stacks[1].value.key_stack.shape == torch.Size([2, 6])
|
| 195 |
+
|
stanza/stanza/tests/constituency/test_transition_sequence.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from stanza.models.constituency import parse_transitions
|
| 3 |
+
from stanza.models.constituency import transition_sequence
|
| 4 |
+
from stanza.models.constituency import tree_reader
|
| 5 |
+
from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT
|
| 6 |
+
from stanza.models.constituency.parse_transitions import *
|
| 7 |
+
|
| 8 |
+
from stanza.tests import *
|
| 9 |
+
from stanza.tests.constituency.test_parse_tree import CHINESE_LONG_LIST_TREE
|
| 10 |
+
|
| 11 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 12 |
+
|
| 13 |
+
def reconstruct_tree(tree, sequence, transition_scheme=TransitionScheme.IN_ORDER, unary_limit=UNARY_LIMIT, reverse=False):
|
| 14 |
+
"""
|
| 15 |
+
Starting from a tree and a list of transitions, build the tree caused by the transitions
|
| 16 |
+
"""
|
| 17 |
+
model = SimpleModel(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse)
|
| 18 |
+
states = model.initial_state_from_gold_trees([tree])
|
| 19 |
+
assert(len(states)) == 1
|
| 20 |
+
assert states[0].num_transitions == 0
|
| 21 |
+
|
| 22 |
+
# TODO: could fold this into parse_sentences (similar to verify_transitions in trainer.py)
|
| 23 |
+
for idx, t in enumerate(sequence):
|
| 24 |
+
assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence)
|
| 25 |
+
states = model.bulk_apply(states, [t])
|
| 26 |
+
|
| 27 |
+
result_tree = states[0].constituents.value
|
| 28 |
+
if reverse:
|
| 29 |
+
result_tree = result_tree.reverse()
|
| 30 |
+
return result_tree
|
| 31 |
+
|
| 32 |
+
def check_reproduce_tree(transition_scheme):
|
| 33 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 34 |
+
trees = tree_reader.read_trees(text)
|
| 35 |
+
|
| 36 |
+
model = SimpleModel(transition_scheme)
|
| 37 |
+
transitions = transition_sequence.build_sequence(trees[0], transition_scheme)
|
| 38 |
+
states = model.initial_state_from_gold_trees(trees)
|
| 39 |
+
assert(len(states)) == 1
|
| 40 |
+
state = states[0]
|
| 41 |
+
assert state.num_transitions == 0
|
| 42 |
+
|
| 43 |
+
for t in transitions:
|
| 44 |
+
assert t.is_legal(state, model)
|
| 45 |
+
state = t.apply(state, model)
|
| 46 |
+
|
| 47 |
+
# one item for the final tree
|
| 48 |
+
# one item for the sentinel at the end
|
| 49 |
+
assert len(state.constituents) == 2
|
| 50 |
+
# the transition sequence should put all of the words
|
| 51 |
+
# from the buffer onto the tree
|
| 52 |
+
# one spot left for the sentinel value
|
| 53 |
+
assert len(state.word_queue) == 8
|
| 54 |
+
assert state.sentence_length == 6
|
| 55 |
+
assert state.word_position == state.sentence_length
|
| 56 |
+
assert len(state.transitions) == len(transitions) + 1
|
| 57 |
+
|
| 58 |
+
result_tree = state.constituents.value
|
| 59 |
+
assert result_tree == trees[0]
|
| 60 |
+
|
| 61 |
+
def test_top_down_unary():
|
| 62 |
+
check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY)
|
| 63 |
+
|
| 64 |
+
def test_top_down_no_unary():
|
| 65 |
+
check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN)
|
| 66 |
+
|
| 67 |
+
def test_in_order():
|
| 68 |
+
check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER)
|
| 69 |
+
|
| 70 |
+
def test_in_order_compound():
|
| 71 |
+
check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND)
|
| 72 |
+
|
| 73 |
+
def test_in_order_unary():
|
| 74 |
+
check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_UNARY)
|
| 75 |
+
|
| 76 |
+
def test_all_transitions():
|
| 77 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 78 |
+
trees = tree_reader.read_trees(text)
|
| 79 |
+
model = SimpleModel()
|
| 80 |
+
transitions = transition_sequence.build_treebank(trees)
|
| 81 |
+
|
| 82 |
+
expected = [Shift(), CloseConstituent(), CompoundUnary("ROOT"), CompoundUnary("SQ"), CompoundUnary("WHNP"), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("SBARQ"), OpenConstituent("VP")]
|
| 83 |
+
assert transition_sequence.all_transitions(transitions) == expected
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def test_all_transitions_no_unary():
|
| 87 |
+
text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 88 |
+
trees = tree_reader.read_trees(text)
|
| 89 |
+
model = SimpleModel()
|
| 90 |
+
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 91 |
+
|
| 92 |
+
expected = [Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("ROOT"), OpenConstituent("SBARQ"), OpenConstituent("SQ"), OpenConstituent("VP"), OpenConstituent("WHNP")]
|
| 93 |
+
assert transition_sequence.all_transitions(transitions) == expected
|
| 94 |
+
|
| 95 |
+
def test_top_down_compound_unary():
|
| 96 |
+
text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))"
|
| 97 |
+
|
| 98 |
+
trees = tree_reader.read_trees(text)
|
| 99 |
+
assert len(trees) == 1
|
| 100 |
+
|
| 101 |
+
model = SimpleModel()
|
| 102 |
+
transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND)
|
| 103 |
+
|
| 104 |
+
states = model.initial_state_from_gold_trees(trees)
|
| 105 |
+
assert len(states) == 1
|
| 106 |
+
state = states[0]
|
| 107 |
+
|
| 108 |
+
for t in transitions:
|
| 109 |
+
assert t.is_legal(state, model)
|
| 110 |
+
state = t.apply(state, model)
|
| 111 |
+
|
| 112 |
+
result = model.get_top_constituent(state.constituents)
|
| 113 |
+
assert trees[0] == result
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_chinese_tree():
|
| 117 |
+
trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
|
| 118 |
+
|
| 119 |
+
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN)
|
| 120 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)
|
| 121 |
+
assert redone == trees[0]
|
| 122 |
+
|
| 123 |
+
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER)
|
| 124 |
+
with pytest.raises(AssertionError):
|
| 125 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER)
|
| 126 |
+
|
| 127 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
|
| 128 |
+
assert redone == trees[0]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_chinese_tree_reversed():
|
| 132 |
+
"""
|
| 133 |
+
test that the reversed transitions also work
|
| 134 |
+
"""
|
| 135 |
+
trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE)
|
| 136 |
+
|
| 137 |
+
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)
|
| 138 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN, reverse=True)
|
| 139 |
+
assert redone == trees[0]
|
| 140 |
+
|
| 141 |
+
with pytest.raises(AssertionError):
|
| 142 |
+
# turn off reverse - it should fail to rebuild the tree
|
| 143 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN)
|
| 144 |
+
assert redone == trees[0]
|
| 145 |
+
|
| 146 |
+
transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER, reverse=True)
|
| 147 |
+
with pytest.raises(AssertionError):
|
| 148 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, reverse=True)
|
| 149 |
+
|
| 150 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6, reverse=True)
|
| 151 |
+
assert redone == trees[0]
|
| 152 |
+
|
| 153 |
+
with pytest.raises(AssertionError):
|
| 154 |
+
# turn off reverse - it should fail to rebuild the tree
|
| 155 |
+
redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6)
|
| 156 |
+
assert redone == trees[0]
|
stanza/stanza/tests/constituency/test_tree_reader.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from stanza.models.constituency import tree_reader
|
| 3 |
+
from stanza.models.constituency.tree_reader import MixedTreeError, UnclosedTreeError, UnlabeledTreeError
|
| 4 |
+
|
| 5 |
+
from stanza.tests import *
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 8 |
+
|
| 9 |
+
def test_simple():
|
| 10 |
+
"""
|
| 11 |
+
Tests reading two simple trees from the same text
|
| 12 |
+
"""
|
| 13 |
+
text = "(VB Unban) (NNP Opal)"
|
| 14 |
+
trees = tree_reader.read_trees(text)
|
| 15 |
+
assert len(trees) == 2
|
| 16 |
+
assert trees[0].is_preterminal()
|
| 17 |
+
assert trees[0].label == 'VB'
|
| 18 |
+
assert trees[0].children[0].label == 'Unban'
|
| 19 |
+
assert trees[1].is_preterminal()
|
| 20 |
+
assert trees[1].label == 'NNP'
|
| 21 |
+
assert trees[1].children[0].label == 'Opal'
|
| 22 |
+
|
| 23 |
+
def test_newlines():
|
| 24 |
+
"""
|
| 25 |
+
The same test should work if there are newlines
|
| 26 |
+
"""
|
| 27 |
+
text = "(VB Unban)\n\n(NNP Opal)"
|
| 28 |
+
trees = tree_reader.read_trees(text)
|
| 29 |
+
assert len(trees) == 2
|
| 30 |
+
|
| 31 |
+
def test_parens():
|
| 32 |
+
"""
|
| 33 |
+
Parens should be escaped in the tree files and escaped when written
|
| 34 |
+
"""
|
| 35 |
+
text = "(-LRB- -LRB-) (-RRB- -RRB-)"
|
| 36 |
+
trees = tree_reader.read_trees(text)
|
| 37 |
+
assert len(trees) == 2
|
| 38 |
+
|
| 39 |
+
assert trees[0].label == '-LRB-'
|
| 40 |
+
assert trees[0].children[0].label == '('
|
| 41 |
+
assert "{}".format(trees[0]) == '(-LRB- -LRB-)'
|
| 42 |
+
|
| 43 |
+
assert trees[1].label == '-RRB-'
|
| 44 |
+
assert trees[1].children[0].label == ')'
|
| 45 |
+
assert "{}".format(trees[1]) == '(-RRB- -RRB-)'
|
| 46 |
+
|
| 47 |
+
def test_complicated():
|
| 48 |
+
"""
|
| 49 |
+
A more complicated tree that should successfully read
|
| 50 |
+
"""
|
| 51 |
+
text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))"
|
| 52 |
+
trees = tree_reader.read_trees(text)
|
| 53 |
+
assert len(trees) == 1
|
| 54 |
+
tree = trees[0]
|
| 55 |
+
assert not tree.is_leaf()
|
| 56 |
+
assert not tree.is_preterminal()
|
| 57 |
+
assert tree.label == 'ROOT'
|
| 58 |
+
assert len(tree.children) == 1
|
| 59 |
+
assert tree.children[0].label == 'SBARQ'
|
| 60 |
+
assert len(tree.children[0].children) == 3
|
| 61 |
+
assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.']
|
| 62 |
+
# etc etc
|
| 63 |
+
|
| 64 |
+
def test_one_word():
|
| 65 |
+
"""
|
| 66 |
+
Check that one node trees are correctly read
|
| 67 |
+
|
| 68 |
+
probably not super relevant for the parsing use case
|
| 69 |
+
"""
|
| 70 |
+
text="(FOO) (BAR)"
|
| 71 |
+
trees = tree_reader.read_trees(text)
|
| 72 |
+
assert len(trees) == 2
|
| 73 |
+
|
| 74 |
+
assert trees[0].is_leaf()
|
| 75 |
+
assert trees[0].label == 'FOO'
|
| 76 |
+
|
| 77 |
+
assert trees[1].is_leaf()
|
| 78 |
+
assert trees[1].label == 'BAR'
|
| 79 |
+
|
| 80 |
+
def test_missing_close_parens():
|
| 81 |
+
"""
|
| 82 |
+
Test the unclosed error condition
|
| 83 |
+
"""
|
| 84 |
+
text = "(Foo) \n (Bar \n zzz"
|
| 85 |
+
try:
|
| 86 |
+
trees = tree_reader.read_trees(text)
|
| 87 |
+
raise AssertionError("Expected an exception")
|
| 88 |
+
except UnclosedTreeError as e:
|
| 89 |
+
assert e.line_num == 1
|
| 90 |
+
|
| 91 |
+
def test_mixed_tree():
|
| 92 |
+
"""
|
| 93 |
+
Test the mixed error condition
|
| 94 |
+
"""
|
| 95 |
+
text = "(Foo) \n (Bar) \n (Unban (Mox) Opal)"
|
| 96 |
+
try:
|
| 97 |
+
trees = tree_reader.read_trees(text)
|
| 98 |
+
raise AssertionError("Expected an exception")
|
| 99 |
+
except MixedTreeError as e:
|
| 100 |
+
assert e.line_num == 2
|
| 101 |
+
|
| 102 |
+
trees = tree_reader.read_trees(text, broken_ok=True)
|
| 103 |
+
assert len(trees) == 3
|
| 104 |
+
|
| 105 |
+
def test_unlabeled_tree():
|
| 106 |
+
"""
|
| 107 |
+
Test the unlabeled error condition
|
| 108 |
+
"""
|
| 109 |
+
text = "(ROOT ((Foo) (Bar)))"
|
| 110 |
+
try:
|
| 111 |
+
trees = tree_reader.read_trees(text)
|
| 112 |
+
raise AssertionError("Expected an exception")
|
| 113 |
+
except UnlabeledTreeError as e:
|
| 114 |
+
assert e.line_num == 0
|
| 115 |
+
|
| 116 |
+
trees = tree_reader.read_trees(text, broken_ok=True)
|
| 117 |
+
assert len(trees) == 1
|
| 118 |
+
|
| 119 |
+
|
stanza/stanza/tests/constituency/test_vietnamese.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A few tests for Vietnamese parsing, which has some difficulties related to spaces in words
|
| 3 |
+
|
| 4 |
+
Technically some other languages can have this, too, like that one French token
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
|
| 12 |
+
from stanza.models.common import pretrain
|
| 13 |
+
from stanza.models.constituency import tree_reader
|
| 14 |
+
|
| 15 |
+
from stanza.tests.constituency.test_trainer import build_trainer
|
| 16 |
+
|
| 17 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 18 |
+
|
| 19 |
+
VI_TREEBANK = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài Loan) (" ") (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'
|
| 20 |
+
|
| 21 |
+
VI_TREEBANK_UNDERSCORE = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài_Loan) (" ") (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .)))'
|
| 22 |
+
|
| 23 |
+
VI_TREEBANK_SIMPLE = '(ROOT (S (NP (" ") (N Đảo) (Np Đài Loan) (" ") (PP (E ở) (NP (N đồng bằng) (NP (N sông) (Np Cửu Long))))) (. .)))'
|
| 24 |
+
|
| 25 |
+
VI_TREEBANK_PAREN = '(ROOT (S-TTL (NP (PUNCT -LRB-) (N-H Đảo) (Np Đài Loan) (PUNCT -RRB-) (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))'
|
| 26 |
+
VI_TREEBANK_VLSP = '<s>\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n</s>'
|
| 27 |
+
VI_TREEBANK_VLSP_50 = '<s id=50>\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n</s>'
|
| 28 |
+
VI_TREEBANK_VLSP_100 = '<s id=100>\n(S-TTL (NP (PUNCT LBKT) (N-H Đảo) (Np Đài_Loan) (PUNCT RBKT) (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .))\n</s>'
|
| 29 |
+
|
| 30 |
+
EXPECTED_LABELED_BRACKETS = '(_ROOT (_S (_NP (_" " )_" (_N Đảo )_N (_Np Đài_Loan )_Np (_" " )_" (_PP (_E ở )_E (_NP (_N đồng_bằng )_N (_NP (_N sông )_N (_Np Cửu_Long )_Np )_NP )_NP )_PP )_NP (_. . )_. )_S )_ROOT'
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_read_vi_tree():
|
| 34 |
+
"""
|
| 35 |
+
Test that an individual tree with spaces in the leaves is being processed as we expect
|
| 36 |
+
"""
|
| 37 |
+
text = VI_TREEBANK.split("\n")[0]
|
| 38 |
+
trees = tree_reader.read_trees(text)
|
| 39 |
+
assert len(trees) == 1
|
| 40 |
+
assert str(trees[0]) == text
|
| 41 |
+
# this is the first NP
|
| 42 |
+
# the third node of that NP, eg (Np Đài Loan)
|
| 43 |
+
node = trees[0].children[0].children[0].children[2]
|
| 44 |
+
assert node.is_preterminal()
|
| 45 |
+
assert node.children[0].label == "Đài Loan"
|
| 46 |
+
|
| 47 |
+
VI_EMBEDDING = """
|
| 48 |
+
4 4
|
| 49 |
+
Đảo 0.11 0.21 0.31 0.41
|
| 50 |
+
Đài Loan 0.12 0.22 0.32 0.42
|
| 51 |
+
đồng bằng 0.13 0.23 0.33 0.43
|
| 52 |
+
sông 0.14 0.24 0.34 0.44
|
| 53 |
+
""".strip()
|
| 54 |
+
|
| 55 |
+
def test_vi_embedding():
|
| 56 |
+
"""
|
| 57 |
+
Test that a VI embedding's words are correctly found when processing trees
|
| 58 |
+
"""
|
| 59 |
+
text = VI_TREEBANK.split("\n")[0]
|
| 60 |
+
trees = tree_reader.read_trees(text)
|
| 61 |
+
words = set(trees[0].leaf_labels())
|
| 62 |
+
|
| 63 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 64 |
+
emb_filename = os.path.join(tempdir, "emb.txt")
|
| 65 |
+
pt_filename = os.path.join(tempdir, "emb.pt")
|
| 66 |
+
with open(emb_filename, "w", encoding="utf-8") as fout:
|
| 67 |
+
fout.write(VI_EMBEDDING)
|
| 68 |
+
pt = pretrain.Pretrain(filename=pt_filename, vec_filename=emb_filename, save_to_file=True)
|
| 69 |
+
pt.load()
|
| 70 |
+
|
| 71 |
+
trainer = build_trainer(pt_filename)
|
| 72 |
+
model = trainer.model
|
| 73 |
+
|
| 74 |
+
assert model.num_words_known(words) == 4
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_space_formatting():
|
| 78 |
+
"""
|
| 79 |
+
By default, spaces are left as spaces, but there is a format option to change spaces
|
| 80 |
+
"""
|
| 81 |
+
text = VI_TREEBANK.split("\n")[0]
|
| 82 |
+
trees = tree_reader.read_trees(text)
|
| 83 |
+
assert len(trees) == 1
|
| 84 |
+
assert str(trees[0]) == text
|
| 85 |
+
|
| 86 |
+
assert "{}".format(trees[0]) == VI_TREEBANK
|
| 87 |
+
assert "{:_O}".format(trees[0]) == VI_TREEBANK_UNDERSCORE
|
| 88 |
+
|
| 89 |
+
def test_vlsp_formatting():
|
| 90 |
+
text = VI_TREEBANK_PAREN.split("\n")[0]
|
| 91 |
+
trees = tree_reader.read_trees(text)
|
| 92 |
+
assert len(trees) == 1
|
| 93 |
+
assert str(trees[0]) == text
|
| 94 |
+
|
| 95 |
+
assert "{:_V}".format(trees[0]) == VI_TREEBANK_VLSP
|
| 96 |
+
trees[0].tree_id = 50
|
| 97 |
+
assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_50
|
| 98 |
+
trees[0].tree_id = 100
|
| 99 |
+
assert "{:_Vi}".format(trees[0]) == VI_TREEBANK_VLSP_100
|
| 100 |
+
|
| 101 |
+
empty = tree_reader.read_trees("(ROOT)")[0]
|
| 102 |
+
with pytest.raises(ValueError):
|
| 103 |
+
"{:V}".format(empty)
|
| 104 |
+
|
| 105 |
+
branches = tree_reader.read_trees("(ROOT (1) (2) (3))")[0]
|
| 106 |
+
with pytest.raises(ValueError):
|
| 107 |
+
"{:V}".format(branches)
|
| 108 |
+
|
| 109 |
+
def test_language_formatting():
|
| 110 |
+
"""
|
| 111 |
+
Test turning the parse tree into a 'language' for GPT
|
| 112 |
+
"""
|
| 113 |
+
text = VI_TREEBANK.split("\n")[0]
|
| 114 |
+
trees = tree_reader.read_trees(text)
|
| 115 |
+
trees = [t.prune_none().simplify_labels() for t in trees]
|
| 116 |
+
assert len(trees) == 1
|
| 117 |
+
assert str(trees[0]) == VI_TREEBANK_SIMPLE
|
| 118 |
+
|
| 119 |
+
text = "{:L}".format(trees[0])
|
| 120 |
+
assert text == EXPECTED_LABELED_BRACKETS
|
| 121 |
+
|
stanza/stanza/tests/langid/test_langid.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic tests of langid module
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from stanza.models.common.doc import Document
|
| 8 |
+
from stanza.pipeline.core import Pipeline
|
| 9 |
+
from stanza.pipeline.langid_processor import LangIDProcessor
|
| 10 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 11 |
+
|
| 12 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 13 |
+
|
| 14 |
+
#pytestmark = pytest.mark.skip
|
| 15 |
+
|
| 16 |
+
@pytest.fixture(scope="module")
|
| 17 |
+
def basic_multilingual():
|
| 18 |
+
return Pipeline(dir=TEST_MODELS_DIR, lang='multilingual', processors="langid")
|
| 19 |
+
|
| 20 |
+
@pytest.fixture(scope="module")
|
| 21 |
+
def enfr_multilingual():
|
| 22 |
+
return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en", "fr"])
|
| 23 |
+
|
| 24 |
+
@pytest.fixture(scope="module")
|
| 25 |
+
def en_multilingual():
|
| 26 |
+
return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_lang_subset=["en"])
|
| 27 |
+
|
| 28 |
+
@pytest.fixture(scope="module")
|
| 29 |
+
def clean_multilingual():
|
| 30 |
+
return Pipeline(dir=TEST_MODELS_DIR, lang="multilingual", processors="langid", langid_clean_text=True)
|
| 31 |
+
|
| 32 |
+
def test_langid(basic_multilingual):
|
| 33 |
+
"""
|
| 34 |
+
Basic test of language identification
|
| 35 |
+
"""
|
| 36 |
+
english_text = "This is an English sentence."
|
| 37 |
+
french_text = "C'est une phrase française."
|
| 38 |
+
docs = [english_text, french_text]
|
| 39 |
+
|
| 40 |
+
docs = [Document([], text=text) for text in docs]
|
| 41 |
+
basic_multilingual(docs)
|
| 42 |
+
predictions = [doc.lang for doc in docs]
|
| 43 |
+
assert predictions == ["en", "fr"]
|
| 44 |
+
|
| 45 |
+
def test_langid_benchmark(basic_multilingual):
|
| 46 |
+
"""
|
| 47 |
+
Run lang id model on 500 examples, confirm reasonable accuracy.
|
| 48 |
+
"""
|
| 49 |
+
examples = [
|
| 50 |
+
{"text": "contingentiam in naturalibus causis.", "label": "la"},
|
| 51 |
+
{"text": "I jak opowiadał nieżyjący już pan Czesław", "label": "pl"},
|
| 52 |
+
{"text": "Sonera gilt seit längerem als Übernahmekandidat", "label": "de"},
|
| 53 |
+
{"text": "与银类似,汞也可以与空气中的硫化氢反应。", "label": "zh-hans"},
|
| 54 |
+
{"text": "contradictionem implicat.", "label": "la"},
|
| 55 |
+
{"text": "Bis zu Prozent gingen die Offerten etwa im", "label": "de"},
|
| 56 |
+
{"text": "inneren Sicherheit vorgeschlagene Ausweitung der", "label": "de"},
|
| 57 |
+
{"text": "Multimedia-PDA mit Mini-Tastatur", "label": "de"},
|
| 58 |
+
{"text": "Ponášalo sa to na rovnicu o dvoch neznámych.", "label": "sk"},
|
| 59 |
+
{"text": "이처럼 앞으로 심판의 그 날에 다시 올 메시아가 예수 그리스도이며 , 그는 모든 인류의", "label": "ko"},
|
| 60 |
+
{"text": "Die Arbeitsgruppe bedauert , dass der weit über", "label": "de"},
|
| 61 |
+
{"text": "И только раз довелось поговорить с ним не вполне", "label": "ru"},
|
| 62 |
+
{"text": "de a-l lovi cu piciorul și conștiința că era", "label": "ro"},
|
| 63 |
+
{"text": "relación coas pretensións do demandante e que, nos", "label": "gl"},
|
| 64 |
+
{"text": "med petdeset in sedemdeset", "label": "sl"},
|
| 65 |
+
{"text": "Catalunya; el Consell Comarcal del Vallès Oriental", "label": "ca"},
|
| 66 |
+
{"text": "kunnen worden.", "label": "nl"},
|
| 67 |
+
{"text": "Witkin je ve většině ohledů zcela jiný.", "label": "cs"},
|
| 68 |
+
{"text": "lernen, so zu agieren, dass sie positive oder auch", "label": "de"},
|
| 69 |
+
{"text": "olurmuş...", "label": "tr"},
|
| 70 |
+
{"text": "sarcasmo de Altman, desde as «peruas» que discutem", "label": "pt"},
|
| 71 |
+
{"text": "خلاف فوجداری مقدمہ درج کرے۔", "label": "ur"},
|
| 72 |
+
{"text": "Norddal kommune :", "label": "no"},
|
| 73 |
+
{"text": "dem Windows-.-Zeitalter , soll in diesem Jahr", "label": "de"},
|
| 74 |
+
{"text": "przeklętych ucieleśniają mit poety-cygana,", "label": "pl"},
|
| 75 |
+
{"text": "We do not believe the suspect has ties to this", "label": "en"},
|
| 76 |
+
{"text": "groziņu pīšanu.", "label": "lv"},
|
| 77 |
+
{"text": "Senior Vice-President David M. Thomas möchte", "label": "de"},
|
| 78 |
+
{"text": "neomylně vybral nějakou knihu a začetl se.", "label": "cs"},
|
| 79 |
+
{"text": "Statt dessen darf beispielsweise der Browser des", "label": "de"},
|
| 80 |
+
{"text": "outubro, alcançando R $ bilhões em .", "label": "pt"},
|
| 81 |
+
{"text": "(Porte, ), as it does other disciplines", "label": "en"},
|
| 82 |
+
{"text": "uskupení se mylně domnívaly, že podporu", "label": "cs"},
|
| 83 |
+
{"text": "Übernahme von Next Ende an dem System herum , das", "label": "de"},
|
| 84 |
+
{"text": "No podemos decir a la Hacienda que los alemanes", "label": "es"},
|
| 85 |
+
{"text": "и рѣста еи братья", "label": "orv"},
|
| 86 |
+
{"text": "الذي اتخذ قرارا بتجميد اعلان الدولة الفلسطينية", "label": "ar"},
|
| 87 |
+
{"text": "uurides Rootsi sõjaarhiivist toodud . sajandi", "label": "et"},
|
| 88 |
+
{"text": "selskapets penger til å pusse opp sin enebolig på", "label": "no"},
|
| 89 |
+
{"text": "средней полосе и севернее в Ярославской,", "label": "ru"},
|
| 90 |
+
{"text": "il-massa żejda fil-ġemgħat u superġemgħat ta'", "label": "mt"},
|
| 91 |
+
{"text": "The Global Beauties on internetilehekülg, mida", "label": "et"},
|
| 92 |
+
{"text": "이스라엘 인들은 하나님이 그 큰 팔을 펴 이집트 인들을 치는 것을 보고 하나님을 두려워하며", "label": "ko"},
|
| 93 |
+
{"text": "Snad ještě dodejme jeden ekonomický argument.", "label": "cs"},
|
| 94 |
+
{"text": "Spalio d. vykusiame pirmajame rinkimų ture", "label": "lt"},
|
| 95 |
+
{"text": "und schlechter Journalismus ein gutes Geschäft .", "label": "de"},
|
| 96 |
+
{"text": "Du sodiečiai sėdi ant potvynio apsemtų namų stogo.", "label": "lt"},
|
| 97 |
+
{"text": "цей є автентичним.", "label": "uk"},
|
| 98 |
+
{"text": "Și îndegrabă fu cu îngerul mulțime de șireaguri", "label": "ro"},
|
| 99 |
+
{"text": "sobra personal cualificado.", "label": "es"},
|
| 100 |
+
{"text": "Tako se u Njemačkoj dvije trećine liječnika služe", "label": "hr"},
|
| 101 |
+
{"text": "Dual-Athlon-Chipsatz noch in diesem Jahr", "label": "de"},
|
| 102 |
+
{"text": "यहां तक कि चीन के चीफ ऑफ जनरल स्टाफ भी भारत का", "label": "hi"},
|
| 103 |
+
{"text": "Li forestier du mont avale", "label": "fro"},
|
| 104 |
+
{"text": "Netzwerken für Privatanwender zu bewundern .", "label": "de"},
|
| 105 |
+
{"text": "만해는 승적을 가진 중이 결혼할 수 없다는 불교의 계율을 시대에 맞지 않는 것으로 보았다", "label": "ko"},
|
| 106 |
+
{"text": "balance and weight distribution but not really for", "label": "en"},
|
| 107 |
+
{"text": "og så e # tente vi opp den om morgonen å sfyrte", "label": "nn"},
|
| 108 |
+
{"text": "변화는 의심의 여지가 없는 것이지만 반면에 진화는 논쟁의 씨앗이다 .", "label": "ko"},
|
| 109 |
+
{"text": "puteare fac aceastea.", "label": "ro"},
|
| 110 |
+
{"text": "Waitt seine Führungsmannschaft nicht dem", "label": "de"},
|
| 111 |
+
{"text": "juhtimisega, tulid sealt.", "label": "et"},
|
| 112 |
+
{"text": "Veränderungen .", "label": "de"},
|
| 113 |
+
{"text": "banda en el Bayer Leverkusen de la Bundesliga de", "label": "es"},
|
| 114 |
+
{"text": "В туже зиму посла всеволодъ сн҃а своѥго ст҃ослава", "label": "orv"},
|
| 115 |
+
{"text": "пославъ приведе я мастеры ѿ грекъ", "label": "orv"},
|
| 116 |
+
{"text": "En un nou escenari difícil d'imaginar fa poques", "label": "ca"},
|
| 117 |
+
{"text": "καὶ γὰρ τινὲς αὐτοὺς εὐεργεσίαι εἶχον ἐκ Κροίσου", "label": "grc"},
|
| 118 |
+
{"text": "직접적인 관련이 있다 .", "label": "ko"},
|
| 119 |
+
{"text": "가까운 듯하면서도 멀다 .", "label": "ko"},
|
| 120 |
+
{"text": "Er bietet ein ähnliches Leistungsniveau und", "label": "de"},
|
| 121 |
+
{"text": "民都洛水牛是獨居的,並不會以群族聚居。", "label": "zh-hant"},
|
| 122 |
+
{"text": "την τρομοκρατία.", "label": "el"},
|
| 123 |
+
{"text": "hurbiltzen diren neurrian.", "label": "eu"},
|
| 124 |
+
{"text": "Ah dimenticavo, ma tutta sta caciara per fare un", "label": "it"},
|
| 125 |
+
{"text": "На первом этапе (-) прошла так называемая", "label": "ru"},
|
| 126 |
+
{"text": "of games are on the market.", "label": "en"},
|
| 127 |
+
{"text": "находится Мост дружбы, соединяющий узбекский и", "label": "ru"},
|
| 128 |
+
{"text": "lessié je voldroie que li saint fussent aporté", "label": "fro"},
|
| 129 |
+
{"text": "Дошла очередь и до Гималаев.", "label": "ru"},
|
| 130 |
+
{"text": "vzácným suknem táhly pouští, si jednou chtěl do", "label": "cs"},
|
| 131 |
+
{"text": "E no terceiro tipo sitúa a familias (%), nos que a", "label": "gl"},
|
| 132 |
+
{"text": "وجابت دوريات امريكية وعراقية شوارع المدينة، فيما", "label": "ar"},
|
| 133 |
+
{"text": "Jeg har bodd her i år .", "label": "no"},
|
| 134 |
+
{"text": "Pohrozil, že odbory zostří postoj, pokud se", "label": "cs"},
|
| 135 |
+
{"text": "tinham conseguido.", "label": "pt"},
|
| 136 |
+
{"text": "Nicht-Erkrankten einen Anfangsverdacht für einen", "label": "de"},
|
| 137 |
+
{"text": "permanece em aberto.", "label": "pt"},
|
| 138 |
+
{"text": "questi possono promettere rendimenti fino a un", "label": "it"},
|
| 139 |
+
{"text": "Tema juurutatud kahevedurisüsteemita oleksid", "label": "et"},
|
| 140 |
+
{"text": "Поведение внешне простой игрушки оказалось", "label": "ru"},
|
| 141 |
+
{"text": "Bundesländern war vom Börsenverein des Deutschen", "label": "de"},
|
| 142 |
+
{"text": "acció, 'a mesura que avanci l'estiu, amb l'augment", "label": "ca"},
|
| 143 |
+
{"text": "Dove trovare queste risorse? Jay Naidoo, ministro", "label": "it"},
|
| 144 |
+
{"text": "essas gordurinhas.", "label": "pt"},
|
| 145 |
+
{"text": "Im zweiten Schritt sollen im übernächsten Jahr", "label": "de"},
|
| 146 |
+
{"text": "allveelaeva pole enam vaja, kuna külm sõda on läbi", "label": "et"},
|
| 147 |
+
{"text": "उपद्रवी दुकानों को लूटने के साथ ही उनमें आग लगा", "label": "hi"},
|
| 148 |
+
{"text": "@user nella sfortuna sei fortunata ..", "label": "it"},
|
| 149 |
+
{"text": "математических школ в виде грозовых туч.", "label": "ru"},
|
| 150 |
+
{"text": "No cambiaremos nunca nuestra forma de jugar por un", "label": "es"},
|
| 151 |
+
{"text": "dla tej klasy ani wymogów minimalnych, z wyjątkiem", "label": "pl"},
|
| 152 |
+
{"text": "en todo el mundo, mientras que en España consiguió", "label": "es"},
|
| 153 |
+
{"text": "политики считать надежное обеспечение военной", "label": "ru"},
|
| 154 |
+
{"text": "gogoratzen du, genio alemana delakoaren", "label": "eu"},
|
| 155 |
+
{"text": "Бычий глаз.", "label": "ru"},
|
| 156 |
+
{"text": "Opeření se v pravidelných obdobích obnovuje", "label": "cs"},
|
| 157 |
+
{"text": "I no és només la seva, es tracta d'una resposta", "label": "ca"},
|
| 158 |
+
{"text": "오경을 가르쳤다 .", "label": "ko"},
|
| 159 |
+
{"text": "Nach der so genannten Start-up-Periode vergibt die", "label": "de"},
|
| 160 |
+
{"text": "Saulista huomasi jo lapsena , että hänellä on", "label": "fi"},
|
| 161 |
+
{"text": "Министерство культуры сочло нецелесообразным, и", "label": "ru"},
|
| 162 |
+
{"text": "znepřátelené tábory v Tádžikistánu předseda", "label": "cs"},
|
| 163 |
+
{"text": "καὶ ἦν ὁ λαὸς προσδοκῶν τὸν Ζαχαρίαν καὶ ἐθαύμαζον", "label": "grc"},
|
| 164 |
+
{"text": "Вечером, в продукте, этот же человек говорил о", "label": "ru"},
|
| 165 |
+
{"text": "lugar á formación de xuizos máis complexos.", "label": "gl"},
|
| 166 |
+
{"text": "cheaper, in the end?", "label": "en"},
|
| 167 |
+
{"text": "الوزارة في شأن صفقات بيع الشركات العامة التي تم", "label": "ar"},
|
| 168 |
+
{"text": "tärkeintä elämässäni .", "label": "fi"},
|
| 169 |
+
{"text": "Виконання Мінських угод було заблоковано Росією та", "label": "uk"},
|
| 170 |
+
{"text": "Aby szybko rozpoznać żołnierzy desantu, należy", "label": "pl"},
|
| 171 |
+
{"text": "Bankengeschäfte liegen vorn , sagte Strothmann .", "label": "de"},
|
| 172 |
+
{"text": "продолжение работы.", "label": "ru"},
|
| 173 |
+
{"text": "Metro AG plant Online-Offensive", "label": "de"},
|
| 174 |
+
{"text": "nu vor veni, și să vor osîndi, aceia nu pot porni", "label": "ro"},
|
| 175 |
+
{"text": "Ich denke , es geht in Wirklichkeit darum , NT bei", "label": "de"},
|
| 176 |
+
{"text": "de turism care încasează contravaloarea", "label": "ro"},
|
| 177 |
+
{"text": "Aurkaria itotzea da helburua, baloia lapurtu eta", "label": "eu"},
|
| 178 |
+
{"text": "com a centre de formació en Tecnologies de la", "label": "ca"},
|
| 179 |
+
{"text": "oportet igitur quod omne agens in agendo intendat", "label": "la"},
|
| 180 |
+
{"text": "Jerzego Andrzejewskiego, oparty na chińskich", "label": "pl"},
|
| 181 |
+
{"text": "sau một vài câu chuyện xã giao không dính dáng tới", "label": "vi"},
|
| 182 |
+
{"text": "что экономическому прорыву жесткий авторитарный", "label": "ru"},
|
| 183 |
+
{"text": "DRAM-Preisen scheinen DSPs ein", "label": "de"},
|
| 184 |
+
{"text": "Jos dajan nubbái: Mana!", "label": "sme"},
|
| 185 |
+
{"text": "toți carii ascultară de el să răsipiră.", "label": "ro"},
|
| 186 |
+
{"text": "odpowiedzialności, które w systemie własności", "label": "pl"},
|
| 187 |
+
{"text": "Dvomesečno potovanje do Mollenda v Peruju je", "label": "sl"},
|
| 188 |
+
{"text": "d'entre les agències internacionals.", "label": "ca"},
|
| 189 |
+
{"text": "Fahrzeugzugangssysteme gefertigt und an viele", "label": "de"},
|
| 190 |
+
{"text": "in an answer to the sharers' petition in Cuthbert", "label": "en"},
|
| 191 |
+
{"text": "Europa-Domain per Verordnung zu regeln .", "label": "de"},
|
| 192 |
+
{"text": "#Balotelli. Su ebay prezzi stracciati per Silvio", "label": "it"},
|
| 193 |
+
{"text": "Ne na košickém trávníku, ale už včera v letadle se", "label": "cs"},
|
| 194 |
+
{"text": "zaměstnanosti a investičních strategií.", "label": "cs"},
|
| 195 |
+
{"text": "Tatínku, udělej den", "label": "cs"},
|
| 196 |
+
{"text": "frecuencia con Mary.", "label": "es"},
|
| 197 |
+
{"text": "Свеаборге.", "label": "ru"},
|
| 198 |
+
{"text": "opatření slovenské strany o certifikaci nejvíce", "label": "cs"},
|
| 199 |
+
{"text": "En todas me decían: 'Espera que hagamos un estudio", "label": "es"},
|
| 200 |
+
{"text": "Die Demonstration sollte nach Darstellung der", "label": "de"},
|
| 201 |
+
{"text": "Ci vorrà un assoluto rigore se dietro i disavanzi", "label": "it"},
|
| 202 |
+
{"text": "Tatínku, víš, že Honzovi odešla maminka?", "label": "cs"},
|
| 203 |
+
{"text": "Die Anzahl der Rechner wuchs um % auf und die", "label": "de"},
|
| 204 |
+
{"text": "האמריקאית על אדמת סעודיה עלולה לסבך את ישראל, אין", "label": "he"},
|
| 205 |
+
{"text": "Volán Egyesülés, a Közlekedési Főfelügyelet is.", "label": "hu"},
|
| 206 |
+
{"text": "Schejbala, který stejnou hru s velkým úspěchem", "label": "cs"},
|
| 207 |
+
{"text": "depends on the data type of the field.", "label": "en"},
|
| 208 |
+
{"text": "Umsatzwarnung zu Wochenbeginn zeitweise auf ein", "label": "de"},
|
| 209 |
+
{"text": "niin heti nukun .", "label": "fi"},
|
| 210 |
+
{"text": "Mobilfunkunternehmen gegen die Anwendung der so", "label": "de"},
|
| 211 |
+
{"text": "sapessi le intenzioni del governo Monti e dell'UE", "label": "it"},
|
| 212 |
+
{"text": "Di chi è figlia Martine Aubry?", "label": "it"},
|
| 213 |
+
{"text": "avec le reste du monde.", "label": "fr"},
|
| 214 |
+
{"text": "Այդ մաքոքը ինքնին նոր չէ, աշխարհը արդեն մի քանի", "label": "hy"},
|
| 215 |
+
{"text": "și în cazul destrămării cenaclului.", "label": "ro"},
|
| 216 |
+
{"text": "befriedigen kann , und ohne die auftretenden", "label": "de"},
|
| 217 |
+
{"text": "Κύκνον τ̓ ἐξεναρεῖν καὶ ἀπὸ κλυτὰ τεύχεα δῦσαι.", "label": "grc"},
|
| 218 |
+
{"text": "færdiguddannede.", "label": "da"},
|
| 219 |
+
{"text": "Schmidt war Sohn eines Rittergutsbesitzers.", "label": "de"},
|
| 220 |
+
{"text": "и вдаша попадь ѡпрати", "label": "orv"},
|
| 221 |
+
{"text": "cine nu știe învățătură”.", "label": "ro"},
|
| 222 |
+
{"text": "détacha et cette dernière tenta de tuer le jeune", "label": "fr"},
|
| 223 |
+
{"text": "Der har saka også ei lengre forhistorie.", "label": "nn"},
|
| 224 |
+
{"text": "Pieprz roztłuc w moździerzu, dodać do pasty,", "label": "pl"},
|
| 225 |
+
{"text": "Лежа за гребнем оврага, как за бруствером, Ушаков", "label": "ru"},
|
| 226 |
+
{"text": "gesucht habe, vielen Dank nochmals!", "label": "de"},
|
| 227 |
+
{"text": "инструментальных сталей, повышения", "label": "ru"},
|
| 228 |
+
{"text": "im Halbfinale Patrick Smith und im Finale dann", "label": "de"},
|
| 229 |
+
{"text": "البنوك التريث في منح تسهيلات جديدة لمنتجي حديد", "label": "ar"},
|
| 230 |
+
{"text": "una bolsa ventral, la cual se encuentra debajo de", "label": "es"},
|
| 231 |
+
{"text": "za SETimes.", "label": "sr"},
|
| 232 |
+
{"text": "de Irak, a un piloto italiano que había violado el", "label": "es"},
|
| 233 |
+
{"text": "Er könne sich nicht erklären , wie die Zeitung auf", "label": "de"},
|
| 234 |
+
{"text": "Прохорова.", "label": "ru"},
|
| 235 |
+
{"text": "la democrazia perde sulla tecnocrazia? #", "label": "it"},
|
| 236 |
+
{"text": "entre ambas instituciones, confirmó al medio que", "label": "es"},
|
| 237 |
+
{"text": "Austlandet, vart det funne om lag førti", "label": "nn"},
|
| 238 |
+
{"text": "уровнями власти.", "label": "ru"},
|
| 239 |
+
{"text": "Dá tedy primáři úplatek, a často ne malý.", "label": "cs"},
|
| 240 |
+
{"text": "brillantes del acto, al llevar a cabo en el", "label": "es"},
|
| 241 |
+
{"text": "eee druga zadeva je majhen priročen gre kamorkoli", "label": "sl"},
|
| 242 |
+
{"text": "Das ATX-Board paßt in herkömmliche PC-ATX-Gehäuse", "label": "de"},
|
| 243 |
+
{"text": "Za vodné bylo v prvním pololetí zaplaceno v ČR", "label": "cs"},
|
| 244 |
+
{"text": "Даже на полсантиметра.", "label": "ru"},
|
| 245 |
+
{"text": "com la del primer tinent d'alcalde en funcions,", "label": "ca"},
|
| 246 |
+
{"text": "кількох оповідань в цілості — щось на зразок того", "label": "uk"},
|
| 247 |
+
{"text": "sed ad divitias congregandas, vel superfluum", "label": "la"},
|
| 248 |
+
{"text": "Norma Talmadge, spela mot Valentino i en version", "label": "sv"},
|
| 249 |
+
{"text": "Dlatego chciał się jej oświadczyć w niezwykłym", "label": "pl"},
|
| 250 |
+
{"text": "будут выступать на одинаковых снарядах.", "label": "ru"},
|
| 251 |
+
{"text": "Orang-orang terbunuh di sana.", "label": "id"},
|
| 252 |
+
{"text": "لدى رايت شقيق اسمه أوسكار, وهو يعمل كرسام للكتب", "label": "ar"},
|
| 253 |
+
{"text": "Wirklichkeit verlagerten und kaum noch", "label": "de"},
|
| 254 |
+
{"text": "как перемешивают костяшки перед игрой в домино, и", "label": "ru"},
|
| 255 |
+
{"text": "В средине дня, когда солнце светило в нашу", "label": "ru"},
|
| 256 |
+
{"text": "d'aventure aux rôles de jeune romantique avec une", "label": "fr"},
|
| 257 |
+
{"text": "My teď hledáme organizace, jež by s námi chtěly", "label": "cs"},
|
| 258 |
+
{"text": "Urteilsfähigkeit einbüßen , wenn ich eigene", "label": "de"},
|
| 259 |
+
{"text": "sua appartenenza anche a voci diverse da quella in", "label": "it"},
|
| 260 |
+
{"text": "Aufträge dieses Jahr verdoppeln werden .", "label": "de"},
|
| 261 |
+
{"text": "M.E.: Miała szanse mnie odnaleźć, gdyby naprawdę", "label": "pl"},
|
| 262 |
+
{"text": "secundum contactum virtutis, cum careat dimensiva", "label": "la"},
|
| 263 |
+
{"text": "ezinbestekoa dela esan zuen.", "label": "eu"},
|
| 264 |
+
{"text": "Anek hurbiltzeko eskatzen zion besaulkitik, eta", "label": "eu"},
|
| 265 |
+
{"text": "perfectius alio videat, quamvis uterque videat", "label": "la"},
|
| 266 |
+
{"text": "Die Strecke war anspruchsvoll und führte unter", "label": "de"},
|
| 267 |
+
{"text": "саморазоблачительным уроком, западные СМИ не", "label": "ru"},
|
| 268 |
+
{"text": "han representerer radikal islamisme .", "label": "no"},
|
| 269 |
+
{"text": "Què s'hi respira pel que fa a la reforma del", "label": "ca"},
|
| 270 |
+
{"text": "previsto para também ser desconstruido.", "label": "pt"},
|
| 271 |
+
{"text": "Ὠκεανοῦ βαθυκόλποις ἄνθεά τ̓ αἰνυμένην, ῥόδα καὶ", "label": "grc"},
|
| 272 |
+
{"text": "para jovens de a anos nos Cieps.", "label": "pt"},
|
| 273 |
+
{"text": "संघर्ष को अंजाम तक पहुंचाने का ऐलान किया है ।", "label": "hi"},
|
| 274 |
+
{"text": "objeví i u nás.", "label": "cs"},
|
| 275 |
+
{"text": "kvitteringer.", "label": "da"},
|
| 276 |
+
{"text": "This report is no exception.", "label": "en"},
|
| 277 |
+
{"text": "Разлепват доносниците до избирателните списъци", "label": "bg"},
|
| 278 |
+
{"text": "anderem ihre Bewegungsfreiheit in den USA", "label": "de"},
|
| 279 |
+
{"text": "Ñu tegoon ca kaw gor ña ay njotti bopp yu kenn", "label": "wo"},
|
| 280 |
+
{"text": "Struktur kann beispielsweise der Schwerpunkt mehr", "label": "de"},
|
| 281 |
+
{"text": "% la velocidad permitida, la sanción es muy grave.", "label": "es"},
|
| 282 |
+
{"text": "Teles-Einstieg in ADSL-Markt", "label": "de"},
|
| 283 |
+
{"text": "ettekäändeks liiga suure osamaksu.", "label": "et"},
|
| 284 |
+
{"text": "als Indiz für die geänderte Marktpolitik des", "label": "de"},
|
| 285 |
+
{"text": "quod quidem aperte consequitur ponentes", "label": "la"},
|
| 286 |
+
{"text": "de negociación para el próximo de junio.", "label": "es"},
|
| 287 |
+
{"text": "Tyto důmyslné dekorace doznaly v poslední době", "label": "cs"},
|
| 288 |
+
{"text": "največjega uspeha doslej.", "label": "sl"},
|
| 289 |
+
{"text": "Paul Allen je jedan od suosnivača Interval", "label": "hr"},
|
| 290 |
+
{"text": "Federal (Seac / DF) eo Sindicato das Empresas de", "label": "pt"},
|
| 291 |
+
{"text": "Quartal mit . Mark gegenüber dem gleichen Quartal", "label": "de"},
|
| 292 |
+
{"text": "otros clubes y del Barça B saldrán varios", "label": "es"},
|
| 293 |
+
{"text": "Jaskula (Pol.) -", "label": "cs"},
|
| 294 |
+
{"text": "umožnily říci, že je možné přejít k mnohem", "label": "cs"},
|
| 295 |
+
{"text": "اعلن الجنرال تومي فرانكس قائد القوات الامريكية", "label": "ar"},
|
| 296 |
+
{"text": "Telekom-Chef Ron Sommer und der Vorstandssprecher", "label": "de"},
|
| 297 |
+
{"text": "My, jako průmyslový a finanční holding, můžeme", "label": "cs"},
|
| 298 |
+
{"text": "voorlichting onder andere betrekking kan hebben:", "label": "nl"},
|
| 299 |
+
{"text": "Hinrichtung geistig Behinderter applaudiert oder", "label": "de"},
|
| 300 |
+
{"text": "wie beispielsweise Anzahl erzielte Klicks ,", "label": "de"},
|
| 301 |
+
{"text": "Intel-PC-SDRAM-Spezifikation in der Version . (", "label": "de"},
|
| 302 |
+
{"text": "plângere în termen de zile de la comunicarea", "label": "ro"},
|
| 303 |
+
{"text": "и Испания ще изгубят втория си комисар в ЕК.", "label": "bg"},
|
| 304 |
+
{"text": "इसके चलते इस आदिवासी जनजाति का क्षरण हो रहा है ।", "label": "hi"},
|
| 305 |
+
{"text": "aunque se mostró contrario a establecer un", "label": "es"},
|
| 306 |
+
{"text": "des letzten Jahres von auf Millionen Euro .", "label": "de"},
|
| 307 |
+
{"text": "Ankara se također poziva da u cijelosti ratificira", "label": "hr"},
|
| 308 |
+
{"text": "herunterlädt .", "label": "de"},
|
| 309 |
+
{"text": "стрессовую ситуацию для организма, каковой", "label": "ru"},
|
| 310 |
+
{"text": "Státního shromáždění (parlamentu).", "label": "cs"},
|
| 311 |
+
{"text": "diskutieren , ob und wie dieser Dienst weiterhin", "label": "de"},
|
| 312 |
+
{"text": "Verbindungen zu FPÖ-nahen Polizisten gepflegt und", "label": "de"},
|
| 313 |
+
{"text": "Pražského volebního lídra ovšem nevybírá Miloš", "label": "cs"},
|
| 314 |
+
{"text": "Nach einem Bericht der Washington Post bleibt das", "label": "de"},
|
| 315 |
+
{"text": "للوضع آنذاك، لكني في قرارة نفسي كنت سعيداً لما", "label": "ar"},
|
| 316 |
+
{"text": "не желаят запазването на статуквото.", "label": "bg"},
|
| 317 |
+
{"text": "Offenburg gewesen .", "label": "de"},
|
| 318 |
+
{"text": "ἐὰν ὑμῖν εἴπω οὐ μὴ πιστεύσητε", "label": "grc"},
|
| 319 |
+
{"text": "all'odiato compagno di squadra Prost, il quale", "label": "it"},
|
| 320 |
+
{"text": "historischen Gänselieselbrunnens.", "label": "de"},
|
| 321 |
+
{"text": "למידע מלווייני הריגול האמריקאיים העוקבים אחר", "label": "he"},
|
| 322 |
+
{"text": "οὐδὲν ἄρα διαφέρεις Ἀμάσιος τοῦ Ἠλείου, ὃν", "label": "grc"},
|
| 323 |
+
{"text": "movementos migratorios.", "label": "gl"},
|
| 324 |
+
{"text": "Handy und ein Spracherkennungsprogramm sämtliche", "label": "de"},
|
| 325 |
+
{"text": "Kümne aasta jooksul on Eestisse ohjeldamatult", "label": "et"},
|
| 326 |
+
{"text": "H.G. Bücknera.", "label": "pl"},
|
| 327 |
+
{"text": "protiv krijumčarenja, ili pak traženju ukidanja", "label": "hr"},
|
| 328 |
+
{"text": "Topware-Anteile mehrere Millionen Mark gefordert", "label": "de"},
|
| 329 |
+
{"text": "Maar de mensen die nu over Van Dijk bij FC Twente", "label": "nl"},
|
| 330 |
+
{"text": "poidan experimentar as percepcións do interesado,", "label": "gl"},
|
| 331 |
+
{"text": "Miał przecież w kieszeni nóż.", "label": "pl"},
|
| 332 |
+
{"text": "Avšak žádná z nich nepronikla za hranice přímé", "label": "cs"},
|
| 333 |
+
{"text": "esim. helpottamalla luottoja muiden", "label": "fi"},
|
| 334 |
+
{"text": "Podle předběžných výsledků zvítězila v", "label": "cs"},
|
| 335 |
+
{"text": "Nicht nur das Web-Frontend , auch die", "label": "de"},
|
| 336 |
+
{"text": "Regierungsinstitutionen oder Universitäten bei", "label": "de"},
|
| 337 |
+
{"text": "Խուլեն Լոպետեգիին, պատճառաբանելով, որ վերջինս", "label": "hy"},
|
| 338 |
+
{"text": "Афганистана, где в последние дни идут ожесточенные", "label": "ru"},
|
| 339 |
+
{"text": "лѧхове же не идоша", "label": "orv"},
|
| 340 |
+
{"text": "Mit Hilfe von IBMs Chip-Management-Systemen sollen", "label": "de"},
|
| 341 |
+
{"text": ", als Manager zu Telefonica zu wechseln .", "label": "de"},
|
| 342 |
+
{"text": "którym zajmuje się człowiek, zmienia go i pozwala", "label": "pl"},
|
| 343 |
+
{"text": "činí kyperských liber, to je asi USD.", "label": "cs"},
|
| 344 |
+
{"text": "Studienplätze getauscht werden .", "label": "de"},
|
| 345 |
+
{"text": "учёных, орнитологов признают вид.", "label": "ru"},
|
| 346 |
+
{"text": "acordare a concediilor prevăzute de legislațiile", "label": "ro"},
|
| 347 |
+
{"text": "at større innsats for fornybar, berekraftig energi", "label": "nn"},
|
| 348 |
+
{"text": "Politiet veit ikkje kor mange personar som deltok", "label": "nn"},
|
| 349 |
+
{"text": "offentligheten av unge , sinte menn som har", "label": "no"},
|
| 350 |
+
{"text": "însuși în jurul lapunei, care încet DISPARE în", "label": "ro"},
|
| 351 |
+
{"text": "O motivo da decisão é evitar uma sobrecarga ainda", "label": "pt"},
|
| 352 |
+
{"text": "El Apostolado de la prensa contribuye en modo", "label": "es"},
|
| 353 |
+
{"text": "Teltow ( Kreis Teltow-Fläming ) ist Schmitt einer", "label": "de"},
|
| 354 |
+
{"text": "grozījumus un iesniegt tos Apvienoto Nāciju", "label": "lv"},
|
| 355 |
+
{"text": "Gestalt einer deutschen Nationalmannschaft als", "label": "de"},
|
| 356 |
+
{"text": "D überholt zu haben , konterte am heutigen Montag", "label": "de"},
|
| 357 |
+
{"text": "Softwarehersteller Oracle hat im dritten Quartal", "label": "de"},
|
| 358 |
+
{"text": "Během nich se ekonomické podmínky mohou radikálně", "label": "cs"},
|
| 359 |
+
{"text": "Dziki kot w górach zeskakuje z kamienia.", "label": "pl"},
|
| 360 |
+
{"text": "Ačkoliv ligový nováček prohrál, opět potvrdil, že", "label": "cs"},
|
| 361 |
+
{"text": "des Tages , Portraits internationaler Stars sowie", "label": "de"},
|
| 362 |
+
{"text": "Communicator bekannt wurde .", "label": "de"},
|
| 363 |
+
{"text": "τῷ δ’ ἄρα καὶ αὐτῷ ἡ γυνή ἐπίτεξ ἐοῦσα πᾶσαν", "label": "grc"},
|
| 364 |
+
{"text": "Triadú tenia, mentre redactava 'Dies de memòria',", "label": "ca"},
|
| 365 |
+
{"text": "دستهجمعی در درخشندگی ماه سیمگون زمزمه ستاینده و", "label": "fa"},
|
| 366 |
+
{"text": "Книгу, наполненную мелочной заботой об одежде,", "label": "ru"},
|
| 367 |
+
{"text": "putares canem leporem persequi.", "label": "la"},
|
| 368 |
+
{"text": "В дальнейшем эта яркость слегка померкла, но в", "label": "ru"},
|
| 369 |
+
{"text": "offizielles Verfahren gegen die Telekom", "label": "de"},
|
| 370 |
+
{"text": "podrían haber sido habitantes de la Península", "label": "es"},
|
| 371 |
+
{"text": "Grundlage für dieses Verfahren sind spezielle", "label": "de"},
|
| 372 |
+
{"text": "Rechtsausschuß vorgelegten Entwurf der Richtlinie", "label": "de"},
|
| 373 |
+
{"text": "Im so genannten Portalgeschäft sei das Unternehmen", "label": "de"},
|
| 374 |
+
{"text": "ⲏ ⲉⲓϣⲁⲛϥⲓ ⲛⲉⲓⲇⲱⲗⲟⲛ ⲉⲧϩⲙⲡⲉⲕⲏⲓ ⲙⲏ ⲉⲓⲛⲁϣϩⲱⲡ ⲟⲛ ⲙⲡⲣⲏ", "label": "cop"},
|
| 375 |
+
{"text": "juego podían matar a cualquier herbívoro, pero", "label": "es"},
|
| 376 |
+
{"text": "Nach Angaben von Axent nutzen Unternehmen aus der", "label": "de"},
|
| 377 |
+
{"text": "hrdiny Havlovy Zahradní slavnosti (premiéra ) se", "label": "cs"},
|
| 378 |
+
{"text": "Een zin van heb ik jou daar", "label": "nl"},
|
| 379 |
+
{"text": "hat sein Hirn an der CeBIT-Kasse vergessen .", "label": "de"},
|
| 380 |
+
{"text": "καὶ τοὺς ἐκπλαγέντας οὐκ ἔχειν ἔτι ἐλεγχομένους", "label": "grc"},
|
| 381 |
+
{"text": "nachgewiesenen langfristigen Kosten , sowie den im", "label": "de"},
|
| 382 |
+
{"text": "jučer nakon četiri dana putovanja u Helsinki.", "label": "hr"},
|
| 383 |
+
{"text": "pašto paslaugos teikėjas gali susitarti su", "label": "lt"},
|
| 384 |
+
{"text": "В результате, эти золотые кадры переходят из одной", "label": "ru"},
|
| 385 |
+
{"text": "द फाइव-ईयर एंगेजमेंट में अभिनय किया जिसमें जैसन", "label": "hi"},
|
| 386 |
+
{"text": "výpis o počtu akcií.", "label": "cs"},
|
| 387 |
+
{"text": "Enfin, elles arrivent à un pavillon chinois", "label": "fr"},
|
| 388 |
+
{"text": "Tentu saja, tren yang berhubungandengan", "label": "id"},
|
| 389 |
+
{"text": "Arbeidarpartiet og SV har sikra seg fleirtal mot", "label": "nn"},
|
| 390 |
+
{"text": "eles: 'Tudo isso está errado' , disse um", "label": "pt"},
|
| 391 |
+
{"text": "The islands are in their own time zone, minutes", "label": "en"},
|
| 392 |
+
{"text": "Auswahl debütierte er am .", "label": "de"},
|
| 393 |
+
{"text": "Bu komisyonlar, arazilerini satın almak için", "label": "tr"},
|
| 394 |
+
{"text": "Geschütze gegen Redmond aufgefahren .", "label": "de"},
|
| 395 |
+
{"text": "Time scything the hours, but at the top, over the", "label": "en"},
|
| 396 |
+
{"text": "Di musim semi , berharap mengadaptasi Tintin untuk", "label": "id"},
|
| 397 |
+
{"text": "крупнейшей геополитической катастрофой XX века.", "label": "ru"},
|
| 398 |
+
{"text": "Rajojen avaaminen ei suju ongelmitta .", "label": "fi"},
|
| 399 |
+
{"text": "непроницаемым, как для СССР.", "label": "ru"},
|
| 400 |
+
{"text": "Ma non mancano le polemiche.", "label": "it"},
|
| 401 |
+
{"text": "Internet als Ort politischer Diskussion und auch", "label": "de"},
|
| 402 |
+
{"text": "incomplets.", "label": "ca"},
|
| 403 |
+
{"text": "Su padre luchó al lado de Luis Moya, primer Jefe", "label": "es"},
|
| 404 |
+
{"text": "informazione.", "label": "it"},
|
| 405 |
+
{"text": "Primacom bietet für Telekom-Kabelnetz", "label": "de"},
|
| 406 |
+
{"text": "Oświadczenie prezydencji w imieniu Unii", "label": "pl"},
|
| 407 |
+
{"text": "foran rattet i familiens gamle Baleno hvis døra på", "label": "no"},
|
| 408 |
+
{"text": "[speaker:laughter]", "label": "sl"},
|
| 409 |
+
{"text": "Dog med langt mindre utstyr med seg.", "label": "nn"},
|
| 410 |
+
{"text": "dass es nicht schon mit der anfänglichen", "label": "de"},
|
| 411 |
+
{"text": "इस पर दोनों पक्षों में नोकझोंक शुरू हो गई ।", "label": "hi"},
|
| 412 |
+
{"text": "کے ترجمان منیش تیواری اور دگ وجئے سنگھ نے بھی یہ", "label": "ur"},
|
| 413 |
+
{"text": "dell'Assemblea Costituente che posseggono i", "label": "it"},
|
| 414 |
+
{"text": "и аште вьси съблазнѧтъ сѧ нъ не азъ", "label": "cu"},
|
| 415 |
+
{"text": "In Irvine hat auch das Logistikunternehmen Atlas", "label": "de"},
|
| 416 |
+
{"text": "законодательных норм, принимаемых существующей", "label": "ru"},
|
| 417 |
+
{"text": "Κροίσῳ προτείνων τὰς χεῖρας ἐπικατασφάξαι μιν", "label": "grc"},
|
| 418 |
+
{"text": "МИНУСЫ: ИНФЛЯЦИЯ И КРИЗИС В ЖИВОТНОВОДСТВЕ.", "label": "ru"},
|
| 419 |
+
{"text": "unterschiedlicher Meinung .", "label": "de"},
|
| 420 |
+
{"text": "Jospa joku ystävällinen sielu auttaisi kassieni", "label": "fi"},
|
| 421 |
+
{"text": "Añadió que, en el futuro se harán otros", "label": "es"},
|
| 422 |
+
{"text": "Sessiz tonlama hem Fince, hem de Kuzey Sami", "label": "tr"},
|
| 423 |
+
{"text": "nicht ihnen gehört und sie nicht alles , was sie", "label": "de"},
|
| 424 |
+
{"text": "Etelästä Kuivajärveen laskee Tammelan Liesjärvestä", "label": "fi"},
|
| 425 |
+
{"text": "ICANNs Vorsitzender Vint Cerf warb mit dem Hinweis", "label": "de"},
|
| 426 |
+
{"text": "Norsk politikk frå til kan dermed, i", "label": "nn"},
|
| 427 |
+
{"text": "Głosowało posłów.", "label": "pl"},
|
| 428 |
+
{"text": "Danny Jones -- smithjones@ev.net", "label": "en"},
|
| 429 |
+
{"text": "sebeuvědomění moderní civilizace sehrála lučavka", "label": "cs"},
|
| 430 |
+
{"text": "относительно спокойный сон: тому гарантия", "label": "ru"},
|
| 431 |
+
{"text": "A halte voiz prist li pedra a crïer", "label": "fro"},
|
| 432 |
+
{"text": "آنها امیدوارند این واکسن بهزودی در دسترس بیماران", "label": "fa"},
|
| 433 |
+
{"text": "vlastní důstojnou vousatou tváří.", "label": "cs"},
|
| 434 |
+
{"text": "ora aprire la strada a nuove cause e alimentare il", "label": "it"},
|
| 435 |
+
{"text": "Die Zahl der Vielleser nahm von auf Prozent zu ,", "label": "de"},
|
| 436 |
+
{"text": "Finanzvorstand von Hotline-Dienstleister InfoGenie", "label": "de"},
|
| 437 |
+
{"text": "entwickeln .", "label": "de"},
|
| 438 |
+
{"text": "incolumità pubblica.", "label": "it"},
|
| 439 |
+
{"text": "lehtija televisiomainonta", "label": "fi"},
|
| 440 |
+
{"text": "joistakin kohdista eri mieltä.", "label": "fi"},
|
| 441 |
+
{"text": "Hlavně anglická nezávislá scéna, Dead Can Dance,", "label": "cs"},
|
| 442 |
+
{"text": "pásmech od do bodů bodové stupnice.", "label": "cs"},
|
| 443 |
+
{"text": "Zu Beginn des Ersten Weltkrieges zählte das", "label": "de"},
|
| 444 |
+
{"text": "Així van sorgir, damunt els antics cementiris,", "label": "ca"},
|
| 445 |
+
{"text": "In manchem Gedicht der spätern Alten, wie zum", "label": "de"},
|
| 446 |
+
{"text": "gaweihaida jah insandida in þana fairƕu jus qiþiþ", "label": "got"},
|
| 447 |
+
{"text": "Beides sollte gelöscht werden!", "label": "de"},
|
| 448 |
+
{"text": "modifiqués la seva petició inicial de anys de", "label": "ca"},
|
| 449 |
+
{"text": "В день открытия симпозиума состоялась закладка", "label": "ru"},
|
| 450 |
+
{"text": "tõestatud.", "label": "et"},
|
| 451 |
+
{"text": "ἵππῳ πίπτει αὐτοῦ ταύτῃ", "label": "grc"},
|
| 452 |
+
{"text": "bisher nie enttäuscht!", "label": "de"},
|
| 453 |
+
{"text": "De bohte ollu tuollárat ja suttolaččat ja", "label": "sme"},
|
| 454 |
+
{"text": "Klarsignal från röstlängdsläsaren, tre tryck i", "label": "sv"},
|
| 455 |
+
{"text": "Tvůrcem nového termínu je Joseph Fisher.", "label": "cs"},
|
| 456 |
+
{"text": "Nie miałem czasu na reakcję twierdzi Norbert,", "label": "pl"},
|
| 457 |
+
{"text": "potentia Schöpfer.", "label": "de"},
|
| 458 |
+
{"text": "Un poquito caro, pero vale mucho la pena;", "label": "es"},
|
| 459 |
+
{"text": "οὔ τε γὰρ ἴφθιμοι Λύκιοι Δαναῶν ἐδύναντο τεῖχος", "label": "grc"},
|
| 460 |
+
{"text": "vajec, sladového výtažku a některých vitamínových", "label": "cs"},
|
| 461 |
+
{"text": "Настоящие герои, те, чьи истории потом", "label": "ru"},
|
| 462 |
+
{"text": "praesumptio:", "label": "la"},
|
| 463 |
+
{"text": "Olin justkui nende vastutusel.", "label": "et"},
|
| 464 |
+
{"text": "Jokainen keinahdus tuo lähemmäksi hetkeä jolloin", "label": "fi"},
|
| 465 |
+
{"text": "ekonomicky výhodných způsobů odvodnění těžkých,", "label": "cs"},
|
| 466 |
+
{"text": "Poprvé ve své historii dokázala v kvalifikaci pro", "label": "cs"},
|
| 467 |
+
{"text": "zpracovatelského a spotřebního průmyslu bude nutné", "label": "cs"},
|
| 468 |
+
{"text": "Windows CE zu integrieren .", "label": "de"},
|
| 469 |
+
{"text": "Armangué, a través d'un decret, ordenés l'aturada", "label": "ca"},
|
| 470 |
+
{"text": "to, co nás Evropany spojuje, než to, co nás od", "label": "cs"},
|
| 471 |
+
{"text": "ergänzt durch einen gesetzlich verankertes", "label": "de"},
|
| 472 |
+
{"text": "Насчитал, что с начала года всего три дня были", "label": "ru"},
|
| 473 |
+
{"text": "Borisovu tražeći od njega da prihvati njenu", "label": "sr"},
|
| 474 |
+
{"text": "la presenza di ben veleni diversi: . chili di", "label": "it"},
|
| 475 |
+
{"text": "καὶ τῶν ἐκλεκτῶν ἀγγέλων ἵνα ταῦτα φυλάξῃς χωρὶς", "label": "grc"},
|
| 476 |
+
{"text": "pretraživale obližnju bolnicu i stambene zgrade u", "label": "hr"},
|
| 477 |
+
{"text": "An rund Katzen habe Wolf seine Spiele getestet ,", "label": "de"},
|
| 478 |
+
{"text": "investigating since March.", "label": "en"},
|
| 479 |
+
{"text": "Tonböden (Mullböden).", "label": "de"},
|
| 480 |
+
{"text": "Stálý dopisovatel LN v SRN Bedřich Utitz", "label": "cs"},
|
| 481 |
+
{"text": "červnu předložené smlouvy.", "label": "cs"},
|
| 482 |
+
{"text": "πνεύματι ᾧ ἐλάλει", "label": "grc"},
|
| 483 |
+
{"text": ".%의 신장세를 보였다.", "label": "ko"},
|
| 484 |
+
{"text": "Foae verde, foi de nuc, Prin pădure, prin colnic,", "label": "ro"},
|
| 485 |
+
{"text": "διαπέμψας ἄλλους ἄλλῃ τοὺς μὲν ἐς Δελφοὺς ἰέναι", "label": "grc"},
|
| 486 |
+
{"text": "المسلمين أو أي تيار سياسي طالما عمل ذلك التيار في", "label": "ar"},
|
| 487 |
+
{"text": "As informações são da Dow Jones.", "label": "pt"},
|
| 488 |
+
{"text": "Milliarde DM ausgestattet sein .", "label": "de"},
|
| 489 |
+
{"text": "De utgår fortfarande från att kvinnans jämlikhet", "label": "sv"},
|
| 490 |
+
{"text": "Sneeuw maakte in Davos bij de voorbereiding een", "label": "nl"},
|
| 491 |
+
{"text": "De ahí que en este mercado puedan negociarse", "label": "es"},
|
| 492 |
+
{"text": "intenzívnějšímu sbírání a studiu.", "label": "cs"},
|
| 493 |
+
{"text": "और औसकर ४.० पैकेज का प्रयोग किया गया है ।", "label": "hi"},
|
| 494 |
+
{"text": "Adipati Kuningan karena Kuningan menjadi bagian", "label": "id"},
|
| 495 |
+
{"text": "Svako je bar jednom poželeo da mašine prosto umeju", "label": "sr"},
|
| 496 |
+
{"text": "Im vergangenen Jahr haben die Regierungen einen", "label": "de"},
|
| 497 |
+
{"text": "durat motus, aliquid fit et non est;", "label": "la"},
|
| 498 |
+
{"text": "Dominować będą piosenki do tekstów Edwarda", "label": "pl"},
|
| 499 |
+
{"text": "beantwortet .", "label": "de"},
|
| 500 |
+
{"text": "О гуманитариях было кому рассказывать, а вот за", "label": "ru"},
|
| 501 |
+
{"text": "Helsingin kaupunki riitautti vuokrasopimuksen", "label": "fi"},
|
| 502 |
+
{"text": "chợt tan biến.", "label": "vi"},
|
| 503 |
+
{"text": "avtomobil ločuje od drugih.", "label": "sl"},
|
| 504 |
+
{"text": "Congress has proven itself ineffective as a body.", "label": "en"},
|
| 505 |
+
{"text": "मैक्सिको ने इस तरह का शो इस समय आयोजित करने का", "label": "hi"},
|
| 506 |
+
{"text": "No minimum order amount.", "label": "en"},
|
| 507 |
+
{"text": "Convertassa .", "label": "fi"},
|
| 508 |
+
{"text": "Как это можно сделать?", "label": "ru"},
|
| 509 |
+
{"text": "tha mi creidsinn gu robh iad ceart cho saor shuas", "label": "gd"},
|
| 510 |
+
{"text": "실제 일제는 이런 만해의 논리를 묵살하고 한반도를 침략한 다음 , 이어 만주를 침략하고", "label": "ko"},
|
| 511 |
+
{"text": "Da un semplice richiamo all'ordine fino a grandi", "label": "it"},
|
| 512 |
+
{"text": "pozoruhodný nejen po umělecké stránce, jež", "label": "cs"},
|
| 513 |
+
{"text": "La comida y el servicio aprueban.", "label": "es"},
|
| 514 |
+
{"text": "again, connected not with each other but to the", "label": "en"},
|
| 515 |
+
{"text": "Protokol výslovně stanoví, že nikdo nemůže být", "label": "cs"},
|
| 516 |
+
{"text": "ఒక విషయం అడగాలని ఉంది .", "label": "te"},
|
| 517 |
+
{"text": "Безгранично почитая дирекцию, ловя на лету каждое", "label": "ru"},
|
| 518 |
+
{"text": "rovnoběžných růstových vrstev, zůstávají krychlové", "label": "cs"},
|
| 519 |
+
{"text": "प्रवेश और पूर्व प्रधानमंत्री लाल बहादुर शास्त्री", "label": "hi"},
|
| 520 |
+
{"text": "Bronzen medaille in de Europese marathon.", "label": "nl"},
|
| 521 |
+
{"text": "- gadu vecumā viņi to nesaprot.", "label": "lv"},
|
| 522 |
+
{"text": "Realizó sus estudios primarios en la Escuela Julia", "label": "es"},
|
| 523 |
+
{"text": "cuartos de final, su clasificación para la final a", "label": "es"},
|
| 524 |
+
{"text": "Sem si pro něho přiletí americký raketoplán, na", "label": "cs"},
|
| 525 |
+
{"text": "Way to go!", "label": "en"},
|
| 526 |
+
{"text": "gehört der neuen SPD-Führung unter Parteichef", "label": "de"},
|
| 527 |
+
{"text": "Somit simuliert der Player mit einer GByte-Platte", "label": "de"},
|
| 528 |
+
{"text": "Berufung auf kommissionsnahe Kreise , die bereits", "label": "de"},
|
| 529 |
+
{"text": "Dist Clarïen", "label": "fro"},
|
| 530 |
+
{"text": "Schon nach den Gerüchten , die Telekom wolle den", "label": "de"},
|
| 531 |
+
{"text": "Software von NetObjects ist nach Angaben des", "label": "de"},
|
| 532 |
+
{"text": "si enim per legem iustitia ergo Christus gratis", "label": "la"},
|
| 533 |
+
{"text": "ducerent in ipsam magis quam in corpus christi,", "label": "la"},
|
| 534 |
+
{"text": "Neustar-Melbourne-IT-Partnerschaft NeuLevel .", "label": "de"},
|
| 535 |
+
{"text": "forderte dagegen seine drastische Verschärfung.", "label": "de"},
|
| 536 |
+
{"text": "pemmican på hundrede forskellige måder.", "label": "da"},
|
| 537 |
+
{"text": "Lehån, själv matematiklärare, visar hur den nya", "label": "sv"},
|
| 538 |
+
{"text": "I highly recommend his shop.", "label": "en"},
|
| 539 |
+
{"text": "verità, giovani fedeli prostratevi #amen", "label": "it"},
|
| 540 |
+
{"text": "उत्तर प्रदेश के अध्यक्ष पद से हटाए गए विनय कटियार", "label": "hi"},
|
| 541 |
+
{"text": "() روزی مےں کشادگی ہوتی ہے۔", "label": "ur"},
|
| 542 |
+
{"text": "Prozessorgeschäft profitieren kann , stellen", "label": "de"},
|
| 543 |
+
{"text": "školy začalo počítat pytle s moukou a zjistilo, že", "label": "cs"},
|
| 544 |
+
{"text": "प्रभावशाली पर गैर सरकारी लोगों के घरों में भी", "label": "hi"},
|
| 545 |
+
{"text": "geschichtslos , oder eine Farce , wie sich", "label": "de"},
|
| 546 |
+
{"text": "Ústrednými mocnosťami v marci však spôsobilo, že", "label": "sk"},
|
| 547 |
+
{"text": "التسليح بدون مبرر، واستمرار الأضرار الناجمة عن فرض", "label": "ar"},
|
| 548 |
+
{"text": "Například Pedagogická fakulta Univerzity Karlovy", "label": "cs"},
|
| 549 |
+
{"text": "nostris ut eriperet nos de praesenti saeculo", "label": "la"}]
|
| 550 |
+
|
| 551 |
+
docs = [Document([], text=example["text"]) for example in examples]
|
| 552 |
+
gold_labels = [example["label"] for example in examples]
|
| 553 |
+
basic_multilingual(docs)
|
| 554 |
+
accuracy = sum([(doc.lang == label) for doc,label in zip(docs,gold_labels)])/len(docs)
|
| 555 |
+
assert accuracy >= 0.98
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def test_text_cleaning(basic_multilingual, clean_multilingual):
|
| 559 |
+
"""
|
| 560 |
+
Basic test of cleaning text
|
| 561 |
+
"""
|
| 562 |
+
docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
|
| 563 |
+
"Bonjour le monde! https://t.co/U0Zjp3tusD"]
|
| 564 |
+
docs = [Document([], text=text) for text in docs]
|
| 565 |
+
|
| 566 |
+
basic_multilingual(docs)
|
| 567 |
+
assert [doc.lang for doc in docs] == ["it", "it"]
|
| 568 |
+
|
| 569 |
+
assert clean_multilingual.processors["langid"]._clean_text
|
| 570 |
+
clean_multilingual(docs)
|
| 571 |
+
assert [doc.lang for doc in docs] == ["fr", "fr"]
|
| 572 |
+
|
| 573 |
+
def test_emoji_cleaning():
|
| 574 |
+
TEXT = ["Sh'reyan has nice antennae :thumbs_up:",
|
| 575 |
+
"This is🐱 a cat"]
|
| 576 |
+
EXPECTED = ["Sh'reyan has nice antennae",
|
| 577 |
+
"This is a cat"]
|
| 578 |
+
for text, expected in zip(TEXT, EXPECTED):
|
| 579 |
+
assert LangIDProcessor.clean_text(text) == expected
|
| 580 |
+
|
| 581 |
+
def test_lang_subset(basic_multilingual, enfr_multilingual, en_multilingual):
|
| 582 |
+
"""
|
| 583 |
+
Basic test of restricting output to subset of languages
|
| 584 |
+
"""
|
| 585 |
+
docs = ["Bonjour le monde! #thisisfrench #ilovefrance",
|
| 586 |
+
"Bonjour le monde! https://t.co/U0Zjp3tusD"]
|
| 587 |
+
docs = [Document([], text=text) for text in docs]
|
| 588 |
+
|
| 589 |
+
basic_multilingual(docs)
|
| 590 |
+
assert [doc.lang for doc in docs] == ["it", "it"]
|
| 591 |
+
|
| 592 |
+
assert enfr_multilingual.processors["langid"]._model.lang_subset == ["en", "fr"]
|
| 593 |
+
enfr_multilingual(docs)
|
| 594 |
+
assert [doc.lang for doc in docs] == ["fr", "fr"]
|
| 595 |
+
|
| 596 |
+
assert en_multilingual.processors["langid"]._model.lang_subset == ["en"]
|
| 597 |
+
en_multilingual(docs)
|
| 598 |
+
assert [doc.lang for doc in docs] == ["en", "en"]
|
| 599 |
+
|
| 600 |
+
def test_lang_subset_unlikely_language(en_multilingual):
|
| 601 |
+
"""
|
| 602 |
+
Test that the language subset masking chooses a legal language, even if all legal languages are supa unlikely
|
| 603 |
+
"""
|
| 604 |
+
sentences = ["你好" * 200]
|
| 605 |
+
docs = [Document([], text=text) for text in sentences]
|
| 606 |
+
en_multilingual(docs)
|
| 607 |
+
assert [doc.lang for doc in docs] == ["en"]
|
| 608 |
+
|
| 609 |
+
processor = en_multilingual.processors['langid']
|
| 610 |
+
model = processor._model
|
| 611 |
+
text_tensor = processor._text_to_tensor(sentences)
|
| 612 |
+
en_idx = model.tag_to_idx['en']
|
| 613 |
+
predictions = model(text_tensor)
|
| 614 |
+
assert predictions[0, en_idx] < 0, "If this test fails, then regardless of how unlikely it was, the model is predicting the input string is possibly English. Update the test by picking a different combination of languages & input"
|
| 615 |
+
|
stanza/stanza/tests/lemma/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/mwt/test_utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test the MWT resplitting of preexisting tokens without word splits
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
import stanza
|
| 8 |
+
from stanza.models.mwt.utils import resplit_mwt
|
| 9 |
+
|
| 10 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 11 |
+
|
| 12 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 13 |
+
|
| 14 |
+
@pytest.fixture(scope="module")
|
| 15 |
+
def pipeline():
|
| 16 |
+
"""
|
| 17 |
+
A reusable pipeline with the NER module
|
| 18 |
+
"""
|
| 19 |
+
return stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize,mwt", package="gum")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_resplit_keep_tokens(pipeline):
|
| 23 |
+
"""
|
| 24 |
+
Test splitting with enforced token boundaries
|
| 25 |
+
"""
|
| 26 |
+
tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
|
| 27 |
+
doc = resplit_mwt(tokens, pipeline)
|
| 28 |
+
assert len(doc.sentences) == 2
|
| 29 |
+
assert len(doc.sentences[0].tokens) == 4
|
| 30 |
+
assert len(doc.sentences[0].tokens[1].words) == 2
|
| 31 |
+
assert doc.sentences[0].tokens[1].words[0].text == "ca"
|
| 32 |
+
assert doc.sentences[0].tokens[1].words[1].text == "n't"
|
| 33 |
+
|
| 34 |
+
assert len(doc.sentences[1].tokens) == 2
|
| 35 |
+
# updated GUM MWT splits "I can't" into three segments
|
| 36 |
+
# the way we want, "I - ca - n't"
|
| 37 |
+
# previously it would split "I - can - 't"
|
| 38 |
+
assert len(doc.sentences[1].tokens[0].words) == 3
|
| 39 |
+
assert doc.sentences[1].tokens[0].words[0].text == "I"
|
| 40 |
+
assert doc.sentences[1].tokens[0].words[1].text == "ca"
|
| 41 |
+
assert doc.sentences[1].tokens[0].words[2].text == "n't"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_resplit_no_keep_tokens(pipeline):
|
| 45 |
+
"""
|
| 46 |
+
Test splitting without enforced token boundaries
|
| 47 |
+
"""
|
| 48 |
+
tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
|
| 49 |
+
doc = resplit_mwt(tokens, pipeline, keep_tokens=False)
|
| 50 |
+
assert len(doc.sentences) == 2
|
| 51 |
+
assert len(doc.sentences[0].tokens) == 4
|
| 52 |
+
assert len(doc.sentences[0].tokens[1].words) == 2
|
| 53 |
+
assert doc.sentences[0].tokens[1].words[0].text == "ca"
|
| 54 |
+
assert doc.sentences[0].tokens[1].words[1].text == "n't"
|
| 55 |
+
|
| 56 |
+
assert len(doc.sentences[1].tokens) == 3
|
| 57 |
+
assert len(doc.sentences[1].tokens[1].words) == 2
|
| 58 |
+
assert doc.sentences[1].tokens[1].words[0].text == "ca"
|
| 59 |
+
assert doc.sentences[1].tokens[1].words[1].text == "n't"
|
stanza/stanza/tests/ner/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/ner/test_combine_ner_datasets.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 6 |
+
|
| 7 |
+
from stanza.models.common.doc import Document
|
| 8 |
+
from stanza.tests.ner.test_ner_training import write_temp_file, EN_TRAIN_BIO, EN_DEV_BIO
|
| 9 |
+
from stanza.utils.datasets.ner import combine_ner_datasets
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_combine(tmp_path):
|
| 13 |
+
"""
|
| 14 |
+
Test that if we write two short datasets and combine them, we get back
|
| 15 |
+
one slightly longer dataset
|
| 16 |
+
|
| 17 |
+
To simplify matters, we just use the same input text with longer
|
| 18 |
+
amounts of text for each shard.
|
| 19 |
+
"""
|
| 20 |
+
SHARDS = ("train", "dev", "test")
|
| 21 |
+
for s_num, shard in enumerate(SHARDS):
|
| 22 |
+
t1_json = tmp_path / ("en_t1.%s.json" % shard)
|
| 23 |
+
# eg, 1x, 2x, 3x the test data from test_ner_training
|
| 24 |
+
write_temp_file(t1_json, "\n\n".join([EN_TRAIN_BIO] * (s_num + 1)))
|
| 25 |
+
|
| 26 |
+
t2_json = tmp_path / ("en_t2.%s.json" % shard)
|
| 27 |
+
write_temp_file(t2_json, "\n\n".join([EN_DEV_BIO] * (s_num + 1)))
|
| 28 |
+
|
| 29 |
+
args = ["--output_dataset", "en_c", "en_t1", "en_t2", "--input_dir", str(tmp_path), "--output_dir", str(tmp_path)]
|
| 30 |
+
combine_ner_datasets.main(args)
|
| 31 |
+
|
| 32 |
+
for s_num, shard in enumerate(SHARDS):
|
| 33 |
+
filename = tmp_path / ("en_c.%s.json" % shard)
|
| 34 |
+
assert os.path.exists(filename)
|
| 35 |
+
|
| 36 |
+
with open(filename, encoding="utf-8") as fin:
|
| 37 |
+
doc = Document(json.load(fin))
|
| 38 |
+
assert len(doc.sentences) == (s_num + 1) * 3
|
| 39 |
+
|
stanza/stanza/tests/ner/test_models_ner_scorer.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple test of the scorer module for NER
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import stanza
|
| 7 |
+
|
| 8 |
+
from stanza.tests import *
|
| 9 |
+
from stanza.models.ner.scorer import score_by_token, score_by_entity
|
| 10 |
+
|
| 11 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 12 |
+
|
| 13 |
+
def test_ner_scorer():
|
| 14 |
+
pred_sequences = [['O', 'S-LOC', 'O', 'O', 'B-PER', 'E-PER'],
|
| 15 |
+
['O', 'S-MISC', 'O', 'E-ORG', 'O', 'B-PER', 'I-PER', 'E-PER']]
|
| 16 |
+
gold_sequences = [['O', 'B-LOC', 'E-LOC', 'O', 'B-PER', 'E-PER'],
|
| 17 |
+
['O', 'S-MISC', 'B-ORG', 'E-ORG', 'O', 'B-PER', 'E-PER', 'S-LOC']]
|
| 18 |
+
|
| 19 |
+
token_p, token_r, token_f, confusion = score_by_token(pred_sequences, gold_sequences)
|
| 20 |
+
assert pytest.approx(token_p, abs=0.00001) == 0.625
|
| 21 |
+
assert pytest.approx(token_r, abs=0.00001) == 0.5
|
| 22 |
+
assert pytest.approx(token_f, abs=0.00001) == 0.55555
|
| 23 |
+
|
| 24 |
+
entity_p, entity_r, entity_f, entity_f1 = score_by_entity(pred_sequences, gold_sequences)
|
| 25 |
+
assert pytest.approx(entity_p, abs=0.00001) == 0.4
|
| 26 |
+
assert pytest.approx(entity_r, abs=0.00001) == 0.33333
|
| 27 |
+
assert pytest.approx(entity_f, abs=0.00001) == 0.36363
|
| 28 |
+
assert entity_f1 == {'LOC': 0.0, 'MISC': 1.0, 'ORG': 0.0, 'PER': 0.5}
|
stanza/stanza/tests/ner/test_ner_tagger.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic testing of the NER tagger.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import pytest
|
| 7 |
+
import stanza
|
| 8 |
+
|
| 9 |
+
from stanza.tests import *
|
| 10 |
+
from stanza.models import ner_tagger
|
| 11 |
+
from stanza.utils.confusion import confusion_to_macro_f1
|
| 12 |
+
import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file
|
| 13 |
+
from stanza.utils.training.run_ner import build_pretrain_args
|
| 14 |
+
|
| 15 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 16 |
+
|
| 17 |
+
EN_DOC = "Chris Manning is a good man. He works in Stanford University."
|
| 18 |
+
|
| 19 |
+
EN_DOC_GOLD = """
|
| 20 |
+
<Span text=Chris Manning;type=PERSON;start_char=0;end_char=13>
|
| 21 |
+
<Span text=Stanford University;type=ORG;start_char=41;end_char=60>
|
| 22 |
+
""".strip()
|
| 23 |
+
|
| 24 |
+
EN_BIO = """
|
| 25 |
+
Chris B-PERSON
|
| 26 |
+
Manning E-PERSON
|
| 27 |
+
is O
|
| 28 |
+
a O
|
| 29 |
+
good O
|
| 30 |
+
man O
|
| 31 |
+
. O
|
| 32 |
+
|
| 33 |
+
He O
|
| 34 |
+
works O
|
| 35 |
+
in O
|
| 36 |
+
Stanford B-ORG
|
| 37 |
+
University E-ORG
|
| 38 |
+
. O
|
| 39 |
+
""".strip().replace(" ", "\t")
|
| 40 |
+
|
| 41 |
+
EN_EXPECTED_OUTPUT = """
|
| 42 |
+
Chris B-PERSON B-PERSON
|
| 43 |
+
Manning E-PERSON E-PERSON
|
| 44 |
+
is O O
|
| 45 |
+
a O O
|
| 46 |
+
good O O
|
| 47 |
+
man O O
|
| 48 |
+
. O O
|
| 49 |
+
|
| 50 |
+
He O O
|
| 51 |
+
works O O
|
| 52 |
+
in O O
|
| 53 |
+
Stanford B-ORG B-ORG
|
| 54 |
+
University E-ORG E-ORG
|
| 55 |
+
. O O
|
| 56 |
+
""".strip().replace(" ", "\t")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_ner():
|
| 60 |
+
nlp = stanza.Pipeline(**{'processors': 'tokenize,ner', 'dir': TEST_MODELS_DIR, 'lang': 'en', 'logging_level': 'error'})
|
| 61 |
+
doc = nlp(EN_DOC)
|
| 62 |
+
assert EN_DOC_GOLD == '\n'.join([ent.pretty_print() for ent in doc.ents])
|
| 63 |
+
|
| 64 |
+
def test_evaluate(tmp_path):
|
| 65 |
+
"""
|
| 66 |
+
This simple example should have a 1.0 f1 for the ontonote model
|
| 67 |
+
"""
|
| 68 |
+
package = "ontonotes-ww-multi_charlm"
|
| 69 |
+
model_path = os.path.join(TEST_MODELS_DIR, "en", "ner", package + ".pt")
|
| 70 |
+
assert os.path.exists(model_path), "The {} model should be downloaded as part of setup.py".format(package)
|
| 71 |
+
|
| 72 |
+
os.makedirs(tmp_path, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
test_bio_filename = tmp_path / "test.bio"
|
| 75 |
+
test_json_filename = tmp_path / "test.json"
|
| 76 |
+
test_output_filename = tmp_path / "output.bio"
|
| 77 |
+
with open(test_bio_filename, "w", encoding="utf-8") as fout:
|
| 78 |
+
fout.write(EN_BIO)
|
| 79 |
+
|
| 80 |
+
prepare_ner_file.process_dataset(test_bio_filename, test_json_filename)
|
| 81 |
+
|
| 82 |
+
args = ["--save_name", str(model_path),
|
| 83 |
+
"--eval_file", str(test_json_filename),
|
| 84 |
+
"--eval_output_file", str(test_output_filename),
|
| 85 |
+
"--mode", "predict"]
|
| 86 |
+
args = args + build_pretrain_args("en", package, model_dir=TEST_MODELS_DIR)
|
| 87 |
+
args = ner_tagger.parse_args(args=args)
|
| 88 |
+
confusion = ner_tagger.evaluate(args)
|
| 89 |
+
assert confusion_to_macro_f1(confusion) == pytest.approx(1.0)
|
| 90 |
+
|
| 91 |
+
with open(test_output_filename, encoding="utf-8") as fin:
|
| 92 |
+
results = fin.read().strip()
|
| 93 |
+
|
| 94 |
+
assert results == EN_EXPECTED_OUTPUT
|
stanza/stanza/tests/ner/test_ner_trainer.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from stanza.tests import *
|
| 4 |
+
|
| 5 |
+
from stanza.models.ner import trainer
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 8 |
+
|
| 9 |
+
def test_fix_singleton_tags():
|
| 10 |
+
TESTS = [
|
| 11 |
+
(["O"], ["O"]),
|
| 12 |
+
(["B-PER"], ["S-PER"]),
|
| 13 |
+
(["B-PER", "I-PER"], ["B-PER", "E-PER"]),
|
| 14 |
+
(["B-PER", "O", "B-PER"], ["S-PER", "O", "S-PER"]),
|
| 15 |
+
(["B-PER", "B-PER", "I-PER"], ["S-PER", "B-PER", "E-PER"]),
|
| 16 |
+
(["B-PER", "I-PER", "O", "B-PER"], ["B-PER", "E-PER", "O", "S-PER"]),
|
| 17 |
+
(["B-PER", "B-PER", "I-PER", "B-PER"], ["S-PER", "B-PER", "E-PER", "S-PER"]),
|
| 18 |
+
(["B-PER", "I-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]),
|
| 19 |
+
(["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
|
| 20 |
+
(["S-PER", "B-PER", "E-PER"], ["S-PER", "B-PER", "E-PER"]),
|
| 21 |
+
(["E-PER"], ["S-PER"]),
|
| 22 |
+
(["E-PER", "O", "E-PER"], ["S-PER", "O", "S-PER"]),
|
| 23 |
+
(["B-PER", "E-ORG", "O", "B-PER"], ["S-PER", "S-ORG", "O", "S-PER"]),
|
| 24 |
+
(["I-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
|
| 25 |
+
(["B-PER", "I-PER", "I-PER", "O", "B-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
|
| 26 |
+
(["B-PER", "I-PER", "E-PER", "O", "I-PER", "E-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
|
| 27 |
+
(["B-PER", "I-PER", "E-PER", "O", "B-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
|
| 28 |
+
(["I-PER", "I-PER", "I-PER", "O", "I-PER", "I-PER"], ["B-PER", "I-PER", "E-PER", "O", "B-PER", "E-PER"]),
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
for unfixed, expected in TESTS:
|
| 32 |
+
assert trainer.fix_singleton_tags(unfixed) == expected, "Error converting {} to {}".format(unfixed, expected)
|
stanza/stanza/tests/ner/test_pay_amt_annotators.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple test for tracking AMT annotator work
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import zipfile
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stanza.tests import TEST_WORKING_DIR
|
| 11 |
+
from stanza.utils.ner import paying_annotators
|
| 12 |
+
|
| 13 |
+
DATA_SOURCE = os.path.join(TEST_WORKING_DIR, "in", "aws_annotations.zip")
|
| 14 |
+
|
| 15 |
+
@pytest.fixture(scope="module")
|
| 16 |
+
def completed_amt_job_metadata(tmp_path_factory):
|
| 17 |
+
assert os.path.exists(DATA_SOURCE)
|
| 18 |
+
unzip_path = tmp_path_factory.mktemp("amt_test")
|
| 19 |
+
input_path = unzip_path / "ner" / "aws_labeling_copy"
|
| 20 |
+
with zipfile.ZipFile(DATA_SOURCE, 'r') as zin:
|
| 21 |
+
zin.extractall(unzip_path)
|
| 22 |
+
return input_path
|
| 23 |
+
|
| 24 |
+
def test_amt_annotator_track(completed_amt_job_metadata):
|
| 25 |
+
workers = {
|
| 26 |
+
"7efc17ac-3397-4472-afe5-89184ad145d0": "Worker1",
|
| 27 |
+
"afce8c28-969c-4e73-a20f-622ef122f585": "Worker2",
|
| 28 |
+
"91f6236e-63c6-4a84-8fd6-1efbab6dedab": "Worker3",
|
| 29 |
+
"6f202e93-e6b6-4e1d-8f07-0484b9a9093a": "Worker4",
|
| 30 |
+
"2b674d33-f656-44b0-8f90-d70a1ab71ec2": "Worker5"
|
| 31 |
+
} # map AMT annotator subs to relevant identifier
|
| 32 |
+
|
| 33 |
+
tracked_work = paying_annotators.track_tasks(completed_amt_job_metadata, workers)
|
| 34 |
+
assert tracked_work == {'Worker4': 20, 'Worker5': 20, 'Worker2': 3, 'Worker3': 16}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_amt_annotator_track_no_map(completed_amt_job_metadata):
|
| 38 |
+
sub_to_count = paying_annotators.track_tasks(completed_amt_job_metadata)
|
| 39 |
+
assert sub_to_count == {'6f202e93-e6b6-4e1d-8f07-0484b9a9093a': 20, '2b674d33-f656-44b0-8f90-d70a1ab71ec2': 20,
|
| 40 |
+
'afce8c28-969c-4e73-a20f-622ef122f585': 3, '91f6236e-63c6-4a84-8fd6-1efbab6dedab': 16}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
test_amt_annotator_track()
|
| 45 |
+
test_amt_annotator_track_no_map()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
main()
|
| 50 |
+
print("TESTS COMPLETED!")
|
stanza/stanza/tests/ner/test_split_wikiner.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Runs a few tests on the split_wikiner file
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stanza.utils.datasets.ner import split_wikiner
|
| 11 |
+
|
| 12 |
+
from stanza.tests import *
|
| 13 |
+
|
| 14 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 15 |
+
|
| 16 |
+
# two sentences from the Italian dataset, split into many pieces
|
| 17 |
+
# to test the splitting functionality
|
| 18 |
+
FBK_SAMPLE = """
|
| 19 |
+
Il O
|
| 20 |
+
Papa O
|
| 21 |
+
si O
|
| 22 |
+
aggrava O
|
| 23 |
+
|
| 24 |
+
Le O
|
| 25 |
+
condizioni O
|
| 26 |
+
di O
|
| 27 |
+
|
| 28 |
+
Papa O
|
| 29 |
+
Giovanni PER
|
| 30 |
+
Paolo PER
|
| 31 |
+
II PER
|
| 32 |
+
si O
|
| 33 |
+
|
| 34 |
+
sono O
|
| 35 |
+
aggravate O
|
| 36 |
+
in O
|
| 37 |
+
il O
|
| 38 |
+
corso O
|
| 39 |
+
|
| 40 |
+
di O
|
| 41 |
+
la O
|
| 42 |
+
giornata O
|
| 43 |
+
di O
|
| 44 |
+
giovedì O
|
| 45 |
+
. O
|
| 46 |
+
|
| 47 |
+
Il O
|
| 48 |
+
portavoce O
|
| 49 |
+
Navarro PER
|
| 50 |
+
Valls PER
|
| 51 |
+
|
| 52 |
+
ha O
|
| 53 |
+
dichiarato O
|
| 54 |
+
che O
|
| 55 |
+
|
| 56 |
+
il O
|
| 57 |
+
Santo O
|
| 58 |
+
Padre O
|
| 59 |
+
|
| 60 |
+
in O
|
| 61 |
+
la O
|
| 62 |
+
giornata O
|
| 63 |
+
|
| 64 |
+
di O
|
| 65 |
+
oggi O
|
| 66 |
+
è O
|
| 67 |
+
stato O
|
| 68 |
+
|
| 69 |
+
colpito O
|
| 70 |
+
da O
|
| 71 |
+
una O
|
| 72 |
+
affezione O
|
| 73 |
+
|
| 74 |
+
altamente O
|
| 75 |
+
febbrile O
|
| 76 |
+
provocata O
|
| 77 |
+
da O
|
| 78 |
+
una O
|
| 79 |
+
|
| 80 |
+
infezione O
|
| 81 |
+
documentata O
|
| 82 |
+
|
| 83 |
+
di O
|
| 84 |
+
le O
|
| 85 |
+
vie O
|
| 86 |
+
urinarie O
|
| 87 |
+
. O
|
| 88 |
+
|
| 89 |
+
A O
|
| 90 |
+
il O
|
| 91 |
+
momento O
|
| 92 |
+
|
| 93 |
+
non O
|
| 94 |
+
è O
|
| 95 |
+
previsto O
|
| 96 |
+
il O
|
| 97 |
+
ricovero O
|
| 98 |
+
|
| 99 |
+
a O
|
| 100 |
+
il O
|
| 101 |
+
Policlinico LOC
|
| 102 |
+
Gemelli LOC
|
| 103 |
+
, O
|
| 104 |
+
|
| 105 |
+
come O
|
| 106 |
+
ha O
|
| 107 |
+
precisato O
|
| 108 |
+
il O
|
| 109 |
+
|
| 110 |
+
responsabile O
|
| 111 |
+
di O
|
| 112 |
+
il O
|
| 113 |
+
dipartimento O
|
| 114 |
+
|
| 115 |
+
di O
|
| 116 |
+
emergenza O
|
| 117 |
+
professor O
|
| 118 |
+
Rodolfo PER
|
| 119 |
+
Proietti PER
|
| 120 |
+
. O
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def test_read_sentences():
|
| 125 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 126 |
+
raw_filename = os.path.join(tempdir, "raw.tsv")
|
| 127 |
+
with open(raw_filename, "w") as fout:
|
| 128 |
+
fout.write(FBK_SAMPLE)
|
| 129 |
+
|
| 130 |
+
sentences = split_wikiner.read_sentences(raw_filename, "utf-8")
|
| 131 |
+
assert len(sentences) == 20
|
| 132 |
+
text = [["\t".join(word) for word in sent] for sent in sentences]
|
| 133 |
+
text = ["\n".join(sent) for sent in text]
|
| 134 |
+
text = "\n\n".join(text)
|
| 135 |
+
assert FBK_SAMPLE.strip() == text
|
| 136 |
+
|
| 137 |
+
def test_write_sentences():
|
| 138 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 139 |
+
raw_filename = os.path.join(tempdir, "raw.tsv")
|
| 140 |
+
with open(raw_filename, "w") as fout:
|
| 141 |
+
fout.write(FBK_SAMPLE)
|
| 142 |
+
|
| 143 |
+
sentences = split_wikiner.read_sentences(raw_filename, "utf-8")
|
| 144 |
+
copy_filename = os.path.join(tempdir, "copy.tsv")
|
| 145 |
+
split_wikiner.write_sentences_to_file(sentences, copy_filename)
|
| 146 |
+
|
| 147 |
+
sent2 = split_wikiner.read_sentences(raw_filename, "utf-8")
|
| 148 |
+
assert sent2 == sentences
|
| 149 |
+
|
| 150 |
+
def run_split_wikiner(expected_train=14, expected_dev=3, expected_test=3, **kwargs):
|
| 151 |
+
"""
|
| 152 |
+
Runs a test using various parameters to check the results of the splitting process
|
| 153 |
+
"""
|
| 154 |
+
with tempfile.TemporaryDirectory() as indir:
|
| 155 |
+
raw_filename = os.path.join(indir, "raw.tsv")
|
| 156 |
+
with open(raw_filename, "w") as fout:
|
| 157 |
+
fout.write(FBK_SAMPLE)
|
| 158 |
+
|
| 159 |
+
with tempfile.TemporaryDirectory() as outdir:
|
| 160 |
+
split_wikiner.split_wikiner(outdir, raw_filename, **kwargs)
|
| 161 |
+
|
| 162 |
+
train_file = os.path.join(outdir, "it_fbk.train.bio")
|
| 163 |
+
dev_file = os.path.join(outdir, "it_fbk.dev.bio")
|
| 164 |
+
test_file = os.path.join(outdir, "it_fbk.test.bio")
|
| 165 |
+
|
| 166 |
+
assert os.path.exists(train_file)
|
| 167 |
+
assert os.path.exists(dev_file)
|
| 168 |
+
if kwargs["test_section"]:
|
| 169 |
+
assert os.path.exists(test_file)
|
| 170 |
+
else:
|
| 171 |
+
assert not os.path.exists(test_file)
|
| 172 |
+
|
| 173 |
+
train_sent = split_wikiner.read_sentences(train_file, "utf-8")
|
| 174 |
+
dev_sent = split_wikiner.read_sentences(dev_file, "utf-8")
|
| 175 |
+
assert len(train_sent) == expected_train
|
| 176 |
+
assert len(dev_sent) == expected_dev
|
| 177 |
+
if kwargs["test_section"]:
|
| 178 |
+
test_sent = split_wikiner.read_sentences(test_file, "utf-8")
|
| 179 |
+
assert len(test_sent) == expected_test
|
| 180 |
+
else:
|
| 181 |
+
test_sent = []
|
| 182 |
+
|
| 183 |
+
if kwargs["shuffle"]:
|
| 184 |
+
orig_sents = sorted(split_wikiner.read_sentences(raw_filename, "utf-8"))
|
| 185 |
+
split_sents = sorted(train_sent + dev_sent + test_sent)
|
| 186 |
+
else:
|
| 187 |
+
orig_sents = split_wikiner.read_sentences(raw_filename, "utf-8")
|
| 188 |
+
split_sents = train_sent + dev_sent + test_sent
|
| 189 |
+
assert orig_sents == split_sents
|
| 190 |
+
|
| 191 |
+
def test_no_shuffle_split():
|
| 192 |
+
run_split_wikiner(prefix="it_fbk", shuffle=False, test_section=True)
|
| 193 |
+
|
| 194 |
+
def test_shuffle_split():
|
| 195 |
+
run_split_wikiner(prefix="it_fbk", shuffle=True, test_section=True)
|
| 196 |
+
|
| 197 |
+
def test_resize():
|
| 198 |
+
run_split_wikiner(expected_train=12, expected_dev=2, expected_test=6, train_fraction=0.6, dev_fraction=0.1, prefix="it_fbk", shuffle=True, test_section=True)
|
| 199 |
+
|
| 200 |
+
def test_no_test_split():
|
| 201 |
+
run_split_wikiner(expected_train=17, train_fraction=0.85, prefix="it_fbk", shuffle=False, test_section=False)
|
| 202 |
+
|
stanza/stanza/tests/ner/test_suc3.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests the conversion code for the SUC3 NER dataset
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from zipfile import ZipFile
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 11 |
+
|
| 12 |
+
import stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob
|
| 13 |
+
|
| 14 |
+
TEST_CONLL = """
|
| 15 |
+
1 Den den PN PN UTR|SIN|DEF|SUB/OBJ _ _ _ _ O _ ac01b-030:2328
|
| 16 |
+
2 Gud Gud PM PM NOM _ _ _ _ B myth ac01b-030:2329
|
| 17 |
+
3 giver giva VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2330
|
| 18 |
+
4 ämbetet ämbete NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2331
|
| 19 |
+
5 får få VB VB PRS|AKT _ _ _ _ O _ ac01b-030:2332
|
| 20 |
+
6 också också AB AB _ _ _ _ O _ ac01b-030:2333
|
| 21 |
+
7 förståndet förstånd NN NN NEU|SIN|DEF|NOM _ _ _ _ O _ ac01b-030:2334
|
| 22 |
+
8 . . MAD MAD _ _ _ _ O _ ac01b-030:2335
|
| 23 |
+
|
| 24 |
+
1 Han han PN PN UTR|SIN|DEF|SUB _ _ _ _ O _ aa01a-017:227
|
| 25 |
+
2 berättar berätta VB VB PRS|AKT _ _ _ _ O _ aa01a-017:228
|
| 26 |
+
3 anekdoten anekdot NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:229
|
| 27 |
+
4 som som HP HP -|-|- _ _ _ _ O _ aa01a-017:230
|
| 28 |
+
5 FN-medlaren FN-medlare NN NN UTR|SIN|DEF|NOM _ _ _ _ O _ aa01a-017:231
|
| 29 |
+
6 Brian Brian PM PM NOM _ _ _ _ B person aa01a-017:232
|
| 30 |
+
7 Urquhart Urquhart PM PM NOM _ _ _ _ I person aa01a-017:233
|
| 31 |
+
8 myntat mynta VB VB SUP|AKT _ _ _ _ O _ aa01a-017:234
|
| 32 |
+
9 : : MAD MAD _ _ _ _ O _ aa01a-017:235
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
EXPECTED_IOB = """
|
| 36 |
+
Den O
|
| 37 |
+
Gud B-myth
|
| 38 |
+
giver O
|
| 39 |
+
ämbetet O
|
| 40 |
+
får O
|
| 41 |
+
också O
|
| 42 |
+
förståndet O
|
| 43 |
+
. O
|
| 44 |
+
|
| 45 |
+
Han O
|
| 46 |
+
berättar O
|
| 47 |
+
anekdoten O
|
| 48 |
+
som O
|
| 49 |
+
FN-medlaren O
|
| 50 |
+
Brian B-person
|
| 51 |
+
Urquhart I-person
|
| 52 |
+
myntat O
|
| 53 |
+
: O
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def test_read_zip():
|
| 57 |
+
"""
|
| 58 |
+
Test creating a fake zip file, then converting it to an .iob file
|
| 59 |
+
"""
|
| 60 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 61 |
+
zip_name = os.path.join(tempdir, "test.zip")
|
| 62 |
+
in_filename = "conll"
|
| 63 |
+
with ZipFile(zip_name, "w") as zout:
|
| 64 |
+
with zout.open(in_filename, "w") as fout:
|
| 65 |
+
fout.write(TEST_CONLL.encode())
|
| 66 |
+
|
| 67 |
+
out_filename = "iob"
|
| 68 |
+
num = suc_conll_to_iob.extract_from_zip(zip_name, in_filename, out_filename)
|
| 69 |
+
assert num == 2
|
| 70 |
+
|
| 71 |
+
with open(out_filename) as fin:
|
| 72 |
+
result = fin.read()
|
| 73 |
+
assert EXPECTED_IOB.strip() == result.strip()
|
| 74 |
+
|
| 75 |
+
def test_read_raw():
|
| 76 |
+
"""
|
| 77 |
+
Test a direct text file conversion w/o the zip file
|
| 78 |
+
"""
|
| 79 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
| 80 |
+
in_filename = os.path.join(tempdir, "test.txt")
|
| 81 |
+
with open(in_filename, "w", encoding="utf-8") as fout:
|
| 82 |
+
fout.write(TEST_CONLL)
|
| 83 |
+
|
| 84 |
+
out_filename = "iob"
|
| 85 |
+
with open(in_filename, encoding="utf-8") as fin, open(out_filename, "w", encoding="utf-8") as fout:
|
| 86 |
+
num = suc_conll_to_iob.extract(fin, fout)
|
| 87 |
+
assert num == 2
|
| 88 |
+
|
| 89 |
+
with open(out_filename) as fin:
|
| 90 |
+
result = fin.read()
|
| 91 |
+
assert EXPECTED_IOB.strip() == result.strip()
|
stanza/stanza/tests/pipeline/test_decorators.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic tests of the depparse processor boolean flags
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
import stanza
|
| 7 |
+
from stanza.models.common.doc import Document
|
| 8 |
+
from stanza.pipeline.core import PipelineRequirementsException
|
| 9 |
+
from stanza.pipeline.processor import Processor, ProcessorVariant, register_processor, register_processor_variant, ProcessorRegisterException
|
| 10 |
+
from stanza.utils.conll import CoNLL
|
| 11 |
+
from stanza.tests import *
|
| 12 |
+
|
| 13 |
+
pytestmark = pytest.mark.pipeline
|
| 14 |
+
|
| 15 |
+
# data for testing
|
| 16 |
+
EN_DOC = "This is a test sentence. This is another!"
|
| 17 |
+
|
| 18 |
+
EN_DOC_LOWERCASE_TOKENS = '''<Token id=1;words=[<Word id=1;text=this>]>
|
| 19 |
+
<Token id=2;words=[<Word id=2;text=is>]>
|
| 20 |
+
<Token id=3;words=[<Word id=3;text=a>]>
|
| 21 |
+
<Token id=4;words=[<Word id=4;text=test>]>
|
| 22 |
+
<Token id=5;words=[<Word id=5;text=sentence>]>
|
| 23 |
+
<Token id=6;words=[<Word id=6;text=.>]>
|
| 24 |
+
|
| 25 |
+
<Token id=1;words=[<Word id=1;text=this>]>
|
| 26 |
+
<Token id=2;words=[<Word id=2;text=is>]>
|
| 27 |
+
<Token id=3;words=[<Word id=3;text=another>]>
|
| 28 |
+
<Token id=4;words=[<Word id=4;text=!>]>'''
|
| 29 |
+
|
| 30 |
+
EN_DOC_LOL_TOKENS = '''<Token id=1;words=[<Word id=1;text=LOL>]>
|
| 31 |
+
<Token id=2;words=[<Word id=2;text=LOL>]>
|
| 32 |
+
<Token id=3;words=[<Word id=3;text=LOL>]>
|
| 33 |
+
<Token id=4;words=[<Word id=4;text=LOL>]>
|
| 34 |
+
<Token id=5;words=[<Word id=5;text=LOL>]>
|
| 35 |
+
<Token id=6;words=[<Word id=6;text=LOL>]>
|
| 36 |
+
<Token id=7;words=[<Word id=7;text=LOL>]>
|
| 37 |
+
<Token id=8;words=[<Word id=8;text=LOL>]>'''
|
| 38 |
+
|
| 39 |
+
EN_DOC_COOL_LEMMAS = '''<Token id=1;words=[<Word id=1;text=This;lemma=cool;upos=PRON;xpos=DT;feats=Number=Sing|PronType=Dem>]>
|
| 40 |
+
<Token id=2;words=[<Word id=2;text=is;lemma=cool;upos=AUX;xpos=VBZ;feats=Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin>]>
|
| 41 |
+
<Token id=3;words=[<Word id=3;text=a;lemma=cool;upos=DET;xpos=DT;feats=Definite=Ind|PronType=Art>]>
|
| 42 |
+
<Token id=4;words=[<Word id=4;text=test;lemma=cool;upos=NOUN;xpos=NN;feats=Number=Sing>]>
|
| 43 |
+
<Token id=5;words=[<Word id=5;text=sentence;lemma=cool;upos=NOUN;xpos=NN;feats=Number=Sing>]>
|
| 44 |
+
<Token id=6;words=[<Word id=6;text=.;lemma=cool;upos=PUNCT;xpos=.>]>
|
| 45 |
+
|
| 46 |
+
<Token id=1;words=[<Word id=1;text=This;lemma=cool;upos=PRON;xpos=DT;feats=Number=Sing|PronType=Dem>]>
|
| 47 |
+
<Token id=2;words=[<Word id=2;text=is;lemma=cool;upos=AUX;xpos=VBZ;feats=Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin>]>
|
| 48 |
+
<Token id=3;words=[<Word id=3;text=another;lemma=cool;upos=DET;xpos=DT;feats=PronType=Ind>]>
|
| 49 |
+
<Token id=4;words=[<Word id=4;text=!;lemma=cool;upos=PUNCT;xpos=.>]>'''
|
| 50 |
+
|
| 51 |
+
@register_processor("lowercase")
|
| 52 |
+
class LowercaseProcessor(Processor):
|
| 53 |
+
''' Processor that lowercases all text '''
|
| 54 |
+
_requires = set(['tokenize'])
|
| 55 |
+
_provides = set(['lowercase'])
|
| 56 |
+
|
| 57 |
+
def __init__(self, config, pipeline, device):
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
def _set_up_model(self, *args):
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
def process(self, doc):
|
| 64 |
+
doc.text = doc.text.lower()
|
| 65 |
+
for sent in doc.sentences:
|
| 66 |
+
for tok in sent.tokens:
|
| 67 |
+
tok.text = tok.text.lower()
|
| 68 |
+
|
| 69 |
+
for word in sent.words:
|
| 70 |
+
word.text = word.text.lower()
|
| 71 |
+
|
| 72 |
+
return doc
|
| 73 |
+
|
| 74 |
+
def test_register_processor():
|
| 75 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors='tokenize,lowercase', download_method=None)
|
| 76 |
+
doc = nlp(EN_DOC)
|
| 77 |
+
assert EN_DOC_LOWERCASE_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
|
| 78 |
+
|
| 79 |
+
def test_register_nonprocessor():
|
| 80 |
+
with pytest.raises(ProcessorRegisterException):
|
| 81 |
+
@register_processor("nonprocessor")
|
| 82 |
+
class NonProcessor:
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
@register_processor_variant("tokenize", "lol")
|
| 86 |
+
class LOLTokenizer(ProcessorVariant):
|
| 87 |
+
''' An alternative tokenizer that splits text by space and replaces all tokens with LOL '''
|
| 88 |
+
|
| 89 |
+
def __init__(self, lang):
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
def process(self, text):
|
| 93 |
+
sentence = [{'id': (i+1, ), 'text': 'LOL'} for i, tok in enumerate(text.split())]
|
| 94 |
+
return Document([sentence], text)
|
| 95 |
+
|
| 96 |
+
def test_register_processor_variant():
|
| 97 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "lol"}, package=None, download_method=None)
|
| 98 |
+
doc = nlp(EN_DOC)
|
| 99 |
+
assert EN_DOC_LOL_TOKENS == '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
|
| 100 |
+
|
| 101 |
+
@register_processor_variant("lemma", "cool")
|
| 102 |
+
class CoolLemmatizer(ProcessorVariant):
|
| 103 |
+
''' An alternative lemmatizer that lemmatizes every word to "cool". '''
|
| 104 |
+
|
| 105 |
+
OVERRIDE = True
|
| 106 |
+
|
| 107 |
+
def __init__(self, lang):
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
def process(self, document):
|
| 111 |
+
for sentence in document.sentences:
|
| 112 |
+
for word in sentence.words:
|
| 113 |
+
word.lemma = "cool"
|
| 114 |
+
|
| 115 |
+
return document
|
| 116 |
+
|
| 117 |
+
def test_register_processor_variant_with_override():
|
| 118 |
+
nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en', processors={"tokenize": "combined", "pos": "combined", "lemma": "cool"}, package=None, download_method=None)
|
| 119 |
+
doc = nlp(EN_DOC)
|
| 120 |
+
result = '\n\n'.join(sent.tokens_string() for sent in doc.sentences)
|
| 121 |
+
assert EN_DOC_COOL_LEMMAS == result
|
| 122 |
+
|
| 123 |
+
def test_register_nonprocessor_variant():
|
| 124 |
+
with pytest.raises(ProcessorRegisterException):
|
| 125 |
+
@register_processor_variant("tokenize", "nonvariant")
|
| 126 |
+
class NonVariant:
|
| 127 |
+
pass
|
stanza/stanza/tests/pipeline/test_pipeline_mwt_expander.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic testing of multi-word-token expansion
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import stanza
|
| 7 |
+
|
| 8 |
+
from stanza.tests import *
|
| 9 |
+
|
| 10 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 11 |
+
|
| 12 |
+
# mwt data for testing
|
| 13 |
+
FR_MWT_SENTENCE = "Alors encore inconnu du grand public, Emmanuel Macron devient en 2014 ministre de l'Économie, de " \
|
| 14 |
+
"l'Industrie et du Numérique."
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
FR_MWT_TOKEN_TO_WORDS_GOLD = """
|
| 18 |
+
token: Alors words: [<Word id=1;text=Alors>]
|
| 19 |
+
token: encore words: [<Word id=2;text=encore>]
|
| 20 |
+
token: inconnu words: [<Word id=3;text=inconnu>]
|
| 21 |
+
token: du words: [<Word id=4;text=de>, <Word id=5;text=le>]
|
| 22 |
+
token: grand words: [<Word id=6;text=grand>]
|
| 23 |
+
token: public words: [<Word id=7;text=public>]
|
| 24 |
+
token: , words: [<Word id=8;text=,>]
|
| 25 |
+
token: Emmanuel words: [<Word id=9;text=Emmanuel>]
|
| 26 |
+
token: Macron words: [<Word id=10;text=Macron>]
|
| 27 |
+
token: devient words: [<Word id=11;text=devient>]
|
| 28 |
+
token: en words: [<Word id=12;text=en>]
|
| 29 |
+
token: 2014 words: [<Word id=13;text=2014>]
|
| 30 |
+
token: ministre words: [<Word id=14;text=ministre>]
|
| 31 |
+
token: de words: [<Word id=15;text=de>]
|
| 32 |
+
token: l' words: [<Word id=16;text=l'>]
|
| 33 |
+
token: Économie words: [<Word id=17;text=Économie>]
|
| 34 |
+
token: , words: [<Word id=18;text=,>]
|
| 35 |
+
token: de words: [<Word id=19;text=de>]
|
| 36 |
+
token: l' words: [<Word id=20;text=l'>]
|
| 37 |
+
token: Industrie words: [<Word id=21;text=Industrie>]
|
| 38 |
+
token: et words: [<Word id=22;text=et>]
|
| 39 |
+
token: du words: [<Word id=23;text=de>, <Word id=24;text=le>]
|
| 40 |
+
token: Numérique words: [<Word id=25;text=Numérique>]
|
| 41 |
+
token: . words: [<Word id=26;text=.>]
|
| 42 |
+
""".strip()
|
| 43 |
+
|
| 44 |
+
FR_MWT_WORD_TO_TOKEN_GOLD = """
|
| 45 |
+
word: Alors token parent:1-Alors
|
| 46 |
+
word: encore token parent:2-encore
|
| 47 |
+
word: inconnu token parent:3-inconnu
|
| 48 |
+
word: de token parent:4-5-du
|
| 49 |
+
word: le token parent:4-5-du
|
| 50 |
+
word: grand token parent:6-grand
|
| 51 |
+
word: public token parent:7-public
|
| 52 |
+
word: , token parent:8-,
|
| 53 |
+
word: Emmanuel token parent:9-Emmanuel
|
| 54 |
+
word: Macron token parent:10-Macron
|
| 55 |
+
word: devient token parent:11-devient
|
| 56 |
+
word: en token parent:12-en
|
| 57 |
+
word: 2014 token parent:13-2014
|
| 58 |
+
word: ministre token parent:14-ministre
|
| 59 |
+
word: de token parent:15-de
|
| 60 |
+
word: l' token parent:16-l'
|
| 61 |
+
word: Économie token parent:17-Économie
|
| 62 |
+
word: , token parent:18-,
|
| 63 |
+
word: de token parent:19-de
|
| 64 |
+
word: l' token parent:20-l'
|
| 65 |
+
word: Industrie token parent:21-Industrie
|
| 66 |
+
word: et token parent:22-et
|
| 67 |
+
word: de token parent:23-24-du
|
| 68 |
+
word: le token parent:23-24-du
|
| 69 |
+
word: Numérique token parent:25-Numérique
|
| 70 |
+
word: . token parent:26-.
|
| 71 |
+
""".strip()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_mwt():
|
| 75 |
+
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='fr', download_method=None)
|
| 76 |
+
doc = pipeline(FR_MWT_SENTENCE)
|
| 77 |
+
token_to_words = "\n".join(
|
| 78 |
+
[f'token: {token.text.ljust(9)}\t\twords: [{", ".join([word.pretty_print() for word in token.words])}]' for sent in doc.sentences for token in sent.tokens]
|
| 79 |
+
).strip()
|
| 80 |
+
word_to_token = "\n".join(
|
| 81 |
+
[f'word: {word.text.ljust(9)}\t\ttoken parent:{"-".join([str(x) for x in word.parent.id])}-{word.parent.text}'
|
| 82 |
+
for sent in doc.sentences for word in sent.words]).strip()
|
| 83 |
+
assert token_to_words == FR_MWT_TOKEN_TO_WORDS_GOLD
|
| 84 |
+
assert word_to_token == FR_MWT_WORD_TO_TOKEN_GOLD
|
| 85 |
+
|
| 86 |
+
def test_unknown_character():
|
| 87 |
+
"""
|
| 88 |
+
The MWT processor has a mechanism to temporarily add unknown characters to the vocab
|
| 89 |
+
|
| 90 |
+
Here we check that it is properly adding the characters from a test case a user sent us
|
| 91 |
+
"""
|
| 92 |
+
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
|
| 93 |
+
text = "Björkängshallen's"
|
| 94 |
+
mwt_processor = pipeline.processors["mwt"]
|
| 95 |
+
trainer = mwt_processor.trainer
|
| 96 |
+
# verify that the test case is still valid
|
| 97 |
+
# (perhaps an updated MWT model will have all of these characters in the future)
|
| 98 |
+
assert not all(x in trainer.vocab._unit2id for x in text)
|
| 99 |
+
doc = pipeline(text)
|
| 100 |
+
batch = mwt_processor.build_batch(doc)
|
| 101 |
+
# the vocab used in this batch should have the missing characters
|
| 102 |
+
assert all(x in batch.vocab._unit2id for x in text)
|
| 103 |
+
|
| 104 |
+
def test_unknown_word():
|
| 105 |
+
"""
|
| 106 |
+
Test a word which wasn't in the MWT training data
|
| 107 |
+
|
| 108 |
+
The seq2seq model for MWT was randomly hallucinating, but with the
|
| 109 |
+
CharacterClassifier, it should be able to process unusual MWT
|
| 110 |
+
without hallucinations
|
| 111 |
+
"""
|
| 112 |
+
pipe = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
|
| 113 |
+
doc = pipe("I read the newspaper's report.")
|
| 114 |
+
assert len(doc.sentences) == 1
|
| 115 |
+
assert len(doc.sentences[0].tokens) == 6
|
| 116 |
+
assert len(doc.sentences[0].tokens[3].words) == 2
|
| 117 |
+
assert doc.sentences[0].tokens[3].words[0].text == 'newspaper'
|
| 118 |
+
|
| 119 |
+
# double check that this is something unknown to the model
|
| 120 |
+
mwt_processor = pipe.processors["mwt"]
|
| 121 |
+
trainer = mwt_processor.trainer
|
| 122 |
+
expansion = trainer.dict_expansion("newspaper's")
|
| 123 |
+
assert expansion is None
|
stanza/stanza/tests/pos/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/pos/test_tagger.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run the tagger for a couple iterations on some fake data
|
| 3 |
+
|
| 4 |
+
Uses a couple sentences of UD_English-EWT as training/dev data
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import stanza
|
| 13 |
+
from stanza.models import tagger
|
| 14 |
+
from stanza.models.common import pretrain
|
| 15 |
+
from stanza.models.pos.trainer import Trainer
|
| 16 |
+
from stanza.tests import TEST_WORKING_DIR, TEST_MODELS_DIR
|
| 17 |
+
from stanza.utils.training.common import choose_pos_charlm, build_charlm_args
|
| 18 |
+
|
| 19 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 20 |
+
|
| 21 |
+
TRAIN_DATA = """
|
| 22 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
|
| 23 |
+
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
|
| 24 |
+
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
|
| 25 |
+
2 : : PUNCT : _ 1 punct 1:punct _
|
| 26 |
+
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
|
| 27 |
+
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
|
| 28 |
+
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
|
| 29 |
+
6 that that SCONJ IN _ 9 mark 9:mark _
|
| 30 |
+
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
|
| 31 |
+
8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
|
| 32 |
+
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
|
| 33 |
+
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
|
| 34 |
+
11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
|
| 35 |
+
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
|
| 36 |
+
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
|
| 37 |
+
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
|
| 38 |
+
15 in in ADP IN _ 16 case 16:case _
|
| 39 |
+
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
|
| 40 |
+
17 . . PUNCT . _ 1 punct 1:punct _
|
| 41 |
+
|
| 42 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
|
| 43 |
+
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
|
| 44 |
+
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
|
| 45 |
+
2 of of ADP IN _ 3 case 3:case _
|
| 46 |
+
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
|
| 47 |
+
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
|
| 48 |
+
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
|
| 49 |
+
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
|
| 50 |
+
7 by by ADP IN _ 9 case 9:case _
|
| 51 |
+
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
|
| 52 |
+
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
|
| 53 |
+
10 of of ADP IN _ 12 case 12:case _
|
| 54 |
+
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
|
| 55 |
+
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
|
| 56 |
+
13 of of ADP IN _ 15 case 15:case _
|
| 57 |
+
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
|
| 58 |
+
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
|
| 59 |
+
16 ! ! PUNCT . _ 6 punct 6:punct _
|
| 60 |
+
|
| 61 |
+
""".lstrip()
|
| 62 |
+
|
| 63 |
+
TRAIN_DATA_2 = """
|
| 64 |
+
# sent_id = 11
|
| 65 |
+
# text = It's all hers!
|
| 66 |
+
# previous = Which person owns this?
|
| 67 |
+
# comment = predeterminer modifier
|
| 68 |
+
1 It it PRON PRP Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No
|
| 69 |
+
2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _
|
| 70 |
+
3 all all DET DT Case=Nom 4 det:predet _ _
|
| 71 |
+
4 hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
|
| 72 |
+
5 ! ! PUNCT . _ 4 punct _ _
|
| 73 |
+
|
| 74 |
+
""".lstrip()
|
| 75 |
+
|
| 76 |
+
TRAIN_DATA_NO_UPOS = """
|
| 77 |
+
# sent_id = 11
|
| 78 |
+
# text = It's all hers!
|
| 79 |
+
# previous = Which person owns this?
|
| 80 |
+
# comment = predeterminer modifier
|
| 81 |
+
1 It it _ PRP Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No
|
| 82 |
+
2 's be _ VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _
|
| 83 |
+
3 all all _ DT Case=Nom 4 det:predet _ _
|
| 84 |
+
4 hers hers _ PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
|
| 85 |
+
5 ! ! _ . _ 4 punct _ _
|
| 86 |
+
|
| 87 |
+
""".lstrip()
|
| 88 |
+
|
| 89 |
+
TRAIN_DATA_NO_XPOS = """
|
| 90 |
+
# sent_id = 11
|
| 91 |
+
# text = It's all hers!
|
| 92 |
+
# previous = Which person owns this?
|
| 93 |
+
# comment = predeterminer modifier
|
| 94 |
+
1 It it PRON _ Number=Sing|Person=3|PronType=Prs 4 nsubj _ SpaceAfter=No
|
| 95 |
+
2 's be AUX _ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 cop _ _
|
| 96 |
+
3 all all DET _ Case=Nom 4 det:predet _ _
|
| 97 |
+
4 hers hers PRON _ Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 0 root _ SpaceAfter=No
|
| 98 |
+
5 ! ! PUNCT _ _ 4 punct _ _
|
| 99 |
+
|
| 100 |
+
""".lstrip()
|
| 101 |
+
|
| 102 |
+
TRAIN_DATA_NO_FEATS = """
|
| 103 |
+
# sent_id = 11
|
| 104 |
+
# text = It's all hers!
|
| 105 |
+
# previous = Which person owns this?
|
| 106 |
+
# comment = predeterminer modifier
|
| 107 |
+
1 It it PRON PRP _ 4 nsubj _ SpaceAfter=No
|
| 108 |
+
2 's be AUX VBZ _ 4 cop _ _
|
| 109 |
+
3 all all DET DT _ 4 det:predet _ _
|
| 110 |
+
4 hers hers PRON PRP _ 0 root _ SpaceAfter=No
|
| 111 |
+
5 ! ! PUNCT . _ 4 punct _ _
|
| 112 |
+
|
| 113 |
+
""".lstrip()
|
| 114 |
+
|
| 115 |
+
DEV_DATA = """
|
| 116 |
+
1 From from ADP IN _ 3 case 3:case _
|
| 117 |
+
2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
|
| 118 |
+
3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
|
| 119 |
+
4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
|
| 120 |
+
5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
|
| 121 |
+
6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
|
| 122 |
+
7 : : PUNCT : _ 4 punct 4:punct _
|
| 123 |
+
|
| 124 |
+
""".lstrip()
|
| 125 |
+
|
| 126 |
+
class TestTagger:
|
| 127 |
+
@pytest.fixture(scope="class")
|
| 128 |
+
def wordvec_pretrain_file(self):
|
| 129 |
+
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
|
| 130 |
+
|
| 131 |
+
@pytest.fixture(scope="class")
|
| 132 |
+
def charlm_args(self):
|
| 133 |
+
charlm = choose_pos_charlm("en", "test", "default")
|
| 134 |
+
charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR)
|
| 135 |
+
return charlm_args
|
| 136 |
+
|
| 137 |
+
def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None):
|
| 138 |
+
"""
|
| 139 |
+
Run the training for a few iterations, load & return the model
|
| 140 |
+
"""
|
| 141 |
+
dev_file = str(tmp_path / "dev.conllu")
|
| 142 |
+
pred_file = str(tmp_path / "pred.conllu")
|
| 143 |
+
|
| 144 |
+
save_name = "test_tagger.pt"
|
| 145 |
+
save_file = str(tmp_path / save_name)
|
| 146 |
+
|
| 147 |
+
if isinstance(train_text, str):
|
| 148 |
+
train_text = [train_text]
|
| 149 |
+
train_files = []
|
| 150 |
+
for idx, train_blob in enumerate(train_text):
|
| 151 |
+
train_file = str(tmp_path / ("train_%d.conllu" % idx))
|
| 152 |
+
with open(train_file, "w", encoding="utf-8") as fout:
|
| 153 |
+
fout.write(train_blob)
|
| 154 |
+
train_files.append(train_file)
|
| 155 |
+
train_file = ";".join(train_files)
|
| 156 |
+
|
| 157 |
+
with open(dev_file, "w", encoding="utf-8") as fout:
|
| 158 |
+
fout.write(dev_text)
|
| 159 |
+
|
| 160 |
+
args = ["--wordvec_pretrain_file", wordvec_pretrain_file,
|
| 161 |
+
"--train_file", train_file,
|
| 162 |
+
"--eval_file", dev_file,
|
| 163 |
+
"--output_file", pred_file,
|
| 164 |
+
"--log_step", "10",
|
| 165 |
+
"--eval_interval", "20",
|
| 166 |
+
"--max_steps", "100",
|
| 167 |
+
"--shorthand", "en_test",
|
| 168 |
+
"--save_dir", str(tmp_path),
|
| 169 |
+
"--save_name", save_name,
|
| 170 |
+
"--lang", "en"]
|
| 171 |
+
if not augment_nopunct:
|
| 172 |
+
args.extend(["--augment_nopunct", "0.0"])
|
| 173 |
+
if extra_args is not None:
|
| 174 |
+
args = args + extra_args
|
| 175 |
+
tagger.main(args)
|
| 176 |
+
|
| 177 |
+
assert os.path.exists(save_file)
|
| 178 |
+
pt = pretrain.Pretrain(wordvec_pretrain_file)
|
| 179 |
+
saved_model = Trainer(pretrain=pt, model_file=save_file)
|
| 180 |
+
return saved_model
|
| 181 |
+
|
| 182 |
+
def test_train(self, tmp_path, wordvec_pretrain_file, augment_nopunct=True):
|
| 183 |
+
"""
|
| 184 |
+
Simple test of a few 'epochs' of tagger training
|
| 185 |
+
"""
|
| 186 |
+
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)
|
| 187 |
+
|
| 188 |
+
def test_vocab_cutoff(self, tmp_path, wordvec_pretrain_file):
|
| 189 |
+
"""
|
| 190 |
+
Test that the vocab cutoff leaves words we expect in the vocab, but not rare words
|
| 191 |
+
"""
|
| 192 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=["--word_cutoff", "3"])
|
| 193 |
+
word_vocab = trainer.vocab['word']
|
| 194 |
+
assert 'of' in word_vocab
|
| 195 |
+
assert 'officials' in TRAIN_DATA
|
| 196 |
+
assert 'officials' not in word_vocab
|
| 197 |
+
|
| 198 |
+
def test_multiple_files(self, tmp_path, wordvec_pretrain_file):
|
| 199 |
+
"""
|
| 200 |
+
Test that multiple train files works
|
| 201 |
+
|
| 202 |
+
Checks for evidence of it working by looking for words from the second file in the vocab
|
| 203 |
+
"""
|
| 204 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA, TRAIN_DATA_2 * 3], DEV_DATA, extra_args=["--word_cutoff", "3"])
|
| 205 |
+
word_vocab = trainer.vocab['word']
|
| 206 |
+
assert 'of' in word_vocab
|
| 207 |
+
assert 'officials' in TRAIN_DATA
|
| 208 |
+
assert 'officials' not in word_vocab
|
| 209 |
+
|
| 210 |
+
assert ' hers ' not in TRAIN_DATA
|
| 211 |
+
assert ' hers ' in TRAIN_DATA_2
|
| 212 |
+
assert 'hers' in word_vocab
|
| 213 |
+
|
| 214 |
+
def test_train_zero_augment(self, tmp_path, wordvec_pretrain_file):
|
| 215 |
+
"""
|
| 216 |
+
Train with the punct augmentation set to zero
|
| 217 |
+
|
| 218 |
+
Distinguishs cases where training works w/ or w/o augmentation
|
| 219 |
+
"""
|
| 220 |
+
extra_args = ['--augment_nopunct', '0.0']
|
| 221 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
|
| 222 |
+
|
| 223 |
+
def test_train_100_augment(self, tmp_path, wordvec_pretrain_file):
|
| 224 |
+
"""
|
| 225 |
+
Train with the punct augmentation set to 1.0
|
| 226 |
+
|
| 227 |
+
Distinguishs cases where training works w/ or w/o augmentation
|
| 228 |
+
"""
|
| 229 |
+
extra_args = ['--augment_nopunct', '1.0']
|
| 230 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
|
| 231 |
+
|
| 232 |
+
def test_train_charlm(self, tmp_path, wordvec_pretrain_file, charlm_args):
|
| 233 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=charlm_args)
|
| 234 |
+
|
| 235 |
+
def test_train_charlm_projection(self, tmp_path, wordvec_pretrain_file, charlm_args):
|
| 236 |
+
extra_args = charlm_args + ['--charlm_transform_dim', '100']
|
| 237 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
|
| 238 |
+
|
| 239 |
+
def test_missing_column(self, tmp_path, wordvec_pretrain_file):
|
| 240 |
+
"""
|
| 241 |
+
Test that using train files with missing columns works
|
| 242 |
+
|
| 243 |
+
In this test, we create three separate files, each with a single training entry.
|
| 244 |
+
We then train on an amalgam of those three files with a batch size of 1, saving after each batch.
|
| 245 |
+
This will ensure that only one item is used for each training loop and we can inspect the models which were saved.
|
| 246 |
+
|
| 247 |
+
Since each of the three files have exactly one column missing
|
| 248 |
+
from the training data, we expect to see the output maps for
|
| 249 |
+
each column stay unchanged in one iteration and change in the
|
| 250 |
+
other two.
|
| 251 |
+
"""
|
| 252 |
+
# use SGD because some old versions of pytorch with Adam keep
|
| 253 |
+
# learning a value even if the loss is 0 in subsequent steps
|
| 254 |
+
# (perhaps it had a momentum by default?)
|
| 255 |
+
extra_args = ['--save_each', '--eval_interval', '1', '--max_steps', '3', '--batch_size', '1', '--optim', 'sgd']
|
| 256 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_UPOS, TRAIN_DATA_NO_XPOS, TRAIN_DATA_NO_FEATS], DEV_DATA, extra_args=extra_args)
|
| 257 |
+
save_each_name = tagger.save_each_file_name(trainer.args)
|
| 258 |
+
model_files = [save_each_name % i for i in range(4)]
|
| 259 |
+
assert all(os.path.exists(x) for x in model_files)
|
| 260 |
+
pt = pretrain.Pretrain(wordvec_pretrain_file)
|
| 261 |
+
saved_trainers = [Trainer(pretrain=pt, model_file=model_file) for model_file in model_files]
|
| 262 |
+
|
| 263 |
+
upos_unchanged = 0
|
| 264 |
+
xpos_unchanged = 0
|
| 265 |
+
ufeats_unchanged = 0
|
| 266 |
+
for t1, t2 in zip(saved_trainers[:-1], saved_trainers[1:]):
|
| 267 |
+
upos_unchanged += torch.allclose(t1.model.upos_clf.weight, t2.model.upos_clf.weight)
|
| 268 |
+
xpos_unchanged += torch.allclose(t1.model.xpos_clf.W_bilin.weight, t2.model.xpos_clf.W_bilin.weight)
|
| 269 |
+
ufeats_unchanged += all(torch.allclose(f1.W_bilin.weight, f2.W_bilin.weight) for f1, f2 in zip(t1.model.ufeats_clf, t2.model.ufeats_clf))
|
| 270 |
+
upos_norms = [torch.linalg.norm(t.model.upos_clf.weight) for t in saved_trainers]
|
| 271 |
+
assert upos_unchanged == 1, "Unchanged: {} {} {} {}".format(upos_unchanged, xpos_unchanged, ufeats_unchanged, upos_norms)
|
| 272 |
+
assert xpos_unchanged == 1, "Unchanged: %d %d %d" % (upos_unchanged, xpos_unchanged, ufeats_unchanged)
|
| 273 |
+
assert ufeats_unchanged == 1, "Unchanged: %d %d %d" % (upos_unchanged, xpos_unchanged, ufeats_unchanged)
|
| 274 |
+
|
| 275 |
+
def test_save_each(self, tmp_path, wordvec_pretrain_file):
|
| 276 |
+
extra_args = ['--save_each']
|
| 277 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)
|
| 278 |
+
save_each_name = tagger.save_each_file_name(trainer.args)
|
| 279 |
+
expected_models = sorted(set([save_each_name % i for i in range(0, trainer.args['max_steps']+1, trainer.args['eval_interval'])]))
|
| 280 |
+
assert len(expected_models) == 6
|
| 281 |
+
for model_name in expected_models:
|
| 282 |
+
assert os.path.exists(model_name)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def test_with_bert(self, tmp_path, wordvec_pretrain_file):
|
| 286 |
+
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert'])
|
| 287 |
+
|
| 288 |
+
def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
|
| 289 |
+
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])
|
| 290 |
+
|
| 291 |
+
def test_with_bert_finetune(self, tmp_path, wordvec_pretrain_file):
|
| 292 |
+
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_learning_rate', '0.01', '--bert_hidden_layers', '2'])
|
| 293 |
+
|
| 294 |
+
def test_bert_pipeline(self, tmp_path, wordvec_pretrain_file):
|
| 295 |
+
"""
|
| 296 |
+
Test training the tagger, then using it in a pipeline
|
| 297 |
+
|
| 298 |
+
The pipeline use of the tagger also tests the longer-than-maxlen workaround for the transformer
|
| 299 |
+
"""
|
| 300 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert'])
|
| 301 |
+
save_name = trainer.args['save_name']
|
| 302 |
+
save_file = str(tmp_path / save_name)
|
| 303 |
+
assert os.path.exists(save_file)
|
| 304 |
+
|
| 305 |
+
pipe = stanza.Pipeline("en", processors="tokenize,pos", models_dir=TEST_MODELS_DIR, pos_model_path=save_file, pos_pretrain_path=wordvec_pretrain_file)
|
| 306 |
+
trainer = pipe.processors['pos'].trainer
|
| 307 |
+
assert trainer.args['save_name'] == save_name
|
| 308 |
+
|
| 309 |
+
# these should be one chunk only
|
| 310 |
+
doc = pipe("foo " * 100)
|
| 311 |
+
doc = pipe("foo " * 500)
|
| 312 |
+
# this is two chunks of bert embedding
|
| 313 |
+
doc = pipe("foo " * 1000)
|
| 314 |
+
# this is multiple chunks
|
| 315 |
+
doc = pipe("foo " * 2000)
|
stanza/stanza/tests/resources/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/resources/test_default_packages.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
import stanza
|
| 4 |
+
|
| 5 |
+
from stanza.resources import default_packages
|
| 6 |
+
|
| 7 |
+
def test_default_pretrains():
|
| 8 |
+
"""
|
| 9 |
+
Test that all languages with a default treebank have a default pretrain or are specifically marked as not having a pretrain
|
| 10 |
+
"""
|
| 11 |
+
for lang in default_packages.default_treebanks.keys():
|
| 12 |
+
assert lang in default_packages.no_pretrain_languages or lang in default_packages.default_pretrains, "Lang %s does not have a default pretrain marked!" % lang
|
| 13 |
+
|
| 14 |
+
def test_no_pretrain_languages():
|
| 15 |
+
"""
|
| 16 |
+
Test that no languages have no_default_pretrain marked despite having a pretrain
|
| 17 |
+
"""
|
| 18 |
+
for lang in default_packages.no_pretrain_languages:
|
| 19 |
+
assert lang not in default_packages.default_pretrains, "Lang %s is marked as no_pretrain but has a default pretrain!" % lang
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
stanza/stanza/tests/resources/test_prepare_resources.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
import stanza
|
| 4 |
+
import stanza.resources.prepare_resources as prepare_resources
|
| 5 |
+
|
| 6 |
+
from stanza.tests import *
|
| 7 |
+
|
| 8 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 9 |
+
|
| 10 |
+
def test_split_model_name():
|
| 11 |
+
# Basic test
|
| 12 |
+
lang, package, processor = prepare_resources.split_model_name('ro_nonstandard_tagger.pt')
|
| 13 |
+
assert lang == 'ro'
|
| 14 |
+
assert package == 'nonstandard'
|
| 15 |
+
assert processor == 'pos'
|
| 16 |
+
|
| 17 |
+
# Check that nertagger is found even though it also ends with tagger
|
| 18 |
+
# Check that ncbi_disease is correctly partitioned despite the extra _
|
| 19 |
+
lang, package, processor = prepare_resources.split_model_name('en_ncbi_disease_nertagger.pt')
|
| 20 |
+
assert lang == 'en'
|
| 21 |
+
assert package == 'ncbi_disease'
|
| 22 |
+
assert processor == 'ner'
|
| 23 |
+
|
| 24 |
+
# assert that processors with _ in them are also okay
|
| 25 |
+
lang, package, processor = prepare_resources.split_model_name('en_pubmed_forward_charlm.pt')
|
| 26 |
+
assert lang == 'en'
|
| 27 |
+
assert package == 'pubmed'
|
| 28 |
+
assert processor == 'forward_charlm'
|
| 29 |
+
|
| 30 |
+
|
stanza/stanza/tests/server/test_server_misc.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Misc tests for the server
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import re
|
| 7 |
+
import stanza.server as corenlp
|
| 8 |
+
from stanza.tests import compare_ignoring_whitespace
|
| 9 |
+
|
| 10 |
+
pytestmark = pytest.mark.client
|
| 11 |
+
|
| 12 |
+
EN_DOC = "Joe Smith lives in California."
|
| 13 |
+
|
| 14 |
+
EN_DOC_GOLD = """
|
| 15 |
+
Sentence #1 (6 tokens):
|
| 16 |
+
Joe Smith lives in California.
|
| 17 |
+
|
| 18 |
+
Tokens:
|
| 19 |
+
[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP Lemma=Joe NamedEntityTag=PERSON]
|
| 20 |
+
[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP Lemma=Smith NamedEntityTag=PERSON]
|
| 21 |
+
[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ Lemma=live NamedEntityTag=O]
|
| 22 |
+
[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN Lemma=in NamedEntityTag=O]
|
| 23 |
+
[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP Lemma=California NamedEntityTag=STATE_OR_PROVINCE]
|
| 24 |
+
[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=. Lemma=. NamedEntityTag=O]
|
| 25 |
+
|
| 26 |
+
Dependency Parse (enhanced plus plus dependencies):
|
| 27 |
+
root(ROOT-0, lives-3)
|
| 28 |
+
compound(Smith-2, Joe-1)
|
| 29 |
+
nsubj(lives-3, Smith-2)
|
| 30 |
+
case(California-5, in-4)
|
| 31 |
+
obl:in(lives-3, California-5)
|
| 32 |
+
punct(lives-3, .-6)
|
| 33 |
+
|
| 34 |
+
Extracted the following NER entity mentions:
|
| 35 |
+
Joe Smith PERSON PERSON:0.9972202681743931
|
| 36 |
+
California STATE_OR_PROVINCE LOCATION:0.9990868267559281
|
| 37 |
+
|
| 38 |
+
Extracted the following KBP triples:
|
| 39 |
+
1.0 Joe Smith per:statesorprovinces_of_residence California
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
EN_DOC_POS_ONLY_GOLD = """
|
| 44 |
+
Sentence #1 (6 tokens):
|
| 45 |
+
Joe Smith lives in California.
|
| 46 |
+
|
| 47 |
+
Tokens:
|
| 48 |
+
[Text=Joe CharacterOffsetBegin=0 CharacterOffsetEnd=3 PartOfSpeech=NNP]
|
| 49 |
+
[Text=Smith CharacterOffsetBegin=4 CharacterOffsetEnd=9 PartOfSpeech=NNP]
|
| 50 |
+
[Text=lives CharacterOffsetBegin=10 CharacterOffsetEnd=15 PartOfSpeech=VBZ]
|
| 51 |
+
[Text=in CharacterOffsetBegin=16 CharacterOffsetEnd=18 PartOfSpeech=IN]
|
| 52 |
+
[Text=California CharacterOffsetBegin=19 CharacterOffsetEnd=29 PartOfSpeech=NNP]
|
| 53 |
+
[Text=. CharacterOffsetBegin=29 CharacterOffsetEnd=30 PartOfSpeech=.]
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def test_english_request():
|
| 57 |
+
""" Test case of starting server with Spanish defaults, and then requesting default English properties """
|
| 58 |
+
with corenlp.CoreNLPClient(properties='spanish', server_id='test_spanish_english_request') as client:
|
| 59 |
+
ann = client.annotate(EN_DOC, properties='english', output_format='text')
|
| 60 |
+
compare_ignoring_whitespace(ann, EN_DOC_GOLD)
|
| 61 |
+
|
| 62 |
+
# Rerun the test with a server created in English mode to verify
|
| 63 |
+
# that the expected output is what the defaults actually give us
|
| 64 |
+
with corenlp.CoreNLPClient(properties='english', server_id='test_english_request') as client:
|
| 65 |
+
ann = client.annotate(EN_DOC, output_format='text')
|
| 66 |
+
compare_ignoring_whitespace(ann, EN_DOC_GOLD)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_default_annotators():
|
| 70 |
+
"""
|
| 71 |
+
Test case of creating a client with start_server=False and a set of annotators
|
| 72 |
+
The annotators should be used instead of the server's default annotators
|
| 73 |
+
"""
|
| 74 |
+
with corenlp.CoreNLPClient(server_id='test_default_annotators',
|
| 75 |
+
output_format='text',
|
| 76 |
+
annotators=['tokenize','ssplit','pos','lemma','ner','depparse']) as client:
|
| 77 |
+
with corenlp.CoreNLPClient(start_server=False,
|
| 78 |
+
output_format='text',
|
| 79 |
+
annotators=['tokenize','ssplit','pos']) as client2:
|
| 80 |
+
ann = client2.annotate(EN_DOC)
|
| 81 |
+
|
| 82 |
+
expected_codepoints = ((0, 1), (2, 4), (5, 8), (9, 15), (16, 20))
|
| 83 |
+
expected_characters = ((0, 1), (2, 4), (5, 10), (11, 17), (18, 22))
|
| 84 |
+
codepoint_doc = "I am 𝒚̂𝒊 random text"
|
| 85 |
+
|
| 86 |
+
def test_codepoints():
|
| 87 |
+
""" Test case of asking for codepoints from the English tokenizer """
|
| 88 |
+
with corenlp.CoreNLPClient(annotators=['tokenize','ssplit'], # 'depparse','coref'],
|
| 89 |
+
properties={'tokenize.codepoint': 'true'}) as client:
|
| 90 |
+
ann = client.annotate(codepoint_doc)
|
| 91 |
+
for i, (codepoints, characters) in enumerate(zip(expected_codepoints, expected_characters)):
|
| 92 |
+
token = ann.sentence[0].token[i]
|
| 93 |
+
assert token.codepointOffsetBegin == codepoints[0]
|
| 94 |
+
assert token.codepointOffsetEnd == codepoints[1]
|
| 95 |
+
assert token.beginChar == characters[0]
|
| 96 |
+
assert token.endChar == characters[1]
|
| 97 |
+
|
| 98 |
+
def test_codepoint_text():
|
| 99 |
+
""" Test case of extracting the correct sentence text using codepoints """
|
| 100 |
+
|
| 101 |
+
text = 'Unban mox opal 🐱. This is a second sentence.'
|
| 102 |
+
|
| 103 |
+
with corenlp.CoreNLPClient(annotators=["tokenize","ssplit"],
|
| 104 |
+
properties={'tokenize.codepoint': 'true'}) as client:
|
| 105 |
+
ann = client.annotate(text)
|
| 106 |
+
|
| 107 |
+
text_start = ann.sentence[0].token[0].codepointOffsetBegin
|
| 108 |
+
text_end = ann.sentence[0].token[-1].codepointOffsetEnd
|
| 109 |
+
sentence_text = text[text_start:text_end]
|
| 110 |
+
assert sentence_text == 'Unban mox opal 🐱.'
|
| 111 |
+
|
| 112 |
+
text_start = ann.sentence[1].token[0].codepointOffsetBegin
|
| 113 |
+
text_end = ann.sentence[1].token[-1].codepointOffsetEnd
|
| 114 |
+
sentence_text = text[text_start:text_end]
|
| 115 |
+
assert sentence_text == 'This is a second sentence.'
|
stanza/stanza/utils/datasets/common.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse
|
| 3 |
+
from enum import Enum
|
| 4 |
+
import glob
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import subprocess
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
from stanza.models.common.short_name_to_treebank import canonical_treebank_name
|
| 12 |
+
import stanza.utils.datasets.prepare_tokenizer_data as prepare_tokenizer_data
|
| 13 |
+
import stanza.utils.default_paths as default_paths
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger('stanza')
|
| 16 |
+
|
| 17 |
+
# RE to see if the index of a conllu line represents an MWT
|
| 18 |
+
MWT_RE = re.compile("^[0-9]+[-][0-9]+")
|
| 19 |
+
|
| 20 |
+
# RE to see if the index of a conllu line represents an MWT or copy node
|
| 21 |
+
MWT_OR_COPY_RE = re.compile("^[0-9]+[-.][0-9]+")
|
| 22 |
+
|
| 23 |
+
# more restrictive than an actual int as we expect certain formats in the conllu files
|
| 24 |
+
INT_RE = re.compile("^[0-9]+$")
|
| 25 |
+
|
| 26 |
+
CONLLU_TO_TXT_PERL = os.path.join(os.path.split(__file__)[0], "conllu_to_text.pl")
|
| 27 |
+
|
| 28 |
+
class ModelType(Enum):
|
| 29 |
+
TOKENIZER = 1
|
| 30 |
+
MWT = 2
|
| 31 |
+
POS = 3
|
| 32 |
+
LEMMA = 4
|
| 33 |
+
DEPPARSE = 5
|
| 34 |
+
|
| 35 |
+
class UnknownDatasetError(ValueError):
|
| 36 |
+
def __init__(self, dataset, text):
|
| 37 |
+
super().__init__(text)
|
| 38 |
+
self.dataset = dataset
|
| 39 |
+
|
| 40 |
+
def convert_conllu_to_txt(tokenizer_dir, short_name, shards=("train", "dev", "test")):
|
| 41 |
+
"""
|
| 42 |
+
Uses the udtools perl script to convert a conllu file to txt
|
| 43 |
+
|
| 44 |
+
TODO: switch to a python version to get rid of some perl dependence
|
| 45 |
+
"""
|
| 46 |
+
for dataset in shards:
|
| 47 |
+
output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
|
| 48 |
+
output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
|
| 49 |
+
|
| 50 |
+
if not os.path.exists(output_conllu):
|
| 51 |
+
# the perl script doesn't raise an error code for file not found!
|
| 52 |
+
raise FileNotFoundError("Cannot convert %s as the file cannot be found" % output_conllu)
|
| 53 |
+
# use an external script to produce the txt files
|
| 54 |
+
subprocess.check_output(f"perl {CONLLU_TO_TXT_PERL} {output_conllu} > {output_txt}", shell=True)
|
| 55 |
+
|
| 56 |
+
def mwt_name(base_dir, short_name, dataset):
|
| 57 |
+
return os.path.join(base_dir, f"{short_name}-ud-{dataset}-mwt.json")
|
| 58 |
+
|
| 59 |
+
def tokenizer_conllu_name(base_dir, short_name, dataset):
|
| 60 |
+
return os.path.join(base_dir, f"{short_name}.{dataset}.gold.conllu")
|
| 61 |
+
|
| 62 |
+
def prepare_tokenizer_dataset_labels(input_txt, input_conllu, tokenizer_dir, short_name, dataset):
|
| 63 |
+
labels_filename = f"{tokenizer_dir}/{short_name}-ud-{dataset}.toklabels"
|
| 64 |
+
mwt_filename = mwt_name(tokenizer_dir, short_name, dataset)
|
| 65 |
+
prepare_tokenizer_data.main([input_txt,
|
| 66 |
+
input_conllu,
|
| 67 |
+
"-o", labels_filename,
|
| 68 |
+
"-m", mwt_filename])
|
| 69 |
+
|
| 70 |
+
def prepare_tokenizer_treebank_labels(tokenizer_dir, short_name):
|
| 71 |
+
"""
|
| 72 |
+
Given the txt and gold.conllu files, prepare mwt and label files for train/dev/test
|
| 73 |
+
"""
|
| 74 |
+
for dataset in ("train", "dev", "test"):
|
| 75 |
+
output_txt = f"{tokenizer_dir}/{short_name}.{dataset}.txt"
|
| 76 |
+
output_conllu = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
|
| 77 |
+
try:
|
| 78 |
+
prepare_tokenizer_dataset_labels(output_txt, output_conllu, tokenizer_dir, short_name, dataset)
|
| 79 |
+
except (KeyboardInterrupt, SystemExit):
|
| 80 |
+
raise
|
| 81 |
+
except:
|
| 82 |
+
print("Failed to convert %s to %s" % (output_txt, output_conllu))
|
| 83 |
+
raise
|
| 84 |
+
|
| 85 |
+
def read_sentences_from_conllu(filename):
|
| 86 |
+
"""
|
| 87 |
+
Reads a conllu file as a list of list of strings
|
| 88 |
+
|
| 89 |
+
Finding a blank line separates the lists
|
| 90 |
+
"""
|
| 91 |
+
sents = []
|
| 92 |
+
cache = []
|
| 93 |
+
with open(filename, encoding="utf-8") as infile:
|
| 94 |
+
for line in infile:
|
| 95 |
+
line = line.strip()
|
| 96 |
+
if len(line) == 0:
|
| 97 |
+
if len(cache) > 0:
|
| 98 |
+
sents.append(cache)
|
| 99 |
+
cache = []
|
| 100 |
+
continue
|
| 101 |
+
cache.append(line)
|
| 102 |
+
if len(cache) > 0:
|
| 103 |
+
sents.append(cache)
|
| 104 |
+
return sents
|
| 105 |
+
|
| 106 |
+
def maybe_add_fake_dependencies(lines):
|
| 107 |
+
"""
|
| 108 |
+
Possibly add fake dependencies in columns 6 and 7 (counting from 0)
|
| 109 |
+
|
| 110 |
+
The conllu scripts need the dependencies column filled out, so in
|
| 111 |
+
the case of models we build without dependency data, we need to
|
| 112 |
+
add those fake dependencies in order to use the eval script etc
|
| 113 |
+
|
| 114 |
+
lines: a list of strings with 10 tab separated columns
|
| 115 |
+
comments are allowed (they will be skipped)
|
| 116 |
+
|
| 117 |
+
returns: the same strings, but with fake dependencies added
|
| 118 |
+
if columns 6 and 7 were empty
|
| 119 |
+
"""
|
| 120 |
+
new_lines = []
|
| 121 |
+
root_idx = None
|
| 122 |
+
first_idx = None
|
| 123 |
+
for line_idx, line in enumerate(lines):
|
| 124 |
+
if line.startswith("#"):
|
| 125 |
+
new_lines.append(line)
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
pieces = line.split("\t")
|
| 129 |
+
if MWT_OR_COPY_RE.match(pieces[0]):
|
| 130 |
+
new_lines.append(line)
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
token_idx = int(pieces[0])
|
| 134 |
+
if pieces[6] != '_':
|
| 135 |
+
if pieces[6] == '0':
|
| 136 |
+
root_idx = token_idx
|
| 137 |
+
new_lines.append(line)
|
| 138 |
+
elif token_idx == 1:
|
| 139 |
+
# note that the comments might make this not the first line
|
| 140 |
+
# we keep track of this separately so we can either make this the root,
|
| 141 |
+
# or set this to be the root later
|
| 142 |
+
first_idx = line_idx
|
| 143 |
+
new_lines.append(pieces)
|
| 144 |
+
else:
|
| 145 |
+
pieces[6] = "1"
|
| 146 |
+
pieces[7] = "dep"
|
| 147 |
+
new_lines.append("\t".join(pieces))
|
| 148 |
+
if first_idx is not None:
|
| 149 |
+
if root_idx is None:
|
| 150 |
+
new_lines[first_idx][6] = "0"
|
| 151 |
+
new_lines[first_idx][7] = "root"
|
| 152 |
+
else:
|
| 153 |
+
new_lines[first_idx][6] = str(root_idx)
|
| 154 |
+
new_lines[first_idx][7] = "dep"
|
| 155 |
+
new_lines[first_idx] = "\t".join(new_lines[first_idx])
|
| 156 |
+
return new_lines
|
| 157 |
+
|
| 158 |
+
def write_sentences_to_file(outfile, sents):
|
| 159 |
+
for lines in sents:
|
| 160 |
+
lines = maybe_add_fake_dependencies(lines)
|
| 161 |
+
for line in lines:
|
| 162 |
+
print(line, file=outfile)
|
| 163 |
+
print("", file=outfile)
|
| 164 |
+
|
| 165 |
+
def write_sentences_to_conllu(filename, sents):
|
| 166 |
+
with open(filename, 'w', encoding="utf-8") as outfile:
|
| 167 |
+
write_sentences_to_file(outfile, sents)
|
| 168 |
+
|
| 169 |
+
def find_treebank_dataset_file(treebank, udbase_dir, dataset, extension, fail=False, env_var="UDBASE"):
|
| 170 |
+
"""
|
| 171 |
+
For a given treebank, dataset, extension, look for the exact filename to use.
|
| 172 |
+
|
| 173 |
+
Sometimes the short name we use is different from the short name
|
| 174 |
+
used by UD. For example, Norwegian or Chinese. Hence the reason
|
| 175 |
+
to not hardcode it based on treebank
|
| 176 |
+
|
| 177 |
+
set fail=True to fail if the file is not found
|
| 178 |
+
"""
|
| 179 |
+
if treebank.startswith("UD_Korean") and treebank.endswith("_seg"):
|
| 180 |
+
treebank = treebank[:-4]
|
| 181 |
+
filename = os.path.join(udbase_dir, treebank, f"*-ud-{dataset}.{extension}")
|
| 182 |
+
files = glob.glob(filename)
|
| 183 |
+
if len(files) == 0:
|
| 184 |
+
if fail:
|
| 185 |
+
raise FileNotFoundError("Could not find any treebank files which matched {}\nIf you have the data elsewhere, you can change the base directory for the search by changing the {} environment variable".format(filename, env_var))
|
| 186 |
+
else:
|
| 187 |
+
return None
|
| 188 |
+
elif len(files) == 1:
|
| 189 |
+
return files[0]
|
| 190 |
+
else:
|
| 191 |
+
raise RuntimeError(f"Unexpected number of files matched '{udbase_dir}/{treebank}/*-ud-{dataset}.{extension}'")
|
| 192 |
+
|
| 193 |
+
def mostly_underscores(filename):
|
| 194 |
+
"""
|
| 195 |
+
Certain treebanks have proprietary data, so the text is hidden
|
| 196 |
+
|
| 197 |
+
For example:
|
| 198 |
+
UD_Arabic-NYUAD
|
| 199 |
+
UD_English-ESL
|
| 200 |
+
UD_English-GUMReddit
|
| 201 |
+
UD_Hindi_English-HIENCS
|
| 202 |
+
UD_Japanese-BCCWJ
|
| 203 |
+
"""
|
| 204 |
+
underscore_count = 0
|
| 205 |
+
total_count = 0
|
| 206 |
+
for line in open(filename).readlines():
|
| 207 |
+
line = line.strip()
|
| 208 |
+
if not line:
|
| 209 |
+
continue
|
| 210 |
+
if line.startswith("#"):
|
| 211 |
+
continue
|
| 212 |
+
total_count = total_count + 1
|
| 213 |
+
pieces = line.split("\t")
|
| 214 |
+
if pieces[1] in ("_", "-"):
|
| 215 |
+
underscore_count = underscore_count + 1
|
| 216 |
+
return underscore_count / total_count > 0.5
|
| 217 |
+
|
| 218 |
+
def num_words_in_file(conllu_file):
|
| 219 |
+
"""
|
| 220 |
+
Count the number of non-blank lines in a conllu file
|
| 221 |
+
"""
|
| 222 |
+
count = 0
|
| 223 |
+
with open(conllu_file) as fin:
|
| 224 |
+
for line in fin:
|
| 225 |
+
line = line.strip()
|
| 226 |
+
if not line:
|
| 227 |
+
continue
|
| 228 |
+
if line.startswith("#"):
|
| 229 |
+
continue
|
| 230 |
+
count = count + 1
|
| 231 |
+
return count
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_ud_treebanks(udbase_dir, filtered=True):
|
| 235 |
+
"""
|
| 236 |
+
Looks in udbase_dir for all the treebanks which have both train, dev, and test
|
| 237 |
+
"""
|
| 238 |
+
treebanks = sorted(glob.glob(udbase_dir + "/UD_*"))
|
| 239 |
+
# skip UD_English-GUMReddit as it is usually incorporated into UD_English-GUM
|
| 240 |
+
treebanks = [os.path.split(t)[1] for t in treebanks]
|
| 241 |
+
treebanks = [t for t in treebanks if t != "UD_English-GUMReddit"]
|
| 242 |
+
if filtered:
|
| 243 |
+
treebanks = [t for t in treebanks
|
| 244 |
+
if (find_treebank_dataset_file(t, udbase_dir, "train", "conllu") and
|
| 245 |
+
# this will be fixed using XV
|
| 246 |
+
#find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") and
|
| 247 |
+
find_treebank_dataset_file(t, udbase_dir, "test", "conllu"))]
|
| 248 |
+
treebanks = [t for t in treebanks
|
| 249 |
+
if not mostly_underscores(find_treebank_dataset_file(t, udbase_dir, "train", "conllu"))]
|
| 250 |
+
# eliminate partial treebanks (fixed with XV) for which we only have 1000 words or less
|
| 251 |
+
# if the train set is small and the test set is large enough, we'll flip them
|
| 252 |
+
treebanks = [t for t in treebanks
|
| 253 |
+
if (find_treebank_dataset_file(t, udbase_dir, "dev", "conllu") or
|
| 254 |
+
num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "train", "conllu")) > 1000 or
|
| 255 |
+
num_words_in_file(find_treebank_dataset_file(t, udbase_dir, "test", "conllu")) > 5000)]
|
| 256 |
+
return treebanks
|
| 257 |
+
|
| 258 |
+
def build_argparse():
|
| 259 |
+
parser = argparse.ArgumentParser()
|
| 260 |
+
parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
|
| 261 |
+
|
| 262 |
+
return parser
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def main(process_treebank, model_type, add_specific_args=None):
|
| 266 |
+
logger.info("Datasets program called with:\n" + " ".join(sys.argv))
|
| 267 |
+
|
| 268 |
+
parser = build_argparse()
|
| 269 |
+
if add_specific_args is not None:
|
| 270 |
+
add_specific_args(parser)
|
| 271 |
+
args = parser.parse_args()
|
| 272 |
+
|
| 273 |
+
paths = default_paths.get_default_paths()
|
| 274 |
+
|
| 275 |
+
treebanks = []
|
| 276 |
+
for treebank in args.treebanks:
|
| 277 |
+
if treebank.lower() in ('ud_all', 'all_ud'):
|
| 278 |
+
ud_treebanks = get_ud_treebanks(paths["UDBASE"])
|
| 279 |
+
treebanks.extend(ud_treebanks)
|
| 280 |
+
else:
|
| 281 |
+
# If this is a known UD short name, use the official name (we need it for the paths)
|
| 282 |
+
treebank = canonical_treebank_name(treebank)
|
| 283 |
+
treebanks.append(treebank)
|
| 284 |
+
|
| 285 |
+
for treebank in treebanks:
|
| 286 |
+
process_treebank(treebank, model_type, paths, args)
|
stanza/stanza/utils/datasets/conllu_to_text.pl
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env perl
|
| 2 |
+
# Extracts raw text from CoNLL-U file. Uses newdoc and newpar tags when available.
|
| 3 |
+
# Copyright © 2017 Dan Zeman <zeman@ufal.mff.cuni.cz>
|
| 4 |
+
# License: GNU GPL
|
| 5 |
+
|
| 6 |
+
use utf8;
|
| 7 |
+
use open ':utf8';
|
| 8 |
+
binmode(STDIN, ':utf8');
|
| 9 |
+
binmode(STDOUT, ':utf8');
|
| 10 |
+
binmode(STDERR, ':utf8');
|
| 11 |
+
use Getopt::Long;
|
| 12 |
+
|
| 13 |
+
# Language code 'zh' or 'ja' will trigger Chinese-like text formatting.
|
| 14 |
+
my $language = 'en';
|
| 15 |
+
GetOptions
|
| 16 |
+
(
|
| 17 |
+
'language=s' => \$language
|
| 18 |
+
);
|
| 19 |
+
my $chinese = $language =~ m/^(zh|ja|lzh|yue)(_|$)/;
|
| 20 |
+
|
| 21 |
+
my $text = ''; # from the text attribute of the sentence
|
| 22 |
+
my $ftext = ''; # from the word forms of the tokens
|
| 23 |
+
my $newpar = 0;
|
| 24 |
+
my $newdoc = 0;
|
| 25 |
+
my $buffer = '';
|
| 26 |
+
my $start = 1;
|
| 27 |
+
my $mwtlast;
|
| 28 |
+
while(<>)
|
| 29 |
+
{
|
| 30 |
+
if(m/^\#\s*text\s*=\s*(.+)/)
|
| 31 |
+
{
|
| 32 |
+
$text = $1;
|
| 33 |
+
}
|
| 34 |
+
elsif(m/^\#\s*newpar(\s|$)/i)
|
| 35 |
+
{
|
| 36 |
+
$newpar = 1;
|
| 37 |
+
}
|
| 38 |
+
elsif(m/^\#\s*newdoc(\s|$)/i)
|
| 39 |
+
{
|
| 40 |
+
$newdoc = 1;
|
| 41 |
+
}
|
| 42 |
+
elsif(m/^\d+-(\d+)\t/)
|
| 43 |
+
{
|
| 44 |
+
$mwtlast = $1;
|
| 45 |
+
my @f = split(/\t/, $_);
|
| 46 |
+
# Paragraphs may start in the middle of a sentence (bulleted lists, verse etc.)
|
| 47 |
+
# The first token of the new paragraph has "NewPar=Yes" in the MISC column.
|
| 48 |
+
# Multi-word tokens have this in the token-introducing line.
|
| 49 |
+
if($f[9] =~ m/NewPar=Yes/i)
|
| 50 |
+
{
|
| 51 |
+
# Empty line between documents and paragraphs. (There may have been
|
| 52 |
+
# a paragraph break before the first part of this sentence as well!)
|
| 53 |
+
$buffer = print_new_paragraph_if_needed($start, $newdoc, $newpar, $buffer);
|
| 54 |
+
$buffer .= $ftext;
|
| 55 |
+
# Line breaks at word boundaries after at most 80 characters.
|
| 56 |
+
$buffer = print_lines_from_buffer($buffer, 80, $chinese);
|
| 57 |
+
print("$buffer\n\n");
|
| 58 |
+
$buffer = '';
|
| 59 |
+
# Start is only true until we write the first sentence of the input stream.
|
| 60 |
+
$start = 0;
|
| 61 |
+
$newdoc = 0;
|
| 62 |
+
$newpar = 0;
|
| 63 |
+
$text = '';
|
| 64 |
+
$ftext = '';
|
| 65 |
+
}
|
| 66 |
+
$ftext .= $f[1];
|
| 67 |
+
$ftext .= ' ' unless($f[9] =~ m/SpaceAfter=No/);
|
| 68 |
+
}
|
| 69 |
+
elsif(m/^(\d+)\t/ && !(defined($mwtlast) && $1<=$mwtlast))
|
| 70 |
+
{
|
| 71 |
+
$mwtlast = undef;
|
| 72 |
+
my @f = split(/\t/, $_);
|
| 73 |
+
# Paragraphs may start in the middle of a sentence (bulleted lists, verse etc.)
|
| 74 |
+
# The first token of the new paragraph has "NewPar=Yes" in the MISC column.
|
| 75 |
+
# Multi-word tokens have this in the token-introducing line.
|
| 76 |
+
if($f[9] =~ m/NewPar=Yes/i)
|
| 77 |
+
{
|
| 78 |
+
# Empty line between documents and paragraphs. (There may have been
|
| 79 |
+
# a paragraph break before the first part of this sentence as well!)
|
| 80 |
+
$buffer = print_new_paragraph_if_needed($start, $newdoc, $newpar, $buffer);
|
| 81 |
+
$buffer .= $ftext;
|
| 82 |
+
# Line breaks at word boundaries after at most 80 characters.
|
| 83 |
+
$buffer = print_lines_from_buffer($buffer, 80, $chinese);
|
| 84 |
+
print("$buffer\n\n");
|
| 85 |
+
$buffer = '';
|
| 86 |
+
# Start is only true until we write the first sentence of the input stream.
|
| 87 |
+
$start = 0;
|
| 88 |
+
$newdoc = 0;
|
| 89 |
+
$newpar = 0;
|
| 90 |
+
$text = '';
|
| 91 |
+
$ftext = '';
|
| 92 |
+
}
|
| 93 |
+
$ftext .= $f[1];
|
| 94 |
+
$ftext .= ' ' unless($f[9] =~ m/SpaceAfter=No/);
|
| 95 |
+
}
|
| 96 |
+
elsif(m/^\s*$/)
|
| 97 |
+
{
|
| 98 |
+
# In a valid CoNLL-U file, $text should be equal to $ftext except for the
|
| 99 |
+
# space after the last token. However, if there have been intra-sentential
|
| 100 |
+
# paragraph breaks, $ftext contains only the part after the last such
|
| 101 |
+
# break, and $text is empty. Hence we currently use $ftext everywhere
|
| 102 |
+
# and ignore $text, even though we note it when seeing the text attribute.
|
| 103 |
+
# $text .= ' ' unless($chinese);
|
| 104 |
+
# Empty line between documents and paragraphs.
|
| 105 |
+
$buffer = print_new_paragraph_if_needed($start, $newdoc, $newpar, $buffer);
|
| 106 |
+
$buffer .= $ftext;
|
| 107 |
+
# Line breaks at word boundaries after at most 80 characters.
|
| 108 |
+
$buffer = print_lines_from_buffer($buffer, 80, $chinese);
|
| 109 |
+
# Start is only true until we write the first sentence of the input stream.
|
| 110 |
+
$start = 0;
|
| 111 |
+
$newdoc = 0;
|
| 112 |
+
$newpar = 0;
|
| 113 |
+
$text = '';
|
| 114 |
+
$ftext = '';
|
| 115 |
+
$mwtlast = undef;
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
# There may be unflushed buffer contents after the last sentence, less than 80 characters
|
| 119 |
+
# (otherwise we would have already dealt with it), so just flush it.
|
| 120 |
+
if($buffer ne '')
|
| 121 |
+
{
|
| 122 |
+
print("$buffer\n");
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
#------------------------------------------------------------------------------
|
| 128 |
+
# Checks whether we have to print an extra line to separate paragraphs. Does it
|
| 129 |
+
# if necessary. Returns the updated buffer.
|
| 130 |
+
#------------------------------------------------------------------------------
|
| 131 |
+
sub print_new_paragraph_if_needed
|
| 132 |
+
{
|
| 133 |
+
my $start = shift;
|
| 134 |
+
my $newdoc = shift;
|
| 135 |
+
my $newpar = shift;
|
| 136 |
+
my $buffer = shift;
|
| 137 |
+
if(!$start && ($newdoc || $newpar))
|
| 138 |
+
{
|
| 139 |
+
if($buffer ne '')
|
| 140 |
+
{
|
| 141 |
+
print("$buffer\n");
|
| 142 |
+
$buffer = '';
|
| 143 |
+
}
|
| 144 |
+
print("\n");
|
| 145 |
+
}
|
| 146 |
+
return $buffer;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
#------------------------------------------------------------------------------
|
| 152 |
+
# Prints as many complete lines of text as there are in the buffer. Returns the
|
| 153 |
+
# remaining contents of the buffer.
|
| 154 |
+
#------------------------------------------------------------------------------
|
| 155 |
+
sub print_lines_from_buffer
|
| 156 |
+
{
|
| 157 |
+
my $buffer = shift;
|
| 158 |
+
# Maximum number of characters allowed on one line, not counting the line
|
| 159 |
+
# break character(s), which also replace any number of trailing spaces.
|
| 160 |
+
# Exception: If there is a word longer than the limit, it will be printed
|
| 161 |
+
# on one line.
|
| 162 |
+
# Note that this algorithm is not suitable for Chinese and Japanese.
|
| 163 |
+
my $limit = shift;
|
| 164 |
+
# We need a different algorithm for Chinese and Japanese.
|
| 165 |
+
my $chinese = shift;
|
| 166 |
+
if($chinese)
|
| 167 |
+
{
|
| 168 |
+
return print_chinese_lines_from_buffer($buffer, $limit);
|
| 169 |
+
}
|
| 170 |
+
if(length($buffer) >= $limit)
|
| 171 |
+
{
|
| 172 |
+
my @cbuffer = split(//, $buffer);
|
| 173 |
+
# There may be more than one new line waiting in the buffer.
|
| 174 |
+
while(scalar(@cbuffer) >= $limit)
|
| 175 |
+
{
|
| 176 |
+
###!!! We could make it simpler if we ignored multi-space sequences
|
| 177 |
+
###!!! between words. It sounds OK to ignore them because at the
|
| 178 |
+
###!!! line break we do not respect original spacing anyway.
|
| 179 |
+
my $i;
|
| 180 |
+
my $ilastspace;
|
| 181 |
+
for($i = 0; $i<=$#cbuffer; $i++)
|
| 182 |
+
{
|
| 183 |
+
if($i>$limit && defined($ilastspace))
|
| 184 |
+
{
|
| 185 |
+
last;
|
| 186 |
+
}
|
| 187 |
+
if($cbuffer[$i] =~ m/\s/)
|
| 188 |
+
{
|
| 189 |
+
$ilastspace = $i;
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
if(defined($ilastspace) && $ilastspace>0)
|
| 193 |
+
{
|
| 194 |
+
my @out = @cbuffer[0..($ilastspace-1)];
|
| 195 |
+
splice(@cbuffer, 0, $ilastspace+1);
|
| 196 |
+
print(join('', @out), "\n");
|
| 197 |
+
}
|
| 198 |
+
else
|
| 199 |
+
{
|
| 200 |
+
print(join('', @cbuffer), "\n");
|
| 201 |
+
splice(@cbuffer);
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
$buffer = join('', @cbuffer);
|
| 205 |
+
}
|
| 206 |
+
return $buffer;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
#------------------------------------------------------------------------------
|
| 212 |
+
# Prints as many complete lines of text as there are in the buffer. Returns the
|
| 213 |
+
# remaining contents of the buffer. Assumes that there are no spaces between
|
| 214 |
+
# words and lines can be broken between any two characters, as is the custom in
|
| 215 |
+
# Chinese and Japanese.
|
| 216 |
+
#------------------------------------------------------------------------------
|
| 217 |
+
sub print_chinese_lines_from_buffer
|
| 218 |
+
{
|
| 219 |
+
my $buffer = shift;
|
| 220 |
+
# Maximum number of characters allowed on one line, not counting the line
|
| 221 |
+
# break character(s).
|
| 222 |
+
my $limit = shift;
|
| 223 |
+
# We cannot simply print the first $limit characters from the buffer,
|
| 224 |
+
# followed by a line break. There could be embedded Latin words or
|
| 225 |
+
# numbers and we do not want to insert a line break in the middle of
|
| 226 |
+
# a foreign word.
|
| 227 |
+
my @cbuffer = split(//, $buffer);
|
| 228 |
+
while(scalar(@cbuffer) >= $limit)
|
| 229 |
+
{
|
| 230 |
+
my $nprint = 0;
|
| 231 |
+
for(my $i = 0; $i <= $#cbuffer; $i++)
|
| 232 |
+
{
|
| 233 |
+
if($i > $limit && $nprint > 0)
|
| 234 |
+
{
|
| 235 |
+
last;
|
| 236 |
+
}
|
| 237 |
+
unless($i < $#cbuffer && $cbuffer[$i] =~ m/[\p{Latin}0-9]/ && $cbuffer[$i+1] =~ m/[\p{Latin}0-9]/)
|
| 238 |
+
{
|
| 239 |
+
$nprint = $i+1;
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
my @out = @cbuffer[0..($nprint-1)];
|
| 243 |
+
splice(@cbuffer, 0, $nprint);
|
| 244 |
+
print(join('', @out), "\n");
|
| 245 |
+
}
|
| 246 |
+
$buffer = join('', @cbuffer);
|
| 247 |
+
return $buffer;
|
| 248 |
+
}
|
stanza/stanza/utils/datasets/prepare_lemma_classifier.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from stanza.utils.datasets.common import find_treebank_dataset_file, UnknownDatasetError
|
| 5 |
+
from stanza.utils.default_paths import get_default_paths
|
| 6 |
+
from stanza.models.lemma_classifier import prepare_dataset
|
| 7 |
+
from stanza.models.common.short_name_to_treebank import short_name_to_treebank
|
| 8 |
+
from stanza.utils.conll import CoNLL
|
| 9 |
+
|
| 10 |
+
SECTIONS = ("train", "dev", "test")
|
| 11 |
+
|
| 12 |
+
def process_treebank(paths, short_name, word, upos, allowed_lemmas, sections=SECTIONS):
|
| 13 |
+
treebank = short_name_to_treebank(short_name)
|
| 14 |
+
udbase_dir = paths["UDBASE"]
|
| 15 |
+
|
| 16 |
+
output_dir = paths["LEMMA_CLASSIFIER_DATA_DIR"]
|
| 17 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
output_filenames = []
|
| 20 |
+
|
| 21 |
+
for section in sections:
|
| 22 |
+
filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True)
|
| 23 |
+
output_filename = os.path.join(output_dir, "%s.%s.lemma" % (short_name, section))
|
| 24 |
+
args = ["--conll_path", filename,
|
| 25 |
+
"--target_word", word,
|
| 26 |
+
"--target_upos", upos,
|
| 27 |
+
"--output_path", output_filename]
|
| 28 |
+
if allowed_lemmas is not None:
|
| 29 |
+
args.extend(["--allowed_lemmas", allowed_lemmas])
|
| 30 |
+
prepare_dataset.main(args)
|
| 31 |
+
output_filenames.append(output_filename)
|
| 32 |
+
|
| 33 |
+
return output_filenames
|
| 34 |
+
|
| 35 |
+
def process_en_combined(paths, short_name):
|
| 36 |
+
udbase_dir = paths["UDBASE"]
|
| 37 |
+
output_dir = paths["LEMMA_CLASSIFIER_DATA_DIR"]
|
| 38 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
train_treebanks = ["UD_English-EWT", "UD_English-GUM", "UD_English-GUMReddit", "UD_English-LinES"]
|
| 41 |
+
test_treebanks = ["UD_English-PUD", "UD_English-Pronouns"]
|
| 42 |
+
|
| 43 |
+
target_word = "'s"
|
| 44 |
+
target_upos = ["AUX"]
|
| 45 |
+
|
| 46 |
+
sentences = [ [], [], [] ]
|
| 47 |
+
for treebank in train_treebanks:
|
| 48 |
+
for section_idx, section in enumerate(SECTIONS):
|
| 49 |
+
filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True)
|
| 50 |
+
doc = CoNLL.conll2doc(filename)
|
| 51 |
+
processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=".*")
|
| 52 |
+
new_sentences = processor.process_document(doc, save_name=None)
|
| 53 |
+
print("Read %d sentences from %s" % (len(new_sentences), filename))
|
| 54 |
+
sentences[section_idx].extend(new_sentences)
|
| 55 |
+
for treebank in test_treebanks:
|
| 56 |
+
section = "test"
|
| 57 |
+
filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True)
|
| 58 |
+
doc = CoNLL.conll2doc(filename)
|
| 59 |
+
processor = prepare_dataset.DataProcessor(target_word=target_word, target_upos=target_upos, allowed_lemmas=".*")
|
| 60 |
+
new_sentences = processor.process_document(doc, save_name=None)
|
| 61 |
+
print("Read %d sentences from %s" % (len(new_sentences), filename))
|
| 62 |
+
sentences[2].extend(new_sentences)
|
| 63 |
+
|
| 64 |
+
for section, section_sentences in zip(SECTIONS, sentences):
|
| 65 |
+
output_filename = os.path.join(output_dir, "%s.%s.lemma" % (short_name, section))
|
| 66 |
+
prepare_dataset.DataProcessor.write_output_file(output_filename, target_upos, section_sentences)
|
| 67 |
+
print("Wrote %s sentences to %s" % (len(section_sentences), output_filename))
|
| 68 |
+
|
| 69 |
+
def process_ja_gsd(paths, short_name):
|
| 70 |
+
# this one looked promising, but only has 10 total dev & test cases
|
| 71 |
+
# 行っ VERB Counter({'行う': 60, '行く': 38})
|
| 72 |
+
# could possibly do
|
| 73 |
+
# ない AUX Counter({'ない': 383, '無い': 99})
|
| 74 |
+
# なく AUX Counter({'無い': 53, 'ない': 42})
|
| 75 |
+
# currently this one has enough in the dev & test data
|
| 76 |
+
# and functions well
|
| 77 |
+
# だ AUX Counter({'だ': 237, 'た': 67})
|
| 78 |
+
word = "だ"
|
| 79 |
+
upos = "AUX"
|
| 80 |
+
allowed_lemmas = None
|
| 81 |
+
|
| 82 |
+
process_treebank(paths, short_name, word, upos, allowed_lemmas)
|
| 83 |
+
|
| 84 |
+
def process_fa_perdt(paths, short_name):
|
| 85 |
+
word = "شد"
|
| 86 |
+
upos = "VERB"
|
| 87 |
+
allowed_lemmas = "کرد|شد"
|
| 88 |
+
|
| 89 |
+
process_treebank(paths, short_name, word, upos, allowed_lemmas)
|
| 90 |
+
|
| 91 |
+
def process_hi_hdtb(paths, short_name):
|
| 92 |
+
word = "के"
|
| 93 |
+
upos = "ADP"
|
| 94 |
+
allowed_lemmas = "का|के"
|
| 95 |
+
|
| 96 |
+
process_treebank(paths, short_name, word, upos, allowed_lemmas)
|
| 97 |
+
|
| 98 |
+
def process_ar_padt(paths, short_name):
|
| 99 |
+
word = "أن"
|
| 100 |
+
upos = "SCONJ"
|
| 101 |
+
allowed_lemmas = "أَن|أَنَّ"
|
| 102 |
+
|
| 103 |
+
process_treebank(paths, short_name, word, upos, allowed_lemmas)
|
| 104 |
+
|
| 105 |
+
def process_el_gdt(paths, short_name):
|
| 106 |
+
"""
|
| 107 |
+
All of the Greek lemmas for these words are εγώ or μου
|
| 108 |
+
|
| 109 |
+
τους PRON Counter({'μου': 118, 'εγώ': 32})
|
| 110 |
+
μας PRON Counter({'μου': 89, 'εγώ': 32})
|
| 111 |
+
του PRON Counter({'μου': 82, 'εγώ': 8})
|
| 112 |
+
της PRON Counter({'μου': 80, 'εγώ': 2})
|
| 113 |
+
σας PRON Counter({'μου': 34, 'εγώ': 24})
|
| 114 |
+
μου PRON Counter({'μου': 45, 'εγώ': 10})
|
| 115 |
+
"""
|
| 116 |
+
word = "τους|μας|του|της|σας|μου"
|
| 117 |
+
upos = "PRON"
|
| 118 |
+
allowed_lemmas = None
|
| 119 |
+
|
| 120 |
+
process_treebank(paths, short_name, word, upos, allowed_lemmas)
|
| 121 |
+
|
| 122 |
+
DATASET_MAPPING = {
|
| 123 |
+
"ar_padt": process_ar_padt,
|
| 124 |
+
"el_gdt": process_el_gdt,
|
| 125 |
+
"en_combined": process_en_combined,
|
| 126 |
+
"fa_perdt": process_fa_perdt,
|
| 127 |
+
"hi_hdtb": process_hi_hdtb,
|
| 128 |
+
"ja_gsd": process_ja_gsd,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main(dataset_name):
|
| 133 |
+
paths = get_default_paths()
|
| 134 |
+
print("Processing %s" % dataset_name)
|
| 135 |
+
|
| 136 |
+
# obviously will want to multiplex to multiple languages / datasets
|
| 137 |
+
if dataset_name in DATASET_MAPPING:
|
| 138 |
+
DATASET_MAPPING[dataset_name](paths, dataset_name)
|
| 139 |
+
else:
|
| 140 |
+
raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_lemma_classifier.py")
|
| 141 |
+
print("Done processing %s" % dataset_name)
|
| 142 |
+
|
| 143 |
+
if __name__ == '__main__':
|
| 144 |
+
main(sys.argv[1])
|
stanza/stanza/utils/datasets/prepare_mwt_treebank.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A script to prepare all MWT datasets.
|
| 3 |
+
|
| 4 |
+
For example, do
|
| 5 |
+
python -m stanza.utils.datasets.prepare_mwt_treebank TREEBANK
|
| 6 |
+
such as
|
| 7 |
+
python -m stanza.utils.datasets.prepare_mwt_treebank UD_English-EWT
|
| 8 |
+
|
| 9 |
+
and it will prepare each of train, dev, test
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
import shutil
|
| 15 |
+
import tempfile
|
| 16 |
+
|
| 17 |
+
from stanza.utils.conll import CoNLL
|
| 18 |
+
from stanza.models.common.constant import treebank_to_short_name
|
| 19 |
+
import stanza.utils.datasets.common as common
|
| 20 |
+
import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
|
| 21 |
+
|
| 22 |
+
from stanza.utils.datasets.contract_mwt import contract_mwt
|
| 23 |
+
|
| 24 |
+
# languages where the MWTs are always a composition of the words themselves
|
| 25 |
+
KNOWN_COMPOSABLE_MWTS = {"en"}
|
| 26 |
+
# ... but partut is not put together that way
|
| 27 |
+
MWT_EXCEPTIONS = {"en_partut"}
|
| 28 |
+
|
| 29 |
+
def copy_conllu(tokenizer_dir, mwt_dir, short_name, dataset, particle):
|
| 30 |
+
input_conllu_tokenizer = f"{tokenizer_dir}/{short_name}.{dataset}.gold.conllu"
|
| 31 |
+
input_conllu_mwt = f"{mwt_dir}/{short_name}.{dataset}.{particle}.conllu"
|
| 32 |
+
shutil.copyfile(input_conllu_tokenizer, input_conllu_mwt)
|
| 33 |
+
|
| 34 |
+
def check_mwt_composition(filename):
|
| 35 |
+
print("Checking the MWTs in %s" % filename)
|
| 36 |
+
doc = CoNLL.conll2doc(filename)
|
| 37 |
+
for sent_idx, sentence in enumerate(doc.sentences):
|
| 38 |
+
for token_idx, token in enumerate(sentence.tokens):
|
| 39 |
+
if len(token.words) > 1:
|
| 40 |
+
expected = "".join(x.text for x in token.words)
|
| 41 |
+
if token.text != expected:
|
| 42 |
+
raise ValueError("Unexpected token composition in filename %s sentence %d id %s token %d: %s instead of %s" % (filename, sent_idx, sentence.sent_id, token_idx, token.text, expected))
|
| 43 |
+
|
| 44 |
+
def process_treebank(treebank, model_type, paths, args):
|
| 45 |
+
short_name = treebank_to_short_name(treebank)
|
| 46 |
+
|
| 47 |
+
mwt_dir = paths["MWT_DATA_DIR"]
|
| 48 |
+
os.makedirs(mwt_dir, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
with tempfile.TemporaryDirectory() as tokenizer_dir:
|
| 51 |
+
paths = dict(paths)
|
| 52 |
+
paths["TOKENIZE_DATA_DIR"] = tokenizer_dir
|
| 53 |
+
|
| 54 |
+
# first we process the tokenization data
|
| 55 |
+
tokenizer_args = argparse.Namespace()
|
| 56 |
+
tokenizer_args.augment = False
|
| 57 |
+
tokenizer_args.prepare_labels = True
|
| 58 |
+
prepare_tokenizer_treebank.process_treebank(treebank, model_type, paths, tokenizer_args)
|
| 59 |
+
|
| 60 |
+
copy_conllu(tokenizer_dir, mwt_dir, short_name, "train", "in")
|
| 61 |
+
copy_conllu(tokenizer_dir, mwt_dir, short_name, "dev", "gold")
|
| 62 |
+
copy_conllu(tokenizer_dir, mwt_dir, short_name, "test", "gold")
|
| 63 |
+
|
| 64 |
+
for shard in ("train", "dev", "test"):
|
| 65 |
+
source_filename = common.mwt_name(tokenizer_dir, short_name, shard)
|
| 66 |
+
dest_filename = common.mwt_name(mwt_dir, short_name, shard)
|
| 67 |
+
print("Copying from %s to %s" % (source_filename, dest_filename))
|
| 68 |
+
shutil.copyfile(source_filename, dest_filename)
|
| 69 |
+
|
| 70 |
+
language = short_name.split("_", 1)[0]
|
| 71 |
+
if language in KNOWN_COMPOSABLE_MWTS and short_name not in MWT_EXCEPTIONS:
|
| 72 |
+
print("Language %s is known to have all MWT composed of exactly its word pieces. Checking..." % language)
|
| 73 |
+
check_mwt_composition(f"{mwt_dir}/{short_name}.train.in.conllu")
|
| 74 |
+
check_mwt_composition(f"{mwt_dir}/{short_name}.dev.gold.conllu")
|
| 75 |
+
check_mwt_composition(f"{mwt_dir}/{short_name}.test.gold.conllu")
|
| 76 |
+
|
| 77 |
+
contract_mwt(f"{mwt_dir}/{short_name}.dev.gold.conllu",
|
| 78 |
+
f"{mwt_dir}/{short_name}.dev.in.conllu")
|
| 79 |
+
contract_mwt(f"{mwt_dir}/{short_name}.test.gold.conllu",
|
| 80 |
+
f"{mwt_dir}/{short_name}.test.in.conllu")
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
common.main(process_treebank, common.ModelType.MWT)
|
| 84 |
+
|
| 85 |
+
if __name__ == '__main__':
|
| 86 |
+
main()
|
| 87 |
+
|
| 88 |
+
|
stanza/stanza/utils/datasets/prepare_pos_treebank.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A script to prepare all pos datasets.
|
| 3 |
+
|
| 4 |
+
For example, do
|
| 5 |
+
python -m stanza.utils.datasets.prepare_pos_treebank TREEBANK
|
| 6 |
+
such as
|
| 7 |
+
python -m stanza.utils.datasets.prepare_pos_treebank UD_English-EWT
|
| 8 |
+
|
| 9 |
+
and it will prepare each of train, dev, test
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import shutil
|
| 14 |
+
|
| 15 |
+
import stanza.utils.datasets.common as common
|
| 16 |
+
import stanza.utils.datasets.prepare_tokenizer_treebank as prepare_tokenizer_treebank
|
| 17 |
+
|
| 18 |
+
def copy_conllu_file_or_zip(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name):
|
| 19 |
+
original = f"{tokenizer_dir}/{short_name}.{tokenizer_file}.zip"
|
| 20 |
+
copied = f"{dest_dir}/{short_name}.{dest_file}.zip"
|
| 21 |
+
|
| 22 |
+
if os.path.exists(original):
|
| 23 |
+
print("Copying from %s to %s" % (original, copied))
|
| 24 |
+
shutil.copyfile(original, copied)
|
| 25 |
+
else:
|
| 26 |
+
prepare_tokenizer_treebank.copy_conllu_file(tokenizer_dir, tokenizer_file, dest_dir, dest_file, short_name)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def process_treebank(treebank, model_type, paths, args):
|
| 30 |
+
prepare_tokenizer_treebank.copy_conllu_treebank(treebank, model_type, paths, paths["POS_DATA_DIR"], postprocess=copy_conllu_file_or_zip)
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
common.main(process_treebank, common.ModelType.POS)
|
| 34 |
+
|
| 35 |
+
if __name__ == '__main__':
|
| 36 |
+
main()
|
| 37 |
+
|
| 38 |
+
|
stanza/stanza/utils/datasets/random_split_conllu.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Randomly split a file into train, dev, and test sections
|
| 3 |
+
|
| 4 |
+
Specifically used in the case of building a tagger from the initial
|
| 5 |
+
POS tagging provided by Isra, but obviously can be used to split any
|
| 6 |
+
conllu file
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
from stanza.models.common.doc import Document
|
| 14 |
+
from stanza.utils.conll import CoNLL
|
| 15 |
+
from stanza.utils.default_paths import get_default_paths
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
parser.add_argument('--filename', default='extern_data/sindhi/upos/sindhi_upos.conllu', help='Which file to split')
|
| 20 |
+
parser.add_argument('--train', type=float, default=0.8, help='Fraction of the data to use for train')
|
| 21 |
+
parser.add_argument('--dev', type=float, default=0.1, help='Fraction of the data to use for dev')
|
| 22 |
+
parser.add_argument('--test', type=float, default=0.1, help='Fraction of the data to use for test')
|
| 23 |
+
parser.add_argument('--seed', default='1234', help='Random seed to use')
|
| 24 |
+
parser.add_argument('--short_name', default='sd_isra', help='Dataset name to use when writing output files')
|
| 25 |
+
parser.add_argument('--no_remove_xpos', default=True, action='store_false', dest='remove_xpos', help='By default, we remove the xpos from the dataset')
|
| 26 |
+
parser.add_argument('--no_remove_feats', default=True, action='store_false', dest='remove_feats', help='By default, we remove the feats from the dataset')
|
| 27 |
+
parser.add_argument('--output_directory', default=get_default_paths()["POS_DATA_DIR"], help="Where to put the split conllu")
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
weights = (args.train, args.dev, args.test)
|
| 31 |
+
|
| 32 |
+
doc = CoNLL.conll2doc(args.filename)
|
| 33 |
+
random.seed(args.seed)
|
| 34 |
+
|
| 35 |
+
train_doc = ([], [])
|
| 36 |
+
dev_doc = ([], [])
|
| 37 |
+
test_doc = ([], [])
|
| 38 |
+
splits = [train_doc, dev_doc, test_doc]
|
| 39 |
+
for sentence in doc.sentences:
|
| 40 |
+
sentence_dict = sentence.to_dict()
|
| 41 |
+
if args.remove_xpos:
|
| 42 |
+
for x in sentence_dict:
|
| 43 |
+
x.pop('xpos', None)
|
| 44 |
+
if args.remove_feats:
|
| 45 |
+
for x in sentence_dict:
|
| 46 |
+
x.pop('feats', None)
|
| 47 |
+
split = random.choices(splits, weights)[0]
|
| 48 |
+
split[0].append(sentence_dict)
|
| 49 |
+
split[1].append(sentence.comments)
|
| 50 |
+
|
| 51 |
+
splits = [Document(split[0], comments=split[1]) for split in splits]
|
| 52 |
+
for split_doc, split_name in zip(splits, ("train", "dev", "test")):
|
| 53 |
+
filename = os.path.join(args.output_directory, "%s.%s.in.conllu" % (args.short_name, split_name))
|
| 54 |
+
print("Outputting %d sentences to %s" % (len(split_doc.sentences), filename))
|
| 55 |
+
CoNLL.write_doc2conll(split_doc, filename)
|
| 56 |
+
|
| 57 |
+
if __name__ == '__main__':
|
| 58 |
+
main()
|
| 59 |
+
|
stanza/stanza/utils/datasets/thai_syllable_dict_generator.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import pathlib
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_dictionary(dataset_dir, save_dir):
|
| 7 |
+
syllables = set()
|
| 8 |
+
|
| 9 |
+
for p in pathlib.Path(dataset_dir).rglob("*.ssg"): # iterate through all files
|
| 10 |
+
|
| 11 |
+
with open(p) as f: # for each file
|
| 12 |
+
sentences = f.readlines()
|
| 13 |
+
|
| 14 |
+
for i in range(len(sentences)):
|
| 15 |
+
|
| 16 |
+
sentences[i] = sentences[i].replace("\n", "")
|
| 17 |
+
sentences[i] = sentences[i].replace("<s/>", "~")
|
| 18 |
+
sentences[i] = sentences[i].split("~") # create list of all syllables
|
| 19 |
+
|
| 20 |
+
syllables = syllables.union(sentences[i])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
print(len(syllables))
|
| 24 |
+
|
| 25 |
+
# Filter out syllables with English words
|
| 26 |
+
import re
|
| 27 |
+
|
| 28 |
+
a = []
|
| 29 |
+
|
| 30 |
+
for s in syllables:
|
| 31 |
+
print("---")
|
| 32 |
+
if bool(re.match("^[\u0E00-\u0E7F]*$", s)) and s != "" and " " not in s:
|
| 33 |
+
a.append(s)
|
| 34 |
+
else:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
a = set(a)
|
| 38 |
+
a = dict(zip(list(a), range(len(a))))
|
| 39 |
+
|
| 40 |
+
import json
|
| 41 |
+
print(a)
|
| 42 |
+
print(len(a))
|
| 43 |
+
with open(save_dir, "w") as fp:
|
| 44 |
+
json.dump(a, fp)
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
|
| 48 |
+
parser = argparse.ArgumentParser()
|
| 49 |
+
parser.add_argument('--dataset_dir', type=str, default="syllable_segmentation_data", help="Directory for syllable dataset")
|
| 50 |
+
parser.add_argument('--save_dir', type=str, default="thai-syllable.json", help="Directory for generated file")
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
create_dictionary(args.dataset_dir, args.save_dir)
|