Spaces:
Paused
Paused
| """ | |
| Copyright (C) 2019 NVIDIA Corporation. All rights reserved. | |
| Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). | |
| """ | |
| ############################################################################### | |
| # Code from | |
| # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py | |
| # Modified the original code so that it also loads images from the current | |
| # directory as well as the subdirectories | |
| ############################################################################### | |
| import torch.utils.data as data | |
| from PIL import Image | |
| import os | |
| IMG_EXTENSIONS = [ | |
| '.jpg', '.JPG', '.jpeg', '.JPEG', | |
| '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp' | |
| ] | |
| def is_image_file(filename): | |
| return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | |
| def make_dataset_rec(dir, images): | |
| assert os.path.isdir(dir), '%s is not a valid directory' % dir | |
| for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): | |
| for fname in fnames: | |
| if is_image_file(fname): | |
| path = os.path.join(root, fname) | |
| images.append(path) | |
| def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): | |
| """ | |
| xuyang: could get image path list | |
| """ | |
| images = [] | |
| if read_cache: | |
| possible_filelist = os.path.join(dir, 'files.list') | |
| if os.path.isfile(possible_filelist): | |
| with open(possible_filelist, 'r') as f: | |
| images = f.read().splitlines() | |
| return images | |
| if recursive: | |
| make_dataset_rec(dir, images) | |
| else: | |
| assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir | |
| for root, dnames, fnames in sorted(os.walk(dir)): | |
| for fname in fnames: | |
| if is_image_file(fname): | |
| path = os.path.join(root, fname) | |
| images.append(path) | |
| if write_cache: | |
| filelist_cache = os.path.join(dir, 'files.list') | |
| with open(filelist_cache, 'w') as f: | |
| for path in images: | |
| f.write("%s\n" % path) | |
| print('wrote filelist cache at %s' % filelist_cache) | |
| return images | |
| def default_loader(path): | |
| return Image.open(path).convert('RGB') | |
| class ImageFolder(data.Dataset): | |
| def __init__(self, root, transform=None, return_paths=False, | |
| loader=default_loader): | |
| imgs = make_dataset(root) | |
| if len(imgs) == 0: | |
| raise(RuntimeError("Found 0 images in: " + root + "\n" | |
| "Supported image extensions are: " + | |
| ",".join(IMG_EXTENSIONS))) | |
| self.root = root | |
| self.imgs = imgs | |
| self.transform = transform | |
| self.return_paths = return_paths | |
| self.loader = loader | |
| def __getitem__(self, index): | |
| path = self.imgs[index] | |
| img = self.loader(path) | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.return_paths: | |
| return img, path | |
| else: | |
| return img | |
| def __len__(self): | |
| return len(self.imgs) | |