File size: 28,521 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 |
from bisect import bisect_right
from copy import copy
import numpy as np
import random
import logging
import re
import torch
from torch.utils.data import Dataset
from stanza.models.common.utils import sort_with_indices, unsort
from stanza.models.tokenization.vocab import Vocab
logger = logging.getLogger('stanza')
def filter_consecutive_whitespaces(para):
filtered = []
for i, (char, label) in enumerate(para):
if i > 0:
if char == ' ' and para[i-1][0] == ' ':
continue
filtered.append((char, label))
return filtered
NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
# this was (r'^([\d]+[,\.]*)+$')
# but the runtime on that can explode exponentially
# for example, on 111111111111111111111111a
NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$')
WHITESPACE_RE = re.compile(r'\s')
class TokenizationDataset:
def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs):
super().__init__(*args, **kwargs) # forwards all unused arguments
self.args = tokenizer_args
self.eval = evaluation
self.dictionary = dictionary
self.vocab = vocab
# get input files
txt_file = input_files['txt']
label_file = input_files['label']
# Load data and process it
# set up text from file or input string
assert txt_file is not None or input_text is not None
if input_text is None:
with open(txt_file, encoding="utf-8") as f:
text = ''.join(f.readlines()).rstrip()
else:
text = input_text
text_chunks = NEWLINE_WHITESPACE_RE.split(text)
text_chunks = [pt.rstrip() for pt in text_chunks]
text_chunks = [pt for pt in text_chunks if pt]
if label_file is not None:
with open(label_file, encoding="utf-8") as f:
labels = ''.join(f.readlines()).rstrip()
labels = NEWLINE_WHITESPACE_RE.split(labels)
labels = [pt.rstrip() for pt in labels]
labels = [map(int, pt) for pt in labels if pt]
else:
labels = [[0 for _ in pt] for pt in text_chunks]
skip_newline = self.args.get('skip_newline', False)
self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces
for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten
for pt, pc in zip(text_chunks, labels)]
# remove consecutive whitespaces
self.data = [filter_consecutive_whitespaces(x) for x in self.data]
def labels(self):
"""
Returns a list of the labels for all of the sentences in this DataLoader
Used at eval time to compare to the results, for example
"""
return [np.array(list(x[1] for x in sent)) for sent in self.data]
def extract_dict_feat(self, para, idx):
"""
This function is to extract dictionary features for each character
"""
length = len(para)
dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
forward_word = para[idx][0]
backward_word = para[idx][0]
prefix = True
suffix = True
for window in range(1,self.args['num_dict_feat']+1):
# concatenate each character and check if words found in dict not, stop if prefix not found
#check if idx+t is out of bound and if the prefix is already not found
if (idx + window) <= length-1 and prefix:
forward_word += para[idx+window][0].lower()
#check in json file if the word is present as prefix or word or None.
feat = 1 if forward_word in self.dictionary["words"] else 0
#if the return value is not 2 or 3 then the checking word is not a valid word in dict.
dict_forward_feats[window-1] = feat
#if the dict return 0 means no prefixes found, thus, stop looking for forward.
if forward_word not in self.dictionary["prefixes"]:
prefix = False
#backward check: similar to forward
if (idx - window) >= 0 and suffix:
backward_word = para[idx-window][0].lower() + backward_word
feat = 1 if backward_word in self.dictionary["words"] else 0
dict_backward_feats[window-1] = feat
if backward_word not in self.dictionary["suffixes"]:
suffix = False
#if cannot find both prefix and suffix, then exit the loop
if not prefix and not suffix:
break
return dict_forward_feats + dict_backward_feats
def para_to_sentences(self, para):
""" Convert a paragraph to a list of processed sentences. """
res = []
funcs = []
for feat_func in self.args['feat_funcs']:
if feat_func == 'end_of_para' or feat_func == 'start_of_para':
# skip for position-dependent features
continue
if feat_func == 'space_before':
func = lambda x: 1 if x.startswith(' ') else 0
elif feat_func == 'capitalized':
func = lambda x: 1 if x[0].isupper() else 0
elif feat_func == 'numeric':
func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
else:
raise ValueError('Feature function "{}" is undefined.'.format(feat_func))
funcs.append(func)
# stacking all featurize functions
composite_func = lambda x: [f(x) for f in funcs]
def process_sentence(sent_units, sent_labels, sent_feats):
return (np.array([self.vocab.unit2id(y) for y in sent_units]),
np.array(sent_labels),
np.array(sent_feats),
list(sent_units))
use_end_of_para = 'end_of_para' in self.args['feat_funcs']
use_start_of_para = 'start_of_para' in self.args['feat_funcs']
use_dictionary = self.args['use_dictionary']
current_units = []
current_labels = []
current_feats = []
for i, (unit, label) in enumerate(para):
feats = composite_func(unit)
# position-dependent features
if use_end_of_para:
f = 1 if i == len(para)-1 else 0
feats.append(f)
if use_start_of_para:
f = 1 if i == 0 else 0
feats.append(f)
#if dictionary feature is selected
if use_dictionary:
dict_feats = self.extract_dict_feat(para, i)
feats = feats + dict_feats
current_units.append(unit)
current_labels.append(label)
current_feats.append(feats)
if not self.eval and (label == 2 or label == 4): # end of sentence
if len(current_units) <= self.args['max_seqlen']:
# get rid of sentences that are too long during training of the tokenizer
res.append(process_sentence(current_units, current_labels, current_feats))
current_units.clear()
current_labels.clear()
current_feats.clear()
if len(current_units) > 0:
if self.eval or len(current_units) <= self.args['max_seqlen']:
res.append(process_sentence(current_units, current_labels, current_feats))
return res
def advance_old_batch(self, eval_offsets, old_batch):
"""
Advance to a new position in a batch where we have partially processed the batch
If we have previously built a batch of data and made predictions on them, then when we are trying to make
prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
In this case, eval_offsets index within the old_batch to advance the strings to process.
"""
unkid = self.vocab.unit2id('<UNK>')
padid = self.vocab.unit2id('<PAD>')
ounits, olabels, ofeatures, oraw = old_batch
feat_size = ofeatures.shape[-1]
lens = (ounits != padid).sum(1).tolist()
pad_len = max(l-i for i, l in zip(eval_offsets, lens))
units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64)
labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32)
features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32)
raw_units = []
for i in range(len(ounits)):
eval_offsets[i] = min(eval_offsets[i], lens[i])
units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]
labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]
features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]
raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + ['<PAD>'] * (pad_len - lens[i] + eval_offsets[i]))
return units, labels, features, raw_units
def build_move_punct_set(data, move_back_prob):
move_punct = {',', ':', '!', '.', '?', '"', '(', ')'}
for chunk in data:
# ignore positions at the start and end of a chunk
for idx in range(1, len(chunk)-1):
if chunk[idx][0] not in move_punct:
continue
if chunk[idx][1] == 0:
if chunk[idx+1][0].isspace() and not chunk[idx-1][0].isdigit():
# this check removes punct which isn't ending a word...
# honestly that's a rather unusual situation
# VI has |3, 5| as a complete token
# so we also eliminate isdigit()
move_punct.remove(chunk[idx][0])
continue
# we skip isdigit() because we will intentionally not
# create things that look like decimal numbers
if not chunk[idx-1][0].isspace() and chunk[idx-1][0] not in move_punct and not chunk[idx-1][0].isdigit():
# this check eliminates things like '.' after 'Mr.'
move_punct.remove(chunk[idx][0])
continue
return move_punct
def build_known_mwt(data, mwt_expansions):
known_mwts = set()
for chunk in data:
for idx, unit in enumerate(chunk):
if unit[1] != 3:
continue
# found an MWT
prev_idx = idx - 1
while prev_idx >= 0 and chunk[prev_idx][1] == 0:
prev_idx -= 1
prev_idx += 1
while chunk[prev_idx][0].isspace():
prev_idx += 1
if prev_idx == idx:
continue
mwt = "".join(x[0] for x in chunk[prev_idx:idx+1])
if mwt not in mwt_expansions:
continue
if len(mwt_expansions[mwt]) > 2:
# TODO: could split 3 word tokens as well
continue
known_mwts.add(mwt)
return known_mwts
class DataLoader(TokenizationDataset):
"""
This is the training version of the dataset.
"""
def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None):
super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)
self.vocab = vocab if vocab is not None else self.init_vocab()
# data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels.
# At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where
# sentence breaks occur. We make prediction from left to right for each paragraph and move forward to
# the last predicted sentence break to start afresh.
self.sentences = [self.para_to_sentences(para) for para in self.data]
self.init_sent_ids()
logger.debug(f"{len(self.sentence_ids)} sentences loaded.")
punct_move_back_prob = args.get('punct_move_back_prob', 0.0)
if punct_move_back_prob > 0.0:
self.move_punct = build_move_punct_set(self.data, punct_move_back_prob)
if len(self.move_punct) > 0:
logger.debug('Based on the training data, will augment space/punct combinations {}'.format(self.move_punct))
else:
logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace')
split_mwt_prob = args.get('split_mwt_prob', 0.0)
if split_mwt_prob > 0.0 and not evaluation:
self.mwt_expansions = mwt_expansions
self.known_mwt = build_known_mwt(self.data, mwt_expansions)
if len(self.known_mwt) > 0:
logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt))
else:
logger.debug('Based on the training data, there are NO MWT to split at training time')
def __len__(self):
return len(self.sentence_ids)
def init_vocab(self):
vocab = Vocab(self.data, self.args['lang'])
return vocab
def init_sent_ids(self):
self.sentence_ids = []
self.cumlen = [0]
for i, para in enumerate(self.sentences):
for j in range(len(para)):
self.sentence_ids += [(i, j)]
self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])]
def has_mwt(self):
# presumably this only needs to be called either 0 or 1 times,
# 1 when training and 0 any other time, so no effort is put
# into caching the result
for sentence in self.data:
for word in sentence:
if word[1] > 2:
return True
return False
def shuffle(self):
for para in self.sentences:
random.shuffle(para)
self.init_sent_ids()
def move_last_char(self, sentence):
if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen'] and sentence[1][-1] == 2 and sentence[1][-2] != 0:
new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])]
new_units.extend([(' ', 0), (sentence[3][-1], int(sentence[1][-1]))])
encoded = self.para_to_sentences(new_units)
return encoded
return None
def split_mwt(self, sentence):
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
return None
# if we find a token in the sentence which ends with label 3,
# eg it is an MWT,
# with some probability we split it into two tokens
# and treat the split tokens as both label 1 instead of 3
# in this manner, we teach the tokenizer not to treat the
# entire sequence of characters with added spaces as an MWT,
# which weirdly can happen in some corner cases
mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3]
if len(mwt_ends) == 0:
return None
random_end = random.randint(0, len(mwt_ends)-1)
mwt_end = mwt_ends[random_end]
mwt_start = mwt_end - 1
while mwt_start >= 0 and sentence[1][mwt_start] == 0:
mwt_start -= 1
mwt_start += 1
while sentence[3][mwt_start].isspace():
mwt_start += 1
if mwt_start == mwt_end:
return None
mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1])
if mwt not in self.mwt_expansions:
return None
all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]
w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]]
w0_units[-1] = (w0_units[-1][0], 1)
w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]]
w1_units[-1] = (w1_units[-1][0], 1)
split_units = w0_units + [(' ', 0)] + w1_units
new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:]
encoded = self.para_to_sentences(new_units)
return encoded
def move_punct_back(self, sentence):
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
return None
# check that we are not accidentally creating decimal numbers
# idx == 1 or not sentence[3][idx-2].isdigit()
# one disadvantage of checking for sentence[1][idx] == 0
# would be that tokens of all punct, such as '...',
# should move but would not move if this is eliminated
commas = [idx for idx, c in enumerate(sentence[3])
if c in self.move_punct and idx > 0 and sentence[3][idx-1].isspace() and (idx == 1 or not sentence[3][idx-2].isdigit())]
if len(commas) == 0:
return None
all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]
new_units = []
span_start = 0
for span_end in commas:
new_units.extend(all_units[span_start:span_end-1])
span_start = span_end
if span_end < len(sentence[3]):
new_units.extend(all_units[span_end:])
encoded = self.para_to_sentences(new_units)
return encoded
def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0):
''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''
feat_size = len(self.sentences[0][0][2][0])
unkid = self.vocab.unit2id('<UNK>')
padid = self.vocab.unit2id('<PAD>')
def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
# At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed
# by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences
# from the entire dataset until we reach max_seqlen.
drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0))
drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))
move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0)
move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0)
split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0)
pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)
sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]
total_len = len(sentences[0][0])
assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])]))
if self.eval:
for sid1 in range(sid+1, len(self.sentences[pid])):
total_len += len(self.sentences[pid][sid1][0])
sentences.append(self.sentences[pid][sid1])
if total_len >= self.args['max_seqlen']:
break
else:
while True:
pid1, sid1 = random.choice(self.sentence_ids)
total_len += len(self.sentences[pid1][sid1][0])
sentences.append(self.sentences[pid1][sid1])
if total_len >= self.args['max_seqlen']:
break
if move_last_char_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < move_last_char_prob:
# the sentence might not be eligible, such as
# already having a space or not having a sentence final punct,
# so we need to do a two step checking process here
new_sentence = self.move_last_char(sentence)
if new_sentence is not None:
sentences[sentence_idx] = new_sentence[0]
total_len += 1
if move_punct_back_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < move_punct_back_prob:
# the sentence might not be eligible, such as
# not having a space separated punct,
# so we need to do a two step checking process here
new_sentence = self.move_punct_back(sentence)
if new_sentence is not None:
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
sentences[sentence_idx] = new_sentence[0]
if split_mwt_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < split_mwt_prob:
new_sentence = self.split_mwt(sentence)
if new_sentence is not None:
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
sentences[sentence_idx] = new_sentence[0]
if drop_sents and len(sentences) > 1:
if total_len > self.args['max_seqlen']:
sentences = sentences[:-1]
if len(sentences) > 1:
p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability
cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0]
sentences = sentences[:cutoff+1]
units = np.concatenate([s[0] for s in sentences])
labels = np.concatenate([s[1] for s in sentences])
feats = np.concatenate([s[2] for s in sentences])
raw_units = [x for s in sentences for x in s[3]]
if not self.eval:
cutoff = self.args['max_seqlen']
units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff]
if drop_last_char: # can only happen in non-eval mode
if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3):
# training text ended with a sentence end position
# and that word was a single character
# and the previous character ended the word
units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1]
# word end -> sentence end, mwt end -> sentence mwt end
labels[-1] = labels[-1] + 1
return units, labels, feats, raw_units
if eval_offsets is not None:
# find max padding length
pad_len = 0
for eval_offset in eval_offsets:
if eval_offset < self.cumlen[-1]:
pair_id = bisect_right(self.cumlen, eval_offset) - 1
pair = self.sentence_ids[pair_id]
pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0]))
pad_len += 1
id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets]
pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs]
offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)]
offsets_pairs = list(zip(offsets, pairs))
else:
id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size']))
offsets_pairs = [(0, x) for x in id_pairs]
pad_len = self.args['max_seqlen']
# put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors
units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64)
labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64)
features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32)
raw_units = []
for i, (offset, pair) in enumerate(offsets_pairs):
u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len)
units[i, :len(u_)] = u_
labels[i, :len(l_)] = l_
features[i, :len(f_), :] = f_
raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))
if unit_dropout > 0 and not self.eval:
# dropout characters/units at training time and replace them with UNKs
mask = np.random.random_sample(units.shape) < unit_dropout
mask[units == padid] = 0
units[mask] = unkid
for i in range(len(raw_units)):
for j in range(len(raw_units[i])):
if mask[i, j]:
raw_units[i][j] = '<UNK>'
# dropout unit feature vector in addition to only torch.dropout in the model.
# experiments showed that only torch.dropout hurts the model
# we believe it is because the dict feature vector is mostly scarse so it makes
# more sense to drop out the whole vector instead of only single element.
if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
mask_feat[units == padid] = 0
for i in range(len(raw_units)):
for j in range(len(raw_units[i])):
if mask_feat[i,j]:
features[i,j,:] = 0
units = torch.from_numpy(units)
labels = torch.from_numpy(labels)
features = torch.from_numpy(features)
return units, labels, features, raw_units
class SortedDataset(Dataset):
"""
Holds a TokenizationDataset for use in a torch DataLoader
The torch DataLoader is different from the DataLoader defined here
and allows for cpu & gpu parallelism. Updating output_predictions
to use this class as a wrapper to a TokenizationDataset means the
calculation of features can happen in parallel, saving quite a
bit of time.
"""
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# This will return a single sample
# np: index in character map
# np: tokenization label
# np: features
# list: original text as one length strings
return self.dataset.para_to_sentences(self.data[index])
def unsort(self, arr):
return unsort(arr, self.indices)
def collate(self, samples):
if any(len(x) > 1 for x in samples):
raise ValueError("Expected all paragraphs to have no preset sentence splits!")
feat_size = samples[0][0][2].shape[-1]
padid = self.dataset.vocab.unit2id('<PAD>')
# +1 so that all samples end with at least one pad
pad_len = max(len(x[0][3]) for x in samples) + 1
units = torch.full((len(samples), pad_len), padid, dtype=torch.int64)
labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32)
features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32)
raw_units = []
for i, sample in enumerate(samples):
u_, l_, f_, r_ = sample[0]
units[i, :len(u_)] = torch.from_numpy(u_)
labels[i, :len(l_)] = torch.from_numpy(l_)
features[i, :len(f_), :] = torch.from_numpy(f_)
raw_units.append(r_ + ['<PAD>'])
return units, labels, features, raw_units
|