from bisect import bisect_right from copy import copy import numpy as np import random import logging import re import torch from torch.utils.data import Dataset from stanza.models.common.utils import sort_with_indices, unsort from stanza.models.tokenization.vocab import Vocab logger = logging.getLogger('stanza') def filter_consecutive_whitespaces(para): filtered = [] for i, (char, label) in enumerate(para): if i > 0: if char == ' ' and para[i-1][0] == ' ': continue filtered.append((char, label)) return filtered NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n') # this was (r'^([\d]+[,\.]*)+$') # but the runtime on that can explode exponentially # for example, on 111111111111111111111111a NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$') WHITESPACE_RE = re.compile(r'\s') class TokenizationDataset: def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs): super().__init__(*args, **kwargs) # forwards all unused arguments self.args = tokenizer_args self.eval = evaluation self.dictionary = dictionary self.vocab = vocab # get input files txt_file = input_files['txt'] label_file = input_files['label'] # Load data and process it # set up text from file or input string assert txt_file is not None or input_text is not None if input_text is None: with open(txt_file, encoding="utf-8") as f: text = ''.join(f.readlines()).rstrip() else: text = input_text text_chunks = NEWLINE_WHITESPACE_RE.split(text) text_chunks = [pt.rstrip() for pt in text_chunks] text_chunks = [pt for pt in text_chunks if pt] if label_file is not None: with open(label_file, encoding="utf-8") as f: labels = ''.join(f.readlines()).rstrip() labels = NEWLINE_WHITESPACE_RE.split(labels) labels = [pt.rstrip() for pt in labels] labels = [map(int, pt) for pt in labels if pt] else: labels = [[0 for _ in pt] for pt in text_chunks] skip_newline = self.args.get('skip_newline', False) self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten for pt, pc in zip(text_chunks, labels)] # remove consecutive whitespaces self.data = [filter_consecutive_whitespaces(x) for x in self.data] def labels(self): """ Returns a list of the labels for all of the sentences in this DataLoader Used at eval time to compare to the results, for example """ return [np.array(list(x[1] for x in sent)) for sent in self.data] def extract_dict_feat(self, para, idx): """ This function is to extract dictionary features for each character """ length = len(para) dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])] dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])] forward_word = para[idx][0] backward_word = para[idx][0] prefix = True suffix = True for window in range(1,self.args['num_dict_feat']+1): # concatenate each character and check if words found in dict not, stop if prefix not found #check if idx+t is out of bound and if the prefix is already not found if (idx + window) <= length-1 and prefix: forward_word += para[idx+window][0].lower() #check in json file if the word is present as prefix or word or None. feat = 1 if forward_word in self.dictionary["words"] else 0 #if the return value is not 2 or 3 then the checking word is not a valid word in dict. dict_forward_feats[window-1] = feat #if the dict return 0 means no prefixes found, thus, stop looking for forward. if forward_word not in self.dictionary["prefixes"]: prefix = False #backward check: similar to forward if (idx - window) >= 0 and suffix: backward_word = para[idx-window][0].lower() + backward_word feat = 1 if backward_word in self.dictionary["words"] else 0 dict_backward_feats[window-1] = feat if backward_word not in self.dictionary["suffixes"]: suffix = False #if cannot find both prefix and suffix, then exit the loop if not prefix and not suffix: break return dict_forward_feats + dict_backward_feats def para_to_sentences(self, para): """ Convert a paragraph to a list of processed sentences. """ res = [] funcs = [] for feat_func in self.args['feat_funcs']: if feat_func == 'end_of_para' or feat_func == 'start_of_para': # skip for position-dependent features continue if feat_func == 'space_before': func = lambda x: 1 if x.startswith(' ') else 0 elif feat_func == 'capitalized': func = lambda x: 1 if x[0].isupper() else 0 elif feat_func == 'numeric': func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0 else: raise ValueError('Feature function "{}" is undefined.'.format(feat_func)) funcs.append(func) # stacking all featurize functions composite_func = lambda x: [f(x) for f in funcs] def process_sentence(sent_units, sent_labels, sent_feats): return (np.array([self.vocab.unit2id(y) for y in sent_units]), np.array(sent_labels), np.array(sent_feats), list(sent_units)) use_end_of_para = 'end_of_para' in self.args['feat_funcs'] use_start_of_para = 'start_of_para' in self.args['feat_funcs'] use_dictionary = self.args['use_dictionary'] current_units = [] current_labels = [] current_feats = [] for i, (unit, label) in enumerate(para): feats = composite_func(unit) # position-dependent features if use_end_of_para: f = 1 if i == len(para)-1 else 0 feats.append(f) if use_start_of_para: f = 1 if i == 0 else 0 feats.append(f) #if dictionary feature is selected if use_dictionary: dict_feats = self.extract_dict_feat(para, i) feats = feats + dict_feats current_units.append(unit) current_labels.append(label) current_feats.append(feats) if not self.eval and (label == 2 or label == 4): # end of sentence if len(current_units) <= self.args['max_seqlen']: # get rid of sentences that are too long during training of the tokenizer res.append(process_sentence(current_units, current_labels, current_feats)) current_units.clear() current_labels.clear() current_feats.clear() if len(current_units) > 0: if self.eval or len(current_units) <= self.args['max_seqlen']: res.append(process_sentence(current_units, current_labels, current_feats)) return res def advance_old_batch(self, eval_offsets, old_batch): """ Advance to a new position in a batch where we have partially processed the batch If we have previously built a batch of data and made predictions on them, then when we are trying to make prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch and just (essentially) advance the indices/offsets from where we read converted data in this old batch. In this case, eval_offsets index within the old_batch to advance the strings to process. """ unkid = self.vocab.unit2id('') padid = self.vocab.unit2id('') ounits, olabels, ofeatures, oraw = old_batch feat_size = ofeatures.shape[-1] lens = (ounits != padid).sum(1).tolist() pad_len = max(l-i for i, l in zip(eval_offsets, lens)) units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64) labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32) features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32) raw_units = [] for i in range(len(ounits)): eval_offsets[i] = min(eval_offsets[i], lens[i]) units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]] labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]] features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]] raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + [''] * (pad_len - lens[i] + eval_offsets[i])) return units, labels, features, raw_units def build_move_punct_set(data, move_back_prob): move_punct = {',', ':', '!', '.', '?', '"', '(', ')'} for chunk in data: # ignore positions at the start and end of a chunk for idx in range(1, len(chunk)-1): if chunk[idx][0] not in move_punct: continue if chunk[idx][1] == 0: if chunk[idx+1][0].isspace() and not chunk[idx-1][0].isdigit(): # this check removes punct which isn't ending a word... # honestly that's a rather unusual situation # VI has |3, 5| as a complete token # so we also eliminate isdigit() move_punct.remove(chunk[idx][0]) continue # we skip isdigit() because we will intentionally not # create things that look like decimal numbers if not chunk[idx-1][0].isspace() and chunk[idx-1][0] not in move_punct and not chunk[idx-1][0].isdigit(): # this check eliminates things like '.' after 'Mr.' move_punct.remove(chunk[idx][0]) continue return move_punct def build_known_mwt(data, mwt_expansions): known_mwts = set() for chunk in data: for idx, unit in enumerate(chunk): if unit[1] != 3: continue # found an MWT prev_idx = idx - 1 while prev_idx >= 0 and chunk[prev_idx][1] == 0: prev_idx -= 1 prev_idx += 1 while chunk[prev_idx][0].isspace(): prev_idx += 1 if prev_idx == idx: continue mwt = "".join(x[0] for x in chunk[prev_idx:idx+1]) if mwt not in mwt_expansions: continue if len(mwt_expansions[mwt]) > 2: # TODO: could split 3 word tokens as well continue known_mwts.add(mwt) return known_mwts class DataLoader(TokenizationDataset): """ This is the training version of the dataset. """ def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None): super().__init__(args, input_files, input_text, vocab, evaluation, dictionary) self.vocab = vocab if vocab is not None else self.init_vocab() # data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels. # At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where # sentence breaks occur. We make prediction from left to right for each paragraph and move forward to # the last predicted sentence break to start afresh. self.sentences = [self.para_to_sentences(para) for para in self.data] self.init_sent_ids() logger.debug(f"{len(self.sentence_ids)} sentences loaded.") punct_move_back_prob = args.get('punct_move_back_prob', 0.0) if punct_move_back_prob > 0.0: self.move_punct = build_move_punct_set(self.data, punct_move_back_prob) if len(self.move_punct) > 0: logger.debug('Based on the training data, will augment space/punct combinations {}'.format(self.move_punct)) else: logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace') split_mwt_prob = args.get('split_mwt_prob', 0.0) if split_mwt_prob > 0.0 and not evaluation: self.mwt_expansions = mwt_expansions self.known_mwt = build_known_mwt(self.data, mwt_expansions) if len(self.known_mwt) > 0: logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt)) else: logger.debug('Based on the training data, there are NO MWT to split at training time') def __len__(self): return len(self.sentence_ids) def init_vocab(self): vocab = Vocab(self.data, self.args['lang']) return vocab def init_sent_ids(self): self.sentence_ids = [] self.cumlen = [0] for i, para in enumerate(self.sentences): for j in range(len(para)): self.sentence_ids += [(i, j)] self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])] def has_mwt(self): # presumably this only needs to be called either 0 or 1 times, # 1 when training and 0 any other time, so no effort is put # into caching the result for sentence in self.data: for word in sentence: if word[1] > 2: return True return False def shuffle(self): for para in self.sentences: random.shuffle(para) self.init_sent_ids() def move_last_char(self, sentence): if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen'] and sentence[1][-1] == 2 and sentence[1][-2] != 0: new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])] new_units.extend([(' ', 0), (sentence[3][-1], int(sentence[1][-1]))]) encoded = self.para_to_sentences(new_units) return encoded return None def split_mwt(self, sentence): if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: return None # if we find a token in the sentence which ends with label 3, # eg it is an MWT, # with some probability we split it into two tokens # and treat the split tokens as both label 1 instead of 3 # in this manner, we teach the tokenizer not to treat the # entire sequence of characters with added spaces as an MWT, # which weirdly can happen in some corner cases mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3] if len(mwt_ends) == 0: return None random_end = random.randint(0, len(mwt_ends)-1) mwt_end = mwt_ends[random_end] mwt_start = mwt_end - 1 while mwt_start >= 0 and sentence[1][mwt_start] == 0: mwt_start -= 1 mwt_start += 1 while sentence[3][mwt_start].isspace(): mwt_start += 1 if mwt_start == mwt_end: return None mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1]) if mwt not in self.mwt_expansions: return None all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])] w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]] w0_units[-1] = (w0_units[-1][0], 1) w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]] w1_units[-1] = (w1_units[-1][0], 1) split_units = w0_units + [(' ', 0)] + w1_units new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:] encoded = self.para_to_sentences(new_units) return encoded def move_punct_back(self, sentence): if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: return None # check that we are not accidentally creating decimal numbers # idx == 1 or not sentence[3][idx-2].isdigit() # one disadvantage of checking for sentence[1][idx] == 0 # would be that tokens of all punct, such as '...', # should move but would not move if this is eliminated commas = [idx for idx, c in enumerate(sentence[3]) if c in self.move_punct and idx > 0 and sentence[3][idx-1].isspace() and (idx == 1 or not sentence[3][idx-2].isdigit())] if len(commas) == 0: return None all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])] new_units = [] span_start = 0 for span_end in commas: new_units.extend(all_units[span_start:span_end-1]) span_start = span_end if span_end < len(sentence[3]): new_units.extend(all_units[span_end:]) encoded = self.para_to_sentences(new_units) return encoded def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0): ''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. ''' feat_size = len(self.sentences[0][0][2][0]) unkid = self.vocab.unit2id('') padid = self.vocab.unit2id('') def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): # At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed # by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences # from the entire dataset until we reach max_seqlen. drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0)) drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0)) move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0) move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0) split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0) pid, sid = id_pair if self.eval else random.choice(self.sentence_ids) sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])] total_len = len(sentences[0][0]) assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])])) if self.eval: for sid1 in range(sid+1, len(self.sentences[pid])): total_len += len(self.sentences[pid][sid1][0]) sentences.append(self.sentences[pid][sid1]) if total_len >= self.args['max_seqlen']: break else: while True: pid1, sid1 = random.choice(self.sentence_ids) total_len += len(self.sentences[pid1][sid1][0]) sentences.append(self.sentences[pid1][sid1]) if total_len >= self.args['max_seqlen']: break if move_last_char_prob > 0.0: for sentence_idx, sentence in enumerate(sentences): if random.random() < move_last_char_prob: # the sentence might not be eligible, such as # already having a space or not having a sentence final punct, # so we need to do a two step checking process here new_sentence = self.move_last_char(sentence) if new_sentence is not None: sentences[sentence_idx] = new_sentence[0] total_len += 1 if move_punct_back_prob > 0.0: for sentence_idx, sentence in enumerate(sentences): if random.random() < move_punct_back_prob: # the sentence might not be eligible, such as # not having a space separated punct, # so we need to do a two step checking process here new_sentence = self.move_punct_back(sentence) if new_sentence is not None: total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) sentences[sentence_idx] = new_sentence[0] if split_mwt_prob > 0.0: for sentence_idx, sentence in enumerate(sentences): if random.random() < split_mwt_prob: new_sentence = self.split_mwt(sentence) if new_sentence is not None: total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) sentences[sentence_idx] = new_sentence[0] if drop_sents and len(sentences) > 1: if total_len > self.args['max_seqlen']: sentences = sentences[:-1] if len(sentences) > 1: p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0] sentences = sentences[:cutoff+1] units = np.concatenate([s[0] for s in sentences]) labels = np.concatenate([s[1] for s in sentences]) feats = np.concatenate([s[2] for s in sentences]) raw_units = [x for s in sentences for x in s[3]] if not self.eval: cutoff = self.args['max_seqlen'] units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff] if drop_last_char: # can only happen in non-eval mode if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3): # training text ended with a sentence end position # and that word was a single character # and the previous character ended the word units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1] # word end -> sentence end, mwt end -> sentence mwt end labels[-1] = labels[-1] + 1 return units, labels, feats, raw_units if eval_offsets is not None: # find max padding length pad_len = 0 for eval_offset in eval_offsets: if eval_offset < self.cumlen[-1]: pair_id = bisect_right(self.cumlen, eval_offset) - 1 pair = self.sentence_ids[pair_id] pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0])) pad_len += 1 id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets] pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs] offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)] offsets_pairs = list(zip(offsets, pairs)) else: id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size'])) offsets_pairs = [(0, x) for x in id_pairs] pad_len = self.args['max_seqlen'] # put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64) labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64) features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32) raw_units = [] for i, (offset, pair) in enumerate(offsets_pairs): u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len) units[i, :len(u_)] = u_ labels[i, :len(l_)] = l_ features[i, :len(f_), :] = f_ raw_units.append(r_ + [''] * (pad_len - len(r_))) if unit_dropout > 0 and not self.eval: # dropout characters/units at training time and replace them with UNKs mask = np.random.random_sample(units.shape) < unit_dropout mask[units == padid] = 0 units[mask] = unkid for i in range(len(raw_units)): for j in range(len(raw_units[i])): if mask[i, j]: raw_units[i][j] = '' # dropout unit feature vector in addition to only torch.dropout in the model. # experiments showed that only torch.dropout hurts the model # we believe it is because the dict feature vector is mostly scarse so it makes # more sense to drop out the whole vector instead of only single element. if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval: mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout mask_feat[units == padid] = 0 for i in range(len(raw_units)): for j in range(len(raw_units[i])): if mask_feat[i,j]: features[i,j,:] = 0 units = torch.from_numpy(units) labels = torch.from_numpy(labels) features = torch.from_numpy(features) return units, labels, features, raw_units class SortedDataset(Dataset): """ Holds a TokenizationDataset for use in a torch DataLoader The torch DataLoader is different from the DataLoader defined here and allows for cpu & gpu parallelism. Updating output_predictions to use this class as a wrapper to a TokenizationDataset means the calculation of features can happen in parallel, saving quite a bit of time. """ def __init__(self, dataset): super().__init__() self.dataset = dataset self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True) def __len__(self): return len(self.data) def __getitem__(self, index): # This will return a single sample # np: index in character map # np: tokenization label # np: features # list: original text as one length strings return self.dataset.para_to_sentences(self.data[index]) def unsort(self, arr): return unsort(arr, self.indices) def collate(self, samples): if any(len(x) > 1 for x in samples): raise ValueError("Expected all paragraphs to have no preset sentence splits!") feat_size = samples[0][0][2].shape[-1] padid = self.dataset.vocab.unit2id('') # +1 so that all samples end with at least one pad pad_len = max(len(x[0][3]) for x in samples) + 1 units = torch.full((len(samples), pad_len), padid, dtype=torch.int64) labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32) features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32) raw_units = [] for i, sample in enumerate(samples): u_, l_, f_, r_ = sample[0] units[i, :len(u_)] = torch.from_numpy(u_) labels[i, :len(l_)] = torch.from_numpy(l_) features[i, :len(f_), :] = torch.from_numpy(f_) raw_units.append(r_ + ['']) return units, labels, features, raw_units