Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import os
import itertools
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
""" This file implements dataset wrappers and batch samplers for TorchTask.
"""
class _TorchTaskDatasetWrapper(Dataset):
""" This is the superclass of TorchTask dataset wrapper.
"""
def __init__(self):
super(_TorchTaskDatasetWrapper, self).__init__()
self.labeled_idxs = [] # index of the labeled data
self.additional_idxs = [] # index of the additional data
class SplitUnlabeledWrapper(_TorchTaskDatasetWrapper):
""" Split the fully labeled dataset into a labeled subset and an
additional dataset based on a given sublabeled prefix list.
For a fully labeled dataset, a common operation is to remove the labels
of some samples and treat them as the additional samples.
This dataset wrapper implements the dataset-split operation by using
the given sublabeled prefix list. Samples whose prefix in the list
are treated as the labeled samples, while others samples are treated as
the additional samples.
"""
def __init__(self, dataset, sublabeled_prefix, ignore_additional=False):
super(SplitUnlabeledWrapper, self).__init__()
self.dataset = dataset
self.sublabeled_prefix = sublabeled_prefix
self.ignore_additional = ignore_additional
self._split_labeled()
def __len__(self):
return self.dataset.__len__()
def __getitem__(self, idx):
return self.dataset.__getitem__(idx)
def _split_labeled(self):
labeled_list, additional_list = [], []
for img in self.dataset.sample_list:
is_labeled = False
for pdx, prefix in enumerate(self.sublabeled_prefix):
if img.startswith(prefix):
labeled_list.append(img)
is_labeled = True
break
if not is_labeled:
additional_list.append(img)
labeled_size, additional_size = len(labeled_list), len(additional_list)
assert labeled_size + additional_size == len(self.dataset.sample_list)
if self.ignore_additional:
self.dataset.sample_list = labeled_list
self.dataset.idxs = [_ for _ in range(0, len(self.dataset.sample_list))]
self.labeled_idxs = self.dataset.idxs
self.additional_idxs = []
else:
self.dataset.sample_list = labeled_list + additional_list
self.dataset.idxs = [_ for _ in range(0, len(self.dataset.sample_list))]
self.labeled_idxs = [_ for _ in range(0, labeled_size)]
self.additional_idxs = [_ + labeled_size for _ in range(0, additional_size)]
class JointDatasetsWrapper(_TorchTaskDatasetWrapper):
""" Combine several datasets (can be labeled or additional) into one dataset.
This dataset wrapper will combine multiple given dataset into one big dataset.
The new dataset consists of a labeled subset and an additional subset.
"""
def __init__(self, labeled_datasets, additional_datasets, ignore_additional=False):
super(JointDatasetsWrapper, self).__init__()
self.labeled_datasets = labeled_datasets
self.additional_datasets = additional_datasets
self.ignore_additional = ignore_additional
self.labeled_datasets_size = [len(d) for d in self.labeled_datasets]
self.additional_datasets_size = [len(d) for d in self.additional_datasets]
self.labeled_size = np.sum(np.asarray(self.labeled_datasets_size))
self.labeled_idxs = [_ for _ in range(0, self.labeled_size)]
self.additional_size = 0
if not self.ignore_additional:
self.additional_size = np.sum(np.asarray(self.additional_datasets_size))
self.additional_idxs = [self.labeled_size + _ for _ in range(0, self.additional_size)]
def __len__(self):
return int(self.labeled_size + self.additional_size)
def __getitem__(self, idx):
assert 0 <= idx < self.__len__()
if idx >= self.labeled_size:
idx -= self.labeled_size
datasets = self.additional_datasets
datasets_size = self.additional_datasets_size
else:
datasets = self.labeled_datasets
datasets_size = self.labeled_datasets_size
accumulated_idxs = 0
for ddx, dsize in enumerate(datasets_size):
accumulated_idxs += dsize
if idx < accumulated_idxs:
return datasets[ddx].__getitem__(idx - (accumulated_idxs - dsize))
class TwoStreamBatchSampler(Sampler):
""" This two stream batch sampler is used to read data from '_TorchTaskDatasetWrapper'.
It iterates two sets of indices simultaneously to read mini-batch for TorchTask.
There are two sets of indices:
labeled_idxs, additional_idxs
An 'epoch' is defined by going through the longer indices once.
In each 'epoch', the shorter indices are iterated through as many times as needed.
"""
def __init__(self, labeled_idxs, additional_idxs, labeled_batch_size, additional_batch_size, short_ep=False):
self.labeled_idxs = labeled_idxs
self.additional_idxs = additional_idxs
self.labeled_batch_size = labeled_batch_size
self.additional_batch_size = additional_batch_size
assert len(self.labeled_idxs) >= self.labeled_batch_size > 0
assert len(self.additional_idxs) >= self.additional_batch_size > 0
self.additional_batchs = len(self.additional_idxs) // self.additional_batch_size
self.labeled_batchs = len(self.labeled_idxs) // self.labeled_batch_size
self.short_ep = short_ep
def __iter__(self):
if not self.short_ep:
if self.additional_batchs >= self.labeled_batchs:
additional_iter = self.iterate_once(self.additional_idxs)
labeled_iter = self.iterate_eternally(self.labeled_idxs)
else:
additional_iter = self.iterate_eternally(self.additional_idxs)
labeled_iter = self.iterate_once(self.labeled_idxs)
else:
if self.additional_batchs >= self.labeled_batchs:
additional_iter = self.iterate_eternally(self.additional_idxs)
labeled_iter = self.iterate_once(self.labeled_idxs)
else:
additional_iter = self.iterate_once(self.additional_idxs)
labeled_iter = self.iterate_eternally(self.labeled_idxs)
return (labeled_batch + additional_batch
for (labeled_batch, additional_batch) in zip(
self.grouper(labeled_iter, self.labeled_batch_size),
self.grouper(additional_iter, self.additional_batch_size)))
def __len__(self):
if self.short_ep:
return min(self.additional_batchs, self.labeled_batchs)
else:
return max(self.additional_batchs, self.labeled_batchs)
def iterate_once(self, iterable):
return np.random.permutation(iterable)
def iterate_eternally(self, indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())
def grouper(self, iterable, n):
# e.g., grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)