File size: 7,596 Bytes
0103f17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import numpy as np
import cv2
import albumentations as A
from torch.utils.data import Dataset
from .data_utils import * 

class BaseDataset(Dataset):
    def __init__(self):
        self.data = []

    def __getitem__(self, idx):
        item = self._get_sample(idx)
        return item
                
    def _get_sample(self, idx):
        # Implemented for each specific dataset
        pass

    def __len__(self):
        # We adjust the ratio of different dataset by setting the length.
        pass

    def aug_data_mask(self, image, mask):
        transform = A.Compose([
            A.RandomBrightnessContrast(p=0.5),
            A.Rotate(limit=30, border_mode=cv2.BORDER_CONSTANT),
            ])

        transformed = transform(image=image.astype(np.uint8), mask=mask)
        transformed_image = transformed["image"]
        transformed_mask = transformed["mask"]
        return transformed_image, transformed_mask

    # def aug_patch(self, patch):
    #     transform = A.Compose([
    #         A.HorizontalFlip(p=0.2),
    #         A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
    #         A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5),
    #         ])

    #     return transform(image=patch)["image"]

    def aug_patch(self, patch):
        gray = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY)
        mask = (gray < 250).astype(np.float32)[:, :, None] 

        transform = A.Compose([
            A.HorizontalFlip(p=0.2),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
            A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5),
        ])

        transformed = transform(image=patch.astype(np.uint8), mask=mask)
        aug_img = transformed["image"]
        aug_mask = transformed["mask"]
        final_img = aug_img * aug_mask + 255 * (1 - aug_mask)

        return final_img.astype(np.uint8)

    def sample_timestep(self, max_step=1000):
        if np.random.rand() < 0.3:
            step = np.random.randint(0, max_step)
        else:
            step = np.random.randint(0, max_step // 2)
        return np.array([step])

    def get_patch(self, ref_image, ref_mask):
        '''
        extract compact patch and convert to 224x224 RGBA. 
        ref_mask: [0, 1]
        '''

        # 1. Get the outline Box of the reference image
        y1, y2, x1, x2 = get_bbox_from_mask(ref_mask) # y1y2x1x2, obtain location from ref patch
        
        # 2. Background is set to white (255)
        ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
        masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1 - ref_mask_3)

        # 3. Crop based on bounding boxes
        masked_ref_image = masked_ref_image[y1:y2, x1:x2, :]
        ref_mask_crop = ref_mask[y1:y2, x1:x2] # obtain a tight mask

        # 4. Dilate the patch and mask
        ratio = np.random.randint(11, 15) / 10
        masked_ref_image, ref_mask_crop = expand_image_mask(masked_ref_image, ref_mask_crop, ratio=ratio)

        # augmentation
        # masked_ref_image, ref_mask_crop = self.aug_data_mask(masked_ref_image, ref_mask_crop) 

        # 5. Padding & Resize 
        masked_ref_image = pad_to_square(masked_ref_image, pad_value=255)
        masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), (224, 224))

        m_local = ref_mask_crop[:, :, None] * 255
        m_local = pad_to_square(m_local, pad_value=0)
        m_local = cv2.resize(m_local.astype(np.uint8), (224, 224), interpolation=cv2.INTER_NEAREST)
        
        rgba_image = np.dstack((masked_ref_image.astype(np.uint8), m_local))

        return rgba_image

    def _construct_collage(self, image, object_0, object_1, mask_0, mask_1):
        background = image.copy()
        image = pad_to_square(image, pad_value = 0, random = False).astype(np.uint8)
        image = cv2.resize(image.astype(np.uint8), (512,512)).astype(np.float32)
        image = image / 127.5 - 1.0
        item = {}
        item.update({'jpg': image.copy()}) # source image (checked) [-1, 1], 512x512x3

        ratio = np.random.randint(11, 15) / 10 
        object_0 = expand_image(object_0, ratio=ratio)
        object_0 = self.aug_patch(object_0)
        object_0 = pad_to_square(object_0, pad_value = 255, random = False) # pad to square
        object_0 = cv2.resize(object_0.astype(np.uint8), (224,224) ).astype(np.uint8) # check 1
        object_0 = object_0 / 255 
        item.update({'ref0': object_0.copy()}) # patch 0 (checked) [0, 1], 224x224x3

        ratio = np.random.randint(11, 15) / 10 
        object_1 = expand_image(object_1, ratio=ratio)
        object_1 = self.aug_patch(object_1)
        object_1 = pad_to_square(object_1, pad_value = 255, random = False) # pad to square
        object_1 = cv2.resize(object_1.astype(np.uint8), (224,224) ).astype(np.uint8) # check 1
        object_1 = object_1 / 255 
        item.update({'ref1': object_1.copy()}) # patch 1 (checked) [0, 1], 224x224x3

        background_mask0 = background.copy() * 0.0
        background_mask1 = background.copy() * 0.0
        background_mask = background.copy() * 0.0

        box_yyxx = get_bbox_from_mask(mask_0)
        box_yyxx = expand_bbox(mask_0, box_yyxx, ratio=[1.1, 1.2]) #1.1  1.3
        y1, y2, x1, x2 = box_yyxx
        background[y1:y2, x1:x2,:] = 0
        background_mask0[y1:y2, x1:x2, :] = 1.0
        background_mask[y1:y2, x1:x2, :] = 1.0

        box_yyxx = get_bbox_from_mask(mask_1)
        box_yyxx = expand_bbox(mask_1, box_yyxx, ratio=[1.1, 1.2]) #1.1  1.3
        y1, y2, x1, x2 = box_yyxx
        background[y1:y2, x1:x2,:] = 0
        background_mask1[y1:y2, x1:x2, :] = 1.0
        background_mask[y1:y2, x1:x2, :] = 1.0

        background = pad_to_square(background, pad_value = 0, random = False).astype(np.uint8)
        background = cv2.resize(background.astype(np.uint8), (512,512)).astype(np.float32)
        background_mask0 = pad_to_square(background_mask0, pad_value = 2, random = False).astype(np.uint8)
        background_mask1 = pad_to_square(background_mask1, pad_value = 2, random = False).astype(np.uint8)
        background_mask = pad_to_square(background_mask, pad_value = 2, random = False).astype(np.uint8)
        background_mask0  = cv2.resize(background_mask0.astype(np.uint8), (512,512),  interpolation = cv2.INTER_NEAREST).astype(np.float32)
        background_mask1  = cv2.resize(background_mask1.astype(np.uint8), (512,512),  interpolation = cv2.INTER_NEAREST).astype(np.float32)
        background_mask  = cv2.resize(background_mask.astype(np.uint8), (512,512),  interpolation = cv2.INTER_NEAREST).astype(np.float32)
        
        background_mask0[background_mask0 == 2] = -1
        background_mask1[background_mask1 == 2] = -1
        background_mask[background_mask == 2] = -1

        background_mask0_ = background_mask0
        background_mask0_[background_mask0_ == -1] = 0
        background_mask0_ = background_mask0_[:, :, 0]

        background_mask1_ = background_mask1
        background_mask1_[background_mask1_ == -1] = 0
        background_mask1_ = background_mask1_[:, :, 0]

        background = background / 127.5 - 1.0 
        background = np.concatenate([background, background_mask[:,:,:1]] , -1)
        item.update({'hint': background.copy()})

        item.update({'mask0': background_mask0_.copy()})
        item.update({'mask1': background_mask1_.copy()})

        sampled_time_steps = self.sample_timestep()
        item['time_steps'] = sampled_time_steps
        item['object_num'] = 2

        return item