Spaces:
Runtime error
Runtime error
| import random | |
| import sys | |
| import os.path | |
| from PIL import Image | |
| from swapae.data.base_dataset import BaseDataset, get_transform | |
| import cv2 | |
| import numpy as np | |
| if sys.version_info[0] == 2: | |
| import cPickle as pickle | |
| else: | |
| import pickle | |
| import torchvision.transforms as transforms | |
| class LMDBDataset(BaseDataset): | |
| def __init__(self, opt): | |
| import lmdb | |
| self.opt = opt | |
| write_cache = True | |
| root = opt.dataroot | |
| self.root = os.path.expanduser(root) | |
| self.env = lmdb.open(root, readonly=True, lock=False) | |
| with self.env.begin(write=False) as txn: | |
| self.length = txn.stat()['entries'] | |
| print('lmdb file at %s opened.' % root) | |
| cache_file = os.path.join(root, '_cache_') | |
| if os.path.isfile(cache_file): | |
| self.keys = pickle.load(open(cache_file, "rb")) | |
| elif write_cache: | |
| print('generating keys') | |
| with self.env.begin(write=False) as txn: | |
| self.keys = [key for key, _ in txn.cursor()] | |
| pickle.dump(self.keys, open(cache_file, "wb")) | |
| print('cache file generated at %s' % cache_file) | |
| else: | |
| self.keys = [] | |
| random.Random(0).shuffle(self.keys) | |
| self.transform = get_transform(self.opt, grayscale=False) | |
| if "lsun" in self.opt.dataroot.lower(): | |
| print("Seems like a LSUN dataset, so we will apply BGR->RGB conversion") | |
| def __getitem__(self, index): | |
| path = self.keys[index] | |
| return self.getitem_by_path(path) | |
| def getitem_by_path(self, path): | |
| env = self.env | |
| with env.begin(write=False) as txn: | |
| imgbuf = txn.get(path) | |
| try: | |
| img = cv2.imdecode( | |
| np.fromstring(imgbuf, dtype=np.uint8), 1) | |
| except cv2.error as e: | |
| print(path, e) | |
| return self.__getitem__(random.randint(0, self.length - 1)) | |
| if "lsun" in self.opt.dataroot.lower(): | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = Image.fromarray(img) | |
| return {"real_A": self.transform(img), "path_A": path.decode("utf-8")} | |
| def set_phase(self, phase): | |
| super().set_phase(phase) | |
| pass | |
| def __len__(self): | |
| return self.length | |
| def __repr__(self): | |
| return self.__class__.__name__ + ' (' + self.root + ')' | |