| | import bisect |
| | import warnings |
| |
|
| | from torch._utils import _accumulate |
| | from torch import randperm |
| |
|
| |
|
| | class Dataset(object): |
| | """An abstract class representing a Dataset. |
| | |
| | All other datasets should subclass it. All subclasses should override |
| | ``__len__``, that provides the size of the dataset, and ``__getitem__``, |
| | supporting integer indexing in range from 0 to len(self) exclusive. |
| | """ |
| |
|
| | def __getitem__(self, index): |
| | raise NotImplementedError |
| |
|
| | def __len__(self): |
| | raise NotImplementedError |
| |
|
| | def __add__(self, other): |
| | return ConcatDataset([self, other]) |
| |
|
| |
|
| | class TensorDataset(Dataset): |
| | """Dataset wrapping data and target tensors. |
| | |
| | Each sample will be retrieved by indexing both tensors along the first |
| | dimension. |
| | |
| | Arguments: |
| | data_tensor (Tensor): contains sample data. |
| | target_tensor (Tensor): contains sample targets (labels). |
| | """ |
| |
|
| | def __init__(self, data_tensor, target_tensor): |
| | assert data_tensor.size(0) == target_tensor.size(0) |
| | self.data_tensor = data_tensor |
| | self.target_tensor = target_tensor |
| |
|
| | def __getitem__(self, index): |
| | return self.data_tensor[index], self.target_tensor[index] |
| |
|
| | def __len__(self): |
| | return self.data_tensor.size(0) |
| |
|
| |
|
| | class ConcatDataset(Dataset): |
| | """ |
| | Dataset to concatenate multiple datasets. |
| | Purpose: useful to assemble different existing datasets, possibly |
| | large-scale datasets as the concatenation operation is done in an |
| | on-the-fly manner. |
| | |
| | Arguments: |
| | datasets (iterable): List of datasets to be concatenated |
| | """ |
| |
|
| | @staticmethod |
| | def cumsum(sequence): |
| | r, s = [], 0 |
| | for e in sequence: |
| | l = len(e) |
| | r.append(l + s) |
| | s += l |
| | return r |
| |
|
| | def __init__(self, datasets): |
| | super(ConcatDataset, self).__init__() |
| | assert len(datasets) > 0, 'datasets should not be an empty iterable' |
| | self.datasets = list(datasets) |
| | self.cumulative_sizes = self.cumsum(self.datasets) |
| |
|
| | def __len__(self): |
| | return self.cumulative_sizes[-1] |
| |
|
| | def __getitem__(self, idx): |
| | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| | if dataset_idx == 0: |
| | sample_idx = idx |
| | else: |
| | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| | return self.datasets[dataset_idx][sample_idx] |
| |
|
| | @property |
| | def cummulative_sizes(self): |
| | warnings.warn("cummulative_sizes attribute is renamed to " |
| | "cumulative_sizes", DeprecationWarning, stacklevel=2) |
| | return self.cumulative_sizes |
| |
|
| |
|
| | class Subset(Dataset): |
| | def __init__(self, dataset, indices): |
| | self.dataset = dataset |
| | self.indices = indices |
| |
|
| | def __getitem__(self, idx): |
| | return self.dataset[self.indices[idx]] |
| |
|
| | def __len__(self): |
| | return len(self.indices) |
| |
|
| |
|
| | def random_split(dataset, lengths): |
| | """ |
| | Randomly split a dataset into non-overlapping new datasets of given lengths |
| | ds |
| | |
| | Arguments: |
| | dataset (Dataset): Dataset to be split |
| | lengths (iterable): lengths of splits to be produced |
| | """ |
| | if sum(lengths) != len(dataset): |
| | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") |
| |
|
| | indices = randperm(sum(lengths)) |
| | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] |
| |
|