dianecy commited on
Commit
7e3a804
·
verified ·
1 Parent(s): e290a7d

Upload folder using huggingface_hub

Browse files
ASDA/dataset/__pycache__/data_loader.cpython-39.pyc ADDED
Binary file (7.71 kB). View file
 
ASDA/dataset/__pycache__/data_loader_gref_sbert.cpython-39.pyc ADDED
Binary file (8.75 kB). View file
 
ASDA/dataset/__pycache__/data_loader_rccp.cpython-39.pyc ADDED
Binary file (7 kB). View file
 
ASDA/dataset/__pycache__/data_loader_test.cpython-39.pyc ADDED
Binary file (7.55 kB). View file
 
ASDA/dataset/__pycache__/refer.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
ASDA/dataset/data.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # data process
3
+ python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcoco --split unc --generate_mask
4
+ python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcoco+ --split unc --generate_mask
5
+ python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcocog --split google --generate_mask
6
+ python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcocog --split umd --generate_mask
7
+
8
+ # datascript
9
+ python datascript.py --dataset refcoco
10
+ python datascript.py --dataset refcoco+
11
+ python datascript.py --dataset refcocog_google
12
+ python datascript.py --dataset refcocog_umd
ASDA/dataset/data_loader.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
5
+ """
6
+ import sys
7
+ import cv2
8
+ import torch
9
+ import random
10
+ import numpy as np
11
+ import os.path as osp
12
+ import torch.utils.data as data
13
+ sys.path.append('.')
14
+ import utils
15
+ import re
16
+
17
+ from pytorch_pretrained_bert.tokenization import BertTokenizer
18
+ from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
19
+ import copy
20
+
21
+ import clip
22
+
23
+ sys.modules['utils'] = utils
24
+ cv2.setNumThreads(0)
25
+
26
+ def read_examples(input_line, unique_id):
27
+ """Read a list of `InputExample`s from an input file."""
28
+ examples = []
29
+ # unique_id = 0
30
+ line = input_line #reader.readline()
31
+ # if not line:
32
+ # break
33
+ line = line.strip()
34
+ text_a = None
35
+ text_b = None
36
+ m = re.match(r"^(.*) \|\|\| (.*)$", line)
37
+ if m is None:
38
+ text_a = line
39
+ else:
40
+ text_a = m.group(1) #'man in black'
41
+ text_b = m.group(2)
42
+
43
+ examples.append(
44
+ InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
45
+ # unique_id += 1
46
+ return examples
47
+
48
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
49
+ while True:
50
+ total_length = len(tokens_a) + len(tokens_b)
51
+ if total_length <= max_length:
52
+ break
53
+ if len(tokens_a) > len(tokens_b):
54
+ tokens_a.pop()
55
+ else:
56
+ tokens_b.pop()
57
+
58
+ ## Bert text encoding
59
+ class InputExample(object):
60
+ def __init__(self, unique_id, text_a, text_b):
61
+ self.unique_id = unique_id
62
+ self.text_a = text_a
63
+ self.text_b = text_b
64
+
65
+ class InputFeatures(object):
66
+ """A single set of features of data."""
67
+ def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
68
+ self.unique_id = unique_id
69
+ self.tokens = tokens
70
+ self.input_ids = input_ids
71
+ self.input_mask = input_mask
72
+ self.input_type_ids = input_type_ids
73
+
74
+ def convert_examples_to_features(examples, seq_length, tokenizer):
75
+ """Loads a data file into a list of `InputBatch`s."""
76
+ features = []
77
+ for (ex_index, example) in enumerate(examples):
78
+ tokens_a = tokenizer.tokenize(example.text_a) # ['far', 'left', 'vase']
79
+
80
+ tokens_b = None
81
+ if example.text_b:
82
+ tokens_b = tokenizer.tokenize(example.text_b)
83
+
84
+ if tokens_b:
85
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
86
+ # length is less than the specified length.
87
+ # Account for [CLS], [SEP], [SEP] with "- 3"
88
+ _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
89
+ else:
90
+ # Account for [CLS] and [SEP] with "- 2"
91
+ if len(tokens_a) > seq_length - 2:
92
+ tokens_a = tokens_a[0:(seq_length - 2)]
93
+ tokens = []
94
+ input_type_ids = []
95
+ tokens.append("[CLS]")
96
+ input_type_ids.append(0)
97
+ for token in tokens_a:
98
+ tokens.append(token)
99
+ input_type_ids.append(0)
100
+ tokens.append("[SEP]")
101
+ input_type_ids.append(0)
102
+
103
+ if tokens_b:
104
+ for token in tokens_b:
105
+ tokens.append(token)
106
+ input_type_ids.append(1)
107
+ tokens.append("[SEP]")
108
+ input_type_ids.append(1)
109
+
110
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
111
+
112
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
113
+ # tokens are attended to.
114
+ input_mask = [1] * len(input_ids)
115
+
116
+ # Zero-pad up to the sequence length.
117
+ while len(input_ids) < seq_length:
118
+ input_ids.append(0)
119
+ input_mask.append(0)
120
+ input_type_ids.append(0)
121
+
122
+ assert len(input_ids) == seq_length
123
+ assert len(input_mask) == seq_length
124
+ assert len(input_type_ids) == seq_length
125
+ features.append(
126
+ InputFeatures(
127
+ unique_id=example.unique_id,
128
+ tokens=tokens,
129
+ input_ids=input_ids,
130
+ input_mask=input_mask,
131
+ input_type_ids=input_type_ids))
132
+ return features
133
+
134
+ class DatasetNotFoundError(Exception):
135
+ pass
136
+
137
+ class ReferDataset(data.Dataset):
138
+ SUPPORTED_DATASETS = {
139
+ 'refcoco': {
140
+ 'splits': ('train', 'val', 'testA', 'testB'),
141
+ 'params': {'dataset': 'refcoco', 'split_by': 'unc'}
142
+ },
143
+ 'refcoco+': {
144
+ 'splits': ('train', 'val', 'testA', 'testB'),
145
+ 'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
146
+ },
147
+ 'refcocog': {
148
+ 'splits': ('train', 'val', 'test'),
149
+ 'params': {'dataset': 'refcocog', 'split_by': 'unc'}
150
+ },
151
+ 'refcocog_g': {
152
+ 'splits': ('train', 'val'),
153
+ 'params': {'dataset': 'refcocog', 'split_by': 'google'}
154
+ },
155
+ 'refcocog_u': {
156
+ 'splits': ('train', 'val', 'test'),
157
+ 'params': {'dataset': 'refcocog', 'split_by': 'unc'}
158
+ },
159
+ 'grefcoco': {
160
+ 'splits': ('train', 'val', 'testA', 'testB'),
161
+ 'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
162
+ }
163
+ }
164
+
165
+ def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
166
+ transform=None, augment=False, split='train', max_query_len=128,
167
+ bert_model='bert-base-uncased'):
168
+ self.images = []
169
+ self.data_root = data_root
170
+ self.split_root = split_root
171
+ self.dataset = dataset
172
+ self.imsize = imsize
173
+ self.query_len = max_query_len
174
+ self.transform = transform
175
+ self.split = split
176
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True) # should be true for English
177
+ self.augment=augment
178
+
179
+ valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
180
+
181
+ if split not in valid_splits:
182
+ raise ValueError(
183
+ 'Dataset {0} does not have split {1}'.format(
184
+ self.dataset, split))
185
+
186
+ self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
187
+ if self.dataset == 'refcocog' :
188
+ mask_anno_str = '{0}_{1}'.format(self.dataset, splitby)
189
+ self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
190
+ else :
191
+ self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
192
+
193
+ self.im_dir = osp.join(self.data_root, 'images', 'train2014')
194
+
195
+
196
+ if self.dataset == 'refcocog' :
197
+ dataset_path = osp.join(self.split_root, self.dataset + '_' + splitby)
198
+ splits = [split]
199
+ for split in splits:
200
+ imgset_file = '{0}_{1}_{2}.pth'.format(self.dataset, splitby, split)
201
+ imgset_path = osp.join(dataset_path, imgset_file)
202
+ self.images += torch.load(imgset_path)
203
+ else :
204
+ dataset_path = osp.join(self.split_root, self.dataset)
205
+ splits = [split]
206
+ for split in splits:
207
+ imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
208
+ imgset_path = osp.join(dataset_path, imgset_file)
209
+ self.images += torch.load(imgset_path)
210
+
211
+ def exists_dataset(self):
212
+ return osp.exists(osp.join(self.split_root, self.dataset))
213
+
214
+ def pull_item(self, idx):
215
+ img_file, seg_id, bbox, phrase = self.images[idx]
216
+ bbox = np.array(bbox, dtype=int) # x1y1x2y2
217
+
218
+ img_path = osp.join(self.im_dir, img_file)
219
+ img = cv2.imread(img_path) # BGR [512, 640, 3]
220
+ ## duplicate channel if gray image
221
+ if img.shape[-1] > 1:
222
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
223
+ else:
224
+ img = np.stack([img] * 3)
225
+
226
+ ## seg map
227
+ seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
228
+ seg_map = np.array(seg_map).astype(np.float32)
229
+ return img, phrase, bbox, seg_map
230
+
231
+ def __len__(self):
232
+ return len(self.images)
233
+
234
+ def __getitem__(self, idx):
235
+ img, phrase, bbox, seg_map = self.pull_item(idx)
236
+ phrase = phrase.lower()
237
+ if self.augment:
238
+ augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
239
+ True, True, True, False, False, False
240
+
241
+ ## seems a bug in torch transformation resize, so separate in advance
242
+ h,w = img.shape[0], img.shape[1]
243
+ # print("img.shape", img.shape)
244
+ if self.augment:
245
+ ## random horizontal flip
246
+ if augment_flip and random.random() > 0.5:
247
+ img = cv2.flip(img, 1)
248
+ seg_map = cv2.flip(seg_map, 1)
249
+ bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
250
+ phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
251
+
252
+ ## random copy and add left or right
253
+ if augment_copy:
254
+ img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
255
+
256
+ ## random erase for occluded
257
+ if augment_erase:
258
+ img, seg_map = random_erase(img, seg_map)
259
+
260
+ ## random padding and crop
261
+ if augment_crop:
262
+ img, seg_map = random_crop(img, seg_map, 40, h, w)
263
+
264
+ ## random intensity, saturation change
265
+ if augment_hsv:
266
+ fraction = 0.50
267
+ img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
268
+ S = img_hsv[:, :, 1].astype(np.float32)
269
+ V = img_hsv[:, :, 2].astype(np.float32)
270
+ a = (random.random() * 2 - 1) * fraction + 1
271
+ if a > 1:
272
+ np.clip(S, a_min=0, a_max=255, out=S)
273
+ a = (random.random() * 2 - 1) * fraction + 1
274
+ V *= a
275
+ if a > 1:
276
+ np.clip(V, a_min=0, a_max=255, out=V)
277
+
278
+ img_hsv[:, :, 1] = S.astype(np.uint8)
279
+ img_hsv[:, :, 2] = V.astype(np.uint8)
280
+ img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
281
+
282
+ img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
283
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
284
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
285
+
286
+ ## random affine transformation
287
+ if augment_affine:
288
+ img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
289
+ degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
290
+
291
+ else: ## should be inference, or specified training
292
+ img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
293
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
294
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
295
+
296
+ draw_img = copy.deepcopy(img)
297
+ # Norm, to tensor
298
+ if self.transform is not None:
299
+ img = self.transform(img)
300
+
301
+ ## encode phrase to clip input
302
+ word_id = clip.tokenize(phrase, 17, truncate=True)
303
+ word_mask = ~ (word_id == 0)
304
+
305
+ if self.augment: # train
306
+ seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
307
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
308
+ return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
309
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
310
+ else:
311
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
312
+ return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
313
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
314
+ np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
ASDA/dataset/data_loader_gref_sbert.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
5
+ """
6
+ import sys
7
+ import cv2
8
+ import os
9
+ import torch
10
+ import json
11
+ import random
12
+ import numpy as np
13
+ import os.path as osp
14
+ import torch.utils.data as data
15
+ sys.path.append('.')
16
+ import utils
17
+ import re
18
+
19
+ # from pytorch_pretrained_bert.tokenization import BertTokenizer
20
+ from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
21
+ import copy
22
+
23
+ import clip
24
+
25
+ sys.modules['utils'] = utils
26
+ cv2.setNumThreads(0)
27
+
28
+ class ReferDataset(data.Dataset):
29
+ SUPPORTED_DATASETS = {
30
+ 'refcoco': {
31
+ 'splits': ('train', 'val', 'testA', 'testB'),
32
+ 'params': {'dataset': 'refcoco', 'split_by': 'unc'}
33
+ },
34
+ 'refcoco+': {
35
+ 'splits': ('train', 'val', 'testA', 'testB'),
36
+ 'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
37
+ },
38
+ 'refcocog': {
39
+ 'splits': ('train', 'val', 'test'),
40
+ 'params': {'dataset': 'refcocog', 'split_by': 'unc'}
41
+ },
42
+ 'refcocog_g': {
43
+ 'splits': ('train', 'val'),
44
+ 'params': {'dataset': 'refcocog', 'split_by': 'google'}
45
+ },
46
+ 'refcocog_u': {
47
+ 'splits': ('train', 'val', 'test'),
48
+ 'params': {'dataset': 'refcocog', 'split_by': 'unc'}
49
+ },
50
+ 'grefcoco': {
51
+ 'splits': ('train', 'val', 'testA', 'testB'),
52
+ 'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
53
+ }
54
+ }
55
+
56
+
57
+ def _load_multi_obj_ref_ids(self):
58
+ # Load multi-object reference IDs based on configurations
59
+ if not self.exclude_multiobj and not self.exclude_position :
60
+ return None
61
+ elif self.exclude_position:
62
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt')
63
+ elif self.exclude_multiobj :
64
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt')
65
+ with open(multiobj_path, 'r') as f:
66
+ return [int(line.strip()) for line in f.readlines()]
67
+
68
+ def _load_metadata(self):
69
+ # Load metadata for hard positive verb phrases, hard negative queries
70
+ # we set refined file as default option
71
+ hardpos_path = '/data2/projects/seunghoon/VerbRIS/CrossVLT/hardpos_verdict_gref_v4.json'
72
+ with open(hardpos_path, 'r', encoding='utf-8') as f:
73
+ hardpos_json = json.load(f)
74
+ return hardpos_json
75
+
76
+ def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
77
+ transform=None, augment=False, split='train', max_query_len=128, metric_learning=None):
78
+ images_tmp = []
79
+ self.data_root = data_root
80
+ self.split_root = split_root
81
+ self.dataset = dataset
82
+ self.imsize = imsize
83
+ self.query_len = max_query_len
84
+ self.transform = transform
85
+ self.word_len = 17
86
+ self.emb_size = 384
87
+ self.split = split
88
+ self.augment=augment
89
+
90
+ valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
91
+
92
+ if split not in valid_splits:
93
+ raise ValueError(
94
+ 'Dataset {0} does not have split {1}'.format(
95
+ self.dataset, split))
96
+
97
+ self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
98
+ if self.dataset == 'refcocog' :
99
+ mask_anno_str = '{0}_{1}'.format(self.dataset, splitby)
100
+ self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
101
+ else :
102
+ self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
103
+
104
+ self.im_dir = osp.join(self.data_root, 'images', 'train2014')
105
+
106
+
107
+ if self.dataset == 'refcocog' :
108
+ dataset_path = osp.join(self.split_root, self.dataset + '_' + splitby)
109
+ splits = [split]
110
+ for split in splits:
111
+ imgset_file = '{0}_{1}_{2}.pth'.format(self.dataset, splitby, split)
112
+ imgset_path = osp.join(dataset_path, imgset_file)
113
+ images_tmp += torch.load(imgset_path)
114
+
115
+ # metric learning options
116
+ self.ROOT = '/data2/projects/seunghoon/VerbRIS/VerbCentric_CY/'
117
+ self.all_hp_root = "/data2/dataset/RefCOCO/refcocog/SBERT_gref_umd"
118
+ # self.exclude_position = args.exclude_pos
119
+ self.exclude_position = True
120
+ self.exclude_multiobj = True
121
+ self.metric_learning = metric_learning
122
+
123
+ # self.metric_mode = args.metric_mode
124
+ self.hp_selection = 'strict'
125
+
126
+ # meta data loading
127
+ if self.metric_learning and self.split == 'train':
128
+ self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
129
+ self.hardpos_meta = self._load_metadata()
130
+
131
+ # make new self.images file with sentence idx and total sent num (per ref_id)
132
+ from collections import defaultdict
133
+ ref_sentence_counts = defaultdict(int)
134
+ for item in images_tmp:
135
+ ref_sentence_counts[item[1]] += 1
136
+
137
+ self.images = []
138
+ ref_sentence_indices = defaultdict(int)
139
+ for item in images_tmp:
140
+ img_name, seg_id, box, sentence = item
141
+ sent_index = ref_sentence_indices[seg_id]
142
+ total_sentences = ref_sentence_counts[seg_id]
143
+ self.images.append((img_name, seg_id, box, sentence, sent_index, total_sentences))
144
+ ref_sentence_indices[seg_id] += 1
145
+
146
+ else :
147
+ self.images = images_tmp
148
+ self.multi_obj_ref_ids = None
149
+ self.hardpos_meta = None
150
+
151
+ else :
152
+ dataset_path = osp.join(self.split_root, self.dataset)
153
+ splits = [split]
154
+ for split in splits:
155
+ imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
156
+ imgset_path = osp.join(dataset_path, imgset_file)
157
+ self.images += torch.load(imgset_path)
158
+
159
+ def exists_dataset(self):
160
+ return osp.exists(osp.join(self.split_root, self.dataset))
161
+
162
+ def _get_hardpos_verb(self, seg_id, sent_idx):
163
+ """
164
+ Handle the logic for selecting hard positive verb phrases during metric learning.
165
+ Returns the sentence, raw_verb, and tokenized verb if applicable.
166
+ """
167
+ # If the object appears multiple times, no hard positive is used
168
+ if seg_id in self.multi_obj_ref_ids:
169
+ verb_embed = torch.zeros(self.emb_size, dtype=torch.float32)
170
+ return '', verb_embed
171
+
172
+ # Extract metadata for hard positives if present
173
+ hardpos_dict = self.hardpos_meta.get(str(seg_id), {})
174
+ if self.hp_selection == 'strict' :
175
+ sent_id_list = list(hardpos_dict.keys())
176
+ cur_sent_id = sent_id_list[sent_idx]
177
+ cur_hardpos = hardpos_dict.get(cur_sent_id, {}).get('phrases', [])
178
+
179
+ if cur_hardpos:
180
+ # Assign a hard positive verb phrase if available
181
+ rand_index = random.randint(0, len(cur_hardpos) - 1)
182
+ raw_verb = cur_hardpos[rand_index]
183
+ verb_embed = torch.from_numpy(self._get_hardpos_embed(seg_id, cur_sent_id, rand_index))
184
+ # print("Positive phrase : " , raw_verb)
185
+ return raw_verb, verb_embed
186
+
187
+ verb_embed = torch.zeros(self.emb_size, dtype=torch.float32)
188
+ return '', verb_embed
189
+
190
+
191
+ def _get_hardpos_embed(self, seg_id, sent_id, rand_index):
192
+ emb_folder = os.path.join(self.all_hp_root, str(seg_id))
193
+ emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_{sent_id}_") and f.endswith(".npy")])
194
+ selected_emb_file = os.path.join(emb_folder, emb_files[rand_index])
195
+
196
+ return np.load(selected_emb_file)
197
+
198
+
199
+ def pull_item(self, idx):
200
+ # if metric learning and in train mode
201
+ if self.metric_learning and self.augment :
202
+ # sent_idx refers to index of sent among sent_num-1
203
+ img_file, seg_id, bbox, phrase, sent_idx, sent_num = self.images[idx]
204
+ else :
205
+ img_file, seg_id, bbox, phrase = self.images[idx]
206
+ bbox = np.array(bbox, dtype=int) # x1y1x2y2
207
+
208
+ img_path = osp.join(self.im_dir, img_file)
209
+ img = cv2.imread(img_path) # BGR [512, 640, 3]
210
+ ## duplicate channel if gray image
211
+ if img.shape[-1] > 1:
212
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
213
+ else:
214
+ img = np.stack([img] * 3)
215
+
216
+ ## seg map
217
+ seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
218
+ seg_map = np.array(seg_map).astype(np.float32)
219
+
220
+ if self.metric_learning and self.split == 'train' :
221
+ return img, phrase, bbox, seg_map, seg_id, sent_idx
222
+ else :
223
+ return img, phrase, bbox, seg_map, seg_id
224
+
225
+ def __len__(self):
226
+ return len(self.images)
227
+
228
+ def __getitem__(self, idx):
229
+ if self.metric_learning and self.augment :
230
+ img, phrase, bbox, seg_map, seg_id, sent_idx = self.pull_item(idx)
231
+ else :
232
+ img, phrase, bbox, seg_map, seg_id = self.pull_item(idx)
233
+
234
+ phrase = phrase.lower()
235
+ if self.augment:
236
+ augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
237
+ True, True, True, False, False, False
238
+
239
+ ## seems a bug in torch transformation resize, so separate in advance
240
+ h,w = img.shape[0], img.shape[1]
241
+ # print("img.shape", img.shape)
242
+ if self.augment:
243
+ ## random horizontal flip
244
+ if augment_flip and random.random() > 0.5:
245
+ img = cv2.flip(img, 1)
246
+ seg_map = cv2.flip(seg_map, 1)
247
+ bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
248
+ phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
249
+
250
+ ## random copy and add left or right
251
+ if augment_copy:
252
+ img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
253
+
254
+ ## random erase for occluded
255
+ if augment_erase:
256
+ img, seg_map = random_erase(img, seg_map)
257
+
258
+ ## random padding and crop
259
+ if augment_crop:
260
+ img, seg_map = random_crop(img, seg_map, 40, h, w)
261
+
262
+ ## random intensity, saturation change
263
+ if augment_hsv:
264
+ fraction = 0.50
265
+ img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
266
+ S = img_hsv[:, :, 1].astype(np.float32)
267
+ V = img_hsv[:, :, 2].astype(np.float32)
268
+ a = (random.random() * 2 - 1) * fraction + 1
269
+ if a > 1:
270
+ np.clip(S, a_min=0, a_max=255, out=S)
271
+ a = (random.random() * 2 - 1) * fraction + 1
272
+ V *= a
273
+ if a > 1:
274
+ np.clip(V, a_min=0, a_max=255, out=V)
275
+
276
+ img_hsv[:, :, 1] = S.astype(np.uint8)
277
+ img_hsv[:, :, 2] = V.astype(np.uint8)
278
+ img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
279
+
280
+ img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
281
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
282
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
283
+
284
+ ## random affine transformation
285
+ if augment_affine:
286
+ img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
287
+ degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
288
+
289
+ else: ## should be inference, or specified training
290
+ img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
291
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
292
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
293
+
294
+ draw_img = copy.deepcopy(img)
295
+ # Norm, to tensor
296
+ if self.transform is not None:
297
+ img = self.transform(img)
298
+
299
+
300
+ ## encode phrase to clip input
301
+ word_id = clip.tokenize(phrase, 17, truncate=True)
302
+ word_mask = ~ (word_id == 0)
303
+
304
+ orig_word_id = np.array(word_id, dtype=int)
305
+ orig_word_mask = np.array(word_mask, dtype=int)
306
+
307
+ # Get hardpos verb phrase
308
+ if self.metric_learning and self.augment:
309
+ raw_hardpos, hardpos_emb = self._get_hardpos_verb(seg_id, sent_idx)
310
+ pos_type = 'nopos'
311
+ if raw_hardpos:
312
+ pos_type = 'hardpos'
313
+ hardpos_id = clip.tokenize(raw_hardpos, self.word_len, truncate=True)
314
+ else:
315
+ # Empty phrase → Create a zero tensor matching shape of tokenized input
316
+ hardpos_id = np.zeros((1, self.word_len), dtype=int)
317
+
318
+ # Masking
319
+ hardpos_mask = hardpos_id != 0 # Mask should be boolean
320
+
321
+ hp_word_id = np.array(hardpos_id, dtype=int)
322
+ hp_word_mask = np.array(hardpos_mask, dtype=int)
323
+
324
+ if self.augment: # train
325
+ seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
326
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
327
+ if self.metric_learning :
328
+ params = {
329
+ 'hp_word_id' : hp_word_id,
330
+ 'hp_word_mask' : hp_word_mask,
331
+ 'hardpos_emb' : hardpos_emb.unsqueeze(0),
332
+ 'pos_type' : pos_type
333
+ }
334
+ return img, orig_word_id, orig_word_mask, np.array(bbox, dtype=np.float32), \
335
+ np.array(seg_map, dtype=np.float32), params
336
+ else :
337
+ return img, orig_word_id, orig_word_mask, \
338
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
339
+ else:
340
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
341
+ return img, orig_word_id, orig_word_mask, \
342
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
343
+ np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
ASDA/dataset/data_loader_rccp.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
5
+ """
6
+ import sys
7
+ import cv2
8
+ import os
9
+ import torch
10
+ import json
11
+ import random
12
+ import numpy as np
13
+ import os.path as osp
14
+ import torch.utils.data as data
15
+ sys.path.append('.')
16
+ import utils
17
+ import re
18
+
19
+ # from pytorch_pretrained_bert.tokenization import BertTokenizer
20
+ from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
21
+ import copy
22
+
23
+ import clip
24
+
25
+ sys.modules['utils'] = utils
26
+ cv2.setNumThreads(0)
27
+
28
+ class ReferDataset(data.Dataset):
29
+ SUPPORTED_DATASETS = {
30
+ 'refcoco': {
31
+ 'splits': ('train', 'val', 'testA', 'testB'),
32
+ 'params': {'dataset': 'refcoco', 'split_by': 'unc'}
33
+ },
34
+ 'refcoco+': {
35
+ 'splits': ('train', 'val', 'testA', 'testB'),
36
+ 'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
37
+ },
38
+ 'refcocog': {
39
+ 'splits': ('train', 'val', 'test'),
40
+ 'params': {'dataset': 'refcocog', 'split_by': 'unc'}
41
+ },
42
+ 'refcocog_g': {
43
+ 'splits': ('train', 'val'),
44
+ 'params': {'dataset': 'refcocog', 'split_by': 'google'}
45
+ },
46
+ 'refcocog_u': {
47
+ 'splits': ('train', 'val', 'test'),
48
+ 'params': {'dataset': 'refcocog', 'split_by': 'unc'}
49
+ },
50
+ 'grefcoco': {
51
+ 'splits': ('train', 'val', 'testA', 'testB'),
52
+ 'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
53
+ }
54
+ }
55
+
56
+
57
+ def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
58
+ transform=None, augment=False, split='train', max_query_len=128, metric_learning=None):
59
+ images_tmp = []
60
+ self.data_root = data_root
61
+ self.split_root = split_root
62
+ self.dataset = dataset
63
+ self.imsize = imsize
64
+ self.query_len = max_query_len
65
+ self.transform = transform
66
+ self.word_len = 17
67
+ self.emb_size = 384
68
+ self.split = split
69
+ self.augment=augment
70
+
71
+ valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
72
+
73
+ if split not in valid_splits:
74
+ raise ValueError(
75
+ 'Dataset {0} does not have split {1}'.format(
76
+ self.dataset, split))
77
+
78
+ self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
79
+ if self.dataset == 'refcocog' :
80
+ mask_anno_str = '{0}_{1}'.format(self.dataset, splitby)
81
+ self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
82
+ else :
83
+ self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
84
+
85
+ self.im_dir = osp.join(self.data_root, 'images', 'train2014')
86
+
87
+ # if self.dataset in ['refcoco', 'refcoco+']
88
+ dataset_path = osp.join(self.split_root, self.dataset)
89
+ splits = [split]
90
+ for split in splits:
91
+ imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
92
+ imgset_path = osp.join(dataset_path, imgset_file)
93
+ images_tmp += torch.load(imgset_path)
94
+
95
+ # hardpos related
96
+ self.ROOT = '/data2/dataset/RefCOCO/VRIS'
97
+ if self.dataset == 'refcoco' :
98
+ self.all_hp_root = '/data2/dataset/RefCOCO/refcoco/SBERT_rcc_unc'
99
+ elif self.dataset == 'refcoco+' :
100
+ self.all_hp_root = '/data2/dataset/RefCOCO/refcoco+/SBERT_rccp_unc'
101
+
102
+ self.metric_learning = metric_learning
103
+ if self.metric_learning :
104
+ self.exclude_position = True
105
+ self.exclude_multiobj = True
106
+ self.hp_selection = 'strict'
107
+ self.multi_obj_ref_ids = None
108
+ self.hardpos_meta = None
109
+
110
+ # make new self.images file with sentence idx and total sent num (per ref_id)
111
+ from collections import defaultdict
112
+ ref_sentence_counts = defaultdict(int)
113
+ for item in images_tmp:
114
+ ref_sentence_counts[item[1]] += 1
115
+
116
+ if self.split == 'train' :
117
+ images = []
118
+ ref_sentence_indices = defaultdict(int)
119
+ for item in images_tmp:
120
+ img_name, seg_id, box, sentence = item
121
+ sent_index = ref_sentence_indices[seg_id]
122
+ total_sentences = ref_sentence_counts[seg_id]
123
+ images.append((img_name, seg_id, box, sentence, sent_index, total_sentences))
124
+ ref_sentence_indices[seg_id] += 1
125
+ self.images = images
126
+ else :
127
+ self.images = images_tmp
128
+ else :
129
+ self.images = images_tmp
130
+
131
+ def exists_dataset(self):
132
+ return osp.exists(osp.join(self.split_root, self.dataset))
133
+
134
+ def _get_hardpos_verb_rcc(self, seg_id, sent_idx):
135
+ emb_folder = os.path.join(self.all_hp_root, str(seg_id))
136
+ emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_") and f.endswith(".npy")])
137
+ if self.hp_selection == 'strict' :
138
+ # choose only corresponding (selected) sentence embedding
139
+ emb_file = emb_files[sent_idx]
140
+ else :
141
+ # choose any sentence embedding
142
+ emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_") and f.endswith(".npy")])
143
+ emb_file = random.choice(emb_files)
144
+ selected_emb = np.load(os.path.join(emb_folder, emb_file))
145
+ verb_embed = torch.from_numpy(selected_emb)
146
+ return verb_embed
147
+
148
+
149
+ def pull_item(self, idx):
150
+ # if metric learning and in train mode
151
+ if self.metric_learning and self.augment :
152
+ # sent_idx refers to index of sent among sent_num-1
153
+ img_file, seg_id, bbox, phrase, sent_idx, sent_num = self.images[idx]
154
+ else :
155
+ img_file, seg_id, bbox, phrase = self.images[idx]
156
+ bbox = np.array(bbox, dtype=int) # x1y1x2y2
157
+
158
+ img_path = osp.join(self.im_dir, img_file)
159
+ img = cv2.imread(img_path) # BGR [512, 640, 3]
160
+ ## duplicate channel if gray image
161
+ if img.shape[-1] > 1:
162
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
163
+ else:
164
+ img = np.stack([img] * 3)
165
+
166
+ ## seg map
167
+ seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
168
+ seg_map = np.array(seg_map).astype(np.float32)
169
+
170
+ if self.metric_learning and self.split == 'train' :
171
+ return img, phrase, bbox, seg_map, seg_id, sent_idx
172
+ else :
173
+ return img, phrase, bbox, seg_map, seg_id
174
+
175
+ def __len__(self):
176
+ return len(self.images)
177
+
178
+ def __getitem__(self, idx):
179
+ if self.metric_learning and self.augment :
180
+ img, phrase, bbox, seg_map, seg_id, sent_idx = self.pull_item(idx)
181
+ else :
182
+ img, phrase, bbox, seg_map, seg_id = self.pull_item(idx)
183
+
184
+ phrase = phrase.lower()
185
+ if self.augment:
186
+ augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
187
+ True, True, True, False, False, False
188
+
189
+ ## seems a bug in torch transformation resize, so separate in advance
190
+ h,w = img.shape[0], img.shape[1]
191
+ # print("img.shape", img.shape)
192
+ if self.augment:
193
+ ## random horizontal flip
194
+ if augment_flip and random.random() > 0.5:
195
+ img = cv2.flip(img, 1)
196
+ seg_map = cv2.flip(seg_map, 1)
197
+ bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
198
+ phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
199
+
200
+ ## random copy and add left or right
201
+ if augment_copy:
202
+ img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
203
+
204
+ ## random erase for occluded
205
+ if augment_erase:
206
+ img, seg_map = random_erase(img, seg_map)
207
+
208
+ ## random padding and crop
209
+ if augment_crop:
210
+ img, seg_map = random_crop(img, seg_map, 40, h, w)
211
+
212
+ ## random intensity, saturation change
213
+ if augment_hsv:
214
+ fraction = 0.50
215
+ img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
216
+ S = img_hsv[:, :, 1].astype(np.float32)
217
+ V = img_hsv[:, :, 2].astype(np.float32)
218
+ a = (random.random() * 2 - 1) * fraction + 1
219
+ if a > 1:
220
+ np.clip(S, a_min=0, a_max=255, out=S)
221
+ a = (random.random() * 2 - 1) * fraction + 1
222
+ V *= a
223
+ if a > 1:
224
+ np.clip(V, a_min=0, a_max=255, out=V)
225
+
226
+ img_hsv[:, :, 1] = S.astype(np.uint8)
227
+ img_hsv[:, :, 2] = V.astype(np.uint8)
228
+ img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
229
+
230
+ img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
231
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
232
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
233
+
234
+ ## random affine transformation
235
+ if augment_affine:
236
+ img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
237
+ degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
238
+
239
+ else: ## should be inference, or specified training
240
+ img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
241
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
242
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
243
+
244
+ draw_img = copy.deepcopy(img)
245
+ # Norm, to tensor
246
+ if self.transform is not None:
247
+ img = self.transform(img)
248
+
249
+
250
+ ## encode phrase to clip input
251
+ word_id = clip.tokenize(phrase, 17, truncate=True)
252
+ word_mask = ~ (word_id == 0)
253
+
254
+ orig_word_id = np.array(word_id, dtype=int)
255
+ orig_word_mask = np.array(word_mask, dtype=int)
256
+
257
+ # Get hardpos verb phrase
258
+ if self.metric_learning and self.augment:
259
+ original_emb = self._get_hardpos_verb_rcc(seg_id, sent_idx)
260
+
261
+ if self.augment: # train
262
+ seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
263
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
264
+ if self.metric_learning :
265
+ params = {
266
+ 'seg_id' : seg_id,
267
+ 'sent' : phrase,
268
+ 'hardpos_emb' : original_emb.unsqueeze(0)
269
+ }
270
+ return img, orig_word_id, orig_word_mask, np.array(bbox, dtype=np.float32), \
271
+ np.array(seg_map, dtype=np.float32), params
272
+ else :
273
+ return img, orig_word_id, orig_word_mask, \
274
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
275
+ else:
276
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
277
+ return img, orig_word_id, orig_word_mask, \
278
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
279
+ np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
ASDA/dataset/data_loader_test.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
5
+ """
6
+ import sys
7
+ import cv2
8
+ import torch
9
+ import random
10
+ import numpy as np
11
+ import os.path as osp
12
+ import torch.utils.data as data
13
+ sys.path.append('.')
14
+ import utils
15
+ import re
16
+
17
+ from pytorch_pretrained_bert.tokenization import BertTokenizer
18
+ from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
19
+ import copy
20
+
21
+ import clip
22
+
23
+ sys.modules['utils'] = utils
24
+ cv2.setNumThreads(0)
25
+
26
+ def read_examples(input_line, unique_id):
27
+ """Read a list of `InputExample`s from an input file."""
28
+ examples = []
29
+ # unique_id = 0
30
+ line = input_line #reader.readline()
31
+ # if not line:
32
+ # break
33
+ line = line.strip()
34
+ text_a = None
35
+ text_b = None
36
+ m = re.match(r"^(.*) \|\|\| (.*)$", line)
37
+ if m is None:
38
+ text_a = line
39
+ else:
40
+ text_a = m.group(1) #'man in black'
41
+ text_b = m.group(2)
42
+
43
+ examples.append(
44
+ InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
45
+ # unique_id += 1
46
+ return examples
47
+
48
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
49
+ while True:
50
+ total_length = len(tokens_a) + len(tokens_b)
51
+ if total_length <= max_length:
52
+ break
53
+ if len(tokens_a) > len(tokens_b):
54
+ tokens_a.pop()
55
+ else:
56
+ tokens_b.pop()
57
+
58
+ ## Bert text encoding
59
+ class InputExample(object):
60
+ def __init__(self, unique_id, text_a, text_b):
61
+ self.unique_id = unique_id
62
+ self.text_a = text_a
63
+ self.text_b = text_b
64
+
65
+ class InputFeatures(object):
66
+ """A single set of features of data."""
67
+ def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
68
+ self.unique_id = unique_id
69
+ self.tokens = tokens
70
+ self.input_ids = input_ids
71
+ self.input_mask = input_mask
72
+ self.input_type_ids = input_type_ids
73
+
74
+ def convert_examples_to_features(examples, seq_length, tokenizer):
75
+ """Loads a data file into a list of `InputBatch`s."""
76
+ features = []
77
+ for (ex_index, example) in enumerate(examples):
78
+ tokens_a = tokenizer.tokenize(example.text_a) # ['far', 'left', 'vase']
79
+
80
+ tokens_b = None
81
+ if example.text_b:
82
+ tokens_b = tokenizer.tokenize(example.text_b)
83
+
84
+ if tokens_b:
85
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
86
+ # length is less than the specified length.
87
+ # Account for [CLS], [SEP], [SEP] with "- 3"
88
+ _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
89
+ else:
90
+ # Account for [CLS] and [SEP] with "- 2"
91
+ if len(tokens_a) > seq_length - 2:
92
+ tokens_a = tokens_a[0:(seq_length - 2)]
93
+ tokens = []
94
+ input_type_ids = []
95
+ tokens.append("[CLS]")
96
+ input_type_ids.append(0)
97
+ for token in tokens_a:
98
+ tokens.append(token)
99
+ input_type_ids.append(0)
100
+ tokens.append("[SEP]")
101
+ input_type_ids.append(0)
102
+
103
+ if tokens_b:
104
+ for token in tokens_b:
105
+ tokens.append(token)
106
+ input_type_ids.append(1)
107
+ tokens.append("[SEP]")
108
+ input_type_ids.append(1)
109
+
110
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
111
+
112
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
113
+ # tokens are attended to.
114
+ input_mask = [1] * len(input_ids)
115
+
116
+ # Zero-pad up to the sequence length.
117
+ while len(input_ids) < seq_length:
118
+ input_ids.append(0)
119
+ input_mask.append(0)
120
+ input_type_ids.append(0)
121
+
122
+ assert len(input_ids) == seq_length
123
+ assert len(input_mask) == seq_length
124
+ assert len(input_type_ids) == seq_length
125
+ features.append(
126
+ InputFeatures(
127
+ unique_id=example.unique_id,
128
+ tokens=tokens,
129
+ input_ids=input_ids,
130
+ input_mask=input_mask,
131
+ input_type_ids=input_type_ids))
132
+ return features
133
+
134
+ class DatasetNotFoundError(Exception):
135
+ pass
136
+
137
+ class ReferDataset(data.Dataset):
138
+ SUPPORTED_DATASETS = {
139
+ 'refcoco': {
140
+ 'splits': ('train', 'val', 'testA', 'testB'),
141
+ 'params': {'dataset': 'refcoco', 'split_by': 'unc'}
142
+ },
143
+ 'refcoco+': {
144
+ 'splits': ('train', 'val', 'testA', 'testB'),
145
+ 'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
146
+ },
147
+ 'refcocog': {
148
+ 'splits': ('train', 'val', 'test'),
149
+ 'params': {'dataset': 'refcocog', 'split_by': 'umd'}
150
+ },
151
+ 'refcocog_g': {
152
+ 'splits': ('train', 'val'),
153
+ 'params': {'dataset': 'refcocog', 'split_by': 'google'}
154
+ },
155
+ 'refcocog_u': {
156
+ 'splits': ('train', 'val', 'test'),
157
+ 'params': {'dataset': 'refcocog', 'split_by': 'umd'}
158
+ },
159
+ 'grefcoco': {
160
+ 'splits': ('train', 'val', 'testA', 'testB'),
161
+ 'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
162
+ }
163
+ }
164
+
165
+ def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
166
+ transform=None, augment=False, split='train', max_query_len=128,
167
+ bert_model='bert-base-uncased'):
168
+ self.images = []
169
+ self.data_root = data_root
170
+ self.split_root = split_root
171
+ self.dataset = dataset
172
+ self.imsize = imsize
173
+ self.query_len = max_query_len
174
+ self.transform = transform
175
+ self.split = split
176
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True) # should be true for English
177
+ self.augment=augment
178
+
179
+ valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
180
+
181
+ if split not in valid_splits:
182
+ raise ValueError(
183
+ 'Dataset {0} does not have split {1}'.format(
184
+ self.dataset, split))
185
+
186
+ self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
187
+ if self.dataset == 'refcocog_u' :
188
+ dataset = 'refcocog'
189
+ mask_anno_str = '{0}_{1}'.format(dataset, splitby)
190
+ self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
191
+ else :
192
+ self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
193
+
194
+ self.im_dir = osp.join(self.data_root, 'images', 'train2014')
195
+
196
+ if self.dataset == 'refcocog_u' :
197
+ dataset = 'refcocog'
198
+ dataset_path = osp.join(self.split_root, dataset + '_' + splitby)
199
+ splits = [split]
200
+ for split in splits:
201
+ imgset_file = '{0}_{1}_{2}.pth'.format(dataset, splitby, split)
202
+ imgset_path = osp.join(dataset_path, imgset_file)
203
+ self.images += torch.load(imgset_path)
204
+ else :
205
+ dataset_path = osp.join(self.split_root, self.dataset)
206
+ splits = [split]
207
+ for split in splits:
208
+ imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
209
+ imgset_path = osp.join(dataset_path, imgset_file)
210
+ self.images += torch.load(imgset_path)
211
+
212
+ # def exists_dataset(self):
213
+ # return osp.exists(osp.join(self.split_root, self.dataset))
214
+
215
+ def pull_item(self, idx):
216
+ img_file, seg_id, bbox, phrase = self.images[idx]
217
+ bbox = np.array(bbox, dtype=int) # x1y1x2y2
218
+
219
+ img_path = osp.join(self.im_dir, img_file)
220
+ img = cv2.imread(img_path) # BGR [512, 640, 3]
221
+ ## duplicate channel if gray image
222
+ if img.shape[-1] > 1:
223
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
224
+ else:
225
+ img = np.stack([img] * 3)
226
+
227
+ ## seg map
228
+ seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
229
+ seg_map = np.array(seg_map).astype(np.float32)
230
+ return img, phrase, bbox, seg_map
231
+
232
+ def __len__(self):
233
+ return len(self.images)
234
+
235
+ def __getitem__(self, idx):
236
+ img, phrase, bbox, seg_map = self.pull_item(idx)
237
+ phrase = phrase.lower()
238
+ if self.augment:
239
+ augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
240
+ True, True, True, False, False, False
241
+
242
+ ## seems a bug in torch transformation resize, so separate in advance
243
+ h,w = img.shape[0], img.shape[1]
244
+ # print("img.shape", img.shape)
245
+ if self.augment:
246
+ ## random horizontal flip
247
+ if augment_flip and random.random() > 0.5:
248
+ img = cv2.flip(img, 1)
249
+ seg_map = cv2.flip(seg_map, 1)
250
+ bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
251
+ phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
252
+
253
+ ## random copy and add left or right
254
+ if augment_copy:
255
+ img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
256
+
257
+ ## random erase for occluded
258
+ if augment_erase:
259
+ img, seg_map = random_erase(img, seg_map)
260
+
261
+ ## random padding and crop
262
+ if augment_crop:
263
+ img, seg_map = random_crop(img, seg_map, 40, h, w)
264
+
265
+ ## random intensity, saturation change
266
+ if augment_hsv:
267
+ fraction = 0.50
268
+ img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
269
+ S = img_hsv[:, :, 1].astype(np.float32)
270
+ V = img_hsv[:, :, 2].astype(np.float32)
271
+ a = (random.random() * 2 - 1) * fraction + 1
272
+ if a > 1:
273
+ np.clip(S, a_min=0, a_max=255, out=S)
274
+ a = (random.random() * 2 - 1) * fraction + 1
275
+ V *= a
276
+ if a > 1:
277
+ np.clip(V, a_min=0, a_max=255, out=V)
278
+
279
+ img_hsv[:, :, 1] = S.astype(np.uint8)
280
+ img_hsv[:, :, 2] = V.astype(np.uint8)
281
+ img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
282
+
283
+ img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
284
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
285
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
286
+
287
+ ## random affine transformation
288
+ if augment_affine:
289
+ img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
290
+ degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
291
+
292
+ else: ## should be inference, or specified training
293
+ img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
294
+ bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
295
+ bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
296
+
297
+ draw_img = copy.deepcopy(img)
298
+ # Norm, to tensor
299
+ if self.transform is not None:
300
+ img = self.transform(img)
301
+
302
+ ## encode phrase to clip input
303
+ word_id = clip.tokenize(phrase, 17, truncate=True)
304
+ word_mask = ~ (word_id == 0)
305
+
306
+ if self.augment: # train
307
+ seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
308
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
309
+ return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
310
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
311
+ else:
312
+ seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
313
+ return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
314
+ np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
315
+ np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
ASDA/dataset/data_process.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding=utf8
2
+ # %matplotlib inline
3
+ import numpy as np
4
+ import os
5
+ from refer import REFER
6
+ import os.path as osp
7
+ import cv2
8
+ import argparse
9
+ parser = argparse.ArgumentParser(description='Data preparation')
10
+ parser.add_argument('--data_root', type=str) # contains refclef, refcoco, refcoco+, refcocog and images
11
+ parser.add_argument('--output_dir', type=str)
12
+ parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog'], default='refcoco')
13
+ parser.add_argument('--split', type=str,default='umd')
14
+ parser.add_argument('--generate_mask', action='store_true')
15
+ args = parser.parse_args()
16
+ # data_root # contains refclef, refcoco, refcoco+, refcocog and images
17
+ refer = REFER(args.data_root, args.dataset, args.split)
18
+
19
+ print ('dataset [%s_%s] contains: ' % (args.dataset, args.split))
20
+ ref_ids = refer.getRefIds()
21
+ image_ids = refer.getImgIds()
22
+ print ('%s expressions for %s refs in %s images.' % (len(refer.Sents), len(ref_ids), len(image_ids)))
23
+
24
+ print('\nAmong them:')
25
+ if args.dataset == 'refclef':
26
+ if args.split == 'unc':
27
+ splits = ['train', 'val', 'testA','testB','testC']
28
+ else:
29
+ splits = ['train', 'val', 'test']
30
+ elif args.dataset == 'refcoco':
31
+ splits = ['train', 'val', 'testA', 'testB']
32
+ elif args.dataset == 'refcoco+':
33
+ splits = ['train', 'val', 'testA', 'testB']
34
+ elif args.dataset == 'grefcoco':
35
+ splits = ['train', 'val', 'testA', 'testB']
36
+ elif args.dataset == 'refcocog':
37
+ splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now.
38
+
39
+
40
+
41
+ # split data as a type in splits list
42
+ for split in splits:
43
+ ref_ids = refer.getRefIds(split=split)
44
+ print('%s refs are in split [%s].' % (len(ref_ids), split))
45
+
46
+
47
+ # show a batch data with bounding box,cat,sentences
48
+ def show_a_batch(batch_size):
49
+ split='train'
50
+ # batch_size=32
51
+ ref_ids = refer.getRefIds(split=split)
52
+ print(split+'_size:',len(ref_ids))
53
+ batch_index=list(np.random.choice(len(ref_ids),batch_size))
54
+
55
+ # print(refer.Refs)
56
+ ref_id = [ref_ids[i] for i in batch_index]
57
+ refs = [refer.Refs[i] for i in ref_id]
58
+ bboxs=[refer.getRefBox(i) for i in ref_id]
59
+ sentences=[ref['sentences'] for ref in refs]
60
+ image_urls=[refer.loadImgs(image_ids=ref['image_id']) for ref in refs]
61
+ cats = [refer.loadCats(cat_ids=ref['category_id']) for ref in refs]
62
+ # plt.figure()
63
+ # plt.subplot(batch_size)
64
+ grid_width = 2
65
+ grid_height = int(batch_size / grid_width)
66
+ # fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*10, 10*grid_height))
67
+ for i in range(batch_size):
68
+ print('bbox for batch[{}]:'.format(i),bboxs[i])
69
+ print('sentences for batch[{}]:'.format(i))
70
+ for sid, sent in enumerate(sentences[i]):
71
+ print('%s. %s' % (sid+1, sent['sent']))
72
+ print('cats for batch[{}]:'.format(i), cats[i])
73
+
74
+ image_url=image_urls[i][0]
75
+ image=cv2.imread(osp.join(refer.IMAGE_DIR, image_url['file_name']))
76
+ print(image.shape)
77
+ # print(bboxs[i][0])
78
+ cv2.rectangle(image,(int(bboxs[i][0]), int(bboxs[i][1])), (int(bboxs[i][0]+bboxs[i][2]),int(bboxs[i][1]+ bboxs[i][3])),255,3)
79
+ cv2.putText(image,
80
+ str(sent['sent']),
81
+ (20, 20),
82
+ cv2.FONT_HERSHEY_SIMPLEX,
83
+ .9,(0,255,0), 2)
84
+ os.mkdir('debug_vis')
85
+ cv2.imwrite('./debug_vis/'+image_url['file_name'], image)
86
+ cv2.imwrite('./debug_vis/mask'+image_url['file_name'], refer.getMask(refs[i])['mask']*255)
87
+ # ax.imshow(image)
88
+ # plt.show()
89
+
90
+ def cat_process(cat):
91
+ if cat >= 1 and cat <= 11:
92
+ cat = cat - 1
93
+ elif cat >= 13 and cat <= 25:
94
+ cat = cat - 2
95
+ elif cat >= 27 and cat <= 28:
96
+ cat = cat - 3
97
+ elif cat >= 31 and cat <= 44:
98
+ cat = cat - 5
99
+ elif cat >= 46 and cat <= 65:
100
+ cat = cat - 6
101
+ elif cat == 67:
102
+ cat = cat - 7
103
+ elif cat == 70:
104
+ cat = cat - 9
105
+ elif cat >= 72 and cat <= 82:
106
+ cat = cat - 10
107
+ elif cat >= 84 and cat <= 90:
108
+ cat = cat - 11
109
+ return cat
110
+
111
+ def bbox_process(bbox,cat,segement_id):
112
+ x_min = int(bbox[0])
113
+ y_min = int(bbox[1])
114
+ x_max = x_min + int(bbox[2])
115
+ y_max = y_min + int(bbox[3])
116
+ box_info = " %d,%d,%d,%d,%d,%d" % (int(x_min), int(y_min), int(x_max), int(y_max), int(cat),int(segement_id))
117
+ return box_info
118
+
119
+ def prepare_dataset(dataset,splits,output_dir,generate_mask=False):
120
+ # split_type='train'
121
+ # splits=[split_type]
122
+ # batch_size=32
123
+ if dataset == 'refcocog':
124
+ dataset = 'refcocog_' + args.split
125
+ if not os.path.exists(os.path.join(output_dir,'anns',dataset)):
126
+ os.makedirs(os.path.join(output_dir,'anns',dataset))
127
+ if not os.path.exists(os.path.join(output_dir,'masks',dataset)):
128
+ os.makedirs(os.path.join(output_dir,'masks',dataset))
129
+ for split in splits:
130
+ f = open(os.path.join(output_dir,'anns', dataset, split + '.txt'), 'w', encoding='utf-8')
131
+ # print(split)
132
+ split_num=0
133
+ ll=0
134
+ ref_ids = refer.getRefIds(split=split)
135
+ print(split+'_size:',len(ref_ids))
136
+ for i in ref_ids:
137
+ # ref_id = ref_ids[i]
138
+ refs = refer.Refs[i]
139
+ bboxs=refer.getRefBox(i)
140
+ print("bboxs", bboxs)
141
+ sentences=refs['sentences']
142
+ image_urls=refer.loadImgs(image_ids=refs['image_id'])[0]
143
+
144
+ # grefcoco中的category_id是一个list
145
+ cat = refs['category_id']
146
+ if type(cat) == list:
147
+ for j in range(len(cat)):
148
+ cat[j] = cat_process(cat[j])
149
+ else:
150
+ cat = cat_process(cat)
151
+
152
+ image_urls=image_urls['file_name']
153
+ if dataset=='refclef' and image_urls in ['19579.jpg', '17975.jpg', '19575.jpg']:
154
+ continue
155
+ # RES中box信息和cat信息使用不到
156
+ if type(bboxs[0]) == list:
157
+ box_info = bbox_process(bboxs[0], cat[0], i) # add segement id
158
+ else:
159
+ box_info=bbox_process(bboxs,cat,i) #add segement id
160
+ f.write(image_urls)
161
+ f.write(box_info)
162
+ # f.write(' '+str(i))
163
+ if generate_mask:
164
+ if dataset == 'grefcoco':
165
+ np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMaskByRef(refs, merge=True)['mask'])
166
+ else:
167
+ np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMask(refs)['mask']) #if need seg mask ,set it!
168
+ for sentence in sentences:
169
+ f.write(' ~ ')
170
+ # print(sentence['sent'].encode('UTF-8'))
171
+ f.write(sentence['sent'])
172
+ if ll<len(sentence['sent']):
173
+ ll=len(sentence['sent'])
174
+ f.write('\n')
175
+ split_num+=1
176
+ print('split_num:',split_num)
177
+ print('max_len:',ll)
178
+ f.close()
179
+
180
+ def prepare_sentences_refcoco():
181
+ splits=['train','val']
182
+ # batch_size=32
183
+ f = open('sentences.txt', 'w')
184
+ for split in splits:
185
+ print(split)
186
+ ref_ids = refer.getRefIds(split=split)
187
+ print(split+'_size:',len(ref_ids))
188
+ for i in range(len(ref_ids)):
189
+ refs = refer.Refs[i]
190
+ sentences=refs['sentences']
191
+ for sentence in sentences:
192
+ f.write(sentence['sent'])
193
+ f.write('\n')
194
+ f.close()
195
+
196
+ def test_length():
197
+ max_len=0
198
+ word_l_count=np.zeros([50],dtype=np.int)
199
+ with open('./refcocog/train.txt') as f:
200
+ lines = f.readlines()
201
+ for j in range(len(lines)):
202
+ line=lines[j].split()
203
+ stop = len(line)
204
+ for i in range(1, len(line)):
205
+ if (line[i] == '~'):
206
+ stop = i
207
+ break
208
+ sentences = []
209
+ sent_stop = stop + 1
210
+ for i in range(stop + 1, len(line)):
211
+ if line[i] == '~':
212
+ # sentences.append(line[sent_stop:i])
213
+ # print(len(line[sent_stop:i]))
214
+ word_l_count[len(line[sent_stop:i])]+=1
215
+ # if len(line[sent_stop:i])>max_len:
216
+ # max_len=len(line[sent_stop:i])
217
+ sent_stop = i + 1
218
+ for i in range(50):
219
+ if word_l_count[i]>0:
220
+ print('length:%d'%i,',count:%d'%word_l_count[i])
221
+ # print('max_len:',max_len)
222
+ # print(len(lines))
223
+
224
+
225
+ prepare_dataset(args.dataset,splits,args.output_dir,args.generate_mask)
ASDA/dataset/datascript.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate **.pth
2
+ import os
3
+ import sys
4
+ import torch
5
+ sys.path.append('.')
6
+
7
+ import argparse
8
+ parser = argparse.ArgumentParser(description='Data preparation')
9
+ parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog_google', 'refcocog_umd'], default='refcoco')
10
+ args = parser.parse_args()
11
+
12
+ def main(args):
13
+ dataset = args.dataset
14
+ input_txt_list = os.listdir(f'../ln_data/anns/{dataset}')
15
+ if not os.path.exists(f'../data/{dataset}'):
16
+ os.makedirs(f'../data/{dataset}')
17
+ for input_txt in input_txt_list:
18
+ split = input_txt.split('_')[-1].split('.')[0]
19
+ input_txt = os.path.join('../ln_data/anns', dataset, input_txt)
20
+ res = []
21
+ with open(input_txt, encoding='utf-8') as f:
22
+ lines = f.readlines()
23
+ for line in lines:
24
+ line = line.split()
25
+ stop = len(line)
26
+ img_name = line[0]
27
+ for i in range(1,len(line)):
28
+ if (line[i]=='~'):
29
+ stop=i
30
+ break
31
+ box_ = [list(map(int,box.split(','))) for box in line[1:stop]]
32
+ box = box_[0][:4]
33
+ seg_id=box_[0][-1]
34
+
35
+ sent_stop=stop+1
36
+ for i in range(stop+1,len(line)):
37
+ if line[i]=='~':
38
+ des = ''
39
+ for word in line[sent_stop:i]:
40
+ des = des + word + ' '
41
+ sent_stop=i+1
42
+ des = des.rstrip(' ')
43
+ res.append((img_name, seg_id, box, des))
44
+ des = ''
45
+ for word in line[sent_stop:len(line)]:
46
+ des = des + word + ' '
47
+ des = des.rstrip(' ')
48
+ res.append((img_name, seg_id, box, des))
49
+ # print(res)
50
+
51
+ imgset_path = '{0}_{1}.pth'.format(dataset, split)
52
+ images = torch.save(res, os.path.join("../data", dataset, imgset_path))
53
+ print(dataset, " done")
54
+
55
+ if __name__ == "__main__":
56
+ main(args)
ASDA/dataset/refer.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = 'licheng'
2
+
3
+ """
4
+ This interface provides access to four datasets:
5
+ 1) refclef
6
+ 2) refcoco
7
+ 3) refcoco+
8
+ 4) refcocog
9
+ split by unc and google
10
+ The following API functions are defined:
11
+ REFER - REFER api class
12
+ getRefIds - get ref ids that satisfy given filter conditions.
13
+ getAnnIds - get ann ids that satisfy given filter conditions.
14
+ getImgIds - get image ids that satisfy given filter conditions.
15
+ getCatIds - get category ids that satisfy given filter conditions.
16
+ loadRefs - load refs with the specified ref ids.
17
+ loadAnns - load anns with the specified ann ids.
18
+ loadImgs - load images with the specified image ids.
19
+ loadCats - load category names with the specified category ids.
20
+ getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
21
+ showRef - show image, segmentation or box of the referred object with the ref
22
+ getMask - get mask and area of the referred object given ref
23
+ showMask - show mask of the referred object given ref
24
+ """
25
+
26
+ import sys
27
+ import os.path as osp
28
+ import os
29
+ import json
30
+ # import _pickle as pickle
31
+ import pickle
32
+ import time
33
+ import itertools
34
+ import skimage.io as io
35
+ import matplotlib.pyplot as plt
36
+ from matplotlib.collections import PatchCollection
37
+ from matplotlib.patches import Polygon, Rectangle
38
+ from pprint import pprint
39
+ import numpy as np
40
+ from pycocotools import mask
41
+ import cv2
42
+ # from skimage.measure import label, regionprops
43
+
44
+ class REFER:
45
+ def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
46
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
47
+ # also provide dataset name and splitBy information
48
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
49
+ print('loading dataset %s into memory...' % dataset)
50
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
51
+ self.DATA_DIR = osp.join(data_root, dataset)
52
+ if dataset in ['refcoco', 'refcoco+', 'refcocog']:
53
+ self.IMAGE_DIR = osp.join(data_root, 'images/train2014')
54
+ elif dataset == 'refclef':
55
+ self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
56
+ else:
57
+ print('No refer dataset is called [%s]' % dataset)
58
+ sys.exit()
59
+
60
+ # load refs from data/dataset/refs(dataset).json
61
+ tic = time.time()
62
+ ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
63
+ self.data = {}
64
+ self.data['dataset'] = dataset
65
+
66
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'),fix_imports=True)
67
+
68
+ # load annotations from data/dataset/instances.json
69
+ instances_file = osp.join(self.DATA_DIR, 'instances.json')
70
+ instances = json.load(open(instances_file, 'r'))
71
+ self.data['images'] = instances['images']
72
+ self.data['annotations'] = instances['annotations']
73
+ self.data['categories'] = instances['categories']
74
+
75
+ # create index
76
+ self.createIndex()
77
+ print('DONE (t=%.2fs)' % (time.time()-tic))
78
+
79
+ def createIndex(self):
80
+ # create sets of mapping
81
+ # 1) Refs: {ref_id: ref}
82
+ # 2) Anns: {ann_id: ann}
83
+ # 3) Imgs: {image_id: image}
84
+ # 4) Cats: {category_id: category_name}
85
+ # 5) Sents: {sent_id: sent}
86
+ # 6) imgToRefs: {image_id: refs}
87
+ # 7) imgToAnns: {image_id: anns}
88
+ # 8) refToAnn: {ref_id: ann}
89
+ # 9) annToRef: {ann_id: ref}
90
+ # 10) catToRefs: {category_id: refs}
91
+ # 11) sentToRef: {sent_id: ref}
92
+ # 12) sentToTokens: {sent_id: tokens}
93
+ print('creating index...')
94
+ # fetch info from instances
95
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
96
+ for ann in self.data['annotations']:
97
+ Anns[ann['id']] = ann
98
+ imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
99
+ for img in self.data['images']:
100
+ Imgs[img['id']] = img
101
+ for cat in self.data['categories']:
102
+ Cats[cat['id']] = cat['name']
103
+
104
+ # fetch info from refs
105
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
106
+ Sents, sentToRef, sentToTokens = {}, {}, {}
107
+ for ref in self.data['refs']:
108
+ # ids
109
+ ref_id = ref['ref_id']
110
+ ann_id = ref['ann_id']
111
+ category_id = ref['category_id']
112
+ image_id = ref['image_id']
113
+
114
+ # add mapping related to ref
115
+ Refs[ref_id] = ref
116
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
117
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
118
+ refToAnn[ref_id] = Anns[ann_id]
119
+ annToRef[ann_id] = ref
120
+
121
+ # add mapping of sent
122
+ for sent in ref['sentences']:
123
+ Sents[sent['sent_id']] = sent
124
+ sentToRef[sent['sent_id']] = ref
125
+ sentToTokens[sent['sent_id']] = sent['tokens']
126
+
127
+ # create class members
128
+ self.Refs = Refs
129
+ self.Anns = Anns
130
+ self.Imgs = Imgs
131
+ self.Cats = Cats
132
+ self.Sents = Sents
133
+ self.imgToRefs = imgToRefs
134
+ self.imgToAnns = imgToAnns
135
+ self.refToAnn = refToAnn
136
+ self.annToRef = annToRef
137
+ self.catToRefs = catToRefs
138
+ self.sentToRef = sentToRef
139
+ self.sentToTokens = sentToTokens
140
+ print('index created.')
141
+
142
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
143
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
144
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
145
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
146
+
147
+ if len(image_ids)==len(cat_ids)==len(ref_ids)==len(split)==0:
148
+ refs = self.data['refs']
149
+ else:
150
+ if not len(image_ids) == 0:
151
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
152
+ else:
153
+ refs = self.data['refs']
154
+ if not len(cat_ids) == 0:
155
+ refs = [ref for ref in refs if ref['category_id'] in cat_ids]
156
+ if not len(ref_ids) == 0:
157
+ refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
158
+ if not len(split) == 0:
159
+ if split in ['testA', 'testB', 'testC']:
160
+ refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ...
161
+ elif split in ['testAB', 'testBC', 'testAC']:
162
+ refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
163
+ elif split == 'test':
164
+ refs = [ref for ref in refs if 'test' in ref['split']]
165
+ elif split == 'train' or split == 'val':
166
+ refs = [ref for ref in refs if ref['split'] == split]
167
+ else:
168
+ print('No such split [%s]' % split)
169
+ sys.exit()
170
+ ref_ids = [ref['ref_id'] for ref in refs]
171
+ return ref_ids
172
+
173
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
174
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
175
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
176
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
177
+
178
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
179
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
180
+ else:
181
+ if not len(image_ids) == 0:
182
+ lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
183
+ anns = list(itertools.chain.from_iterable(lists))
184
+ else:
185
+ anns = self.data['annotations']
186
+ if not len(cat_ids) == 0:
187
+ anns = [ann for ann in anns if ann['category_id'] in cat_ids]
188
+ ann_ids = [ann['id'] for ann in anns]
189
+ if not len(ref_ids) == 0:
190
+ ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
191
+ return ann_ids
192
+
193
+ def getImgIds(self, ref_ids=[]):
194
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
195
+
196
+ if not len(ref_ids) == 0:
197
+ image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
198
+ else:
199
+ image_ids = self.Imgs.keys()
200
+ return image_ids
201
+
202
+ def getCatIds(self):
203
+ return self.Cats.keys()
204
+
205
+ def loadRefs(self, ref_ids=[]):
206
+ if type(ref_ids) == list:
207
+ return [self.Refs[ref_id] for ref_id in ref_ids]
208
+ elif type(ref_ids) == int:
209
+ return [self.Refs[ref_ids]]
210
+
211
+ def loadAnns(self, ann_ids=[]):
212
+ if type(ann_ids) == list:
213
+ return [self.Anns[ann_id] for ann_id in ann_ids]
214
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
215
+ return [self.Anns[ann_ids]]
216
+
217
+ def loadImgs(self, image_ids=[]):
218
+ if type(image_ids) == list:
219
+ return [self.Imgs[image_id] for image_id in image_ids]
220
+ elif type(image_ids) == int:
221
+ return [self.Imgs[image_ids]]
222
+
223
+ def loadCats(self, cat_ids=[]):
224
+ if type(cat_ids) == list:
225
+ return [self.Cats[cat_id] for cat_id in cat_ids]
226
+ elif type(cat_ids) == int:
227
+ return [self.Cats[cat_ids]]
228
+
229
+ def getRefBox(self, ref_id):
230
+ ref = self.Refs[ref_id]
231
+ ann = self.refToAnn[ref_id]
232
+ return ann['bbox'] # [x, y, w, h]
233
+
234
+ def showRef(self, ref, seg_box='seg'):
235
+ ax = plt.gca()
236
+ # show image
237
+ image = self.Imgs[ref['image_id']]
238
+ I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
239
+ ax.imshow(I)
240
+ # show refer expression
241
+ for sid, sent in enumerate(ref['sentences']):
242
+ print('%s. %s' % (sid+1, sent['sent']))
243
+ # show segmentations
244
+ if seg_box == 'seg':
245
+ ann_id = ref['ann_id']
246
+ ann = self.Anns[ann_id]
247
+ polygons = []
248
+ color = []
249
+ c = 'none'
250
+ if type(ann['segmentation'][0]) == list:
251
+ # polygon used for refcoco*
252
+ for seg in ann['segmentation']:
253
+ poly = np.array(seg).reshape((len(seg)//2, 2))
254
+ polygons.append(Polygon(poly, True, alpha=0.4))
255
+ color.append(c)
256
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(1,1,0,0), linewidths=3, alpha=1)
257
+ ax.add_collection(p) # thick yellow polygon
258
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(1,0,0,0), linewidths=1, alpha=1)
259
+ ax.add_collection(p) # thin red polygon
260
+ else:
261
+ # mask used for refclef
262
+ rle = ann['segmentation']
263
+ m = mask.decode(rle)
264
+ img = np.ones( (m.shape[0], m.shape[1], 3) )
265
+ color_mask = np.array([2.0,166.0,101.0])/255
266
+ for i in range(3):
267
+ img[:,:,i] = color_mask[i]
268
+ ax.imshow(np.dstack( (img, m*0.5) ))
269
+ # show bounding-box
270
+ elif seg_box == 'box':
271
+ ann_id = ref['ann_id']
272
+ print(ann_id)
273
+ ann = self.Anns[ann_id]
274
+ bbox = self.getRefBox(ref['ref_id'])
275
+ box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
276
+ ax.add_patch(box_plot)
277
+
278
+ def getMask(self, ref):
279
+ # return mask, area and mask-center
280
+ ann = self.refToAnn[ref['ref_id']]
281
+ print(ann)
282
+ image = self.Imgs[ref['image_id']]
283
+ if type(ann['segmentation'][0]) == list: # polygon
284
+ rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width'])
285
+ else:
286
+ rle = ann['segmentation']
287
+
288
+ # for i in range(len(rle['counts'])):
289
+ # print(rle)
290
+ m = mask.decode(rle)
291
+ m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
292
+ m = m.astype(np.uint8) # convert to np.uint8
293
+ # compute area
294
+ area = sum(mask.area(rle)) # should be close to ann['area']
295
+ return {'mask': m, 'area': area}
296
+ # # position
297
+ # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
298
+ # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
299
+ # # mass position (if there were multiple regions, we use the largest one.)
300
+ # label_m = label(m, connectivity=m.ndim)
301
+ # regions = regionprops(label_m)
302
+ # if len(regions) > 0:
303
+ # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
304
+ # largest_props = regions[largest_id]
305
+ # mass_y, mass_x = largest_props.centroid
306
+ # else:
307
+ # mass_x, mass_y = position_x, position_y
308
+ # # if centroid is not in mask, we find the closest point to it from mask
309
+ # if m[mass_y, mass_x] != 1:
310
+ # print 'Finding closes mask point ...'
311
+ # kernel = np.ones((10, 10),np.uint8)
312
+ # me = cv2.erode(m, kernel, iterations = 1)
313
+ # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
314
+ # points = np.array(points)
315
+ # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
316
+ # id = np.argsort(dist)[0]
317
+ # mass_y, mass_x = points[id]
318
+ # # return
319
+ # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
320
+ # # show image and mask
321
+ # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
322
+ # plt.figure()
323
+ # plt.imshow(I)
324
+ # ax = plt.gca()
325
+ # img = np.ones( (m.shape[0], m.shape[1], 3) )
326
+ # color_mask = np.array([2.0,166.0,101.0])/255
327
+ # for i in range(3):
328
+ # img[:,:,i] = color_mask[i]
329
+ # ax.imshow(np.dstack( (img, m*0.5) ))
330
+ # plt.show()
331
+
332
+ def showMask(self, ref):
333
+ M = self.getMask(ref)
334
+ msk = M['mask']
335
+ ax = plt.gca()
336
+ ax.imshow(msk)
337
+
338
+
339
+ if __name__ == '__main__':
340
+ refer = REFER(data_root="/home/ypf/workspace/code/BKINet/ln_data", dataset='refcoco', splitBy='unc')
341
+ save_path = "./visualization/"
342
+ ref_ids = refer.getRefIds()
343
+ print(len(ref_ids))
344
+
345
+ print(len(refer.Imgs))
346
+ print(len(refer.imgToRefs))
347
+ print(refer.Cats)
348
+
349
+ ref_ids = refer.getRefIds(split='train')
350
+ print('There are %s training referred objects.' % len(ref_ids))
351
+
352
+ img_ids = [8936, 52563]
353
+ # ref_ids = refer.getRefIds(image_ids=img_ids)
354
+
355
+ # refs = refer.loadRefs(ref_ids)
356
+
357
+ def custom_vis1(image, mask_):
358
+ # 将mask应用到蓝色图层
359
+ # 创建一个蓝色图层
360
+ blue_layer = np.zeros_like(image)
361
+ blue_layer[:, :, 0] = 255 # 对于OpenCV,蓝色通道是第一个
362
+ blue_mask = cv2.bitwise_and(blue_layer, blue_layer, mask=mask_)
363
+
364
+ # 将蓝色mask以一定的透明度覆盖到原图上
365
+ alpha = 0.1 # 透明度
366
+ cv2.addWeighted(blue_mask, alpha, image, 1 - alpha, 0, image)
367
+
368
+ def custom_vis2(image, mask_):
369
+ # 创建蓝色图层
370
+ blue_layer = np.zeros_like(image)
371
+ blue_layer[:, :, 0] = 255 # 对于OpenCV,蓝色通道是第一个
372
+
373
+ # 将mask应用到蓝色图层
374
+ blue_mask = cv2.bitwise_and(blue_layer, blue_layer, mask=mask_)
375
+
376
+ # alpha值定义了mask图层和原图的融合程度
377
+ alpha = 0.5 # 透明度
378
+
379
+ # 创建一个完全透明的图层
380
+ transparent_layer = np.zeros_like(image)
381
+
382
+ # 我们只在mask的区域上应用蓝色图层,并调整alpha值来控制透明度
383
+ for i in range(3): # 只处理RGB三个通道
384
+ transparent_layer[:, :, i] = cv2.addWeighted(blue_mask[:, :, i], alpha, image[:, :, i], 1 - alpha, 0)
385
+
386
+ # 在mask区域外使用原图
387
+ transparent_layer[mask_ == 0] = image[mask_ == 0]
388
+
389
+ return transparent_layer
390
+
391
+ def custom_vis3(image, mask_):
392
+ """
393
+ 直接在原图上修改指定mask区域的颜色为蓝色
394
+ 不改变其他区域的亮度或色彩
395
+ """
396
+ image[mask_ != 0] = [255, 0, 0] # OpenCV中的颜色顺序是BGR
397
+
398
+ def custom_vis4(image, mask_, alpha=0.4):
399
+ """
400
+ 在原图上以指定的透明度应用蓝色遮罩。
401
+ alpha: 遮罩的透明度,范围从0(完全透明)到1(完全不透明)。
402
+ """
403
+ # 将原图从BGR转换为RGBA以添加Alpha通道
404
+ image_rgba = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)
405
+ # 创建一个同样大小的全蓝色图层
406
+ blue_mask = np.zeros_like(image_rgba)
407
+ blue_mask[:, :, 0] = 255 # B
408
+ blue_mask[:, :, 3] = 255 # Alpha设置为不透明
409
+
410
+ # 应用透明度到mask区域
411
+ blue_mask[mask_ != 0, 3] = int(alpha * 255)
412
+
413
+ # 将蓝色遮罩叠加到原图
414
+ image_rgba = cv2.addWeighted(image_rgba, 1, blue_mask, alpha, 0)
415
+ return image_rgba
416
+
417
+
418
+
419
+
420
+ for i, img_id in enumerate(img_ids):
421
+ ref = refer.imgToRefs[img_id][0]
422
+ print(ref)
423
+ mask_ = refer.getMask(ref)['mask']
424
+ # sentence = ref['sentences'][0]['sent']
425
+
426
+ img = refer.Imgs[img_id]
427
+ # I = io.imread(osp.join(refer.IMAGE_DIR, img['file_name']))
428
+ # 假设`image_path`是原始图像的路径,`mask`是一个与原图像相同大小的二值数组
429
+ image_path = osp.join(refer.IMAGE_DIR, img['file_name'])
430
+ image = cv2.imread(image_path)
431
+ # mask = np.zeros(image.shape[:2], dtype=np.uint8) # 这里你需要有一个实际的mask
432
+
433
+ # custom_vis1(image, mask_)
434
+
435
+ image = custom_vis2(image, mask_)
436
+
437
+ # custom_vis3(image, mask_)
438
+
439
+ # image = custom_vis4(image=image, mask_=mask_, alpha=0.4)
440
+
441
+
442
+
443
+
444
+ # 保存结果图像到指定路径
445
+ image_dir = osp.join(save_path, str(img_id))
446
+ osp.exists(image_dir) or os.makedirs(image_dir)
447
+ # 复制原图
448
+ I = io.imread(osp.join(refer.IMAGE_DIR, img['file_name']))
449
+ io.imsave(osp.join(image_dir, img['file_name']), I)
450
+
451
+
452
+ cv2.imwrite(osp.join(image_dir, str(img_id)+".png"), image)
453
+
454
+ # 将json格式的ref保存
455
+ with open(osp.join(image_dir, str(img_id)+".json"), "w") as f:
456
+ json.dump(ref, f)
457
+
458
+
459
+
460
+
461
+
462
+ # i = 0
463
+ # for ref_id in ref_ids:
464
+ # i += 1
465
+ # ref = refer.loadRefs(ref_id)[0]
466
+ # if len(ref['sentences']) < 2:
467
+ # continue
468
+
469
+ # print(ref)
470
+ # print('The label is %s.' % refer.Cats[ref['category_id']])
471
+ # plt.figure()
472
+ # # refer.getMask(ref)
473
+ # refer.showMask(ref)
474
+
475
+ # # refer.showRef(ref, seg_box='seg')
476
+
477
+ # plt.show()
478
+ # if i == 0:
479
+ # break
480
+ # # save
481
+ # plt.savefig('tmp.png')
482
+
483
+ # plt.figure()
484
+ # refer.showMask(ref)
485
+ # plt.show()