import os import PIL import torch import numpy as np import torchvision from torchvision import transforms from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 from torchvision.datasets import VisionDataset cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] class CIFAR10: def __init__(self, preprocess, location=os.path.expanduser('~/data'), batch_size=128, num_workers=16): self.train_dataset = PyTorchCIFAR10( root=location, download=True, train=True, transform=preprocess ) self.train_loader = torch.utils.data.DataLoader( self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) self.test_dataset = PyTorchCIFAR10( root=location, download=True, train=False, transform=preprocess ) self.test_loader = torch.utils.data.DataLoader( self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) self.classnames = self.test_dataset.classes def convert(x): if isinstance(x, np.ndarray): return torchvision.transforms.functional.to_pil_image(x) return x class BasicVisionDataset(VisionDataset): def __init__(self, images, targets, transform=None, target_transform=None): if transform is not None: transform.transforms.insert(0, convert) super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) assert len(images) == len(targets) self.images = images self.targets = targets def __getitem__(self, index): return self.transform(self.images[index]), self.targets[index] def __len__(self): return len(self.targets)