File size: 2,319 Bytes
9774d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import glob
import os

import cv2
import numpy as np
import torchvision
from PIL import Image
from torch.utils.data.dataset import Dataset
from tqdm import tqdm


class MnistDataset(Dataset):
    r"""
    Nothing special here. Just a simple dataset class for mnist images.
    Created a dataset class rather using torchvision to allow
    replacement with any other image dataset
    """

    def __init__(self, split, im_path, im_ext="png", im_size=28, return_hints=False):
        r"""
        Init method for initializing the dataset properties
        :param split: train/test to locate the image files
        :param im_path: root folder of images
        :param im_ext: image extension. assumes all
        images would be this type.
        """
        self.split = split
        self.im_ext = im_ext
        self.return_hints = return_hints
        self.images = self.load_images(im_path)

    def load_images(self, im_path):
        r"""
        Gets all images from the path specified
        and stacks them all up
        :param im_path:
        :return:
        """
        assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
        ims = []
        labels = []
        for d_name in tqdm(os.listdir(im_path)):
            for fname in glob.glob(
                os.path.join(im_path, d_name, "*.{}".format(self.im_ext))
            ):
                ims.append(fname)
        print("Found {} images for split {}".format(len(ims), self.split))
        return ims

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        im = Image.open(self.images[index])
        im_tensor = torchvision.transforms.ToTensor()(im)

        # Convert input to -1 to 1 range.
        im_tensor = (2 * im_tensor) - 1

        if self.return_hints:
            canny_image = Image.open(self.images[index])
            canny_image = np.array(canny_image)
            canny_image = cv2.Canny(canny_image, 100, 200)
            canny_image = canny_image[:, :, None]
            canny_image = np.concatenate(
                [canny_image, canny_image, canny_image], axis=2
            )
            canny_image_tensor = torchvision.transforms.ToTensor()(canny_image)
            return im_tensor, canny_image_tensor
        else:
            return im_tensor