detect / TruFor_train_test /dataset /AbstractDataset.py
DaniilOr's picture
Upload folder using huggingface_hub
5f0437a verified
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
from abc import ABC, abstractmethod
from PIL import Image
import numpy as np
import torch
import random
import cv2
class AbstractDataset(ABC):
def __init__(self, crop_size, grid_crop: bool, max_dim=None, aug=None):
"""
:param crop_size: (H, W) or None. H and W must be the multiple of 8 if grid_crop==True.
:param grid_crop: T: crop within 8x8 grid. F: crop anywhere.
:param max_dim: if image is bigger than this size, it is cropped
:param aug: augmentation
"""
self._crop_size = crop_size
self._max_dim = max_dim
self._grid_crop = grid_crop
if grid_crop and crop_size is not None:
assert crop_size[0] % 8 == 0 and crop_size[1] % 8 == 0
self.img_list = None
self.aug = aug
#if self.aug is not None:
# print('Augmentation:', self.aug)
def _create_tensor(self, mask=None, rgb_path=None):
ignore_index = -1
try:
img_RGB = np.array(Image.open(rgb_path).convert("RGB"))
except:
raise ValueError(f'error path: {rgb_path}')
h, w = img_RGB.shape[0], img_RGB.shape[1]
if mask is None:
mask = np.zeros((h, w))
elif mask.shape[0]!=h or mask.shape[1]!=w:
# a small number of images have a mask that mismatches the size of the image
print(f'MASK MISMATCH: {rgb_path} \n h:{h}, w:{w}, mask: {mask.shape}', flush=True)
try:
mask = np.ascontiguousarray(np.rot90(mask))
assert mask.shape[0]==h and mask.shape[1]==w
except:
mask = cv2.resize(np.uint8(mask), (h, w), interpolation=cv2.INTER_NEAREST)>0
# augmentation
if self.aug is not None:
mask = np.uint8(mask)
dat = self.aug(image=img_RGB, mask=mask)
assert dat['image'].dtype==img_RGB.dtype
assert dat['mask'].dtype==mask.dtype
img_RGB = dat['image']
mask = dat['mask']>0
h, w = img_RGB.shape[0], img_RGB.shape[1]
del dat
# cropping
if self._crop_size is None and self._grid_crop:
crop_size = (-(-h//8) * 8, -(-w//8) * 8) # smallest 8x8 grid crop that contains image
elif self._crop_size is None and not self._grid_crop:
crop_size = None # use entire image! no crop, no pad
else:
crop_size = self._crop_size
if crop_size is not None:
# Pad if crop_size is larger than image size
if h < crop_size[0] or w < crop_size[1]:
# pad RGB
if img_RGB is not None:
temp = np.full((max(h, crop_size[0]), max(w, crop_size[1]), 3), 127.5)
temp[:img_RGB.shape[0], :img_RGB.shape[1], :] = img_RGB
img_RGB = temp
# pad mask
temp = np.full((max(h, crop_size[0]), max(w, crop_size[1])), ignore_index) # pad with ignore_index(-1)
try:
temp[:mask.shape[0], :mask.shape[1]] = mask
mask = temp
except:
raise ValueError(f'{rgb_path}\nh:{h}, w:{w}, temp:{temp.shape}, mask: {mask.shape}')
# Determine where to crop
if self._grid_crop:
s_r = (random.randint(0, max(h - crop_size[0], 0)) // 8) * 8
s_c = (random.randint(0, max(w - crop_size[1], 0)) // 8) * 8
else:
s_r = random.randint(0, max(h - crop_size[0], 0))
s_c = random.randint(0, max(w - crop_size[1], 0))
# crop
mask = mask[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1]]
img_RGB = img_RGB[s_r:s_r+crop_size[0], s_c:s_c+crop_size[1], :]
# cropping big images
if self._max_dim is not None:
max_dim = self._max_dim
# Determine where to crop
s_r = (max((h - max_dim)//2, 0) // 8) * 8
s_c = (max((w - max_dim)//2, 0) // 8) * 8
# crop
mask = mask[s_r:s_r+max_dim, s_c:s_c+max_dim]
img_RGB = img_RGB[s_r:s_r+max_dim, s_c:s_c+max_dim, :]
t_mask = torch.tensor(mask, dtype=torch.long)
t_RGB = torch.tensor(img_RGB.transpose(2,0,1), dtype=torch.float)/256.0
return t_RGB, t_mask
@abstractmethod
def get_img(self, index):
pass
def get_img_name(self, index):
item = self.img_list[index]
if isinstance(item, list):
return item[0]
else:
return item
def __len__(self):
return len(self.img_list)