File size: 4,949 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import json
import pytest
import stanza.models.classifiers.data as data
from stanza.models.classifiers.utils import WVType
from stanza.models.common.vocab import PAD, UNK
from stanza.models.constituency.parse_tree import Tree
SENTENCES = [
["I", "hate", "the", "Opal", "banning"],
["Tell", "my", "wife", "hello"], # obviously this is the neutral result
["I", "like", "Sh'reyan", "'s", "antennae"],
]
DATASET = [
{"sentiment": "0", "text": SENTENCES[0]},
{"sentiment": "1", "text": SENTENCES[1]},
{"sentiment": "2", "text": SENTENCES[2]},
]
TREES = [
"(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))",
"(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))",
"(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))",
]
DATASET_WITH_TREES = [
{"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]},
{"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]},
{"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]},
]
@pytest.fixture(scope="module")
def train_file(tmp_path_factory):
train_set = DATASET * 20
train_filename = tmp_path_factory.mktemp("data") / "train.json"
with open(train_filename, "w", encoding="utf-8") as fout:
json.dump(train_set, fout, ensure_ascii=False)
return train_filename
@pytest.fixture(scope="module")
def dev_file(tmp_path_factory):
dev_set = DATASET * 2
dev_filename = tmp_path_factory.mktemp("data") / "dev.json"
with open(dev_filename, "w", encoding="utf-8") as fout:
json.dump(dev_set, fout, ensure_ascii=False)
return dev_filename
@pytest.fixture(scope="module")
def test_file(tmp_path_factory):
test_set = DATASET
test_filename = tmp_path_factory.mktemp("data") / "test.json"
with open(test_filename, "w", encoding="utf-8") as fout:
json.dump(test_set, fout, ensure_ascii=False)
return test_filename
@pytest.fixture(scope="module")
def train_file_with_trees(tmp_path_factory):
train_set = DATASET_WITH_TREES * 20
train_filename = tmp_path_factory.mktemp("data") / "train_trees.json"
with open(train_filename, "w", encoding="utf-8") as fout:
json.dump(train_set, fout, ensure_ascii=False)
return train_filename
@pytest.fixture(scope="module")
def dev_file_with_trees(tmp_path_factory):
dev_set = DATASET_WITH_TREES * 2
dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json"
with open(dev_filename, "w", encoding="utf-8") as fout:
json.dump(dev_set, fout, ensure_ascii=False)
return dev_filename
class TestClassifierData:
def test_read_data(self, train_file):
"""
Test reading of the json format
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
assert len(train_set) == 60
def test_read_data_with_trees(self, train_file, train_file_with_trees):
"""
Test reading of the json format
"""
train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1)
assert len(train_trees_set) == 60
for idx, x in enumerate(train_trees_set):
assert isinstance(x.constituency, Tree)
assert str(x.constituency) == TREES[idx % len(TREES)]
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
def test_dataset_vocab(self, train_file):
"""
Converting a dataset to vocab should have a specific set of words along with PAD and UNK
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
vocab = data.dataset_vocab(train_set)
expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y])
assert set(vocab) == expected
def test_dataset_labels(self, train_file):
"""
Test the extraction of labels from a dataset
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
labels = data.dataset_labels(train_set)
assert labels == ["0", "1", "2"]
def test_sort_by_length(self, train_file):
"""
There are two unique lengths in the toy dataset
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
sorted_dataset = data.sort_dataset_by_len(train_set)
assert list(sorted_dataset.keys()) == [4, 5]
assert len(sorted_dataset[4]) == len(train_set) // 3
assert len(sorted_dataset[5]) == 2 * len(train_set) // 3
def test_check_labels(self, train_file):
"""
Check that an exception is thrown for an unknown label
"""
train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
labels = sorted(set([x["sentiment"] for x in DATASET]))
assert len(labels) > 1
data.check_labels(labels, train_set)
with pytest.raises(RuntimeError):
data.check_labels(labels[:1], train_set)
|