| | import numpy as np |
| | import cv2 |
| | import albumentations as A |
| | from torch.utils.data import Dataset |
| | from .data_utils import * |
| |
|
| | class BaseDataset(Dataset): |
| | def __init__(self): |
| | self.data = [] |
| |
|
| | def __getitem__(self, idx): |
| | item = self._get_sample(idx) |
| | return item |
| | |
| | def _get_sample(self, idx): |
| | |
| | pass |
| |
|
| | def __len__(self): |
| | |
| | pass |
| |
|
| | def aug_data_mask(self, image, mask): |
| | transform = A.Compose([ |
| | A.RandomBrightnessContrast(p=0.5), |
| | A.Rotate(limit=30, border_mode=cv2.BORDER_CONSTANT), |
| | ]) |
| |
|
| | transformed = transform(image=image.astype(np.uint8), mask=mask) |
| | transformed_image = transformed["image"] |
| | transformed_mask = transformed["mask"] |
| | return transformed_image, transformed_mask |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | def aug_patch(self, patch): |
| | gray = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY) |
| | mask = (gray < 250).astype(np.float32)[:, :, None] |
| |
|
| | transform = A.Compose([ |
| | A.HorizontalFlip(p=0.2), |
| | A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3), |
| | A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5), |
| | ]) |
| |
|
| | transformed = transform(image=patch.astype(np.uint8), mask=mask) |
| | aug_img = transformed["image"] |
| | aug_mask = transformed["mask"] |
| | final_img = aug_img * aug_mask + 255 * (1 - aug_mask) |
| |
|
| | return final_img.astype(np.uint8) |
| |
|
| | def sample_timestep(self, max_step=1000): |
| | if np.random.rand() < 0.3: |
| | step = np.random.randint(0, max_step) |
| | else: |
| | step = np.random.randint(0, max_step // 2) |
| | return np.array([step]) |
| |
|
| | def get_patch(self, ref_image, ref_mask): |
| | ''' |
| | extract compact patch and convert to 224x224 RGBA. |
| | ref_mask: [0, 1] |
| | ''' |
| |
|
| | |
| | y1, y2, x1, x2 = get_bbox_from_mask(ref_mask) |
| | |
| | |
| | ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1) |
| | masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1 - ref_mask_3) |
| |
|
| | |
| | masked_ref_image = masked_ref_image[y1:y2, x1:x2, :] |
| | ref_mask_crop = ref_mask[y1:y2, x1:x2] |
| |
|
| | |
| | ratio = np.random.randint(11, 15) / 10 |
| | masked_ref_image, ref_mask_crop = expand_image_mask(masked_ref_image, ref_mask_crop, ratio=ratio) |
| |
|
| | |
| | |
| |
|
| | |
| | masked_ref_image = pad_to_square(masked_ref_image, pad_value=255) |
| | masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), (224, 224)) |
| |
|
| | m_local = ref_mask_crop[:, :, None] * 255 |
| | m_local = pad_to_square(m_local, pad_value=0) |
| | m_local = cv2.resize(m_local.astype(np.uint8), (224, 224), interpolation=cv2.INTER_NEAREST) |
| | |
| | rgba_image = np.dstack((masked_ref_image.astype(np.uint8), m_local)) |
| |
|
| | return rgba_image |
| |
|
| | def _construct_collage(self, image, object_0, object_1, mask_0, mask_1): |
| | background = image.copy() |
| | image = pad_to_square(image, pad_value = 0, random = False).astype(np.uint8) |
| | image = cv2.resize(image.astype(np.uint8), (512,512)).astype(np.float32) |
| | image = image / 127.5 - 1.0 |
| | item = {} |
| | item.update({'jpg': image.copy()}) |
| |
|
| | ratio = np.random.randint(11, 15) / 10 |
| | object_0 = expand_image(object_0, ratio=ratio) |
| | object_0 = self.aug_patch(object_0) |
| | object_0 = pad_to_square(object_0, pad_value = 255, random = False) |
| | object_0 = cv2.resize(object_0.astype(np.uint8), (224,224) ).astype(np.uint8) |
| | object_0 = object_0 / 255 |
| | item.update({'ref0': object_0.copy()}) |
| |
|
| | ratio = np.random.randint(11, 15) / 10 |
| | object_1 = expand_image(object_1, ratio=ratio) |
| | object_1 = self.aug_patch(object_1) |
| | object_1 = pad_to_square(object_1, pad_value = 255, random = False) |
| | object_1 = cv2.resize(object_1.astype(np.uint8), (224,224) ).astype(np.uint8) |
| | object_1 = object_1 / 255 |
| | item.update({'ref1': object_1.copy()}) |
| |
|
| | background_mask0 = background.copy() * 0.0 |
| | background_mask1 = background.copy() * 0.0 |
| | background_mask = background.copy() * 0.0 |
| |
|
| | box_yyxx = get_bbox_from_mask(mask_0) |
| | box_yyxx = expand_bbox(mask_0, box_yyxx, ratio=[1.1, 1.2]) |
| | y1, y2, x1, x2 = box_yyxx |
| | background[y1:y2, x1:x2,:] = 0 |
| | background_mask0[y1:y2, x1:x2, :] = 1.0 |
| | background_mask[y1:y2, x1:x2, :] = 1.0 |
| |
|
| | box_yyxx = get_bbox_from_mask(mask_1) |
| | box_yyxx = expand_bbox(mask_1, box_yyxx, ratio=[1.1, 1.2]) |
| | y1, y2, x1, x2 = box_yyxx |
| | background[y1:y2, x1:x2,:] = 0 |
| | background_mask1[y1:y2, x1:x2, :] = 1.0 |
| | background_mask[y1:y2, x1:x2, :] = 1.0 |
| |
|
| | background = pad_to_square(background, pad_value = 0, random = False).astype(np.uint8) |
| | background = cv2.resize(background.astype(np.uint8), (512,512)).astype(np.float32) |
| | background_mask0 = pad_to_square(background_mask0, pad_value = 2, random = False).astype(np.uint8) |
| | background_mask1 = pad_to_square(background_mask1, pad_value = 2, random = False).astype(np.uint8) |
| | background_mask = pad_to_square(background_mask, pad_value = 2, random = False).astype(np.uint8) |
| | background_mask0 = cv2.resize(background_mask0.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32) |
| | background_mask1 = cv2.resize(background_mask1.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32) |
| | background_mask = cv2.resize(background_mask.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32) |
| | |
| | background_mask0[background_mask0 == 2] = -1 |
| | background_mask1[background_mask1 == 2] = -1 |
| | background_mask[background_mask == 2] = -1 |
| |
|
| | background_mask0_ = background_mask0 |
| | background_mask0_[background_mask0_ == -1] = 0 |
| | background_mask0_ = background_mask0_[:, :, 0] |
| |
|
| | background_mask1_ = background_mask1 |
| | background_mask1_[background_mask1_ == -1] = 0 |
| | background_mask1_ = background_mask1_[:, :, 0] |
| |
|
| | background = background / 127.5 - 1.0 |
| | background = np.concatenate([background, background_mask[:,:,:1]] , -1) |
| | item.update({'hint': background.copy()}) |
| |
|
| | item.update({'mask0': background_mask0_.copy()}) |
| | item.update({'mask1': background_mask1_.copy()}) |
| |
|
| | sampled_time_steps = self.sample_timestep() |
| | item['time_steps'] = sampled_time_steps |
| | item['object_num'] = 2 |
| |
|
| | return item |
| |
|