Upload folder using huggingface_hub
Browse files- ASDA/dataset/__pycache__/data_loader.cpython-39.pyc +0 -0
- ASDA/dataset/__pycache__/data_loader_gref_sbert.cpython-39.pyc +0 -0
- ASDA/dataset/__pycache__/data_loader_rccp.cpython-39.pyc +0 -0
- ASDA/dataset/__pycache__/data_loader_test.cpython-39.pyc +0 -0
- ASDA/dataset/__pycache__/refer.cpython-39.pyc +0 -0
- ASDA/dataset/data.sh +12 -0
- ASDA/dataset/data_loader.py +314 -0
- ASDA/dataset/data_loader_gref_sbert.py +343 -0
- ASDA/dataset/data_loader_rccp.py +279 -0
- ASDA/dataset/data_loader_test.py +315 -0
- ASDA/dataset/data_process.py +225 -0
- ASDA/dataset/datascript.py +56 -0
- ASDA/dataset/refer.py +485 -0
ASDA/dataset/__pycache__/data_loader.cpython-39.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
ASDA/dataset/__pycache__/data_loader_gref_sbert.cpython-39.pyc
ADDED
|
Binary file (8.75 kB). View file
|
|
|
ASDA/dataset/__pycache__/data_loader_rccp.cpython-39.pyc
ADDED
|
Binary file (7 kB). View file
|
|
|
ASDA/dataset/__pycache__/data_loader_test.cpython-39.pyc
ADDED
|
Binary file (7.55 kB). View file
|
|
|
ASDA/dataset/__pycache__/refer.cpython-39.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
ASDA/dataset/data.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# data process
|
| 3 |
+
python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcoco --split unc --generate_mask
|
| 4 |
+
python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcoco+ --split unc --generate_mask
|
| 5 |
+
python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcocog --split google --generate_mask
|
| 6 |
+
python data_process.py --data_root ../ln_data --output_dir ../ln_data --dataset refcocog --split umd --generate_mask
|
| 7 |
+
|
| 8 |
+
# datascript
|
| 9 |
+
python datascript.py --dataset refcoco
|
| 10 |
+
python datascript.py --dataset refcoco+
|
| 11 |
+
python datascript.py --dataset refcocog_google
|
| 12 |
+
python datascript.py --dataset refcocog_umd
|
ASDA/dataset/data_loader.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
|
| 5 |
+
"""
|
| 6 |
+
import sys
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os.path as osp
|
| 12 |
+
import torch.utils.data as data
|
| 13 |
+
sys.path.append('.')
|
| 14 |
+
import utils
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
| 18 |
+
from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
|
| 19 |
+
import copy
|
| 20 |
+
|
| 21 |
+
import clip
|
| 22 |
+
|
| 23 |
+
sys.modules['utils'] = utils
|
| 24 |
+
cv2.setNumThreads(0)
|
| 25 |
+
|
| 26 |
+
def read_examples(input_line, unique_id):
|
| 27 |
+
"""Read a list of `InputExample`s from an input file."""
|
| 28 |
+
examples = []
|
| 29 |
+
# unique_id = 0
|
| 30 |
+
line = input_line #reader.readline()
|
| 31 |
+
# if not line:
|
| 32 |
+
# break
|
| 33 |
+
line = line.strip()
|
| 34 |
+
text_a = None
|
| 35 |
+
text_b = None
|
| 36 |
+
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
| 37 |
+
if m is None:
|
| 38 |
+
text_a = line
|
| 39 |
+
else:
|
| 40 |
+
text_a = m.group(1) #'man in black'
|
| 41 |
+
text_b = m.group(2)
|
| 42 |
+
|
| 43 |
+
examples.append(
|
| 44 |
+
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
| 45 |
+
# unique_id += 1
|
| 46 |
+
return examples
|
| 47 |
+
|
| 48 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
| 49 |
+
while True:
|
| 50 |
+
total_length = len(tokens_a) + len(tokens_b)
|
| 51 |
+
if total_length <= max_length:
|
| 52 |
+
break
|
| 53 |
+
if len(tokens_a) > len(tokens_b):
|
| 54 |
+
tokens_a.pop()
|
| 55 |
+
else:
|
| 56 |
+
tokens_b.pop()
|
| 57 |
+
|
| 58 |
+
## Bert text encoding
|
| 59 |
+
class InputExample(object):
|
| 60 |
+
def __init__(self, unique_id, text_a, text_b):
|
| 61 |
+
self.unique_id = unique_id
|
| 62 |
+
self.text_a = text_a
|
| 63 |
+
self.text_b = text_b
|
| 64 |
+
|
| 65 |
+
class InputFeatures(object):
|
| 66 |
+
"""A single set of features of data."""
|
| 67 |
+
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
| 68 |
+
self.unique_id = unique_id
|
| 69 |
+
self.tokens = tokens
|
| 70 |
+
self.input_ids = input_ids
|
| 71 |
+
self.input_mask = input_mask
|
| 72 |
+
self.input_type_ids = input_type_ids
|
| 73 |
+
|
| 74 |
+
def convert_examples_to_features(examples, seq_length, tokenizer):
|
| 75 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
| 76 |
+
features = []
|
| 77 |
+
for (ex_index, example) in enumerate(examples):
|
| 78 |
+
tokens_a = tokenizer.tokenize(example.text_a) # ['far', 'left', 'vase']
|
| 79 |
+
|
| 80 |
+
tokens_b = None
|
| 81 |
+
if example.text_b:
|
| 82 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
| 83 |
+
|
| 84 |
+
if tokens_b:
|
| 85 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
| 86 |
+
# length is less than the specified length.
|
| 87 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
| 88 |
+
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
| 89 |
+
else:
|
| 90 |
+
# Account for [CLS] and [SEP] with "- 2"
|
| 91 |
+
if len(tokens_a) > seq_length - 2:
|
| 92 |
+
tokens_a = tokens_a[0:(seq_length - 2)]
|
| 93 |
+
tokens = []
|
| 94 |
+
input_type_ids = []
|
| 95 |
+
tokens.append("[CLS]")
|
| 96 |
+
input_type_ids.append(0)
|
| 97 |
+
for token in tokens_a:
|
| 98 |
+
tokens.append(token)
|
| 99 |
+
input_type_ids.append(0)
|
| 100 |
+
tokens.append("[SEP]")
|
| 101 |
+
input_type_ids.append(0)
|
| 102 |
+
|
| 103 |
+
if tokens_b:
|
| 104 |
+
for token in tokens_b:
|
| 105 |
+
tokens.append(token)
|
| 106 |
+
input_type_ids.append(1)
|
| 107 |
+
tokens.append("[SEP]")
|
| 108 |
+
input_type_ids.append(1)
|
| 109 |
+
|
| 110 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 111 |
+
|
| 112 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
| 113 |
+
# tokens are attended to.
|
| 114 |
+
input_mask = [1] * len(input_ids)
|
| 115 |
+
|
| 116 |
+
# Zero-pad up to the sequence length.
|
| 117 |
+
while len(input_ids) < seq_length:
|
| 118 |
+
input_ids.append(0)
|
| 119 |
+
input_mask.append(0)
|
| 120 |
+
input_type_ids.append(0)
|
| 121 |
+
|
| 122 |
+
assert len(input_ids) == seq_length
|
| 123 |
+
assert len(input_mask) == seq_length
|
| 124 |
+
assert len(input_type_ids) == seq_length
|
| 125 |
+
features.append(
|
| 126 |
+
InputFeatures(
|
| 127 |
+
unique_id=example.unique_id,
|
| 128 |
+
tokens=tokens,
|
| 129 |
+
input_ids=input_ids,
|
| 130 |
+
input_mask=input_mask,
|
| 131 |
+
input_type_ids=input_type_ids))
|
| 132 |
+
return features
|
| 133 |
+
|
| 134 |
+
class DatasetNotFoundError(Exception):
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
class ReferDataset(data.Dataset):
|
| 138 |
+
SUPPORTED_DATASETS = {
|
| 139 |
+
'refcoco': {
|
| 140 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 141 |
+
'params': {'dataset': 'refcoco', 'split_by': 'unc'}
|
| 142 |
+
},
|
| 143 |
+
'refcoco+': {
|
| 144 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 145 |
+
'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
|
| 146 |
+
},
|
| 147 |
+
'refcocog': {
|
| 148 |
+
'splits': ('train', 'val', 'test'),
|
| 149 |
+
'params': {'dataset': 'refcocog', 'split_by': 'unc'}
|
| 150 |
+
},
|
| 151 |
+
'refcocog_g': {
|
| 152 |
+
'splits': ('train', 'val'),
|
| 153 |
+
'params': {'dataset': 'refcocog', 'split_by': 'google'}
|
| 154 |
+
},
|
| 155 |
+
'refcocog_u': {
|
| 156 |
+
'splits': ('train', 'val', 'test'),
|
| 157 |
+
'params': {'dataset': 'refcocog', 'split_by': 'unc'}
|
| 158 |
+
},
|
| 159 |
+
'grefcoco': {
|
| 160 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 161 |
+
'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
|
| 166 |
+
transform=None, augment=False, split='train', max_query_len=128,
|
| 167 |
+
bert_model='bert-base-uncased'):
|
| 168 |
+
self.images = []
|
| 169 |
+
self.data_root = data_root
|
| 170 |
+
self.split_root = split_root
|
| 171 |
+
self.dataset = dataset
|
| 172 |
+
self.imsize = imsize
|
| 173 |
+
self.query_len = max_query_len
|
| 174 |
+
self.transform = transform
|
| 175 |
+
self.split = split
|
| 176 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True) # should be true for English
|
| 177 |
+
self.augment=augment
|
| 178 |
+
|
| 179 |
+
valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
|
| 180 |
+
|
| 181 |
+
if split not in valid_splits:
|
| 182 |
+
raise ValueError(
|
| 183 |
+
'Dataset {0} does not have split {1}'.format(
|
| 184 |
+
self.dataset, split))
|
| 185 |
+
|
| 186 |
+
self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
|
| 187 |
+
if self.dataset == 'refcocog' :
|
| 188 |
+
mask_anno_str = '{0}_{1}'.format(self.dataset, splitby)
|
| 189 |
+
self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
|
| 190 |
+
else :
|
| 191 |
+
self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
|
| 192 |
+
|
| 193 |
+
self.im_dir = osp.join(self.data_root, 'images', 'train2014')
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if self.dataset == 'refcocog' :
|
| 197 |
+
dataset_path = osp.join(self.split_root, self.dataset + '_' + splitby)
|
| 198 |
+
splits = [split]
|
| 199 |
+
for split in splits:
|
| 200 |
+
imgset_file = '{0}_{1}_{2}.pth'.format(self.dataset, splitby, split)
|
| 201 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
| 202 |
+
self.images += torch.load(imgset_path)
|
| 203 |
+
else :
|
| 204 |
+
dataset_path = osp.join(self.split_root, self.dataset)
|
| 205 |
+
splits = [split]
|
| 206 |
+
for split in splits:
|
| 207 |
+
imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
|
| 208 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
| 209 |
+
self.images += torch.load(imgset_path)
|
| 210 |
+
|
| 211 |
+
def exists_dataset(self):
|
| 212 |
+
return osp.exists(osp.join(self.split_root, self.dataset))
|
| 213 |
+
|
| 214 |
+
def pull_item(self, idx):
|
| 215 |
+
img_file, seg_id, bbox, phrase = self.images[idx]
|
| 216 |
+
bbox = np.array(bbox, dtype=int) # x1y1x2y2
|
| 217 |
+
|
| 218 |
+
img_path = osp.join(self.im_dir, img_file)
|
| 219 |
+
img = cv2.imread(img_path) # BGR [512, 640, 3]
|
| 220 |
+
## duplicate channel if gray image
|
| 221 |
+
if img.shape[-1] > 1:
|
| 222 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
|
| 223 |
+
else:
|
| 224 |
+
img = np.stack([img] * 3)
|
| 225 |
+
|
| 226 |
+
## seg map
|
| 227 |
+
seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
|
| 228 |
+
seg_map = np.array(seg_map).astype(np.float32)
|
| 229 |
+
return img, phrase, bbox, seg_map
|
| 230 |
+
|
| 231 |
+
def __len__(self):
|
| 232 |
+
return len(self.images)
|
| 233 |
+
|
| 234 |
+
def __getitem__(self, idx):
|
| 235 |
+
img, phrase, bbox, seg_map = self.pull_item(idx)
|
| 236 |
+
phrase = phrase.lower()
|
| 237 |
+
if self.augment:
|
| 238 |
+
augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
|
| 239 |
+
True, True, True, False, False, False
|
| 240 |
+
|
| 241 |
+
## seems a bug in torch transformation resize, so separate in advance
|
| 242 |
+
h,w = img.shape[0], img.shape[1]
|
| 243 |
+
# print("img.shape", img.shape)
|
| 244 |
+
if self.augment:
|
| 245 |
+
## random horizontal flip
|
| 246 |
+
if augment_flip and random.random() > 0.5:
|
| 247 |
+
img = cv2.flip(img, 1)
|
| 248 |
+
seg_map = cv2.flip(seg_map, 1)
|
| 249 |
+
bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
|
| 250 |
+
phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
|
| 251 |
+
|
| 252 |
+
## random copy and add left or right
|
| 253 |
+
if augment_copy:
|
| 254 |
+
img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
|
| 255 |
+
|
| 256 |
+
## random erase for occluded
|
| 257 |
+
if augment_erase:
|
| 258 |
+
img, seg_map = random_erase(img, seg_map)
|
| 259 |
+
|
| 260 |
+
## random padding and crop
|
| 261 |
+
if augment_crop:
|
| 262 |
+
img, seg_map = random_crop(img, seg_map, 40, h, w)
|
| 263 |
+
|
| 264 |
+
## random intensity, saturation change
|
| 265 |
+
if augment_hsv:
|
| 266 |
+
fraction = 0.50
|
| 267 |
+
img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
|
| 268 |
+
S = img_hsv[:, :, 1].astype(np.float32)
|
| 269 |
+
V = img_hsv[:, :, 2].astype(np.float32)
|
| 270 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 271 |
+
if a > 1:
|
| 272 |
+
np.clip(S, a_min=0, a_max=255, out=S)
|
| 273 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 274 |
+
V *= a
|
| 275 |
+
if a > 1:
|
| 276 |
+
np.clip(V, a_min=0, a_max=255, out=V)
|
| 277 |
+
|
| 278 |
+
img_hsv[:, :, 1] = S.astype(np.uint8)
|
| 279 |
+
img_hsv[:, :, 2] = V.astype(np.uint8)
|
| 280 |
+
img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
|
| 281 |
+
|
| 282 |
+
img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
|
| 283 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 284 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 285 |
+
|
| 286 |
+
## random affine transformation
|
| 287 |
+
if augment_affine:
|
| 288 |
+
img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
|
| 289 |
+
degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
|
| 290 |
+
|
| 291 |
+
else: ## should be inference, or specified training
|
| 292 |
+
img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
|
| 293 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 294 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 295 |
+
|
| 296 |
+
draw_img = copy.deepcopy(img)
|
| 297 |
+
# Norm, to tensor
|
| 298 |
+
if self.transform is not None:
|
| 299 |
+
img = self.transform(img)
|
| 300 |
+
|
| 301 |
+
## encode phrase to clip input
|
| 302 |
+
word_id = clip.tokenize(phrase, 17, truncate=True)
|
| 303 |
+
word_mask = ~ (word_id == 0)
|
| 304 |
+
|
| 305 |
+
if self.augment: # train
|
| 306 |
+
seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
|
| 307 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 308 |
+
return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
|
| 309 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
|
| 310 |
+
else:
|
| 311 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 312 |
+
return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
|
| 313 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
|
| 314 |
+
np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
|
ASDA/dataset/data_loader_gref_sbert.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
|
| 5 |
+
"""
|
| 6 |
+
import sys
|
| 7 |
+
import cv2
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
import os.path as osp
|
| 14 |
+
import torch.utils.data as data
|
| 15 |
+
sys.path.append('.')
|
| 16 |
+
import utils
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
# from pytorch_pretrained_bert.tokenization import BertTokenizer
|
| 20 |
+
from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
|
| 21 |
+
import copy
|
| 22 |
+
|
| 23 |
+
import clip
|
| 24 |
+
|
| 25 |
+
sys.modules['utils'] = utils
|
| 26 |
+
cv2.setNumThreads(0)
|
| 27 |
+
|
| 28 |
+
class ReferDataset(data.Dataset):
|
| 29 |
+
SUPPORTED_DATASETS = {
|
| 30 |
+
'refcoco': {
|
| 31 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 32 |
+
'params': {'dataset': 'refcoco', 'split_by': 'unc'}
|
| 33 |
+
},
|
| 34 |
+
'refcoco+': {
|
| 35 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 36 |
+
'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
|
| 37 |
+
},
|
| 38 |
+
'refcocog': {
|
| 39 |
+
'splits': ('train', 'val', 'test'),
|
| 40 |
+
'params': {'dataset': 'refcocog', 'split_by': 'unc'}
|
| 41 |
+
},
|
| 42 |
+
'refcocog_g': {
|
| 43 |
+
'splits': ('train', 'val'),
|
| 44 |
+
'params': {'dataset': 'refcocog', 'split_by': 'google'}
|
| 45 |
+
},
|
| 46 |
+
'refcocog_u': {
|
| 47 |
+
'splits': ('train', 'val', 'test'),
|
| 48 |
+
'params': {'dataset': 'refcocog', 'split_by': 'unc'}
|
| 49 |
+
},
|
| 50 |
+
'grefcoco': {
|
| 51 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 52 |
+
'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _load_multi_obj_ref_ids(self):
|
| 58 |
+
# Load multi-object reference IDs based on configurations
|
| 59 |
+
if not self.exclude_multiobj and not self.exclude_position :
|
| 60 |
+
return None
|
| 61 |
+
elif self.exclude_position:
|
| 62 |
+
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt')
|
| 63 |
+
elif self.exclude_multiobj :
|
| 64 |
+
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt')
|
| 65 |
+
with open(multiobj_path, 'r') as f:
|
| 66 |
+
return [int(line.strip()) for line in f.readlines()]
|
| 67 |
+
|
| 68 |
+
def _load_metadata(self):
|
| 69 |
+
# Load metadata for hard positive verb phrases, hard negative queries
|
| 70 |
+
# we set refined file as default option
|
| 71 |
+
hardpos_path = '/data2/projects/seunghoon/VerbRIS/CrossVLT/hardpos_verdict_gref_v4.json'
|
| 72 |
+
with open(hardpos_path, 'r', encoding='utf-8') as f:
|
| 73 |
+
hardpos_json = json.load(f)
|
| 74 |
+
return hardpos_json
|
| 75 |
+
|
| 76 |
+
def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
|
| 77 |
+
transform=None, augment=False, split='train', max_query_len=128, metric_learning=None):
|
| 78 |
+
images_tmp = []
|
| 79 |
+
self.data_root = data_root
|
| 80 |
+
self.split_root = split_root
|
| 81 |
+
self.dataset = dataset
|
| 82 |
+
self.imsize = imsize
|
| 83 |
+
self.query_len = max_query_len
|
| 84 |
+
self.transform = transform
|
| 85 |
+
self.word_len = 17
|
| 86 |
+
self.emb_size = 384
|
| 87 |
+
self.split = split
|
| 88 |
+
self.augment=augment
|
| 89 |
+
|
| 90 |
+
valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
|
| 91 |
+
|
| 92 |
+
if split not in valid_splits:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
'Dataset {0} does not have split {1}'.format(
|
| 95 |
+
self.dataset, split))
|
| 96 |
+
|
| 97 |
+
self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
|
| 98 |
+
if self.dataset == 'refcocog' :
|
| 99 |
+
mask_anno_str = '{0}_{1}'.format(self.dataset, splitby)
|
| 100 |
+
self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
|
| 101 |
+
else :
|
| 102 |
+
self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
|
| 103 |
+
|
| 104 |
+
self.im_dir = osp.join(self.data_root, 'images', 'train2014')
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if self.dataset == 'refcocog' :
|
| 108 |
+
dataset_path = osp.join(self.split_root, self.dataset + '_' + splitby)
|
| 109 |
+
splits = [split]
|
| 110 |
+
for split in splits:
|
| 111 |
+
imgset_file = '{0}_{1}_{2}.pth'.format(self.dataset, splitby, split)
|
| 112 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
| 113 |
+
images_tmp += torch.load(imgset_path)
|
| 114 |
+
|
| 115 |
+
# metric learning options
|
| 116 |
+
self.ROOT = '/data2/projects/seunghoon/VerbRIS/VerbCentric_CY/'
|
| 117 |
+
self.all_hp_root = "/data2/dataset/RefCOCO/refcocog/SBERT_gref_umd"
|
| 118 |
+
# self.exclude_position = args.exclude_pos
|
| 119 |
+
self.exclude_position = True
|
| 120 |
+
self.exclude_multiobj = True
|
| 121 |
+
self.metric_learning = metric_learning
|
| 122 |
+
|
| 123 |
+
# self.metric_mode = args.metric_mode
|
| 124 |
+
self.hp_selection = 'strict'
|
| 125 |
+
|
| 126 |
+
# meta data loading
|
| 127 |
+
if self.metric_learning and self.split == 'train':
|
| 128 |
+
self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
|
| 129 |
+
self.hardpos_meta = self._load_metadata()
|
| 130 |
+
|
| 131 |
+
# make new self.images file with sentence idx and total sent num (per ref_id)
|
| 132 |
+
from collections import defaultdict
|
| 133 |
+
ref_sentence_counts = defaultdict(int)
|
| 134 |
+
for item in images_tmp:
|
| 135 |
+
ref_sentence_counts[item[1]] += 1
|
| 136 |
+
|
| 137 |
+
self.images = []
|
| 138 |
+
ref_sentence_indices = defaultdict(int)
|
| 139 |
+
for item in images_tmp:
|
| 140 |
+
img_name, seg_id, box, sentence = item
|
| 141 |
+
sent_index = ref_sentence_indices[seg_id]
|
| 142 |
+
total_sentences = ref_sentence_counts[seg_id]
|
| 143 |
+
self.images.append((img_name, seg_id, box, sentence, sent_index, total_sentences))
|
| 144 |
+
ref_sentence_indices[seg_id] += 1
|
| 145 |
+
|
| 146 |
+
else :
|
| 147 |
+
self.images = images_tmp
|
| 148 |
+
self.multi_obj_ref_ids = None
|
| 149 |
+
self.hardpos_meta = None
|
| 150 |
+
|
| 151 |
+
else :
|
| 152 |
+
dataset_path = osp.join(self.split_root, self.dataset)
|
| 153 |
+
splits = [split]
|
| 154 |
+
for split in splits:
|
| 155 |
+
imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
|
| 156 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
| 157 |
+
self.images += torch.load(imgset_path)
|
| 158 |
+
|
| 159 |
+
def exists_dataset(self):
|
| 160 |
+
return osp.exists(osp.join(self.split_root, self.dataset))
|
| 161 |
+
|
| 162 |
+
def _get_hardpos_verb(self, seg_id, sent_idx):
|
| 163 |
+
"""
|
| 164 |
+
Handle the logic for selecting hard positive verb phrases during metric learning.
|
| 165 |
+
Returns the sentence, raw_verb, and tokenized verb if applicable.
|
| 166 |
+
"""
|
| 167 |
+
# If the object appears multiple times, no hard positive is used
|
| 168 |
+
if seg_id in self.multi_obj_ref_ids:
|
| 169 |
+
verb_embed = torch.zeros(self.emb_size, dtype=torch.float32)
|
| 170 |
+
return '', verb_embed
|
| 171 |
+
|
| 172 |
+
# Extract metadata for hard positives if present
|
| 173 |
+
hardpos_dict = self.hardpos_meta.get(str(seg_id), {})
|
| 174 |
+
if self.hp_selection == 'strict' :
|
| 175 |
+
sent_id_list = list(hardpos_dict.keys())
|
| 176 |
+
cur_sent_id = sent_id_list[sent_idx]
|
| 177 |
+
cur_hardpos = hardpos_dict.get(cur_sent_id, {}).get('phrases', [])
|
| 178 |
+
|
| 179 |
+
if cur_hardpos:
|
| 180 |
+
# Assign a hard positive verb phrase if available
|
| 181 |
+
rand_index = random.randint(0, len(cur_hardpos) - 1)
|
| 182 |
+
raw_verb = cur_hardpos[rand_index]
|
| 183 |
+
verb_embed = torch.from_numpy(self._get_hardpos_embed(seg_id, cur_sent_id, rand_index))
|
| 184 |
+
# print("Positive phrase : " , raw_verb)
|
| 185 |
+
return raw_verb, verb_embed
|
| 186 |
+
|
| 187 |
+
verb_embed = torch.zeros(self.emb_size, dtype=torch.float32)
|
| 188 |
+
return '', verb_embed
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _get_hardpos_embed(self, seg_id, sent_id, rand_index):
|
| 192 |
+
emb_folder = os.path.join(self.all_hp_root, str(seg_id))
|
| 193 |
+
emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_{sent_id}_") and f.endswith(".npy")])
|
| 194 |
+
selected_emb_file = os.path.join(emb_folder, emb_files[rand_index])
|
| 195 |
+
|
| 196 |
+
return np.load(selected_emb_file)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def pull_item(self, idx):
|
| 200 |
+
# if metric learning and in train mode
|
| 201 |
+
if self.metric_learning and self.augment :
|
| 202 |
+
# sent_idx refers to index of sent among sent_num-1
|
| 203 |
+
img_file, seg_id, bbox, phrase, sent_idx, sent_num = self.images[idx]
|
| 204 |
+
else :
|
| 205 |
+
img_file, seg_id, bbox, phrase = self.images[idx]
|
| 206 |
+
bbox = np.array(bbox, dtype=int) # x1y1x2y2
|
| 207 |
+
|
| 208 |
+
img_path = osp.join(self.im_dir, img_file)
|
| 209 |
+
img = cv2.imread(img_path) # BGR [512, 640, 3]
|
| 210 |
+
## duplicate channel if gray image
|
| 211 |
+
if img.shape[-1] > 1:
|
| 212 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
|
| 213 |
+
else:
|
| 214 |
+
img = np.stack([img] * 3)
|
| 215 |
+
|
| 216 |
+
## seg map
|
| 217 |
+
seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
|
| 218 |
+
seg_map = np.array(seg_map).astype(np.float32)
|
| 219 |
+
|
| 220 |
+
if self.metric_learning and self.split == 'train' :
|
| 221 |
+
return img, phrase, bbox, seg_map, seg_id, sent_idx
|
| 222 |
+
else :
|
| 223 |
+
return img, phrase, bbox, seg_map, seg_id
|
| 224 |
+
|
| 225 |
+
def __len__(self):
|
| 226 |
+
return len(self.images)
|
| 227 |
+
|
| 228 |
+
def __getitem__(self, idx):
|
| 229 |
+
if self.metric_learning and self.augment :
|
| 230 |
+
img, phrase, bbox, seg_map, seg_id, sent_idx = self.pull_item(idx)
|
| 231 |
+
else :
|
| 232 |
+
img, phrase, bbox, seg_map, seg_id = self.pull_item(idx)
|
| 233 |
+
|
| 234 |
+
phrase = phrase.lower()
|
| 235 |
+
if self.augment:
|
| 236 |
+
augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
|
| 237 |
+
True, True, True, False, False, False
|
| 238 |
+
|
| 239 |
+
## seems a bug in torch transformation resize, so separate in advance
|
| 240 |
+
h,w = img.shape[0], img.shape[1]
|
| 241 |
+
# print("img.shape", img.shape)
|
| 242 |
+
if self.augment:
|
| 243 |
+
## random horizontal flip
|
| 244 |
+
if augment_flip and random.random() > 0.5:
|
| 245 |
+
img = cv2.flip(img, 1)
|
| 246 |
+
seg_map = cv2.flip(seg_map, 1)
|
| 247 |
+
bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
|
| 248 |
+
phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
|
| 249 |
+
|
| 250 |
+
## random copy and add left or right
|
| 251 |
+
if augment_copy:
|
| 252 |
+
img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
|
| 253 |
+
|
| 254 |
+
## random erase for occluded
|
| 255 |
+
if augment_erase:
|
| 256 |
+
img, seg_map = random_erase(img, seg_map)
|
| 257 |
+
|
| 258 |
+
## random padding and crop
|
| 259 |
+
if augment_crop:
|
| 260 |
+
img, seg_map = random_crop(img, seg_map, 40, h, w)
|
| 261 |
+
|
| 262 |
+
## random intensity, saturation change
|
| 263 |
+
if augment_hsv:
|
| 264 |
+
fraction = 0.50
|
| 265 |
+
img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
|
| 266 |
+
S = img_hsv[:, :, 1].astype(np.float32)
|
| 267 |
+
V = img_hsv[:, :, 2].astype(np.float32)
|
| 268 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 269 |
+
if a > 1:
|
| 270 |
+
np.clip(S, a_min=0, a_max=255, out=S)
|
| 271 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 272 |
+
V *= a
|
| 273 |
+
if a > 1:
|
| 274 |
+
np.clip(V, a_min=0, a_max=255, out=V)
|
| 275 |
+
|
| 276 |
+
img_hsv[:, :, 1] = S.astype(np.uint8)
|
| 277 |
+
img_hsv[:, :, 2] = V.astype(np.uint8)
|
| 278 |
+
img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
|
| 279 |
+
|
| 280 |
+
img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
|
| 281 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 282 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 283 |
+
|
| 284 |
+
## random affine transformation
|
| 285 |
+
if augment_affine:
|
| 286 |
+
img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
|
| 287 |
+
degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
|
| 288 |
+
|
| 289 |
+
else: ## should be inference, or specified training
|
| 290 |
+
img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
|
| 291 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 292 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 293 |
+
|
| 294 |
+
draw_img = copy.deepcopy(img)
|
| 295 |
+
# Norm, to tensor
|
| 296 |
+
if self.transform is not None:
|
| 297 |
+
img = self.transform(img)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
## encode phrase to clip input
|
| 301 |
+
word_id = clip.tokenize(phrase, 17, truncate=True)
|
| 302 |
+
word_mask = ~ (word_id == 0)
|
| 303 |
+
|
| 304 |
+
orig_word_id = np.array(word_id, dtype=int)
|
| 305 |
+
orig_word_mask = np.array(word_mask, dtype=int)
|
| 306 |
+
|
| 307 |
+
# Get hardpos verb phrase
|
| 308 |
+
if self.metric_learning and self.augment:
|
| 309 |
+
raw_hardpos, hardpos_emb = self._get_hardpos_verb(seg_id, sent_idx)
|
| 310 |
+
pos_type = 'nopos'
|
| 311 |
+
if raw_hardpos:
|
| 312 |
+
pos_type = 'hardpos'
|
| 313 |
+
hardpos_id = clip.tokenize(raw_hardpos, self.word_len, truncate=True)
|
| 314 |
+
else:
|
| 315 |
+
# Empty phrase → Create a zero tensor matching shape of tokenized input
|
| 316 |
+
hardpos_id = np.zeros((1, self.word_len), dtype=int)
|
| 317 |
+
|
| 318 |
+
# Masking
|
| 319 |
+
hardpos_mask = hardpos_id != 0 # Mask should be boolean
|
| 320 |
+
|
| 321 |
+
hp_word_id = np.array(hardpos_id, dtype=int)
|
| 322 |
+
hp_word_mask = np.array(hardpos_mask, dtype=int)
|
| 323 |
+
|
| 324 |
+
if self.augment: # train
|
| 325 |
+
seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
|
| 326 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 327 |
+
if self.metric_learning :
|
| 328 |
+
params = {
|
| 329 |
+
'hp_word_id' : hp_word_id,
|
| 330 |
+
'hp_word_mask' : hp_word_mask,
|
| 331 |
+
'hardpos_emb' : hardpos_emb.unsqueeze(0),
|
| 332 |
+
'pos_type' : pos_type
|
| 333 |
+
}
|
| 334 |
+
return img, orig_word_id, orig_word_mask, np.array(bbox, dtype=np.float32), \
|
| 335 |
+
np.array(seg_map, dtype=np.float32), params
|
| 336 |
+
else :
|
| 337 |
+
return img, orig_word_id, orig_word_mask, \
|
| 338 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
|
| 339 |
+
else:
|
| 340 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 341 |
+
return img, orig_word_id, orig_word_mask, \
|
| 342 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
|
| 343 |
+
np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
|
ASDA/dataset/data_loader_rccp.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
|
| 5 |
+
"""
|
| 6 |
+
import sys
|
| 7 |
+
import cv2
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
import os.path as osp
|
| 14 |
+
import torch.utils.data as data
|
| 15 |
+
sys.path.append('.')
|
| 16 |
+
import utils
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
# from pytorch_pretrained_bert.tokenization import BertTokenizer
|
| 20 |
+
from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
|
| 21 |
+
import copy
|
| 22 |
+
|
| 23 |
+
import clip
|
| 24 |
+
|
| 25 |
+
sys.modules['utils'] = utils
|
| 26 |
+
cv2.setNumThreads(0)
|
| 27 |
+
|
| 28 |
+
class ReferDataset(data.Dataset):
|
| 29 |
+
SUPPORTED_DATASETS = {
|
| 30 |
+
'refcoco': {
|
| 31 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 32 |
+
'params': {'dataset': 'refcoco', 'split_by': 'unc'}
|
| 33 |
+
},
|
| 34 |
+
'refcoco+': {
|
| 35 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 36 |
+
'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
|
| 37 |
+
},
|
| 38 |
+
'refcocog': {
|
| 39 |
+
'splits': ('train', 'val', 'test'),
|
| 40 |
+
'params': {'dataset': 'refcocog', 'split_by': 'unc'}
|
| 41 |
+
},
|
| 42 |
+
'refcocog_g': {
|
| 43 |
+
'splits': ('train', 'val'),
|
| 44 |
+
'params': {'dataset': 'refcocog', 'split_by': 'google'}
|
| 45 |
+
},
|
| 46 |
+
'refcocog_u': {
|
| 47 |
+
'splits': ('train', 'val', 'test'),
|
| 48 |
+
'params': {'dataset': 'refcocog', 'split_by': 'unc'}
|
| 49 |
+
},
|
| 50 |
+
'grefcoco': {
|
| 51 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 52 |
+
'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
|
| 58 |
+
transform=None, augment=False, split='train', max_query_len=128, metric_learning=None):
|
| 59 |
+
images_tmp = []
|
| 60 |
+
self.data_root = data_root
|
| 61 |
+
self.split_root = split_root
|
| 62 |
+
self.dataset = dataset
|
| 63 |
+
self.imsize = imsize
|
| 64 |
+
self.query_len = max_query_len
|
| 65 |
+
self.transform = transform
|
| 66 |
+
self.word_len = 17
|
| 67 |
+
self.emb_size = 384
|
| 68 |
+
self.split = split
|
| 69 |
+
self.augment=augment
|
| 70 |
+
|
| 71 |
+
valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
|
| 72 |
+
|
| 73 |
+
if split not in valid_splits:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
'Dataset {0} does not have split {1}'.format(
|
| 76 |
+
self.dataset, split))
|
| 77 |
+
|
| 78 |
+
self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
|
| 79 |
+
if self.dataset == 'refcocog' :
|
| 80 |
+
mask_anno_str = '{0}_{1}'.format(self.dataset, splitby)
|
| 81 |
+
self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
|
| 82 |
+
else :
|
| 83 |
+
self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
|
| 84 |
+
|
| 85 |
+
self.im_dir = osp.join(self.data_root, 'images', 'train2014')
|
| 86 |
+
|
| 87 |
+
# if self.dataset in ['refcoco', 'refcoco+']
|
| 88 |
+
dataset_path = osp.join(self.split_root, self.dataset)
|
| 89 |
+
splits = [split]
|
| 90 |
+
for split in splits:
|
| 91 |
+
imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
|
| 92 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
| 93 |
+
images_tmp += torch.load(imgset_path)
|
| 94 |
+
|
| 95 |
+
# hardpos related
|
| 96 |
+
self.ROOT = '/data2/dataset/RefCOCO/VRIS'
|
| 97 |
+
if self.dataset == 'refcoco' :
|
| 98 |
+
self.all_hp_root = '/data2/dataset/RefCOCO/refcoco/SBERT_rcc_unc'
|
| 99 |
+
elif self.dataset == 'refcoco+' :
|
| 100 |
+
self.all_hp_root = '/data2/dataset/RefCOCO/refcoco+/SBERT_rccp_unc'
|
| 101 |
+
|
| 102 |
+
self.metric_learning = metric_learning
|
| 103 |
+
if self.metric_learning :
|
| 104 |
+
self.exclude_position = True
|
| 105 |
+
self.exclude_multiobj = True
|
| 106 |
+
self.hp_selection = 'strict'
|
| 107 |
+
self.multi_obj_ref_ids = None
|
| 108 |
+
self.hardpos_meta = None
|
| 109 |
+
|
| 110 |
+
# make new self.images file with sentence idx and total sent num (per ref_id)
|
| 111 |
+
from collections import defaultdict
|
| 112 |
+
ref_sentence_counts = defaultdict(int)
|
| 113 |
+
for item in images_tmp:
|
| 114 |
+
ref_sentence_counts[item[1]] += 1
|
| 115 |
+
|
| 116 |
+
if self.split == 'train' :
|
| 117 |
+
images = []
|
| 118 |
+
ref_sentence_indices = defaultdict(int)
|
| 119 |
+
for item in images_tmp:
|
| 120 |
+
img_name, seg_id, box, sentence = item
|
| 121 |
+
sent_index = ref_sentence_indices[seg_id]
|
| 122 |
+
total_sentences = ref_sentence_counts[seg_id]
|
| 123 |
+
images.append((img_name, seg_id, box, sentence, sent_index, total_sentences))
|
| 124 |
+
ref_sentence_indices[seg_id] += 1
|
| 125 |
+
self.images = images
|
| 126 |
+
else :
|
| 127 |
+
self.images = images_tmp
|
| 128 |
+
else :
|
| 129 |
+
self.images = images_tmp
|
| 130 |
+
|
| 131 |
+
def exists_dataset(self):
|
| 132 |
+
return osp.exists(osp.join(self.split_root, self.dataset))
|
| 133 |
+
|
| 134 |
+
def _get_hardpos_verb_rcc(self, seg_id, sent_idx):
|
| 135 |
+
emb_folder = os.path.join(self.all_hp_root, str(seg_id))
|
| 136 |
+
emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_") and f.endswith(".npy")])
|
| 137 |
+
if self.hp_selection == 'strict' :
|
| 138 |
+
# choose only corresponding (selected) sentence embedding
|
| 139 |
+
emb_file = emb_files[sent_idx]
|
| 140 |
+
else :
|
| 141 |
+
# choose any sentence embedding
|
| 142 |
+
emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_") and f.endswith(".npy")])
|
| 143 |
+
emb_file = random.choice(emb_files)
|
| 144 |
+
selected_emb = np.load(os.path.join(emb_folder, emb_file))
|
| 145 |
+
verb_embed = torch.from_numpy(selected_emb)
|
| 146 |
+
return verb_embed
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def pull_item(self, idx):
|
| 150 |
+
# if metric learning and in train mode
|
| 151 |
+
if self.metric_learning and self.augment :
|
| 152 |
+
# sent_idx refers to index of sent among sent_num-1
|
| 153 |
+
img_file, seg_id, bbox, phrase, sent_idx, sent_num = self.images[idx]
|
| 154 |
+
else :
|
| 155 |
+
img_file, seg_id, bbox, phrase = self.images[idx]
|
| 156 |
+
bbox = np.array(bbox, dtype=int) # x1y1x2y2
|
| 157 |
+
|
| 158 |
+
img_path = osp.join(self.im_dir, img_file)
|
| 159 |
+
img = cv2.imread(img_path) # BGR [512, 640, 3]
|
| 160 |
+
## duplicate channel if gray image
|
| 161 |
+
if img.shape[-1] > 1:
|
| 162 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
|
| 163 |
+
else:
|
| 164 |
+
img = np.stack([img] * 3)
|
| 165 |
+
|
| 166 |
+
## seg map
|
| 167 |
+
seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
|
| 168 |
+
seg_map = np.array(seg_map).astype(np.float32)
|
| 169 |
+
|
| 170 |
+
if self.metric_learning and self.split == 'train' :
|
| 171 |
+
return img, phrase, bbox, seg_map, seg_id, sent_idx
|
| 172 |
+
else :
|
| 173 |
+
return img, phrase, bbox, seg_map, seg_id
|
| 174 |
+
|
| 175 |
+
def __len__(self):
|
| 176 |
+
return len(self.images)
|
| 177 |
+
|
| 178 |
+
def __getitem__(self, idx):
|
| 179 |
+
if self.metric_learning and self.augment :
|
| 180 |
+
img, phrase, bbox, seg_map, seg_id, sent_idx = self.pull_item(idx)
|
| 181 |
+
else :
|
| 182 |
+
img, phrase, bbox, seg_map, seg_id = self.pull_item(idx)
|
| 183 |
+
|
| 184 |
+
phrase = phrase.lower()
|
| 185 |
+
if self.augment:
|
| 186 |
+
augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
|
| 187 |
+
True, True, True, False, False, False
|
| 188 |
+
|
| 189 |
+
## seems a bug in torch transformation resize, so separate in advance
|
| 190 |
+
h,w = img.shape[0], img.shape[1]
|
| 191 |
+
# print("img.shape", img.shape)
|
| 192 |
+
if self.augment:
|
| 193 |
+
## random horizontal flip
|
| 194 |
+
if augment_flip and random.random() > 0.5:
|
| 195 |
+
img = cv2.flip(img, 1)
|
| 196 |
+
seg_map = cv2.flip(seg_map, 1)
|
| 197 |
+
bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
|
| 198 |
+
phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
|
| 199 |
+
|
| 200 |
+
## random copy and add left or right
|
| 201 |
+
if augment_copy:
|
| 202 |
+
img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
|
| 203 |
+
|
| 204 |
+
## random erase for occluded
|
| 205 |
+
if augment_erase:
|
| 206 |
+
img, seg_map = random_erase(img, seg_map)
|
| 207 |
+
|
| 208 |
+
## random padding and crop
|
| 209 |
+
if augment_crop:
|
| 210 |
+
img, seg_map = random_crop(img, seg_map, 40, h, w)
|
| 211 |
+
|
| 212 |
+
## random intensity, saturation change
|
| 213 |
+
if augment_hsv:
|
| 214 |
+
fraction = 0.50
|
| 215 |
+
img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
|
| 216 |
+
S = img_hsv[:, :, 1].astype(np.float32)
|
| 217 |
+
V = img_hsv[:, :, 2].astype(np.float32)
|
| 218 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 219 |
+
if a > 1:
|
| 220 |
+
np.clip(S, a_min=0, a_max=255, out=S)
|
| 221 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 222 |
+
V *= a
|
| 223 |
+
if a > 1:
|
| 224 |
+
np.clip(V, a_min=0, a_max=255, out=V)
|
| 225 |
+
|
| 226 |
+
img_hsv[:, :, 1] = S.astype(np.uint8)
|
| 227 |
+
img_hsv[:, :, 2] = V.astype(np.uint8)
|
| 228 |
+
img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
|
| 229 |
+
|
| 230 |
+
img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
|
| 231 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 232 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 233 |
+
|
| 234 |
+
## random affine transformation
|
| 235 |
+
if augment_affine:
|
| 236 |
+
img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
|
| 237 |
+
degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
|
| 238 |
+
|
| 239 |
+
else: ## should be inference, or specified training
|
| 240 |
+
img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
|
| 241 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 242 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 243 |
+
|
| 244 |
+
draw_img = copy.deepcopy(img)
|
| 245 |
+
# Norm, to tensor
|
| 246 |
+
if self.transform is not None:
|
| 247 |
+
img = self.transform(img)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
## encode phrase to clip input
|
| 251 |
+
word_id = clip.tokenize(phrase, 17, truncate=True)
|
| 252 |
+
word_mask = ~ (word_id == 0)
|
| 253 |
+
|
| 254 |
+
orig_word_id = np.array(word_id, dtype=int)
|
| 255 |
+
orig_word_mask = np.array(word_mask, dtype=int)
|
| 256 |
+
|
| 257 |
+
# Get hardpos verb phrase
|
| 258 |
+
if self.metric_learning and self.augment:
|
| 259 |
+
original_emb = self._get_hardpos_verb_rcc(seg_id, sent_idx)
|
| 260 |
+
|
| 261 |
+
if self.augment: # train
|
| 262 |
+
seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
|
| 263 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 264 |
+
if self.metric_learning :
|
| 265 |
+
params = {
|
| 266 |
+
'seg_id' : seg_id,
|
| 267 |
+
'sent' : phrase,
|
| 268 |
+
'hardpos_emb' : original_emb.unsqueeze(0)
|
| 269 |
+
}
|
| 270 |
+
return img, orig_word_id, orig_word_mask, np.array(bbox, dtype=np.float32), \
|
| 271 |
+
np.array(seg_map, dtype=np.float32), params
|
| 272 |
+
else :
|
| 273 |
+
return img, orig_word_id, orig_word_mask, \
|
| 274 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
|
| 275 |
+
else:
|
| 276 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 277 |
+
return img, orig_word_id, orig_word_mask, \
|
| 278 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
|
| 279 |
+
np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
|
ASDA/dataset/data_loader_test.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset.
|
| 5 |
+
"""
|
| 6 |
+
import sys
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os.path as osp
|
| 12 |
+
import torch.utils.data as data
|
| 13 |
+
sys.path.append('.')
|
| 14 |
+
import utils
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
| 18 |
+
from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase
|
| 19 |
+
import copy
|
| 20 |
+
|
| 21 |
+
import clip
|
| 22 |
+
|
| 23 |
+
sys.modules['utils'] = utils
|
| 24 |
+
cv2.setNumThreads(0)
|
| 25 |
+
|
| 26 |
+
def read_examples(input_line, unique_id):
|
| 27 |
+
"""Read a list of `InputExample`s from an input file."""
|
| 28 |
+
examples = []
|
| 29 |
+
# unique_id = 0
|
| 30 |
+
line = input_line #reader.readline()
|
| 31 |
+
# if not line:
|
| 32 |
+
# break
|
| 33 |
+
line = line.strip()
|
| 34 |
+
text_a = None
|
| 35 |
+
text_b = None
|
| 36 |
+
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
| 37 |
+
if m is None:
|
| 38 |
+
text_a = line
|
| 39 |
+
else:
|
| 40 |
+
text_a = m.group(1) #'man in black'
|
| 41 |
+
text_b = m.group(2)
|
| 42 |
+
|
| 43 |
+
examples.append(
|
| 44 |
+
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
| 45 |
+
# unique_id += 1
|
| 46 |
+
return examples
|
| 47 |
+
|
| 48 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
| 49 |
+
while True:
|
| 50 |
+
total_length = len(tokens_a) + len(tokens_b)
|
| 51 |
+
if total_length <= max_length:
|
| 52 |
+
break
|
| 53 |
+
if len(tokens_a) > len(tokens_b):
|
| 54 |
+
tokens_a.pop()
|
| 55 |
+
else:
|
| 56 |
+
tokens_b.pop()
|
| 57 |
+
|
| 58 |
+
## Bert text encoding
|
| 59 |
+
class InputExample(object):
|
| 60 |
+
def __init__(self, unique_id, text_a, text_b):
|
| 61 |
+
self.unique_id = unique_id
|
| 62 |
+
self.text_a = text_a
|
| 63 |
+
self.text_b = text_b
|
| 64 |
+
|
| 65 |
+
class InputFeatures(object):
|
| 66 |
+
"""A single set of features of data."""
|
| 67 |
+
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
| 68 |
+
self.unique_id = unique_id
|
| 69 |
+
self.tokens = tokens
|
| 70 |
+
self.input_ids = input_ids
|
| 71 |
+
self.input_mask = input_mask
|
| 72 |
+
self.input_type_ids = input_type_ids
|
| 73 |
+
|
| 74 |
+
def convert_examples_to_features(examples, seq_length, tokenizer):
|
| 75 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
| 76 |
+
features = []
|
| 77 |
+
for (ex_index, example) in enumerate(examples):
|
| 78 |
+
tokens_a = tokenizer.tokenize(example.text_a) # ['far', 'left', 'vase']
|
| 79 |
+
|
| 80 |
+
tokens_b = None
|
| 81 |
+
if example.text_b:
|
| 82 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
| 83 |
+
|
| 84 |
+
if tokens_b:
|
| 85 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
| 86 |
+
# length is less than the specified length.
|
| 87 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
| 88 |
+
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
| 89 |
+
else:
|
| 90 |
+
# Account for [CLS] and [SEP] with "- 2"
|
| 91 |
+
if len(tokens_a) > seq_length - 2:
|
| 92 |
+
tokens_a = tokens_a[0:(seq_length - 2)]
|
| 93 |
+
tokens = []
|
| 94 |
+
input_type_ids = []
|
| 95 |
+
tokens.append("[CLS]")
|
| 96 |
+
input_type_ids.append(0)
|
| 97 |
+
for token in tokens_a:
|
| 98 |
+
tokens.append(token)
|
| 99 |
+
input_type_ids.append(0)
|
| 100 |
+
tokens.append("[SEP]")
|
| 101 |
+
input_type_ids.append(0)
|
| 102 |
+
|
| 103 |
+
if tokens_b:
|
| 104 |
+
for token in tokens_b:
|
| 105 |
+
tokens.append(token)
|
| 106 |
+
input_type_ids.append(1)
|
| 107 |
+
tokens.append("[SEP]")
|
| 108 |
+
input_type_ids.append(1)
|
| 109 |
+
|
| 110 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 111 |
+
|
| 112 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
| 113 |
+
# tokens are attended to.
|
| 114 |
+
input_mask = [1] * len(input_ids)
|
| 115 |
+
|
| 116 |
+
# Zero-pad up to the sequence length.
|
| 117 |
+
while len(input_ids) < seq_length:
|
| 118 |
+
input_ids.append(0)
|
| 119 |
+
input_mask.append(0)
|
| 120 |
+
input_type_ids.append(0)
|
| 121 |
+
|
| 122 |
+
assert len(input_ids) == seq_length
|
| 123 |
+
assert len(input_mask) == seq_length
|
| 124 |
+
assert len(input_type_ids) == seq_length
|
| 125 |
+
features.append(
|
| 126 |
+
InputFeatures(
|
| 127 |
+
unique_id=example.unique_id,
|
| 128 |
+
tokens=tokens,
|
| 129 |
+
input_ids=input_ids,
|
| 130 |
+
input_mask=input_mask,
|
| 131 |
+
input_type_ids=input_type_ids))
|
| 132 |
+
return features
|
| 133 |
+
|
| 134 |
+
class DatasetNotFoundError(Exception):
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
class ReferDataset(data.Dataset):
|
| 138 |
+
SUPPORTED_DATASETS = {
|
| 139 |
+
'refcoco': {
|
| 140 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 141 |
+
'params': {'dataset': 'refcoco', 'split_by': 'unc'}
|
| 142 |
+
},
|
| 143 |
+
'refcoco+': {
|
| 144 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 145 |
+
'params': {'dataset': 'refcoco+', 'split_by': 'unc'}
|
| 146 |
+
},
|
| 147 |
+
'refcocog': {
|
| 148 |
+
'splits': ('train', 'val', 'test'),
|
| 149 |
+
'params': {'dataset': 'refcocog', 'split_by': 'umd'}
|
| 150 |
+
},
|
| 151 |
+
'refcocog_g': {
|
| 152 |
+
'splits': ('train', 'val'),
|
| 153 |
+
'params': {'dataset': 'refcocog', 'split_by': 'google'}
|
| 154 |
+
},
|
| 155 |
+
'refcocog_u': {
|
| 156 |
+
'splits': ('train', 'val', 'test'),
|
| 157 |
+
'params': {'dataset': 'refcocog', 'split_by': 'umd'}
|
| 158 |
+
},
|
| 159 |
+
'grefcoco': {
|
| 160 |
+
'splits': ('train', 'val', 'testA', 'testB'),
|
| 161 |
+
'params': {'dataset': 'grefcoco', 'split_by': 'unc'}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd',
|
| 166 |
+
transform=None, augment=False, split='train', max_query_len=128,
|
| 167 |
+
bert_model='bert-base-uncased'):
|
| 168 |
+
self.images = []
|
| 169 |
+
self.data_root = data_root
|
| 170 |
+
self.split_root = split_root
|
| 171 |
+
self.dataset = dataset
|
| 172 |
+
self.imsize = imsize
|
| 173 |
+
self.query_len = max_query_len
|
| 174 |
+
self.transform = transform
|
| 175 |
+
self.split = split
|
| 176 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True) # should be true for English
|
| 177 |
+
self.augment=augment
|
| 178 |
+
|
| 179 |
+
valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits']
|
| 180 |
+
|
| 181 |
+
if split not in valid_splits:
|
| 182 |
+
raise ValueError(
|
| 183 |
+
'Dataset {0} does not have split {1}'.format(
|
| 184 |
+
self.dataset, split))
|
| 185 |
+
|
| 186 |
+
self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt')
|
| 187 |
+
if self.dataset == 'refcocog_u' :
|
| 188 |
+
dataset = 'refcocog'
|
| 189 |
+
mask_anno_str = '{0}_{1}'.format(dataset, splitby)
|
| 190 |
+
self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str)
|
| 191 |
+
else :
|
| 192 |
+
self.mask_root = osp.join(self.data_root, 'masks', self.dataset)
|
| 193 |
+
|
| 194 |
+
self.im_dir = osp.join(self.data_root, 'images', 'train2014')
|
| 195 |
+
|
| 196 |
+
if self.dataset == 'refcocog_u' :
|
| 197 |
+
dataset = 'refcocog'
|
| 198 |
+
dataset_path = osp.join(self.split_root, dataset + '_' + splitby)
|
| 199 |
+
splits = [split]
|
| 200 |
+
for split in splits:
|
| 201 |
+
imgset_file = '{0}_{1}_{2}.pth'.format(dataset, splitby, split)
|
| 202 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
| 203 |
+
self.images += torch.load(imgset_path)
|
| 204 |
+
else :
|
| 205 |
+
dataset_path = osp.join(self.split_root, self.dataset)
|
| 206 |
+
splits = [split]
|
| 207 |
+
for split in splits:
|
| 208 |
+
imgset_file = '{0}_{1}.pth'.format(self.dataset, split)
|
| 209 |
+
imgset_path = osp.join(dataset_path, imgset_file)
|
| 210 |
+
self.images += torch.load(imgset_path)
|
| 211 |
+
|
| 212 |
+
# def exists_dataset(self):
|
| 213 |
+
# return osp.exists(osp.join(self.split_root, self.dataset))
|
| 214 |
+
|
| 215 |
+
def pull_item(self, idx):
|
| 216 |
+
img_file, seg_id, bbox, phrase = self.images[idx]
|
| 217 |
+
bbox = np.array(bbox, dtype=int) # x1y1x2y2
|
| 218 |
+
|
| 219 |
+
img_path = osp.join(self.im_dir, img_file)
|
| 220 |
+
img = cv2.imread(img_path) # BGR [512, 640, 3]
|
| 221 |
+
## duplicate channel if gray image
|
| 222 |
+
if img.shape[-1] > 1:
|
| 223 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #RGB
|
| 224 |
+
else:
|
| 225 |
+
img = np.stack([img] * 3)
|
| 226 |
+
|
| 227 |
+
## seg map
|
| 228 |
+
seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) # [512, 640]
|
| 229 |
+
seg_map = np.array(seg_map).astype(np.float32)
|
| 230 |
+
return img, phrase, bbox, seg_map
|
| 231 |
+
|
| 232 |
+
def __len__(self):
|
| 233 |
+
return len(self.images)
|
| 234 |
+
|
| 235 |
+
def __getitem__(self, idx):
|
| 236 |
+
img, phrase, bbox, seg_map = self.pull_item(idx)
|
| 237 |
+
phrase = phrase.lower()
|
| 238 |
+
if self.augment:
|
| 239 |
+
augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \
|
| 240 |
+
True, True, True, False, False, False
|
| 241 |
+
|
| 242 |
+
## seems a bug in torch transformation resize, so separate in advance
|
| 243 |
+
h,w = img.shape[0], img.shape[1]
|
| 244 |
+
# print("img.shape", img.shape)
|
| 245 |
+
if self.augment:
|
| 246 |
+
## random horizontal flip
|
| 247 |
+
if augment_flip and random.random() > 0.5:
|
| 248 |
+
img = cv2.flip(img, 1)
|
| 249 |
+
seg_map = cv2.flip(seg_map, 1)
|
| 250 |
+
bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1
|
| 251 |
+
phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
|
| 252 |
+
|
| 253 |
+
## random copy and add left or right
|
| 254 |
+
if augment_copy:
|
| 255 |
+
img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox)
|
| 256 |
+
|
| 257 |
+
## random erase for occluded
|
| 258 |
+
if augment_erase:
|
| 259 |
+
img, seg_map = random_erase(img, seg_map)
|
| 260 |
+
|
| 261 |
+
## random padding and crop
|
| 262 |
+
if augment_crop:
|
| 263 |
+
img, seg_map = random_crop(img, seg_map, 40, h, w)
|
| 264 |
+
|
| 265 |
+
## random intensity, saturation change
|
| 266 |
+
if augment_hsv:
|
| 267 |
+
fraction = 0.50
|
| 268 |
+
img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV)
|
| 269 |
+
S = img_hsv[:, :, 1].astype(np.float32)
|
| 270 |
+
V = img_hsv[:, :, 2].astype(np.float32)
|
| 271 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 272 |
+
if a > 1:
|
| 273 |
+
np.clip(S, a_min=0, a_max=255, out=S)
|
| 274 |
+
a = (random.random() * 2 - 1) * fraction + 1
|
| 275 |
+
V *= a
|
| 276 |
+
if a > 1:
|
| 277 |
+
np.clip(V, a_min=0, a_max=255, out=V)
|
| 278 |
+
|
| 279 |
+
img_hsv[:, :, 1] = S.astype(np.uint8)
|
| 280 |
+
img_hsv[:, :, 2] = V.astype(np.uint8)
|
| 281 |
+
img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB)
|
| 282 |
+
|
| 283 |
+
img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize)
|
| 284 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 285 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 286 |
+
|
| 287 |
+
## random affine transformation
|
| 288 |
+
if augment_affine:
|
| 289 |
+
img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \
|
| 290 |
+
degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) # 255 white fill
|
| 291 |
+
|
| 292 |
+
else: ## should be inference, or specified training
|
| 293 |
+
img, _, ratio, dw, dh = letterbox(img, None, self.imsize)
|
| 294 |
+
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw
|
| 295 |
+
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh
|
| 296 |
+
|
| 297 |
+
draw_img = copy.deepcopy(img)
|
| 298 |
+
# Norm, to tensor
|
| 299 |
+
if self.transform is not None:
|
| 300 |
+
img = self.transform(img)
|
| 301 |
+
|
| 302 |
+
## encode phrase to clip input
|
| 303 |
+
word_id = clip.tokenize(phrase, 17, truncate=True)
|
| 304 |
+
word_mask = ~ (word_id == 0)
|
| 305 |
+
|
| 306 |
+
if self.augment: # train
|
| 307 |
+
seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) # (208, 208)
|
| 308 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 309 |
+
return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
|
| 310 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32)
|
| 311 |
+
else:
|
| 312 |
+
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]])
|
| 313 |
+
return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \
|
| 314 |
+
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \
|
| 315 |
+
np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8)
|
ASDA/dataset/data_process.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# encoding=utf8
|
| 2 |
+
# %matplotlib inline
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from refer import REFER
|
| 6 |
+
import os.path as osp
|
| 7 |
+
import cv2
|
| 8 |
+
import argparse
|
| 9 |
+
parser = argparse.ArgumentParser(description='Data preparation')
|
| 10 |
+
parser.add_argument('--data_root', type=str) # contains refclef, refcoco, refcoco+, refcocog and images
|
| 11 |
+
parser.add_argument('--output_dir', type=str)
|
| 12 |
+
parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog'], default='refcoco')
|
| 13 |
+
parser.add_argument('--split', type=str,default='umd')
|
| 14 |
+
parser.add_argument('--generate_mask', action='store_true')
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
# data_root # contains refclef, refcoco, refcoco+, refcocog and images
|
| 17 |
+
refer = REFER(args.data_root, args.dataset, args.split)
|
| 18 |
+
|
| 19 |
+
print ('dataset [%s_%s] contains: ' % (args.dataset, args.split))
|
| 20 |
+
ref_ids = refer.getRefIds()
|
| 21 |
+
image_ids = refer.getImgIds()
|
| 22 |
+
print ('%s expressions for %s refs in %s images.' % (len(refer.Sents), len(ref_ids), len(image_ids)))
|
| 23 |
+
|
| 24 |
+
print('\nAmong them:')
|
| 25 |
+
if args.dataset == 'refclef':
|
| 26 |
+
if args.split == 'unc':
|
| 27 |
+
splits = ['train', 'val', 'testA','testB','testC']
|
| 28 |
+
else:
|
| 29 |
+
splits = ['train', 'val', 'test']
|
| 30 |
+
elif args.dataset == 'refcoco':
|
| 31 |
+
splits = ['train', 'val', 'testA', 'testB']
|
| 32 |
+
elif args.dataset == 'refcoco+':
|
| 33 |
+
splits = ['train', 'val', 'testA', 'testB']
|
| 34 |
+
elif args.dataset == 'grefcoco':
|
| 35 |
+
splits = ['train', 'val', 'testA', 'testB']
|
| 36 |
+
elif args.dataset == 'refcocog':
|
| 37 |
+
splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# split data as a type in splits list
|
| 42 |
+
for split in splits:
|
| 43 |
+
ref_ids = refer.getRefIds(split=split)
|
| 44 |
+
print('%s refs are in split [%s].' % (len(ref_ids), split))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# show a batch data with bounding box,cat,sentences
|
| 48 |
+
def show_a_batch(batch_size):
|
| 49 |
+
split='train'
|
| 50 |
+
# batch_size=32
|
| 51 |
+
ref_ids = refer.getRefIds(split=split)
|
| 52 |
+
print(split+'_size:',len(ref_ids))
|
| 53 |
+
batch_index=list(np.random.choice(len(ref_ids),batch_size))
|
| 54 |
+
|
| 55 |
+
# print(refer.Refs)
|
| 56 |
+
ref_id = [ref_ids[i] for i in batch_index]
|
| 57 |
+
refs = [refer.Refs[i] for i in ref_id]
|
| 58 |
+
bboxs=[refer.getRefBox(i) for i in ref_id]
|
| 59 |
+
sentences=[ref['sentences'] for ref in refs]
|
| 60 |
+
image_urls=[refer.loadImgs(image_ids=ref['image_id']) for ref in refs]
|
| 61 |
+
cats = [refer.loadCats(cat_ids=ref['category_id']) for ref in refs]
|
| 62 |
+
# plt.figure()
|
| 63 |
+
# plt.subplot(batch_size)
|
| 64 |
+
grid_width = 2
|
| 65 |
+
grid_height = int(batch_size / grid_width)
|
| 66 |
+
# fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*10, 10*grid_height))
|
| 67 |
+
for i in range(batch_size):
|
| 68 |
+
print('bbox for batch[{}]:'.format(i),bboxs[i])
|
| 69 |
+
print('sentences for batch[{}]:'.format(i))
|
| 70 |
+
for sid, sent in enumerate(sentences[i]):
|
| 71 |
+
print('%s. %s' % (sid+1, sent['sent']))
|
| 72 |
+
print('cats for batch[{}]:'.format(i), cats[i])
|
| 73 |
+
|
| 74 |
+
image_url=image_urls[i][0]
|
| 75 |
+
image=cv2.imread(osp.join(refer.IMAGE_DIR, image_url['file_name']))
|
| 76 |
+
print(image.shape)
|
| 77 |
+
# print(bboxs[i][0])
|
| 78 |
+
cv2.rectangle(image,(int(bboxs[i][0]), int(bboxs[i][1])), (int(bboxs[i][0]+bboxs[i][2]),int(bboxs[i][1]+ bboxs[i][3])),255,3)
|
| 79 |
+
cv2.putText(image,
|
| 80 |
+
str(sent['sent']),
|
| 81 |
+
(20, 20),
|
| 82 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 83 |
+
.9,(0,255,0), 2)
|
| 84 |
+
os.mkdir('debug_vis')
|
| 85 |
+
cv2.imwrite('./debug_vis/'+image_url['file_name'], image)
|
| 86 |
+
cv2.imwrite('./debug_vis/mask'+image_url['file_name'], refer.getMask(refs[i])['mask']*255)
|
| 87 |
+
# ax.imshow(image)
|
| 88 |
+
# plt.show()
|
| 89 |
+
|
| 90 |
+
def cat_process(cat):
|
| 91 |
+
if cat >= 1 and cat <= 11:
|
| 92 |
+
cat = cat - 1
|
| 93 |
+
elif cat >= 13 and cat <= 25:
|
| 94 |
+
cat = cat - 2
|
| 95 |
+
elif cat >= 27 and cat <= 28:
|
| 96 |
+
cat = cat - 3
|
| 97 |
+
elif cat >= 31 and cat <= 44:
|
| 98 |
+
cat = cat - 5
|
| 99 |
+
elif cat >= 46 and cat <= 65:
|
| 100 |
+
cat = cat - 6
|
| 101 |
+
elif cat == 67:
|
| 102 |
+
cat = cat - 7
|
| 103 |
+
elif cat == 70:
|
| 104 |
+
cat = cat - 9
|
| 105 |
+
elif cat >= 72 and cat <= 82:
|
| 106 |
+
cat = cat - 10
|
| 107 |
+
elif cat >= 84 and cat <= 90:
|
| 108 |
+
cat = cat - 11
|
| 109 |
+
return cat
|
| 110 |
+
|
| 111 |
+
def bbox_process(bbox,cat,segement_id):
|
| 112 |
+
x_min = int(bbox[0])
|
| 113 |
+
y_min = int(bbox[1])
|
| 114 |
+
x_max = x_min + int(bbox[2])
|
| 115 |
+
y_max = y_min + int(bbox[3])
|
| 116 |
+
box_info = " %d,%d,%d,%d,%d,%d" % (int(x_min), int(y_min), int(x_max), int(y_max), int(cat),int(segement_id))
|
| 117 |
+
return box_info
|
| 118 |
+
|
| 119 |
+
def prepare_dataset(dataset,splits,output_dir,generate_mask=False):
|
| 120 |
+
# split_type='train'
|
| 121 |
+
# splits=[split_type]
|
| 122 |
+
# batch_size=32
|
| 123 |
+
if dataset == 'refcocog':
|
| 124 |
+
dataset = 'refcocog_' + args.split
|
| 125 |
+
if not os.path.exists(os.path.join(output_dir,'anns',dataset)):
|
| 126 |
+
os.makedirs(os.path.join(output_dir,'anns',dataset))
|
| 127 |
+
if not os.path.exists(os.path.join(output_dir,'masks',dataset)):
|
| 128 |
+
os.makedirs(os.path.join(output_dir,'masks',dataset))
|
| 129 |
+
for split in splits:
|
| 130 |
+
f = open(os.path.join(output_dir,'anns', dataset, split + '.txt'), 'w', encoding='utf-8')
|
| 131 |
+
# print(split)
|
| 132 |
+
split_num=0
|
| 133 |
+
ll=0
|
| 134 |
+
ref_ids = refer.getRefIds(split=split)
|
| 135 |
+
print(split+'_size:',len(ref_ids))
|
| 136 |
+
for i in ref_ids:
|
| 137 |
+
# ref_id = ref_ids[i]
|
| 138 |
+
refs = refer.Refs[i]
|
| 139 |
+
bboxs=refer.getRefBox(i)
|
| 140 |
+
print("bboxs", bboxs)
|
| 141 |
+
sentences=refs['sentences']
|
| 142 |
+
image_urls=refer.loadImgs(image_ids=refs['image_id'])[0]
|
| 143 |
+
|
| 144 |
+
# grefcoco中的category_id是一个list
|
| 145 |
+
cat = refs['category_id']
|
| 146 |
+
if type(cat) == list:
|
| 147 |
+
for j in range(len(cat)):
|
| 148 |
+
cat[j] = cat_process(cat[j])
|
| 149 |
+
else:
|
| 150 |
+
cat = cat_process(cat)
|
| 151 |
+
|
| 152 |
+
image_urls=image_urls['file_name']
|
| 153 |
+
if dataset=='refclef' and image_urls in ['19579.jpg', '17975.jpg', '19575.jpg']:
|
| 154 |
+
continue
|
| 155 |
+
# RES中box信息和cat信息使用不到
|
| 156 |
+
if type(bboxs[0]) == list:
|
| 157 |
+
box_info = bbox_process(bboxs[0], cat[0], i) # add segement id
|
| 158 |
+
else:
|
| 159 |
+
box_info=bbox_process(bboxs,cat,i) #add segement id
|
| 160 |
+
f.write(image_urls)
|
| 161 |
+
f.write(box_info)
|
| 162 |
+
# f.write(' '+str(i))
|
| 163 |
+
if generate_mask:
|
| 164 |
+
if dataset == 'grefcoco':
|
| 165 |
+
np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMaskByRef(refs, merge=True)['mask'])
|
| 166 |
+
else:
|
| 167 |
+
np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMask(refs)['mask']) #if need seg mask ,set it!
|
| 168 |
+
for sentence in sentences:
|
| 169 |
+
f.write(' ~ ')
|
| 170 |
+
# print(sentence['sent'].encode('UTF-8'))
|
| 171 |
+
f.write(sentence['sent'])
|
| 172 |
+
if ll<len(sentence['sent']):
|
| 173 |
+
ll=len(sentence['sent'])
|
| 174 |
+
f.write('\n')
|
| 175 |
+
split_num+=1
|
| 176 |
+
print('split_num:',split_num)
|
| 177 |
+
print('max_len:',ll)
|
| 178 |
+
f.close()
|
| 179 |
+
|
| 180 |
+
def prepare_sentences_refcoco():
|
| 181 |
+
splits=['train','val']
|
| 182 |
+
# batch_size=32
|
| 183 |
+
f = open('sentences.txt', 'w')
|
| 184 |
+
for split in splits:
|
| 185 |
+
print(split)
|
| 186 |
+
ref_ids = refer.getRefIds(split=split)
|
| 187 |
+
print(split+'_size:',len(ref_ids))
|
| 188 |
+
for i in range(len(ref_ids)):
|
| 189 |
+
refs = refer.Refs[i]
|
| 190 |
+
sentences=refs['sentences']
|
| 191 |
+
for sentence in sentences:
|
| 192 |
+
f.write(sentence['sent'])
|
| 193 |
+
f.write('\n')
|
| 194 |
+
f.close()
|
| 195 |
+
|
| 196 |
+
def test_length():
|
| 197 |
+
max_len=0
|
| 198 |
+
word_l_count=np.zeros([50],dtype=np.int)
|
| 199 |
+
with open('./refcocog/train.txt') as f:
|
| 200 |
+
lines = f.readlines()
|
| 201 |
+
for j in range(len(lines)):
|
| 202 |
+
line=lines[j].split()
|
| 203 |
+
stop = len(line)
|
| 204 |
+
for i in range(1, len(line)):
|
| 205 |
+
if (line[i] == '~'):
|
| 206 |
+
stop = i
|
| 207 |
+
break
|
| 208 |
+
sentences = []
|
| 209 |
+
sent_stop = stop + 1
|
| 210 |
+
for i in range(stop + 1, len(line)):
|
| 211 |
+
if line[i] == '~':
|
| 212 |
+
# sentences.append(line[sent_stop:i])
|
| 213 |
+
# print(len(line[sent_stop:i]))
|
| 214 |
+
word_l_count[len(line[sent_stop:i])]+=1
|
| 215 |
+
# if len(line[sent_stop:i])>max_len:
|
| 216 |
+
# max_len=len(line[sent_stop:i])
|
| 217 |
+
sent_stop = i + 1
|
| 218 |
+
for i in range(50):
|
| 219 |
+
if word_l_count[i]>0:
|
| 220 |
+
print('length:%d'%i,',count:%d'%word_l_count[i])
|
| 221 |
+
# print('max_len:',max_len)
|
| 222 |
+
# print(len(lines))
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
prepare_dataset(args.dataset,splits,args.output_dir,args.generate_mask)
|
ASDA/dataset/datascript.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# generate **.pth
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
sys.path.append('.')
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
parser = argparse.ArgumentParser(description='Data preparation')
|
| 9 |
+
parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog_google', 'refcocog_umd'], default='refcoco')
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
def main(args):
|
| 13 |
+
dataset = args.dataset
|
| 14 |
+
input_txt_list = os.listdir(f'../ln_data/anns/{dataset}')
|
| 15 |
+
if not os.path.exists(f'../data/{dataset}'):
|
| 16 |
+
os.makedirs(f'../data/{dataset}')
|
| 17 |
+
for input_txt in input_txt_list:
|
| 18 |
+
split = input_txt.split('_')[-1].split('.')[0]
|
| 19 |
+
input_txt = os.path.join('../ln_data/anns', dataset, input_txt)
|
| 20 |
+
res = []
|
| 21 |
+
with open(input_txt, encoding='utf-8') as f:
|
| 22 |
+
lines = f.readlines()
|
| 23 |
+
for line in lines:
|
| 24 |
+
line = line.split()
|
| 25 |
+
stop = len(line)
|
| 26 |
+
img_name = line[0]
|
| 27 |
+
for i in range(1,len(line)):
|
| 28 |
+
if (line[i]=='~'):
|
| 29 |
+
stop=i
|
| 30 |
+
break
|
| 31 |
+
box_ = [list(map(int,box.split(','))) for box in line[1:stop]]
|
| 32 |
+
box = box_[0][:4]
|
| 33 |
+
seg_id=box_[0][-1]
|
| 34 |
+
|
| 35 |
+
sent_stop=stop+1
|
| 36 |
+
for i in range(stop+1,len(line)):
|
| 37 |
+
if line[i]=='~':
|
| 38 |
+
des = ''
|
| 39 |
+
for word in line[sent_stop:i]:
|
| 40 |
+
des = des + word + ' '
|
| 41 |
+
sent_stop=i+1
|
| 42 |
+
des = des.rstrip(' ')
|
| 43 |
+
res.append((img_name, seg_id, box, des))
|
| 44 |
+
des = ''
|
| 45 |
+
for word in line[sent_stop:len(line)]:
|
| 46 |
+
des = des + word + ' '
|
| 47 |
+
des = des.rstrip(' ')
|
| 48 |
+
res.append((img_name, seg_id, box, des))
|
| 49 |
+
# print(res)
|
| 50 |
+
|
| 51 |
+
imgset_path = '{0}_{1}.pth'.format(dataset, split)
|
| 52 |
+
images = torch.save(res, os.path.join("../data", dataset, imgset_path))
|
| 53 |
+
print(dataset, " done")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
main(args)
|
ASDA/dataset/refer.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'licheng'
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
This interface provides access to four datasets:
|
| 5 |
+
1) refclef
|
| 6 |
+
2) refcoco
|
| 7 |
+
3) refcoco+
|
| 8 |
+
4) refcocog
|
| 9 |
+
split by unc and google
|
| 10 |
+
The following API functions are defined:
|
| 11 |
+
REFER - REFER api class
|
| 12 |
+
getRefIds - get ref ids that satisfy given filter conditions.
|
| 13 |
+
getAnnIds - get ann ids that satisfy given filter conditions.
|
| 14 |
+
getImgIds - get image ids that satisfy given filter conditions.
|
| 15 |
+
getCatIds - get category ids that satisfy given filter conditions.
|
| 16 |
+
loadRefs - load refs with the specified ref ids.
|
| 17 |
+
loadAnns - load anns with the specified ann ids.
|
| 18 |
+
loadImgs - load images with the specified image ids.
|
| 19 |
+
loadCats - load category names with the specified category ids.
|
| 20 |
+
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
| 21 |
+
showRef - show image, segmentation or box of the referred object with the ref
|
| 22 |
+
getMask - get mask and area of the referred object given ref
|
| 23 |
+
showMask - show mask of the referred object given ref
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import sys
|
| 27 |
+
import os.path as osp
|
| 28 |
+
import os
|
| 29 |
+
import json
|
| 30 |
+
# import _pickle as pickle
|
| 31 |
+
import pickle
|
| 32 |
+
import time
|
| 33 |
+
import itertools
|
| 34 |
+
import skimage.io as io
|
| 35 |
+
import matplotlib.pyplot as plt
|
| 36 |
+
from matplotlib.collections import PatchCollection
|
| 37 |
+
from matplotlib.patches import Polygon, Rectangle
|
| 38 |
+
from pprint import pprint
|
| 39 |
+
import numpy as np
|
| 40 |
+
from pycocotools import mask
|
| 41 |
+
import cv2
|
| 42 |
+
# from skimage.measure import label, regionprops
|
| 43 |
+
|
| 44 |
+
class REFER:
|
| 45 |
+
def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
|
| 46 |
+
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
| 47 |
+
# also provide dataset name and splitBy information
|
| 48 |
+
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
| 49 |
+
print('loading dataset %s into memory...' % dataset)
|
| 50 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
| 51 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
| 52 |
+
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
|
| 53 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/train2014')
|
| 54 |
+
elif dataset == 'refclef':
|
| 55 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
|
| 56 |
+
else:
|
| 57 |
+
print('No refer dataset is called [%s]' % dataset)
|
| 58 |
+
sys.exit()
|
| 59 |
+
|
| 60 |
+
# load refs from data/dataset/refs(dataset).json
|
| 61 |
+
tic = time.time()
|
| 62 |
+
ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
|
| 63 |
+
self.data = {}
|
| 64 |
+
self.data['dataset'] = dataset
|
| 65 |
+
|
| 66 |
+
self.data['refs'] = pickle.load(open(ref_file, 'rb'),fix_imports=True)
|
| 67 |
+
|
| 68 |
+
# load annotations from data/dataset/instances.json
|
| 69 |
+
instances_file = osp.join(self.DATA_DIR, 'instances.json')
|
| 70 |
+
instances = json.load(open(instances_file, 'r'))
|
| 71 |
+
self.data['images'] = instances['images']
|
| 72 |
+
self.data['annotations'] = instances['annotations']
|
| 73 |
+
self.data['categories'] = instances['categories']
|
| 74 |
+
|
| 75 |
+
# create index
|
| 76 |
+
self.createIndex()
|
| 77 |
+
print('DONE (t=%.2fs)' % (time.time()-tic))
|
| 78 |
+
|
| 79 |
+
def createIndex(self):
|
| 80 |
+
# create sets of mapping
|
| 81 |
+
# 1) Refs: {ref_id: ref}
|
| 82 |
+
# 2) Anns: {ann_id: ann}
|
| 83 |
+
# 3) Imgs: {image_id: image}
|
| 84 |
+
# 4) Cats: {category_id: category_name}
|
| 85 |
+
# 5) Sents: {sent_id: sent}
|
| 86 |
+
# 6) imgToRefs: {image_id: refs}
|
| 87 |
+
# 7) imgToAnns: {image_id: anns}
|
| 88 |
+
# 8) refToAnn: {ref_id: ann}
|
| 89 |
+
# 9) annToRef: {ann_id: ref}
|
| 90 |
+
# 10) catToRefs: {category_id: refs}
|
| 91 |
+
# 11) sentToRef: {sent_id: ref}
|
| 92 |
+
# 12) sentToTokens: {sent_id: tokens}
|
| 93 |
+
print('creating index...')
|
| 94 |
+
# fetch info from instances
|
| 95 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
| 96 |
+
for ann in self.data['annotations']:
|
| 97 |
+
Anns[ann['id']] = ann
|
| 98 |
+
imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
|
| 99 |
+
for img in self.data['images']:
|
| 100 |
+
Imgs[img['id']] = img
|
| 101 |
+
for cat in self.data['categories']:
|
| 102 |
+
Cats[cat['id']] = cat['name']
|
| 103 |
+
|
| 104 |
+
# fetch info from refs
|
| 105 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
| 106 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
| 107 |
+
for ref in self.data['refs']:
|
| 108 |
+
# ids
|
| 109 |
+
ref_id = ref['ref_id']
|
| 110 |
+
ann_id = ref['ann_id']
|
| 111 |
+
category_id = ref['category_id']
|
| 112 |
+
image_id = ref['image_id']
|
| 113 |
+
|
| 114 |
+
# add mapping related to ref
|
| 115 |
+
Refs[ref_id] = ref
|
| 116 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
| 117 |
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
| 118 |
+
refToAnn[ref_id] = Anns[ann_id]
|
| 119 |
+
annToRef[ann_id] = ref
|
| 120 |
+
|
| 121 |
+
# add mapping of sent
|
| 122 |
+
for sent in ref['sentences']:
|
| 123 |
+
Sents[sent['sent_id']] = sent
|
| 124 |
+
sentToRef[sent['sent_id']] = ref
|
| 125 |
+
sentToTokens[sent['sent_id']] = sent['tokens']
|
| 126 |
+
|
| 127 |
+
# create class members
|
| 128 |
+
self.Refs = Refs
|
| 129 |
+
self.Anns = Anns
|
| 130 |
+
self.Imgs = Imgs
|
| 131 |
+
self.Cats = Cats
|
| 132 |
+
self.Sents = Sents
|
| 133 |
+
self.imgToRefs = imgToRefs
|
| 134 |
+
self.imgToAnns = imgToAnns
|
| 135 |
+
self.refToAnn = refToAnn
|
| 136 |
+
self.annToRef = annToRef
|
| 137 |
+
self.catToRefs = catToRefs
|
| 138 |
+
self.sentToRef = sentToRef
|
| 139 |
+
self.sentToTokens = sentToTokens
|
| 140 |
+
print('index created.')
|
| 141 |
+
|
| 142 |
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
|
| 143 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 144 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 145 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 146 |
+
|
| 147 |
+
if len(image_ids)==len(cat_ids)==len(ref_ids)==len(split)==0:
|
| 148 |
+
refs = self.data['refs']
|
| 149 |
+
else:
|
| 150 |
+
if not len(image_ids) == 0:
|
| 151 |
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
| 152 |
+
else:
|
| 153 |
+
refs = self.data['refs']
|
| 154 |
+
if not len(cat_ids) == 0:
|
| 155 |
+
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
|
| 156 |
+
if not len(ref_ids) == 0:
|
| 157 |
+
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
|
| 158 |
+
if not len(split) == 0:
|
| 159 |
+
if split in ['testA', 'testB', 'testC']:
|
| 160 |
+
refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ...
|
| 161 |
+
elif split in ['testAB', 'testBC', 'testAC']:
|
| 162 |
+
refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
|
| 163 |
+
elif split == 'test':
|
| 164 |
+
refs = [ref for ref in refs if 'test' in ref['split']]
|
| 165 |
+
elif split == 'train' or split == 'val':
|
| 166 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
| 167 |
+
else:
|
| 168 |
+
print('No such split [%s]' % split)
|
| 169 |
+
sys.exit()
|
| 170 |
+
ref_ids = [ref['ref_id'] for ref in refs]
|
| 171 |
+
return ref_ids
|
| 172 |
+
|
| 173 |
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
| 174 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
| 175 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
| 176 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 177 |
+
|
| 178 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
| 179 |
+
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
| 180 |
+
else:
|
| 181 |
+
if not len(image_ids) == 0:
|
| 182 |
+
lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
|
| 183 |
+
anns = list(itertools.chain.from_iterable(lists))
|
| 184 |
+
else:
|
| 185 |
+
anns = self.data['annotations']
|
| 186 |
+
if not len(cat_ids) == 0:
|
| 187 |
+
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
|
| 188 |
+
ann_ids = [ann['id'] for ann in anns]
|
| 189 |
+
if not len(ref_ids) == 0:
|
| 190 |
+
ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
|
| 191 |
+
return ann_ids
|
| 192 |
+
|
| 193 |
+
def getImgIds(self, ref_ids=[]):
|
| 194 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
| 195 |
+
|
| 196 |
+
if not len(ref_ids) == 0:
|
| 197 |
+
image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
|
| 198 |
+
else:
|
| 199 |
+
image_ids = self.Imgs.keys()
|
| 200 |
+
return image_ids
|
| 201 |
+
|
| 202 |
+
def getCatIds(self):
|
| 203 |
+
return self.Cats.keys()
|
| 204 |
+
|
| 205 |
+
def loadRefs(self, ref_ids=[]):
|
| 206 |
+
if type(ref_ids) == list:
|
| 207 |
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
| 208 |
+
elif type(ref_ids) == int:
|
| 209 |
+
return [self.Refs[ref_ids]]
|
| 210 |
+
|
| 211 |
+
def loadAnns(self, ann_ids=[]):
|
| 212 |
+
if type(ann_ids) == list:
|
| 213 |
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
| 214 |
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
| 215 |
+
return [self.Anns[ann_ids]]
|
| 216 |
+
|
| 217 |
+
def loadImgs(self, image_ids=[]):
|
| 218 |
+
if type(image_ids) == list:
|
| 219 |
+
return [self.Imgs[image_id] for image_id in image_ids]
|
| 220 |
+
elif type(image_ids) == int:
|
| 221 |
+
return [self.Imgs[image_ids]]
|
| 222 |
+
|
| 223 |
+
def loadCats(self, cat_ids=[]):
|
| 224 |
+
if type(cat_ids) == list:
|
| 225 |
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
| 226 |
+
elif type(cat_ids) == int:
|
| 227 |
+
return [self.Cats[cat_ids]]
|
| 228 |
+
|
| 229 |
+
def getRefBox(self, ref_id):
|
| 230 |
+
ref = self.Refs[ref_id]
|
| 231 |
+
ann = self.refToAnn[ref_id]
|
| 232 |
+
return ann['bbox'] # [x, y, w, h]
|
| 233 |
+
|
| 234 |
+
def showRef(self, ref, seg_box='seg'):
|
| 235 |
+
ax = plt.gca()
|
| 236 |
+
# show image
|
| 237 |
+
image = self.Imgs[ref['image_id']]
|
| 238 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
| 239 |
+
ax.imshow(I)
|
| 240 |
+
# show refer expression
|
| 241 |
+
for sid, sent in enumerate(ref['sentences']):
|
| 242 |
+
print('%s. %s' % (sid+1, sent['sent']))
|
| 243 |
+
# show segmentations
|
| 244 |
+
if seg_box == 'seg':
|
| 245 |
+
ann_id = ref['ann_id']
|
| 246 |
+
ann = self.Anns[ann_id]
|
| 247 |
+
polygons = []
|
| 248 |
+
color = []
|
| 249 |
+
c = 'none'
|
| 250 |
+
if type(ann['segmentation'][0]) == list:
|
| 251 |
+
# polygon used for refcoco*
|
| 252 |
+
for seg in ann['segmentation']:
|
| 253 |
+
poly = np.array(seg).reshape((len(seg)//2, 2))
|
| 254 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
| 255 |
+
color.append(c)
|
| 256 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(1,1,0,0), linewidths=3, alpha=1)
|
| 257 |
+
ax.add_collection(p) # thick yellow polygon
|
| 258 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(1,0,0,0), linewidths=1, alpha=1)
|
| 259 |
+
ax.add_collection(p) # thin red polygon
|
| 260 |
+
else:
|
| 261 |
+
# mask used for refclef
|
| 262 |
+
rle = ann['segmentation']
|
| 263 |
+
m = mask.decode(rle)
|
| 264 |
+
img = np.ones( (m.shape[0], m.shape[1], 3) )
|
| 265 |
+
color_mask = np.array([2.0,166.0,101.0])/255
|
| 266 |
+
for i in range(3):
|
| 267 |
+
img[:,:,i] = color_mask[i]
|
| 268 |
+
ax.imshow(np.dstack( (img, m*0.5) ))
|
| 269 |
+
# show bounding-box
|
| 270 |
+
elif seg_box == 'box':
|
| 271 |
+
ann_id = ref['ann_id']
|
| 272 |
+
print(ann_id)
|
| 273 |
+
ann = self.Anns[ann_id]
|
| 274 |
+
bbox = self.getRefBox(ref['ref_id'])
|
| 275 |
+
box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
|
| 276 |
+
ax.add_patch(box_plot)
|
| 277 |
+
|
| 278 |
+
def getMask(self, ref):
|
| 279 |
+
# return mask, area and mask-center
|
| 280 |
+
ann = self.refToAnn[ref['ref_id']]
|
| 281 |
+
print(ann)
|
| 282 |
+
image = self.Imgs[ref['image_id']]
|
| 283 |
+
if type(ann['segmentation'][0]) == list: # polygon
|
| 284 |
+
rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width'])
|
| 285 |
+
else:
|
| 286 |
+
rle = ann['segmentation']
|
| 287 |
+
|
| 288 |
+
# for i in range(len(rle['counts'])):
|
| 289 |
+
# print(rle)
|
| 290 |
+
m = mask.decode(rle)
|
| 291 |
+
m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
|
| 292 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
| 293 |
+
# compute area
|
| 294 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
| 295 |
+
return {'mask': m, 'area': area}
|
| 296 |
+
# # position
|
| 297 |
+
# position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
|
| 298 |
+
# position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
|
| 299 |
+
# # mass position (if there were multiple regions, we use the largest one.)
|
| 300 |
+
# label_m = label(m, connectivity=m.ndim)
|
| 301 |
+
# regions = regionprops(label_m)
|
| 302 |
+
# if len(regions) > 0:
|
| 303 |
+
# largest_id = np.argmax(np.array([props.filled_area for props in regions]))
|
| 304 |
+
# largest_props = regions[largest_id]
|
| 305 |
+
# mass_y, mass_x = largest_props.centroid
|
| 306 |
+
# else:
|
| 307 |
+
# mass_x, mass_y = position_x, position_y
|
| 308 |
+
# # if centroid is not in mask, we find the closest point to it from mask
|
| 309 |
+
# if m[mass_y, mass_x] != 1:
|
| 310 |
+
# print 'Finding closes mask point ...'
|
| 311 |
+
# kernel = np.ones((10, 10),np.uint8)
|
| 312 |
+
# me = cv2.erode(m, kernel, iterations = 1)
|
| 313 |
+
# points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
|
| 314 |
+
# points = np.array(points)
|
| 315 |
+
# dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
|
| 316 |
+
# id = np.argsort(dist)[0]
|
| 317 |
+
# mass_y, mass_x = points[id]
|
| 318 |
+
# # return
|
| 319 |
+
# return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
|
| 320 |
+
# # show image and mask
|
| 321 |
+
# I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
| 322 |
+
# plt.figure()
|
| 323 |
+
# plt.imshow(I)
|
| 324 |
+
# ax = plt.gca()
|
| 325 |
+
# img = np.ones( (m.shape[0], m.shape[1], 3) )
|
| 326 |
+
# color_mask = np.array([2.0,166.0,101.0])/255
|
| 327 |
+
# for i in range(3):
|
| 328 |
+
# img[:,:,i] = color_mask[i]
|
| 329 |
+
# ax.imshow(np.dstack( (img, m*0.5) ))
|
| 330 |
+
# plt.show()
|
| 331 |
+
|
| 332 |
+
def showMask(self, ref):
|
| 333 |
+
M = self.getMask(ref)
|
| 334 |
+
msk = M['mask']
|
| 335 |
+
ax = plt.gca()
|
| 336 |
+
ax.imshow(msk)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == '__main__':
|
| 340 |
+
refer = REFER(data_root="/home/ypf/workspace/code/BKINet/ln_data", dataset='refcoco', splitBy='unc')
|
| 341 |
+
save_path = "./visualization/"
|
| 342 |
+
ref_ids = refer.getRefIds()
|
| 343 |
+
print(len(ref_ids))
|
| 344 |
+
|
| 345 |
+
print(len(refer.Imgs))
|
| 346 |
+
print(len(refer.imgToRefs))
|
| 347 |
+
print(refer.Cats)
|
| 348 |
+
|
| 349 |
+
ref_ids = refer.getRefIds(split='train')
|
| 350 |
+
print('There are %s training referred objects.' % len(ref_ids))
|
| 351 |
+
|
| 352 |
+
img_ids = [8936, 52563]
|
| 353 |
+
# ref_ids = refer.getRefIds(image_ids=img_ids)
|
| 354 |
+
|
| 355 |
+
# refs = refer.loadRefs(ref_ids)
|
| 356 |
+
|
| 357 |
+
def custom_vis1(image, mask_):
|
| 358 |
+
# 将mask应用到蓝色图层
|
| 359 |
+
# 创建一个蓝色图层
|
| 360 |
+
blue_layer = np.zeros_like(image)
|
| 361 |
+
blue_layer[:, :, 0] = 255 # 对于OpenCV,蓝色通道是第一个
|
| 362 |
+
blue_mask = cv2.bitwise_and(blue_layer, blue_layer, mask=mask_)
|
| 363 |
+
|
| 364 |
+
# 将蓝色mask以一定的透明度覆盖到原图上
|
| 365 |
+
alpha = 0.1 # 透明度
|
| 366 |
+
cv2.addWeighted(blue_mask, alpha, image, 1 - alpha, 0, image)
|
| 367 |
+
|
| 368 |
+
def custom_vis2(image, mask_):
|
| 369 |
+
# 创建蓝色图层
|
| 370 |
+
blue_layer = np.zeros_like(image)
|
| 371 |
+
blue_layer[:, :, 0] = 255 # 对于OpenCV,蓝色通道是第一个
|
| 372 |
+
|
| 373 |
+
# 将mask应用到蓝色图层
|
| 374 |
+
blue_mask = cv2.bitwise_and(blue_layer, blue_layer, mask=mask_)
|
| 375 |
+
|
| 376 |
+
# alpha值定义了mask图层和原图的融合程度
|
| 377 |
+
alpha = 0.5 # 透明度
|
| 378 |
+
|
| 379 |
+
# 创建一个完全透明的图层
|
| 380 |
+
transparent_layer = np.zeros_like(image)
|
| 381 |
+
|
| 382 |
+
# 我们只在mask的区域上应用蓝色图层,并调整alpha值来控制透明度
|
| 383 |
+
for i in range(3): # 只处理RGB三个通道
|
| 384 |
+
transparent_layer[:, :, i] = cv2.addWeighted(blue_mask[:, :, i], alpha, image[:, :, i], 1 - alpha, 0)
|
| 385 |
+
|
| 386 |
+
# 在mask区域外使用原图
|
| 387 |
+
transparent_layer[mask_ == 0] = image[mask_ == 0]
|
| 388 |
+
|
| 389 |
+
return transparent_layer
|
| 390 |
+
|
| 391 |
+
def custom_vis3(image, mask_):
|
| 392 |
+
"""
|
| 393 |
+
直接在原图上修改指定mask区域的颜色为蓝色
|
| 394 |
+
不改变其他区域的亮度或色彩
|
| 395 |
+
"""
|
| 396 |
+
image[mask_ != 0] = [255, 0, 0] # OpenCV中的颜色顺序是BGR
|
| 397 |
+
|
| 398 |
+
def custom_vis4(image, mask_, alpha=0.4):
|
| 399 |
+
"""
|
| 400 |
+
在原图上以指定的透明度应用蓝色遮罩。
|
| 401 |
+
alpha: 遮罩的透明度,范围从0(完全透明)到1(完全不透明)。
|
| 402 |
+
"""
|
| 403 |
+
# 将原图从BGR转换为RGBA以添加Alpha通道
|
| 404 |
+
image_rgba = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)
|
| 405 |
+
# 创建一个同样大小的全蓝色图层
|
| 406 |
+
blue_mask = np.zeros_like(image_rgba)
|
| 407 |
+
blue_mask[:, :, 0] = 255 # B
|
| 408 |
+
blue_mask[:, :, 3] = 255 # Alpha设置为不透明
|
| 409 |
+
|
| 410 |
+
# 应用透明度到mask区域
|
| 411 |
+
blue_mask[mask_ != 0, 3] = int(alpha * 255)
|
| 412 |
+
|
| 413 |
+
# 将蓝色遮罩叠加到原图
|
| 414 |
+
image_rgba = cv2.addWeighted(image_rgba, 1, blue_mask, alpha, 0)
|
| 415 |
+
return image_rgba
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
for i, img_id in enumerate(img_ids):
|
| 421 |
+
ref = refer.imgToRefs[img_id][0]
|
| 422 |
+
print(ref)
|
| 423 |
+
mask_ = refer.getMask(ref)['mask']
|
| 424 |
+
# sentence = ref['sentences'][0]['sent']
|
| 425 |
+
|
| 426 |
+
img = refer.Imgs[img_id]
|
| 427 |
+
# I = io.imread(osp.join(refer.IMAGE_DIR, img['file_name']))
|
| 428 |
+
# 假设`image_path`是原始图像的路径,`mask`是一个与原图像相同大小的二值数组
|
| 429 |
+
image_path = osp.join(refer.IMAGE_DIR, img['file_name'])
|
| 430 |
+
image = cv2.imread(image_path)
|
| 431 |
+
# mask = np.zeros(image.shape[:2], dtype=np.uint8) # 这里你需要有一个实际的mask
|
| 432 |
+
|
| 433 |
+
# custom_vis1(image, mask_)
|
| 434 |
+
|
| 435 |
+
image = custom_vis2(image, mask_)
|
| 436 |
+
|
| 437 |
+
# custom_vis3(image, mask_)
|
| 438 |
+
|
| 439 |
+
# image = custom_vis4(image=image, mask_=mask_, alpha=0.4)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
# 保存结果图像到指定路径
|
| 445 |
+
image_dir = osp.join(save_path, str(img_id))
|
| 446 |
+
osp.exists(image_dir) or os.makedirs(image_dir)
|
| 447 |
+
# 复制原图
|
| 448 |
+
I = io.imread(osp.join(refer.IMAGE_DIR, img['file_name']))
|
| 449 |
+
io.imsave(osp.join(image_dir, img['file_name']), I)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
cv2.imwrite(osp.join(image_dir, str(img_id)+".png"), image)
|
| 453 |
+
|
| 454 |
+
# 将json格式的ref保存
|
| 455 |
+
with open(osp.join(image_dir, str(img_id)+".json"), "w") as f:
|
| 456 |
+
json.dump(ref, f)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# i = 0
|
| 463 |
+
# for ref_id in ref_ids:
|
| 464 |
+
# i += 1
|
| 465 |
+
# ref = refer.loadRefs(ref_id)[0]
|
| 466 |
+
# if len(ref['sentences']) < 2:
|
| 467 |
+
# continue
|
| 468 |
+
|
| 469 |
+
# print(ref)
|
| 470 |
+
# print('The label is %s.' % refer.Cats[ref['category_id']])
|
| 471 |
+
# plt.figure()
|
| 472 |
+
# # refer.getMask(ref)
|
| 473 |
+
# refer.showMask(ref)
|
| 474 |
+
|
| 475 |
+
# # refer.showRef(ref, seg_box='seg')
|
| 476 |
+
|
| 477 |
+
# plt.show()
|
| 478 |
+
# if i == 0:
|
| 479 |
+
# break
|
| 480 |
+
# # save
|
| 481 |
+
# plt.savefig('tmp.png')
|
| 482 |
+
|
| 483 |
+
# plt.figure()
|
| 484 |
+
# refer.showMask(ref)
|
| 485 |
+
# plt.show()
|