Albin Thörn Cleland
Clean initial commit with LFS
19b8775
import random
import numpy as np
import os
from collections import Counter, namedtuple
import logging
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader as DL
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
from stanza.models.common.vocab import DeltaVocab
from stanza.models.mwt.vocab import Vocab
from stanza.models.common.doc import Document
logger = logging.getLogger('stanza')
DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text")
DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx")
# enforce that the MWT splitter knows about a couple different alternate apostrophes
# including covering some potential " typos
# setting the augmentation to a very low value should be enough to teach it
# about the unknown characters without messing up the predictions for other text
#
# 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07
APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''')
class DataLoader:
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
self.batch_size = batch_size
self.args = args
self.augment_apos = args.get('augment_apos', 0.0)
self.evaluation = evaluation
self.doc = doc
data = self.load_doc(self.doc, evaluation=self.evaluation)
# handle vocab
if vocab is None:
assert self.evaluation == False # for eval vocab must exist
self.vocab = self.init_vocab(data)
if self.augment_apos > 0 and any(x in self.vocab for x in APOS):
for apos in APOS:
self.vocab.add_unit(apos)
elif expand_unk_vocab:
self.vocab = DeltaVocab(data, vocab)
else:
self.vocab = vocab
# filter and sample data
if args.get('sample_train', 1.0) < 1.0 and not self.evaluation:
keep = int(args['sample_train'] * len(data))
data = random.sample(data, keep)
logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
# shuffle for training
if not self.evaluation:
indices = list(range(len(data)))
random.shuffle(indices)
data = [data[i] for i in indices]
self.data = data
self.num_examples = len(data)
def init_vocab(self, data):
assert self.evaluation == False # for eval vocab must exist
vocab = Vocab(data, self.args['shorthand'])
return vocab
def maybe_augment_apos(self, datum):
for original in APOS:
if original in datum[0]:
if random.uniform(0,1) < self.augment_apos:
replacement = random.choice(APOS)
datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement))
break
return datum
def process(self, sample):
if not self.evaluation and self.augment_apos > 0:
sample = self.maybe_augment_apos(sample)
src = list(sample[0])
src = [constant.SOS] + src + [constant.EOS]
tgt_in, tgt_out = self.prepare_target(self.vocab, sample)
src = self.vocab.map(src)
processed = [src, tgt_in, tgt_out, sample[0]]
return processed
def prepare_target(self, vocab, datum):
if self.evaluation:
tgt = list(datum[0]) # as a placeholder
else:
tgt = list(datum[1])
tgt_in = vocab.map([constant.SOS] + tgt)
tgt_out = vocab.map(tgt + [constant.EOS])
return tgt_in, tgt_out
def __len__(self):
return len(self.data)
def __getitem__(self, key):
""" Get a batch with index. """
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
sample = self.data[key]
sample = self.process(sample)
assert len(sample) == 4
src = torch.tensor(sample[0])
tgt_in = torch.tensor(sample[1])
tgt_out = torch.tensor(sample[2])
orig_text = sample[3]
result = DataSample(src, tgt_in, tgt_out, orig_text), key
return result
@staticmethod
def __collate_fn(data):
(data, idx) = zip(*data)
(src, tgt_in, tgt_out, orig_text) = zip(*data)
# collate_fn is given a list of length batch size
batch_size = len(data)
# need to sort by length of src to properly handle
# the batching in the model itself
lens = [len(x) for x in src]
(src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens)
lens = [len(x) for x in src]
# convert to tensors
src = pad_sequence(src, True, constant.PAD_ID)
src_mask = torch.eq(src, constant.PAD_ID)
tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID)
tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID)
assert tgt_in.size(1) == tgt_out.size(1), \
"Target input and output sequence sizes do not match."
return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def to_loader(self):
"""Converts self to a DataLoader """
batch_size = self.batch_size
shuffle = not self.evaluation
return DL(self,
collate_fn=self.__collate_fn,
batch_size=batch_size,
shuffle=shuffle)
def load_doc(self, doc, evaluation=False):
data = doc.get_mwt_expansions(evaluation)
if evaluation: data = [[e] for e in data]
return data
class BinaryDataLoader(DataLoader):
"""
This version of the DataLoader performs the same tasks as the regular DataLoader,
except the targets are arrays of 0/1 indicating if the character is the location
of an MWT split
"""
def prepare_target(self, vocab, datum):
src = datum[0] if self.evaluation else datum[1]
binary = [0]
has_space = False
for char in src:
if char == ' ':
has_space = True
elif has_space:
has_space = False
binary.append(1)
else:
binary.append(0)
binary.append(0)
return binary, binary