diff --git a/stanza/stanza/pipeline/demo/stanza-brat.css b/stanza/stanza/pipeline/demo/stanza-brat.css new file mode 100644 index 0000000000000000000000000000000000000000..147e5275f6162be71c762c7494f9372f4236e521 --- /dev/null +++ b/stanza/stanza/pipeline/demo/stanza-brat.css @@ -0,0 +1,74 @@ + +.red { + color:#990000 +} + +#wrap { + min-height: 100%; + height: auto; + margin: 0 auto -6ex; + padding: 0 0 6ex; +} + +.pattern_tab { + margin: 1ex; +} + +.pattern_brat { + margin-top: 1ex; +} + +.label { + color: #777777; + font-size: small; +} + +.footer { + bottom: 0; + width: 100%; + /* Set the fixed height of the footer here */ + height: 5ex; + padding-top: 1ex; + margin-top: 1ex; + background-color: #f5f5f5; +} + +.corenlp_error { + margin-top: 2ex; +} + +/* Styling for parse graph */ +.node rect { + stroke: #333; + fill: #fff; +} + +.parse-RULE rect { + fill: #C0D9AF; +} + +.parse-TERMINAL rect { + stroke: #333; + fill: #EEE8AA; +} + +.node.highlighted { + stroke: #ffff00; +} + +.edgePath path { + stroke: #333; + fill: #333; + stroke-width: 1.5px; +} + +.parse-EDGE path { + stroke: DarkGray; + fill: DarkGray; + stroke-width: 1.5px; +} + +.logo { + font-family: "Lato", "Gill Sans MT", "Gill Sans", "Helvetica", "Arial", sans-serif; + font-style: italic; +} diff --git a/stanza/stanza/pipeline/demo/stanza-parseviewer.js b/stanza/stanza/pipeline/demo/stanza-parseviewer.js new file mode 100644 index 0000000000000000000000000000000000000000..9bfcdbf0d0fc69bed8d1c4eb661f6f378b5d2f26 --- /dev/null +++ b/stanza/stanza/pipeline/demo/stanza-parseviewer.js @@ -0,0 +1,215 @@ +//'use strict'; + +//d3 || require('d3'); +//var dagreD3 = require('dagre-d3'); +//var jquery = require('jquery'); +//var $ = jquery; + +var ParseViewer = function(params) { + // Container in which the scene template is displayed + this.selector = params.selector; + this.container = $(this.selector); + this.fitToGraph = true; + this.onClickNodeCallback = params.onClickNodeCallback; + this.onHoverNodeCallback = params.onHoverNodeCallback; + this.init(); + return this; +}; + +ParseViewer.MIN_WIDTH = 100; +ParseViewer.MIN_HEIGHT = 100; + +ParseViewer.prototype.constructor = ParseViewer; + +ParseViewer.prototype.getAutoWidth = function () { + return Math.max(ParseViewer.MIN_WIDTH, this.container.width()); +}; + +ParseViewer.prototype.getAutoHeight = function () { + return Math.max(ParseViewer.MIN_HEIGHT, this.container.height() - 20); +}; + +ParseViewer.prototype.init = function () { + var canvasWidth = this.getAutoWidth(); + var canvasHeight = this.getAutoHeight(); + this.parseElem = d3.select(this.selector) + .append('svg') + .attr({'width': canvasWidth, 'height': canvasHeight}) + .style({'width': canvasWidth, 'height': canvasHeight}); + console.log(this.parseElem); + this.graph = null; + this.graphRendered = false; + + this.controls = $('
'); + this.container.append(this.controls); +}; + +var GraphBuilder = function(roots) { + // Create the input graph + this.graph = new dagreD3.graphlib.Graph() + .setGraph({}) + .setDefaultEdgeLabel(function () { + return {}; + }); + this.visitIndex = 0; + //console.log('building graph', roots); + for (var i = 0; i < roots.length; i++) { + this.build(roots[i]); + } +}; + +GraphBuilder.prototype.build = function(node) { + console.log(node); + // Track my visit index + this.visitIndex++; + node.visitIndex = this.visitIndex; + + // Add a node + var nodeData = node; // TODO: replace with semantic data + var nodeLabel = node.label; + var nodeIndex = node.visitIndex; + var nodeClass = 'parse-RULE'; + + this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData }); + if (node.parent) { + this.graph.setEdge(node.parent.visitIndex, nodeIndex, { + class: 'parse-EDGE' + }); + } + + if (node.isTerminal) { + this.visitIndex++; + nodeIndex = this.visitIndex; + nodeLabel = node.text; + nodeClass = 'parse-TERMINAL'; + + this.graph.setNode(nodeIndex, { label: nodeLabel, class: nodeClass, data: nodeData }); + this.graph.setEdge(node.visitIndex, nodeIndex, { + class: 'parse-EDGE' + }); + } else if (node.children) { + for (var i = 0; i < node.children.length; i++) { + this.build(node.children[i]); + } + } +}; + +ParseViewer.prototype.updateGraphPosition = function (svg, g, minWidth, minHeight) { + if (this.fitToGraph) { + minWidth = g.graph().width; + minHeight = this.getAutoHeight(); + } + adjustGraphPositioning(svg, g, minWidth, minHeight); +}; + +function adjustGraphPositioning(svg, g, minWidth, minHeight) { + // Resize svg + var newWidth = Math.max(minWidth, g.graph().width); + var newHeight = Math.max(minHeight, g.graph().height + 40); + svg.attr({'width': newWidth, 'height': newHeight}); + svg.style({'width': newWidth, 'height': newHeight}); + // Center the graph + var svgGroup = svg.select('g'); + var xCenterOffset = (svg.attr('width') - g.graph().width) / 2; + svgGroup.attr('transform', 'translate(' + xCenterOffset + ', 20)'); + svg.attr('height', g.graph().height + 40); + svg.style('height', g.graph().height + 40); +} + +ParseViewer.prototype.renderGraph = function (svg, g, parse) { + // Create the renderer + var render = new dagreD3.render(); + // Run the renderer. This is what draws the final graph. + var svgGroup = svg.select('g'); + render(svgGroup, g); + + var scope = this; + var nodes = svgGroup.selectAll('g.node'); + nodes.on('click', + function (d) { + var v = d; + var node = g.node(v); + if (scope.onClickNodeCallback) { + scope.onClickNodeCallback(node.data); + } + console.log(g.node(v)); + } + ); + + nodes.on('mouseover', + function (d) { + var v = d; + var node = g.node(v); + if (scope.onHoverNodeCallback) { + scope.onHoverNodeCallback(node.data); + } + } + ); + + this.updateGraphPosition(svg, g, svg.attr('width'), svg.attr('height')); + this.graphRendered = true; +}; + +ParseViewer.prototype.showParse = function (root) { + this.showParses([root]); +}; + +ParseViewer.prototype.showParses = function (roots) { + // Take parse and create a graph + var gb = new GraphBuilder(roots); + var g = gb.graph; + + g.nodes().forEach(function (v) { + var node = g.node(v); + // Round the corners of the nodes + node.rx = node.ry = 5; + }); + + var svg = this.parseElem; + svg.selectAll('*').remove(); + var svgGroup = svg.append('g'); + this.graph = g; + this.parse = roots; + if (this.container.is(':visible')) { + if (roots.length > 0) { + this.renderGraph(svg, this.graph, this.parse); + } + } else { + this.graphRendered = false; + } +}; + +ParseViewer.prototype.showAnnotation = function (annotation) { + var parses = []; + for (var i = 0; i < annotation.sentences.length; i++) { + var s = annotation.sentences[i]; + if (s && s.parseTree) { + parses.push(s.parseTree); + } + } + this.showParses(parses); +}; + +ParseViewer.prototype.onResize = function () { + var canvasWidth = this.getAutoWidth(); + var canvasHeight = this.getAutoHeight(); + var svg = this.parseElem; + + // Center the graph + var svgGroup = svg.select('g'); + if (svgGroup && this.graph) { + if (!this.graphRendered) { + svg.attr({'width': canvasWidth, 'height': canvasHeight}); + svg.style({'width': canvasWidth, 'height': canvasHeight}); + this.renderGraph(svg, this.graph, this.parse); + } else { + this.updateGraphPosition(svg, this.graph, canvasWidth, canvasHeight); + } + } else { + svg.attr({'width': canvasWidth, 'height': canvasHeight}); + svg.style({'width': canvasWidth, 'height': canvasHeight}); + } +}; + +// Exports +//module.exports = ParseViewer; diff --git a/stanza/stanza/pipeline/external/__init__.py b/stanza/stanza/pipeline/external/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/tests/classifiers/__init__.py b/stanza/stanza/tests/classifiers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/tests/classifiers/test_classifier.py b/stanza/stanza/tests/classifiers/test_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..52e202fad3e41987b7871c3802ca327d412bf472 --- /dev/null +++ b/stanza/stanza/tests/classifiers/test_classifier.py @@ -0,0 +1,317 @@ +import glob +import os + +import pytest + +import numpy as np +import torch + +import stanza +import stanza.models.classifier as classifier +import stanza.models.classifiers.data as data +from stanza.models.classifiers.trainer import Trainer +from stanza.models.common import pretrain +from stanza.models.common import utils + +from stanza.tests import TEST_MODELS_DIR +from stanza.tests.classifiers.test_data import train_file, dev_file, test_file, DATASET, SENTENCES + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +EMB_DIM = 5 + +@pytest.fixture(scope="module") +def fake_embeddings(tmp_path_factory): + """ + will return a path to a fake embeddings file with the words in SENTENCES + """ + # could set np random seed here + words = sorted(set([x.lower() for y in SENTENCES for x in y])) + words = words[:-1] + embedding_dir = tmp_path_factory.mktemp("data") + embedding_txt = embedding_dir / "embedding.txt" + embedding_pt = embedding_dir / "embedding.pt" + embedding = np.random.random((len(words), EMB_DIM)) + + with open(embedding_txt, "w", encoding="utf-8") as fout: + for word, emb in zip(words, embedding): + fout.write(word) + fout.write("\t") + fout.write("\t".join(str(x) for x in emb)) + fout.write("\n") + + pt = pretrain.Pretrain(str(embedding_pt), str(embedding_txt)) + pt.load() + assert os.path.exists(embedding_pt) + return embedding_pt + +class TestClassifier: + def build_model(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None): + """ + Build a model to be used by one of the later tests + """ + save_dir = str(tmp_path / "classifier") + save_name = "model.pt" + args = ["--save_dir", save_dir, + "--save_name", save_name, + "--wordvec_pretrain_file", str(fake_embeddings), + "--filter_channels", "20", + "--fc_shapes", "20,10", + "--train_file", str(train_file), + "--dev_file", str(dev_file), + "--max_epochs", "2", + "--batch_size", "60"] + if extra_args is not None: + args = args + extra_args + args = classifier.parse_args(args) + train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len) + if checkpoint_file: + trainer = Trainer.load(checkpoint_file, args, load_optimizer=True) + else: + trainer = Trainer.build_new_model(args, train_set) + return trainer, train_set, args + + def run_training(self, tmp_path, fake_embeddings, train_file, dev_file, extra_args=None, checkpoint_file=None): + """ + Iterate a couple times over a model + """ + trainer, train_set, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args, checkpoint_file) + dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len) + labels = data.dataset_labels(train_set) + + save_filename = os.path.join(args.save_dir, args.save_name) + if checkpoint_file is None: + checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name) + classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels) + return trainer, save_filename, checkpoint_file + + def test_build_model(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test that building a basic model works + """ + self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"]) + + def test_save_load(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test that a basic model can save & load + """ + trainer, _, args = self.build_model(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"]) + + save_filename = os.path.join(args.save_dir, args.save_name) + trainer.save(save_filename) + + args.load_name = args.save_name + trainer = Trainer.load(args.load_name, args) + args.load_name = save_filename + trainer = Trainer.load(args.load_name, args) + + def test_train_basic(self, tmp_path, fake_embeddings, train_file, dev_file): + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20"]) + + def test_train_bilstm(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test w/ and w/o bilstm variations of the classifier + """ + args = ["--bilstm", "--bilstm_hidden_dim", "20"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + args = ["--no_bilstm"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + def test_train_maxpool_width(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test various maxpool widths + + Also sets --filter_channels to a multiple of 2 but not of 3 for + the test to make sure the math is done correctly on a non-divisible width + """ + args = ["--maxpool_width", "1", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + args = ["--maxpool_width", "2", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + args = ["--maxpool_width", "3", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + def test_train_conv_2d(self, tmp_path, fake_embeddings, train_file, dev_file): + args = ["--filter_sizes", "(3,4,5)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + args = ["--filter_sizes", "((3,2),)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--bilstm_hidden_dim", "20"] + self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + def test_train_filter_channels(self, tmp_path, fake_embeddings, train_file, dev_file): + args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20", "--no_bilstm"] + trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + assert trainer.model.fc_input_size == 40 + + args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "15,20", "--no_bilstm"] + trainer, _, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + # 50 = 2x15 for the 2d conv (over 5 dim embeddings) + 20 + assert trainer.model.fc_input_size == 50 + + def test_train_bert(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test on a tiny Bert WITHOUT finetuning, which hopefully does not take up too much disk space or memory + """ + bert_model = "hf-internal-testing/tiny-bert" + + trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model]) + assert os.path.exists(save_filename) + saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True) + # check that the bert model wasn't saved as part of the classifier + assert not saved_model['params']['config']['force_bert_saved'] + assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys()) + + def test_finetune_bert(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory + """ + bert_model = "hf-internal-testing/tiny-bert" + + trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune"]) + assert os.path.exists(save_filename) + saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True) + # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer + assert saved_model['params']['config']['force_bert_saved'] + assert any(x.startswith("bert_model") for x in saved_model['params']['model'].keys()) + + def test_finetune_bert_layers(self, tmp_path, fake_embeddings, train_file, dev_file): + """Test on a tiny Bert WITH finetuning, which hopefully does not take up too much disk space or memory, using 2 layers + + As an added bonus (or eager test), load the finished model and continue + training from there. Then check that the initial model and + the middle model are different, then that the middle model and + final model are different + + """ + bert_model = "hf-internal-testing/tiny-bert" + + trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models"]) + assert os.path.exists(save_filename) + + save_path = os.path.split(save_filename)[0] + + initial_model = glob.glob(os.path.join(save_path, "*E0000*")) + assert len(initial_model) == 1 + initial_model = initial_model[0] + initial_model = torch.load(initial_model, lambda storage, loc: storage, weights_only=True) + + second_model_file = glob.glob(os.path.join(save_path, "*E0002*")) + assert len(second_model_file) == 1 + second_model_file = second_model_file[0] + second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True) + + for layer_idx in range(2): + bert_names = [x for x in second_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x] + assert len(bert_names) > 0 + assert all(x in initial_model['params']['model'] and x in second_model['params']['model'] for x in bert_names) + assert not all(torch.allclose(initial_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names) + + # put some random marker in the file to look for later, + # check the continued training didn't clobber the expected file + assert "asdf" not in second_model + second_model["asdf"] = 1234 + torch.save(second_model, second_model_file) + + trainer, save_filename, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--bert_hidden_layers", "2", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file) + + second_model_file_redo = glob.glob(os.path.join(save_path, "*E0002*")) + assert len(second_model_file_redo) == 1 + assert second_model_file == second_model_file_redo[0] + second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True) + assert "asdf" in second_model + + fifth_model_file = glob.glob(os.path.join(save_path, "*E0005*")) + assert len(fifth_model_file) == 1 + + final_model = torch.load(fifth_model_file[0], lambda storage, loc: storage, weights_only=True) + for layer_idx in range(2): + bert_names = [x for x in final_model['params']['model'].keys() if x.startswith("bert_model") and "layer.%d." % layer_idx in x] + assert len(bert_names) > 0 + assert all(x in final_model['params']['model'] and x in second_model['params']['model'] for x in bert_names) + assert not all(torch.allclose(final_model['params']['model'].get(x), second_model['params']['model'].get(x)) for x in bert_names) + + def test_finetune_peft(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test on a tiny Bert with PEFT finetuning + """ + bert_model = "hf-internal-testing/tiny-bert" + + trainer, save_filename, _ = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler"]) + assert os.path.exists(save_filename) + saved_model = torch.load(save_filename, lambda storage, loc: storage, weights_only=True) + # after finetuning the bert model, make sure that the save file DOES contain parts of the transformer, but only in peft form + assert saved_model['params']['config']['bert_model'] == bert_model + assert saved_model['params']['config']['force_bert_saved'] + assert saved_model['params']['config']['use_peft'] + + assert not saved_model['params']['config']['has_charlm_forward'] + assert not saved_model['params']['config']['has_charlm_backward'] + + assert len(saved_model['params']['bert_lora']) > 0 + assert any(x.find(".pooler.") >= 0 for x in saved_model['params']['bert_lora']) + assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora']) + assert not any(x.startswith("bert_model") for x in saved_model['params']['model'].keys()) + + # The Pipeline should load and run a PEFT trained model, + # although obviously we don't expect the results to do + # anything correct + pipeline = stanza.Pipeline("en", download_method=None, model_dir=TEST_MODELS_DIR, processors="tokenize,sentiment", sentiment_model_path=save_filename, sentiment_pretrain_path=str(fake_embeddings)) + doc = pipeline("This is a test") + + def test_finetune_peft_restart(self, tmp_path, fake_embeddings, train_file, dev_file): + """ + Test that if we restart training on a peft model, the peft weights change + """ + bert_model = "hf-internal-testing/tiny-bert" + + trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models"]) + + assert os.path.exists(save_file) + saved_model = torch.load(save_file, lambda storage, loc: storage, weights_only=True) + assert any(x.find(".encoder.") >= 0 for x in saved_model['params']['bert_lora']) + + + trainer, save_file, checkpoint_file = self.run_training(tmp_path, fake_embeddings, train_file, dev_file, extra_args=["--bilstm_hidden_dim", "20", "--bert_model", bert_model, "--bert_finetune", "--use_peft", "--lora_modules_to_save", "pooler", "--save_intermediate_models", "--max_epochs", "5"], checkpoint_file=checkpoint_file) + + save_path = os.path.split(save_file)[0] + + initial_model_file = glob.glob(os.path.join(save_path, "*E0000*")) + assert len(initial_model_file) == 1 + initial_model_file = initial_model_file[0] + initial_model = torch.load(initial_model_file, lambda storage, loc: storage, weights_only=True) + + second_model_file = glob.glob(os.path.join(save_path, "*E0002*")) + assert len(second_model_file) == 1 + second_model_file = second_model_file[0] + second_model = torch.load(second_model_file, lambda storage, loc: storage, weights_only=True) + + final_model_file = glob.glob(os.path.join(save_path, "*E0005*")) + assert len(final_model_file) == 1 + final_model_file = final_model_file[0] + final_model = torch.load(final_model_file, lambda storage, loc: storage, weights_only=True) + + # params in initial_model & second_model start with "base_model.model." + # whereas params in final_model start directly with "encoder" or "pooler" + initial_lora = initial_model['params']['bert_lora'] + second_lora = second_model['params']['bert_lora'] + final_lora = final_model['params']['bert_lora'] + for side in ("_A.", "_B."): + for layer in (".0.", ".1."): + initial_params = sorted([x for x in initial_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0]) + second_params = sorted([x for x in second_lora if x.find(".encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0]) + final_params = sorted([x for x in final_lora if x.startswith("encoder.") > 0 and x.find(side) > 0 and x.find(layer) > 0]) + assert len(initial_params) > 0 + assert len(initial_params) == len(second_params) + assert len(initial_params) == len(final_params) + for x, y in zip(second_params, final_params): + assert x.endswith(y) + if side != "_A.": # the A tensors don't move very much, if at all + assert not torch.allclose(initial_lora.get(x), second_lora.get(x)) + assert not torch.allclose(second_lora.get(x), final_lora.get(y)) + diff --git a/stanza/stanza/tests/classifiers/test_process_utils.py b/stanza/stanza/tests/classifiers/test_process_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dcfa2900ede1f6a4ea7a82c687708698a11194 --- /dev/null +++ b/stanza/stanza/tests/classifiers/test_process_utils.py @@ -0,0 +1,83 @@ +""" +A few tests of the utils module for the sentiment datasets +""" + +import os +import pytest + +import stanza + +from stanza.models.classifiers import data +from stanza.models.classifiers.data import SentimentDatum +from stanza.models.classifiers.utils import WVType +from stanza.utils.datasets.sentiment import process_utils + +from stanza.tests import TEST_MODELS_DIR +from stanza.tests.classifiers.test_data import train_file, dev_file, test_file + + +def test_write_list(tmp_path, train_file): + """ + Test that writing a single list of items to an output file works + """ + train_set = data.read_dataset(train_file, WVType.OTHER, 1) + + dataset_file = tmp_path / "foo.json" + process_utils.write_list(dataset_file, train_set) + + train_copy = data.read_dataset(dataset_file, WVType.OTHER, 1) + assert train_copy == train_set + +def test_write_dataset(tmp_path, train_file, dev_file, test_file): + """ + Test that writing all three parts of a dataset works + """ + dataset = [data.read_dataset(filename, WVType.OTHER, 1) for filename in (train_file, dev_file, test_file)] + process_utils.write_dataset(dataset, tmp_path, "en_test") + + expected_files = ['en_test.train.json', 'en_test.dev.json', 'en_test.test.json'] + dataset_files = os.listdir(tmp_path) + assert sorted(dataset_files) == sorted(expected_files) + + for filename, expected in zip(expected_files, dataset): + written = data.read_dataset(tmp_path / filename, WVType.OTHER, 1) + assert written == expected + +def test_read_snippets(tmp_path): + """ + Test the basic operation of the read_snippets function + """ + filename = tmp_path / "foo.csv" + with open(filename, "w", encoding="utf-8") as fout: + fout.write("FOO\tThis is a test\thappy\n") + fout.write("FOO\tThis is a second sentence\tsad\n") + + nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None) + + mapping = {"happy": 0, "sad": 1} + + snippets = process_utils.read_snippets(filename, 2, 1, "en", mapping, nlp=nlp) + assert len(snippets) == 2 + assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']), + SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence'])] + +def test_read_snippets_two_columns(tmp_path): + """ + Test what happens when multiple columns are combined for the sentiment value + """ + filename = tmp_path / "foo.csv" + with open(filename, "w", encoding="utf-8") as fout: + fout.write("FOO\tThis is a test\thappy\tfoo\n") + fout.write("FOO\tThis is a second sentence\tsad\tbar\n") + fout.write("FOO\tThis is a third sentence\tsad\tfoo\n") + + nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None) + + mapping = {("happy", "foo"): 0, ("sad", "bar"): 1, ("sad", "foo"): 2} + + snippets = process_utils.read_snippets(filename, (2,3), 1, "en", mapping, nlp=nlp) + assert len(snippets) == 3 + assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']), + SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence']), + SentimentDatum(sentiment=2, text=['This', 'is', 'a', 'third', 'sentence'])] + diff --git a/stanza/stanza/tests/common/test_bert_embedding.py b/stanza/stanza/tests/common/test_bert_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc061557ad583bc764cf0d806c31acd58a60b64 --- /dev/null +++ b/stanza/stanza/tests/common/test_bert_embedding.py @@ -0,0 +1,33 @@ +import pytest +import torch + +from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +BERT_MODEL = "hf-internal-testing/tiny-bert" + +@pytest.fixture(scope="module") +def tiny_bert(): + m, t = load_bert(BERT_MODEL) + return m, t + +def test_load_bert(tiny_bert): + """ + Empty method that just tests loading the bert + """ + m, t = tiny_bert + +def test_run_bert(tiny_bert): + m, t = tiny_bert + device = next(m.parameters()).device + extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True) + +def test_run_bert_empty_word(tiny_bert): + m, t = tiny_bert + device = next(m.parameters()).device + foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True) + bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True) + + assert len(foo) == 1 + assert torch.allclose(foo[0], bar[0]) diff --git a/stanza/stanza/tests/common/test_char_model.py b/stanza/stanza/tests/common/test_char_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3535c27e74b1b01af0fb68b78eb2d8152b1d2839 --- /dev/null +++ b/stanza/stanza/tests/common/test_char_model.py @@ -0,0 +1,190 @@ +""" +Currently tests a few configurations of files for creating a charlm vocab + +Also has a skeleton test of loading & saving a charlm +""" + +from collections import Counter +import glob +import lzma +import os +import tempfile + +import pytest + +from stanza.models import charlm +from stanza.models.common import char_model +from stanza.tests import TEST_MODELS_DIR + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +fake_text_1 = """ +Unban mox opal! +I hate watching Peppa Pig +""" + +fake_text_2 = """ +This is plastic cheese +""" + +class TestCharModel: + def test_single_file_vocab(self): + with tempfile.TemporaryDirectory() as tempdir: + sample_file = os.path.join(tempdir, "text.txt") + with open(sample_file, "w", encoding="utf-8") as fout: + fout.write(fake_text_1) + vocab = char_model.build_charlm_vocab(sample_file) + + for i in fake_text_1: + assert i in vocab + assert "Q" not in vocab + + def test_single_file_xz_vocab(self): + with tempfile.TemporaryDirectory() as tempdir: + sample_file = os.path.join(tempdir, "text.txt.xz") + with lzma.open(sample_file, "wt", encoding="utf-8") as fout: + fout.write(fake_text_1) + vocab = char_model.build_charlm_vocab(sample_file) + + for i in fake_text_1: + assert i in vocab + assert "Q" not in vocab + + def test_single_file_dir_vocab(self): + with tempfile.TemporaryDirectory() as tempdir: + sample_file = os.path.join(tempdir, "text.txt") + with open(sample_file, "w", encoding="utf-8") as fout: + fout.write(fake_text_1) + vocab = char_model.build_charlm_vocab(tempdir) + + for i in fake_text_1: + assert i in vocab + assert "Q" not in vocab + + def test_multiple_files_vocab(self): + with tempfile.TemporaryDirectory() as tempdir: + sample_file = os.path.join(tempdir, "t1.txt") + with open(sample_file, "w", encoding="utf-8") as fout: + fout.write(fake_text_1) + sample_file = os.path.join(tempdir, "t2.txt.xz") + with lzma.open(sample_file, "wt", encoding="utf-8") as fout: + fout.write(fake_text_2) + vocab = char_model.build_charlm_vocab(tempdir) + + for i in fake_text_1: + assert i in vocab + for i in fake_text_2: + assert i in vocab + assert "Q" not in vocab + + def test_cutoff_vocab(self): + with tempfile.TemporaryDirectory() as tempdir: + sample_file = os.path.join(tempdir, "t1.txt") + with open(sample_file, "w", encoding="utf-8") as fout: + fout.write(fake_text_1) + sample_file = os.path.join(tempdir, "t2.txt.xz") + with lzma.open(sample_file, "wt", encoding="utf-8") as fout: + fout.write(fake_text_2) + + vocab = char_model.build_charlm_vocab(tempdir, cutoff=2) + + counts = Counter(fake_text_1) + Counter(fake_text_2) + for letter, count in counts.most_common(): + if count < 2: + assert letter not in vocab + else: + assert letter in vocab + + def test_build_model(self): + """ + Test the whole thing on a small dataset for an iteration or two + """ + with tempfile.TemporaryDirectory() as tempdir: + eval_file = os.path.join(tempdir, "en_test.dev.txt") + with open(eval_file, "w", encoding="utf-8") as fout: + fout.write(fake_text_1) + train_file = os.path.join(tempdir, "en_test.train.txt") + with open(train_file, "w", encoding="utf-8") as fout: + for i in range(1000): + fout.write(fake_text_1) + fout.write("\n") + fout.write(fake_text_2) + fout.write("\n") + save_name = 'en_test.forward.pt' + vocab_save_name = 'en_text.vocab.pt' + checkpoint_save_name = 'en_text.checkpoint.pt' + args = ['--train_file', train_file, + '--eval_file', eval_file, + '--eval_steps', '0', # eval once per opoch + '--epochs', '2', + '--cutoff', '1', + '--batch_size', '%d' % len(fake_text_1), + '--shorthand', 'en_test', + '--save_dir', tempdir, + '--save_name', save_name, + '--vocab_save_name', vocab_save_name, + '--checkpoint_save_name', checkpoint_save_name] + args = charlm.parse_args(args) + charlm.train(args) + + assert os.path.exists(os.path.join(tempdir, vocab_save_name)) + + # test that saving & loading of the model worked + assert os.path.exists(os.path.join(tempdir, save_name)) + model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name)) + + # test that saving & loading of the checkpoint worked + assert os.path.exists(os.path.join(tempdir, checkpoint_save_name)) + model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name)) + trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name)) + + assert trainer.global_step > 0 + assert trainer.epoch == 2 + + # quick test to verify this method works with a trained model + charlm.get_current_lr(trainer, args) + + # test loading a vocab built by the training method... + vocab = charlm.load_char_vocab(os.path.join(tempdir, vocab_save_name)) + trainer = char_model.CharacterLanguageModelTrainer.from_new_model(args, vocab) + # ... and test the get_current_lr for an untrained model as well + # this test is super "eager" + assert charlm.get_current_lr(trainer, args) == args['lr0'] + + @pytest.fixture(scope="class") + def english_forward(self): + # eg, stanza_test/models/en/forward_charlm/1billion.pt + models_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "*") + models = glob.glob(models_path) + # we expect at least one English model downloaded for the tests + assert len(models) >= 1 + model_file = models[0] + return char_model.CharacterLanguageModel.load(model_file) + + @pytest.fixture(scope="class") + def english_backward(self): + # eg, stanza_test/models/en/forward_charlm/1billion.pt + models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*") + models = glob.glob(models_path) + # we expect at least one English model downloaded for the tests + assert len(models) >= 1 + model_file = models[0] + return char_model.CharacterLanguageModel.load(model_file) + + def test_load_model(self, english_forward, english_backward): + """ + Check that basic loading functions work + """ + assert english_forward.is_forward_lm + assert not english_backward.is_forward_lm + + def test_save_load_model(self, english_forward, english_backward): + """ + Load, save, and load again + """ + with tempfile.TemporaryDirectory() as tempdir: + for model in (english_forward, english_backward): + save_file = os.path.join(tempdir, "resaved", "charlm.pt") + model.save(save_file) + reloaded = char_model.CharacterLanguageModel.load(save_file) + assert model.is_forward_lm == reloaded.is_forward_lm diff --git a/stanza/stanza/tests/common/test_common_data.py b/stanza/stanza/tests/common/test_common_data.py new file mode 100644 index 0000000000000000000000000000000000000000..52c9d636d5c8cfe4d1080d7e5c68f800920dbd21 --- /dev/null +++ b/stanza/stanza/tests/common/test_common_data.py @@ -0,0 +1,32 @@ +import pytest +import stanza + +from stanza.tests import * +from stanza.models.common.data import get_augment_ratio, augment_punct + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +def test_augment_ratio(): + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + should_augment = lambda x: x >= 3 + can_augment = lambda x: x >= 4 + # check that zero is returned if no augmentation is needed + # which will be the case since 2 are already satisfactory + assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0 + + # this should throw an error + with pytest.raises(AssertionError): + get_augment_ratio(data, can_augment, should_augment) + + # with a desired ratio of 0.4, + # there are already 2 that don't need augmenting + # and 7 that are eligible to be augmented + # so 2/7 will need to be augmented + assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7) + +def test_augment_punct(): + data = [["Simple", "test", "."]] + should_augment = lambda x: x[-1] == "." + can_augment = should_augment + new_data = augment_punct(data, 1.0, should_augment, can_augment) + assert new_data == [["Simple", "test"]] diff --git a/stanza/stanza/tests/common/test_data_objects.py b/stanza/stanza/tests/common/test_data_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6d847aca33904aca43ad4c8e757455f180edac --- /dev/null +++ b/stanza/stanza/tests/common/test_data_objects.py @@ -0,0 +1,60 @@ +""" +Basic tests of the stanza data objects, especially the setter/getter routines +""" +import pytest + +import stanza +from stanza.models.common.doc import Document, Sentence, Word +from stanza.tests import * + +pytestmark = pytest.mark.pipeline + +# data for testing +EN_DOC = "This is a test document. Pretty cool!" + +EN_DOC_UPOS_XPOS = (('PRON_DT', 'AUX_VBZ', 'DET_DT', 'NOUN_NN', 'NOUN_NN', 'PUNCT_.'), ('ADV_RB', 'ADJ_JJ', 'PUNCT_.')) + +EN_DOC2 = "Chris Manning wrote a sentence. Then another." + +@pytest.fixture(scope="module") +def nlp_pipeline(): + nlp = stanza.Pipeline(dir=TEST_MODELS_DIR, lang='en') + return nlp + +def test_readonly(nlp_pipeline): + Document.add_property('some_property', 123) + doc = nlp_pipeline(EN_DOC) + assert doc.some_property == 123 + with pytest.raises(ValueError): + doc.some_property = 456 + + +def test_getter(nlp_pipeline): + Word.add_property('upos_xpos', getter=lambda self: f"{self.upos}_{self.xpos}") + + doc = nlp_pipeline(EN_DOC) + + assert EN_DOC_UPOS_XPOS == tuple(tuple(word.upos_xpos for word in sentence.words) for sentence in doc.sentences) + +def test_setter_getter(nlp_pipeline): + int2str = {0: 'ok', 1: 'good', 2: 'bad'} + str2int = {'ok': 0, 'good': 1, 'bad': 2} + def setter(self, value): + self._classname = str2int[value] + Sentence.add_property('classname', getter=lambda self: int2str[self._classname] if self._classname is not None else None, setter=setter) + + doc = nlp_pipeline(EN_DOC) + sentence = doc.sentences[0] + sentence.classname = 'good' + assert sentence._classname == 1 + + # don't try this at home + sentence._classname = 2 + assert sentence.classname == 'bad' + +def test_backpointer(nlp_pipeline): + doc = nlp_pipeline(EN_DOC2) + ent = doc.ents[0] + assert ent.sent is doc.sentences[0] + assert list(doc.iter_words())[0].sent is doc.sentences[0] + assert list(doc.iter_tokens())[-1].sent is doc.sentences[-1] diff --git a/stanza/stanza/tests/common/test_doc.py b/stanza/stanza/tests/common/test_doc.py new file mode 100644 index 0000000000000000000000000000000000000000..960020c87f04f6ef582440c3d538c20ab619a1e4 --- /dev/null +++ b/stanza/stanza/tests/common/test_doc.py @@ -0,0 +1,174 @@ +import pytest + +import stanza +from stanza.tests import * +from stanza.models.common.doc import Document, ID, TEXT, NER, CONSTITUENCY, SENTIMENT + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +@pytest.fixture +def sentences_dict(): + return [[{ID: 1, TEXT: "unban"}, + {ID: 2, TEXT: "mox"}, + {ID: 3, TEXT: "opal"}], + [{ID: 4, TEXT: "ban"}, + {ID: 5, TEXT: "Lurrus"}]] + +@pytest.fixture +def doc(sentences_dict): + doc = Document(sentences_dict) + return doc + +def test_basic_values(doc, sentences_dict): + """ + Test that sentences & token text are properly set when constructing a doc + """ + assert len(doc.sentences) == len(sentences_dict) + + for sentence, raw_sentence in zip(doc.sentences, sentences_dict): + assert sentence.doc == doc + assert len(sentence.tokens) == len(raw_sentence) + for token, raw_token in zip(sentence.tokens, raw_sentence): + assert token.text == raw_token[TEXT] + +def test_set_sentence(doc): + """ + Test setting a field on the sentences themselves + """ + doc.set(fields="sentiment", + contents=["4", "0"], + to_sentence=True) + + assert doc.sentences[0].sentiment == "4" + assert doc.sentences[1].sentiment == "0" + +def test_set_tokens(doc): + """ + Test setting values on tokens + """ + ner_contents = ["O", "ARTIFACT", "ARTIFACT", "O", "CAT"] + doc.set(fields=NER, + contents=ner_contents, + to_token=True) + + result = doc.get(NER, from_token=True) + assert result == ner_contents + + +def test_constituency_comment(doc): + """ + Test that setting the constituency tree on a doc sets the constituency comment + """ + for sentence in doc.sentences: + assert len([x for x in sentence.comments if x.startswith("# constituency")]) == 0 + + # currently nothing is checking that the items are actually trees + trees = ["asdf", "zzzz"] + doc.set(fields=CONSTITUENCY, + contents=trees, + to_sentence=True) + + for sentence, expected in zip(doc.sentences, trees): + constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")] + assert len(constituency_comments) == 1 + assert constituency_comments[0].endswith(expected) + + # Test that if we replace the trees with an updated tree, the comment is also replaced + trees = ["zzzz", "asdf"] + doc.set(fields=CONSTITUENCY, + contents=trees, + to_sentence=True) + + for sentence, expected in zip(doc.sentences, trees): + constituency_comments = [x for x in sentence.comments if x.startswith("# constituency")] + assert len(constituency_comments) == 1 + assert constituency_comments[0].endswith(expected) + +def test_sentiment_comment(doc): + """ + Test that setting the sentiment on a doc sets the sentiment comment + """ + for sentence in doc.sentences: + assert len([x for x in sentence.comments if x.startswith("# sentiment")]) == 0 + + # currently nothing is checking that the items are actually trees + sentiments = ["1", "2"] + doc.set(fields=SENTIMENT, + contents=sentiments, + to_sentence=True) + + for sentence, expected in zip(doc.sentences, sentiments): + sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")] + assert len(sentiment_comments) == 1 + assert sentiment_comments[0].endswith(expected) + + # Test that if we replace the trees with an updated tree, the comment is also replaced + sentiments = ["3", "4"] + doc.set(fields=SENTIMENT, + contents=sentiments, + to_sentence=True) + + for sentence, expected in zip(doc.sentences, sentiments): + sentiment_comments = [x for x in sentence.comments if x.startswith("# sentiment")] + assert len(sentiment_comments) == 1 + assert sentiment_comments[0].endswith(expected) + +def test_sent_id_comment(doc): + """ + Test that setting the sent_id on a sentence sets the sentiment comment + """ + for sent_idx, sentence in enumerate(doc.sentences): + assert len([x for x in sentence.comments if x.startswith("# sent_id")]) == 1 + assert sentence.sent_id == "%d" % sent_idx + doc.sentences[0].sent_id = "foo" + assert doc.sentences[0].sent_id == "foo" + assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1 + assert "# sent_id = foo" in doc.sentences[0].comments + + doc.reindex_sentences(10) + for sent_idx, sentence in enumerate(doc.sentences): + assert sentence.sent_id == "%d" % (sent_idx + 10) + assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1 + assert "# sent_id = %d" % (sent_idx + 10) in sentence.comments + + doc.sentences[0].add_comment("# sent_id = bar") + assert doc.sentences[0].sent_id == "bar" + assert "# sent_id = bar" in doc.sentences[0].comments + assert len([x for x in doc.sentences[0].comments if x.startswith("# sent_id")]) == 1 + +def test_doc_id_comment(doc): + """ + Test that setting the doc_id on a sentence sets the document comment + """ + assert doc.sentences[0].doc_id is None + assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 0 + + doc.sentences[0].doc_id = "foo" + assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1 + assert "# doc_id = foo" in doc.sentences[0].comments + assert doc.sentences[0].doc_id == "foo" + + doc.sentences[0].add_comment("# doc_id = bar") + assert len([x for x in doc.sentences[0].comments if x.startswith("# doc_id")]) == 1 + assert doc.sentences[0].doc_id == "bar" + +@pytest.fixture(scope="module") +def pipeline(): + return stanza.Pipeline(dir=TEST_MODELS_DIR) + +def test_serialized(pipeline): + """ + Brief test of the serialized format + + Checks that NER entities are correctly set. + Also checks that constituency & sentiment are set on the sentences. + """ + text = "John Bauer works at Stanford" + doc = pipeline(text) + assert len(doc.ents) == 2 + serialized = doc.to_serialized() + doc2 = Document.from_serialized(serialized) + assert len(doc2.sentences) == 1 + assert len(doc2.ents) == 2 + assert doc.sentences[0].constituency == doc2.sentences[0].constituency + assert doc.sentences[0].sentiment == doc2.sentences[0].sentiment diff --git a/stanza/stanza/tests/common/test_dropout.py b/stanza/stanza/tests/common/test_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee42b429432e745a7f1f6540cea4d57af9cdb4c --- /dev/null +++ b/stanza/stanza/tests/common/test_dropout.py @@ -0,0 +1,28 @@ +import pytest + +import torch + +import stanza +from stanza.models.common.dropout import WordDropout + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +def test_word_dropout(): + """ + Test that word_dropout is randomly dropping out the entire final dimension of a tensor + + Doing 600 small rows should be super fast, but it leaves us with + something like a 1 in 10^180 chance of the test failing. Not very + common, in other words + """ + wd = WordDropout(0.5) + batch = torch.randn(600, 4) + dropped = wd(batch) + # the one time any of this happens, it's going to be really confusing + assert not torch.allclose(batch, dropped) + num_zeros = 0 + for i in range(batch.shape[0]): + assert torch.allclose(dropped[i], batch[i]) or torch.sum(dropped[i]) == 0.0 + if torch.sum(dropped[i]) == 0.0: + num_zeros += 1 + assert num_zeros > 0 and num_zeros < batch.shape[0] diff --git a/stanza/stanza/tests/common/test_short_name_to_treebank.py b/stanza/stanza/tests/common/test_short_name_to_treebank.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4513acc6e61f77105665397668d4bad5c2dace --- /dev/null +++ b/stanza/stanza/tests/common/test_short_name_to_treebank.py @@ -0,0 +1,14 @@ +import pytest + +import stanza +from stanza.models.common import short_name_to_treebank + +pytestmark = [pytest.mark.travis, pytest.mark.pipeline] + +def test_short_name(): + assert short_name_to_treebank.short_name_to_treebank("en_ewt") == "UD_English-EWT" + +def test_canonical_name(): + assert short_name_to_treebank.canonical_treebank_name("UD_URDU-UDTB") == "UD_Urdu-UDTB" + assert short_name_to_treebank.canonical_treebank_name("ur_udtb") == "UD_Urdu-UDTB" + assert short_name_to_treebank.canonical_treebank_name("Unban_Mox_Opal") == "Unban_Mox_Opal" diff --git a/stanza/stanza/tests/constituency/test_convert_it_vit.py b/stanza/stanza/tests/constituency/test_convert_it_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f2441a3c62ac9f4dfcbbc92c1d50b36bd8705077 --- /dev/null +++ b/stanza/stanza/tests/constituency/test_convert_it_vit.py @@ -0,0 +1,228 @@ +""" +Test a couple different classes of trees to check the output of the VIT conversion + +A couple representative trees are included, but hopefully not enough +to be a problem in terms of our license. + +One of the tests is currently disabled as it relies on tregex & tsurgeon features +not yet released +""" + +import io +import os +import tempfile + +import pytest + +from stanza.server import tsurgeon +from stanza.utils.conll import CoNLL +from stanza.utils.datasets.constituency import convert_it_vit + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +# just a sample! don't sue us please +CON_SAMPLE = """ +#ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]] + +#ID=sent_00318 dirsp-[fc-[congf-tuttavia, f-[sn-[sq-[ind-qualche], n-problema], ir_infl-[vsupir-potrebbe, vcl-esserci], compc-[clit-ci, sp-[p-per, sn-[art-la, n-commissione, sa-[ag-esteri], f2-[sp-[part-alla, relob-cui, sn-[n-presidenza]], f-[ibar-[vc-è], compc-[sn-[n-candidato], sn-[art-l, n-esponente, spd-[pd-di, sn-[mw-Alleanza, npro-Nazionale]], sn-[mw-Mirko, nh-Tremaglia]]]]]]]]]], dirs-':', f3-[sn-[art-una, n-candidatura, sc-[q-più, sa-[ppas-subìta], sc-[ccong-che, sa-[ppas-gradita]], compt-[spda-[partda-dalla, sn-[mw-Lega, npro-Nord, punt-',', f2-[rel-che, fc-[congf-tuttavia, f-[ir_infl-[vsupir-dovrebbe, vit-rispettare], compt-[sn-[art-gli, n-accordi]]]]]]]]]], punto-.]] + +#ID=sent_00589 f-[sn-[art-l, n-ottimismo, spd-[pd-di, sn-[nh-Kantor]]], ir_infl-[vsupir-potrebbe, congf-però, vcl-rivelarsi], compc-[sn-[in-ancora, art-una, nt-volta], sa-[ag-prematuro]], punto-.] +""" + +UD_SAMPLE = """ +# sent_id = VIT-2 +# text = Negli ultimi anni la dinamica dei polo di attrazione è stata sempre più caratterizzata dall'emergere di una crescente concorrenza che si è progressivamente spostata dalle singole imprese ai sistemi economici e territoriali, determinando l'esigenza di una riconsiderazione dei rapporti esistenti tra soggetti produttivi e ambiente in cui questi operano. +1-2 Negli _ _ _ _ _ _ _ _ +1 In in ADP E _ 4 case _ _ +2 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 4 det _ _ +3 ultimi ultimo ADJ A Gender=Masc|Number=Plur 4 amod _ _ +4 anni anno NOUN S Gender=Masc|Number=Plur 16 obl _ _ +5 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 6 det _ _ +6 dinamica dinamica NOUN S Gender=Fem|Number=Sing 16 nsubj:pass _ _ +7-8 dei _ _ _ _ _ _ _ _ +7 di di ADP E _ 9 case _ _ +8 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 9 det _ _ +9 polo polo NOUN S Gender=Masc|Number=Sing 6 nmod _ _ +10 di di ADP E _ 11 case _ _ +11 attrazione attrazione NOUN S Gender=Fem|Number=Sing 9 nmod _ _ +12 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux _ _ +13 stata essere AUX VA Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 16 aux:pass _ _ +14 sempre sempre ADV B _ 15 advmod _ _ +15 più più ADV B _ 16 advmod _ _ +16 caratterizzata caratterizzare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 0 root _ _ +17-18 dall' _ _ _ _ _ _ _ SpaceAfter=No +17 da da ADP E _ 19 case _ _ +18 l' il DET RD Definite=Def|Number=Sing|PronType=Art 19 det _ _ +19 emergere emergere NOUN S Gender=Masc|Number=Sing 16 obl _ _ +20 di di ADP E _ 23 case _ _ +21 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 23 det _ _ +22 crescente crescente ADJ A Number=Sing 23 amod _ _ +23 concorrenza concorrenza NOUN S Gender=Fem|Number=Sing 19 nmod _ _ +24 che che PRON PR PronType=Rel 28 nsubj _ _ +25 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 28 expl _ _ +26 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 28 aux _ _ +27 progressivamente progressivamente ADV B _ 28 advmod _ _ +28 spostata spostare VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 23 acl:relcl _ _ +29-30 dalle _ _ _ _ _ _ _ _ +29 da da ADP E _ 32 case _ _ +30 le il DET RD Definite=Def|Gender=Fem|Number=Plur|PronType=Art 32 det _ _ +31 singole singolo ADJ A Gender=Fem|Number=Plur 32 amod _ _ +32 imprese impresa NOUN S Gender=Fem|Number=Plur 28 obl _ _ +33-34 ai _ _ _ _ _ _ _ _ +33 a a ADP E _ 35 case _ _ +34 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 35 det _ _ +35 sistemi sistema NOUN S Gender=Masc|Number=Plur 28 obl _ _ +36 economici economico ADJ A Gender=Masc|Number=Plur 35 amod _ _ +37 e e CCONJ CC _ 38 cc _ _ +38 territoriali territoriale ADJ A Number=Plur 36 conj _ SpaceAfter=No +39 , , PUNCT FF _ 28 punct _ _ +40 determinando determinare VERB V VerbForm=Ger 28 advcl _ _ +41 l' il DET RD Definite=Def|Number=Sing|PronType=Art 42 det _ SpaceAfter=No +42 esigenza esigenza NOUN S Gender=Fem|Number=Sing 40 obj _ _ +43 di di ADP E _ 45 case _ _ +44 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 45 det _ _ +45 riconsiderazione riconsiderazione NOUN S Gender=Fem|Number=Sing 42 nmod _ _ +46-47 dei _ _ _ _ _ _ _ _ +46 di di ADP E _ 48 case _ _ +47 i il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 48 det _ _ +48 rapporti rapporto NOUN S Gender=Masc|Number=Plur 45 nmod _ _ +49 esistenti esistente VERB V Number=Plur 48 acl _ _ +50 tra tra ADP E _ 51 case _ _ +51 soggetti soggetto NOUN S Gender=Masc|Number=Plur 49 obl _ _ +52 produttivi produttivo ADJ A Gender=Masc|Number=Plur 51 amod _ _ +53 e e CCONJ CC _ 54 cc _ _ +54 ambiente ambiente NOUN S Gender=Masc|Number=Sing 51 conj _ _ +55 in in ADP E _ 56 case _ _ +56 cui cui PRON PR PronType=Rel 58 obl _ _ +57 questi questo PRON PD Gender=Masc|Number=Plur|PronType=Dem 58 nsubj _ _ +58 operano operare VERB V Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 54 acl:relcl _ SpaceAfter=No +59 . . PUNCT FS _ 16 punct _ _ + +# sent_id = VIT-318 +# text = Tuttavia qualche problema potrebbe esserci per la commissione esteri alla cui presidenza è candidato l'esponente di Alleanza Nazionale Mirko Tremaglia: una candidatura più subìta che gradita dalla Lega Nord, che tuttavia dovrebbe rispettare gli accordi. +1 Tuttavia tuttavia CCONJ CC _ 5 cc _ _ +2 qualche qualche DET DI Number=Sing|PronType=Ind 3 det _ _ +3 problema problema NOUN S Gender=Masc|Number=Sing 5 nsubj _ _ +4 potrebbe potere AUX VA Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux _ _ +5-6 esserci _ _ _ _ _ _ _ _ +5 esser essere VERB V VerbForm=Inf 0 root _ _ +6 ci ci PRON PC Clitic=Yes|Number=Plur|Person=1|PronType=Prs 5 expl _ _ +7 per per ADP E _ 9 case _ _ +8 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 9 det _ _ +9 commissione commissione NOUN S Gender=Fem|Number=Sing 5 obl _ _ +10 esteri estero ADJ A Gender=Masc|Number=Plur 9 amod _ _ +11-12 alla _ _ _ _ _ _ _ _ +11 a a ADP E _ 14 case _ _ +12 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 14 det _ _ +13 cui cui DET DR PronType=Rel 14 det:poss _ _ +14 presidenza presidenza NOUN S Gender=Fem|Number=Sing 16 obl _ _ +15 è essere AUX VA Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 16 aux:pass _ _ +16 candidato candidare VERB V Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 9 acl:relcl _ _ +17 l' il DET RD Definite=Def|Number=Sing|PronType=Art 18 det _ SpaceAfter=No +18 esponente esponente NOUN S Number=Sing 16 nsubj:pass _ _ +19 di di ADP E _ 20 case _ _ +20 Alleanza Alleanza PROPN SP _ 18 nmod _ _ +21 Nazionale Nazionale PROPN SP _ 20 flat:name _ _ +22 Mirko Mirko PROPN SP _ 18 nmod _ _ +23 Tremaglia Tremaglia PROPN SP _ 22 flat:name _ SpaceAfter=No +24 : : PUNCT FC _ 22 punct _ _ +25 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 26 det _ _ +26 candidatura candidatura NOUN S Gender=Fem|Number=Sing 22 appos _ _ +27 più più ADV B _ 28 advmod _ _ +28 subìta subire VERB V Gender=Fem|Number=Sing|Tense=Past|VerbForm=Part 26 advcl _ _ +29 che che CCONJ CC _ 30 cc _ _ +30 gradita gradito ADJ A Gender=Fem|Number=Sing 28 amod _ _ +31-32 dalla _ _ _ _ _ _ _ _ +31 da da ADP E _ 33 case _ _ +32 la il DET RD Definite=Def|Gender=Fem|Number=Sing|PronType=Art 33 det _ _ +33 Lega Lega PROPN SP _ 28 obl:agent _ _ +34 Nord Nord PROPN SP _ 33 flat:name _ SpaceAfter=No +35 , , PUNCT FC _ 33 punct _ _ +36 che che PRON PR PronType=Rel 39 nsubj _ _ +37 tuttavia tuttavia CCONJ CC _ 39 cc _ _ +38 dovrebbe dovere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 39 aux _ _ +39 rispettare rispettare VERB V VerbForm=Inf 33 acl:relcl _ _ +40 gli il DET RD Definite=Def|Gender=Masc|Number=Plur|PronType=Art 41 det _ _ +41 accordi accordio NOUN S Gender=Masc|Number=Plur 39 obj _ SpaceAfter=No +42 . . PUNCT FS _ 5 punct _ _ + +# sent_id = VIT-591 +# text = L'ottimismo di Kantor potrebbe però rivelarsi ancora una volta prematuro. +1 L' il DET RD Definite=Def|Number=Sing|PronType=Art 2 det _ SpaceAfter=No +2 ottimismo ottimismo NOUN S Gender=Masc|Number=Sing 7 nsubj _ _ +3 di di ADP E _ 4 case _ _ +4 Kantor Kantor PROPN SP _ 2 nmod _ _ +5 potrebbe potere AUX VM Mood=Cnd|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux _ _ +6 però però ADV B _ 7 advmod _ _ +7-8 rivelarsi _ _ _ _ _ _ _ _ +7 rivelar rivelare VERB V VerbForm=Inf 0 root _ _ +8 si si PRON PC Clitic=Yes|Person=3|PronType=Prs 7 expl _ _ +9 ancora ancora ADV B _ 7 advmod _ _ +10 una uno DET RI Definite=Ind|Gender=Fem|Number=Sing|PronType=Art 11 det _ _ +11 volta volta NOUN S Gender=Fem|Number=Sing 7 obl _ _ +12 prematuro prematuro ADJ A Gender=Masc|Number=Sing 7 xcomp _ SpaceAfter=No +13 . . PUNCT FS _ 7 punct _ _ +""" + + +def test_process_mwts(): + # dei appears multiple times + # the verb/pron esserci will be ignored + expected_mwts = {'Negli': ('In', 'gli'), 'dei': ('di', 'i'), "dall'": ('da', "l'"), 'dalle': ('da', 'le'), 'ai': ('a', 'i'), 'alla': ('a', 'la'), 'dalla': ('da', 'la')} + + ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE) + + mwts = convert_it_vit.get_mwt(ud_train_data) + assert expected_mwts == mwts + +def test_raw_tree(): + con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE)) + expected_ids = ["#ID=sent_00002", "#ID=sent_00318", "#ID=sent_00589"] + expected_trees = ["(ROOT (cp (sp (part negli) (sn (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd dei) (sn (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda dall) (sn (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda dalle) (sn (sa (ag singole)) (n imprese))) (sp (part ai) (sn (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd dei) (sn (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))", + "(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part alla) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda dalla) (sn (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))", + "(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"] + assert len(con_sentences) == 3 + for sentence, expected_id, expected_tree in zip(con_sentences, expected_ids, expected_trees): + assert sentence[0] == expected_id + tree = convert_it_vit.raw_tree(sentence[1]) + assert str(tree) == expected_tree + +def test_update_mwts(): + con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_SAMPLE)) + ud_train_data = CoNLL.conll2doc(input_str=UD_SAMPLE) + mwt_map = convert_it_vit.get_mwt(ud_train_data) + expected_trees=["(ROOT (cp (sp (part In) (sn (art gli) (sa (ag ultimi)) (nt anni))) (f (sn (art la) (n dinamica) (spd (partd di) (sn (art i) (n polo) (n di) (n attrazione)))) (ibar (ause è) (ausep stata) (savv (savv (avv sempre)) (avv più)) (vppt caratterizzata)) (compin (spda (partda da) (sn (art l') (n emergere) (spd (pd di) (sn (art una) (sa (ag crescente)) (n concorrenza) (f2 (rel che) (f (ibar (clit si) (ause è) (avv progressivamente) (vppin spostata)) (compin (spda (partda da) (sn (art le) (sa (ag singole)) (n imprese))) (sp (part a) (sn (art i) (n sistemi) (sa (coord (ag economici) (cong e) (ag territoriali))))) (fp (punt ,) (sv5 (vgt determinando) (compt (sn (art l') (nf esigenza) (spd (pd di) (sn (art una) (n riconsiderazione) (spd (partd di) (sn (art i) (n rapporti) (sv3 (ppre esistenti) (compin (sp (p tra) (sn (n soggetti) (sa (ag produttivi)))) (cong e) (sn (n ambiente) (f2 (sp (p in) (sn (relob cui))) (f (sn (deit questi)) (ibar (vin operano) (punto .))))))))))))))))))))))))))", + "(ROOT (dirsp (fc (congf tuttavia) (f (sn (sq (ind qualche)) (n problema)) (ir_infl (vsupir potrebbe) (vcl esserci)) (compc (clit ci) (sp (p per) (sn (art la) (n commissione) (sa (ag esteri)) (f2 (sp (part a) (art la) (relob cui) (sn (n presidenza))) (f (ibar (vc è)) (compc (sn (n candidato)) (sn (art l) (n esponente) (spd (pd di) (sn (mw Alleanza) (npro Nazionale))) (sn (mw Mirko) (nh Tremaglia))))))))))) (dirs :) (f3 (sn (art una) (n candidatura) (sc (q più) (sa (ppas subìta)) (sc (ccong che) (sa (ppas gradita))) (compt (spda (partda da) (sn (art la) (mw Lega) (npro Nord) (punt ,) (f2 (rel che) (fc (congf tuttavia) (f (ir_infl (vsupir dovrebbe) (vit rispettare)) (compt (sn (art gli) (n accordi))))))))))) (punto .))))", + "(ROOT (f (sn (art l) (n ottimismo) (spd (pd di) (sn (nh Kantor)))) (ir_infl (vsupir potrebbe) (congf però) (vcl rivelarsi)) (compc (clit si) (sn (in ancora) (art una) (nt volta)) (sa (ag prematuro))) (punto .)))"] + with tsurgeon.Tsurgeon() as tsurgeon_processor: + for con_sentence, ud_sentence, expected_tree in zip(con_sentences, ud_train_data.sentences, expected_trees): + con_tree = convert_it_vit.raw_tree(con_sentence[1]) + updated_tree, _ = convert_it_vit.update_mwts_and_special_cases(con_tree, ud_sentence, mwt_map, tsurgeon_processor) + assert str(updated_tree) == expected_tree + + +CON_PERCENT_SAMPLE = """ +ID#sent_00020 f-[sn-[art-il, n-tesoro], ibar-[vt-mette], compt-[sp-[part-sul, sn-[n-mercato]], sn-[art-il, num-51%, sp-[p-a, sn-[num-2, n-lire]], sp-[p-per, sn-[n-azione]]]], punto-.] +ID#sent_00022 dirsp-[f3-[sn-[art-le, n-novità]], dirs-':', f3-[coord-[sn-[n-voto, spd-[pd-di, sn-[n-lista]]], cong-e, sn-[n-tetto, sp-[part-agli, sn-[n-acquisti]], sv3-[vppt-limitato, comppas-[sp-[part-allo, sn-[num-0/5%]]]]]], punto-.]] +ID#sent_00517 dirsp-[fc-[f-[sn-[art-l, n-aumento, sa-[ag-mensile], spd-[pd-di, sn-[nt-aprile]]], ibar-[ause-è, vppc-stato], compc-[sq-[q-dell_, sn-[num-1/3%]], sp-[p-contro, sn-[art-lo, num-0/7/0/8%, spd-[partd-degli, sn-[sa-[ag-ultimi], num-due, sn-[nt-mesi]]]]]]]]] +ID#sent_01117 fc-[f-[sn-[art-La, sa-[ag-crescente], n-ripresa, spd-[partd-dei, sn-[n-beni, spd-[pd-di, sn-[n-consumo]]]]], ibar-[vin-deriva], savv-[avv-esclusivamente], compin-[spda-[partda-dal, sn-[n-miglioramento, f2-[spd-[pd-di, sn-[relob-cui]], f-[ibar-[ausa-hanno, vppin-beneficiato], compin-[sn-[n-beni, coord-[sa-[ag-durevoli, fp-[par-'(', sn-[num-plus4/5%], par-')']], cong-e, sa-[ag-semidurevoli, fp-[par-'(', sn-[num-plus1/5%], par-')']]]]]]]]]]], punt-',', fs-[cosu-mentre, f-[sn-[art-i, n-beni, sa-[neg-non, ag-durevoli], fp-[par-'(', sn-[num-min1%], par-')']], ibar-[vt-accusano], cong-ancora, compt-[sn-[art-un, sa-[ag-evidente], n-ritardo]]]], punto-.] +""" + +CON_PERCENT_LEAVES = [ + ['il', 'tesoro', 'mette', 'sul', 'mercato', 'il', '51', '%%', 'a', '2', 'lire', 'per', 'azione', '.'], + ['le', 'novità', ':', 'voto', 'di', 'lista', 'e', 'tetto', 'agli', 'acquisti', 'limitato', 'allo', '0,5', '%%', '.'], + ['l', 'aumento', 'mensile', 'di', 'aprile', 'è', 'stato', "dell'", '1,3', '%%', 'contro', 'lo', '0/7,0/8', '%%', 'degli', 'ultimi', 'due', 'mesi'], + # the plus and min look bad, but they get cleaned up when merging with the UD version of the dataset + ['La', 'crescente', 'ripresa', 'dei', 'beni', 'di', 'consumo', 'deriva', 'esclusivamente', 'dal', 'miglioramento', 'di', 'cui', 'hanno', 'beneficiato', 'beni', 'durevoli', '(', 'plus4,5', '%%', ')', 'e', 'semidurevoli', '(', 'plus1,5', '%%', ')', ',', 'mentre', 'i', 'beni', 'non', 'durevoli', '(', 'min1', '%%', ')', 'accusano', 'ancora', 'un', 'evidente', 'ritardo', '.'] +] + +def test_read_percent(): + con_sentences = convert_it_vit.read_constituency_sentences(io.StringIO(CON_PERCENT_SAMPLE)) + assert len(con_sentences) == len(CON_PERCENT_LEAVES) + for (_, raw_tree), expected_leaves in zip(con_sentences, CON_PERCENT_LEAVES): + tree = convert_it_vit.raw_tree(raw_tree) + words = tree.leaf_labels() + if expected_leaves is None: + print(words) + else: + assert words == expected_leaves diff --git a/stanza/stanza/tests/constituency/test_convert_starlang.py b/stanza/stanza/tests/constituency/test_convert_starlang.py new file mode 100644 index 0000000000000000000000000000000000000000..183b0f1608e2357764093fe18cb4d231636914bd --- /dev/null +++ b/stanza/stanza/tests/constituency/test_convert_starlang.py @@ -0,0 +1,37 @@ +""" +Test a couple different classes of trees to check the output of the Starlang conversion +""" + +import os +import tempfile + +import pytest + +from stanza.utils.datasets.constituency import convert_starlang + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +TREE="( (S (NP (NP {morphologicalAnalysis=bayan+NOUN+A3SG+PNON+NOM}{metaMorphemes=bayan}{turkish=Bayan}{english=Ms.}{semantics=TUR10-0396530}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580}{englishSemantics=ENG31-06352895-n}) (NP {morphologicalAnalysis=haag+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=haag}{turkish=Haag}{english=Haag}{semantics=TUR10-0000000}{namedEntity=PERSON}{propBank=ARG0$TUR10-0148580})) (VP (NP {morphologicalAnalysis=elianti+NOUN+PROP+A3SG+PNON+NOM}{metaMorphemes=elianti}{turkish=Elianti}{english=Elianti}{semantics=TUR10-0000000}{namedEntity=NONE}{propBank=ARG1$TUR10-0148580}) (VP {morphologicalAnalysis=çal+VERB+POS+AOR+A3SG}{metaMorphemes=çal+Ar}{turkish=çalar}{english=plays}{semantics=TUR10-0148580}{namedEntity=NONE}{propBank=PREDICATE$TUR10-0148580}{englishSemantics=ENG31-01730049-v})) (. {morphologicalAnalysis=.+PUNC}{metaMorphemes=.}{metaMorphemesMoved=.}{turkish=.}{english=.}{semantics=TUR10-1081860}{namedEntity=NONE}{propBank=NONE})) )" + +def test_read_tree(): + """ + Test a basic tree read + """ + tree = convert_starlang.read_tree(TREE) + assert "(ROOT (S (NP (NP Bayan) (NP Haag)) (VP (NP Elianti) (VP çalar)) (. .)))" == str(tree) + +def test_missing_word(): + """ + Test that an error is thrown if the word is missing + """ + tree_text = TREE.replace("turkish=", "foo=") + with pytest.raises(ValueError): + tree = convert_starlang.read_tree(tree_text) + +def test_bad_label(): + """ + Test that an unexpected label results in an error + """ + tree_text = TREE.replace("(S", "(s") + with pytest.raises(ValueError): + tree = convert_starlang.read_tree(tree_text) diff --git a/stanza/stanza/tests/constituency/test_in_order_oracle.py b/stanza/stanza/tests/constituency/test_in_order_oracle.py new file mode 100644 index 0000000000000000000000000000000000000000..e85b988820d41453d52cb1424f3f917f81887291 --- /dev/null +++ b/stanza/stanza/tests/constituency/test_in_order_oracle.py @@ -0,0 +1,522 @@ +import itertools +import pytest + +from stanza.models.constituency import parse_transitions +from stanza.models.constituency import tree_reader +from stanza.models.constituency.base_model import SimpleModel +from stanza.models.constituency.in_order_oracle import * +from stanza.models.constituency.parse_transitions import CloseConstituent, OpenConstituent, Shift, TransitionScheme +from stanza.models.constituency.transition_sequence import build_treebank + +from stanza.tests import * +from stanza.tests.constituency.test_transition_sequence import reconstruct_tree + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +# A sample tree from PTB with a single unary transition (at a location other than root) +SINGLE_UNARY_TREE = """ +( (S + (NP-SBJ-1 (DT A) (NN record) (NN date) ) + (VP (VBZ has) (RB n't) + (VP (VBN been) + (VP (VBN set) + (NP (-NONE- *-1) )))) + (. .) )) +""" + +# [Shift, OpenConstituent(('NP-SBJ-1',)), Shift, Shift, CloseConstituent, OpenConstituent(('S',)), Shift, OpenConstituent(('VP',)), Shift, Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), CloseConstituent, CloseConstituent, CloseConstituent, CloseConstituent, Shift, CloseConstituent, OpenConstituent(('ROOT',)), CloseConstituent] + +# A sample tree from PTB with a double unary transition (at a location other than root) +DOUBLE_UNARY_TREE = """ +( (S + (NP-SBJ + (NP (RB Not) (PDT all) (DT those) ) + (SBAR + (WHNP-3 (WP who) ) + (S + (NP-SBJ (-NONE- *T*-3) ) + (VP (VBD wrote) )))) + (VP (VBP oppose) + (NP (DT the) (NNS changes) )) + (. .) )) +""" + +# A sample tree from PTB with a triple unary transition (at a location other than root) +# The triple unary is at the START of the next bracket, which affects how the +# dynamic oracle repairs the transition sequence +TRIPLE_UNARY_START_TREE = """ +( (S + (PRN + (S + (NP-SBJ (-NONE- *) ) + (VP (VB See) ))) + (, ,) + (NP-SBJ + (NP (DT the) (JJ other) (NN rule) ) + (PP (IN of) + (NP (NN thumb) )) + (PP (IN about) + (NP (NN ballooning) ))))) +""" + +# A sample tree from PTB with a triple unary transition (at a location other than root) +# The triple unary is at the END of the next bracket, which affects how the +# dynamic oracle repairs the transition sequence +TRIPLE_UNARY_END_TREE = """ +( (S + (NP (NNS optimists) ) + (VP (VBP expect) + (S + (NP-SBJ-4 (NNP Hong) (NNP Kong) ) + (VP (TO to) + (VP (VB hum) + (ADVP-CLR (RB along) ) + (SBAR-MNR (RB as) + (S + (NP-SBJ (-NONE- *-4) ) + (VP (-NONE- *?*) + (ADVP-TMP (IN before) )))))))))) +""" + +TREES = [SINGLE_UNARY_TREE, DOUBLE_UNARY_TREE, TRIPLE_UNARY_START_TREE, TRIPLE_UNARY_END_TREE] +TREEBANK = "\n".join(TREES) + +NOUN_PHRASE_TREE = """ +( (NP + (NP (NNP Chicago) (POS 's)) + (NNP Goodman) + (NNP Theatre))) +""" + +WIDE_NP_TREE = """ +( (S + (NP-SBJ (DT These) (NNS studies)) + (VP (VBP demonstrate) + (SBAR (IN that) + (S + (NP-SBJ (NNS mice)) + (VP (VBP are) + (NP-PRD + (NP (DT a) + (ADJP (JJ practical) + (CC and) + (JJ powerful)) + (JJ experimental) (NN system)) + (SBAR + (WHADVP-2 (-NONE- *0*)) + (S + (NP-SBJ (-NONE- *PRO*)) + (VP (TO to) + (VP (VB study) + (NP (DT the) (NN genetics))))))))))))) +""" + +WIDE_TREES = [NOUN_PHRASE_TREE, WIDE_NP_TREE] +WIDE_TREEBANK = "\n".join(WIDE_TREES) + +ROOT_LABELS = ["ROOT"] + +def get_repairs(gold_sequence, wrong_transition, repair_fn): + """ + Use the repair function and the wrong transition to iterate over the gold sequence + + Returns a list of possible repairs, one for each position in the sequence + Repairs are tuples, (idx, seq) + """ + repairs = [(idx, repair_fn(gold_transition, wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None)) + for idx, gold_transition in enumerate(gold_sequence)] + repairs = [x for x in repairs if x[1] is not None] + return repairs + +@pytest.fixture(scope="module") +def unary_trees(): + trees = tree_reader.read_trees(TREEBANK) + trees = [t.prune_none().simplify_labels() for t in trees] + assert len(trees) == len(TREES) + + return trees + +@pytest.fixture(scope="module") +def gold_sequences(unary_trees): + gold_sequences = build_treebank(unary_trees, TransitionScheme.IN_ORDER) + return gold_sequences + +@pytest.fixture(scope="module") +def wide_trees(): + trees = tree_reader.read_trees(WIDE_TREEBANK) + trees = [t.prune_none().simplify_labels() for t in trees] + assert len(trees) == len(WIDE_TREES) + + return trees + +def test_wrong_open_root(gold_sequences): + """ + Test the results of the dynamic oracle on a few trees if the ROOT is mishandled. + """ + wrong_transition = OpenConstituent("S") + gold_transition = OpenConstituent("ROOT") + close_transition = CloseConstituent() + + for gold_sequence in gold_sequences: + # each of the sequences should be ended with ROOT, Close + assert gold_sequence[-2] == gold_transition + + repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_root_error) + # there is only spot in the sequence with a ROOT, so there should + # be exactly one location which affords a S/ROOT replacement + assert len(repairs) == 1 + repair = repairs[0] + + # the repair should occur at the -2 position, which is where ROOT is + assert repair[0] == len(gold_sequence) - 2 + # and the resulting list should have the wrong transition followed by a Close + # to give the model another chance to close the tree + expected = gold_sequence[:-2] + [wrong_transition, close_transition] + gold_sequence[-2:] + assert repair[1] == expected + +def test_missed_unary(gold_sequences): + """ + Test the repairs of an open/open error if it is effectively a skipped unary transition + """ + wrong_transition = OpenConstituent("S") + + repairs = get_repairs(gold_sequences[0], wrong_transition, fix_wrong_open_unary_chain) + assert len(repairs) == 0 + + # here we are simulating picking NT-S instead of NT-VP + # the DOUBLE_UNARY tree has one location where this is relevant, index 11 + repairs = get_repairs(gold_sequences[1], wrong_transition, fix_wrong_open_unary_chain) + assert len(repairs) == 1 + assert repairs[0][0] == 11 + assert repairs[0][1] == gold_sequences[1][:11] + gold_sequences[1][13:] + + # the TRIPLE_UNARY_START tree has two locations where this is relevant + # at index 1, the pattern goes (S (VP ...)) + # so choosing S instead of VP means you can skip the VP and only miss that one bracket + # at index 5, the pattern goes (S (PRN (S (VP ...))) (...)) + # note that this is capturing a unary transition into a larger constituent + # skipping the PRN is satisfactory + repairs = get_repairs(gold_sequences[2], wrong_transition, fix_wrong_open_unary_chain) + assert len(repairs) == 2 + assert repairs[0][0] == 1 + assert repairs[0][1] == gold_sequences[2][:1] + gold_sequences[2][3:] + assert repairs[1][0] == 5 + assert repairs[1][1] == gold_sequences[2][:5] + gold_sequences[2][7:] + + # The TRIPLE_UNARY_END tree has 2 sections of tree for a total of 3 locations + # where the repair might happen + # Surprisingly the unary transition at the very start can only be + # repaired by skipping it and using the outer S transition instead + # The second repair overall (first repair in the second location) + # should have a double skip to reach the S node + repairs = get_repairs(gold_sequences[3], wrong_transition, fix_wrong_open_unary_chain) + assert len(repairs) == 3 + assert repairs[0][0] == 1 + assert repairs[0][1] == gold_sequences[3][:1] + gold_sequences[3][3:] + assert repairs[1][0] == 21 + assert repairs[1][1] == gold_sequences[3][:21] + gold_sequences[3][25:] + assert repairs[2][0] == 23 + assert repairs[2][1] == gold_sequences[3][:23] + gold_sequences[3][25:] + + +def test_open_with_stuff(unary_trees, gold_sequences): + wrong_transition = OpenConstituent("S") + expected_trees = [ + "(ROOT (S (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))", + "(ROOT (S (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))", + None, + "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))" + ] + + for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees): + repairs = get_repairs(gold_sequence, wrong_transition, fix_wrong_open_stuff_unary) + if expected is None: + assert len(repairs) == 0 + else: + assert len(repairs) == 1 + result = reconstruct_tree(tree, repairs[0][1]) + assert str(result) == expected + +def test_general_open(gold_sequences): + wrong_transition = OpenConstituent("SBARQ") + + for sequence in gold_sequences: + repairs = get_repairs(sequence, wrong_transition, fix_wrong_open_general) + assert len(repairs) == sum(isinstance(x, OpenConstituent) for x in sequence) - 1 + for repair in repairs: + assert len(repair[1]) == len(sequence) + assert repair[1][repair[0]] == wrong_transition + assert repair[1][:repair[0]] == sequence[:repair[0]] + assert repair[1][repair[0]+1:] == sequence[repair[0]+1:] + +def test_missed_unary(unary_trees, gold_sequences): + shift_transition = Shift() + close_transition = CloseConstituent() + + expected_close_results = [ + [(12, 2)], + [(11, 4), (13, 2)], + # (NP NN thumb) and (NP NN ballooning) are both candidates for this repair + [(18, 2), (24, 2)], + [(21, 6), (23, 4), (25, 2)], + ] + + expected_shift_results = [ + (), + (), + (), + # (ADVP-CLR (RB along)) is followed by a shift + [(16, 2)], + ] + + for tree, sequence, expected_close, expected_shift in zip(unary_trees, gold_sequences, expected_close_results, expected_shift_results): + repairs = get_repairs(sequence, close_transition, fix_missed_unary) + assert len(repairs) == len(expected_close) + for repair, (expected_idx, expected_len) in zip(repairs, expected_close): + assert repair[0] == expected_idx + assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:] + + repairs = get_repairs(sequence, shift_transition, fix_missed_unary) + assert len(repairs) == len(expected_shift) + for repair, (expected_idx, expected_len) in zip(repairs, expected_shift): + assert repair[0] == expected_idx + assert repair[1] == sequence[:expected_idx] + sequence[expected_idx+expected_len:] + +def test_open_shift(unary_trees, gold_sequences): + shift_transition = Shift() + + expected_repairs = [ + [(7, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"), + (10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VBN been) (VP (VBN set))) (. .)))")], + [(7, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WP who) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), + (9, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), + (19, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose) (NP (DT the) (NNS changes)) (. .)))"), + (21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (DT the) (NNS changes)) (. .)))")], + [(14, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"), + (16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"), + (22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about) (NP (NN ballooning)))))")], + [(5, "(ROOT (S (NP (NNS optimists)) (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (10, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (RB as) (S (VP (ADVP (IN before))))))))))")] + ] + + for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs): + repairs = get_repairs(sequence, shift_transition, fix_open_shift) + assert len(repairs) == len(expected) + for repair, (idx, expected_tree) in zip(repairs, expected): + assert repair[0] == idx + result_tree = reconstruct_tree(tree, repair[1]) + assert str(result_tree) == expected_tree + + +def test_open_close(unary_trees, gold_sequences): + close_transition = CloseConstituent() + + expected_repairs = [ + [(7, "(ROOT (S (S (NP (DT A) (NN record) (NN date)) (VBZ has)) (RB n't) (VP (VBN been) (VP (VBN set))) (. .)))"), + (10, "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VP (VBZ has) (RB n't) (VBN been)) (VP (VBN set))) (. .)))")], + # missed the WHNP. The surrounding SBAR cannot be created, either + [(7, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), + # missed the SBAR + (9, "(ROOT (S (NP (NP (NP (RB Not) (PDT all) (DT those)) (WHNP (WP who))) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"), + # missed the VP around "oppose the changes" + (19, "(ROOT (S (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VBP oppose)) (NP (DT the) (NNS changes)) (. .)))"), + # missed the NP in "the changes", looks pretty bad tbh + (21, "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VP (VBP oppose) (DT the)) (NNS changes)) (. .)))")], + [(14, "(ROOT (S (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule))) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))"), + (16, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (IN of)) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning))))))"), + (22, "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (IN about)) (NP (NN ballooning)))))")], + [(5, "(ROOT (S (S (NP (NNS optimists)) (VBP expect)) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (10, "(ROOT (S (NP (NNS optimists)) (VP (VP (VBP expect) (NP (NNP Hong) (NNP Kong))) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (12, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (S (NP (NNP Hong) (NNP Kong)) (TO to)) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (14, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (VP (TO to) (VB hum)) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))"), + (19, "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (ADVP (RB along)) (RB as)) (S (VP (ADVP (IN before))))))))))")] + ] + + for tree, sequence, expected in zip(unary_trees, gold_sequences, expected_repairs): + repairs = get_repairs(sequence, close_transition, fix_open_close) + + assert len(repairs) == len(expected) + for repair, (idx, expected_tree) in zip(repairs, expected): + assert repair[0] == idx + result_tree = reconstruct_tree(tree, repair[1]) + assert str(result_tree) == expected_tree + +def test_shift_close(unary_trees, gold_sequences): + """ + Test the fix for a shift -> close + + These errors can occur pretty much everywhere, and the fix is quite simple, + so we only test a few cases. + """ + + close_transition = CloseConstituent() + + expected_tree = "(ROOT (S (NP (NP (DT A)) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)))" + + repairs = get_repairs(gold_sequences[0], close_transition, fix_shift_close) + assert len(repairs) == 7 + result_tree = reconstruct_tree(unary_trees[0], repairs[0][1]) + assert str(result_tree) == expected_tree + + repairs = get_repairs(gold_sequences[1], close_transition, fix_shift_close) + assert len(repairs) == 8 + + repairs = get_repairs(gold_sequences[2], close_transition, fix_shift_close) + assert len(repairs) == 8 + + repairs = get_repairs(gold_sequences[3], close_transition, fix_shift_close) + assert len(repairs) == 9 + for rep in repairs: + if rep[0] == 16: + # This one is special because it occurs as part of a unary + # in other words, it should go unary, shift + # and instead we are making it close where the unary should be + # ... the unary would create "(ADVP (RB along))" + expected_tree = "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VP (VB hum) (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before)))))))))))" + result_tree = reconstruct_tree(unary_trees[3], rep[1]) + assert str(result_tree) == expected_tree + break + else: + raise AssertionError("Did not find an expected repair location") + +def test_close_open_shift_nested(unary_trees, gold_sequences): + shift_transition = Shift() + + expected_trees = [{}, + {4: "(ROOT (S (NP (RB Not) (PDT all) (DT those) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"}, + {4: "(ROOT (S (VP (VB See)) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))", + 13: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (DT the) (JJ other) (NN rule) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"}, + {}] + + for tree, gold_sequence, expected in zip(unary_trees, gold_sequences, expected_trees): + repairs = get_repairs(gold_sequence, shift_transition, fix_close_open_shift_nested) + assert len(repairs) == len(expected) + if len(expected) >= 1: + for repair in repairs: + assert repair[0] in expected.keys() + result_tree = reconstruct_tree(tree, repair[1]) + assert str(result_tree) == expected[repair[0]] + +def check_repairs(trees, gold_sequences, expected_trees, transition, repair_fn): + for tree_idx, (gold_tree, gold_sequence, expected) in enumerate(zip(trees, gold_sequences, expected_trees)): + repairs = get_repairs(gold_sequence, transition, repair_fn) + if expected is not None: + assert len(repairs) == len(expected) + for repair in repairs: + assert repair[0] in expected + result_tree = reconstruct_tree(gold_tree, repair[1]) + assert str(result_tree) == expected[repair[0]] + else: + print("---------------------") + print("{:P}".format(gold_tree)) + print(gold_sequence) + #print(repairs) + for repair in repairs: + print("---------------------") + print(gold_sequence) + print(repair[1]) + result_tree = reconstruct_tree(gold_tree, repair[1]) + print("{:P}".format(gold_tree)) + print("{:P}".format(result_tree)) + print(tree_idx) + print(repair[0]) + print(result_tree) + +def test_close_open_shift_unambiguous(unary_trees, gold_sequences): + shift_transition = Shift() + + expected_trees = [{}, + {8: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who) (S (VP (VBD wrote)))))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .)))"}, + {}, + {2: "(ROOT (S (NP (NNS optimists) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))", + 9: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong) (VP (TO to) (VP (VB hum) (ADVP (RB along)) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"}] + check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_unambiguous_bracket) + +def test_close_open_shift_ambiguous_early(unary_trees, gold_sequences): + shift_transition = Shift() + + expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))))) (. .)))"}, + {16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes)))) (. .)))"}, + {2: "(ROOT (S (PRN (S (VP (VB See) (, ,)))) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))", + 6: "(ROOT (S (PRN (S (VP (VB See))) (, ,)) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning))))))"}, + {}] + check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_early) + +def test_close_open_shift_ambiguous_late(unary_trees, gold_sequences): + shift_transition = Shift() + + expected_trees = [{4: "(ROOT (S (NP (DT A) (NN record) (NN date) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))))"}, + {16: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote)))) (VP (VBP oppose) (NP (DT the) (NNS changes))) (. .))))"}, + {2: "(ROOT (S (PRN (S (VP (VB See) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))))", + 6: "(ROOT (S (PRN (S (VP (VB See))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb))) (PP (IN about) (NP (NN ballooning)))))))"}, + {}] + check_repairs(unary_trees, gold_sequences, expected_trees, shift_transition, fix_close_open_shift_ambiguous_bracket_late) + + +def test_close_shift_shift(unary_trees, wide_trees): + """ + Test that close -> shift works when there is a single block shifted after + + Includes a test specifically that there is no oracle action when there are two blocks after the missed close + """ + shift_transition = Shift() + + expected_trees = [{15: "(ROOT (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set))) (. .))))"}, + {24: "(ROOT (S (NP (NP (RB Not) (PDT all) (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBD wrote))))) (VP (VBP oppose) (NP (DT the) (NNS changes)) (. .))))"}, + {20: "(ROOT (S (PRN (S (VP (VB See)))) (, ,) (NP (NP (DT the) (JJ other) (NN rule)) (PP (IN of) (NP (NN thumb)) (PP (IN about) (NP (NN ballooning)))))))"}, + {17: "(ROOT (S (NP (NNS optimists)) (VP (VBP expect) (S (NP (NNP Hong) (NNP Kong)) (VP (TO to) (VP (VB hum) (ADVP (RB along) (SBAR (RB as) (S (VP (ADVP (IN before))))))))))))"}, + {}, + {}] + + test_trees = unary_trees + wide_trees + gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER) + + check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_unambiguous) + + +def test_close_shift_shift_early(unary_trees, wide_trees): + """ + Test that close -> shift works when there are multiple blocks shifted after + + Also checks that the single block case is skipped, so as to keep them separate when testing + + A tree with the expected property was specifically added for this test + """ + shift_transition = Shift() + + test_trees = unary_trees + wide_trees + gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER) + + expected_trees = [{}, + {}, + {}, + {}, + {}, + {21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental)) (NN system)) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}] + + check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_early) + +def test_close_shift_shift_late(unary_trees, wide_trees): + """ + Test that close -> shift works when there are multiple blocks shifted after + + Also checks that the single block case is skipped, so as to keep them separate when testing + + A tree with the expected property was specifically added for this test + """ + shift_transition = Shift() + + test_trees = unary_trees + wide_trees + gold_sequences = build_treebank(test_trees, TransitionScheme.IN_ORDER) + + expected_trees = [{}, + {}, + {}, + {}, + {}, + {21: "(ROOT (S (NP (DT These) (NNS studies)) (VP (VBP demonstrate) (SBAR (IN that) (S (NP (NNS mice)) (VP (VBP are) (NP (NP (DT a) (ADJP (JJ practical) (CC and) (JJ powerful) (JJ experimental) (NN system))) (SBAR (S (VP (TO to) (VP (VB study) (NP (DT the) (NN genetics)))))))))))))"}] + + check_repairs(test_trees, gold_sequences, expected_trees, shift_transition, fix_close_shift_shift_ambiguous_late) diff --git a/stanza/stanza/tests/constituency/test_lstm_model.py b/stanza/stanza/tests/constituency/test_lstm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..421872cc1d4ddcea89ad8bf869f5ebcceef7ee90 --- /dev/null +++ b/stanza/stanza/tests/constituency/test_lstm_model.py @@ -0,0 +1,552 @@ +import os + +import pytest +import torch + +from stanza.models.common import pretrain +from stanza.models.common.utils import set_random_seed +from stanza.models.constituency import parse_transitions +from stanza.tests import * +from stanza.tests.constituency import test_parse_transitions +from stanza.tests.constituency.test_trainer import build_trainer + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +@pytest.fixture(scope="module") +def pretrain_file(): + return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' + +def build_model(pretrain_file, *args): + # By default, we turn off multistage, since that can turn off various other structures in the initial training + args = ['--no_multistage', '--pattn_num_layers', '4', '--pattn_d_model', '256', '--hidden_size', '128', '--use_lattn'] + list(args) + trainer = build_trainer(pretrain_file, *args) + return trainer.model + +@pytest.fixture(scope="module") +def unary_model(pretrain_file): + return build_model(pretrain_file, "--transition_scheme", "TOP_DOWN_UNARY") + +def test_initial_state(unary_model): + test_parse_transitions.test_initial_state(unary_model) + +def test_shift(pretrain_file): + # TODO: might be good to include some tests specifically for shift + # in the context of a model with unaries + model = build_model(pretrain_file) + test_parse_transitions.test_shift(model) + +def test_unary(unary_model): + test_parse_transitions.test_unary(unary_model) + +def test_unary_requires_root(unary_model): + test_parse_transitions.test_unary_requires_root(unary_model) + +def test_open(unary_model): + test_parse_transitions.test_open(unary_model) + +def test_compound_open(pretrain_file): + model = build_model(pretrain_file, '--transition_scheme', "TOP_DOWN_COMPOUND") + test_parse_transitions.test_compound_open(model) + +def test_in_order_open(pretrain_file): + model = build_model(pretrain_file, '--transition_scheme', "IN_ORDER") + test_parse_transitions.test_in_order_open(model) + +def test_close(unary_model): + test_parse_transitions.test_close(unary_model) + +def run_forward_checks(model, num_states=1): + """ + Run a couple small transitions and a forward pass on the given model + + Results are not checked in any way. This function allows for + testing that building models with various options results in a + functional model. + """ + states = test_parse_transitions.build_initial_state(model, num_states) + model(states) + + shift = parse_transitions.Shift() + shifts = [shift for _ in range(num_states)] + states = model.bulk_apply(states, shifts) + model(states) + + open_transition = parse_transitions.OpenConstituent("NP") + open_transitions = [open_transition for _ in range(num_states)] + assert open_transition.is_legal(states[0], model) + states = model.bulk_apply(states, open_transitions) + assert states[0].num_opens == 1 + model(states) + + states = model.bulk_apply(states, shifts) + model(states) + states = model.bulk_apply(states, shifts) + model(states) + assert states[0].num_opens == 1 + # now should have "mox", "opal" on the constituents + + close_transition = parse_transitions.CloseConstituent() + close_transitions = [close_transition for _ in range(num_states)] + assert close_transition.is_legal(states[0], model) + states = model.bulk_apply(states, close_transitions) + assert states[0].num_opens == 0 + + model(states) + +def test_unary_forward(unary_model): + """ + Checks that the forward pass doesn't crash when run after various operations + + Doesn't check the forward pass for making reasonable answers + """ + run_forward_checks(unary_model) + +def test_lstm_forward(pretrain_file): + model = build_model(pretrain_file) + run_forward_checks(model, num_states=1) + run_forward_checks(model, num_states=2) + +def test_lstm_layers(pretrain_file): + model = build_model(pretrain_file, '--num_lstm_layers', '1') + run_forward_checks(model) + model = build_model(pretrain_file, '--num_lstm_layers', '2') + run_forward_checks(model) + model = build_model(pretrain_file, '--num_lstm_layers', '3') + run_forward_checks(model) + +def test_multiple_output_forward(pretrain_file): + """ + Test a couple different sizes of output layers + """ + model = build_model(pretrain_file, '--num_output_layers', '1', '--num_lstm_layers', '2') + run_forward_checks(model) + + model = build_model(pretrain_file, '--num_output_layers', '2', '--num_lstm_layers', '2') + run_forward_checks(model) + + model = build_model(pretrain_file, '--num_output_layers', '3', '--num_lstm_layers', '2') + run_forward_checks(model) + +def test_no_tag_embedding_forward(pretrain_file): + """ + Test that the model continues to work if the tag embedding is turned on or off + """ + model = build_model(pretrain_file, '--tag_embedding_dim', '20') + run_forward_checks(model) + + model = build_model(pretrain_file, '--tag_embedding_dim', '0') + run_forward_checks(model) + +def test_forward_combined_dummy(pretrain_file): + """ + Tests combined dummy and open node embeddings + """ + model = build_model(pretrain_file, '--combined_dummy_embedding') + run_forward_checks(model) + + model = build_model(pretrain_file, '--no_combined_dummy_embedding') + run_forward_checks(model) + +def test_nonlinearity_init(pretrain_file): + """ + Tests that different initialization methods of the nonlinearities result in valid tensors + """ + model = build_model(pretrain_file, '--nonlinearity', 'relu') + run_forward_checks(model) + + model = build_model(pretrain_file, '--nonlinearity', 'tanh') + run_forward_checks(model) + + model = build_model(pretrain_file, '--nonlinearity', 'silu') + run_forward_checks(model) + +def test_forward_charlm(pretrain_file): + """ + Tests loading and running a charlm + + Note that this doesn't test the results of the charlm itself, + just that the model is shaped correctly + """ + forward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "1billion.pt") + backward_charlm_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "1billion.pt") + assert os.path.exists(forward_charlm_path), "Need to download en test models (or update path to the forward charlm)" + assert os.path.exists(backward_charlm_path), "Need to download en test models (or update path to the backward charlm)" + + model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'none') + run_forward_checks(model) + + model = build_model(pretrain_file, '--charlm_forward_file', forward_charlm_path, '--charlm_backward_file', backward_charlm_path, '--sentence_boundary_vectors', 'words') + run_forward_checks(model) + +def test_forward_bert(pretrain_file): + """ + Test on a tiny Bert, which hopefully does not take up too much disk space or memory + """ + bert_model = "hf-internal-testing/tiny-bert" + + model = build_model(pretrain_file, '--bert_model', bert_model) + run_forward_checks(model) + + +def test_forward_xlnet(pretrain_file): + """ + Test on a tiny xlnet, which hopefully does not take up too much disk space or memory + """ + bert_model = "hf-internal-testing/tiny-random-xlnet" + + model = build_model(pretrain_file, '--bert_model', bert_model) + run_forward_checks(model) + + +def test_forward_sentence_boundaries(pretrain_file): + """ + Test start & stop boundary vectors + """ + model = build_model(pretrain_file, '--sentence_boundary_vectors', 'everything') + run_forward_checks(model) + + model = build_model(pretrain_file, '--sentence_boundary_vectors', 'words') + run_forward_checks(model) + + model = build_model(pretrain_file, '--sentence_boundary_vectors', 'none') + run_forward_checks(model) + +def test_forward_constituency_composition(pretrain_file): + """ + Test different constituency composition functions + """ + model = build_model(pretrain_file, '--constituency_composition', 'bilstm') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'max') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'key') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'untied_key') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'untied_max') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'bilstm_max') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'tree_lstm_cx') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'bigram') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'attn') + run_forward_checks(model, num_states=2) + +def test_forward_key_position(pretrain_file): + """ + Test KEY and UNTIED_KEY either with or without reduce_position + """ + model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '0') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'untied_key', '--reduce_position', '32') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '0') + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'key', '--reduce_position', '32') + run_forward_checks(model, num_states=2) + + +def test_forward_attn_hidden_size(pretrain_file): + """ + Test that when attn is used with hidden sizes not evenly divisible by reduce_heads, the model reconfigures the hidden_size + """ + model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129') + assert model.hidden_size >= 129 + assert model.hidden_size % model.reduce_heads == 0 + run_forward_checks(model, num_states=2) + + model = build_model(pretrain_file, '--constituency_composition', 'attn', '--hidden_size', '129', '--reduce_heads', '10') + assert model.hidden_size == 130 + assert model.reduce_heads == 10 + +def test_forward_partitioned_attention(pretrain_file): + """ + Test with & without partitioned attention layers + """ + model = build_model(pretrain_file, '--pattn_num_heads', '8', '--pattn_num_layers', '8') + run_forward_checks(model) + + model = build_model(pretrain_file, '--pattn_num_heads', '0', '--pattn_num_layers', '0') + run_forward_checks(model) + +def test_forward_labeled_attention(pretrain_file): + """ + Test with & without labeled attention layers + """ + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16') + run_forward_checks(model) + + model = build_model(pretrain_file, '--lattn_d_proj', '0', '--lattn_d_l', '0') + run_forward_checks(model) + + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_combined_input') + run_forward_checks(model) + +def test_lattn_partitioned(pretrain_file): + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_partitioned') + run_forward_checks(model) + + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned') + run_forward_checks(model) + + +def test_lattn_projection(pretrain_file): + """ + Test with & without labeled attention layers + """ + with pytest.raises(ValueError): + # this is too small + model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '256', '--lattn_partitioned') + run_forward_checks(model) + + model = build_model(pretrain_file, '--pattn_d_model', '1024', '--lattn_d_proj', '64', '--lattn_d_l', '16', '--no_lattn_partitioned', '--lattn_d_input_proj', '256') + run_forward_checks(model) + + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '768') + run_forward_checks(model) + + # check that it works if we turn off the projection, + # in case having it on beccomes the default + model = build_model(pretrain_file, '--lattn_d_proj', '64', '--lattn_d_l', '16', '--lattn_d_input_proj', '0') + run_forward_checks(model) + +def test_forward_timing_choices(pretrain_file): + """ + Test different timing / position encodings + """ + model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'sin') + run_forward_checks(model) + + model = build_model(pretrain_file, '--pattn_num_heads', '4', '--pattn_num_layers', '4', '--pattn_timing', 'learned') + run_forward_checks(model) + +def test_transition_stack(pretrain_file): + """ + Test different transition stack types: lstm & attention + """ + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_stack', 'attn', '--transition_heads', '1') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_stack', 'attn', '--transition_heads', '4') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_stack', 'lstm') + run_forward_checks(model) + +def test_constituent_stack(pretrain_file): + """ + Test different constituent stack types: lstm & attention + """ + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--constituent_stack', 'attn', '--constituent_heads', '1') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--constituent_stack', 'attn', '--constituent_heads', '4') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--constituent_stack', 'lstm') + run_forward_checks(model) + +def test_different_transition_sizes(pretrain_file): + """ + If the transition hidden size and embedding size are different, the model should still work + """ + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_embedding_dim', '10', '--transition_hidden_size', '10', + '--sentence_boundary_vectors', 'everything') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_embedding_dim', '20', '--transition_hidden_size', '10', + '--sentence_boundary_vectors', 'everything') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_embedding_dim', '10', '--transition_hidden_size', '20', + '--sentence_boundary_vectors', 'everything') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_embedding_dim', '10', '--transition_hidden_size', '10', + '--sentence_boundary_vectors', 'none') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_embedding_dim', '20', '--transition_hidden_size', '10', + '--sentence_boundary_vectors', 'none') + run_forward_checks(model) + + model = build_model(pretrain_file, + '--pattn_num_layers', '0', '--lattn_d_proj', '0', + '--transition_embedding_dim', '10', '--transition_hidden_size', '20', + '--sentence_boundary_vectors', 'none') + run_forward_checks(model) + + +def test_lstm_tree_forward(pretrain_file): + """ + Test the LSTM_TREE forward pass + """ + model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm') + run_forward_checks(model) + model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm') + run_forward_checks(model) + model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm') + run_forward_checks(model) + +def test_lstm_tree_cx_forward(pretrain_file): + """ + Test the LSTM_TREE_CX forward pass + """ + model = build_model(pretrain_file, '--num_tree_lstm_layers', '1', '--constituency_composition', 'tree_lstm_cx') + run_forward_checks(model) + model = build_model(pretrain_file, '--num_tree_lstm_layers', '2', '--constituency_composition', 'tree_lstm_cx') + run_forward_checks(model) + model = build_model(pretrain_file, '--num_tree_lstm_layers', '3', '--constituency_composition', 'tree_lstm_cx') + run_forward_checks(model) + +def test_maxout(pretrain_file): + """ + Test with and without maxout layers for output + """ + model = build_model(pretrain_file, '--maxout_k', '0') + run_forward_checks(model) + # check the output size & implicitly check the type + # to check for a particularly silly bug + assert model.output_layers[-1].weight.shape[0] == len(model.transitions) + + model = build_model(pretrain_file, '--maxout_k', '2') + run_forward_checks(model) + assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 2 + + model = build_model(pretrain_file, '--maxout_k', '3') + run_forward_checks(model) + assert model.output_layers[-1].linear.weight.shape[0] == len(model.transitions) * 3 + +def check_structure_test(pretrain_file, args1, args2): + """ + Test that the "copy" method copies the parameters from one model to another + + Also check that the copied models produce the same results + """ + set_random_seed(1000) + other = build_model(pretrain_file, *args1) + other.eval() + + set_random_seed(1001) + model = build_model(pretrain_file, *args2) + model.eval() + + assert not torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) + assert not torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight) + + model.copy_with_new_structure(other) + + assert torch.allclose(model.delta_embedding.weight, other.delta_embedding.weight) + assert torch.allclose(model.output_layers[0].weight, other.output_layers[0].weight) + # the norms will be the same, as the non-zero values are all the same + assert torch.allclose(torch.linalg.norm(model.word_lstm.weight_ih_l0), torch.linalg.norm(other.word_lstm.weight_ih_l0)) + + # now, check that applying one transition to an initial state + # results in the same values in the output states for both models + # as the pattn layer inputs are 0, the output values should be equal + shift = [parse_transitions.Shift()] + model_states = test_parse_transitions.build_initial_state(model, 1) + model_states = model.bulk_apply(model_states, shift) + + other_states = test_parse_transitions.build_initial_state(other, 1) + other_states = other.bulk_apply(other_states, shift) + + for i, j in zip(other_states[0].word_queue, model_states[0].word_queue): + assert torch.allclose(i.hx, j.hx, atol=1e-07) + for i, j in zip(other_states[0].transitions, model_states[0].transitions): + assert torch.allclose(i.lstm_hx, j.lstm_hx) + assert torch.allclose(i.lstm_cx, j.lstm_cx) + for i, j in zip(other_states[0].constituents, model_states[0].constituents): + assert (i.value is None) == (j.value is None) + if i.value is not None: + assert torch.allclose(i.value.tree_hx, j.value.tree_hx, atol=1e-07) + assert torch.allclose(i.lstm_hx, j.lstm_hx) + assert torch.allclose(i.lstm_cx, j.lstm_cx) + +def test_copy_with_new_structure_same(pretrain_file): + """ + Test that copying the structure with no changes works as expected + """ + check_structure_test(pretrain_file, + ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'], + ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']) + +def test_copy_with_new_structure_untied(pretrain_file): + """ + Test that copying the structure with no changes works as expected + """ + check_structure_test(pretrain_file, + ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'MAX'], + ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--constituency_composition', 'UNTIED_MAX']) + +def test_copy_with_new_structure_pattn(pretrain_file): + check_structure_test(pretrain_file, + ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'], + ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2']) + +def test_copy_with_new_structure_both(pretrain_file): + check_structure_test(pretrain_file, + ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'], + ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2']) + +def test_copy_with_new_structure_lattn(pretrain_file): + check_structure_test(pretrain_file, + ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2'], + ['--pattn_num_layers', '1', '--lattn_d_proj', '32', '--hidden_size', '20', '--delta_embedding_dim', '10', '--pattn_d_model', '20', '--pattn_num_heads', '2']) + +def test_parse_tagged_words(pretrain_file): + """ + Small test which doesn't check results, just execution + """ + model = build_model(pretrain_file) + + sentence = [("I", "PRP"), ("am", "VBZ"), ("Luffa", "NNP")] + + # we don't expect a useful tree out of a random model + # so we don't check the result + # just check that it works without crashing + result = model.parse_tagged_words([sentence], 10) + assert len(result) == 1 + pts = [x for x in result[0].yield_preterminals()] + + for word, pt in zip(sentence, pts): + assert pt.children[0].label == word[0] + assert pt.label == word[1] diff --git a/stanza/stanza/tests/constituency/test_text_processing.py b/stanza/stanza/tests/constituency/test_text_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..98d6adbeffe5c719a77686a3923f3acb6070972d --- /dev/null +++ b/stanza/stanza/tests/constituency/test_text_processing.py @@ -0,0 +1,109 @@ +""" +Run through the various text processing methods for using the parser on text files / directories + +Uses a simple tree where the parser should always get it right, but things could potentially go wrong +""" + +import glob +import os +import pytest + +from stanza import Pipeline + +from stanza.models.constituency import text_processing +from stanza.models.constituency import tree_reader +from stanza.tests import TEST_MODELS_DIR + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +@pytest.fixture(scope="module") +def pipeline(): + return Pipeline(dir=TEST_MODELS_DIR, lang="en", processors="tokenize, pos, constituency", tokenize_pretokenized=True) + +def test_read_tokenized_file(tmp_path): + filename = str(tmp_path / "test_input.txt") + with open(filename, "w") as fout: + # test that the underscore token comes back with spaces + fout.write("This is a_small test\nLine two\n") + text, ids = text_processing.read_tokenized_file(filename) + assert text == [['This', 'is', 'a small', 'test'], ['Line', 'two']] + assert ids == [None, None] + +def test_parse_tokenized_sentences(pipeline): + con_processor = pipeline.processors["constituency"] + model = con_processor._model + args = model.args + + sentences = [["This", "is", "a", "test"]] + trees = text_processing.parse_tokenized_sentences(args, model, [pipeline], sentences) + predictions = [x.predictions for x in trees] + assert len(predictions) == 1 + scored_trees = predictions[0] + assert len(scored_trees) == 1 + result = "{}".format(scored_trees[0].tree) + expected = "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))" + assert result == expected + +def test_parse_text(tmp_path, pipeline): + con_processor = pipeline.processors["constituency"] + model = con_processor._model + args = model.args + + raw_file = str(tmp_path / "test_input.txt") + with open(raw_file, "w") as fout: + fout.write("This is a test\nThis is another test\n") + output_file = str(tmp_path / "test_output.txt") + text_processing.parse_text(args, model, [pipeline], tokenized_file=raw_file, predict_file=output_file) + + trees = tree_reader.read_treebank(output_file) + trees = ["{}".format(x) for x in trees] + expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", + "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] + assert trees == expected_trees + +def test_parse_dir(tmp_path, pipeline): + con_processor = pipeline.processors["constituency"] + model = con_processor._model + args = model.args + + raw_dir = str(tmp_path / "input") + os.makedirs(raw_dir) + raw_f1 = str(tmp_path / "input" / "f1.txt") + raw_f2 = str(tmp_path / "input" / "f2.txt") + output_dir = str(tmp_path / "output") + + with open(raw_f1, "w") as fout: + fout.write("This is a test") + with open(raw_f2, "w") as fout: + fout.write("This is another test") + + text_processing.parse_dir(args, model, [pipeline], raw_dir, output_dir) + output_files = sorted(glob.glob(os.path.join(output_dir, "*"))) + expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", + "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] + for output_file, expected_tree in zip(output_files, expected_trees): + trees = tree_reader.read_treebank(output_file) + assert len(trees) == 1 + assert "{}".format(trees[0]) == expected_tree + +def test_parse_text(tmp_path, pipeline): + con_processor = pipeline.processors["constituency"] + model = con_processor._model + args = dict(model.args) + + model_path = con_processor._config['model_path'] + + raw_file = str(tmp_path / "test_input.txt") + with open(raw_file, "w") as fout: + fout.write("This is a test\nThis is another test\n") + output_file = str(tmp_path / "test_output.txt") + + args['tokenized_file'] = raw_file + args['predict_file'] = output_file + + text_processing.load_model_parse_text(args, model_path, [pipeline]) + trees = tree_reader.read_treebank(output_file) + trees = ["{}".format(x) for x in trees] + expected_trees = ["(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT a) (NN test)))))", + "(ROOT (S (NP (DT This)) (VP (VBZ is) (NP (DT another) (NN test)))))"] + assert trees == expected_trees diff --git a/stanza/stanza/tests/constituency/test_top_down_oracle.py b/stanza/stanza/tests/constituency/test_top_down_oracle.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bf1864be9a6c93bba6579e3a27beacf3a2adfd --- /dev/null +++ b/stanza/stanza/tests/constituency/test_top_down_oracle.py @@ -0,0 +1,443 @@ +import pytest + +from stanza.models.constituency.base_model import SimpleModel +from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent, TransitionScheme +from stanza.models.constituency.top_down_oracle import * +from stanza.models.constituency.transition_sequence import build_sequence +from stanza.models.constituency.tree_reader import read_trees + +from stanza.tests.constituency.test_transition_sequence import reconstruct_tree + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +OPEN_SHIFT_EXAMPLE_TREE = """ +( (S + (NP (NNP Jennifer) (NNP Sh\'reyan)) + (VP (VBZ has) + (NP (RB nice) (NNS antennae))))) +""" + +OPEN_SHIFT_PROBLEM_TREE = """ +(ROOT (S (NP (NP (NP (DT The) (`` ``) (JJ Thin) (NNP Man) ('' '') (NN series)) (PP (IN of) (NP (NNS movies)))) (, ,) (CONJP (RB as) (RB well) (IN as)) (NP (JJ many) (NNS others)) (, ,)) (VP (VBD based) (NP (PRP$ their) (JJ entire) (JJ comedic) (NN appeal)) (PP (IN on) (NP (NP (DT the) (NN star) (NNS detectives) (POS ')) (JJ witty) (NNS quips) (CC and) (NNS puns))) (SBAR (IN as) (S (NP (NP (JJ other) (NNS characters)) (PP (IN in) (NP (DT the) (NNS movies)))) (VP (VBD were) (VP (VBN murdered)))))) (. .))) +""" + +ROOT_LABELS = ["ROOT"] + +def get_single_repair(gold_sequence, wrong_transition, repair_fn, idx, *args, **kwargs): + return repair_fn(gold_sequence[idx], wrong_transition, gold_sequence, idx, ROOT_LABELS, None, None, *args, **kwargs) + +def build_state(model, tree, num_transitions): + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + states = model.initial_state_from_gold_trees([tree], [transitions]) + for idx, t in enumerate(transitions[:num_transitions]): + assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence) + states = model.bulk_apply(states, [t]) + state = states[0] + return state + +def test_fix_open_shift(): + trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + EXPECTED_FIX_EARLY = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + EXPECTED_FIX_LATE = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + + assert transitions == EXPECTED_ORIG + + new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2) + assert new_transitions == EXPECTED_FIX_EARLY + + new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 8) + assert new_transitions == EXPECTED_FIX_LATE + +def test_fix_open_shift_observed_error(): + """ + Ran into an error on this tree, need to fix it + + The problem is the multiple Open in a row all need to be removed when a Shift happens + """ + trees = read_trees(OPEN_SHIFT_PROBLEM_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + new_transitions = get_single_repair(transitions, Shift(), fix_one_open_shift, 2) + assert new_transitions is None + + new_transitions = get_single_repair(transitions, Shift(), fix_multiple_open_shift, 2) + + # Can break the expected transitions down like this: + # [OpenConstituent(('ROOT',)), OpenConstituent(('S',)), + # all gone: OpenConstituent(('NP',)), OpenConstituent(('NP',)), OpenConstituent(('NP',)), + # Shift, Shift, Shift, Shift, Shift, Shift, + # gone: CloseConstituent, + # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), Shift, CloseConstituent, CloseConstituent, + # gone: CloseConstituent, + # Shift, OpenConstituent(('CONJP',)), Shift, Shift, Shift, CloseConstituent, OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, Shift, + # gone: CloseConstituent, + # and then the rest: + # OpenConstituent(('VP',)), Shift, OpenConstituent(('NP',)), + # Shift, Shift, Shift, Shift, CloseConstituent, + # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), + # OpenConstituent(('NP',)), Shift, Shift, Shift, Shift, + # CloseConstituent, Shift, Shift, Shift, Shift, CloseConstituent, + # CloseConstituent, OpenConstituent(('SBAR',)), Shift, + # OpenConstituent(('S',)), OpenConstituent(('NP',)), + # OpenConstituent(('NP',)), Shift, Shift, CloseConstituent, + # OpenConstituent(('PP',)), Shift, OpenConstituent(('NP',)), + # Shift, Shift, CloseConstituent, CloseConstituent, + # CloseConstituent, OpenConstituent(('VP',)), Shift, + # OpenConstituent(('VP',)), Shift, CloseConstituent, + # CloseConstituent, CloseConstituent, CloseConstituent, + # CloseConstituent, Shift, CloseConstituent, CloseConstituent] + expected_transitions = [OpenConstituent('ROOT'), OpenConstituent('S'), Shift(), Shift(), Shift(), Shift(), Shift(), Shift(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), Shift(), OpenConstituent('CONJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), Shift(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('SBAR'), Shift(), OpenConstituent('S'), OpenConstituent('NP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('VP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] + + assert new_transitions == expected_transitions + +def test_open_open_ambiguous_unary_fix(): + trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + assert transitions == EXPECTED_ORIG + new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_unary, 2) + assert new_transitions == EXPECTED_FIX + + +def test_open_open_ambiguous_later_fix(): + trees = read_trees(OPEN_SHIFT_EXAMPLE_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + EXPECTED_ORIG = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + EXPECTED_FIX = [OpenConstituent('ROOT'), OpenConstituent('S'), OpenConstituent('VP'), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + assert transitions == EXPECTED_ORIG + new_transitions = get_single_repair(transitions, OpenConstituent('VP'), fix_open_open_ambiguous_later, 2) + assert new_transitions == EXPECTED_FIX + + +CLOSE_SHIFT_EXAMPLE_TREE = """ +( (NP (DT a) + (ADJP (NN stock) (HYPH -) (VBG picking)) + (NN tool))) +""" + +# not intended to be a correct tree +CLOSE_SHIFT_DEEP_EXAMPLE_TREE = """ +( (NP (DT a) + (VP (ADJP (NN stock) (HYPH -) (VBG picking))) + (NN tool))) +""" + +# not intended to be a correct tree +CLOSE_SHIFT_OPEN_EXAMPLE_TREE = """ +( (NP (DT a) + (ADJP (NN stock) (HYPH -) (VBG picking)) + (NP (NN tool)))) +""" + +CLOSE_SHIFT_AMBIGUOUS_TREE = """ +( (NP (DT a) + (ADJP (NN stock) (HYPH -) (VBG picking)) + (NN tool) + (NN foo))) +""" + +def test_fix_close_shift_ambiguous_immediate(): + """ + Test the result when a close/shift error occurs and we want to close the new, incorrect constituent immediately + """ + trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_later, 7) + expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] + expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + assert transitions == expected_original + assert new_sequence == expected_update + +def test_fix_close_shift_ambiguous_later(): + # test that the one with two shifts, which is ambiguous, gets rejected + trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift_ambiguous_immediate, 7) + expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] + expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] + assert transitions == expected_original + assert new_sequence == expected_update + +def test_oracle_with_optional_level(): + tree = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE)[0] + gold_sequence = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] + expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + assert transitions == gold_sequence + + oracle = TopDownOracle(ROOT_LABELS, 1, "", "") + + model = SimpleModel(transition_scheme=TransitionScheme.TOP_DOWN_UNARY, root_labels=ROOT_LABELS) + state = build_state(model, tree, 7) + fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8], + model=model, + state=state) + assert fix is RepairType.OTHER_CLOSE_SHIFT + assert new_sequence is None + + oracle = TopDownOracle(ROOT_LABELS, 1, "CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR", "") + fix, new_sequence = oracle.fix_error(pred_transition=gold_sequence[8], + model=model, + state=state) + assert fix is RepairType.CLOSE_SHIFT_AMBIGUOUS_IMMEDIATE_ERROR + assert new_sequence == expected_update + + +def test_fix_close_shift(): + """ + Test a tree of the kind we expect the close/shift to be able to get right + """ + trees = read_trees(CLOSE_SHIFT_EXAMPLE_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + + new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7) + + expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] + expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + assert transitions == expected_original + assert new_sequence == expected_update + + # test that the one with two shifts, which is ambiguous, gets rejected + trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + new_sequence = get_single_repair(transitions, transitions[8], fix_close_shift, 7) + assert new_sequence is None + +def test_fix_close_shift_deeper_tree(): + """ + Test a tree of the kind we expect the close/shift to be able to get right + """ + trees = read_trees(CLOSE_SHIFT_DEEP_EXAMPLE_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + + for count_opens in [True, False]: + new_sequence = get_single_repair(transitions, transitions[10], fix_close_shift, 8, count_opens=count_opens) + + expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), CloseConstituent(), CloseConstituent()] + expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('VP'), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + assert transitions == expected_original + assert new_sequence == expected_update + +def test_fix_close_shift_open_tree(): + """ + We would like the close/shift to get this case right as well + """ + trees = read_trees(CLOSE_SHIFT_OPEN_EXAMPLE_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + + new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift, 7, count_opens=False) + assert new_sequence is None + + new_sequence = get_single_repair(transitions, transitions[9], fix_close_shift_with_opens, 7) + + expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), OpenConstituent('NP'), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + expected_update = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + assert transitions == expected_original + assert new_sequence == expected_update + +CLOSE_OPEN_EXAMPLE_TREE = """ +( (VP (VBZ eat) + (NP (NN spaghetti)) + (PP (IN with) (DT a) (NN fork)))) +""" + +CLOSE_OPEN_DIFFERENT_LABEL_TREE = """ +( (VP (VBZ eat) + (NP (NN spaghetti)) + (NP (DT a) (NN fork)))) +""" + +CLOSE_OPEN_TWO_LABELS_TREE = """ +( (VP (VBZ eat) + (NP (NN spaghetti)) + (PP (IN with) (DT a) (NN fork)) + (PP (IN in) (DT a) (NN restaurant)))) +""" + +def test_fix_close_open(): + trees = read_trees(CLOSE_OPEN_EXAMPLE_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + + assert isinstance(transitions[5], CloseConstituent) + assert transitions[6] == OpenConstituent("PP") + + new_transitions = get_single_repair(transitions, transitions[6], fix_close_open_correct_open, 5) + + expected_original = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), CloseConstituent(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + expected_update = [OpenConstituent('ROOT'), OpenConstituent('VP'), Shift(), OpenConstituent('NP'), Shift(), OpenConstituent('PP'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), CloseConstituent(), CloseConstituent()] + + assert transitions == expected_original + assert new_transitions == expected_update + +def test_fix_close_open_invalid(): + for TREE in (CLOSE_OPEN_DIFFERENT_LABEL_TREE, CLOSE_OPEN_TWO_LABELS_TREE): + trees = read_trees(TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + + assert isinstance(transitions[5], CloseConstituent) + assert isinstance(transitions[6], OpenConstituent) + + new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5) + assert new_transitions is None + +def test_fix_close_open_ambiguous_immediate(): + """ + Test that a fix for an ambiguous close/open works as expected + """ + trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + assert isinstance(transitions[5], CloseConstituent) + assert isinstance(transitions[6], OpenConstituent) + + reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN) + assert tree == reconstructed + + new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open, 5, check_close=False) + reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN) + + expected = """ + ( (VP (VBZ eat) + (NP (NN spaghetti) + (PP (IN with) (DT a) (NN fork))) + (PP (IN in) (DT a) (NN restaurant)))) + """ + expected = read_trees(expected)[0] + assert reconstructed == expected + +def test_fix_close_open_ambiguous_later(): + """ + Test that a fix for an ambiguous close/open works as expected + """ + trees = read_trees(CLOSE_OPEN_TWO_LABELS_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + assert isinstance(transitions[5], CloseConstituent) + assert isinstance(transitions[6], OpenConstituent) + + reconstructed = reconstruct_tree(tree, transitions, transition_scheme=TransitionScheme.TOP_DOWN) + assert tree == reconstructed + + new_transitions = get_single_repair(transitions, OpenConstituent("PP"), fix_close_open_correct_open_ambiguous_later, 5, check_close=False) + reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN) + + expected = """ + ( (VP (VBZ eat) + (NP (NN spaghetti) + (PP (IN with) (DT a) (NN fork)) + (PP (IN in) (DT a) (NN restaurant))))) + """ + expected = read_trees(expected)[0] + assert reconstructed == expected + + +SHIFT_CLOSE_EXAMPLES = [ + ("((S (NP (DT an) (NML (NNP Oct) (CD 19)) (NN review))))", "((S (NP (DT an) (NML (NNP Oct) (CD 19))) (NN review)))", 8), + ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", + "((S (NP (` `) (NP (DT The)) (NN Misanthrope) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", 6), + ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", + "((S (NP (` `) (NP (DT The) (NN Misanthrope))) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre)))))", 8), + ("((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman) (NNP Theatre))))))", + "((S (NP (` `) (NP (DT The) (NN Misanthrope)) (` `) (PP (IN at) (NP (NNP Goodman)) (NNP Theatre)))))", 13), +] + +def test_shift_close(): + for idx, (orig_tree, expected_tree, shift_position) in enumerate(SHIFT_CLOSE_EXAMPLES): + trees = read_trees(orig_tree) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + if shift_position is None: + print(transitions) + continue + + assert isinstance(transitions[shift_position], Shift) + new_transitions = get_single_repair(transitions, CloseConstituent(), fix_shift_close, shift_position) + reconstructed = reconstruct_tree(tree, new_transitions, transition_scheme=TransitionScheme.TOP_DOWN) + if expected_tree is None: + print(transitions) + print(new_transitions) + + print("{:P}".format(reconstructed)) + else: + expected_tree = read_trees(expected_tree) + assert len(expected_tree) == 1 + expected_tree = expected_tree[0] + + assert reconstructed == expected_tree + +def test_shift_open_ambiguous_unary(): + """ + Test what happens if a Shift is turned into an Open in an ambiguous manner + """ + trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] + assert transitions == expected_original + + new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_unary, 4) + expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] + assert new_sequence == expected_updated + +def test_shift_open_ambiguous_later(): + """ + Test what happens if a Shift is turned into an Open in an ambiguous manner + """ + trees = read_trees(CLOSE_SHIFT_AMBIGUOUS_TREE) + assert len(trees) == 1 + tree = trees[0] + + transitions = build_sequence(tree, transition_scheme=TransitionScheme.TOP_DOWN) + expected_original = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), Shift(), Shift(), Shift(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] + assert transitions == expected_original + + new_sequence = get_single_repair(transitions, OpenConstituent("ZZ"), fix_shift_open_ambiguous_later, 4) + expected_updated = [OpenConstituent('ROOT'), OpenConstituent('NP'), Shift(), OpenConstituent('ADJP'), OpenConstituent('ZZ'), Shift(), Shift(), Shift(), CloseConstituent(), CloseConstituent(), Shift(), Shift(), CloseConstituent(), CloseConstituent()] + assert new_sequence == expected_updated diff --git a/stanza/stanza/tests/constituency/test_trainer.py b/stanza/stanza/tests/constituency/test_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8b8cc133529c7f680a35c8dc1629163cc93b4654 --- /dev/null +++ b/stanza/stanza/tests/constituency/test_trainer.py @@ -0,0 +1,639 @@ +from collections import defaultdict +import logging +import pathlib +import tempfile + +import pytest +import torch +from torch import nn +from torch import optim + +from stanza import Pipeline + +from stanza.models import constituency_parser +from stanza.models.common import pretrain +from stanza.models.common.bert_embedding import load_bert, load_tokenizer +from stanza.models.common.foundation_cache import FoundationCache +from stanza.models.common.utils import set_random_seed +from stanza.models.constituency import lstm_model +from stanza.models.constituency.parse_transitions import Transition +from stanza.models.constituency import parser_training +from stanza.models.constituency import trainer +from stanza.models.constituency import tree_reader +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +logger = logging.getLogger('stanza.constituency.trainer') +logger.setLevel(logging.WARNING) + +TREEBANK = """ +( (S + (VP (VBG Enjoying) + (NP (PRP$ my) (JJ favorite) (NN Friday) (NN tradition))) + (. .))) + +( (NP + (VP (VBG Sitting) + (PP (IN in) + (NP (DT a) (RB stifling) (JJ hot) (NNP South) (NNP Station))) + (VP (VBG waiting) + (PP (IN for) + (NP (PRP$ my) (JJ delayed) (NNP @MBTA) (NN train))))) + (. .))) + +( (S + (NP (PRP I)) + (VP + (ADVP (RB really)) + (VBP hate) + (NP (DT the) (NNP @MBTA))))) + +( (S + (S (VP (VB Seek))) + (CC and) + (S (NP (PRP ye)) + (VP (MD shall) + (VP (VB find)))) + (. .))) +""" + +def build_trainer(wordvec_pretrain_file, *args, treebank=TREEBANK): + # TODO: build a fake embedding some other way? + train_trees = tree_reader.read_trees(treebank) + dev_trees = train_trees[-1:] + silver_trees = [] + + args = ['--wordvec_pretrain_file', wordvec_pretrain_file] + list(args) + args = constituency_parser.parse_args(args) + + foundation_cache = FoundationCache() + # might be None, unless we're testing loading an existing model + model_load_name = args['load_name'] + + model, _, _, _ = parser_training.build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_name) + assert isinstance(model.model, lstm_model.LSTMModel) + return model + +class TestTrainer: + @pytest.fixture(scope="class") + def wordvec_pretrain_file(self): + return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' + + @pytest.fixture(scope="class") + def tiny_random_xlnet(self, tmp_path_factory): + """ + Download the tiny-random-xlnet model and make a concrete copy of it + + The issue here is that the "random" nature of the original + makes it difficult or impossible to test that the values in + the transformer don't change during certain operations. + Saving a concrete instantiation of those random numbers makes + it so we can test there is no difference when training only a + subset of the layers, for example + """ + xlnet_name = 'hf-internal-testing/tiny-random-xlnet' + xlnet_model, xlnet_tokenizer = load_bert(xlnet_name) + path = str(tmp_path_factory.mktemp('tiny-random-xlnet')) + xlnet_model.save_pretrained(path) + xlnet_tokenizer.save_pretrained(path) + return path + + @pytest.fixture(scope="class") + def tiny_random_bart(self, tmp_path_factory): + """ + Download the tiny-random-bart model and make a concrete copy of it + + Issue is the same as with tiny_random_xlnet + """ + bart_name = 'hf-internal-testing/tiny-random-bart' + bart_model, bart_tokenizer = load_bert(bart_name) + path = str(tmp_path_factory.mktemp('tiny-random-bart')) + bart_model.save_pretrained(path) + bart_tokenizer.save_pretrained(path) + return path + + def test_initial_model(self, wordvec_pretrain_file): + """ + does nothing, just tests that the construction went okay + """ + args = ['wordvec_pretrain_file', wordvec_pretrain_file] + build_trainer(wordvec_pretrain_file) + + + def test_save_load_model(self, wordvec_pretrain_file): + """ + Just tests that saving and loading works without crashs. + + Currently no test of the values themselves + (checks some fields to make sure they are regenerated correctly) + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + tr = build_trainer(wordvec_pretrain_file) + transitions = tr.model.transitions + + # attempt saving + filename = os.path.join(tmpdirname, "parser.pt") + tr.save(filename) + + assert os.path.exists(filename) + + # load it back in + tr2 = tr.load(filename) + trans2 = tr2.model.transitions + assert(transitions == trans2) + assert all(isinstance(x, Transition) for x in trans2) + + def test_relearn_structure(self, wordvec_pretrain_file): + """ + Test that starting a trainer with --relearn_structure copies the old model + """ + + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + set_random_seed(1000) + args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'] + tr = build_trainer(wordvec_pretrain_file, *args) + + # attempt saving + filename = os.path.join(tmpdirname, "parser.pt") + tr.save(filename) + + set_random_seed(1001) + args = ['--pattn_num_layers', '1', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10', '--relearn_structure', '--load_name', filename] + tr2 = build_trainer(wordvec_pretrain_file, *args) + + assert torch.allclose(tr.model.delta_embedding.weight, tr2.model.delta_embedding.weight) + assert torch.allclose(tr.model.output_layers[0].weight, tr2.model.output_layers[0].weight) + # the norms will be the same, as the non-zero values are all the same + assert torch.allclose(torch.linalg.norm(tr.model.word_lstm.weight_ih_l0), torch.linalg.norm(tr2.model.word_lstm.weight_ih_l0)) + + def write_treebanks(self, tmpdirname): + train_treebank_file = os.path.join(tmpdirname, "train.mrg") + with open(train_treebank_file, 'w', encoding='utf-8') as fout: + fout.write(TREEBANK) + fout.write(TREEBANK) + + eval_treebank_file = os.path.join(tmpdirname, "eval.mrg") + with open(eval_treebank_file, 'w', encoding='utf-8') as fout: + fout.write(TREEBANK) + + return train_treebank_file, eval_treebank_file + + def training_args(self, wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *additional_args): + # let's not make the model huge... + args = ['--pattn_num_layers', '0', '--pattn_d_model', '128', '--lattn_d_proj', '0', '--use_lattn', '--hidden_size', '20', '--delta_embedding_dim', '10', + '--wordvec_pretrain_file', wordvec_pretrain_file, '--data_dir', tmpdirname, + '--save_dir', tmpdirname, '--save_name', 'test.pt', '--save_each_start', '0', '--save_each_name', os.path.join(tmpdirname, 'each_%02d.pt'), + '--train_file', train_treebank_file, '--eval_file', eval_treebank_file, + '--epoch_size', '6', '--train_batch_size', '3', + '--shorthand', 'en_test'] + args = args + list(additional_args) + args = constituency_parser.parse_args(args) + # just in case we change the defaults in the future + args['wandb'] = None + return args + + def run_train_test(self, wordvec_pretrain_file, tmpdirname, num_epochs=5, extra_args=None, use_silver=False, exists_ok=False, foundation_cache=None): + """ + Runs a test of the trainer for a few iterations. + + Checks some basic properties of the saved model, but doesn't + check for the accuracy of the results + """ + if extra_args is None: + extra_args = [] + extra_args += ['--epochs', '%d' % num_epochs] + + train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) + if use_silver: + extra_args += ['--silver_file', str(eval_treebank_file)] + args = self.training_args(wordvec_pretrain_file, tmpdirname, train_treebank_file, eval_treebank_file, *extra_args) + + each_name = args['save_each_name'] + if not exists_ok: + assert not os.path.exists(args['save_name']) + retag_pipeline = Pipeline(lang="en", processors="tokenize, pos", tokenize_pretokenized=True, dir=TEST_MODELS_DIR, foundation_cache=foundation_cache) + trained_model = parser_training.train(args, None, [retag_pipeline]) + # check that hooks are in the model if expected + for p in trained_model.model.parameters(): + if p.requires_grad: + if args['grad_clipping'] is not None: + assert len(p._backward_hooks) == 1 + else: + assert p._backward_hooks is None + + # check that the model can be loaded back + assert os.path.exists(args['save_name']) + peft_name = trained_model.model.peft_name + tr = trainer.Trainer.load(args['save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name) + assert tr.optimizer is not None + assert tr.scheduler is not None + assert tr.epochs_trained >= 1 + for p in tr.model.parameters(): + if p.requires_grad: + assert p._backward_hooks is None + + tr = trainer.Trainer.load(args['checkpoint_save_name'], load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name) + assert tr.optimizer is not None + assert tr.scheduler is not None + assert tr.epochs_trained == num_epochs + + for i in range(1, num_epochs+1): + model_name = each_name % i + assert os.path.exists(model_name) + tr = trainer.Trainer.load(model_name, load_optimizer=True, foundation_cache=retag_pipeline.foundation_cache, peft_name=trained_model.model.peft_name) + assert tr.epochs_trained == i + assert tr.batches_trained == (4 * i if use_silver else 2 * i) + + return args, trained_model + + def test_train(self, wordvec_pretrain_file): + """ + Test the whole thing for a few iterations on the fake data + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_train_test(wordvec_pretrain_file, tmpdirname) + + def test_early_dropout(self, wordvec_pretrain_file): + """ + Test the whole thing for a few iterations on the fake data + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--early_dropout', '3'] + _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) + model = model.model + dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)] + assert len(dropouts) > 0, "Didn't find any dropouts in the model!" + for name, module in dropouts: + assert module.p == 0.0, "Dropout module %s was not set to 0 with early_dropout" + + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + # test that when turned off, early_dropout doesn't happen + args = ['--early_dropout', '-1'] + _, model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) + model = model.model + dropouts = [(name, module) for name, module in model.named_children() if isinstance(module, nn.Dropout)] + assert len(dropouts) > 0, "Didn't find any dropouts in the model!" + if all(module.p == 0.0 for _, module in dropouts): + raise AssertionError("All dropouts were 0 after training even though early_dropout was set to -1") + + def test_train_silver(self, wordvec_pretrain_file): + """ + Test the whole thing for a few iterations on the fake data + + This tests that it works if you give it a silver file + The check for the use of the silver data is that the + number of batches trained should go up + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=True) + + def test_train_checkpoint(self, wordvec_pretrain_file): + """ + Test the whole thing for a few iterations, then restart + + This tests that the 5th iteration save file is not rewritten + and that the iterations continue to 10 + + TODO: could make it more robust by verifying that only 5 more + epochs are trained. Perhaps a "most recent epochs" could be + saved in the trainer + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, use_silver=False) + save_5 = args['save_each_name'] % 5 + save_10 = args['save_each_name'] % 10 + assert os.path.exists(save_5) + assert not os.path.exists(save_10) + + save_5_stat = pathlib.Path(save_5).stat() + + self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=10, use_silver=False, exists_ok=True) + assert os.path.exists(save_5) + assert os.path.exists(save_10) + + assert pathlib.Path(save_5).stat().st_mtime == save_5_stat.st_mtime + + def run_multistage_tests(self, wordvec_pretrain_file, tmpdirname, use_lattn, extra_args=None): + train_treebank_file, eval_treebank_file = self.write_treebanks(tmpdirname) + args = ['--multistage', '--pattn_num_layers', '1'] + if use_lattn: + args += ['--lattn_d_proj', '16'] + if extra_args: + args += extra_args + args, _ = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=8, extra_args=args) + each_name = os.path.join(args['save_dir'], 'each_%02d.pt') + + word_input_sizes = defaultdict(list) + for i in range(1, 9): + model_name = each_name % i + assert os.path.exists(model_name) + tr = trainer.Trainer.load(model_name, load_optimizer=True) + assert tr.epochs_trained == i + word_input_sizes[tr.model.word_input_size].append(i) + if use_lattn: + # there should be three stages: no attn, pattn, pattn+lattn + assert len(word_input_sizes) == 3 + word_input_keys = sorted(word_input_sizes.keys()) + assert word_input_sizes[word_input_keys[0]] == [1, 2, 3] + assert word_input_sizes[word_input_keys[1]] == [4, 5] + assert word_input_sizes[word_input_keys[2]] == [6, 7, 8] + else: + # with no lattn, there are two stages: no attn, pattn + assert len(word_input_sizes) == 2 + word_input_keys = sorted(word_input_sizes.keys()) + assert word_input_sizes[word_input_keys[0]] == [1, 2, 3] + assert word_input_sizes[word_input_keys[1]] == [4, 5, 6, 7, 8] + + def test_multistage_lattn(self, wordvec_pretrain_file): + """ + Test a multistage training for a few iterations on the fake data + + This should start with no pattn or lattn, have pattn in the middle, then lattn at the end + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=True) + + def test_multistage_no_lattn(self, wordvec_pretrain_file): + """ + Test a multistage training for a few iterations on the fake data + + This should start with no pattn or lattn, have pattn in the middle, then lattn at the end + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False) + + def test_multistage_optimizer(self, wordvec_pretrain_file): + """ + Test that the correct optimizers are built for a multistage training process + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + extra_args = ['--optim', 'adamw'] + self.run_multistage_tests(wordvec_pretrain_file, tmpdirname, use_lattn=False, extra_args=extra_args) + + # check that the optimizers which get rebuilt when loading + # the models are adadelta for the first half of the + # multistage, then adamw + each_name = os.path.join(tmpdirname, 'each_%02d.pt') + for i in range(1, 3): + model_name = each_name % i + tr = trainer.Trainer.load(model_name, load_optimizer=True) + assert tr.epochs_trained == i + assert isinstance(tr.optimizer, optim.Adadelta) + # double check that this is actually a valid test + assert not isinstance(tr.optimizer, optim.AdamW) + + for i in range(4, 8): + model_name = each_name % i + tr = trainer.Trainer.load(model_name, load_optimizer=True) + assert tr.epochs_trained == i + assert isinstance(tr.optimizer, optim.AdamW) + + + def test_grad_clip_hooks(self, wordvec_pretrain_file): + """ + Verify that grad clipping is not saved with the model, but is attached at training time + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--grad_clipping', '25'] + self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) + + def test_analyze_trees(self, wordvec_pretrain_file): + test_str = "(ROOT (S (NP (PRP I)) (VP (VBP wan) (S (VP (TO na) (VP (VB lick) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))))) (ROOT (S (NP (DT This) (NN interface)) (VP (VBZ sucks))))" + + test_tree = tree_reader.read_trees(test_str) + assert len(test_tree) == 2 + + args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10'] + tr = build_trainer(wordvec_pretrain_file, *args) + + results = tr.model.analyze_trees(test_tree) + assert len(results) == 2 + assert len(results[0].predictions) == 1 + assert results[0].predictions[0].tree == test_tree[0] + assert results[0].state is not None + assert isinstance(results[0].state.score, torch.Tensor) + assert results[0].state.score.shape == torch.Size([]) + assert len(results[0].constituents) == 9 + assert results[0].constituents[-1].value == test_tree[0] + # the way the results are built, the next-to-last entry + # should be the thing just below the root + assert results[0].constituents[-2].value == test_tree[0].children[0] + + assert len(results[1].predictions) == 1 + assert results[1].predictions[0].tree == test_tree[1] + assert results[1].state is not None + assert isinstance(results[1].state.score, torch.Tensor) + assert results[1].state.score.shape == torch.Size([]) + assert len(results[1].constituents) == 4 + assert results[1].constituents[-1].value == test_tree[1] + assert results[1].constituents[-2].value == test_tree[1].children[0] + + def bert_weights_allclose(self, bert_model, parser_model): + """ + Return True if all bert weights are close, False otherwise + """ + for name, parameter in bert_model.named_parameters(): + other_name = "bert_model." + name + other_parameter = parser_model.model.get_parameter(other_name) + if not torch.allclose(parameter.cpu(), other_parameter.cpu()): + return False + return True + + def frozen_transformer_test(self, wordvec_pretrain_file, transformer_name): + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + foundation_cache = FoundationCache() + args = ['--bert_model', transformer_name] + args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args, foundation_cache=foundation_cache) + bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) + assert self.bert_weights_allclose(bert_model, trained_model) + + checkpoint = torch.load(args['save_name'], lambda storage, loc: storage, weights_only=True) + params = checkpoint['params'] + # check that the bert model wasn't saved in the model + assert all(not x.startswith("bert_model.") for x in params['model'].keys()) + # make sure we're looking at the right thing + assert any(x.startswith("output_layers.") for x in params['model'].keys()) + + # check that the cached model is used as expected when loading a bert model + trained_model = trainer.Trainer.load(args['save_name'], foundation_cache=foundation_cache) + assert trained_model.model.bert_model is bert_model + + def test_bert_frozen(self, wordvec_pretrain_file): + """ + Check that the parameters of the bert model don't change when training a basic model + """ + self.frozen_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert') + + def test_xlnet_frozen(self, wordvec_pretrain_file, tiny_random_xlnet): + """ + Check that the parameters of an xlnet model don't change when training a basic model + """ + self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_xlnet) + + def test_bart_frozen(self, wordvec_pretrain_file, tiny_random_bart): + """ + Check that the parameters of an xlnet model don't change when training a basic model + """ + self.frozen_transformer_test(wordvec_pretrain_file, tiny_random_bart) + + def test_bert_finetune_one_epoch(self, wordvec_pretrain_file): + """ + Check that the parameters the bert model DO change over a single training step + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + transformer_name = 'hf-internal-testing/tiny-bert' + args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adadelta'] + args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=1, extra_args=args) + + # check that the weights are different + foundation_cache = FoundationCache() + bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) + assert not self.bert_weights_allclose(bert_model, trained_model) + + # double check that a new bert is created instead of using the FoundationCache when the bert has been trained + model_name = args['save_name'] + assert os.path.exists(model_name) + no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name) + tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache) + assert tr.model.bert_model is not bert_model + assert not self.bert_weights_allclose(bert_model, tr) + assert self.bert_weights_allclose(trained_model.model.bert_model, tr) + + new_save_name = os.path.join(tmpdirname, "test_resave_bert.pt") + assert not os.path.exists(new_save_name) + tr.save(new_save_name, save_optimizer=False) + tr2 = trainer.Trainer.load(new_save_name, args=no_finetune_args, foundation_cache=foundation_cache) + # check that the resaved model included its finetuned bert weights + assert tr2.model.bert_model is not bert_model + # the finetuned bert weights should also be scheduled for saving the next time as well + assert not tr2.model.is_unsaved_module("bert_model") + + def finetune_transformer_test(self, wordvec_pretrain_file, transformer_name): + """ + Check that the parameters of the transformer DO change when using bert_finetune + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw'] + args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) + + # check that the weights are different + foundation_cache = FoundationCache() + bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) + assert not self.bert_weights_allclose(bert_model, trained_model) + + # double check that a new bert is created instead of using the FoundationCache when the bert has been trained + no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', transformer_name) + trained_model = trainer.Trainer.load(args['save_name'], args=no_finetune_args, foundation_cache=foundation_cache) + assert not trained_model.model.args['bert_finetune'] + assert not trained_model.model.args['stage1_bert_finetune'] + assert trained_model.model.bert_model is not bert_model + + def test_bert_finetune(self, wordvec_pretrain_file): + """ + Check that the parameters of a bert model DO change when using bert_finetune + """ + self.finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert') + + def test_xlnet_finetune(self, wordvec_pretrain_file, tiny_random_xlnet): + """ + Check that the parameters of an xlnet model DO change when using bert_finetune + """ + self.finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet) + + def test_stage1_bert_finetune(self, wordvec_pretrain_file): + """ + Check that the parameters the bert model DO change when using stage1_bert_finetune, but only for the first couple steps + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + bert_model_name = 'hf-internal-testing/tiny-bert' + args = ['--bert_model', bert_model_name, '--stage1_bert_finetune', '--optim', 'adamw'] + # need to use num_epochs==6 so that epochs 1 and 2 are saved to be different + # a test of 5 or less means that sometimes it will reload the params + # at step 2 to get ready for the following iterations with adamw + args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=6, extra_args=args) + + # check that the weights are different + foundation_cache = FoundationCache() + bert_model, bert_tokenizer = foundation_cache.load_bert(bert_model_name) + assert not self.bert_weights_allclose(bert_model, trained_model) + + # double check that a new bert is created instead of using the FoundationCache when the bert has been trained + no_finetune_args = self.training_args(wordvec_pretrain_file, tmpdirname, None, None, "--no_bert_finetune", "--no_stage1_bert_finetune", '--bert_model', bert_model_name, '--optim', 'adamw') + num_epochs = trained_model.model.args['epochs'] + each_name = os.path.join(tmpdirname, 'each_%02d.pt') + for i in range(1, num_epochs+1): + model_name = each_name % i + assert os.path.exists(model_name) + tr = trainer.Trainer.load(model_name, args=no_finetune_args, foundation_cache=foundation_cache) + assert tr.model.bert_model is not bert_model + assert not self.bert_weights_allclose(bert_model, tr) + if i >= num_epochs // 2: + assert self.bert_weights_allclose(trained_model.model.bert_model, tr) + + # verify that models 1 and 2 are saved to be different + model_name_1 = each_name % 1 + model_name_2 = each_name % 2 + tr_1 = trainer.Trainer.load(model_name_1, args=no_finetune_args, foundation_cache=foundation_cache) + tr_2 = trainer.Trainer.load(model_name_2, args=no_finetune_args, foundation_cache=foundation_cache) + assert not self.bert_weights_allclose(tr_1.model.bert_model, tr_2) + + + def one_layer_finetune_transformer_test(self, wordvec_pretrain_file, transformer_name): + """ + Check that the parameters the bert model DO change when using bert_finetune + """ + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + args = ['--bert_model', transformer_name, '--bert_finetune', '--bert_finetune_layers', '1', '--optim', 'adamw', '--bert_finetune_layers', '1'] + args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, extra_args=args) + + # check that the weights of the last layer are different, + # but the weights of the earlier layers and + # non-transformer-layers are the same + foundation_cache = FoundationCache() + bert_model, bert_tokenizer = foundation_cache.load_bert(transformer_name) + assert bert_model.config.num_hidden_layers > 1 + layer_name = "layer.%d." % (bert_model.config.num_hidden_layers - 1) + for name, parameter in bert_model.named_parameters(): + other_name = "bert_model." + name + other_parameter = trained_model.model.get_parameter(other_name) + if layer_name in name: + if 'rel_attn.seg_embed' in name or 'rel_attn.r_s_bias' in name: + # not sure why this happens for xlnet, just roll with it + continue + assert not torch.allclose(parameter.cpu(), other_parameter.cpu()) + else: + assert torch.allclose(parameter.cpu(), other_parameter.cpu()) + + def test_bert_finetune_one_layer(self, wordvec_pretrain_file): + self.one_layer_finetune_transformer_test(wordvec_pretrain_file, 'hf-internal-testing/tiny-bert') + + def test_xlnet_finetune_one_layer(self, wordvec_pretrain_file, tiny_random_xlnet): + self.one_layer_finetune_transformer_test(wordvec_pretrain_file, tiny_random_xlnet) + + def test_peft_finetune(self, tmp_path, wordvec_pretrain_file): + transformer_name = 'hf-internal-testing/tiny-bert' + args = ['--bert_model', transformer_name, '--bert_finetune', '--optim', 'adamw', '--use_peft'] + args, trained_model = self.run_train_test(wordvec_pretrain_file, str(tmp_path), extra_args=args) + + def test_peft_twostage_finetune(self, wordvec_pretrain_file): + with tempfile.TemporaryDirectory(dir=TEST_WORKING_DIR) as tmpdirname: + num_epochs = 6 + transformer_name = 'hf-internal-testing/tiny-bert' + args = ['--bert_model', transformer_name, '--stage1_bert_finetune', '--optim', 'adamw', '--use_peft'] + args, trained_model = self.run_train_test(wordvec_pretrain_file, tmpdirname, num_epochs=num_epochs, extra_args=args) + for epoch in range(num_epochs): + filename_prev = args['save_each_name'] % epoch + filename_next = args['save_each_name'] % (epoch+1) + trainer_prev = trainer.Trainer.load(filename_prev, args=args, load_optimizer=False) + trainer_next = trainer.Trainer.load(filename_next, args=args, load_optimizer=False) + + lora_names = [name for name, _ in trainer_prev.model.bert_model.named_parameters() if name.find("lora") >= 0] + if epoch < 2: + assert not any(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(), + trainer_next.model.bert_model.get_parameter(name).cpu()) + for name in lora_names) + elif epoch > 2: + assert all(torch.allclose(trainer_prev.model.bert_model.get_parameter(name).cpu(), + trainer_next.model.bert_model.get_parameter(name).cpu()) + for name in lora_names) diff --git a/stanza/stanza/tests/constituency/test_transformer_tree_stack.py b/stanza/stanza/tests/constituency/test_transformer_tree_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..5045e06b46fbf256feffbe0593ca6eadf7af13ef --- /dev/null +++ b/stanza/stanza/tests/constituency/test_transformer_tree_stack.py @@ -0,0 +1,195 @@ +import pytest + +import torch + +from stanza.models.constituency.transformer_tree_stack import TransformerTreeStack + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_initial_state(): + """ + Test that the initial state has the expected shapes + """ + ts = TransformerTreeStack(3, 5, 0.0) + initial = ts.initial_state() + assert len(initial) == 1 + assert initial.value.output.shape == torch.Size([5]) + assert initial.value.key_stack.shape == torch.Size([1, 5]) + assert initial.value.value_stack.shape == torch.Size([1, 5]) + +def test_output(): + """ + Test that you can get an expected output shape from the TTS + """ + ts = TransformerTreeStack(3, 5, 0.0) + initial = ts.initial_state() + out = ts.output(initial) + assert out.shape == torch.Size([5]) + assert torch.allclose(initial.value.output, out) + +def test_push_state_single(): + """ + Test that stacks are being updated correctly when using a single stack + + Values of the attention are not verified, though + """ + ts = TransformerTreeStack(3, 5, 0.0) + initial = ts.initial_state() + rand_input = torch.randn(1, 3) + stacks = ts.push_states([initial], ["A"], rand_input) + stacks = ts.push_states(stacks, ["B"], rand_input) + assert len(stacks) == 1 + assert len(stacks[0]) == 3 + assert stacks[0].value.value == "B" + assert stacks[0].pop().value.value == "A" + assert stacks[0].pop().pop().value.value is None + +def test_push_state_same_length(): + """ + Test that stacks are being updated correctly when using 3 stacks of the same length + + Values of the attention are not verified, though + """ + ts = TransformerTreeStack(3, 5, 0.0) + initial = ts.initial_state() + rand_input = torch.randn(3, 3) + stacks = ts.push_states([initial, initial, initial], ["A", "A", "A"], rand_input) + stacks = ts.push_states(stacks, ["B", "B", "B"], rand_input) + stacks = ts.push_states(stacks, ["C", "C", "C"], rand_input) + assert len(stacks) == 3 + for s in stacks: + assert len(s) == 4 + assert s.value.key_stack.shape == torch.Size([4, 5]) + assert s.value.value_stack.shape == torch.Size([4, 5]) + assert s.value.value == "C" + assert s.pop().value.value == "B" + assert s.pop().pop().value.value == "A" + assert s.pop().pop().pop().value.value is None + +def test_push_state_different_length(): + """ + Test what happens if stacks of different lengths are passed in + """ + ts = TransformerTreeStack(3, 5, 0.0) + initial = ts.initial_state() + rand_input = torch.randn(2, 3) + one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0] + stacks = [one_step, initial] + stacks = ts.push_states(stacks, ["B", "C"], rand_input) + assert len(stacks) == 2 + assert len(stacks[0]) == 3 + assert len(stacks[1]) == 2 + assert stacks[0].pop().value.value == 'A' + assert stacks[0].value.value == 'B' + assert stacks[1].value.value == 'C' + + assert stacks[0].value.key_stack.shape == torch.Size([3, 5]) + assert stacks[1].value.key_stack.shape == torch.Size([2, 5]) + +def test_mask(): + """ + Test that a mask prevents the softmax from picking up unwanted values + """ + ts = TransformerTreeStack(3, 5, 0.0) + + random_v = torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5]]]) + double_v = random_v * 2 + value = torch.cat([random_v, double_v], axis=1) + random_k = torch.randn(1, 1, 5) + key = torch.cat([random_k, random_k], axis=1) + query = torch.randn(1, 5) + + output = ts.attention(key, query, value) + # when the two keys are equal, we expect the attention to be 50/50 + expected_output = (random_v + double_v) / 2 + assert torch.allclose(output, expected_output) + + # If the first entry is masked out, the second one should be the + # only one represented + mask = torch.zeros(1, 2, dtype=torch.bool) + mask[0][0] = True + output = ts.attention(key, query, value, mask) + assert torch.allclose(output, double_v) + + # If the second entry is masked out, the first one should be the + # only one represented + mask = torch.zeros(1, 2, dtype=torch.bool) + mask[0][1] = True + output = ts.attention(key, query, value, mask) + assert torch.allclose(output, random_v) + +def test_position(): + """ + Test that nothing goes horribly wrong when position encodings are used + + Does not actually test the results of the encodings + """ + ts = TransformerTreeStack(4, 5, 0.0, use_position=True) + initial = ts.initial_state() + assert len(initial) == 1 + assert initial.value.output.shape == torch.Size([5]) + assert initial.value.key_stack.shape == torch.Size([1, 5]) + assert initial.value.value_stack.shape == torch.Size([1, 5]) + + rand_input = torch.randn(2, 4) + one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0] + stacks = [one_step, initial] + stacks = ts.push_states(stacks, ["B", "C"], rand_input) + +def test_length_limit(): + """ + Test that the length limit drops nodes as the length limit is exceeded + """ + ts = TransformerTreeStack(4, 5, 0.0, length_limit = 2) + initial = ts.initial_state() + assert len(initial) == 1 + assert initial.value.output.shape == torch.Size([5]) + assert initial.value.key_stack.shape == torch.Size([1, 5]) + assert initial.value.value_stack.shape == torch.Size([1, 5]) + + data = torch.tensor([[0.1, 0.2, 0.3, 0.4]]) + stacks = ts.push_states([initial], ["A"], data) + + stacks = ts.push_states(stacks, ["B"], data) + assert len(stacks) == 1 + assert len(stacks[0]) == 3 + assert stacks[0].value.key_stack.shape[0] == 3 + assert stacks[0].value.value_stack.shape[0] == 3 + + stacks = ts.push_states(stacks, ["C"], data) + assert len(stacks) == 1 + assert len(stacks[0]) == 4 + assert stacks[0].value.key_stack.shape[0] == 3 + assert stacks[0].value.value_stack.shape[0] == 3 + + stacks = ts.push_states(stacks, ["D"], data) + assert len(stacks) == 1 + assert len(stacks[0]) == 5 + assert stacks[0].value.key_stack.shape[0] == 3 + assert stacks[0].value.value_stack.shape[0] == 3 + +def test_two_heads(): + """ + Test that the length limit drops nodes as the length limit is exceeded + """ + ts = TransformerTreeStack(4, 6, 0.0, num_heads = 2) + initial = ts.initial_state() + assert len(initial) == 1 + assert initial.value.output.shape == torch.Size([6]) + assert initial.value.key_stack.shape == torch.Size([1, 6]) + assert initial.value.value_stack.shape == torch.Size([1, 6]) + + rand_input = torch.randn(2, 4) + one_step = ts.push_states([initial], ["A"], rand_input[0:1, :])[0] + stacks = [one_step, initial] + stacks = ts.push_states(stacks, ["B", "C"], rand_input) + assert len(stacks) == 2 + assert len(stacks[0]) == 3 + assert len(stacks[1]) == 2 + assert stacks[0].pop().value.value == 'A' + assert stacks[0].value.value == 'B' + assert stacks[1].value.value == 'C' + + assert stacks[0].value.key_stack.shape == torch.Size([3, 6]) + assert stacks[1].value.key_stack.shape == torch.Size([2, 6]) + diff --git a/stanza/stanza/tests/constituency/test_transition_sequence.py b/stanza/stanza/tests/constituency/test_transition_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8de02f646f242be872b4fb06f0ed76f1f900b8 --- /dev/null +++ b/stanza/stanza/tests/constituency/test_transition_sequence.py @@ -0,0 +1,156 @@ +import pytest +from stanza.models.constituency import parse_transitions +from stanza.models.constituency import transition_sequence +from stanza.models.constituency import tree_reader +from stanza.models.constituency.base_model import SimpleModel, UNARY_LIMIT +from stanza.models.constituency.parse_transitions import * + +from stanza.tests import * +from stanza.tests.constituency.test_parse_tree import CHINESE_LONG_LIST_TREE + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def reconstruct_tree(tree, sequence, transition_scheme=TransitionScheme.IN_ORDER, unary_limit=UNARY_LIMIT, reverse=False): + """ + Starting from a tree and a list of transitions, build the tree caused by the transitions + """ + model = SimpleModel(transition_scheme=transition_scheme, unary_limit=unary_limit, reverse_sentence=reverse) + states = model.initial_state_from_gold_trees([tree]) + assert(len(states)) == 1 + assert states[0].num_transitions == 0 + + # TODO: could fold this into parse_sentences (similar to verify_transitions in trainer.py) + for idx, t in enumerate(sequence): + assert t.is_legal(states[0], model), "Transition {} not legal at step {} in sequence {}".format(t, idx, sequence) + states = model.bulk_apply(states, [t]) + + result_tree = states[0].constituents.value + if reverse: + result_tree = result_tree.reverse() + return result_tree + +def check_reproduce_tree(transition_scheme): + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + + model = SimpleModel(transition_scheme) + transitions = transition_sequence.build_sequence(trees[0], transition_scheme) + states = model.initial_state_from_gold_trees(trees) + assert(len(states)) == 1 + state = states[0] + assert state.num_transitions == 0 + + for t in transitions: + assert t.is_legal(state, model) + state = t.apply(state, model) + + # one item for the final tree + # one item for the sentinel at the end + assert len(state.constituents) == 2 + # the transition sequence should put all of the words + # from the buffer onto the tree + # one spot left for the sentinel value + assert len(state.word_queue) == 8 + assert state.sentence_length == 6 + assert state.word_position == state.sentence_length + assert len(state.transitions) == len(transitions) + 1 + + result_tree = state.constituents.value + assert result_tree == trees[0] + +def test_top_down_unary(): + check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN_UNARY) + +def test_top_down_no_unary(): + check_reproduce_tree(transition_scheme=TransitionScheme.TOP_DOWN) + +def test_in_order(): + check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER) + +def test_in_order_compound(): + check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_COMPOUND) + +def test_in_order_unary(): + check_reproduce_tree(transition_scheme=TransitionScheme.IN_ORDER_UNARY) + +def test_all_transitions(): + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + model = SimpleModel() + transitions = transition_sequence.build_treebank(trees) + + expected = [Shift(), CloseConstituent(), CompoundUnary("ROOT"), CompoundUnary("SQ"), CompoundUnary("WHNP"), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("SBARQ"), OpenConstituent("VP")] + assert transition_sequence.all_transitions(transitions) == expected + + +def test_all_transitions_no_unary(): + text="((SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + model = SimpleModel() + transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN) + + expected = [Shift(), CloseConstituent(), OpenConstituent("NP"), OpenConstituent("PP"), OpenConstituent("ROOT"), OpenConstituent("SBARQ"), OpenConstituent("SQ"), OpenConstituent("VP"), OpenConstituent("WHNP")] + assert transition_sequence.all_transitions(transitions) == expected + +def test_top_down_compound_unary(): + text = "(ROOT (S (NP (DT The) (NNP Arizona) (NNPS Corporations) (NNP Commission)) (VP (VBD authorized) (NP (NP (DT an) (ADJP (CD 11.5)) (NN %) (NN rate) (NN increase)) (PP (IN at) (NP (NNP Tucson) (NNP Electric) (NNP Power) (NNP Co.))) (, ,) (UCP (ADJP (ADJP (RB substantially) (JJR lower)) (SBAR (IN than) (S (VP (VBN recommended) (NP (JJ last) (NN month)) (PP (IN by) (NP (DT a) (NN commission) (NN hearing) (NN officer))))))) (CC and) (NP (NP (QP (RB barely) (PDT half)) (DT the) (NN rise)) (VP (VBN sought) (PP (IN by) (NP (DT the) (NN utility)))))))) (. .)))" + + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + + model = SimpleModel() + transitions = transition_sequence.build_sequence(trees[0], transition_scheme=TransitionScheme.TOP_DOWN_COMPOUND) + + states = model.initial_state_from_gold_trees(trees) + assert len(states) == 1 + state = states[0] + + for t in transitions: + assert t.is_legal(state, model) + state = t.apply(state, model) + + result = model.get_top_constituent(state.constituents) + assert trees[0] == result + + +def test_chinese_tree(): + trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE) + + transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN) + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN) + assert redone == trees[0] + + transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER) + with pytest.raises(AssertionError): + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER) + + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6) + assert redone == trees[0] + + +def test_chinese_tree_reversed(): + """ + test that the reversed transitions also work + """ + trees = tree_reader.read_trees(CHINESE_LONG_LIST_TREE) + + transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.TOP_DOWN, reverse=True) + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN, reverse=True) + assert redone == trees[0] + + with pytest.raises(AssertionError): + # turn off reverse - it should fail to rebuild the tree + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.TOP_DOWN) + assert redone == trees[0] + + transitions = transition_sequence.build_treebank(trees, transition_scheme=TransitionScheme.IN_ORDER, reverse=True) + with pytest.raises(AssertionError): + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, reverse=True) + + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6, reverse=True) + assert redone == trees[0] + + with pytest.raises(AssertionError): + # turn off reverse - it should fail to rebuild the tree + redone = reconstruct_tree(trees[0], transitions[0], transition_scheme=TransitionScheme.IN_ORDER, unary_limit=6) + assert redone == trees[0] diff --git a/stanza/stanza/tests/constituency/test_tree_reader.py b/stanza/stanza/tests/constituency/test_tree_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..501015af3996b5e7d0c53e6be50056521e921ab8 --- /dev/null +++ b/stanza/stanza/tests/constituency/test_tree_reader.py @@ -0,0 +1,119 @@ +import pytest +from stanza.models.constituency import tree_reader +from stanza.models.constituency.tree_reader import MixedTreeError, UnclosedTreeError, UnlabeledTreeError + +from stanza.tests import * + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +def test_simple(): + """ + Tests reading two simple trees from the same text + """ + text = "(VB Unban) (NNP Opal)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + assert trees[0].is_preterminal() + assert trees[0].label == 'VB' + assert trees[0].children[0].label == 'Unban' + assert trees[1].is_preterminal() + assert trees[1].label == 'NNP' + assert trees[1].children[0].label == 'Opal' + +def test_newlines(): + """ + The same test should work if there are newlines + """ + text = "(VB Unban)\n\n(NNP Opal)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + +def test_parens(): + """ + Parens should be escaped in the tree files and escaped when written + """ + text = "(-LRB- -LRB-) (-RRB- -RRB-)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + + assert trees[0].label == '-LRB-' + assert trees[0].children[0].label == '(' + assert "{}".format(trees[0]) == '(-LRB- -LRB-)' + + assert trees[1].label == '-RRB-' + assert trees[1].children[0].label == ')' + assert "{}".format(trees[1]) == '(-RRB- -RRB-)' + +def test_complicated(): + """ + A more complicated tree that should successfully read + """ + text="( (SBARQ (WHNP (WP Who)) (SQ (VP (VBZ sits) (PP (IN in) (NP (DT this) (NN seat))))) (. ?)))" + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + tree = trees[0] + assert not tree.is_leaf() + assert not tree.is_preterminal() + assert tree.label == 'ROOT' + assert len(tree.children) == 1 + assert tree.children[0].label == 'SBARQ' + assert len(tree.children[0].children) == 3 + assert [x.label for x in tree.children[0].children] == ['WHNP', 'SQ', '.'] + # etc etc + +def test_one_word(): + """ + Check that one node trees are correctly read + + probably not super relevant for the parsing use case + """ + text="(FOO) (BAR)" + trees = tree_reader.read_trees(text) + assert len(trees) == 2 + + assert trees[0].is_leaf() + assert trees[0].label == 'FOO' + + assert trees[1].is_leaf() + assert trees[1].label == 'BAR' + +def test_missing_close_parens(): + """ + Test the unclosed error condition + """ + text = "(Foo) \n (Bar \n zzz" + try: + trees = tree_reader.read_trees(text) + raise AssertionError("Expected an exception") + except UnclosedTreeError as e: + assert e.line_num == 1 + +def test_mixed_tree(): + """ + Test the mixed error condition + """ + text = "(Foo) \n (Bar) \n (Unban (Mox) Opal)" + try: + trees = tree_reader.read_trees(text) + raise AssertionError("Expected an exception") + except MixedTreeError as e: + assert e.line_num == 2 + + trees = tree_reader.read_trees(text, broken_ok=True) + assert len(trees) == 3 + +def test_unlabeled_tree(): + """ + Test the unlabeled error condition + """ + text = "(ROOT ((Foo) (Bar)))" + try: + trees = tree_reader.read_trees(text) + raise AssertionError("Expected an exception") + except UnlabeledTreeError as e: + assert e.line_num == 0 + + trees = tree_reader.read_trees(text, broken_ok=True) + assert len(trees) == 1 + + diff --git a/stanza/stanza/tests/constituency/test_vietnamese.py b/stanza/stanza/tests/constituency/test_vietnamese.py new file mode 100644 index 0000000000000000000000000000000000000000..b764570d97e02d7b16cc1e161ce5260644a0cbff --- /dev/null +++ b/stanza/stanza/tests/constituency/test_vietnamese.py @@ -0,0 +1,121 @@ +""" +A few tests for Vietnamese parsing, which has some difficulties related to spaces in words + +Technically some other languages can have this, too, like that one French token +""" + +import os +import tempfile + +import pytest + +from stanza.models.common import pretrain +from stanza.models.constituency import tree_reader + +from stanza.tests.constituency.test_trainer import build_trainer + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +VI_TREEBANK = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài Loan) (" ") (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))' + +VI_TREEBANK_UNDERSCORE = '(ROOT (S-TTL (NP (" ") (N-H Đảo) (Np Đài_Loan) (" ") (PP (E-H ở) (NP (N-H đồng_bằng) (NP (N-H sông) (Np Cửu_Long))))) (. .)))' + +VI_TREEBANK_SIMPLE = '(ROOT (S (NP (" ") (N Đảo) (Np Đài Loan) (" ") (PP (E ở) (NP (N đồng bằng) (NP (N sông) (Np Cửu Long))))) (. .)))' + +VI_TREEBANK_PAREN = '(ROOT (S-TTL (NP (PUNCT -LRB-) (N-H Đảo) (Np Đài Loan) (PUNCT -RRB-) (PP (E-H ở) (NP (N-H đồng bằng) (NP (N-H sông) (Np Cửu Long))))) (. .)))' +VI_TREEBANK_VLSP = '