File size: 3,007 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
import pytest

from stanza.models import ner_tagger
from stanza.models.common.doc import Document
from stanza.models.ner.data import DataLoader
from stanza.tests import TEST_WORKING_DIR

pytestmark = [pytest.mark.travis, pytest.mark.pipeline]


ONE_SENTENCE = """
[
 [
  {
   "text": "EU",
   "ner": "B-ORG"
  },
  {
   "text": "rejects",
   "ner": "O"
  },
  {
   "text": "German",
   "ner": "B-MISC"
  },
  {
   "text": "call",
   "ner": "O"
  },
  {
   "text": "to",
   "ner": "O"
  },
  {
   "text": "boycott",
   "ner": "O"
  },
  {
   "text": "Mox",
   "ner": "B-MISC"
  },
  {
   "text": "Opal",
   "ner": "I-MISC"
  },
  {
   "text": ".",
   "ner": "O"
  }
 ]
]
"""

@pytest.fixture(scope="module")
def pretrain_file():
    return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'


@pytest.fixture(scope="module")
def one_sentence_json_path(tmpdir_factory):
    filename = tmpdir_factory.mktemp('data').join("sentence.json")
    with open(filename, 'w') as fout:
        fout.write(ONE_SENTENCE)
    return filename


def test_build_vocab(pretrain_file, one_sentence_json_path, tmp_path):
    """
    Test that when loading a data file, we get back 
    """
    args = ner_tagger.parse_args(["--wordvec_pretrain_file", pretrain_file])
    pt = ner_tagger.load_pretrain(args)

    with open(one_sentence_json_path) as fin:
        train_doc = Document(json.load(fin))

    train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])

    vocab = train_batch.vocab
    pt_words = list(vocab['word'])
    assert pt_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'unban', 'mox', 'opal']
    delta_words = list(vocab['delta'])
    assert delta_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'eu', 'rejects', 'german', 'call', 'to', 'boycott', 'mox', 'opal', '.']
    tags = list(vocab['tag'])
    assert tags == [['<PAD>'], ['<UNK>'], [], ['<ROOT>'], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']]


def test_build_vocab_ignore_repeats(pretrain_file, one_sentence_json_path, tmp_path):
    """
    Test that when loading a data file, we get back 
    """
    args = ner_tagger.parse_args(["--wordvec_pretrain_file", pretrain_file, "--emb_finetune_known_only"])
    pt = ner_tagger.load_pretrain(args)

    with open(one_sentence_json_path) as fin:
        train_doc = Document(json.load(fin))

    train_batch = DataLoader(train_doc, args['batch_size'], args, pt, vocab=None, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])

    vocab = train_batch.vocab
    pt_words = list(vocab['word'])
    assert pt_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'unban', 'mox', 'opal']
    delta_words = list(vocab['delta'])
    assert delta_words == ['<PAD>', '<UNK>', '<EMPTY>', '<ROOT>', 'mox', 'opal']
    tags = list(vocab['tag'])
    assert tags == [['<PAD>'], ['<UNK>'], [], ['<ROOT>'], ['S-ORG'], ['O'], ['S-MISC'], ['B-MISC'], ['E-MISC']]