PICS / datasets /base.py
Hang Zhou
Upload folder using huggingface_hub
0103f17 verified
import numpy as np
import cv2
import albumentations as A
from torch.utils.data import Dataset
from .data_utils import *
class BaseDataset(Dataset):
def __init__(self):
self.data = []
def __getitem__(self, idx):
item = self._get_sample(idx)
return item
def _get_sample(self, idx):
# Implemented for each specific dataset
pass
def __len__(self):
# We adjust the ratio of different dataset by setting the length.
pass
def aug_data_mask(self, image, mask):
transform = A.Compose([
A.RandomBrightnessContrast(p=0.5),
A.Rotate(limit=30, border_mode=cv2.BORDER_CONSTANT),
])
transformed = transform(image=image.astype(np.uint8), mask=mask)
transformed_image = transformed["image"]
transformed_mask = transformed["mask"]
return transformed_image, transformed_mask
# def aug_patch(self, patch):
# transform = A.Compose([
# A.HorizontalFlip(p=0.2),
# A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
# A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5),
# ])
# return transform(image=patch)["image"]
def aug_patch(self, patch):
gray = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY)
mask = (gray < 250).astype(np.float32)[:, :, None]
transform = A.Compose([
A.HorizontalFlip(p=0.2),
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5),
])
transformed = transform(image=patch.astype(np.uint8), mask=mask)
aug_img = transformed["image"]
aug_mask = transformed["mask"]
final_img = aug_img * aug_mask + 255 * (1 - aug_mask)
return final_img.astype(np.uint8)
def sample_timestep(self, max_step=1000):
if np.random.rand() < 0.3:
step = np.random.randint(0, max_step)
else:
step = np.random.randint(0, max_step // 2)
return np.array([step])
def get_patch(self, ref_image, ref_mask):
'''
extract compact patch and convert to 224x224 RGBA.
ref_mask: [0, 1]
'''
# 1. Get the outline Box of the reference image
y1, y2, x1, x2 = get_bbox_from_mask(ref_mask) # y1y2x1x2, obtain location from ref patch
# 2. Background is set to white (255)
ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1 - ref_mask_3)
# 3. Crop based on bounding boxes
masked_ref_image = masked_ref_image[y1:y2, x1:x2, :]
ref_mask_crop = ref_mask[y1:y2, x1:x2] # obtain a tight mask
# 4. Dilate the patch and mask
ratio = np.random.randint(11, 15) / 10
masked_ref_image, ref_mask_crop = expand_image_mask(masked_ref_image, ref_mask_crop, ratio=ratio)
# augmentation
# masked_ref_image, ref_mask_crop = self.aug_data_mask(masked_ref_image, ref_mask_crop)
# 5. Padding & Resize
masked_ref_image = pad_to_square(masked_ref_image, pad_value=255)
masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), (224, 224))
m_local = ref_mask_crop[:, :, None] * 255
m_local = pad_to_square(m_local, pad_value=0)
m_local = cv2.resize(m_local.astype(np.uint8), (224, 224), interpolation=cv2.INTER_NEAREST)
rgba_image = np.dstack((masked_ref_image.astype(np.uint8), m_local))
return rgba_image
def _construct_collage(self, image, object_0, object_1, mask_0, mask_1):
background = image.copy()
image = pad_to_square(image, pad_value = 0, random = False).astype(np.uint8)
image = cv2.resize(image.astype(np.uint8), (512,512)).astype(np.float32)
image = image / 127.5 - 1.0
item = {}
item.update({'jpg': image.copy()}) # source image (checked) [-1, 1], 512x512x3
ratio = np.random.randint(11, 15) / 10
object_0 = expand_image(object_0, ratio=ratio)
object_0 = self.aug_patch(object_0)
object_0 = pad_to_square(object_0, pad_value = 255, random = False) # pad to square
object_0 = cv2.resize(object_0.astype(np.uint8), (224,224) ).astype(np.uint8) # check 1
object_0 = object_0 / 255
item.update({'ref0': object_0.copy()}) # patch 0 (checked) [0, 1], 224x224x3
ratio = np.random.randint(11, 15) / 10
object_1 = expand_image(object_1, ratio=ratio)
object_1 = self.aug_patch(object_1)
object_1 = pad_to_square(object_1, pad_value = 255, random = False) # pad to square
object_1 = cv2.resize(object_1.astype(np.uint8), (224,224) ).astype(np.uint8) # check 1
object_1 = object_1 / 255
item.update({'ref1': object_1.copy()}) # patch 1 (checked) [0, 1], 224x224x3
background_mask0 = background.copy() * 0.0
background_mask1 = background.copy() * 0.0
background_mask = background.copy() * 0.0
box_yyxx = get_bbox_from_mask(mask_0)
box_yyxx = expand_bbox(mask_0, box_yyxx, ratio=[1.1, 1.2]) #1.1 1.3
y1, y2, x1, x2 = box_yyxx
background[y1:y2, x1:x2,:] = 0
background_mask0[y1:y2, x1:x2, :] = 1.0
background_mask[y1:y2, x1:x2, :] = 1.0
box_yyxx = get_bbox_from_mask(mask_1)
box_yyxx = expand_bbox(mask_1, box_yyxx, ratio=[1.1, 1.2]) #1.1 1.3
y1, y2, x1, x2 = box_yyxx
background[y1:y2, x1:x2,:] = 0
background_mask1[y1:y2, x1:x2, :] = 1.0
background_mask[y1:y2, x1:x2, :] = 1.0
background = pad_to_square(background, pad_value = 0, random = False).astype(np.uint8)
background = cv2.resize(background.astype(np.uint8), (512,512)).astype(np.float32)
background_mask0 = pad_to_square(background_mask0, pad_value = 2, random = False).astype(np.uint8)
background_mask1 = pad_to_square(background_mask1, pad_value = 2, random = False).astype(np.uint8)
background_mask = pad_to_square(background_mask, pad_value = 2, random = False).astype(np.uint8)
background_mask0 = cv2.resize(background_mask0.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32)
background_mask1 = cv2.resize(background_mask1.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32)
background_mask = cv2.resize(background_mask.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32)
background_mask0[background_mask0 == 2] = -1
background_mask1[background_mask1 == 2] = -1
background_mask[background_mask == 2] = -1
background_mask0_ = background_mask0
background_mask0_[background_mask0_ == -1] = 0
background_mask0_ = background_mask0_[:, :, 0]
background_mask1_ = background_mask1
background_mask1_[background_mask1_ == -1] = 0
background_mask1_ = background_mask1_[:, :, 0]
background = background / 127.5 - 1.0
background = np.concatenate([background, background_mask[:,:,:1]] , -1)
item.update({'hint': background.copy()})
item.update({'mask0': background_mask0_.copy()})
item.update({'mask1': background_mask1_.copy()})
sampled_time_steps = self.sample_timestep()
item['time_steps'] = sampled_time_steps
item['object_num'] = 2
return item