import glob import os import cv2 import numpy as np import torchvision from PIL import Image from torch.utils.data.dataset import Dataset from tqdm import tqdm class MnistDataset(Dataset): r""" Nothing special here. Just a simple dataset class for mnist images. Created a dataset class rather using torchvision to allow replacement with any other image dataset """ def __init__(self, split, im_path, im_ext="png", im_size=28, return_hints=False): r""" Init method for initializing the dataset properties :param split: train/test to locate the image files :param im_path: root folder of images :param im_ext: image extension. assumes all images would be this type. """ self.split = split self.im_ext = im_ext self.return_hints = return_hints self.images = self.load_images(im_path) def load_images(self, im_path): r""" Gets all images from the path specified and stacks them all up :param im_path: :return: """ assert os.path.exists(im_path), "images path {} does not exist".format(im_path) ims = [] labels = [] for d_name in tqdm(os.listdir(im_path)): for fname in glob.glob( os.path.join(im_path, d_name, "*.{}".format(self.im_ext)) ): ims.append(fname) print("Found {} images for split {}".format(len(ims), self.split)) return ims def __len__(self): return len(self.images) def __getitem__(self, index): im = Image.open(self.images[index]) im_tensor = torchvision.transforms.ToTensor()(im) # Convert input to -1 to 1 range. im_tensor = (2 * im_tensor) - 1 if self.return_hints: canny_image = Image.open(self.images[index]) canny_image = np.array(canny_image) canny_image = cv2.Canny(canny_image, 100, 200) canny_image = canny_image[:, :, None] canny_image = np.concatenate( [canny_image, canny_image, canny_image], axis=2 ) canny_image_tensor = torchvision.transforms.ToTensor()(canny_image) return im_tensor, canny_image_tensor else: return im_tensor