Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""Functions related to BERT or similar models"""
import logging
from typing import List, Tuple
import numpy as np # type: ignore
from transformers import AutoModel, AutoTokenizer # type: ignore
from stanza.models.coref.config import Config
from stanza.models.coref.const import Doc
logger = logging.getLogger('stanza')
def get_subwords_batches(doc: Doc,
config: Config,
tok: AutoTokenizer
) -> np.ndarray:
"""
Turns a list of subwords to a list of lists of subword indices
of max length == batch_size (or shorter, as batch boundaries
should match sentence boundaries). Each batch is enclosed in cls and sep
special tokens.
Returns:
batches of bert tokens [n_batches, batch_size]
"""
batch_size = config.bert_window_size - 2 # to save space for CLS and SEP
subwords: List[str] = doc["subwords"]
subwords_batches = []
start, end = 0, 0
while end < len(subwords):
# to prevent the case where a batch_size step forward
# doesn't capture more than 1 sentence, we will just cut
# that sequence
prev_end = end
end = min(end + batch_size, len(subwords))
# Move back till we hit a sentence end
if end < len(subwords):
sent_id = doc["sent_id"][doc["word_id"][end]]
while end and doc["sent_id"][doc["word_id"][end - 1]] == sent_id:
end -= 1
# this occurs IFF there was no sentence end found throughout
# the forward scan; this means that our sentence was waay too
# long (i.e. longer than the max length of the transformer.
#
# if so, we give up and just chop the sentence off at the max length
# that was given
if end == prev_end:
end = min(end + batch_size, len(subwords))
length = end - start
if tok.cls_token == None or tok.sep_token == None:
batch = [tok.eos_token] + subwords[start:end] + [tok.eos_token]
else:
batch = [tok.cls_token] + subwords[start:end] + [tok.sep_token]
# Padding to desired length
batch += [tok.pad_token] * (batch_size - length)
subwords_batches.append([tok.convert_tokens_to_ids(token)
for token in batch])
start += length
return np.array(subwords_batches)