|
|
import lzma |
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
import pytest |
|
|
|
|
|
import stanza |
|
|
import stanza.models.common.utils as utils |
|
|
from stanza.tests import * |
|
|
|
|
|
pytestmark = [pytest.mark.travis, pytest.mark.pipeline] |
|
|
|
|
|
def test_wordvec_not_found(): |
|
|
""" |
|
|
get_wordvec_file should fail if neither word2vec nor fasttext exists |
|
|
""" |
|
|
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: |
|
|
with pytest.raises(FileNotFoundError): |
|
|
utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') |
|
|
|
|
|
|
|
|
def test_word2vec_xz(): |
|
|
""" |
|
|
Test searching for word2vec and xz files |
|
|
""" |
|
|
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: |
|
|
|
|
|
word2vec_dir = os.path.join(temp_dir, 'word2vec', 'English') |
|
|
os.makedirs(word2vec_dir) |
|
|
|
|
|
|
|
|
fake_file = os.path.join(word2vec_dir, 'en.vectors.xz') |
|
|
fout = open(fake_file, 'w') |
|
|
fout.close() |
|
|
|
|
|
|
|
|
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') |
|
|
assert filename == fake_file |
|
|
|
|
|
def test_fasttext_txt(): |
|
|
""" |
|
|
Test searching for fasttext and txt files |
|
|
""" |
|
|
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: |
|
|
|
|
|
fasttext_dir = os.path.join(temp_dir, 'fasttext', 'English') |
|
|
os.makedirs(fasttext_dir) |
|
|
|
|
|
|
|
|
fake_file = os.path.join(fasttext_dir, 'en.vectors.txt') |
|
|
fout = open(fake_file, 'w') |
|
|
fout.close() |
|
|
|
|
|
|
|
|
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') |
|
|
assert filename == fake_file |
|
|
|
|
|
def test_wordvec_type(): |
|
|
""" |
|
|
If we supply our own wordvec type, get_wordvec_file should find that |
|
|
""" |
|
|
with tempfile.TemporaryDirectory(dir=f'{TEST_WORKING_DIR}/out') as temp_dir: |
|
|
|
|
|
google_dir = os.path.join(temp_dir, 'google', 'English') |
|
|
os.makedirs(google_dir) |
|
|
|
|
|
|
|
|
fake_file = os.path.join(google_dir, 'en.vectors.txt') |
|
|
fout = open(fake_file, 'w') |
|
|
fout.close() |
|
|
|
|
|
|
|
|
filename = utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo', wordvec_type='google') |
|
|
assert filename == fake_file |
|
|
|
|
|
|
|
|
with pytest.raises(FileNotFoundError): |
|
|
utils.get_wordvec_file(wordvec_dir=temp_dir, shorthand='en_foo') |
|
|
|
|
|
def test_sort_with_indices(): |
|
|
data = [[1, 2, 3], [4, 5], [6]] |
|
|
ordered, orig_idx = utils.sort_with_indices(data, key=len) |
|
|
assert ordered == ([6], [4, 5], [1, 2, 3]) |
|
|
assert orig_idx == (2, 1, 0) |
|
|
|
|
|
unsorted = utils.unsort(ordered, orig_idx) |
|
|
assert data == unsorted |
|
|
|
|
|
def test_empty_sort_with_indices(): |
|
|
ordered, orig_idx = utils.sort_with_indices([]) |
|
|
assert len(ordered) == 0 |
|
|
assert len(orig_idx) == 0 |
|
|
|
|
|
unsorted = utils.unsort(ordered, orig_idx) |
|
|
assert [] == unsorted |
|
|
|
|
|
|
|
|
def test_split_into_batches(): |
|
|
data = [] |
|
|
for i in range(5): |
|
|
data.append(["Unban", "mox", "opal", str(i)]) |
|
|
|
|
|
data.append(["Do", "n't", "ban", "Urza", "'s", "Saga", "that", "card", "is", "great"]) |
|
|
data.append(["Ban", "Ragavan"]) |
|
|
|
|
|
|
|
|
batches = utils.split_into_batches(data, 5) |
|
|
assert batches == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)] |
|
|
|
|
|
|
|
|
batches = utils.split_into_batches(data, 8) |
|
|
assert batches == [(0, 2), (2, 4), (4, 5), (5, 6), (6, 7)] |
|
|
|
|
|
|
|
|
batches = utils.split_into_batches(data[1:], 8) |
|
|
assert batches == [(0, 2), (2, 4), (4, 5), (5, 6)] |
|
|
|
|
|
|
|
|
ordered, orig_idx = utils.sort_with_indices(data, key=len, reverse=True) |
|
|
assert [len(x) for x in ordered] == [10, 4, 4, 4, 4, 4, 2] |
|
|
|
|
|
|
|
|
batches = utils.split_into_batches(ordered, 8) |
|
|
assert batches == [(0, 1), (1, 3), (3, 5), (5, 7)] |
|
|
|
|
|
|
|
|
assert data == utils.unsort(ordered, orig_idx) |
|
|
|
|
|
|
|
|
def test_find_missing_tags(): |
|
|
assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC"]) == [] |
|
|
assert utils.find_missing_tags(["O", "PER", "LOC"], ["O", "PER", "LOC", "ORG"]) == ['ORG'] |
|
|
assert utils.find_missing_tags([["O", "PER"], ["O", "LOC"]], [["O", "PER"], ["LOC", "ORG"]]) == ['ORG'] |
|
|
|
|
|
|
|
|
def test_open_read_text(): |
|
|
""" |
|
|
test that we can read either .xz or regular txt |
|
|
""" |
|
|
TEXT = "this is a test" |
|
|
with tempfile.TemporaryDirectory() as tempdir: |
|
|
|
|
|
filename = os.path.join(tempdir, "foo.txt") |
|
|
with open(filename, "w") as fout: |
|
|
fout.write(TEXT) |
|
|
with utils.open_read_text(filename) as fin: |
|
|
in_text = fin.read() |
|
|
assert TEXT == in_text |
|
|
|
|
|
assert fin.closed |
|
|
|
|
|
|
|
|
try: |
|
|
with utils.open_read_text(filename) as finex: |
|
|
assert not finex.closed |
|
|
raise ValueError("unban mox opal!") |
|
|
except ValueError: |
|
|
pass |
|
|
assert finex.closed |
|
|
|
|
|
|
|
|
filename = os.path.join(tempdir, "foo.txt.xz") |
|
|
with lzma.open(filename, "wt") as fout: |
|
|
fout.write(TEXT) |
|
|
with utils.open_read_text(filename) as finxz: |
|
|
in_text = finxz.read() |
|
|
assert TEXT == in_text |
|
|
|
|
|
assert finxz.closed |
|
|
|
|
|
|
|
|
try: |
|
|
with utils.open_read_text(filename) as finexxz: |
|
|
assert not finexxz.closed |
|
|
raise ValueError("unban mox opal!") |
|
|
except ValueError: |
|
|
pass |
|
|
assert finexxz.closed |
|
|
|
|
|
|
|
|
def test_checkpoint_name(): |
|
|
""" |
|
|
Test some expected results for the checkpoint names |
|
|
""" |
|
|
|
|
|
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm.pt", None) |
|
|
assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint.pt") |
|
|
|
|
|
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", None) |
|
|
assert os.path.split(checkpoint) == ("saved_models", "kk_oscar_forward_charlm_checkpoint") |
|
|
|
|
|
checkpoint = utils.checkpoint_name("saved_models", "kk_oscar_forward_charlm", "othername.pt") |
|
|
assert os.path.split(checkpoint) == ("saved_models", "othername.pt") |
|
|
|
|
|
|