|
|
from bisect import bisect_right |
|
|
from copy import copy |
|
|
import numpy as np |
|
|
import random |
|
|
import logging |
|
|
import re |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from stanza.models.common.utils import sort_with_indices, unsort |
|
|
from stanza.models.tokenization.vocab import Vocab |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
def filter_consecutive_whitespaces(para): |
|
|
filtered = [] |
|
|
for i, (char, label) in enumerate(para): |
|
|
if i > 0: |
|
|
if char == ' ' and para[i-1][0] == ' ': |
|
|
continue |
|
|
|
|
|
filtered.append((char, label)) |
|
|
|
|
|
return filtered |
|
|
|
|
|
NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n') |
|
|
|
|
|
|
|
|
|
|
|
NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$') |
|
|
WHITESPACE_RE = re.compile(r'\s') |
|
|
|
|
|
class TokenizationDataset: |
|
|
def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.args = tokenizer_args |
|
|
self.eval = evaluation |
|
|
self.dictionary = dictionary |
|
|
self.vocab = vocab |
|
|
|
|
|
|
|
|
txt_file = input_files['txt'] |
|
|
label_file = input_files['label'] |
|
|
|
|
|
|
|
|
|
|
|
assert txt_file is not None or input_text is not None |
|
|
if input_text is None: |
|
|
with open(txt_file, encoding="utf-8") as f: |
|
|
text = ''.join(f.readlines()).rstrip() |
|
|
else: |
|
|
text = input_text |
|
|
|
|
|
text_chunks = NEWLINE_WHITESPACE_RE.split(text) |
|
|
text_chunks = [pt.rstrip() for pt in text_chunks] |
|
|
text_chunks = [pt for pt in text_chunks if pt] |
|
|
if label_file is not None: |
|
|
with open(label_file, encoding="utf-8") as f: |
|
|
labels = ''.join(f.readlines()).rstrip() |
|
|
labels = NEWLINE_WHITESPACE_RE.split(labels) |
|
|
labels = [pt.rstrip() for pt in labels] |
|
|
labels = [map(int, pt) for pt in labels if pt] |
|
|
else: |
|
|
labels = [[0 for _ in pt] for pt in text_chunks] |
|
|
|
|
|
skip_newline = self.args.get('skip_newline', False) |
|
|
self.data = [[(WHITESPACE_RE.sub(' ', char), label) |
|
|
for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] |
|
|
for pt, pc in zip(text_chunks, labels)] |
|
|
|
|
|
|
|
|
self.data = [filter_consecutive_whitespaces(x) for x in self.data] |
|
|
|
|
|
def labels(self): |
|
|
""" |
|
|
Returns a list of the labels for all of the sentences in this DataLoader |
|
|
|
|
|
Used at eval time to compare to the results, for example |
|
|
""" |
|
|
return [np.array(list(x[1] for x in sent)) for sent in self.data] |
|
|
|
|
|
def extract_dict_feat(self, para, idx): |
|
|
""" |
|
|
This function is to extract dictionary features for each character |
|
|
""" |
|
|
length = len(para) |
|
|
|
|
|
dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])] |
|
|
dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])] |
|
|
forward_word = para[idx][0] |
|
|
backward_word = para[idx][0] |
|
|
prefix = True |
|
|
suffix = True |
|
|
for window in range(1,self.args['num_dict_feat']+1): |
|
|
|
|
|
|
|
|
if (idx + window) <= length-1 and prefix: |
|
|
forward_word += para[idx+window][0].lower() |
|
|
|
|
|
feat = 1 if forward_word in self.dictionary["words"] else 0 |
|
|
|
|
|
dict_forward_feats[window-1] = feat |
|
|
|
|
|
if forward_word not in self.dictionary["prefixes"]: |
|
|
prefix = False |
|
|
|
|
|
if (idx - window) >= 0 and suffix: |
|
|
backward_word = para[idx-window][0].lower() + backward_word |
|
|
feat = 1 if backward_word in self.dictionary["words"] else 0 |
|
|
dict_backward_feats[window-1] = feat |
|
|
if backward_word not in self.dictionary["suffixes"]: |
|
|
suffix = False |
|
|
|
|
|
if not prefix and not suffix: |
|
|
break |
|
|
|
|
|
return dict_forward_feats + dict_backward_feats |
|
|
|
|
|
def para_to_sentences(self, para): |
|
|
""" Convert a paragraph to a list of processed sentences. """ |
|
|
res = [] |
|
|
funcs = [] |
|
|
for feat_func in self.args['feat_funcs']: |
|
|
if feat_func == 'end_of_para' or feat_func == 'start_of_para': |
|
|
|
|
|
continue |
|
|
if feat_func == 'space_before': |
|
|
func = lambda x: 1 if x.startswith(' ') else 0 |
|
|
elif feat_func == 'capitalized': |
|
|
func = lambda x: 1 if x[0].isupper() else 0 |
|
|
elif feat_func == 'numeric': |
|
|
func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0 |
|
|
else: |
|
|
raise ValueError('Feature function "{}" is undefined.'.format(feat_func)) |
|
|
|
|
|
funcs.append(func) |
|
|
|
|
|
|
|
|
composite_func = lambda x: [f(x) for f in funcs] |
|
|
|
|
|
def process_sentence(sent_units, sent_labels, sent_feats): |
|
|
return (np.array([self.vocab.unit2id(y) for y in sent_units]), |
|
|
np.array(sent_labels), |
|
|
np.array(sent_feats), |
|
|
list(sent_units)) |
|
|
|
|
|
use_end_of_para = 'end_of_para' in self.args['feat_funcs'] |
|
|
use_start_of_para = 'start_of_para' in self.args['feat_funcs'] |
|
|
use_dictionary = self.args['use_dictionary'] |
|
|
current_units = [] |
|
|
current_labels = [] |
|
|
current_feats = [] |
|
|
for i, (unit, label) in enumerate(para): |
|
|
feats = composite_func(unit) |
|
|
|
|
|
if use_end_of_para: |
|
|
f = 1 if i == len(para)-1 else 0 |
|
|
feats.append(f) |
|
|
if use_start_of_para: |
|
|
f = 1 if i == 0 else 0 |
|
|
feats.append(f) |
|
|
|
|
|
|
|
|
if use_dictionary: |
|
|
dict_feats = self.extract_dict_feat(para, i) |
|
|
feats = feats + dict_feats |
|
|
|
|
|
current_units.append(unit) |
|
|
current_labels.append(label) |
|
|
current_feats.append(feats) |
|
|
if not self.eval and (label == 2 or label == 4): |
|
|
if len(current_units) <= self.args['max_seqlen']: |
|
|
|
|
|
res.append(process_sentence(current_units, current_labels, current_feats)) |
|
|
current_units.clear() |
|
|
current_labels.clear() |
|
|
current_feats.clear() |
|
|
|
|
|
if len(current_units) > 0: |
|
|
if self.eval or len(current_units) <= self.args['max_seqlen']: |
|
|
res.append(process_sentence(current_units, current_labels, current_feats)) |
|
|
|
|
|
return res |
|
|
|
|
|
def advance_old_batch(self, eval_offsets, old_batch): |
|
|
""" |
|
|
Advance to a new position in a batch where we have partially processed the batch |
|
|
|
|
|
If we have previously built a batch of data and made predictions on them, then when we are trying to make |
|
|
prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch |
|
|
and just (essentially) advance the indices/offsets from where we read converted data in this old batch. |
|
|
In this case, eval_offsets index within the old_batch to advance the strings to process. |
|
|
""" |
|
|
unkid = self.vocab.unit2id('<UNK>') |
|
|
padid = self.vocab.unit2id('<PAD>') |
|
|
|
|
|
ounits, olabels, ofeatures, oraw = old_batch |
|
|
feat_size = ofeatures.shape[-1] |
|
|
lens = (ounits != padid).sum(1).tolist() |
|
|
pad_len = max(l-i for i, l in zip(eval_offsets, lens)) |
|
|
|
|
|
units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64) |
|
|
labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32) |
|
|
features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32) |
|
|
raw_units = [] |
|
|
|
|
|
for i in range(len(ounits)): |
|
|
eval_offsets[i] = min(eval_offsets[i], lens[i]) |
|
|
units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]] |
|
|
labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]] |
|
|
features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]] |
|
|
raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + ['<PAD>'] * (pad_len - lens[i] + eval_offsets[i])) |
|
|
|
|
|
return units, labels, features, raw_units |
|
|
|
|
|
def build_move_punct_set(data, move_back_prob): |
|
|
move_punct = {',', ':', '!', '.', '?', '"', '(', ')'} |
|
|
for chunk in data: |
|
|
|
|
|
for idx in range(1, len(chunk)-1): |
|
|
if chunk[idx][0] not in move_punct: |
|
|
continue |
|
|
if chunk[idx][1] == 0: |
|
|
if chunk[idx+1][0].isspace() and not chunk[idx-1][0].isdigit(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
move_punct.remove(chunk[idx][0]) |
|
|
continue |
|
|
|
|
|
|
|
|
if not chunk[idx-1][0].isspace() and chunk[idx-1][0] not in move_punct and not chunk[idx-1][0].isdigit(): |
|
|
|
|
|
move_punct.remove(chunk[idx][0]) |
|
|
continue |
|
|
return move_punct |
|
|
|
|
|
def build_known_mwt(data, mwt_expansions): |
|
|
known_mwts = set() |
|
|
for chunk in data: |
|
|
for idx, unit in enumerate(chunk): |
|
|
if unit[1] != 3: |
|
|
continue |
|
|
|
|
|
prev_idx = idx - 1 |
|
|
while prev_idx >= 0 and chunk[prev_idx][1] == 0: |
|
|
prev_idx -= 1 |
|
|
prev_idx += 1 |
|
|
while chunk[prev_idx][0].isspace(): |
|
|
prev_idx += 1 |
|
|
if prev_idx == idx: |
|
|
continue |
|
|
mwt = "".join(x[0] for x in chunk[prev_idx:idx+1]) |
|
|
if mwt not in mwt_expansions: |
|
|
continue |
|
|
if len(mwt_expansions[mwt]) > 2: |
|
|
|
|
|
continue |
|
|
known_mwts.add(mwt) |
|
|
return known_mwts |
|
|
|
|
|
class DataLoader(TokenizationDataset): |
|
|
""" |
|
|
This is the training version of the dataset. |
|
|
""" |
|
|
def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None): |
|
|
super().__init__(args, input_files, input_text, vocab, evaluation, dictionary) |
|
|
|
|
|
self.vocab = vocab if vocab is not None else self.init_vocab() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.sentences = [self.para_to_sentences(para) for para in self.data] |
|
|
|
|
|
self.init_sent_ids() |
|
|
logger.debug(f"{len(self.sentence_ids)} sentences loaded.") |
|
|
|
|
|
punct_move_back_prob = args.get('punct_move_back_prob', 0.0) |
|
|
if punct_move_back_prob > 0.0: |
|
|
self.move_punct = build_move_punct_set(self.data, punct_move_back_prob) |
|
|
if len(self.move_punct) > 0: |
|
|
logger.debug('Based on the training data, will augment space/punct combinations {}'.format(self.move_punct)) |
|
|
else: |
|
|
logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace') |
|
|
|
|
|
split_mwt_prob = args.get('split_mwt_prob', 0.0) |
|
|
if split_mwt_prob > 0.0 and not evaluation: |
|
|
self.mwt_expansions = mwt_expansions |
|
|
self.known_mwt = build_known_mwt(self.data, mwt_expansions) |
|
|
if len(self.known_mwt) > 0: |
|
|
logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt)) |
|
|
else: |
|
|
logger.debug('Based on the training data, there are NO MWT to split at training time') |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.sentence_ids) |
|
|
|
|
|
def init_vocab(self): |
|
|
vocab = Vocab(self.data, self.args['lang']) |
|
|
return vocab |
|
|
|
|
|
def init_sent_ids(self): |
|
|
self.sentence_ids = [] |
|
|
self.cumlen = [0] |
|
|
for i, para in enumerate(self.sentences): |
|
|
for j in range(len(para)): |
|
|
self.sentence_ids += [(i, j)] |
|
|
self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])] |
|
|
|
|
|
def has_mwt(self): |
|
|
|
|
|
|
|
|
|
|
|
for sentence in self.data: |
|
|
for word in sentence: |
|
|
if word[1] > 2: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def shuffle(self): |
|
|
for para in self.sentences: |
|
|
random.shuffle(para) |
|
|
self.init_sent_ids() |
|
|
|
|
|
def move_last_char(self, sentence): |
|
|
if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen'] and sentence[1][-1] == 2 and sentence[1][-2] != 0: |
|
|
new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])] |
|
|
new_units.extend([(' ', 0), (sentence[3][-1], int(sentence[1][-1]))]) |
|
|
encoded = self.para_to_sentences(new_units) |
|
|
return encoded |
|
|
return None |
|
|
|
|
|
def split_mwt(self, sentence): |
|
|
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3] |
|
|
if len(mwt_ends) == 0: |
|
|
return None |
|
|
random_end = random.randint(0, len(mwt_ends)-1) |
|
|
mwt_end = mwt_ends[random_end] |
|
|
mwt_start = mwt_end - 1 |
|
|
while mwt_start >= 0 and sentence[1][mwt_start] == 0: |
|
|
mwt_start -= 1 |
|
|
mwt_start += 1 |
|
|
while sentence[3][mwt_start].isspace(): |
|
|
mwt_start += 1 |
|
|
if mwt_start == mwt_end: |
|
|
return None |
|
|
mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1]) |
|
|
if mwt not in self.mwt_expansions: |
|
|
return None |
|
|
|
|
|
all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])] |
|
|
w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]] |
|
|
w0_units[-1] = (w0_units[-1][0], 1) |
|
|
w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]] |
|
|
w1_units[-1] = (w1_units[-1][0], 1) |
|
|
split_units = w0_units + [(' ', 0)] + w1_units |
|
|
new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:] |
|
|
encoded = self.para_to_sentences(new_units) |
|
|
return encoded |
|
|
|
|
|
def move_punct_back(self, sentence): |
|
|
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
commas = [idx for idx, c in enumerate(sentence[3]) |
|
|
if c in self.move_punct and idx > 0 and sentence[3][idx-1].isspace() and (idx == 1 or not sentence[3][idx-2].isdigit())] |
|
|
if len(commas) == 0: |
|
|
return None |
|
|
|
|
|
all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])] |
|
|
new_units = [] |
|
|
|
|
|
span_start = 0 |
|
|
for span_end in commas: |
|
|
new_units.extend(all_units[span_start:span_end-1]) |
|
|
span_start = span_end |
|
|
if span_end < len(sentence[3]): |
|
|
new_units.extend(all_units[span_end:]) |
|
|
|
|
|
encoded = self.para_to_sentences(new_units) |
|
|
return encoded |
|
|
|
|
|
|
|
|
def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0): |
|
|
''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. ''' |
|
|
feat_size = len(self.sentences[0][0][2][0]) |
|
|
unkid = self.vocab.unit2id('<UNK>') |
|
|
padid = self.vocab.unit2id('<PAD>') |
|
|
|
|
|
def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): |
|
|
|
|
|
|
|
|
|
|
|
drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0)) |
|
|
drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0)) |
|
|
move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0) |
|
|
move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0) |
|
|
split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0) |
|
|
|
|
|
pid, sid = id_pair if self.eval else random.choice(self.sentence_ids) |
|
|
sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])] |
|
|
total_len = len(sentences[0][0]) |
|
|
|
|
|
assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])])) |
|
|
if self.eval: |
|
|
for sid1 in range(sid+1, len(self.sentences[pid])): |
|
|
total_len += len(self.sentences[pid][sid1][0]) |
|
|
sentences.append(self.sentences[pid][sid1]) |
|
|
|
|
|
if total_len >= self.args['max_seqlen']: |
|
|
break |
|
|
else: |
|
|
while True: |
|
|
pid1, sid1 = random.choice(self.sentence_ids) |
|
|
total_len += len(self.sentences[pid1][sid1][0]) |
|
|
sentences.append(self.sentences[pid1][sid1]) |
|
|
|
|
|
if total_len >= self.args['max_seqlen']: |
|
|
break |
|
|
|
|
|
if move_last_char_prob > 0.0: |
|
|
for sentence_idx, sentence in enumerate(sentences): |
|
|
if random.random() < move_last_char_prob: |
|
|
|
|
|
|
|
|
|
|
|
new_sentence = self.move_last_char(sentence) |
|
|
if new_sentence is not None: |
|
|
sentences[sentence_idx] = new_sentence[0] |
|
|
total_len += 1 |
|
|
|
|
|
if move_punct_back_prob > 0.0: |
|
|
for sentence_idx, sentence in enumerate(sentences): |
|
|
if random.random() < move_punct_back_prob: |
|
|
|
|
|
|
|
|
|
|
|
new_sentence = self.move_punct_back(sentence) |
|
|
if new_sentence is not None: |
|
|
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) |
|
|
sentences[sentence_idx] = new_sentence[0] |
|
|
|
|
|
if split_mwt_prob > 0.0: |
|
|
for sentence_idx, sentence in enumerate(sentences): |
|
|
if random.random() < split_mwt_prob: |
|
|
new_sentence = self.split_mwt(sentence) |
|
|
if new_sentence is not None: |
|
|
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) |
|
|
sentences[sentence_idx] = new_sentence[0] |
|
|
|
|
|
if drop_sents and len(sentences) > 1: |
|
|
if total_len > self.args['max_seqlen']: |
|
|
sentences = sentences[:-1] |
|
|
if len(sentences) > 1: |
|
|
p = [.5 ** i for i in range(1, len(sentences) + 1)] |
|
|
cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0] |
|
|
sentences = sentences[:cutoff+1] |
|
|
|
|
|
units = np.concatenate([s[0] for s in sentences]) |
|
|
labels = np.concatenate([s[1] for s in sentences]) |
|
|
feats = np.concatenate([s[2] for s in sentences]) |
|
|
raw_units = [x for s in sentences for x in s[3]] |
|
|
|
|
|
if not self.eval: |
|
|
cutoff = self.args['max_seqlen'] |
|
|
units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff] |
|
|
|
|
|
if drop_last_char: |
|
|
if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3): |
|
|
|
|
|
|
|
|
|
|
|
units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1] |
|
|
|
|
|
labels[-1] = labels[-1] + 1 |
|
|
|
|
|
return units, labels, feats, raw_units |
|
|
|
|
|
if eval_offsets is not None: |
|
|
|
|
|
pad_len = 0 |
|
|
for eval_offset in eval_offsets: |
|
|
if eval_offset < self.cumlen[-1]: |
|
|
pair_id = bisect_right(self.cumlen, eval_offset) - 1 |
|
|
pair = self.sentence_ids[pair_id] |
|
|
pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0])) |
|
|
|
|
|
pad_len += 1 |
|
|
id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets] |
|
|
pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs] |
|
|
offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)] |
|
|
|
|
|
offsets_pairs = list(zip(offsets, pairs)) |
|
|
else: |
|
|
id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size'])) |
|
|
offsets_pairs = [(0, x) for x in id_pairs] |
|
|
pad_len = self.args['max_seqlen'] |
|
|
|
|
|
|
|
|
units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64) |
|
|
labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64) |
|
|
features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32) |
|
|
raw_units = [] |
|
|
for i, (offset, pair) in enumerate(offsets_pairs): |
|
|
u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len) |
|
|
units[i, :len(u_)] = u_ |
|
|
labels[i, :len(l_)] = l_ |
|
|
features[i, :len(f_), :] = f_ |
|
|
raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_))) |
|
|
|
|
|
if unit_dropout > 0 and not self.eval: |
|
|
|
|
|
mask = np.random.random_sample(units.shape) < unit_dropout |
|
|
mask[units == padid] = 0 |
|
|
units[mask] = unkid |
|
|
for i in range(len(raw_units)): |
|
|
for j in range(len(raw_units[i])): |
|
|
if mask[i, j]: |
|
|
raw_units[i][j] = '<UNK>' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval: |
|
|
mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout |
|
|
mask_feat[units == padid] = 0 |
|
|
for i in range(len(raw_units)): |
|
|
for j in range(len(raw_units[i])): |
|
|
if mask_feat[i,j]: |
|
|
features[i,j,:] = 0 |
|
|
|
|
|
units = torch.from_numpy(units) |
|
|
labels = torch.from_numpy(labels) |
|
|
features = torch.from_numpy(features) |
|
|
|
|
|
return units, labels, features, raw_units |
|
|
|
|
|
class SortedDataset(Dataset): |
|
|
""" |
|
|
Holds a TokenizationDataset for use in a torch DataLoader |
|
|
|
|
|
The torch DataLoader is different from the DataLoader defined here |
|
|
and allows for cpu & gpu parallelism. Updating output_predictions |
|
|
to use this class as a wrapper to a TokenizationDataset means the |
|
|
calculation of features can happen in parallel, saving quite a |
|
|
bit of time. |
|
|
""" |
|
|
def __init__(self, dataset): |
|
|
super().__init__() |
|
|
|
|
|
self.dataset = dataset |
|
|
self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.dataset.para_to_sentences(self.data[index]) |
|
|
|
|
|
def unsort(self, arr): |
|
|
return unsort(arr, self.indices) |
|
|
|
|
|
def collate(self, samples): |
|
|
if any(len(x) > 1 for x in samples): |
|
|
raise ValueError("Expected all paragraphs to have no preset sentence splits!") |
|
|
feat_size = samples[0][0][2].shape[-1] |
|
|
padid = self.dataset.vocab.unit2id('<PAD>') |
|
|
|
|
|
|
|
|
pad_len = max(len(x[0][3]) for x in samples) + 1 |
|
|
|
|
|
units = torch.full((len(samples), pad_len), padid, dtype=torch.int64) |
|
|
labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32) |
|
|
features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32) |
|
|
raw_units = [] |
|
|
for i, sample in enumerate(samples): |
|
|
u_, l_, f_, r_ = sample[0] |
|
|
units[i, :len(u_)] = torch.from_numpy(u_) |
|
|
labels[i, :len(l_)] = torch.from_numpy(l_) |
|
|
features[i, :len(f_), :] = torch.from_numpy(f_) |
|
|
raw_units.append(r_ + ['<PAD>']) |
|
|
|
|
|
return units, labels, features, raw_units |
|
|
|
|
|
|