|
|
"""Functions related to BERT or similar models""" |
|
|
|
|
|
import logging |
|
|
from typing import List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
from stanza.models.coref.config import Config |
|
|
from stanza.models.coref.const import Doc |
|
|
|
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
def get_subwords_batches(doc: Doc, |
|
|
config: Config, |
|
|
tok: AutoTokenizer |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Turns a list of subwords to a list of lists of subword indices |
|
|
of max length == batch_size (or shorter, as batch boundaries |
|
|
should match sentence boundaries). Each batch is enclosed in cls and sep |
|
|
special tokens. |
|
|
|
|
|
Returns: |
|
|
batches of bert tokens [n_batches, batch_size] |
|
|
""" |
|
|
batch_size = config.bert_window_size - 2 |
|
|
|
|
|
subwords: List[str] = doc["subwords"] |
|
|
subwords_batches = [] |
|
|
start, end = 0, 0 |
|
|
|
|
|
while end < len(subwords): |
|
|
|
|
|
|
|
|
|
|
|
prev_end = end |
|
|
end = min(end + batch_size, len(subwords)) |
|
|
|
|
|
|
|
|
if end < len(subwords): |
|
|
sent_id = doc["sent_id"][doc["word_id"][end]] |
|
|
while end and doc["sent_id"][doc["word_id"][end - 1]] == sent_id: |
|
|
end -= 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if end == prev_end: |
|
|
end = min(end + batch_size, len(subwords)) |
|
|
|
|
|
length = end - start |
|
|
if tok.cls_token == None or tok.sep_token == None: |
|
|
batch = [tok.eos_token] + subwords[start:end] + [tok.eos_token] |
|
|
else: |
|
|
batch = [tok.cls_token] + subwords[start:end] + [tok.sep_token] |
|
|
|
|
|
|
|
|
batch += [tok.pad_token] * (batch_size - length) |
|
|
|
|
|
subwords_batches.append([tok.convert_tokens_to_ids(token) |
|
|
for token in batch]) |
|
|
start += length |
|
|
|
|
|
return np.array(subwords_batches) |
|
|
|