|
|
import json |
|
|
import pytest |
|
|
|
|
|
import stanza.models.classifiers.data as data |
|
|
from stanza.models.classifiers.utils import WVType |
|
|
from stanza.models.common.vocab import PAD, UNK |
|
|
from stanza.models.constituency.parse_tree import Tree |
|
|
|
|
|
SENTENCES = [ |
|
|
["I", "hate", "the", "Opal", "banning"], |
|
|
["Tell", "my", "wife", "hello"], |
|
|
["I", "like", "Sh'reyan", "'s", "antennae"], |
|
|
] |
|
|
|
|
|
DATASET = [ |
|
|
{"sentiment": "0", "text": SENTENCES[0]}, |
|
|
{"sentiment": "1", "text": SENTENCES[1]}, |
|
|
{"sentiment": "2", "text": SENTENCES[2]}, |
|
|
] |
|
|
|
|
|
TREES = [ |
|
|
"(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))", |
|
|
"(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))", |
|
|
"(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))", |
|
|
] |
|
|
|
|
|
DATASET_WITH_TREES = [ |
|
|
{"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]}, |
|
|
{"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]}, |
|
|
{"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]}, |
|
|
] |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def train_file(tmp_path_factory): |
|
|
train_set = DATASET * 20 |
|
|
train_filename = tmp_path_factory.mktemp("data") / "train.json" |
|
|
with open(train_filename, "w", encoding="utf-8") as fout: |
|
|
json.dump(train_set, fout, ensure_ascii=False) |
|
|
return train_filename |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def dev_file(tmp_path_factory): |
|
|
dev_set = DATASET * 2 |
|
|
dev_filename = tmp_path_factory.mktemp("data") / "dev.json" |
|
|
with open(dev_filename, "w", encoding="utf-8") as fout: |
|
|
json.dump(dev_set, fout, ensure_ascii=False) |
|
|
return dev_filename |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def test_file(tmp_path_factory): |
|
|
test_set = DATASET |
|
|
test_filename = tmp_path_factory.mktemp("data") / "test.json" |
|
|
with open(test_filename, "w", encoding="utf-8") as fout: |
|
|
json.dump(test_set, fout, ensure_ascii=False) |
|
|
return test_filename |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def train_file_with_trees(tmp_path_factory): |
|
|
train_set = DATASET_WITH_TREES * 20 |
|
|
train_filename = tmp_path_factory.mktemp("data") / "train_trees.json" |
|
|
with open(train_filename, "w", encoding="utf-8") as fout: |
|
|
json.dump(train_set, fout, ensure_ascii=False) |
|
|
return train_filename |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def dev_file_with_trees(tmp_path_factory): |
|
|
dev_set = DATASET_WITH_TREES * 2 |
|
|
dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json" |
|
|
with open(dev_filename, "w", encoding="utf-8") as fout: |
|
|
json.dump(dev_set, fout, ensure_ascii=False) |
|
|
return dev_filename |
|
|
|
|
|
class TestClassifierData: |
|
|
def test_read_data(self, train_file): |
|
|
""" |
|
|
Test reading of the json format |
|
|
""" |
|
|
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) |
|
|
assert len(train_set) == 60 |
|
|
|
|
|
def test_read_data_with_trees(self, train_file, train_file_with_trees): |
|
|
""" |
|
|
Test reading of the json format |
|
|
""" |
|
|
train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1) |
|
|
assert len(train_trees_set) == 60 |
|
|
for idx, x in enumerate(train_trees_set): |
|
|
assert isinstance(x.constituency, Tree) |
|
|
assert str(x.constituency) == TREES[idx % len(TREES)] |
|
|
|
|
|
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) |
|
|
|
|
|
def test_dataset_vocab(self, train_file): |
|
|
""" |
|
|
Converting a dataset to vocab should have a specific set of words along with PAD and UNK |
|
|
""" |
|
|
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) |
|
|
vocab = data.dataset_vocab(train_set) |
|
|
expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y]) |
|
|
assert set(vocab) == expected |
|
|
|
|
|
def test_dataset_labels(self, train_file): |
|
|
""" |
|
|
Test the extraction of labels from a dataset |
|
|
""" |
|
|
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) |
|
|
labels = data.dataset_labels(train_set) |
|
|
assert labels == ["0", "1", "2"] |
|
|
|
|
|
def test_sort_by_length(self, train_file): |
|
|
""" |
|
|
There are two unique lengths in the toy dataset |
|
|
""" |
|
|
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) |
|
|
sorted_dataset = data.sort_dataset_by_len(train_set) |
|
|
assert list(sorted_dataset.keys()) == [4, 5] |
|
|
assert len(sorted_dataset[4]) == len(train_set) // 3 |
|
|
assert len(sorted_dataset[5]) == 2 * len(train_set) // 3 |
|
|
|
|
|
def test_check_labels(self, train_file): |
|
|
""" |
|
|
Check that an exception is thrown for an unknown label |
|
|
""" |
|
|
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1) |
|
|
labels = sorted(set([x["sentiment"] for x in DATASET])) |
|
|
assert len(labels) > 1 |
|
|
data.check_labels(labels, train_set) |
|
|
with pytest.raises(RuntimeError): |
|
|
data.check_labels(labels[:1], train_set) |
|
|
|
|
|
|