| import os |
| import PIL |
| import torch |
| import numpy as np |
| import torchvision |
| from torchvision import transforms |
| from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 |
| from torchvision.datasets import VisionDataset |
|
|
| cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
| class CIFAR10: |
| def __init__(self, preprocess, |
| location=os.path.expanduser('~/data'), |
| batch_size=128, |
| num_workers=16): |
|
|
|
|
| self.train_dataset = PyTorchCIFAR10( |
| root=location, download=True, train=True, transform=preprocess |
| ) |
|
|
| self.train_loader = torch.utils.data.DataLoader( |
| self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers |
| ) |
|
|
| self.test_dataset = PyTorchCIFAR10( |
| root=location, download=True, train=False, transform=preprocess |
| ) |
|
|
| self.test_loader = torch.utils.data.DataLoader( |
| self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers |
| ) |
|
|
| self.classnames = self.test_dataset.classes |
|
|
| def convert(x): |
| if isinstance(x, np.ndarray): |
| return torchvision.transforms.functional.to_pil_image(x) |
| return x |
|
|
| class BasicVisionDataset(VisionDataset): |
| def __init__(self, images, targets, transform=None, target_transform=None): |
| if transform is not None: |
| transform.transforms.insert(0, convert) |
| super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) |
| assert len(images) == len(targets) |
|
|
| self.images = images |
| self.targets = targets |
|
|
| def __getitem__(self, index): |
| return self.transform(self.images[index]), self.targets[index] |
|
|
| def __len__(self): |
| return len(self.targets) |
|
|