| import torch.utils.data as data |
|
|
| from PIL import Image |
|
|
| import os |
| import os.path |
| from io import BytesIO |
|
|
| import lmdb |
| from torch.utils.data import Dataset |
|
|
| class MultiResolutionDataset(Dataset): |
| def __init__(self, path, transform, resolution=256): |
| self.env = lmdb.open( |
| path, |
| max_readers=32, |
| readonly=True, |
| lock=False, |
| readahead=False, |
| meminit=False, |
| ) |
|
|
| if not self.env: |
| raise IOError('Cannot open lmdb dataset', path) |
|
|
| with self.env.begin(write=False) as txn: |
| self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) |
|
|
| self.resolution = resolution |
| self.transform = transform |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| with self.env.begin(write=False) as txn: |
| key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') |
| img_bytes = txn.get(key) |
|
|
| buffer = BytesIO(img_bytes) |
| img = Image.open(buffer) |
| img = self.transform(img) |
|
|
| return img |
|
|
|
|
| def has_file_allowed_extension(filename, extensions): |
| """Checks if a file is an allowed extension. |
| |
| Args: |
| filename (string): path to a file |
| |
| Returns: |
| bool: True if the filename ends with a known image extension |
| """ |
| filename_lower = filename.lower() |
| return any(filename_lower.endswith(ext) for ext in extensions) |
|
|
|
|
| def find_classes(dir): |
| classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] |
| classes.sort() |
| class_to_idx = {classes[i]: i for i in range(len(classes))} |
| return classes, class_to_idx |
|
|
|
|
| def make_dataset(dir, extensions): |
| images = [] |
| for root, _, fnames in sorted(os.walk(dir)): |
| for fname in sorted(fnames): |
| if has_file_allowed_extension(fname, extensions): |
| path = os.path.join(root, fname) |
| item = (path, 0) |
| images.append(item) |
|
|
| return images |
|
|
|
|
| class DatasetFolder(data.Dataset): |
| def __init__(self, root, loader, extensions, transform=None, target_transform=None): |
| |
| samples = make_dataset(root, extensions) |
| if len(samples) == 0: |
| raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" |
| "Supported extensions are: " + ",".join(extensions))) |
|
|
| self.root = root |
| self.loader = loader |
| self.extensions = extensions |
| self.samples = samples |
|
|
| self.transform = transform |
| self.target_transform = target_transform |
|
|
| def __getitem__(self, index): |
| """ |
| Args: |
| index (int): Index |
| |
| Returns: |
| tuple: (sample, target) where target is class_index of the target class. |
| """ |
| path, target = self.samples[index] |
| sample = self.loader(path) |
| if self.transform is not None: |
| sample = self.transform(sample) |
| if self.target_transform is not None: |
| target = self.target_transform(target) |
|
|
| return sample |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __repr__(self): |
| fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' |
| fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) |
| fmt_str += ' Root Location: {}\n'.format(self.root) |
| tmp = ' Transforms (if any): ' |
| fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| tmp = ' Target Transforms (if any): ' |
| fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| return fmt_str |
|
|
|
|
| IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] |
|
|
|
|
| def pil_loader(path): |
| |
| with open(path, 'rb') as f: |
| img = Image.open(f) |
| return img.convert('RGB') |
|
|
|
|
| def default_loader(path): |
| return pil_loader(path) |
|
|
|
|
| class ImageFolder(DatasetFolder): |
| def __init__(self, root, transform1=None, transform2=None, target_transform=None, |
| loader=default_loader): |
| super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, |
| transform=transform1, |
| target_transform=target_transform) |
| self.imgs = self.samples |
| self.transform2 = transform2 |
|
|
| def set_stage(self, stage): |
| if stage == 'last': |
| self.transform = self.transform2 |
|
|
| class ListFolder(Dataset): |
| def __init__(self, txt, transform): |
| with open(txt) as f: |
| imgpaths= f.readlines() |
| self.imgpaths = [x.strip() for x in imgpaths] |
| self.transform = transform |
|
|
| def __getitem__(self, idx): |
| path = self.imgpaths[idx] |
| image = Image.open(path) |
| return self.transform(image) |
|
|
| def __len__(self): |
| return len(self.imgpaths) |
|
|
|
|