import pytest import torch from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings pytestmark = [pytest.mark.travis, pytest.mark.pipeline] BERT_MODEL = "hf-internal-testing/tiny-bert" @pytest.fixture(scope="module") def tiny_bert(): m, t = load_bert(BERT_MODEL) return m, t def test_load_bert(tiny_bert): """ Empty method that just tests loading the bert """ m, t = tiny_bert def test_run_bert(tiny_bert): m, t = tiny_bert device = next(m.parameters()).device extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True) def test_run_bert_empty_word(tiny_bert): m, t = tiny_bert device = next(m.parameters()).device foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True) bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True) assert len(foo) == 1 assert torch.allclose(foo[0], bar[0])