varox34's picture
Upload 64 files
366b225 verified
# -*- coding: utf-8 -*-
from collections.abc import Iterable
from itertools import chain
from parser.utils.alg import kmeans
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, Sampler
class TextDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super(TextDataLoader, self).__init__(*args, **kwargs)
self.fields = self.dataset.fields
def __iter__(self):
for raw_batch in super(TextDataLoader, self).__iter__():
batch, device = [], 'cuda' if torch.cuda.is_available() else 'cpu'
for data, field in zip(raw_batch, self.fields):
if isinstance(data[0], torch.Tensor):
data = pad_sequence(data, True, field.pad_index).to(device)
elif isinstance(data[0], Iterable):
data = [pad_sequence(f, True, field.pad_index).to(device)
for f in zip(*data)]
batch.append(data)
yield batch
class TextDataset(Dataset):
def __init__(self, corpus, fields, n_buckets=1):
super(TextDataset, self).__init__()
self.corpus = corpus
self.fields = list(chain(*[
field if isinstance(field, Iterable) else [field]
for field in fields if field is not None
]))
for field in self.fields:
value = field.numericalize(getattr(corpus, field.name))
setattr(self, field.name, value)
# NOTE: the final bucket count is roughly equal to n_buckets
self.lengths = [len(i) + sum([bool(field.bos), bool(field.bos)])
for i in corpus]
self.buckets = dict(zip(*kmeans(self.lengths, n_buckets)))
def __getitem__(self, index):
for field in self.fields:
yield getattr(self, field.name)[index]
def __len__(self):
return len(self.corpus)
@property
def loader(self):
if hasattr(self, 'data_loader'):
return self.data_loader
else:
raise AttributeError
@loader.setter
def loader(self, data_loader):
self.data_loader = data_loader
@classmethod
def collate_fn(cls, batch):
return (field for field in zip(*batch))
class TextSampler(Sampler):
def __init__(self, buckets, batch_size, shuffle=False):
self.batch_size = batch_size
self.shuffle = shuffle
self.sizes, self.buckets = zip(*[
(size, bucket) for size, bucket in buckets.items()
])
# the number of chunks in each bucket, which is clipped by
# range [1, len(bucket)]
self.chunks = [
min(len(bucket), max(round(size * len(bucket) / batch_size), 1))
for size, bucket in zip(self.sizes, self.buckets)
]
def __iter__(self):
# if shuffle, shuffle both the buckets and samples in each bucket
range_fn = torch.randperm if self.shuffle else torch.arange
for i in range_fn(len(self.buckets)).tolist():
split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1
for j in range(self.chunks[i])]
# DON'T use `torch.chunk` which may return wrong number of chunks
for batch in range_fn(len(self.buckets[i])).split(split_sizes):
yield [self.buckets[i][j] for j in batch.tolist()]
def __len__(self):
return sum(self.chunks)
def batchify(dataset, batch_size, shuffle=False):
batch_sampler = TextSampler(buckets=dataset.buckets,
batch_size=batch_size,
shuffle=shuffle)
loader = TextDataLoader(dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=dataset.collate_fn)
return loader