Albin Thörn Cleland
Clean initial commit with LFS
19b8775
import random
import logging
import torch
from stanza.models.common.bert_embedding import filter_data, needs_length_filter
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX, ROOT_ID, CompositeVocab, CharVocab
from stanza.models.pos.vocab import WordVocab, XPOSVocab, FeatureVocab, MultiVocab
from stanza.models.pos.xpos_vocab_factory import xpos_vocab_factory
from stanza.models.common.doc import *
logger = logging.getLogger('stanza')
def data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately):
"""
Given a list of lists, where the first element of each sublist
represents the sentence, group the sentences into batches.
During training mode (not eval_mode) the sentences are sorted by
length with a bit of random shuffling. During eval mode, the
sentences are sorted by length if sort_during_eval is true.
Refactored from the data structure in case other models could use
it and for ease of testing.
Returns (batches, original_order), where original_order is None
when in train mode or when unsorted and represents the original
location of each sentence in the sort
"""
res = []
if not eval_mode:
# sort sentences (roughly) by length for better memory utilization
data = sorted(data, key = lambda x: len(x[0]), reverse=random.random() > .5)
data_orig_idx = None
elif sort_during_eval:
(data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data])
else:
data_orig_idx = None
current = []
currentlen = 0
for x in data:
if min_length_to_batch_separately is not None and len(x[0]) > min_length_to_batch_separately:
if currentlen > 0:
res.append(current)
current = []
currentlen = 0
res.append([x])
else:
if len(x[0]) + currentlen > batch_size and currentlen > 0:
res.append(current)
current = []
currentlen = 0
current.append(x)
currentlen += len(x[0])
if currentlen > 0:
res.append(current)
return res, data_orig_idx
class DataLoader:
def __init__(self, doc, batch_size, args, pretrain, vocab=None, evaluation=False, sort_during_eval=False, min_length_to_batch_separately=None, bert_tokenizer=None):
self.batch_size = batch_size
self.min_length_to_batch_separately=min_length_to_batch_separately
self.args = args
self.eval = evaluation
self.shuffled = not self.eval
self.sort_during_eval = sort_during_eval
self.doc = doc
data = self.load_doc(doc)
# handle vocab
if vocab is None:
self.vocab = self.init_vocab(data)
else:
self.vocab = vocab
# filter out the long sentences if bert is used
if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']):
data = filter_data(self.args['bert_model'], data, bert_tokenizer)
# handle pretrain; pretrain vocab is used when args['pretrain'] == True and pretrain is not None
self.pretrain_vocab = None
if pretrain is not None and args['pretrain']:
self.pretrain_vocab = pretrain.vocab
# filter and sample data
if args.get('sample_train', 1.0) < 1.0 and not self.eval:
keep = int(args['sample_train'] * len(data))
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
data = self.preprocess(data, self.vocab, self.pretrain_vocab, args)
# shuffle for training
if self.shuffled:
random.shuffle(data)
self.num_examples = len(data)
# chunk into batches
self.data = self.chunk_batches(data)
logger.debug("{} batches created.".format(len(self.data)))
def init_vocab(self, data):
assert self.eval == False # for eval vocab must exist
charvocab = CharVocab(data, self.args['shorthand'])
wordvocab = WordVocab(data, self.args['shorthand'], cutoff=7, lower=True)
uposvocab = WordVocab(data, self.args['shorthand'], idx=1)
xposvocab = xpos_vocab_factory(data, self.args['shorthand'])
featsvocab = FeatureVocab(data, self.args['shorthand'], idx=3)
lemmavocab = WordVocab(data, self.args['shorthand'], cutoff=7, idx=4, lower=True)
deprelvocab = WordVocab(data, self.args['shorthand'], idx=6)
vocab = MultiVocab({'char': charvocab,
'word': wordvocab,
'upos': uposvocab,
'xpos': xposvocab,
'feats': featsvocab,
'lemma': lemmavocab,
'deprel': deprelvocab})
return vocab
def preprocess(self, data, vocab, pretrain_vocab, args):
processed = []
xpos_replacement = [[ROOT_ID] * len(vocab['xpos'])] if isinstance(vocab['xpos'], CompositeVocab) else [ROOT_ID]
feats_replacement = [[ROOT_ID] * len(vocab['feats'])]
for sent in data:
processed_sent = [[ROOT_ID] + vocab['word'].map([w[0] for w in sent])]
processed_sent += [[[ROOT_ID]] + [vocab['char'].map([x for x in w[0]]) for w in sent]]
processed_sent += [[ROOT_ID] + vocab['upos'].map([w[1] for w in sent])]
processed_sent += [xpos_replacement + vocab['xpos'].map([w[2] for w in sent])]
processed_sent += [feats_replacement + vocab['feats'].map([w[3] for w in sent])]
if pretrain_vocab is not None:
# always use lowercase lookup in pretrained vocab
processed_sent += [[ROOT_ID] + pretrain_vocab.map([w[0].lower() for w in sent])]
else:
processed_sent += [[ROOT_ID] + [PAD_ID] * len(sent)]
processed_sent += [[ROOT_ID] + vocab['lemma'].map([w[4] for w in sent])]
processed_sent += [[to_int(w[5], ignore_error=self.eval) for w in sent]]
processed_sent += [vocab['deprel'].map([w[6] for w in sent])]
processed_sent.append([w[0] for w in sent])
processed.append(processed_sent)
return processed
def __len__(self):
return len(self.data)
def __getitem__(self, key):
""" Get a batch with index. """
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 10
# sort sentences by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, lens)
# sort words by lens for easy char-RNN operations
batch_words = [w for sent in batch[1] for w in sent]
word_lens = [len(x) for x in batch_words]
batch_words, word_orig_idx = sort_all([batch_words], word_lens)
batch_words = batch_words[0]
word_lens = [len(x) for x in batch_words]
# convert to tensors
words = batch[0]
words = get_long_tensor(words, batch_size)
words_mask = torch.eq(words, PAD_ID)
wordchars = get_long_tensor(batch_words, len(word_lens))
wordchars_mask = torch.eq(wordchars, PAD_ID)
upos = get_long_tensor(batch[2], batch_size)
xpos = get_long_tensor(batch[3], batch_size)
ufeats = get_long_tensor(batch[4], batch_size)
pretrained = get_long_tensor(batch[5], batch_size)
sentlens = [len(x) for x in batch[0]]
lemma = get_long_tensor(batch[6], batch_size)
head = get_long_tensor(batch[7], batch_size)
deprel = get_long_tensor(batch[8], batch_size)
text = batch[9]
return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, orig_idx, word_orig_idx, sentlens, word_lens, text
def load_doc(self, doc):
data = doc.get([TEXT, UPOS, XPOS, FEATS, LEMMA, HEAD, DEPREL], as_sentences=True)
data = self.resolve_none(data)
return data
def resolve_none(self, data):
# replace None to '_'
for sent_idx in range(len(data)):
for tok_idx in range(len(data[sent_idx])):
for feat_idx in range(len(data[sent_idx][tok_idx])):
if data[sent_idx][tok_idx][feat_idx] is None:
data[sent_idx][tok_idx][feat_idx] = '_'
return data
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def set_batch_size(self, batch_size):
self.batch_size = batch_size
def reshuffle(self):
data = [y for x in self.data for y in x]
self.data = self.chunk_batches(data)
random.shuffle(self.data)
def chunk_batches(self, data):
batches, data_orig_idx = data_to_batches(data=data, batch_size=self.batch_size,
eval_mode=self.eval, sort_during_eval=self.sort_during_eval,
min_length_to_batch_separately=self.min_length_to_batch_separately)
# data_orig_idx might be None at train time, since we don't anticipate unsorting
self.data_orig_idx = data_orig_idx
return batches
def to_int(string, ignore_error=False):
try:
res = int(string)
except ValueError as err:
if ignore_error:
return 0
else:
raise err
return res