import os import torch from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 class CIFAR100: def __init__(self, preprocess, location=os.path.expanduser('~/data'), batch_size=128, num_workers=16): self.train_dataset = PyTorchCIFAR100( root=location, download=True, train=True, transform=preprocess ) self.train_loader = torch.utils.data.DataLoader( self.train_dataset, batch_size=batch_size, num_workers=num_workers ) self.test_dataset = PyTorchCIFAR100( root=location, download=True, train=False, transform=preprocess ) self.test_loader = torch.utils.data.DataLoader( self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) self.classnames = self.test_dataset.classes