stanza-digphil / stanza /tests /classifiers /test_process_utils.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""
A few tests of the utils module for the sentiment datasets
"""
import os
import pytest
import stanza
from stanza.models.classifiers import data
from stanza.models.classifiers.data import SentimentDatum
from stanza.models.classifiers.utils import WVType
from stanza.utils.datasets.sentiment import process_utils
from stanza.tests import TEST_MODELS_DIR
from stanza.tests.classifiers.test_data import train_file, dev_file, test_file
def test_write_list(tmp_path, train_file):
"""
Test that writing a single list of items to an output file works
"""
train_set = data.read_dataset(train_file, WVType.OTHER, 1)
dataset_file = tmp_path / "foo.json"
process_utils.write_list(dataset_file, train_set)
train_copy = data.read_dataset(dataset_file, WVType.OTHER, 1)
assert train_copy == train_set
def test_write_dataset(tmp_path, train_file, dev_file, test_file):
"""
Test that writing all three parts of a dataset works
"""
dataset = [data.read_dataset(filename, WVType.OTHER, 1) for filename in (train_file, dev_file, test_file)]
process_utils.write_dataset(dataset, tmp_path, "en_test")
expected_files = ['en_test.train.json', 'en_test.dev.json', 'en_test.test.json']
dataset_files = os.listdir(tmp_path)
assert sorted(dataset_files) == sorted(expected_files)
for filename, expected in zip(expected_files, dataset):
written = data.read_dataset(tmp_path / filename, WVType.OTHER, 1)
assert written == expected
def test_read_snippets(tmp_path):
"""
Test the basic operation of the read_snippets function
"""
filename = tmp_path / "foo.csv"
with open(filename, "w", encoding="utf-8") as fout:
fout.write("FOO\tThis is a test\thappy\n")
fout.write("FOO\tThis is a second sentence\tsad\n")
nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
mapping = {"happy": 0, "sad": 1}
snippets = process_utils.read_snippets(filename, 2, 1, "en", mapping, nlp=nlp)
assert len(snippets) == 2
assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence'])]
def test_read_snippets_two_columns(tmp_path):
"""
Test what happens when multiple columns are combined for the sentiment value
"""
filename = tmp_path / "foo.csv"
with open(filename, "w", encoding="utf-8") as fout:
fout.write("FOO\tThis is a test\thappy\tfoo\n")
fout.write("FOO\tThis is a second sentence\tsad\tbar\n")
fout.write("FOO\tThis is a third sentence\tsad\tfoo\n")
nlp = stanza.Pipeline("en", dir=TEST_MODELS_DIR, processors="tokenize", download_method=None)
mapping = {("happy", "foo"): 0, ("sad", "bar"): 1, ("sad", "foo"): 2}
snippets = process_utils.read_snippets(filename, (2,3), 1, "en", mapping, nlp=nlp)
assert len(snippets) == 3
assert snippets == [SentimentDatum(sentiment=0, text=['This', 'is', 'a', 'test']),
SentimentDatum(sentiment=1, text=['This', 'is', 'a', 'second', 'sentence']),
SentimentDatum(sentiment=2, text=['This', 'is', 'a', 'third', 'sentence'])]