File size: 3,852 Bytes
366b225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# -*- coding: utf-8 -*-

from collections.abc import Iterable
from itertools import chain
from parser.utils.alg import kmeans

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, Sampler


class TextDataLoader(DataLoader):

    def __init__(self, *args, **kwargs):
        super(TextDataLoader, self).__init__(*args, **kwargs)

        self.fields = self.dataset.fields

    def __iter__(self):
        for raw_batch in super(TextDataLoader, self).__iter__():
            batch, device = [], 'cuda' if torch.cuda.is_available() else 'cpu'
            for data, field in zip(raw_batch, self.fields):
                if isinstance(data[0], torch.Tensor):
                    data = pad_sequence(data, True, field.pad_index).to(device)
                elif isinstance(data[0], Iterable):
                    data = [pad_sequence(f, True, field.pad_index).to(device)
                            for f in zip(*data)]
                batch.append(data)
            yield batch


class TextDataset(Dataset):

    def __init__(self, corpus, fields, n_buckets=1):
        super(TextDataset, self).__init__()

        self.corpus = corpus
        self.fields = list(chain(*[
            field if isinstance(field, Iterable) else [field]
            for field in fields if field is not None
        ]))
        for field in self.fields:
            value = field.numericalize(getattr(corpus, field.name))
            setattr(self, field.name, value)
        # NOTE: the final bucket count is roughly equal to n_buckets
        self.lengths = [len(i) + sum([bool(field.bos), bool(field.bos)])
                        for i in corpus]
        self.buckets = dict(zip(*kmeans(self.lengths, n_buckets)))

    def __getitem__(self, index):
        for field in self.fields:
            yield getattr(self, field.name)[index]

    def __len__(self):
        return len(self.corpus)

    @property
    def loader(self):
        if hasattr(self, 'data_loader'):
            return self.data_loader
        else:
            raise AttributeError

    @loader.setter
    def loader(self, data_loader):
        self.data_loader = data_loader

    @classmethod
    def collate_fn(cls, batch):
        return (field for field in zip(*batch))


class TextSampler(Sampler):

    def __init__(self, buckets, batch_size, shuffle=False):
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.sizes, self.buckets = zip(*[
            (size, bucket) for size, bucket in buckets.items()
        ])
        # the number of chunks in each bucket, which is clipped by
        # range [1, len(bucket)]
        self.chunks = [
            min(len(bucket), max(round(size * len(bucket) / batch_size), 1))
            for size, bucket in zip(self.sizes, self.buckets)
        ]

    def __iter__(self):
        # if shuffle, shuffle both the buckets and samples in each bucket
        range_fn = torch.randperm if self.shuffle else torch.arange
        for i in range_fn(len(self.buckets)).tolist():
            split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1
                           for j in range(self.chunks[i])]
            # DON'T use `torch.chunk` which may return wrong number of chunks
            for batch in range_fn(len(self.buckets[i])).split(split_sizes):
                yield [self.buckets[i][j] for j in batch.tolist()]

    def __len__(self):
        return sum(self.chunks)


def batchify(dataset, batch_size, shuffle=False):
    batch_sampler = TextSampler(buckets=dataset.buckets,
                                batch_size=batch_size,
                                shuffle=shuffle)
    loader = TextDataLoader(dataset=dataset,
                            batch_sampler=batch_sampler,
                            collate_fn=dataset.collate_fn)

    return loader