File size: 5,288 Bytes
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt

"""
Created in September 2022
@author: fabrizio.guillaro
"""

from abc import ABC, abstractmethod
from PIL import Image
import numpy as np
import torch
import random
import cv2


class AbstractDataset(ABC):

    def __init__(self, crop_size, grid_crop: bool, max_dim=None, aug=None):
        """
        :param crop_size: (H, W) or None. H and W must be the multiple of 8 if grid_crop==True.
        :param grid_crop: T: crop within 8x8 grid. F: crop anywhere.
        :param max_dim: if image is bigger than this size, it is cropped
        :param aug: augmentation
        """
        self._crop_size = crop_size
        self._max_dim   = max_dim
        self._grid_crop = grid_crop

        if grid_crop and crop_size is not None:
            assert crop_size[0] % 8 == 0 and crop_size[1] % 8 == 0

        self.img_list = None
        self.aug = aug
        #if self.aug is not None:
        #    print('Augmentation:', self.aug)


    def _create_tensor(self, mask=None, rgb_path=None):
        ignore_index = -1

        try:
            img_RGB = np.array(Image.open(rgb_path).convert("RGB"))
        except:
            raise ValueError(f'error path: {rgb_path}')

        h, w = img_RGB.shape[0], img_RGB.shape[1]
        
        if mask is None:
            mask = np.zeros((h, w))
        elif mask.shape[0]!=h or mask.shape[1]!=w:
            # a small number of images have a mask that mismatches the size of the image
            print(f'MASK MISMATCH: {rgb_path} \n h:{h}, w:{w}, mask: {mask.shape}', flush=True)
            try:
                mask = np.ascontiguousarray(np.rot90(mask))
                assert mask.shape[0]==h and mask.shape[1]==w
            except:
                mask = cv2.resize(np.uint8(mask), (h, w), interpolation=cv2.INTER_NEAREST)>0

        # augmentation
        if self.aug is not None:
            mask = np.uint8(mask)
            dat = self.aug(image=img_RGB, mask=mask)
            assert dat['image'].dtype==img_RGB.dtype
            assert dat['mask'].dtype==mask.dtype
            img_RGB = dat['image']
            mask = dat['mask']>0
            h, w = img_RGB.shape[0], img_RGB.shape[1]
            del dat

        # cropping
        if self._crop_size is None and self._grid_crop:
            crop_size = (-(-h//8) * 8, -(-w//8) * 8)  # smallest 8x8 grid crop that contains image
        elif self._crop_size is None and not self._grid_crop:
            crop_size = None  # use entire image! no crop, no pad
        else:
            crop_size = self._crop_size

        if crop_size is not None:
            # Pad if crop_size is larger than image size
            if h < crop_size[0] or w < crop_size[1]:
                
                # pad RGB
                if img_RGB is not None:
                    temp = np.full((max(h, crop_size[0]), max(w, crop_size[1]), 3), 127.5)
                    temp[:img_RGB.shape[0], :img_RGB.shape[1], :] = img_RGB
                    img_RGB = temp

                # pad mask
                temp = np.full((max(h, crop_size[0]), max(w, crop_size[1])), ignore_index)  # pad with ignore_index(-1)
                try:
                    temp[:mask.shape[0], :mask.shape[1]] = mask
                    mask = temp
                except:
                    raise ValueError(f'{rgb_path}\nh:{h}, w:{w}, temp:{temp.shape}, mask: {mask.shape}')

            # Determine where to crop
            if self._grid_crop:
                s_r = (random.randint(0, max(h - crop_size[0], 0)) // 8) * 8
                s_c = (random.randint(0, max(w - crop_size[1], 0)) // 8) * 8
            else:
                s_r = random.randint(0, max(h - crop_size[0], 0))
                s_c = random.randint(0, max(w - crop_size[1], 0))

            # crop
            mask    = mask[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1]]
            img_RGB = img_RGB[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1], :]
                
        # cropping big images
        if self._max_dim is not None:
            max_dim = self._max_dim
            # Determine where to crop
            s_r = (max((h - max_dim)//2, 0) // 8) * 8
            s_c = (max((w - max_dim)//2, 0) // 8) * 8

            # crop
            mask    = mask[s_r:s_r+max_dim, s_c:s_c+max_dim]
            img_RGB = img_RGB[s_r:s_r+max_dim, s_c:s_c+max_dim, :]
        
        t_mask = torch.tensor(mask, dtype=torch.long)
        t_RGB  = torch.tensor(img_RGB.transpose(2,0,1), dtype=torch.float)/256.0
        return t_RGB, t_mask
    

    @abstractmethod
    def get_img(self, index):
        pass

    def get_img_name(self, index):
        item = self.img_list[index]
        if isinstance(item, list):
            return item[0]
        else:
            return item

    def __len__(self):
        return len(self.img_list)