import random import logging import torch from stanza.models.common.bert_embedding import filter_data, needs_length_filter from stanza.models.common.data import map_to_ids, get_long_tensor, sort_all from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX from stanza.models.pos.vocab import CharVocab, CompositeVocab, WordVocab from stanza.models.ner.vocab import MultiVocab from stanza.models.common.doc import * from stanza.models.ner.utils import process_tags, normalize_empty_tags logger = logging.getLogger('stanza') class DataLoader: def __init__(self, doc, batch_size, args, pretrain=None, vocab=None, evaluation=False, preprocess_tags=True, bert_tokenizer=None, scheme=None, max_batch_words=None): self.max_batch_words = max_batch_words self.batch_size = batch_size self.args = args self.eval = evaluation self.shuffled = not self.eval self.doc = doc self.preprocess_tags = preprocess_tags data = self._load_doc(self.doc, scheme) # filter out the long sentences if bert is used if self.args.get('bert_model', None) and needs_length_filter(self.args['bert_model']): data = filter_data(self.args['bert_model'], data, bert_tokenizer) self.tags = [[w[1] for w in sent] for sent in data] # handle vocab self.pretrain = pretrain if vocab is None: self.vocab = self.init_vocab(data) else: self.vocab = vocab # filter and sample data if args.get('sample_train', 1.0) < 1.0 and not self.eval: keep = int(args['sample_train'] * len(data)) data = random.sample(data, keep) logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) data = self.preprocess(data, self.vocab, args) # shuffle for training if self.shuffled: random.shuffle(data) self.num_examples = len(data) # chunk into batches self.data = self.chunk_batches(data) logger.debug("{} batches created.".format(len(self.data))) def init_vocab(self, data): def from_model(model_filename): """ Try loading vocab from charLM model file. """ state_dict = torch.load(model_filename, lambda storage, loc: storage, weights_only=True) if 'vocab' in state_dict: return state_dict['vocab'] if 'model' in state_dict and 'vocab' in state_dict['model']: return state_dict['model']['vocab'] raise ValueError("Cannot find vocab in charLM model file %s" % model_filename) if self.eval: raise AssertionError("Vocab must exist for evaluation.") if self.args['charlm']: charvocab = CharVocab.load_state_dict(from_model(self.args['charlm_forward_file'])) else: charvocab = CharVocab(data, self.args['shorthand']) wordvocab = self.pretrain.vocab if self.pretrain is not None else None tag_data = [[(x[1],) for x in sentence] for sentence in data] tagvocab = CompositeVocab(tag_data, self.args['shorthand'], idx=0, sep=None) ignore = None if self.args['emb_finetune_known_only']: if self.pretrain is None: raise ValueError("Cannot train emb_finetune_known_only with no pretrain of known words") if self.args['lowercase']: ignore = set([w[0].lower() for sent in data for w in sent if w[0] not in wordvocab and w[0].lower() not in wordvocab]) else: ignore = set([w[0] for sent in data for w in sent if w[0] not in wordvocab]) logger.debug("Ignoring %d in the delta vocab as they did not appear in the original embedding", len(ignore)) deltavocab = WordVocab(data, self.args['shorthand'], cutoff=1, lower=self.args['lowercase'], ignore=ignore) logger.debug("Creating delta vocab of size %s", len(deltavocab)) vocabs = {'char': charvocab, 'delta': deltavocab, 'tag': tagvocab} if wordvocab is not None: vocabs['word'] = wordvocab vocab = MultiVocab(vocabs) return vocab def preprocess(self, data, vocab, args): processed = [] if args.get('char_lowercase', False): # handle character case char_case = lambda x: x.lower() else: char_case = lambda x: x for sent_idx, sent in enumerate(data): processed_sent = [[w[0] for w in sent]] processed_sent += [[vocab['char'].map([char_case(x) for x in w[0]]) for w in sent]] processed_sent += [vocab['tag'].map([w[1] for w in sent])] processed.append(processed_sent) return processed def __len__(self): return len(self.data) def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[List[int]]] # sort sentences by lens for easy RNN operations sentlens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, sentlens) sentlens = [len(x) for x in batch[0]] # sort chars by lens for easy char-LM operations chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(batch[1]) chars_sorted, char_orig_idx = sort_all([chars_forward, chars_backward, charoffsets_forward, charoffsets_backward], charlens) chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted charlens = [len(sent) for sent in chars_forward] # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] wordlens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], wordlens) batch_words = batch_words[0] wordlens = [len(x) for x in batch_words] words = batch[0] wordchars = get_long_tensor(batch_words, len(wordlens)) wordchars_mask = torch.eq(wordchars, PAD_ID) chars_forward = get_long_tensor(chars_forward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars_backward = get_long_tensor(chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars = torch.cat([chars_forward.unsqueeze(0), chars_backward.unsqueeze(0)]) # padded forward and backward char idx charoffsets = [charoffsets_forward, charoffsets_backward] # idx for forward and backward lm to get word representation tags = get_long_tensor(batch[2], batch_size) return words, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets def __iter__(self): for i in range(self.__len__()): yield self.__getitem__(i) def _load_doc(self, doc, scheme): # preferentially load the MULTI_NER in case we are training / # testing a model with multiple layers of tags data = doc.get([TEXT, NER, MULTI_NER], as_sentences=True, from_token=True) data = [[[token[0], token[2]] if token[2] else [token[0], (token[1],)] for token in sentence] for sentence in data] if self.preprocess_tags: # preprocess tags if scheme is None: data = process_tags(data, self.args.get('scheme', 'bio')) data = normalize_empty_tags(data) return data def process_chars(self, sents): start_id, end_id = self.vocab['char'].unit2id('\n'), self.vocab['char'].unit2id(' ') # special token start_offset, end_offset = 1, 1 chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = [], [], [], [] # get char representation for each sentence for sent in sents: chars_forward_sent, chars_backward_sent, charoffsets_forward_sent, charoffsets_backward_sent = [start_id], [start_id], [], [] # forward lm for word in sent: chars_forward_sent += word charoffsets_forward_sent = charoffsets_forward_sent + [len(chars_forward_sent)] # add each token offset in the last for forward lm chars_forward_sent += [end_id] # backward lm for word in sent[::-1]: chars_backward_sent += word[::-1] charoffsets_backward_sent = [len(chars_backward_sent)] + charoffsets_backward_sent # add each offset in the first for backward lm chars_backward_sent += [end_id] # store each sentence chars_forward.append(chars_forward_sent) chars_backward.append(chars_backward_sent) charoffsets_forward.append(charoffsets_forward_sent) charoffsets_backward.append(charoffsets_backward_sent) charlens = [len(sent) for sent in chars_forward] # forward lm and backward lm should have the same lengths return chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens def reshuffle(self): data = [y for x in self.data for y in x] random.shuffle(data) self.data = self.chunk_batches(data) def chunk_batches(self, data): if self.max_batch_words is None: return [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)] batches = [] next_batch = [] for item in data: next_batch.append(item) if len(next_batch) >= self.batch_size: batches.append(next_batch) next_batch = [] if sum(len(x[0]) for x in next_batch) >= self.max_batch_words: batches.append(next_batch) next_batch = [] if len(next_batch) > 0: batches.append(next_batch) return batches