| import glob |
| import os |
|
|
| import cv2 |
| import numpy as np |
| import torchvision |
| from PIL import Image |
| from torch.utils.data.dataset import Dataset |
| from tqdm import tqdm |
|
|
|
|
| class MnistDataset(Dataset): |
| r""" |
| Nothing special here. Just a simple dataset class for mnist images. |
| Created a dataset class rather using torchvision to allow |
| replacement with any other image dataset |
| """ |
|
|
| def __init__(self, split, im_path, im_ext="png", im_size=28, return_hints=False): |
| r""" |
| Init method for initializing the dataset properties |
| :param split: train/test to locate the image files |
| :param im_path: root folder of images |
| :param im_ext: image extension. assumes all |
| images would be this type. |
| """ |
| self.split = split |
| self.im_ext = im_ext |
| self.return_hints = return_hints |
| self.images = self.load_images(im_path) |
|
|
| def load_images(self, im_path): |
| r""" |
| Gets all images from the path specified |
| and stacks them all up |
| :param im_path: |
| :return: |
| """ |
| assert os.path.exists(im_path), "images path {} does not exist".format(im_path) |
| ims = [] |
| labels = [] |
| for d_name in tqdm(os.listdir(im_path)): |
| for fname in glob.glob( |
| os.path.join(im_path, d_name, "*.{}".format(self.im_ext)) |
| ): |
| ims.append(fname) |
| print("Found {} images for split {}".format(len(ims), self.split)) |
| return ims |
|
|
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, index): |
| im = Image.open(self.images[index]) |
| im_tensor = torchvision.transforms.ToTensor()(im) |
|
|
| |
| im_tensor = (2 * im_tensor) - 1 |
|
|
| if self.return_hints: |
| canny_image = Image.open(self.images[index]) |
| canny_image = np.array(canny_image) |
| canny_image = cv2.Canny(canny_image, 100, 200) |
| canny_image = canny_image[:, :, None] |
| canny_image = np.concatenate( |
| [canny_image, canny_image, canny_image], axis=2 |
| ) |
| canny_image_tensor = torchvision.transforms.ToTensor()(canny_image) |
| return im_tensor, canny_image_tensor |
| else: |
| return im_tensor |
|
|