File size: 4,949 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import pytest

import stanza.models.classifiers.data as data
from stanza.models.classifiers.utils import WVType
from stanza.models.common.vocab import PAD, UNK
from stanza.models.constituency.parse_tree import Tree

SENTENCES = [
    ["I", "hate", "the", "Opal", "banning"],
    ["Tell", "my", "wife", "hello"], # obviously this is the neutral result
    ["I", "like", "Sh'reyan", "'s", "antennae"],
]

DATASET = [
    {"sentiment": "0", "text": SENTENCES[0]},
    {"sentiment": "1", "text": SENTENCES[1]},
    {"sentiment": "2", "text": SENTENCES[2]},
]

TREES = [
    "(ROOT (S (NP (PRP I)) (VP (VBP hate) (NP (DT the) (NN Opal) (NN banning)))))",
    "(ROOT (S (VP (VB Tell) (NP (PRP$ my) (NN wife)) (NP (UH hello)))))",
    "(ROOT (S (NP (PRP I)) (VP (VBP like) (NP (NP (NNP Sh'reyan) (POS 's)) (NNS antennae)))))",
]

DATASET_WITH_TREES = [
    {"sentiment": "0", "text": SENTENCES[0], "constituency": TREES[0]},
    {"sentiment": "1", "text": SENTENCES[1], "constituency": TREES[1]},
    {"sentiment": "2", "text": SENTENCES[2], "constituency": TREES[2]},
]

@pytest.fixture(scope="module")
def train_file(tmp_path_factory):
    train_set = DATASET * 20
    train_filename = tmp_path_factory.mktemp("data") / "train.json"
    with open(train_filename, "w", encoding="utf-8") as fout:
        json.dump(train_set, fout, ensure_ascii=False)
    return train_filename

@pytest.fixture(scope="module")
def dev_file(tmp_path_factory):
    dev_set = DATASET * 2
    dev_filename = tmp_path_factory.mktemp("data") / "dev.json"
    with open(dev_filename, "w", encoding="utf-8") as fout:
        json.dump(dev_set, fout, ensure_ascii=False)
    return dev_filename

@pytest.fixture(scope="module")
def test_file(tmp_path_factory):
    test_set = DATASET
    test_filename = tmp_path_factory.mktemp("data") / "test.json"
    with open(test_filename, "w", encoding="utf-8") as fout:
        json.dump(test_set, fout, ensure_ascii=False)
    return test_filename

@pytest.fixture(scope="module")
def train_file_with_trees(tmp_path_factory):
    train_set = DATASET_WITH_TREES * 20
    train_filename = tmp_path_factory.mktemp("data") / "train_trees.json"
    with open(train_filename, "w", encoding="utf-8") as fout:
        json.dump(train_set, fout, ensure_ascii=False)
    return train_filename

@pytest.fixture(scope="module")
def dev_file_with_trees(tmp_path_factory):
    dev_set = DATASET_WITH_TREES * 2
    dev_filename = tmp_path_factory.mktemp("data") / "dev_trees.json"
    with open(dev_filename, "w", encoding="utf-8") as fout:
        json.dump(dev_set, fout, ensure_ascii=False)
    return dev_filename

class TestClassifierData:
    def test_read_data(self, train_file):
        """
        Test reading of the json format
        """
        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
        assert len(train_set) == 60

    def test_read_data_with_trees(self, train_file, train_file_with_trees):
        """
        Test reading of the json format
        """
        train_trees_set = data.read_dataset(str(train_file_with_trees), WVType.OTHER, 1)
        assert len(train_trees_set) == 60
        for idx, x in enumerate(train_trees_set):
            assert isinstance(x.constituency, Tree)
            assert str(x.constituency) == TREES[idx % len(TREES)]

        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)

    def test_dataset_vocab(self, train_file):
        """
        Converting a dataset to vocab should have a specific set of words along with PAD and UNK
        """
        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
        vocab = data.dataset_vocab(train_set)
        expected = set([PAD, UNK] + [x.lower() for y in SENTENCES for x in y])
        assert set(vocab) == expected

    def test_dataset_labels(self, train_file):
        """
        Test the extraction of labels from a dataset
        """
        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
        labels = data.dataset_labels(train_set)
        assert labels == ["0", "1", "2"]

    def test_sort_by_length(self, train_file):
        """
        There are two unique lengths in the toy dataset
        """
        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
        sorted_dataset = data.sort_dataset_by_len(train_set)
        assert list(sorted_dataset.keys()) == [4, 5]
        assert len(sorted_dataset[4]) == len(train_set) // 3
        assert len(sorted_dataset[5]) == 2 * len(train_set) // 3

    def test_check_labels(self, train_file):
        """
        Check that an exception is thrown for an unknown label
        """
        train_set = data.read_dataset(str(train_file), WVType.OTHER, 1)
        labels = sorted(set([x["sentiment"] for x in DATASET]))
        assert len(labels) > 1
        data.check_labels(labels, train_set)
        with pytest.raises(RuntimeError):
            data.check_labels(labels[:1], train_set)