stanza-digphil / stanza /tests /classifiers /test_constituency_classifier.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
import os
import pytest
import stanza
import stanza.models.classifier as classifier
import stanza.models.classifiers.data as data
from stanza.models.classifiers.trainer import Trainer
from stanza.tests import TEST_MODELS_DIR
from stanza.tests.classifiers.test_classifier import fake_embeddings
from stanza.tests.classifiers.test_data import train_file_with_trees, dev_file_with_trees
from stanza.models.common import utils
from stanza.tests.constituency.test_trainer import build_trainer, TREEBANK
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
class TestConstituencyClassifier:
@pytest.fixture(scope="class")
def constituency_model(self, fake_embeddings, tmp_path_factory):
args = ['--pattn_num_layers', '0', '--lattn_d_proj', '0', '--hidden_size', '20', '--delta_embedding_dim', '10']
trainer = build_trainer(str(fake_embeddings), *args, treebank=TREEBANK)
trainer_pt = str(tmp_path_factory.mktemp("constituency") / "constituency.pt")
trainer.save(trainer_pt, save_optimizer=False)
return trainer_pt
def build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
"""
Build a Constituency Classifier model to be used by one of the later tests
"""
save_dir = str(tmp_path / "classifier")
save_name = "model.pt"
args = ["--save_dir", save_dir,
"--save_name", save_name,
"--model_type", "constituency",
"--constituency_model", constituency_model,
"--wordvec_pretrain_file", str(fake_embeddings),
"--fc_shapes", "20,10",
"--train_file", str(train_file_with_trees),
"--dev_file", str(dev_file_with_trees),
"--max_epochs", "2",
"--batch_size", "60"]
if extra_args is not None:
args = args + extra_args
args = classifier.parse_args(args)
train_set = data.read_dataset(args.train_file, args.wordvec_type, args.min_train_len)
trainer = Trainer.build_new_model(args, train_set)
return trainer, train_set, args
def run_training(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args=None):
"""
Iterate a couple times over a model
"""
trainer, train_set, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, extra_args)
dev_set = data.read_dataset(args.dev_file, args.wordvec_type, args.min_train_len)
labels = data.dataset_labels(train_set)
save_filename = os.path.join(args.save_dir, args.save_name)
checkpoint_file = utils.checkpoint_name(args.save_dir, save_filename, args.checkpoint_save_name)
classifier.train_model(trainer, save_filename, checkpoint_file, args, train_set, dev_set, labels)
return trainer, train_set, args
def test_build_model(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
"""
Test that building a basic constituency-based model works
"""
self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
def test_save_load(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
"""
Test that a constituency model can save & load
"""
trainer, _, args = self.build_model(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
save_filename = os.path.join(args.save_dir, args.save_name)
trainer.save(save_filename)
args.load_name = args.save_name
trainer = Trainer.load(args.load_name, args)
def test_train_basic(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
def test_train_pipeline(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
"""
Test that writing out a temp model, then loading it in the pipeline is a thing that works
"""
trainer, _, args = self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees)
save_filename = os.path.join(args.save_dir, args.save_name)
assert os.path.exists(save_filename)
assert os.path.exists(args.constituency_model)
pipeline_args = {"lang": "en",
"download_method": None,
"model_dir": TEST_MODELS_DIR,
"processors": "tokenize,pos,constituency,sentiment",
"tokenize_pretokenized": True,
"constituency_model_path": args.constituency_model,
"constituency_pretrain_path": args.wordvec_pretrain_file,
"constituency_backward_charlm_path": None,
"constituency_forward_charlm_path": None,
"sentiment_model_path": save_filename,
"sentiment_pretrain_path": args.wordvec_pretrain_file,
"sentiment_backward_charlm_path": None,
"sentiment_forward_charlm_path": None}
pipeline = stanza.Pipeline(**pipeline_args)
doc = pipeline("This is a test")
# since the model is random, we have no expectations for what the result actually is
assert doc.sentences[0].sentiment is not None
def test_train_all_words(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_all_words'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_all_words'])
def test_train_top_layer(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_top_layer'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_top_layer'])
def test_train_attn(self, tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees):
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--no_constituency_all_words'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--constituency_node_attn', '--constituency_all_words'])
self.run_training(tmp_path, constituency_model, fake_embeddings, train_file_with_trees, dev_file_with_trees, ['--no_constituency_node_attn'])