stanza-digphil / stanza /tests /common /test_common_data.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
import pytest
import stanza
from stanza.tests import *
from stanza.models.common.data import get_augment_ratio, augment_punct
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
def test_augment_ratio():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
should_augment = lambda x: x >= 3
can_augment = lambda x: x >= 4
# check that zero is returned if no augmentation is needed
# which will be the case since 2 are already satisfactory
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0
# this should throw an error
with pytest.raises(AssertionError):
get_augment_ratio(data, can_augment, should_augment)
# with a desired ratio of 0.4,
# there are already 2 that don't need augmenting
# and 7 that are eligible to be augmented
# so 2/7 will need to be augmented
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7)
def test_augment_punct():
data = [["Simple", "test", "."]]
should_augment = lambda x: x[-1] == "."
can_augment = should_augment
new_data = augment_punct(data, 1.0, should_augment, can_augment)
assert new_data == [["Simple", "test"]]