import os import torch import torchvision.datasets as datasets class SUN397: def __init__(self, preprocess, location=os.path.expanduser('~/data'), batch_size=32, num_workers=16): # Data loading code traindir = os.path.join(location, 'sun397', 'train') valdir = os.path.join(location, 'sun397', 'val') self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) self.train_loader = torch.utils.data.DataLoader( self.train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, ) self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) self.test_loader = torch.utils.data.DataLoader( self.test_dataset, batch_size=batch_size, num_workers=num_workers ) idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items()) self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))]