|
|
import random |
|
|
import numpy as np |
|
|
import os |
|
|
from collections import Counter, namedtuple |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from torch.utils.data import DataLoader as DL |
|
|
|
|
|
import stanza.models.common.seq2seq_constant as constant |
|
|
from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all |
|
|
from stanza.models.common.vocab import DeltaVocab |
|
|
from stanza.models.mwt.vocab import Vocab |
|
|
from stanza.models.common.doc import Document |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
|
|
|
DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text") |
|
|
DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''') |
|
|
|
|
|
class DataLoader: |
|
|
def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False): |
|
|
self.batch_size = batch_size |
|
|
self.args = args |
|
|
self.augment_apos = args.get('augment_apos', 0.0) |
|
|
self.evaluation = evaluation |
|
|
self.doc = doc |
|
|
|
|
|
data = self.load_doc(self.doc, evaluation=self.evaluation) |
|
|
|
|
|
|
|
|
if vocab is None: |
|
|
assert self.evaluation == False |
|
|
self.vocab = self.init_vocab(data) |
|
|
if self.augment_apos > 0 and any(x in self.vocab for x in APOS): |
|
|
for apos in APOS: |
|
|
self.vocab.add_unit(apos) |
|
|
elif expand_unk_vocab: |
|
|
self.vocab = DeltaVocab(data, vocab) |
|
|
else: |
|
|
self.vocab = vocab |
|
|
|
|
|
|
|
|
if args.get('sample_train', 1.0) < 1.0 and not self.evaluation: |
|
|
keep = int(args['sample_train'] * len(data)) |
|
|
data = random.sample(data, keep) |
|
|
logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) |
|
|
|
|
|
|
|
|
if not self.evaluation: |
|
|
indices = list(range(len(data))) |
|
|
random.shuffle(indices) |
|
|
data = [data[i] for i in indices] |
|
|
|
|
|
self.data = data |
|
|
self.num_examples = len(data) |
|
|
|
|
|
def init_vocab(self, data): |
|
|
assert self.evaluation == False |
|
|
vocab = Vocab(data, self.args['shorthand']) |
|
|
return vocab |
|
|
|
|
|
def maybe_augment_apos(self, datum): |
|
|
for original in APOS: |
|
|
if original in datum[0]: |
|
|
if random.uniform(0,1) < self.augment_apos: |
|
|
replacement = random.choice(APOS) |
|
|
datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement)) |
|
|
break |
|
|
return datum |
|
|
|
|
|
def process(self, sample): |
|
|
if not self.evaluation and self.augment_apos > 0: |
|
|
sample = self.maybe_augment_apos(sample) |
|
|
src = list(sample[0]) |
|
|
src = [constant.SOS] + src + [constant.EOS] |
|
|
tgt_in, tgt_out = self.prepare_target(self.vocab, sample) |
|
|
src = self.vocab.map(src) |
|
|
processed = [src, tgt_in, tgt_out, sample[0]] |
|
|
return processed |
|
|
|
|
|
def prepare_target(self, vocab, datum): |
|
|
if self.evaluation: |
|
|
tgt = list(datum[0]) |
|
|
else: |
|
|
tgt = list(datum[1]) |
|
|
tgt_in = vocab.map([constant.SOS] + tgt) |
|
|
tgt_out = vocab.map(tgt + [constant.EOS]) |
|
|
return tgt_in, tgt_out |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, key): |
|
|
""" Get a batch with index. """ |
|
|
if not isinstance(key, int): |
|
|
raise TypeError |
|
|
if key < 0 or key >= len(self.data): |
|
|
raise IndexError |
|
|
sample = self.data[key] |
|
|
sample = self.process(sample) |
|
|
assert len(sample) == 4 |
|
|
|
|
|
src = torch.tensor(sample[0]) |
|
|
tgt_in = torch.tensor(sample[1]) |
|
|
tgt_out = torch.tensor(sample[2]) |
|
|
orig_text = sample[3] |
|
|
result = DataSample(src, tgt_in, tgt_out, orig_text), key |
|
|
return result |
|
|
|
|
|
@staticmethod |
|
|
def __collate_fn(data): |
|
|
(data, idx) = zip(*data) |
|
|
(src, tgt_in, tgt_out, orig_text) = zip(*data) |
|
|
|
|
|
|
|
|
batch_size = len(data) |
|
|
|
|
|
|
|
|
|
|
|
lens = [len(x) for x in src] |
|
|
(src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens) |
|
|
lens = [len(x) for x in src] |
|
|
|
|
|
|
|
|
src = pad_sequence(src, True, constant.PAD_ID) |
|
|
src_mask = torch.eq(src, constant.PAD_ID) |
|
|
tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID) |
|
|
tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID) |
|
|
assert tgt_in.size(1) == tgt_out.size(1), \ |
|
|
"Target input and output sequence sizes do not match." |
|
|
return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx) |
|
|
|
|
|
def __iter__(self): |
|
|
for i in range(self.__len__()): |
|
|
yield self.__getitem__(i) |
|
|
|
|
|
def to_loader(self): |
|
|
"""Converts self to a DataLoader """ |
|
|
|
|
|
batch_size = self.batch_size |
|
|
shuffle = not self.evaluation |
|
|
return DL(self, |
|
|
collate_fn=self.__collate_fn, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle) |
|
|
|
|
|
def load_doc(self, doc, evaluation=False): |
|
|
data = doc.get_mwt_expansions(evaluation) |
|
|
if evaluation: data = [[e] for e in data] |
|
|
return data |
|
|
|
|
|
class BinaryDataLoader(DataLoader): |
|
|
""" |
|
|
This version of the DataLoader performs the same tasks as the regular DataLoader, |
|
|
except the targets are arrays of 0/1 indicating if the character is the location |
|
|
of an MWT split |
|
|
""" |
|
|
def prepare_target(self, vocab, datum): |
|
|
src = datum[0] if self.evaluation else datum[1] |
|
|
binary = [0] |
|
|
has_space = False |
|
|
for char in src: |
|
|
if char == ' ': |
|
|
has_space = True |
|
|
elif has_space: |
|
|
has_space = False |
|
|
binary.append(1) |
|
|
else: |
|
|
binary.append(0) |
|
|
binary.append(0) |
|
|
return binary, binary |
|
|
|
|
|
|