|
|
|
|
|
|
|
|
from collections.abc import Iterable |
|
|
from itertools import chain |
|
|
from parser.utils.alg import kmeans |
|
|
|
|
|
import torch |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from torch.utils.data import DataLoader, Dataset, Sampler |
|
|
|
|
|
|
|
|
class TextDataLoader(DataLoader): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super(TextDataLoader, self).__init__(*args, **kwargs) |
|
|
|
|
|
self.fields = self.dataset.fields |
|
|
|
|
|
def __iter__(self): |
|
|
for raw_batch in super(TextDataLoader, self).__iter__(): |
|
|
batch, device = [], 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
for data, field in zip(raw_batch, self.fields): |
|
|
if isinstance(data[0], torch.Tensor): |
|
|
data = pad_sequence(data, True, field.pad_index).to(device) |
|
|
elif isinstance(data[0], Iterable): |
|
|
data = [pad_sequence(f, True, field.pad_index).to(device) |
|
|
for f in zip(*data)] |
|
|
batch.append(data) |
|
|
yield batch |
|
|
|
|
|
|
|
|
class TextDataset(Dataset): |
|
|
|
|
|
def __init__(self, corpus, fields, n_buckets=1): |
|
|
super(TextDataset, self).__init__() |
|
|
|
|
|
self.corpus = corpus |
|
|
self.fields = list(chain(*[ |
|
|
field if isinstance(field, Iterable) else [field] |
|
|
for field in fields if field is not None |
|
|
])) |
|
|
for field in self.fields: |
|
|
value = field.numericalize(getattr(corpus, field.name)) |
|
|
setattr(self, field.name, value) |
|
|
|
|
|
self.lengths = [len(i) + sum([bool(field.bos), bool(field.bos)]) |
|
|
for i in corpus] |
|
|
self.buckets = dict(zip(*kmeans(self.lengths, n_buckets))) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
for field in self.fields: |
|
|
yield getattr(self, field.name)[index] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.corpus) |
|
|
|
|
|
@property |
|
|
def loader(self): |
|
|
if hasattr(self, 'data_loader'): |
|
|
return self.data_loader |
|
|
else: |
|
|
raise AttributeError |
|
|
|
|
|
@loader.setter |
|
|
def loader(self, data_loader): |
|
|
self.data_loader = data_loader |
|
|
|
|
|
@classmethod |
|
|
def collate_fn(cls, batch): |
|
|
return (field for field in zip(*batch)) |
|
|
|
|
|
|
|
|
class TextSampler(Sampler): |
|
|
|
|
|
def __init__(self, buckets, batch_size, shuffle=False): |
|
|
self.batch_size = batch_size |
|
|
self.shuffle = shuffle |
|
|
self.sizes, self.buckets = zip(*[ |
|
|
(size, bucket) for size, bucket in buckets.items() |
|
|
]) |
|
|
|
|
|
|
|
|
self.chunks = [ |
|
|
min(len(bucket), max(round(size * len(bucket) / batch_size), 1)) |
|
|
for size, bucket in zip(self.sizes, self.buckets) |
|
|
] |
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
range_fn = torch.randperm if self.shuffle else torch.arange |
|
|
for i in range_fn(len(self.buckets)).tolist(): |
|
|
split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1 |
|
|
for j in range(self.chunks[i])] |
|
|
|
|
|
for batch in range_fn(len(self.buckets[i])).split(split_sizes): |
|
|
yield [self.buckets[i][j] for j in batch.tolist()] |
|
|
|
|
|
def __len__(self): |
|
|
return sum(self.chunks) |
|
|
|
|
|
|
|
|
def batchify(dataset, batch_size, shuffle=False): |
|
|
batch_sampler = TextSampler(buckets=dataset.buckets, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle) |
|
|
loader = TextDataLoader(dataset=dataset, |
|
|
batch_sampler=batch_sampler, |
|
|
collate_fn=dataset.collate_fn) |
|
|
|
|
|
return loader |
|
|
|