File size: 28,521 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
from bisect import bisect_right
from copy import copy
import numpy as np
import random
import logging
import re
import torch
from torch.utils.data import Dataset

from stanza.models.common.utils import sort_with_indices, unsort
from stanza.models.tokenization.vocab import Vocab

logger = logging.getLogger('stanza')

def filter_consecutive_whitespaces(para):
    filtered = []
    for i, (char, label) in enumerate(para):
        if i > 0:
            if char == ' ' and para[i-1][0] == ' ':
                continue

        filtered.append((char, label))

    return filtered

NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
# this was (r'^([\d]+[,\.]*)+$')
# but the runtime on that can explode exponentially
# for example, on 111111111111111111111111a
NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$')
WHITESPACE_RE = re.compile(r'\s')

class TokenizationDataset:
    def __init__(self, tokenizer_args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, *args, **kwargs):
        super().__init__(*args, **kwargs)  # forwards all unused arguments
        self.args = tokenizer_args
        self.eval = evaluation
        self.dictionary = dictionary
        self.vocab = vocab

        # get input files
        txt_file = input_files['txt']
        label_file = input_files['label']

        # Load data and process it
        # set up text from file or input string
        assert txt_file is not None or input_text is not None
        if input_text is None:
            with open(txt_file, encoding="utf-8") as f:
                text = ''.join(f.readlines()).rstrip()
        else:
            text = input_text

        text_chunks = NEWLINE_WHITESPACE_RE.split(text)
        text_chunks = [pt.rstrip() for pt in text_chunks]
        text_chunks = [pt for pt in text_chunks if pt]
        if label_file is not None:
            with open(label_file, encoding="utf-8") as f:
                labels = ''.join(f.readlines()).rstrip()
                labels = NEWLINE_WHITESPACE_RE.split(labels)
                labels = [pt.rstrip() for pt in labels]
                labels = [map(int, pt) for pt in labels if pt]
        else:
            labels = [[0 for _ in pt] for pt in text_chunks]

        skip_newline = self.args.get('skip_newline', False)
        self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces
                      for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten
                     for pt, pc in zip(text_chunks, labels)]

        # remove consecutive whitespaces
        self.data = [filter_consecutive_whitespaces(x) for x in self.data]

    def labels(self):
        """
        Returns a list of the labels for all of the sentences in this DataLoader

        Used at eval time to compare to the results, for example
        """
        return [np.array(list(x[1] for x in sent)) for sent in self.data]

    def extract_dict_feat(self, para, idx):
        """
        This function is to extract dictionary features for each character
        """
        length = len(para)

        dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
        dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
        forward_word = para[idx][0]
        backward_word = para[idx][0]
        prefix = True
        suffix = True
        for window in range(1,self.args['num_dict_feat']+1):
            # concatenate each character and check if words found in dict not, stop if prefix not found
            #check if idx+t is out of bound and if the prefix is already not found
            if (idx + window) <= length-1 and prefix:
                forward_word += para[idx+window][0].lower()
                #check in json file if the word is present as prefix or word or None.
                feat = 1 if forward_word in self.dictionary["words"] else 0
                #if the return value is not 2 or 3 then the checking word is not a valid word in dict.
                dict_forward_feats[window-1] = feat
                #if the dict return 0 means no prefixes found, thus, stop looking for forward.
                if forward_word not in self.dictionary["prefixes"]:
                    prefix = False
            #backward check: similar to forward
            if (idx - window) >= 0 and suffix:
                backward_word = para[idx-window][0].lower() + backward_word
                feat = 1 if backward_word in self.dictionary["words"] else 0
                dict_backward_feats[window-1] = feat
                if backward_word not in self.dictionary["suffixes"]:
                    suffix = False
            #if cannot find both prefix and suffix, then exit the loop
            if not prefix and not suffix:
                break

        return dict_forward_feats + dict_backward_feats

    def para_to_sentences(self, para):
        """ Convert a paragraph to a list of processed sentences. """
        res = []
        funcs = []
        for feat_func in self.args['feat_funcs']:
            if feat_func == 'end_of_para' or feat_func == 'start_of_para':
                # skip for position-dependent features
                continue
            if feat_func == 'space_before':
                func = lambda x: 1 if x.startswith(' ') else 0
            elif feat_func == 'capitalized':
                func = lambda x: 1 if x[0].isupper() else 0
            elif feat_func == 'numeric':
                func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
            else:
                raise ValueError('Feature function "{}" is undefined.'.format(feat_func))

            funcs.append(func)

        # stacking all featurize functions
        composite_func = lambda x: [f(x) for f in funcs]

        def process_sentence(sent_units, sent_labels, sent_feats):
            return (np.array([self.vocab.unit2id(y) for y in sent_units]),
                    np.array(sent_labels),
                    np.array(sent_feats),
                    list(sent_units))

        use_end_of_para = 'end_of_para' in self.args['feat_funcs']
        use_start_of_para = 'start_of_para' in self.args['feat_funcs']
        use_dictionary = self.args['use_dictionary']
        current_units = []
        current_labels = []
        current_feats = []
        for i, (unit, label) in enumerate(para):
            feats = composite_func(unit)
            # position-dependent features
            if use_end_of_para:
                f = 1 if i == len(para)-1 else 0
                feats.append(f)
            if use_start_of_para:
                f = 1 if i == 0 else 0
                feats.append(f)

            #if dictionary feature is selected
            if use_dictionary:
                dict_feats = self.extract_dict_feat(para, i)
                feats = feats + dict_feats

            current_units.append(unit)
            current_labels.append(label)
            current_feats.append(feats)
            if not self.eval and (label == 2 or label == 4): # end of sentence
                if len(current_units) <= self.args['max_seqlen']:
                    # get rid of sentences that are too long during training of the tokenizer
                    res.append(process_sentence(current_units, current_labels, current_feats))
                current_units.clear()
                current_labels.clear()
                current_feats.clear()

        if len(current_units) > 0:
            if self.eval or len(current_units) <= self.args['max_seqlen']:
                res.append(process_sentence(current_units, current_labels, current_feats))

        return res

    def advance_old_batch(self, eval_offsets, old_batch):
        """
        Advance to a new position in a batch where we have partially processed the batch

        If we have previously built a batch of data and made predictions on them, then when we are trying to make
        prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
        and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
        In this case, eval_offsets index within the old_batch to advance the strings to process.
        """
        unkid = self.vocab.unit2id('<UNK>')
        padid = self.vocab.unit2id('<PAD>')

        ounits, olabels, ofeatures, oraw = old_batch
        feat_size = ofeatures.shape[-1]
        lens = (ounits != padid).sum(1).tolist()
        pad_len = max(l-i for i, l in zip(eval_offsets, lens))

        units = torch.full((len(ounits), pad_len), padid, dtype=torch.int64)
        labels = torch.full((len(ounits), pad_len), -1, dtype=torch.int32)
        features = torch.zeros((len(ounits), pad_len, feat_size), dtype=torch.float32)
        raw_units = []

        for i in range(len(ounits)):
            eval_offsets[i] = min(eval_offsets[i], lens[i])
            units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]
            labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]
            features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]
            raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + ['<PAD>'] * (pad_len - lens[i] + eval_offsets[i]))

        return units, labels, features, raw_units

def build_move_punct_set(data, move_back_prob):
    move_punct = {',', ':', '!', '.', '?', '"', '(', ')'}
    for chunk in data:
        # ignore positions at the start and end of a chunk
        for idx in range(1, len(chunk)-1):
            if chunk[idx][0] not in move_punct:
                continue
            if chunk[idx][1] == 0:
                if chunk[idx+1][0].isspace() and not chunk[idx-1][0].isdigit():
                    # this check removes punct which isn't ending a word...
                    # honestly that's a rather unusual situation
                    # VI has |3, 5| as a complete token
                    # so we also eliminate isdigit()
                    move_punct.remove(chunk[idx][0])
                continue
            # we skip isdigit() because we will intentionally not
            # create things that look like decimal numbers
            if not chunk[idx-1][0].isspace() and chunk[idx-1][0] not in move_punct and not chunk[idx-1][0].isdigit():
                # this check eliminates things like '.' after 'Mr.'
                move_punct.remove(chunk[idx][0])
                continue
    return move_punct

def build_known_mwt(data, mwt_expansions):
    known_mwts = set()
    for chunk in data:
        for idx, unit in enumerate(chunk):
            if unit[1] != 3:
                continue
            # found an MWT
            prev_idx = idx - 1
            while prev_idx >= 0 and chunk[prev_idx][1] == 0:
                prev_idx -= 1
            prev_idx += 1
            while chunk[prev_idx][0].isspace():
                prev_idx += 1
            if prev_idx == idx:
                continue
            mwt = "".join(x[0] for x in chunk[prev_idx:idx+1])
            if mwt not in mwt_expansions:
                continue
            if len(mwt_expansions[mwt]) > 2:
                # TODO: could split 3 word tokens as well
                continue
            known_mwts.add(mwt)
    return known_mwts

class DataLoader(TokenizationDataset):
    """
    This is the training version of the dataset.
    """
    def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None):
        super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)

        self.vocab = vocab if vocab is not None else self.init_vocab()

        # data comes in a list of paragraphs, where each paragraph is a list of units with unit-level labels.
        # At evaluation time, each paragraph is treated as single "sentence" as we don't know a priori where
        # sentence breaks occur. We make prediction from left to right for each paragraph and move forward to
        # the last predicted sentence break to start afresh.
        self.sentences = [self.para_to_sentences(para) for para in self.data]

        self.init_sent_ids()
        logger.debug(f"{len(self.sentence_ids)} sentences loaded.")

        punct_move_back_prob = args.get('punct_move_back_prob', 0.0)
        if punct_move_back_prob > 0.0:
            self.move_punct = build_move_punct_set(self.data, punct_move_back_prob)
            if len(self.move_punct) > 0:
                logger.debug('Based on the training data, will augment space/punct combinations {}'.format(self.move_punct))
            else:
                logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace')

        split_mwt_prob = args.get('split_mwt_prob', 0.0)
        if split_mwt_prob > 0.0 and not evaluation:
            self.mwt_expansions = mwt_expansions
            self.known_mwt = build_known_mwt(self.data, mwt_expansions)
            if len(self.known_mwt) > 0:
                logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt))
            else:
                logger.debug('Based on the training data, there are NO MWT to split at training time')

    def __len__(self):
        return len(self.sentence_ids)

    def init_vocab(self):
        vocab = Vocab(self.data, self.args['lang'])
        return vocab

    def init_sent_ids(self):
        self.sentence_ids = []
        self.cumlen = [0]
        for i, para in enumerate(self.sentences):
            for j in range(len(para)):
                self.sentence_ids += [(i, j)]
                self.cumlen += [self.cumlen[-1] + len(self.sentences[i][j][0])]

    def has_mwt(self):
        # presumably this only needs to be called either 0 or 1 times,
        # 1 when training and 0 any other time, so no effort is put
        # into caching the result
        for sentence in self.data:
            for word in sentence:
                if word[1] > 2:
                    return True
        return False

    def shuffle(self):
        for para in self.sentences:
            random.shuffle(para)
        self.init_sent_ids()

    def move_last_char(self, sentence):
        if len(sentence[3]) > 1 and len(sentence[3]) < self.args['max_seqlen'] and sentence[1][-1] == 2 and sentence[1][-2] != 0:
            new_units = [(x, int(y)) for x, y in zip(sentence[3][:-1], sentence[1][:-1])]
            new_units.extend([(' ', 0), (sentence[3][-1], int(sentence[1][-1]))])
            encoded = self.para_to_sentences(new_units)
            return encoded
        return None

    def split_mwt(self, sentence):
        if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
            return None

        # if we find a token in the sentence which ends with label 3,
        # eg it is an MWT,
        # with some probability we split it into two tokens
        # and treat the split tokens as both label 1 instead of 3
        # in this manner, we teach the tokenizer not to treat the
        # entire sequence of characters with added spaces as an MWT,
        # which weirdly can happen in some corner cases

        mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3]
        if len(mwt_ends) == 0:
            return None
        random_end = random.randint(0, len(mwt_ends)-1)
        mwt_end = mwt_ends[random_end]
        mwt_start = mwt_end - 1
        while mwt_start >= 0 and sentence[1][mwt_start] == 0:
            mwt_start -= 1
        mwt_start += 1
        while sentence[3][mwt_start].isspace():
            mwt_start += 1
        if mwt_start == mwt_end:
            return None
        mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1])
        if mwt not in self.mwt_expansions:
            return None

        all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]
        w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]]
        w0_units[-1] = (w0_units[-1][0], 1)
        w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]]
        w1_units[-1] = (w1_units[-1][0], 1)
        split_units = w0_units + [(' ', 0)] + w1_units
        new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:]
        encoded = self.para_to_sentences(new_units)
        return encoded

    def move_punct_back(self, sentence):
        if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
            return None

        # check that we are not accidentally creating decimal numbers
        #   idx == 1 or not sentence[3][idx-2].isdigit()
        # one disadvantage of checking for sentence[1][idx] == 0
        #   would be that tokens of all punct, such as '...',
        #   should move but would not move if this is eliminated
        commas = [idx for idx, c in enumerate(sentence[3])
                  if c in self.move_punct and idx > 0 and sentence[3][idx-1].isspace() and (idx == 1 or not sentence[3][idx-2].isdigit())]
        if len(commas) == 0:
            return None

        all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]
        new_units = []

        span_start = 0
        for span_end in commas:
            new_units.extend(all_units[span_start:span_end-1])
            span_start = span_end
        if span_end < len(sentence[3]):
            new_units.extend(all_units[span_end:])

        encoded = self.para_to_sentences(new_units)
        return encoded


    def next(self, eval_offsets=None, unit_dropout=0.0, feat_unit_dropout=0.0):
        ''' Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. '''
        feat_size = len(self.sentences[0][0][2][0])
        unkid = self.vocab.unit2id('<UNK>')
        padid = self.vocab.unit2id('<PAD>')

        def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
            # At eval time, this combines sentences in paragraph (indexed by id_pair[0]) starting sentence (indexed 
            # by id_pair[1]) into a long string for evaluation. At training time, we just select random sentences
            # from the entire dataset until we reach max_seqlen.
            drop_sents = False if self.eval or (self.args.get('sent_drop_prob', 0) == 0) else (random.random() < self.args.get('sent_drop_prob', 0))
            drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))
            move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0)
            move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0)
            split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0)

            pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)
            sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]
            total_len = len(sentences[0][0])

            assert self.eval or total_len <= self.args['max_seqlen'], 'The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}'.format(self.args['max_seqlen'], total_len, ' '.join(["{}/{}".format(*x) for x in zip(self.sentences[pid][sid])]))
            if self.eval:
                for sid1 in range(sid+1, len(self.sentences[pid])):
                    total_len += len(self.sentences[pid][sid1][0])
                    sentences.append(self.sentences[pid][sid1])

                    if total_len >= self.args['max_seqlen']:
                        break
            else:
                while True:
                    pid1, sid1 = random.choice(self.sentence_ids)
                    total_len += len(self.sentences[pid1][sid1][0])
                    sentences.append(self.sentences[pid1][sid1])

                    if total_len >= self.args['max_seqlen']:
                        break

            if move_last_char_prob > 0.0:
                for sentence_idx, sentence in enumerate(sentences):
                    if random.random() < move_last_char_prob:
                        # the sentence might not be eligible, such as
                        # already having a space or not having a sentence final punct,
                        # so we need to do a two step checking process here
                        new_sentence = self.move_last_char(sentence)
                        if new_sentence is not None:
                            sentences[sentence_idx] = new_sentence[0]
                            total_len += 1

            if move_punct_back_prob > 0.0:
                for sentence_idx, sentence in enumerate(sentences):
                    if random.random() < move_punct_back_prob:
                        # the sentence might not be eligible, such as
                        # not having a space separated punct,
                        # so we need to do a two step checking process here
                        new_sentence = self.move_punct_back(sentence)
                        if new_sentence is not None:
                            total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
                            sentences[sentence_idx] = new_sentence[0]

            if split_mwt_prob > 0.0:
                for sentence_idx, sentence in enumerate(sentences):
                    if random.random() < split_mwt_prob:
                        new_sentence = self.split_mwt(sentence)
                        if new_sentence is not None:
                            total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
                            sentences[sentence_idx] = new_sentence[0]

            if drop_sents and len(sentences) > 1:
                if total_len > self.args['max_seqlen']:
                    sentences = sentences[:-1]
                if len(sentences) > 1:
                    p = [.5 ** i for i in range(1, len(sentences) + 1)] # drop a large number of sentences with smaller probability
                    cutoff = random.choices(list(range(len(sentences))), weights=list(reversed(p)))[0]
                    sentences = sentences[:cutoff+1]

            units = np.concatenate([s[0] for s in sentences])
            labels = np.concatenate([s[1] for s in sentences])
            feats = np.concatenate([s[2] for s in sentences])
            raw_units = [x for s in sentences for x in s[3]]

            if not self.eval:
                cutoff = self.args['max_seqlen']
                units, labels, feats, raw_units = units[:cutoff], labels[:cutoff], feats[:cutoff], raw_units[:cutoff]

            if drop_last_char:  # can only happen in non-eval mode
                if len(labels) > 1 and labels[-1] == 2 and labels[-2] in (1, 3):
                    # training text ended with a sentence end position
                    # and that word was a single character
                    # and the previous character ended the word
                    units, labels, feats, raw_units = units[:-1], labels[:-1], feats[:-1], raw_units[:-1]
                    # word end -> sentence end, mwt end -> sentence mwt end
                    labels[-1] = labels[-1] + 1

            return units, labels, feats, raw_units

        if eval_offsets is not None:
            # find max padding length
            pad_len = 0
            for eval_offset in eval_offsets:
                if eval_offset < self.cumlen[-1]:
                    pair_id = bisect_right(self.cumlen, eval_offset) - 1
                    pair = self.sentence_ids[pair_id]
                    pad_len = max(pad_len, len(strings_starting(pair, offset=eval_offset-self.cumlen[pair_id])[0]))

            pad_len += 1
            id_pairs = [bisect_right(self.cumlen, eval_offset) - 1 for eval_offset in eval_offsets]
            pairs = [self.sentence_ids[pair_id] for pair_id in id_pairs]
            offsets = [eval_offset - self.cumlen[pair_id] for eval_offset, pair_id in zip(eval_offsets, id_pairs)]

            offsets_pairs = list(zip(offsets, pairs))
        else:
            id_pairs = random.sample(self.sentence_ids, min(len(self.sentence_ids), self.args['batch_size']))
            offsets_pairs = [(0, x) for x in id_pairs]
            pad_len = self.args['max_seqlen']

        # put everything into padded and nicely shaped NumPy arrays and eventually convert to PyTorch tensors
        units = np.full((len(id_pairs), pad_len), padid, dtype=np.int64)
        labels = np.full((len(id_pairs), pad_len), -1, dtype=np.int64)
        features = np.zeros((len(id_pairs), pad_len, feat_size), dtype=np.float32)
        raw_units = []
        for i, (offset, pair) in enumerate(offsets_pairs):
            u_, l_, f_, r_ = strings_starting(pair, offset=offset, pad_len=pad_len)
            units[i, :len(u_)] = u_
            labels[i, :len(l_)] = l_
            features[i, :len(f_), :] = f_
            raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))

        if unit_dropout > 0 and not self.eval:
            # dropout characters/units at training time and replace them with UNKs
            mask = np.random.random_sample(units.shape) < unit_dropout
            mask[units == padid] = 0
            units[mask] = unkid
            for i in range(len(raw_units)):
                for j in range(len(raw_units[i])):
                    if mask[i, j]:
                        raw_units[i][j] = '<UNK>'

        # dropout unit feature vector in addition to only torch.dropout in the model.
        # experiments showed that only torch.dropout hurts the model
        # we believe it is because the dict feature vector is mostly scarse so it makes
        # more sense to drop out the whole vector instead of only single element.
        if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
            mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
            mask_feat[units == padid] = 0
            for i in range(len(raw_units)):
                for j in range(len(raw_units[i])):
                    if mask_feat[i,j]:
                        features[i,j,:] = 0
                        
        units = torch.from_numpy(units)
        labels = torch.from_numpy(labels)
        features = torch.from_numpy(features)

        return units, labels, features, raw_units

class SortedDataset(Dataset):
    """
    Holds a TokenizationDataset for use in a torch DataLoader

    The torch DataLoader is different from the DataLoader defined here
    and allows for cpu & gpu parallelism.  Updating output_predictions
    to use this class as a wrapper to a TokenizationDataset means the
    calculation of features can happen in parallel, saving quite a
    bit of time.
    """
    def __init__(self, dataset):
        super().__init__()

        self.dataset = dataset
        self.data, self.indices = sort_with_indices(self.dataset.data, key=len, reverse=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # This will return a single sample
        #   np: index in character map
        #   np: tokenization label
        #   np: features
        #   list: original text as one length strings
        return self.dataset.para_to_sentences(self.data[index])

    def unsort(self, arr):
        return unsort(arr, self.indices)

    def collate(self, samples):
        if any(len(x) > 1 for x in samples):
            raise ValueError("Expected all paragraphs to have no preset sentence splits!")
        feat_size = samples[0][0][2].shape[-1]
        padid = self.dataset.vocab.unit2id('<PAD>')

        # +1 so that all samples end with at least one pad
        pad_len = max(len(x[0][3]) for x in samples) + 1

        units = torch.full((len(samples), pad_len), padid, dtype=torch.int64)
        labels = torch.full((len(samples), pad_len), -1, dtype=torch.int32)
        features = torch.zeros((len(samples), pad_len, feat_size), dtype=torch.float32)
        raw_units = []
        for i, sample in enumerate(samples):
            u_, l_, f_, r_ = sample[0]
            units[i, :len(u_)] = torch.from_numpy(u_)
            labels[i, :len(l_)] = torch.from_numpy(l_)
            features[i, :len(f_), :] = torch.from_numpy(f_)
            raw_units.append(r_ + ['<PAD>'])

        return units, labels, features, raw_units