File size: 5,288 Bytes
5f0437a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
from abc import ABC, abstractmethod
from PIL import Image
import numpy as np
import torch
import random
import cv2
class AbstractDataset(ABC):
def __init__(self, crop_size, grid_crop: bool, max_dim=None, aug=None):
"""
:param crop_size: (H, W) or None. H and W must be the multiple of 8 if grid_crop==True.
:param grid_crop: T: crop within 8x8 grid. F: crop anywhere.
:param max_dim: if image is bigger than this size, it is cropped
:param aug: augmentation
"""
self._crop_size = crop_size
self._max_dim = max_dim
self._grid_crop = grid_crop
if grid_crop and crop_size is not None:
assert crop_size[0] % 8 == 0 and crop_size[1] % 8 == 0
self.img_list = None
self.aug = aug
#if self.aug is not None:
# print('Augmentation:', self.aug)
def _create_tensor(self, mask=None, rgb_path=None):
ignore_index = -1
try:
img_RGB = np.array(Image.open(rgb_path).convert("RGB"))
except:
raise ValueError(f'error path: {rgb_path}')
h, w = img_RGB.shape[0], img_RGB.shape[1]
if mask is None:
mask = np.zeros((h, w))
elif mask.shape[0]!=h or mask.shape[1]!=w:
# a small number of images have a mask that mismatches the size of the image
print(f'MASK MISMATCH: {rgb_path} \n h:{h}, w:{w}, mask: {mask.shape}', flush=True)
try:
mask = np.ascontiguousarray(np.rot90(mask))
assert mask.shape[0]==h and mask.shape[1]==w
except:
mask = cv2.resize(np.uint8(mask), (h, w), interpolation=cv2.INTER_NEAREST)>0
# augmentation
if self.aug is not None:
mask = np.uint8(mask)
dat = self.aug(image=img_RGB, mask=mask)
assert dat['image'].dtype==img_RGB.dtype
assert dat['mask'].dtype==mask.dtype
img_RGB = dat['image']
mask = dat['mask']>0
h, w = img_RGB.shape[0], img_RGB.shape[1]
del dat
# cropping
if self._crop_size is None and self._grid_crop:
crop_size = (-(-h//8) * 8, -(-w//8) * 8) # smallest 8x8 grid crop that contains image
elif self._crop_size is None and not self._grid_crop:
crop_size = None # use entire image! no crop, no pad
else:
crop_size = self._crop_size
if crop_size is not None:
# Pad if crop_size is larger than image size
if h < crop_size[0] or w < crop_size[1]:
# pad RGB
if img_RGB is not None:
temp = np.full((max(h, crop_size[0]), max(w, crop_size[1]), 3), 127.5)
temp[:img_RGB.shape[0], :img_RGB.shape[1], :] = img_RGB
img_RGB = temp
# pad mask
temp = np.full((max(h, crop_size[0]), max(w, crop_size[1])), ignore_index) # pad with ignore_index(-1)
try:
temp[:mask.shape[0], :mask.shape[1]] = mask
mask = temp
except:
raise ValueError(f'{rgb_path}\nh:{h}, w:{w}, temp:{temp.shape}, mask: {mask.shape}')
# Determine where to crop
if self._grid_crop:
s_r = (random.randint(0, max(h - crop_size[0], 0)) // 8) * 8
s_c = (random.randint(0, max(w - crop_size[1], 0)) // 8) * 8
else:
s_r = random.randint(0, max(h - crop_size[0], 0))
s_c = random.randint(0, max(w - crop_size[1], 0))
# crop
mask = mask[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1]]
img_RGB = img_RGB[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1], :]
# cropping big images
if self._max_dim is not None:
max_dim = self._max_dim
# Determine where to crop
s_r = (max((h - max_dim)//2, 0) // 8) * 8
s_c = (max((w - max_dim)//2, 0) // 8) * 8
# crop
mask = mask[s_r:s_r+max_dim, s_c:s_c+max_dim]
img_RGB = img_RGB[s_r:s_r+max_dim, s_c:s_c+max_dim, :]
t_mask = torch.tensor(mask, dtype=torch.long)
t_RGB = torch.tensor(img_RGB.transpose(2,0,1), dtype=torch.float)/256.0
return t_RGB, t_mask
@abstractmethod
def get_img(self, index):
pass
def get_img_name(self, index):
item = self.img_list[index]
if isinstance(item, list):
return item[0]
else:
return item
def __len__(self):
return len(self.img_list)
|