Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import numpy as np | |
| class NYU_v2_datset(Dataset): | |
| """NYUDataset.""" | |
| def __init__(self, root_dir, scale=8, train=True, transform=None): | |
| """ | |
| Args: | |
| root_dir (string): Directory with all the images. | |
| scale (float): dataset scale | |
| train (bool): train or test | |
| transform (callable, optional): Optional transform to be applied on a sample. | |
| """ | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| self.scale = scale | |
| self.train = train | |
| if train: | |
| self.depths = np.load('%s/train_depth_split.npy' % root_dir) | |
| self.images = np.load('%s/train_images_split.npy' % root_dir) | |
| else: | |
| self.depths = np.load('%s/test_depth.npy' % root_dir) | |
| self.images = np.load('%s/test_images_v2.npy' % root_dir) | |
| def __len__(self): | |
| return self.depths.shape[0] | |
| def __getitem__(self, idx): | |
| depth = self.depths[idx] | |
| image = self.images[idx] | |
| h, w = depth.shape[:2] | |
| s = self.scale | |
| lr = np.array(Image.fromarray(depth.squeeze()).resize((w // s, h // s), Image.BICUBIC).resize((w, h), Image.BICUBIC)) | |
| if self.transform: | |
| image = self.transform(image).float() | |
| depth = self.transform(depth).float() | |
| lr = self.transform(np.expand_dims(lr, 2)).float() | |
| sample = {'guidance': image, 'lr': lr, 'gt': depth} | |
| return sample |