|
|
import random |
|
|
import logging |
|
|
import torch |
|
|
|
|
|
from stanza.models.common.bert_embedding import filter_data, needs_length_filter |
|
|
from stanza.models.common.data import map_to_ids, get_long_tensor, sort_all |
|
|
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX |
|
|
from stanza.models.pos.vocab import CharVocab, CompositeVocab, WordVocab |
|
|
from stanza.models.ner.vocab import MultiVocab |
|
|
from stanza.models.common.doc import * |
|
|
from stanza.models.ner.utils import process_tags, normalize_empty_tags |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
class DataLoader: |
|
|
def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation=False, preprocess_tags=True, bert_tokenizer=None, scheme=None, max_batch_words=None): |
|
|
self.max_batch_words = max_batch_words |
|
|
self.batch_size = batch_size |
|
|
self.args = args |
|
|
self.eval = evaluation |
|
|
self.shuffled = not self.eval |
|
|
self.doc = doc |
|
|
self.preprocess_tags = preprocess_tags |
|
|
|
|
|
data = self._load_doc(self.doc, scheme) |
|
|
|
|
|
|
|
|
if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']): |
|
|
data = filter_data(self.args['bert_model'], data, bert_tokenizer) |
|
|
|
|
|
self.tags = [[w[1] for w in sent] for sent in data] |
|
|
|
|
|
self.pretrain = pretrain |
|
|
if vocab is None: |
|
|
self.vocab = self.init_vocab(data) |
|
|
else: |
|
|
self.vocab = vocab |
|
|
|
|
|
|
|
|
if args.get('sample_train', 1.0) < 1.0 and not self.eval: |
|
|
keep = int(args['sample_train'] * len(data)) |
|
|
data = random.sample(data, keep) |
|
|
logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) |
|
|
|
|
|
data = self.preprocess(data, self.vocab, args) |
|
|
|
|
|
if self.shuffled: |
|
|
random.shuffle(data) |
|
|
self.num_examples = len(data) |
|
|
|
|
|
|
|
|
self.data = self.chunk_batches(data) |
|
|
logger.debug("{} batches created.".format(len(self.data))) |
|
|
|
|
|
def init_vocab(self, data): |
|
|
def from_model(model_filename): |
|
|
""" Try loading vocab from charLM model file. """ |
|
|
state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True) |
|
|
if 'vocab' in state_dict: |
|
|
return state_dict['vocab'] |
|
|
if 'model' in state_dict and 'vocab' in state_dict['model']: |
|
|
return state_dict['model']['vocab'] |
|
|
raise ValueError("Cannot find vocab in charLM model file %s" % model_filename) |
|
|
|
|
|
if self.eval: |
|
|
raise AssertionError("Vocab must exist for evaluation.") |
|
|
if self.args['charlm']: |
|
|
charvocab = CharVocab.load_state_dict(from_model(self.args['charlm_forward_file'])) |
|
|
else: |
|
|
charvocab = CharVocab(data, self.args['shorthand']) |
|
|
wordvocab = self.pretrain.vocab if self.pretrain is not None else None |
|
|
tag_data = [[(x[1],) for x in sentence] for sentence in data] |
|
|
tagvocab = CompositeVocab(tag_data, self.args['shorthand'], idx=0, sep=None) |
|
|
ignore = None |
|
|
if self.args['emb_finetune_known_only']: |
|
|
if self.pretrain is None: |
|
|
raise ValueError("Cannot train emb_finetune_known_only with no pretrain of known words") |
|
|
if self.args['lowercase']: |
|
|
ignore = set([w[0].lower() for sent in data for w in sent if w[0] not in wordvocab and w[0].lower() not in wordvocab]) |
|
|
else: |
|
|
ignore = set([w[0] for sent in data for w in sent if w[0] not in wordvocab]) |
|
|
logger.debug("Ignoring %d in the delta vocab as they did not appear in the original embedding", len(ignore)) |
|
|
deltavocab = WordVocab(data, self.args['shorthand'], cutoff=1, lower=self.args['lowercase'], ignore=ignore) |
|
|
logger.debug("Creating delta vocab of size %s", len(deltavocab)) |
|
|
vocabs = {'char': charvocab, |
|
|
'delta': deltavocab, |
|
|
'tag': tagvocab} |
|
|
if wordvocab is not None: |
|
|
vocabs['word'] = wordvocab |
|
|
vocab = MultiVocab(vocabs) |
|
|
return vocab |
|
|
|
|
|
def preprocess(self, data, vocab, args): |
|
|
processed = [] |
|
|
if args.get('char_lowercase', False): |
|
|
char_case = lambda x: x.lower() |
|
|
else: |
|
|
char_case = lambda x: x |
|
|
for sent_idx, sent in enumerate(data): |
|
|
processed_sent = [[w[0] for w in sent]] |
|
|
processed_sent += [[vocab['char'].map([char_case(x) for x in w[0]]) for w in sent]] |
|
|
processed_sent += [vocab['tag'].map([w[1] for w in sent])] |
|
|
processed.append(processed_sent) |
|
|
return processed |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, key): |
|
|
""" Get a batch with index. """ |
|
|
if not isinstance(key, int): |
|
|
raise TypeError |
|
|
if key < 0 or key >= len(self.data): |
|
|
raise IndexError |
|
|
batch = self.data[key] |
|
|
batch_size = len(batch) |
|
|
batch = list(zip(*batch)) |
|
|
assert len(batch) == 3 |
|
|
|
|
|
|
|
|
sentlens = [len(x) for x in batch[0]] |
|
|
batch, orig_idx = sort_all(batch, sentlens) |
|
|
sentlens = [len(x) for x in batch[0]] |
|
|
|
|
|
|
|
|
chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(batch[1]) |
|
|
chars_sorted, char_orig_idx = sort_all([chars_forward, chars_backward, charoffsets_forward, charoffsets_backward], charlens) |
|
|
chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted |
|
|
charlens = [len(sent) for sent in chars_forward] |
|
|
|
|
|
|
|
|
batch_words = [w for sent in batch[1] for w in sent] |
|
|
wordlens = [len(x) for x in batch_words] |
|
|
batch_words, word_orig_idx = sort_all([batch_words], wordlens) |
|
|
batch_words = batch_words[0] |
|
|
wordlens = [len(x) for x in batch_words] |
|
|
|
|
|
words = batch[0] |
|
|
|
|
|
wordchars = get_long_tensor(batch_words, len(wordlens)) |
|
|
wordchars_mask = torch.eq(wordchars, PAD_ID) |
|
|
chars_forward = get_long_tensor(chars_forward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) |
|
|
chars_backward = get_long_tensor(chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) |
|
|
chars = torch.cat([chars_forward.unsqueeze(0), chars_backward.unsqueeze(0)]) |
|
|
charoffsets = [charoffsets_forward, charoffsets_backward] |
|
|
tags = get_long_tensor(batch[2], batch_size) |
|
|
|
|
|
return words, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets |
|
|
|
|
|
def __iter__(self): |
|
|
for i in range(self.__len__()): |
|
|
yield self.__getitem__(i) |
|
|
|
|
|
def _load_doc(self, doc, scheme): |
|
|
|
|
|
|
|
|
data = doc.get([TEXT, NER, MULTI_NER], as_sentences=True, from_token=True) |
|
|
data = [[[token[0], token[2]] if token[2] else [token[0], (token[1],)] for token in sentence] for sentence in data] |
|
|
if self.preprocess_tags: |
|
|
if scheme is None: |
|
|
data = process_tags(data, self.args.get('scheme', 'bio')) |
|
|
data = normalize_empty_tags(data) |
|
|
return data |
|
|
|
|
|
def process_chars(self, sents): |
|
|
start_id, end_id = self.vocab['char'].unit2id('\n'), self.vocab['char'].unit2id(' ') |
|
|
start_offset, end_offset = 1, 1 |
|
|
chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = [], [], [], [] |
|
|
|
|
|
for sent in sents: |
|
|
chars_forward_sent, chars_backward_sent, charoffsets_forward_sent, charoffsets_backward_sent = [start_id], [start_id], [], [] |
|
|
|
|
|
for word in sent: |
|
|
chars_forward_sent += word |
|
|
charoffsets_forward_sent = charoffsets_forward_sent + [len(chars_forward_sent)] |
|
|
chars_forward_sent += [end_id] |
|
|
|
|
|
for word in sent[::-1]: |
|
|
chars_backward_sent += word[::-1] |
|
|
charoffsets_backward_sent = [len(chars_backward_sent)] + charoffsets_backward_sent |
|
|
chars_backward_sent += [end_id] |
|
|
|
|
|
chars_forward.append(chars_forward_sent) |
|
|
chars_backward.append(chars_backward_sent) |
|
|
charoffsets_forward.append(charoffsets_forward_sent) |
|
|
charoffsets_backward.append(charoffsets_backward_sent) |
|
|
charlens = [len(sent) for sent in chars_forward] |
|
|
return chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens |
|
|
|
|
|
def reshuffle(self): |
|
|
data = [y for x in self.data for y in x] |
|
|
random.shuffle(data) |
|
|
self.data = self.chunk_batches(data) |
|
|
|
|
|
def chunk_batches(self, data): |
|
|
if self.max_batch_words is None: |
|
|
return [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)] |
|
|
batches = [] |
|
|
next_batch = [] |
|
|
for item in data: |
|
|
next_batch.append(item) |
|
|
if len(next_batch) >= self.batch_size: |
|
|
batches.append(next_batch) |
|
|
next_batch = [] |
|
|
if sum(len(x[0]) for x in next_batch) >= self.max_batch_words: |
|
|
batches.append(next_batch) |
|
|
next_batch = [] |
|
|
if len(next_batch) > 0: |
|
|
batches.append(next_batch) |
|
|
return batches |
|
|
|