File size: 991 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pytest
import torch

from stanza.models.common.bert_embedding import load_bert, extract_bert_embeddings

pytestmark = [pytest.mark.travis, pytest.mark.pipeline]

BERT_MODEL = "hf-internal-testing/tiny-bert"

@pytest.fixture(scope="module")
def tiny_bert():
    m, t = load_bert(BERT_MODEL)
    return m, t

def test_load_bert(tiny_bert):
    """
    Empty method that just tests loading the bert
    """
    m, t = tiny_bert

def test_run_bert(tiny_bert):
    m, t = tiny_bert
    device = next(m.parameters()).device
    extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "a", "test"]], device, True)

def test_run_bert_empty_word(tiny_bert):
    m, t = tiny_bert
    device = next(m.parameters()).device
    foo = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "-", "a", "test"]], device, True)
    bar = extract_bert_embeddings(BERT_MODEL, t, m, [["This", "is", "", "a", "test"]], device, True)

    assert len(foo) == 1
    assert torch.allclose(foo[0], bar[0])