Spaces:
Runtime error
Runtime error
| ''' Datasets | |
| This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets | |
| ''' | |
| import os | |
| import os.path | |
| import sys | |
| from PIL import Image | |
| import numpy as np | |
| from tqdm import tqdm, trange | |
| import torchvision.datasets as dset | |
| import torchvision.transforms as transforms | |
| from torchvision.datasets.utils import download_url, check_integrity | |
| import torch.utils.data as data | |
| from torch.utils.data import DataLoader | |
| IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] | |
| def is_image_file(filename): | |
| """Checks if a file is an image. | |
| Args: | |
| filename (string): path to a file | |
| Returns: | |
| bool: True if the filename ends with a known image extension | |
| """ | |
| filename_lower = filename.lower() | |
| return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) | |
| def find_classes(dir): | |
| classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] | |
| classes.sort() | |
| class_to_idx = {classes[i]: i for i in range(len(classes))} | |
| return classes, class_to_idx | |
| def make_dataset(dir, class_to_idx): | |
| images = [] | |
| dir = os.path.expanduser(dir) | |
| for target in tqdm(sorted(os.listdir(dir))): | |
| d = os.path.join(dir, target) | |
| if not os.path.isdir(d): | |
| continue | |
| for root, _, fnames in sorted(os.walk(d)): | |
| for fname in sorted(fnames): | |
| if is_image_file(fname): | |
| path = os.path.join(root, fname) | |
| item = (path, class_to_idx[target]) | |
| images.append(item) | |
| return images | |
| def pil_loader(path): | |
| # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| def accimage_loader(path): | |
| import accimage | |
| try: | |
| return accimage.Image(path) | |
| except IOError: | |
| # Potentially a decoding problem, fall back to PIL.Image | |
| return pil_loader(path) | |
| def default_loader(path): | |
| from torchvision import get_image_backend | |
| if get_image_backend() == 'accimage': | |
| return accimage_loader(path) | |
| else: | |
| return pil_loader(path) | |
| class ImageFolder(data.Dataset): | |
| """A generic data loader where the images are arranged in this way: :: | |
| root/dogball/xxx.png | |
| root/dogball/xxy.png | |
| root/dogball/xxz.png | |
| root/cat/123.png | |
| root/cat/nsdf3.png | |
| root/cat/asd932_.png | |
| Args: | |
| root (string): Root directory path. | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.RandomCrop`` | |
| target_transform (callable, optional): A function/transform that takes in the | |
| target and transforms it. | |
| loader (callable, optional): A function to load an image given its path. | |
| Attributes: | |
| classes (list): List of the class names. | |
| class_to_idx (dict): Dict with items (class_name, class_index). | |
| imgs (list): List of (image path, class_index) tuples | |
| """ | |
| def __init__(self, root, transform=None, target_transform=None, | |
| loader=default_loader, load_in_mem=False, | |
| index_filename='imagenet_imgs.npz', **kwargs): | |
| classes, class_to_idx = find_classes(root) | |
| # Load pre-computed image directory walk | |
| if os.path.exists(index_filename): | |
| print('Loading pre-saved Index file %s...' % index_filename) | |
| imgs = np.load(index_filename)['imgs'] | |
| # If first time, walk the folder directory and save the | |
| # results to a pre-computed file. | |
| else: | |
| print('Generating Index file %s...' % index_filename) | |
| imgs = make_dataset(root, class_to_idx) | |
| np.savez_compressed(index_filename, **{'imgs' : imgs}) | |
| if len(imgs) == 0: | |
| raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" | |
| "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) | |
| self.root = root | |
| self.imgs = imgs | |
| self.classes = classes | |
| self.class_to_idx = class_to_idx | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| self.loader = loader | |
| self.load_in_mem = load_in_mem | |
| if self.load_in_mem: | |
| print('Loading all images into memory...') | |
| self.data, self.labels = [], [] | |
| for index in tqdm(range(len(self.imgs))): | |
| path, target = imgs[index][0], imgs[index][1] | |
| self.data.append(self.transform(self.loader(path))) | |
| self.labels.append(target) | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (image, target) where target is class_index of the target class. | |
| """ | |
| if self.load_in_mem: | |
| img = self.data[index] | |
| target = self.labels[index] | |
| else: | |
| path, target = self.imgs[index] | |
| img = self.loader(str(path)) | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| # print(img.size(), target) | |
| return img, int(target) | |
| def __len__(self): | |
| return len(self.imgs) | |
| def __repr__(self): | |
| fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
| fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
| fmt_str += ' Root Location: {}\n'.format(self.root) | |
| tmp = ' Transforms (if any): ' | |
| fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
| tmp = ' Target Transforms (if any): ' | |
| fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
| return fmt_str | |
| ''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid | |
| having to load individual images all the time. ''' | |
| import h5py as h5 | |
| import torch | |
| class ILSVRC_HDF5(data.Dataset): | |
| def __init__(self, root, transform=None, target_transform=None, | |
| load_in_mem=False, train=True,download=False, validate_seed=0, | |
| val_split=0, **kwargs): # last four are dummies | |
| self.root = root | |
| self.num_imgs = len(h5.File(root, 'r')['labels']) | |
| # self.transform = transform | |
| self.target_transform = target_transform | |
| # Set the transform here | |
| self.transform = transform | |
| # load the entire dataset into memory? | |
| self.load_in_mem = load_in_mem | |
| # If loading into memory, do so now | |
| if self.load_in_mem: | |
| print('Loading %s into memory...' % root) | |
| with h5.File(root,'r') as f: | |
| self.data = f['imgs'][:] | |
| self.labels = f['labels'][:] | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (image, target) where target is class_index of the target class. | |
| """ | |
| # If loaded the entire dataset in RAM, get image from memory | |
| if self.load_in_mem: | |
| img = self.data[index] | |
| target = self.labels[index] | |
| # Else load it from disk | |
| else: | |
| with h5.File(self.root,'r') as f: | |
| img = f['imgs'][index] | |
| target = f['labels'][index] | |
| # if self.transform is not None: | |
| # img = self.transform(img) | |
| # Apply my own transform | |
| img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2 | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return img, int(target) | |
| def __len__(self): | |
| return self.num_imgs | |
| # return len(self.f['imgs']) | |
| import pickle | |
| class CIFAR10(dset.CIFAR10): | |
| def __init__(self, root, train=True, | |
| transform=None, target_transform=None, | |
| download=True, validate_seed=0, | |
| val_split=0, load_in_mem=True, **kwargs): | |
| self.root = os.path.expanduser(root) | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| self.train = train # training set or test set | |
| self.val_split = val_split | |
| if download: | |
| self.download() | |
| if not self._check_integrity(): | |
| raise RuntimeError('Dataset not found or corrupted.' + | |
| ' You can use download=True to download it') | |
| # now load the picked numpy arrays | |
| self.data = [] | |
| self.labels= [] | |
| for fentry in self.train_list: | |
| f = fentry[0] | |
| file = os.path.join(self.root, self.base_folder, f) | |
| fo = open(file, 'rb') | |
| if sys.version_info[0] == 2: | |
| entry = pickle.load(fo) | |
| else: | |
| entry = pickle.load(fo, encoding='latin1') | |
| self.data.append(entry['data']) | |
| if 'labels' in entry: | |
| self.labels += entry['labels'] | |
| else: | |
| self.labels += entry['fine_labels'] | |
| fo.close() | |
| self.data = np.concatenate(self.data) | |
| # Randomly select indices for validation | |
| if self.val_split > 0: | |
| label_indices = [[] for _ in range(max(self.labels)+1)] | |
| for i,l in enumerate(self.labels): | |
| label_indices[l] += [i] | |
| label_indices = np.asarray(label_indices) | |
| # randomly grab 500 elements of each class | |
| np.random.seed(validate_seed) | |
| self.val_indices = [] | |
| for l_i in label_indices: | |
| self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)]) | |
| if self.train=='validate': | |
| self.data = self.data[self.val_indices] | |
| self.labels = list(np.asarray(self.labels)[self.val_indices]) | |
| self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32)) | |
| self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC | |
| elif self.train: | |
| print(np.shape(self.data)) | |
| if self.val_split > 0: | |
| self.data = np.delete(self.data,self.val_indices,axis=0) | |
| self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0)) | |
| self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32)) | |
| self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC | |
| else: | |
| f = self.test_list[0][0] | |
| file = os.path.join(self.root, self.base_folder, f) | |
| fo = open(file, 'rb') | |
| if sys.version_info[0] == 2: | |
| entry = pickle.load(fo) | |
| else: | |
| entry = pickle.load(fo, encoding='latin1') | |
| self.data = entry['data'] | |
| if 'labels' in entry: | |
| self.labels = entry['labels'] | |
| else: | |
| self.labels = entry['fine_labels'] | |
| fo.close() | |
| self.data = self.data.reshape((10000, 3, 32, 32)) | |
| self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (image, target) where target is index of the target class. | |
| """ | |
| img, target = self.data[index], self.labels[index] | |
| # doing this so that it is consistent with all other datasets | |
| # to return a PIL Image | |
| img = Image.fromarray(img) | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return img, target | |
| def __len__(self): | |
| return len(self.data) | |
| class CIFAR100(CIFAR10): | |
| base_folder = 'cifar-100-python' | |
| url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" | |
| filename = "cifar-100-python.tar.gz" | |
| tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' | |
| train_list = [ | |
| ['train', '16019d7e3df5f24257cddd939b257f8d'], | |
| ] | |
| test_list = [ | |
| ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], | |
| ] | |