File size: 5,591 Bytes
6029b11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
import os
import random

from torch.utils.data import Dataset
from pycocotools.coco import COCO
from pycocotools import mask as maskUtils

from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
from torchvision import transforms
from tqdm import tqdm
import pickle
import cv2
import torch
import numpy as np
import copy
from transformers import AutoProcessor
from nltk.corpus import wordnet
from bg_aug import get_bkgd
import jax
import random

clip_standard_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((224, 224), interpolation=Image.BICUBIC),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
to_tensor = transforms.ToTensor()

normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

mask_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((224, 224)),
    transforms.Normalize(0.5, 0.26)
])

crop_aug = transforms.Compose([
    transforms.RandomCrop((224-32, 224-32)),
    transforms.Resize((224, 224)),
])

def text_filter(text):
    text = text.replace(' with a white background', '')
    text = text.replace(' with white background', '')
    text = text.replace(' next to a white background', '')
    text = text.replace(' over a white background', '')
    text = text.replace(' is cut out of a white background', '')
    text = text.replace(' across a white background', '')
    text = text.replace(' on a white background', '')
    text = text.replace(' sticking out of a white background', '')
    text = text.replace(' in the middle of a white background', '')
    text = text.replace(' on white background', '')
    text = text.replace(' in a white background', '')
    text = text.replace(' and a white background', '')
    text = text.replace(' and white background', '')
    text = text.replace(' in front of a white background', '')
    text = text.replace(' on top of a white background', '')
    text = text.replace(' against a white background', '')
    text = text.replace('a white background with ', '')
    text = text.replace(' and has a white background', '')
    text = text.replace('white background', 'background')
    text = text + '.'
    return text

def crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, scale=1.5):
    tl_x = int(bbox_xywh[0])
    tl_y = int(bbox_xywh[1])
    w = int(bbox_xywh[2]) if int(bbox_xywh[2]) > 0 else 1
    h = int(bbox_xywh[3]) if int(bbox_xywh[3]) > 0 else 1
    image_h, image_w = image.shape[:2]

    # shape maintained
    r = max(h, w)
    tl_x -= (r - w) / 2
    tl_y -= (r - h) / 2
    half_scale = (scale - 1.0) / 2
    w_l = int(tl_x - half_scale * r) if (tl_x - half_scale * r) > 0 else 0
    w_r = int(tl_x + (1+half_scale) * r) if (tl_x + (1+half_scale) * r) < image_w else image_w - 1
    h_t = int(tl_y - half_scale * r) if (tl_y - half_scale * r) > 0 else 0
    h_b = int(tl_y + (1+half_scale) * r) if (tl_y + (1+half_scale) * r) < image_h else image_h - 1

    return image[h_t: h_b, w_l: w_r, :], bi_mask[h_t: h_b, w_l: w_r]

def masked_crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, crop_scale=1.0, masked_color=[255, 255, 255]):
    # padding to make_sure bboxshape maintained
    image = np.pad(image, ((600, 600), (600, 600), (0, 0)), 'constant', constant_values=255)
    bi_mask = np.pad(bi_mask, ((600, 600), (600, 600)), "constant", constant_values=0)
    bbox_xywh[:2] += 600
    cropped_image, cropped_mask = crop(image, bbox_xywh, bi_mask, crop_scale)   
    cropped_image[np.nonzero(cropped_mask == 0)] = masked_color
    return cropped_image, cropped_mask

class ImageNet_Masked(Dataset):
    def __init__(self, ann_file="M_ImageNet_top_460k.json",  masked_color=[255, 255, 255]):
        self.masked_color = masked_color
        self.anns_list = json.load(open(ann_file, 'r'))
        random.shuffle(self.anns_list)
        self.crop_scale = 1.5
        self.transform = clip_standard_transform
        self.res = 224
        self.blur = 10.0

    def __len__(self):
        return len(self.anns_list)

    def __getitem__(self, index):
        cv2.ocl.setUseOpenCL(False)
        cv2.setNumThreads(0)
        ann = self.anns_list[index]
        # TODO: change list to dict key.
        img_pth = ann[2]
        # img_pth = img_pth.replace('imagenet-21k/images', 'imagenet-21k-demo/*')
        mask = ann[3]
        bbox = ann[4]
        text = ann[6]
        image = cv2.imread(img_pth)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        bbox_xywh = np.copy(np.array(bbox))
        binary_mask = maskUtils.decode(mask)
        cat_word = img_pth.split("/")[3]
        synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:]))
        synonyms = [x.name() for x in synset.lemmas()]
        text = text.replace(".", f", probably {synonyms[0]}").replace(" ", "_").replace("/", "_").replace("\\", "_")
        image[np.nonzero(binary_mask == 1)] = (0.5 * image[np.nonzero(binary_mask == 1)] + 0.5 * np.array([0, 255, 0])).astype(np.uint8) 
        os.makedirs(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0], exist_ok=True)
        Image.fromarray(image).save(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0] + f"/{text}_" + os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[1])

if __name__ == "__main__":
    data = ImageNet_Masked()
    for i in tqdm(range(data.__len__())):
        data.__getitem__(i)