File size: 12,688 Bytes
8e6512c
 
b5cbaa6
8e6512c
b5cbaa6
8e6512c
 
b5cbaa6
 
 
8e6512c
 
b5cbaa6
8e6512c
 
b5cbaa6
 
8e6512c
b5cbaa6
 
 
8e6512c
b5cbaa6
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
 
8e6512c
 
 
b5cbaa6
 
 
 
8e6512c
b5cbaa6
8e6512c
 
 
 
b5cbaa6
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
8e6512c
 
 
 
 
 
 
 
 
b5cbaa6
 
8e6512c
 
 
 
b5cbaa6
 
 
 
 
8e6512c
 
b5cbaa6
 
8e6512c
b5cbaa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
8e6512c
 
 
 
 
 
 
 
 
 
 
 
 
b5cbaa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6512c
 
 
b5cbaa6
 
 
 
8e6512c
b5cbaa6
8e6512c
b5cbaa6
 
 
 
8e6512c
 
 
b5cbaa6
 
 
 
 
 
 
 
8e6512c
 
 
b5cbaa6
 
8e6512c
b5cbaa6
 
8e6512c
 
 
 
b5cbaa6
 
 
 
 
 
8e6512c
 
 
b5cbaa6
 
 
8e6512c
 
b5cbaa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
8e6512c
b5cbaa6
 
 
 
 
 
8e6512c
 
 
 
 
 
 
b5cbaa6
8e6512c
 
b5cbaa6
 
 
 
 
 
 
 
8e6512c
 
 
b5cbaa6
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6512c
b5cbaa6
 
 
 
 
8e6512c
b5cbaa6
 
 
8e6512c
 
 
 
 
 
 
b5cbaa6
 
 
 
8e6512c
 
b5cbaa6
 
 
 
8e6512c
b5cbaa6
 
 
 
 
 
 
8e6512c
b5cbaa6
 
8e6512c
 
 
b5cbaa6
8e6512c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import logging
from glob import glob
from os import listdir
from os.path import splitext
from typing import Dict, List

import cv2
import numpy
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset

import utils.utils
from utils.augment import *

r"""
Defines the `BasicSegmentationDataset` and `CoronaryArterySegmentationDatasets`, which extend the `Dataset` and `BasicSegmentationDataset` \
classes, respectively. Each class defines the specific methods needed for data processing and a method :func:`__getitem__` to return samples.
"""


class BasicSegmentationDataset(Dataset):
    r"""
    Implements a basic dataset for segmentation tasks, with methods for image and mask scaling and normalization. \
    The filenames of the segmentation ground truths must be equal to the filenames of the images to be segmented, \
    except for a possible suffix.

    Args:
        imgs_dir (str): path to the directory containing the images to be segmented.
        masks_dir (str): path to the directory containing the segmentation ground truths.
        scale (float, optional): image scale, between 0 and 1, to be used in the segmentation.
        mask_suffix (str, optional): suffix to be added to an image's filename to obtain its
            ground truth filename.
    """

    def __init__(
        self, imgs_dir: str, masks_dir: str, scale: float = 1, mask_suffix: str = ""
    ):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.scale = scale
        self.mask_suffix = mask_suffix
        assert 0 < scale <= 1, "Scale must be between 0 and 1"

        self.ids = [
            splitext(file)[0] for file in listdir(imgs_dir) if not file.startswith(".")
        ]
        logging.info(f"Creating dataset with {len(self.ids)} examples")

    def __len__(self) -> int:
        r"""
        Returns the size of the dataset.
        """
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img: Image, scale: float) -> Image:
        r"""
        Preprocesses an `Image`, rescaling it and returning it as a NumPy array in
        the CHW format.

        Args:
            pil_imgs (Image): object of class `Image` to be preprocessed.
            scale (float): image scale, between 0 and 1.
        """
        w, h = pil_img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, "Scale is too small"
        pil_img = pil_img.resize((newW, newH))

        img_nd = numpy.array(pil_img)

        if len(img_nd.shape) == 2:
            img_nd = numpy.expand_dims(img_nd, axis=2)

        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255

        return img_trans

    def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]:
        r"""
        Returns two tensors: an image and the corresponding mask.
        """
        idx = self.ids[i]
        mask_file = glob(self.masks_dir + idx + self.mask_suffix + ".*")
        img_file = glob(self.imgs_dir + idx + ".*")

        assert (
            len(mask_file) == 1
        ), f"Either no mask or multiple masks found for the ID {idx}: {mask_file}"
        assert (
            len(img_file) == 1
        ), f"Either no image or multiple images found for the ID {idx}: {img_file}"
        mask = Image.open(mask_file[0])
        img = Image.open(img_file[0])

        assert (
            img.size == mask.size
        ), f"Image and mask {idx} should be the same size, but are {img.size} and {mask.size}"

        img = self.preprocess(img, self.scale)
        mask = self.preprocess(mask, self.scale)

        return {
            "image": [torch.from_numpy(img).type(torch.FloatTensor)],
            "mask": [torch.from_numpy(mask).type(torch.FloatTensor)],
        }


class CoronaryDataset(BasicSegmentationDataset):
    r"""
    Implements a dataset for the Retinal Vessel Segmentation task

    Args:
        imgs_dir (str): path to the directory containing the images to be segmented.
        masks_dir (str): path to the directory containing the segmentation ground truths.
        scale (float, optional): image scale, between 0 and 1, to be used in the segmentation.
        augmentation_ratio (int, optional): number of augmentations to generate per image.
        crop_size (int, optional): size of the square image to be fed to the model.
        aug_policy (str, optional): data augmentation policy.
    """
    # Number of classes, including the background class
    n_classes = 2

    # Maps maks grayscale value to mask class index
    gray2class_mapping = {0: 0, 255: 1}

    # Maps mask grayscale value to mask RGB value
    gray2rgb_mapping = {0: (0, 0, 0), 255: (255, 255, 255)}

    rgb2class_mapping = {(0, 0, 0): 0, (255, 255, 255): 1}

    def __init__(
        self,
        imgs_dir: str,
        masks_dir: str,
        scale: float = 1,
        augmentation_ratio: int = 0,
        crop_size: int = 512,
        aug_policy: str = "retina",
    ):
        super().__init__(imgs_dir, masks_dir, scale)
        self.augmentation_ratio = augmentation_ratio
        self.policy = aug_policy
        self.crop_size = crop_size

    @classmethod
    def mask_img2class_mask(cls, pil_mask: Image, scale: float) -> numpy.array:
        r"""
        Preprocesses a grayscale `Image` containing a segmentation mask, rescaling it, converting its grayscale values \
        to class indices and returning it as a NumPy array in the CHW format.

        Args:
            pil_imgs (Image): object of class `Image` to be preprocessed.
            scale (float): image scale, between 0 and 1.
        """
        w, h = pil_mask.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, "Scale is too small"
        pil_mask = pil_mask.resize((newW, newH))

        if pil_mask.mode != "L":
            pil_mask = pil_mask.convert(mode="L")
        mask_nd = numpy.array(pil_mask)

        if len(mask_nd.shape) == 2:
            mask_nd = numpy.expand_dims(mask_nd, axis=2)

        # HWC to CHW
        mask = mask_nd.transpose((2, 0, 1))
        mask = mask / 255

        return mask

    @classmethod
    def one_hot2mask(
        cls, one_hot_mask: torch.FloatTensor, shape: str = "CHW"
    ) -> numpy.array:
        r"""
        Returns the one-channel mask (1HW) corresponding to the CHW one-hot encoded one.
        """
        # Assuming tensor in CHW shape
        if shape == "CHW":
            return numpy.argmax(one_hot_mask.detach().numpy(), axis=0)
        elif shape == "NCHW":
            return numpy.argmax(one_hot_mask.detach().numpy(), axis=1)
        return numpy.argmax(one_hot_mask.detach().numpy(), axis=0)

    @classmethod
    def mask2one_hot(
        cls, mask_tensor: torch.FloatTensor, output_shape: str = "NHWC"
    ) -> torch.Tensor:
        r"""
        Returns the received `FloatTensor` in the N1HW shape to a one hot encoded `LongTensor` in the NHWC shape.\
            Can return in NCHW shape is specified.

        Args:
            mask_tensor (FloatTensor): N1HW FloatTensor to be one-hot encoded.
            output_shape (str): NHWC or NCHW.
        """
        assert (
            output_shape == "NHWC" or output_shape == "NCHW"
        ), "Invalid output shape specified"

        # Assuming tensor in NCHW = N1HW shape
        if output_shape == "NHWC":
            return F.one_hot(mask_tensor, cls.n_classes).squeeze(1)
        # Assuming tensor in N1HW shape
        elif output_shape == "NCHW":
            return torch.transpose(
                torch.transpose(F.one_hot(mask_tensor, cls.n_classes), 2, 3), 1, 2
            )

    @classmethod
    def class2gray(cls, mask: numpy.array) -> numpy.array:
        r"""
        Replaces the class labels in a numpy array represented mask by their grayscale values, according to `gray2class_mapping`.
        """
        assert (
            len(cls.gray2class_mapping) == cls.n_classes
        ), f"Number of class mappings - {len(cls.gray2class_mapping)} - should be the same as the number of classes - {cls.n_classes}"
        for color, label in cls.gray2class_mapping.items():
            mask[mask == label] = color
        return mask

    @classmethod
    def gray2rgb(cls, img: Image) -> Image:
        r"""
        Converts a grayscale image into an RGB one, according to gray2rgb_mapping.
        """
        rgb_img = Image.new("RGB", img.size)
        for x in range(img.size[0]):
            for y in range(img.size[1]):
                rgb_img.putpixel((x, y), cls.gray2rgb_mapping[img.getpixel((x, y))])
        return rgb_img

    @classmethod
    def mask2image(cls, mask: numpy.array) -> Image:
        r"""
        Converts a one-channel mask (1HW) with class indices into an RGB image, according to gray2class_mapping and gray2rgb_mapping.
        """
        return cls.gray2rgb(Image.fromarray(cls.class2gray(mask).astype(numpy.uint8)))

    def augment(self, image, mask, policy="retina", augmentation_ratio=0):
        """
        Returns a list with the original image and mask and augmented versions of them.
        The number of augmented images and masks is equal to the specified augmentation_ratio.
        The policy is chosen by the policy argument
        """
        tf_imgs = []
        tf_masks = []
        # Data Augmentation
        for i in range(augmentation_ratio):
            # Select the policy
            if policy == "retina":
                aug_policy = RetinaPolicy(
                    crop_dims=[self.crop_size, self.crop_size], brightness=[0.9, 1.1]
                )

            # Apply the transformation
            tf_image, tf_mask = aug_policy(image, mask)

            # Further process the images and masks
            tf_image = self.preprocess(tf_image, self.scale)
            tf_mask = self.mask_img2class_mask(tf_mask, self.scale)
            tf_image = torch.from_numpy(tf_image).type(torch.FloatTensor)
            tf_mask = torch.from_numpy(tf_mask).type(torch.FloatTensor)
            tf_imgs.append(tf_image)
            tf_masks.append(tf_mask)

        i, j, h, w = transforms.RandomCrop.get_params(
            image, [self.crop_size, self.crop_size]
        )
        image = transforms.functional.crop(image, i, j, h, w)
        mask = transforms.functional.crop(mask, i, j, h, w)
        image = self.preprocess(image, self.scale)
        mask = self.mask_img2class_mask(mask, self.scale)
        image = torch.from_numpy(image).type(torch.FloatTensor)
        mask = torch.from_numpy(mask).type(torch.FloatTensor)

        tf_imgs.insert(0, image)
        tf_masks.insert(0, mask)

        return (tf_imgs, tf_masks)

    def __getitem__(self, i) -> Dict[List[torch.FloatTensor], List[torch.FloatTensor]]:
        r"""
        Returns two tensors: an image, of shape 1HW, and the corresponding mask, of shape CHW.
        """
        idx = self.ids[i]
        # mask_file = glob(self.masks_dir + idx.replace('training', 'manual1') + '.*')
        # img_file = glob(self.imgs_dir + idx + '.*')
        mask_file = glob(f"{self.masks_dir}{idx}.*")
        img_file = glob(self.imgs_dir + idx + ".*")

        # print(img_file, mask_file)

        assert (
            len(mask_file) == 1
        ), f"Either no mask or multiple masks found for the ID {idx}: {mask_file}"
        assert (
            len(img_file) == 1
        ), f"Either no image or multiple images found for the ID {idx}: {img_file}"

        mask = Image.open(mask_file[0])
        image = Image.open(img_file[0])

        # Here we apply any changes to the image that we want for our specfici prediction task
        maskArray = numpy.array(mask).astype("uint8")
        imageArray = numpy.array(image).astype("uint8")

        # ## Get endpoints of skeleton
        # endPoints = utils.utils.skelEndpoints(maskArray)

        # ## change a channel to show the start and end of centreline
        # imageArray[:, :, -1] = endPoints.astype(numpy.uint8)*255

        crudeMask = utils.utils.crudeMaskGenerator(maskArray)
        imageArray[:, :, -1] = crudeMask.astype(numpy.uint8)

        # print(imageArray.max(), imageArray.min())

        ## Reconvert to PIL image object
        image = Image.fromarray(imageArray.astype(numpy.uint8))

        assert (
            image.size == mask.size
        ), f"Image and mask {idx} should be the same size, but are {image.size} and {mask.size}"

        images, masks = self.augment(
            image, mask, policy=self.policy, augmentation_ratio=self.augmentation_ratio
        )
        return {"image": images, "mask": masks}