|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Created in September 2022 |
|
|
@author: fabrizio.guillaro |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch |
|
|
import random |
|
|
import cv2 |
|
|
|
|
|
|
|
|
class AbstractDataset(ABC): |
|
|
|
|
|
def __init__(self, crop_size, grid_crop: bool, max_dim=None, aug=None): |
|
|
""" |
|
|
:param crop_size: (H, W) or None. H and W must be the multiple of 8 if grid_crop==True. |
|
|
:param grid_crop: T: crop within 8x8 grid. F: crop anywhere. |
|
|
:param max_dim: if image is bigger than this size, it is cropped |
|
|
:param aug: augmentation |
|
|
""" |
|
|
self._crop_size = crop_size |
|
|
self._max_dim = max_dim |
|
|
self._grid_crop = grid_crop |
|
|
|
|
|
if grid_crop and crop_size is not None: |
|
|
assert crop_size[0] % 8 == 0 and crop_size[1] % 8 == 0 |
|
|
|
|
|
self.img_list = None |
|
|
self.aug = aug |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_tensor(self, mask=None, rgb_path=None): |
|
|
ignore_index = -1 |
|
|
|
|
|
try: |
|
|
img_RGB = np.array(Image.open(rgb_path).convert("RGB")) |
|
|
except: |
|
|
raise ValueError(f'error path: {rgb_path}') |
|
|
|
|
|
h, w = img_RGB.shape[0], img_RGB.shape[1] |
|
|
|
|
|
if mask is None: |
|
|
mask = np.zeros((h, w)) |
|
|
elif mask.shape[0]!=h or mask.shape[1]!=w: |
|
|
|
|
|
print(f'MASK MISMATCH: {rgb_path} \n h:{h}, w:{w}, mask: {mask.shape}', flush=True) |
|
|
try: |
|
|
mask = np.ascontiguousarray(np.rot90(mask)) |
|
|
assert mask.shape[0]==h and mask.shape[1]==w |
|
|
except: |
|
|
mask = cv2.resize(np.uint8(mask), (h, w), interpolation=cv2.INTER_NEAREST)>0 |
|
|
|
|
|
|
|
|
if self.aug is not None: |
|
|
mask = np.uint8(mask) |
|
|
dat = self.aug(image=img_RGB, mask=mask) |
|
|
assert dat['image'].dtype==img_RGB.dtype |
|
|
assert dat['mask'].dtype==mask.dtype |
|
|
img_RGB = dat['image'] |
|
|
mask = dat['mask']>0 |
|
|
h, w = img_RGB.shape[0], img_RGB.shape[1] |
|
|
del dat |
|
|
|
|
|
|
|
|
if self._crop_size is None and self._grid_crop: |
|
|
crop_size = (-(-h//8) * 8, -(-w//8) * 8) |
|
|
elif self._crop_size is None and not self._grid_crop: |
|
|
crop_size = None |
|
|
else: |
|
|
crop_size = self._crop_size |
|
|
|
|
|
if crop_size is not None: |
|
|
|
|
|
if h < crop_size[0] or w < crop_size[1]: |
|
|
|
|
|
|
|
|
if img_RGB is not None: |
|
|
temp = np.full((max(h, crop_size[0]), max(w, crop_size[1]), 3), 127.5) |
|
|
temp[:img_RGB.shape[0], :img_RGB.shape[1], :] = img_RGB |
|
|
img_RGB = temp |
|
|
|
|
|
|
|
|
temp = np.full((max(h, crop_size[0]), max(w, crop_size[1])), ignore_index) |
|
|
try: |
|
|
temp[:mask.shape[0], :mask.shape[1]] = mask |
|
|
mask = temp |
|
|
except: |
|
|
raise ValueError(f'{rgb_path}\nh:{h}, w:{w}, temp:{temp.shape}, mask: {mask.shape}') |
|
|
|
|
|
|
|
|
if self._grid_crop: |
|
|
s_r = (random.randint(0, max(h - crop_size[0], 0)) // 8) * 8 |
|
|
s_c = (random.randint(0, max(w - crop_size[1], 0)) // 8) * 8 |
|
|
else: |
|
|
s_r = random.randint(0, max(h - crop_size[0], 0)) |
|
|
s_c = random.randint(0, max(w - crop_size[1], 0)) |
|
|
|
|
|
|
|
|
mask = mask[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1]] |
|
|
img_RGB = img_RGB[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1], :] |
|
|
|
|
|
|
|
|
if self._max_dim is not None: |
|
|
max_dim = self._max_dim |
|
|
|
|
|
s_r = (max((h - max_dim)//2, 0) // 8) * 8 |
|
|
s_c = (max((w - max_dim)//2, 0) // 8) * 8 |
|
|
|
|
|
|
|
|
mask = mask[s_r:s_r+max_dim, s_c:s_c+max_dim] |
|
|
img_RGB = img_RGB[s_r:s_r+max_dim, s_c:s_c+max_dim, :] |
|
|
|
|
|
t_mask = torch.tensor(mask, dtype=torch.long) |
|
|
t_RGB = torch.tensor(img_RGB.transpose(2,0,1), dtype=torch.float)/256.0 |
|
|
return t_RGB, t_mask |
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
def get_img(self, index): |
|
|
pass |
|
|
|
|
|
def get_img_name(self, index): |
|
|
item = self.img_list[index] |
|
|
if isinstance(item, list): |
|
|
return item[0] |
|
|
else: |
|
|
return item |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.img_list) |
|
|
|
|
|
|