File size: 2,319 Bytes
9774d79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import glob
import os
import cv2
import numpy as np
import torchvision
from PIL import Image
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
class MnistDataset(Dataset):
r"""
Nothing special here. Just a simple dataset class for mnist images.
Created a dataset class rather using torchvision to allow
replacement with any other image dataset
"""
def __init__(self, split, im_path, im_ext="png", im_size=28, return_hints=False):
r"""
Init method for initializing the dataset properties
:param split: train/test to locate the image files
:param im_path: root folder of images
:param im_ext: image extension. assumes all
images would be this type.
"""
self.split = split
self.im_ext = im_ext
self.return_hints = return_hints
self.images = self.load_images(im_path)
def load_images(self, im_path):
r"""
Gets all images from the path specified
and stacks them all up
:param im_path:
:return:
"""
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
ims = []
labels = []
for d_name in tqdm(os.listdir(im_path)):
for fname in glob.glob(
os.path.join(im_path, d_name, "*.{}".format(self.im_ext))
):
ims.append(fname)
print("Found {} images for split {}".format(len(ims), self.split))
return ims
def __len__(self):
return len(self.images)
def __getitem__(self, index):
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
if self.return_hints:
canny_image = Image.open(self.images[index])
canny_image = np.array(canny_image)
canny_image = cv2.Canny(canny_image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate(
[canny_image, canny_image, canny_image], axis=2
)
canny_image_tensor = torchvision.transforms.ToTensor()(canny_image)
return im_tensor, canny_image_tensor
else:
return im_tensor
|