stanza-digphil / stanza /tests /depparse /test_depparse_data.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""
Test some pieces of the depparse dataloader
"""
import pytest
from stanza.models.depparse.data import data_to_batches
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def make_fake_data(*lengths):
data = []
for i, length in enumerate(lengths):
word = chr(ord('A') + i)
chunk = [[word] * length]
data.append(chunk)
return data
def check_batches(batched_data, expected_sizes, expected_order):
for chunk, size in zip(batched_data, expected_sizes):
assert sum(len(x[0]) for x in chunk) == size
word_order = []
for chunk in batched_data:
for sentence in chunk:
word_order.append(sentence[0][0])
assert word_order == expected_order
def test_data_to_batches_eval_mode():
"""
Tests the chunking of batches in eval_mode
A few options are tested, such as whether or not to sort and the maximum sentence size
"""
data = make_fake_data(1, 2, 3)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [5, 1], ['C', 'B', 'A'])
data = make_fake_data(1, 2, 6)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [6, 3], ['C', 'B', 'A'])
data = make_fake_data(3, 2, 1)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [5, 1], ['A', 'B', 'C'])
data = make_fake_data(3, 5, 2)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
check_batches(batched_data[0], [5, 5], ['B', 'A', 'C'])
data = make_fake_data(3, 5, 2)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
check_batches(batched_data[0], [3, 5, 2], ['A', 'B', 'C'])
data = make_fake_data(4, 1, 1)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
check_batches(batched_data[0], [4, 2], ['A', 'B', 'C'])
data = make_fake_data(1, 4, 1)
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
check_batches(batched_data[0], [1, 4, 1], ['A', 'B', 'C'])
if __name__ == '__main__':
test_data_to_batches()