chouss commited on
Commit
6029b11
·
verified ·
1 Parent(s): df44f6a

Uploading folder contents

Browse files
__pycache__/imagenet_s_test.cpython-39.pyc ADDED
Binary file (4.04 kB). View file
 
__pycache__/mask_image_test.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
__pycache__/mask_image_test.cpython-39.pyc ADDED
Binary file (14.2 kB). View file
 
alpha_grit.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from tqdm import tqdm
5
+ from torch.utils.data import Dataset
6
+ from mask_image import ImageNet_Masked
7
+ from pycocotools.coco import COCO
8
+ from pycocotools import mask as maskUtils
9
+ from PIL import Image
10
+ import cv2
11
+ import random
12
+ from torchvision import transforms
13
+ from tqdm import tqdm
14
+ PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
15
+ MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]
16
+ import pickle
17
+ import torch
18
+ import numpy as np
19
+ import copy
20
+ import sys
21
+ import shutil
22
+ from PIL import Image
23
+
24
+ def get_file(url):
25
+ return #TODO: get file path from local directory
26
+
27
+ clip_standard_transform = transforms.Compose([
28
+ transforms.ToTensor(),
29
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
30
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
31
+ ])
32
+
33
+ hi_clip_standard_transform = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ transforms.Resize((336, 336), interpolation=Image.BICUBIC),
36
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
37
+ ])
38
+
39
+ res_clip_standard_transform = transforms.Compose([
40
+ transforms.ToTensor(),
41
+ transforms.Resize((336, 336), interpolation=Image.BICUBIC),
42
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
43
+ ])
44
+
45
+ mask_transform = transforms.Compose([
46
+ transforms.ToTensor(),
47
+ transforms.Resize((224, 224)),
48
+ transforms.Normalize(0.5, 0.26)
49
+ ])
50
+
51
+ hi_mask_transform = transforms.Compose([
52
+ transforms.ToTensor(),
53
+ transforms.Resize((336, 336)),
54
+ transforms.Normalize(0.5, 0.26)
55
+ ])
56
+
57
+ res_mask_transform = transforms.Compose([
58
+ transforms.ToTensor(),
59
+ transforms.Resize((336, 336)),
60
+ transforms.Normalize(0.5, 0.26)
61
+ ])
62
+
63
+ def crop_center(img, croph, cropw):
64
+ h, w = img.shape[:2]
65
+ starth = h//2 - (croph//2)
66
+ startw = w//2 - (cropw//2)
67
+ return img[starth:starth+croph, startw:startw+cropw, :]
68
+
69
+ class Alpha_GRIT(Dataset):
70
+ def __init__(self, ids_file='grit_1m_ids.pkl', root_pth='grit-1m/', common_pair=0.0, hi_res=False, subnum=None):
71
+ if subnum is not None:
72
+ self.ids = pickle.load(open(ids_file, 'rb'))[:subnum]
73
+ else:
74
+ self.ids = pickle.load(open(ids_file, 'rb'))
75
+ self.root_pth = root_pth
76
+ self.with_common_pair_prop = common_pair
77
+ if hi_res:
78
+ self.mask_transform = res_mask_transform
79
+ self.clip_standard_transform = res_clip_standard_transform
80
+ else:
81
+ self.mask_transform = mask_transform
82
+ self.clip_standard_transform = clip_standard_transform
83
+
84
+ def __len__(self):
85
+ return len(self.ids)
86
+
87
+ def __getitem__(self, index):
88
+ id = self.ids[index]
89
+ ann = json.loads(get_file(self.root_pth + str(id) + '.json'))
90
+ image_data = get_file(self.root_pth + str(id) + '.jpg')
91
+ img = np.frombuffer(image_data, dtype=np.uint8)
92
+ img = cv2.imdecode(img, cv2.IMREAD_COLOR)
93
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
94
+ ref_exps = ann['ref_exps']
95
+ # random choose single ref with its corresponding masks
96
+ choice = random.randint(0, len(ref_exps)-1)
97
+ ref_exp = ref_exps[choice]
98
+ text = ann['caption'][int(ref_exp[0]): int(ref_exp[1])]
99
+ mask = maskUtils.decode(ann['seudo_masks'][choice])
100
+ if mask.shape != img.shape[:2]:
101
+ img = np.rot90(img)
102
+ rgba = np.concatenate((img, np.expand_dims(mask, axis=-1)), axis=-1)
103
+ h, w = rgba.shape[:2]
104
+ choice = random.randint(0, 1)
105
+ choice = 0
106
+ if choice == 0:
107
+ if max(h, w) == w:
108
+ pad = (w - h) // 2
109
+ l, r = pad, w - h - pad
110
+ rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0)
111
+ else:
112
+ pad = (h - w) // 2
113
+ l, r = pad, h - w - pad
114
+ rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0)
115
+ else:
116
+ if min(h, w) == h:
117
+ rgba = crop_center(rgba, h, h)
118
+ else:
119
+ rgba = crop_center(rgba, w, w)
120
+ rgb = rgba[:, :, :-1]
121
+ mask = rgba[:, :, -1]
122
+ image_torch = self.clip_standard_transform(rgb)
123
+
124
+ choice = random.random()
125
+ if choice >= self.with_common_pair_prop:
126
+ mask_torch = self.mask_transform(mask * 255)
127
+ return image_torch, mask_torch, text
128
+ else: # half ori image
129
+ mask_torch = self.mask_transform(np.ones_like(mask) * 255)
130
+ return image_torch, mask_torch, ann['caption']
imagenet_s_test.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from tqdm import tqdm
5
+ from torch.utils.data import Dataset
6
+ from pycocotools.coco import COCO
7
+ from pycocotools import mask as maskUtils
8
+ from PIL import Image
9
+ import cv2
10
+ import random
11
+ from torchvision import transforms
12
+ from tqdm import tqdm
13
+
14
+ import pickle
15
+ import torch
16
+ import numpy as np
17
+ import copy
18
+ import sys
19
+ import shutil
20
+ from PIL import Image
21
+ from nltk.corpus import wordnet
22
+
23
+ PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
24
+ MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]
25
+
26
+ clip_standard_transform = transforms.Compose([
27
+ transforms.ToTensor(),
28
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
29
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
30
+ ])
31
+
32
+ hi_clip_standard_transform = transforms.Compose([
33
+ transforms.ToTensor(),
34
+ transforms.Resize((336, 336), interpolation=Image.BICUBIC),
35
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
36
+ ])
37
+
38
+ res_clip_standard_transform = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ transforms.Resize((336, 336), interpolation=Image.BICUBIC),
41
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
42
+ ])
43
+
44
+ mask_transform = transforms.Compose([
45
+ transforms.ToTensor(),
46
+ transforms.Resize((224, 224)),
47
+ transforms.Normalize(0.5, 0.26)
48
+ ])
49
+
50
+ hi_mask_transform = transforms.Compose([
51
+ transforms.ToTensor(),
52
+ transforms.Resize((336, 336)),
53
+ transforms.Normalize(0.5, 0.26)
54
+ ])
55
+
56
+ res_mask_transform = transforms.Compose([
57
+ transforms.ToTensor(),
58
+ transforms.Resize((336, 336)),
59
+ transforms.Normalize(0.5, 0.26)
60
+ ])
61
+
62
+ def crop_center(img, croph, cropw):
63
+ h, w = img.shape[:2]
64
+ starth = h//2 - (croph//2)
65
+ startw = w//2 - (cropw//2)
66
+ return img[starth:starth+croph, startw:startw+cropw, :]
67
+
68
+ class Imagenet_S(Dataset):
69
+ def __init__(self, ann_file='data/imagenet_s/imagenet_919.json', hi_res=False, all_one=False):
70
+ self.anns = json.load(open(ann_file, 'r'))
71
+ self.root_pth = 'data/imagenet_s/'
72
+ cats = []
73
+ for ann in self.anns:
74
+ if ann['category_word'] not in cats:
75
+ cats.append(ann['category_word'])
76
+ ann['cat_index'] = len(cats) - 1
77
+ self.classes = []
78
+ for cat_word in cats:
79
+ synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:]))
80
+ synonyms = [x.name() for x in synset.lemmas()]
81
+ self.classes.append(synonyms[0])
82
+
83
+ self.choice = "center_crop"
84
+ if hi_res:
85
+ self.mask_transform = res_mask_transform
86
+ self.clip_standard_transform = res_clip_standard_transform
87
+ else:
88
+ self.mask_transform = mask_transform
89
+ self.clip_standard_transform = clip_standard_transform
90
+
91
+ self.all_one = all_one
92
+
93
+ def __len__(self):
94
+ return len(self.anns)
95
+
96
+ def __getitem__(self, index):
97
+ ann = self.anns[index]
98
+ image = cv2.imread(self.root_pth + ann['image_pth'])
99
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
100
+
101
+ mask = maskUtils.decode(ann['mask'])
102
+ rgba = np.concatenate((image, np.expand_dims(mask, axis=-1)), axis=-1)
103
+ h, w = rgba.shape[:2]
104
+
105
+ if self.choice == "padding":
106
+ if max(h, w) == w:
107
+ pad = (w - h) // 2
108
+ l, r = pad, w - h - pad
109
+ rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0)
110
+ else:
111
+ pad = (h - w) // 2
112
+ l, r = pad, h - w - pad
113
+ rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0)
114
+ else:
115
+ if min(h, w) == h:
116
+ rgba = crop_center(rgba, h, h)
117
+ else:
118
+ rgba = crop_center(rgba, w, w)
119
+ rgb = rgba[:, :, :-1]
120
+ mask = rgba[:, :, -1]
121
+ image_torch = self.clip_standard_transform(rgb)
122
+ bi_mask = mask == 1
123
+ h, w = bi_mask.shape[-2:]
124
+ in_height = np.max(bi_mask, axis=-1)
125
+ in_height_coords = np.max(bi_mask, axis=-1) * np.arange(h)
126
+ b_e = in_height_coords.max()
127
+ in_height_coords = in_height_coords + h * (~in_height)
128
+ t_e = in_height_coords.min()
129
+ in_width = np.max(bi_mask, axis=-2)
130
+ in_width_coords = np.max(bi_mask, axis=-2) * np.arange(w)
131
+ r_e = in_width_coords.max()
132
+ in_width_coords = in_width_coords + w * (~in_width)
133
+ l_e = in_width_coords.min()
134
+ if self.all_one:
135
+ mask_torch = self.mask_transform(np.ones_like(mask) * 255)
136
+ else:
137
+ mask_torch = self.mask_transform(mask * 255)
138
+
139
+ return image_torch, mask_torch, ann['cat_index']
140
+
141
+ if __name__ == "__main__":
142
+ data = Imagenet_S()
143
+ for i in tqdm(range(data.__len__())):
144
+ data.__getitem__(i)
mask_image.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+ from pycocotools.coco import COCO
7
+ from pycocotools import mask as maskUtils
8
+
9
+ from PIL import Image
10
+ from PIL import ImageFile
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+ Image.MAX_IMAGE_PIXELS = None
13
+ from tqdm import tqdm
14
+ from torchvision import transforms
15
+ from tqdm import tqdm
16
+ import pickle
17
+ import cv2
18
+ import torch
19
+ import numpy as np
20
+ import copy
21
+ from transformers import AutoProcessor
22
+ from nltk.corpus import wordnet
23
+ from bg_aug import get_bkgd
24
+ import jax
25
+ import random
26
+
27
+ clip_standard_transform = transforms.Compose([
28
+ transforms.ToTensor(),
29
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
30
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
31
+ ])
32
+ to_tensor = transforms.ToTensor()
33
+
34
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
35
+
36
+ mask_transform = transforms.Compose([
37
+ transforms.ToTensor(),
38
+ transforms.Resize((224, 224)),
39
+ transforms.Normalize(0.5, 0.26)
40
+ ])
41
+
42
+ crop_aug = transforms.Compose([
43
+ transforms.RandomCrop((224-32, 224-32)),
44
+ transforms.Resize((224, 224)),
45
+ ])
46
+
47
+ def text_filter(text):
48
+ text = text.replace(' with a white background', '')
49
+ text = text.replace(' with white background', '')
50
+ text = text.replace(' next to a white background', '')
51
+ text = text.replace(' over a white background', '')
52
+ text = text.replace(' is cut out of a white background', '')
53
+ text = text.replace(' across a white background', '')
54
+ text = text.replace(' on a white background', '')
55
+ text = text.replace(' sticking out of a white background', '')
56
+ text = text.replace(' in the middle of a white background', '')
57
+ text = text.replace(' on white background', '')
58
+ text = text.replace(' in a white background', '')
59
+ text = text.replace(' and a white background', '')
60
+ text = text.replace(' and white background', '')
61
+ text = text.replace(' in front of a white background', '')
62
+ text = text.replace(' on top of a white background', '')
63
+ text = text.replace(' against a white background', '')
64
+ text = text.replace('a white background with ', '')
65
+ text = text.replace(' and has a white background', '')
66
+ text = text.replace('white background', 'background')
67
+ text = text + '.'
68
+ return text
69
+
70
+ def crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, scale=1.5):
71
+ tl_x = int(bbox_xywh[0])
72
+ tl_y = int(bbox_xywh[1])
73
+ w = int(bbox_xywh[2]) if int(bbox_xywh[2]) > 0 else 1
74
+ h = int(bbox_xywh[3]) if int(bbox_xywh[3]) > 0 else 1
75
+ image_h, image_w = image.shape[:2]
76
+
77
+ # shape maintained
78
+ r = max(h, w)
79
+ tl_x -= (r - w) / 2
80
+ tl_y -= (r - h) / 2
81
+ half_scale = (scale - 1.0) / 2
82
+ w_l = int(tl_x - half_scale * r) if (tl_x - half_scale * r) > 0 else 0
83
+ w_r = int(tl_x + (1+half_scale) * r) if (tl_x + (1+half_scale) * r) < image_w else image_w - 1
84
+ h_t = int(tl_y - half_scale * r) if (tl_y - half_scale * r) > 0 else 0
85
+ h_b = int(tl_y + (1+half_scale) * r) if (tl_y + (1+half_scale) * r) < image_h else image_h - 1
86
+
87
+ return image[h_t: h_b, w_l: w_r, :], bi_mask[h_t: h_b, w_l: w_r]
88
+
89
+ def masked_crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, crop_scale=1.0, masked_color=[255, 255, 255]):
90
+ # padding to make_sure bboxshape maintained
91
+ image = np.pad(image, ((600, 600), (600, 600), (0, 0)), 'constant', constant_values=255)
92
+ bi_mask = np.pad(bi_mask, ((600, 600), (600, 600)), "constant", constant_values=0)
93
+ bbox_xywh[:2] += 600
94
+ cropped_image, cropped_mask = crop(image, bbox_xywh, bi_mask, crop_scale)
95
+ cropped_image[np.nonzero(cropped_mask == 0)] = masked_color
96
+ return cropped_image, cropped_mask
97
+
98
+ class ImageNet_Masked(Dataset):
99
+ def __init__(self, ann_file="M_ImageNet_top_460k.json", masked_color=[255, 255, 255]):
100
+ self.masked_color = masked_color
101
+ self.anns_list = json.load(open(ann_file, 'r'))
102
+ random.shuffle(self.anns_list)
103
+ self.crop_scale = 1.5
104
+ self.transform = clip_standard_transform
105
+ self.res = 224
106
+ self.blur = 10.0
107
+
108
+ def __len__(self):
109
+ return len(self.anns_list)
110
+
111
+ def __getitem__(self, index):
112
+ cv2.ocl.setUseOpenCL(False)
113
+ cv2.setNumThreads(0)
114
+ ann = self.anns_list[index]
115
+ # TODO: change list to dict key.
116
+ img_pth = ann[2]
117
+ # img_pth = img_pth.replace('imagenet-21k/images', 'imagenet-21k-demo/*')
118
+ mask = ann[3]
119
+ bbox = ann[4]
120
+ text = ann[6]
121
+ image = cv2.imread(img_pth)
122
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
123
+ bbox_xywh = np.copy(np.array(bbox))
124
+ binary_mask = maskUtils.decode(mask)
125
+ cat_word = img_pth.split("/")[3]
126
+ synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:]))
127
+ synonyms = [x.name() for x in synset.lemmas()]
128
+ text = text.replace(".", f", probably {synonyms[0]}").replace(" ", "_").replace("/", "_").replace("\\", "_")
129
+ image[np.nonzero(binary_mask == 1)] = (0.5 * image[np.nonzero(binary_mask == 1)] + 0.5 * np.array([0, 255, 0])).astype(np.uint8)
130
+ os.makedirs(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0], exist_ok=True)
131
+ Image.fromarray(image).save(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0] + f"/{text}_" + os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[1])
132
+
133
+ if __name__ == "__main__":
134
+ data = ImageNet_Masked()
135
+ for i in tqdm(range(data.__len__())):
136
+ data.__getitem__(i)
mask_image_test.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+ import alpha_clip
6
+ from torch.utils.data import Dataset
7
+ from pycocotools.coco import COCO
8
+ from pycocotools import mask as maskUtils
9
+ from lvis import LVIS
10
+ from PIL import Image
11
+ from PIL import ImageFile
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+ Image.MAX_IMAGE_PIXELS = None
14
+ from tqdm import tqdm
15
+ from torchvision import transforms
16
+ from tqdm import tqdm
17
+ import pickle
18
+ import cv2
19
+ import torch
20
+ import numpy as np
21
+ import copy
22
+ from transformers import AutoProcessor
23
+ try:
24
+ from torchvision.transforms import InterpolationMode
25
+ BICUBIC = InterpolationMode.BICUBIC
26
+ except ImportError:
27
+ BICUBIC = Image.BICUBIC
28
+ PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
29
+ MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]
30
+ def _convert_image_to_rgb(image):
31
+ return image.convert("RGB")
32
+ clip_standard_transform = transforms.Compose([
33
+ transforms.ToTensor(),
34
+ transforms.Resize((224, 224), interpolation=Image.BICUBIC),
35
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
36
+ ])
37
+
38
+ hi_clip_standard_transform = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ transforms.Resize((336, 336), interpolation=Image.BICUBIC),
41
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
42
+ ])
43
+
44
+ mask_transform = transforms.Compose([
45
+ transforms.ToTensor(),
46
+ transforms.Resize((224, 224)),
47
+ transforms.Normalize(0.5, 0.26)
48
+ ])
49
+
50
+ hi_mask_transform = transforms.Compose([
51
+ transforms.ToTensor(),
52
+ transforms.Resize((336, 336)),
53
+ transforms.Normalize(0.5, 0.26)
54
+ ])
55
+
56
+ def crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, scale=1.5):
57
+ tl_x = int(bbox_xywh[0])
58
+ tl_y = int(bbox_xywh[1])
59
+ w = int(bbox_xywh[2]) if int(bbox_xywh[2]) > 0 else 1
60
+ h = int(bbox_xywh[3]) if int(bbox_xywh[3]) > 0 else 1
61
+ image_h, image_w = image.shape[:2]
62
+
63
+ # shape maintained
64
+ r = max(h, w)
65
+ tl_x -= (r - w) / 2
66
+ tl_y -= (r - h) / 2
67
+ half_scale = (scale - 1.0) / 2
68
+ w_l = int(tl_x - half_scale * r) if (tl_x - half_scale * r) > 0 else 0
69
+ w_r = int(tl_x + (1+half_scale) * r) if (tl_x + (1+half_scale) * r) < image_w else image_w - 1
70
+ h_t = int(tl_y - half_scale * r) if (tl_y - half_scale * r) > 0 else 0
71
+ h_b = int(tl_y + (1+half_scale) * r) if (tl_y + (1+half_scale) * r) < image_h else image_h - 1
72
+
73
+ return image[h_t: h_b, w_l: w_r, :], bi_mask[h_t: h_b, w_l: w_r]
74
+
75
+ def masked_crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, crop_scale=1.0, masked_color=[255, 255, 255]):
76
+ # padding to make_sure bboxshape maintained
77
+ image = np.pad(image, ((600, 600), (600, 600), (0, 0)), 'constant', constant_values=255)
78
+ bi_mask = np.pad(bi_mask, ((600, 600), (600, 600)), "constant", constant_values=0)
79
+ bbox_xywh[:2] += 600
80
+ cropped_image, cropped_mask = crop(image, bbox_xywh, bi_mask, crop_scale)
81
+ # cropped_image[np.nonzero(cropped_mask == 0)] = MASK_FILL
82
+ return cropped_image, cropped_mask
83
+
84
+ class COCO_Masked_Test(Dataset):
85
+ def __init__(self, ann_file="data/coco/annotations/instances_val2017.json", masked_color=[255, 255, 255], root_directory="data/coco/val2017", hi_res=False):
86
+ self.masked_color = masked_color
87
+ self.coco = COCO(annotation_file=ann_file)
88
+ self.image_directory = root_directory
89
+ self.crop_scale = 1.5
90
+ self.anns_list = list(self.coco.anns.keys())
91
+ self.index2id = [x['id'] for x in self.coco.cats.values()]
92
+ self.id2index = dict()
93
+ for i, item in enumerate(self.index2id):
94
+ self.id2index[item] = i
95
+ self.class_num = 80
96
+ self.classes = [x['name'] for x in self.coco.cats.values()]
97
+
98
+ if hi_res:
99
+ self.mask_transform = hi_mask_transform
100
+ self.clip_standard_transform = hi_clip_standard_transform
101
+ else:
102
+ self.mask_transform = mask_transform
103
+ self.clip_standard_transform = clip_standard_transform
104
+
105
+ def __len__(self):
106
+ return len(self.anns_list)
107
+
108
+ def __getitem__(self, index):
109
+ ann_id = self.anns_list[index]
110
+ ann = self.coco.anns[ann_id]
111
+ img_id = self.coco.anns[ann_id]['image_id']
112
+ image = np.array(Image.open(os.path.join(self.image_directory, self.coco.imgs[img_id]['file_name'])).convert('RGB'))
113
+ bbox_xywh = np.copy(np.array(ann['bbox']))
114
+ binary_mask = self.coco.annToMask(ann)
115
+ cropped_image, cropped_mask = masked_crop(image, bbox_xywh, binary_mask, crop_scale=self.crop_scale, masked_color=self.masked_color)
116
+ image = self.clip_standard_transform(cropped_image)
117
+ mask_torch = self.mask_transform(cropped_mask * 255)
118
+ return image, mask_torch, self.id2index[ann['category_id']]
119
+
120
+ class LVIS_Masked_Test(Dataset):
121
+ def __init__(self, ann_file="data/lvis/annotations/lvis_v1_val.json", masked_color=[255, 255, 255], hi_res=False):
122
+ self.masked_color = masked_color
123
+ self.lvis = LVIS(ann_file)
124
+ self.crop_scale = 1.5
125
+ self.anns_list = list(self.lvis.anns.keys())
126
+ self.index2id = [x['id'] for x in self.lvis.cats.values()]
127
+ self.id2index = dict()
128
+ for i, item in enumerate(self.index2id):
129
+ self.id2index[item] = i
130
+ self.class_num = 1203
131
+ self.classes = [x['name'] for x in self.lvis.cats.values()]
132
+
133
+ if hi_res:
134
+ self.mask_transform = hi_mask_transform
135
+ self.clip_standard_transform = hi_clip_standard_transform
136
+ else:
137
+ self.mask_transform = mask_transform
138
+ self.clip_standard_transform = clip_standard_transform
139
+
140
+ def __len__(self):
141
+ return len(self.anns_list)
142
+
143
+ def __getitem__(self, index):
144
+ ann_id = self.anns_list[index]
145
+ ann = self.lvis.anns[ann_id]
146
+ img_id = self.lvis.anns[ann_id]['image_id']
147
+ image = np.array(Image.open(self.lvis.imgs[img_id]['coco_url'].replace('http://images.cocodataset.org', 'data/coco')).convert('RGB'))
148
+ binary_mask = self.lvis.ann_to_mask(ann)
149
+ rgba = np.concatenate((image, np.expand_dims(binary_mask, axis=-1)), axis=-1)
150
+ h, w = rgba.shape[:2]
151
+ if max(h, w) == w:
152
+ pad = (w - h) // 2
153
+ l, r = pad, w - h - pad
154
+ rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0)
155
+ else:
156
+ pad = (h - w) // 2
157
+ l, r = pad, h - w - pad
158
+ rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0)
159
+ rgb = rgba[:, :, :-1]
160
+ mask = rgba[:, :, -1]
161
+ image = self.clip_standard_transform(rgb)
162
+ mask_torch = self.mask_transform(mask * 255)
163
+ return image, mask_torch, self.id2index[ann['category_id']],
164
+
165
+ class RGBD:
166
+ def __init__(self, annotation_file=None):
167
+ self.anns, self.imgs, self.answers, self.types = defaultdict(list), dict(), dict(), dict()
168
+ if not annotation_file == None:
169
+ with open(annotation_file, 'r') as reader:
170
+ datas = json.load(reader)
171
+ for data in datas:
172
+ self.anns[data['id']] = data['captions']
173
+ self.imgs[data['id']] = data['image']
174
+ self.answers[data['id']] = data['answer']
175
+ self.types[data['id']] = data['type']
176
+
177
+ class RGBD_Outdoor_Benchmark(Dataset):
178
+ def __init__(self, root_dir,tasks):
179
+ self.root_dir = root_dir
180
+ # import pdb;pdb.set_trace()
181
+ self.dataset = RGBD(os.path.join(root_dir, tasks))
182
+ self.image_ids = list(self.dataset.imgs.keys())
183
+ self.captions = [x for x in self.dataset.anns.values()]
184
+ self.depth_transform = transforms.Compose([
185
+ transforms.Resize((224, 224)),
186
+ transforms.ToTensor(),
187
+ ])
188
+ self.transform =clip_standard_transform
189
+ # self.transform = hi_clip_standard_transform
190
+ # self.depth_transform = transforms.Compose([
191
+ # transforms.Resize((336, 336)),
192
+ # transforms.ToTensor(),
193
+ # ])
194
+
195
+ def __len__(self):
196
+ return len(self.image_ids)
197
+
198
+ def __getitem__(self, idx):
199
+ if torch.is_tensor(idx):
200
+ idx = idx.tolist()
201
+
202
+ img_ids = self.image_ids[idx]
203
+ image_path = os.path.join(self.root_dir, 'pic_all', self.dataset.imgs[img_ids])
204
+ depth_path = os.path.join(self.root_dir, 'pic_depth' ,self.dataset.imgs[img_ids])
205
+ image = Image.open(image_path).convert('RGB')
206
+ depth = Image.open(depth_path).convert('L')
207
+
208
+ answer = self.dataset.answers[img_ids]
209
+
210
+ if self.transform:
211
+ image = self.transform(image)
212
+ if self.depth_transform:
213
+ depth = self.depth_transform(depth)
214
+ return image, depth, answer
215
+
216
+
217
+ class RGBD_Benchmark_Test(Dataset):
218
+ def __init__(self, root_dir):
219
+ self.root_dir = root_dir
220
+ self.dataset = RGBD(os.path.join(root_dir, 'annotations.json'))
221
+ self.image_ids = list(self.dataset.imgs.keys())
222
+ self.captions = [x for x in self.dataset.anns.values()]
223
+ # self.transform = transforms.Compose([
224
+ # transforms.Resize((224, 224)),
225
+ # transforms.ToTensor(),
226
+ # ])
227
+ self.transform =clip_standard_transform
228
+ self.depth_transform = transforms.Compose([
229
+ transforms.Resize((224, 224)),
230
+ transforms.ToTensor(),
231
+ ])
232
+
233
+ def __len__(self):
234
+ return len(self.image_ids)
235
+
236
+ def __getitem__(self, idx):
237
+ if torch.is_tensor(idx):
238
+ idx = idx.tolist()
239
+
240
+ img_ids = self.image_ids[idx]
241
+ image_path = os.path.join(self.root_dir, 'all_pic', self.dataset.imgs[img_ids])
242
+ depth_path = os.path.join(self.root_dir, 'depth-new' ,self.dataset.imgs[img_ids])
243
+ image = Image.open(image_path).convert('RGB')
244
+ depth = Image.open(depth_path).convert('L')
245
+
246
+ answer = self.dataset.answers[img_ids]
247
+
248
+ if self.transform:
249
+ image = self.transform(image)
250
+ if self.depth_transform:
251
+ depth = self.depth_transform(depth)
252
+ return image, depth, answer
253
+
254
+ class RGBD_Benchmark_Test2(Dataset):
255
+ def __init__(self, root_dir):
256
+ self.root_dir = root_dir
257
+ self.dataset = RGBD(os.path.join(root_dir, 'annotations2.json'))
258
+ self.image_ids = list(self.dataset.imgs.keys())
259
+ self.captions = [x for x in self.dataset.anns.values()]
260
+ # self.transform = transforms.Compose([
261
+ # transforms.Resize((224, 224)),
262
+ # transforms.ToTensor(),
263
+ # ])
264
+ self.transform =clip_standard_transform
265
+
266
+ self.depth_transform = transforms.Compose([
267
+ transforms.Resize((224, 224)),
268
+ transforms.ToTensor(),
269
+ ])
270
+
271
+ def __len__(self):
272
+ return len(self.image_ids)
273
+
274
+ def __getitem__(self, idx):
275
+ if torch.is_tensor(idx):
276
+ idx = idx.tolist()
277
+
278
+ img_ids = self.image_ids[idx]
279
+ image_path = os.path.join(self.root_dir, 'all_pic', self.dataset.imgs[img_ids])
280
+ depth_path = os.path.join(self.root_dir, 'depth-new' ,self.dataset.imgs[img_ids])
281
+ image = Image.open(image_path).convert('RGB')
282
+ depth = Image.open(depth_path).convert('L')
283
+
284
+ answer = self.dataset.answers[img_ids]
285
+
286
+ if self.transform:
287
+ image = self.transform(image)
288
+ if self.depth_transform:
289
+ depth = self.depth_transform(depth)
290
+ return image, depth, answer
291
+ class ScanRefer:
292
+ def __init__(self, annotation_file=None):
293
+ self.anns, self.imgs, self.answers, self.scene_id = defaultdict(list), dict(), dict(), dict()
294
+ if not annotation_file == None:
295
+ with open(annotation_file, 'r') as reader:
296
+ datas = json.load(reader)
297
+ for data in datas:
298
+ self.anns[data['unique_id']] = data['descriptions']
299
+ self.imgs[data['unique_id']] = data['image']
300
+ self.answers[data['unique_id']] = data['answer']
301
+ self.scene_id[data['unique_id']] = data['scene_id']
302
+
303
+ class ScanRefer_Test(Dataset):
304
+ def __init__(self, root_dir, model):
305
+ self.root_dir = root_dir
306
+ self.dataset = ScanRefer(os.path.join(root_dir, 'scanrefer_annotations_all.json'))
307
+ # self.dataset = ScanRefer(root_dir)
308
+ self.model = model
309
+ self.image_ids = list(self.dataset.imgs.keys())
310
+ # self.transform = transforms.Compose([
311
+ # transforms.Resize((224, 224)),
312
+ # transforms.ToTensor(),
313
+ # ])
314
+ self.transform = transforms.Compose([
315
+ transforms.Resize(224, interpolation=BICUBIC),
316
+ transforms.CenterCrop(224),
317
+ _convert_image_to_rgb,
318
+ transforms.ToTensor(),
319
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
320
+ ])
321
+ self.depth_transform = transforms.Compose([
322
+ transforms.Resize((224, 224)),
323
+ transforms.ToTensor(),
324
+ ])
325
+
326
+ def __len__(self):
327
+ return len(self.image_ids)
328
+
329
+ def __getitem__(self, idx):
330
+ if torch.is_tensor(idx):
331
+ idx = idx.tolist()
332
+
333
+ img_ids = self.image_ids[idx]
334
+ image_path = os.path.join(self.root_dir, self.dataset.scene_id[img_ids], 'color', self.dataset.imgs[img_ids])
335
+ depth_path = os.path.join(self.root_dir, self.dataset.scene_id[img_ids], 'depth', self.dataset.imgs[img_ids].split('.')[0] + '.png')
336
+
337
+ image = Image.open(image_path).convert('RGB')
338
+ depth = Image.open(depth_path).convert('L')
339
+
340
+ if self.transform:
341
+ image = self.transform(image)
342
+ if self.depth_transform:
343
+ depth = self.depth_transform(depth)
344
+
345
+ caption = self.dataset.anns[img_ids]
346
+ texts = alpha_clip.tokenize(caption).cuda()
347
+ text_embeddings = self.model.encode_text(texts)
348
+ text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
349
+
350
+ answer = self.dataset.answers[img_ids]
351
+ return image, depth, text_embeddings, answer
352
+
353
+ class ScanRefer_Test2(Dataset):
354
+ def __init__(self, root_dir, model):
355
+ self.root_dir = root_dir
356
+ self.dataset = ScanRefer(os.path.join(root_dir, 'annotations_2.json'))
357
+ # self.dataset = ScanRefer(root_dir)
358
+ self.model = model
359
+ self.image_ids = list(self.dataset.imgs.keys())
360
+ # self.transform = transforms.Compose([
361
+ # transforms.Resize((224, 224)),
362
+ # transforms.ToTensor(),
363
+ # ])
364
+ self.transform = transforms.Compose([
365
+ transforms.Resize(224, interpolation=BICUBIC),
366
+ transforms.CenterCrop(224),
367
+ _convert_image_to_rgb,
368
+ transforms.ToTensor(),
369
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
370
+ ])
371
+ self.depth_transform = transforms.Compose([
372
+ transforms.Resize((224, 224)),
373
+ transforms.ToTensor(),
374
+ ])
375
+
376
+ def __len__(self):
377
+ return len(self.image_ids)
378
+
379
+ def __getitem__(self, idx):
380
+ if torch.is_tensor(idx):
381
+ idx = idx.tolist()
382
+
383
+ img_ids = self.image_ids[idx]
384
+ image_path = os.path.join(self.root_dir, self.dataset.scene_id[img_ids], 'color', self.dataset.imgs[img_ids])
385
+ depth_path = os.path.join(self.root_dir, self.dataset.scene_id[img_ids], 'depth', self.dataset.imgs[img_ids].split('.')[0] + '.png')
386
+
387
+ image = Image.open(image_path).convert('RGB')
388
+ depth = Image.open(depth_path).convert('L')
389
+
390
+ if self.transform:
391
+ image = self.transform(image)
392
+ if self.depth_transform:
393
+ depth = self.depth_transform(depth)
394
+
395
+ caption = self.dataset.anns[img_ids]
396
+ texts = alpha_clip.tokenize(caption).cuda()
397
+ text_embeddings = self.model.encode_text(texts)
398
+ text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
399
+
400
+ answer = self.dataset.answers[img_ids]
401
+ return image, depth, text_embeddings, answer
402
+
403
+ class ScanRefer_Testnr3d(Dataset):
404
+ def __init__(self, root_dir, model):
405
+ self.root_dir = root_dir
406
+ self.dataset = ScanRefer(os.path.join(root_dir, 'nr3d_annotations.json'))
407
+ # self.dataset = ScanRefer(root_dir)
408
+ self.model = model
409
+ self.image_ids = list(self.dataset.imgs.keys())
410
+ # self.transform = transforms.Compose([
411
+ # transforms.Resize((224, 224)),
412
+ # transforms.ToTensor(),
413
+ # ])
414
+ self.transform = transforms.Compose([
415
+ transforms.Resize(224, interpolation=BICUBIC),
416
+ transforms.CenterCrop(224),
417
+ _convert_image_to_rgb,
418
+ transforms.ToTensor(),
419
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
420
+ ])
421
+ self.depth_transform = transforms.Compose([
422
+ transforms.Resize((224, 224)),
423
+ transforms.ToTensor(),
424
+ ])
425
+
426
+ def __len__(self):
427
+ return len(self.image_ids)
428
+
429
+ def __getitem__(self, idx):
430
+ if torch.is_tensor(idx):
431
+ idx = idx.tolist()
432
+
433
+ img_ids = self.image_ids[idx]
434
+ image_path = os.path.join(self.root_dir, self.dataset.scene_id[img_ids], 'color', self.dataset.imgs[img_ids])
435
+ depth_path = os.path.join(self.root_dir, self.dataset.scene_id[img_ids], 'depth', self.dataset.imgs[img_ids].split('.')[0] + '.png')
436
+
437
+ image = Image.open(image_path).convert('RGB')
438
+ depth = Image.open(depth_path).convert('L')
439
+
440
+ if self.transform:
441
+ image = self.transform(image)
442
+ if self.depth_transform:
443
+ depth = self.depth_transform(depth)
444
+
445
+ caption = self.dataset.anns[img_ids]
446
+ texts = alpha_clip.tokenize(caption).cuda()
447
+ text_embeddings = self.model.encode_text(texts)
448
+ text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
449
+
450
+ answer = self.dataset.answers[img_ids]
451
+ return image, depth, text_embeddings, answer
452
+
453
+
454
+ if __name__ == "__main__":
455
+ data = LVIS_Masked_Test()
456
+ for i in tqdm(range(data.__len__())):
457
+ data.__getitem__(i)