|
|
import pytest |
|
|
import stanza |
|
|
|
|
|
from stanza.tests import * |
|
|
from stanza.models.common.data import get_augment_ratio, augment_punct |
|
|
|
|
|
pytestmark = [pytest.mark.travis, pytest.mark.pipeline] |
|
|
|
|
|
def test_augment_ratio(): |
|
|
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
|
|
should_augment = lambda x: x >= 3 |
|
|
can_augment = lambda x: x >= 4 |
|
|
|
|
|
|
|
|
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.1) == 0.0 |
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError): |
|
|
get_augment_ratio(data, can_augment, should_augment) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert get_augment_ratio(data, should_augment, can_augment, desired_ratio=0.4) == pytest.approx(2/7) |
|
|
|
|
|
def test_augment_punct(): |
|
|
data = [["Simple", "test", "."]] |
|
|
should_augment = lambda x: x[-1] == "." |
|
|
can_augment = should_augment |
|
|
new_data = augment_punct(data, 1.0, should_augment, can_augment) |
|
|
assert new_data == [["Simple", "test"]] |
|
|
|