|
|
|
|
|
|
|
|
""" |
|
|
refcoco, refcoco+ and refcocog referring image detection and segmentation PyTorch dataset. |
|
|
""" |
|
|
import sys |
|
|
import cv2 |
|
|
import os |
|
|
import torch |
|
|
import json |
|
|
import random |
|
|
import numpy as np |
|
|
import os.path as osp |
|
|
import torch.utils.data as data |
|
|
sys.path.append('.') |
|
|
import utils |
|
|
import re |
|
|
|
|
|
|
|
|
from utils.transforms import letterbox, random_affine, random_copy, random_crop, random_erase |
|
|
import copy |
|
|
|
|
|
import clip |
|
|
|
|
|
sys.modules['utils'] = utils |
|
|
cv2.setNumThreads(0) |
|
|
|
|
|
class ReferDataset(data.Dataset): |
|
|
SUPPORTED_DATASETS = { |
|
|
'refcoco': { |
|
|
'splits': ('train', 'val', 'testA', 'testB'), |
|
|
'params': {'dataset': 'refcoco', 'split_by': 'unc'} |
|
|
}, |
|
|
'refcoco+': { |
|
|
'splits': ('train', 'val', 'testA', 'testB'), |
|
|
'params': {'dataset': 'refcoco+', 'split_by': 'unc'} |
|
|
}, |
|
|
'refcocog': { |
|
|
'splits': ('train', 'val', 'test'), |
|
|
'params': {'dataset': 'refcocog', 'split_by': 'unc'} |
|
|
}, |
|
|
'refcocog_g': { |
|
|
'splits': ('train', 'val'), |
|
|
'params': {'dataset': 'refcocog', 'split_by': 'google'} |
|
|
}, |
|
|
'refcocog_u': { |
|
|
'splits': ('train', 'val', 'test'), |
|
|
'params': {'dataset': 'refcocog', 'split_by': 'unc'} |
|
|
}, |
|
|
'grefcoco': { |
|
|
'splits': ('train', 'val', 'testA', 'testB'), |
|
|
'params': {'dataset': 'grefcoco', 'split_by': 'unc'} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def __init__(self, data_root, split_root='data', dataset='refcoco', imsize=256, splitby='umd', |
|
|
transform=None, augment=False, split='train', max_query_len=128, metric_learning=None): |
|
|
images_tmp = [] |
|
|
self.data_root = data_root |
|
|
self.split_root = split_root |
|
|
self.dataset = dataset |
|
|
self.imsize = imsize |
|
|
self.query_len = max_query_len |
|
|
self.transform = transform |
|
|
self.word_len = 17 |
|
|
self.emb_size = 384 |
|
|
self.split = split |
|
|
self.augment=augment |
|
|
|
|
|
valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits'] |
|
|
|
|
|
if split not in valid_splits: |
|
|
raise ValueError( |
|
|
'Dataset {0} does not have split {1}'.format( |
|
|
self.dataset, split)) |
|
|
|
|
|
self.anns_root = osp.join(self.data_root, 'anns', self.dataset, self.split+'.txt') |
|
|
if self.dataset == 'refcocog' : |
|
|
mask_anno_str = '{0}_{1}'.format(self.dataset, splitby) |
|
|
self.mask_root = osp.join(self.data_root, 'masks', mask_anno_str) |
|
|
else : |
|
|
self.mask_root = osp.join(self.data_root, 'masks', self.dataset) |
|
|
|
|
|
self.im_dir = osp.join(self.data_root, 'images', 'train2014') |
|
|
|
|
|
|
|
|
dataset_path = osp.join(self.split_root, self.dataset) |
|
|
splits = [split] |
|
|
for split in splits: |
|
|
imgset_file = '{0}_{1}.pth'.format(self.dataset, split) |
|
|
imgset_path = osp.join(dataset_path, imgset_file) |
|
|
images_tmp += torch.load(imgset_path) |
|
|
|
|
|
|
|
|
self.ROOT = '/data2/dataset/RefCOCO/VRIS' |
|
|
if self.dataset == 'refcoco' : |
|
|
self.all_hp_root = '/data2/dataset/RefCOCO/refcoco/SBERT_rcc_unc' |
|
|
elif self.dataset == 'refcoco+' : |
|
|
self.all_hp_root = '/data2/dataset/RefCOCO/refcoco+/SBERT_rccp_unc' |
|
|
|
|
|
self.metric_learning = metric_learning |
|
|
if self.metric_learning : |
|
|
self.exclude_position = True |
|
|
self.exclude_multiobj = True |
|
|
self.hp_selection = 'strict' |
|
|
self.multi_obj_ref_ids = None |
|
|
self.hardpos_meta = None |
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
ref_sentence_counts = defaultdict(int) |
|
|
for item in images_tmp: |
|
|
ref_sentence_counts[item[1]] += 1 |
|
|
|
|
|
if self.split == 'train' : |
|
|
images = [] |
|
|
ref_sentence_indices = defaultdict(int) |
|
|
for item in images_tmp: |
|
|
img_name, seg_id, box, sentence = item |
|
|
sent_index = ref_sentence_indices[seg_id] |
|
|
total_sentences = ref_sentence_counts[seg_id] |
|
|
images.append((img_name, seg_id, box, sentence, sent_index, total_sentences)) |
|
|
ref_sentence_indices[seg_id] += 1 |
|
|
self.images = images |
|
|
else : |
|
|
self.images = images_tmp |
|
|
else : |
|
|
self.images = images_tmp |
|
|
|
|
|
def exists_dataset(self): |
|
|
return osp.exists(osp.join(self.split_root, self.dataset)) |
|
|
|
|
|
def _get_hardpos_verb_rcc(self, seg_id, sent_idx): |
|
|
emb_folder = os.path.join(self.all_hp_root, str(seg_id)) |
|
|
emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_") and f.endswith(".npy")]) |
|
|
if self.hp_selection == 'strict' : |
|
|
|
|
|
emb_file = emb_files[sent_idx] |
|
|
else : |
|
|
|
|
|
emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_") and f.endswith(".npy")]) |
|
|
emb_file = random.choice(emb_files) |
|
|
selected_emb = np.load(os.path.join(emb_folder, emb_file)) |
|
|
verb_embed = torch.from_numpy(selected_emb) |
|
|
return verb_embed |
|
|
|
|
|
|
|
|
def pull_item(self, idx): |
|
|
|
|
|
if self.metric_learning and self.augment : |
|
|
|
|
|
img_file, seg_id, bbox, phrase, sent_idx, sent_num = self.images[idx] |
|
|
else : |
|
|
img_file, seg_id, bbox, phrase = self.images[idx] |
|
|
bbox = np.array(bbox, dtype=int) |
|
|
|
|
|
img_path = osp.join(self.im_dir, img_file) |
|
|
img = cv2.imread(img_path) |
|
|
|
|
|
if img.shape[-1] > 1: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
else: |
|
|
img = np.stack([img] * 3) |
|
|
|
|
|
|
|
|
seg_map = np.load(osp.join(self.mask_root, str(seg_id)+'.npy')) |
|
|
seg_map = np.array(seg_map).astype(np.float32) |
|
|
|
|
|
if self.metric_learning and self.split == 'train' : |
|
|
return img, phrase, bbox, seg_map, seg_id, sent_idx |
|
|
else : |
|
|
return img, phrase, bbox, seg_map, seg_id |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if self.metric_learning and self.augment : |
|
|
img, phrase, bbox, seg_map, seg_id, sent_idx = self.pull_item(idx) |
|
|
else : |
|
|
img, phrase, bbox, seg_map, seg_id = self.pull_item(idx) |
|
|
|
|
|
phrase = phrase.lower() |
|
|
if self.augment: |
|
|
augment_flip, augment_hsv, augment_affine, augment_crop, augment_copy, augment_erase = \ |
|
|
True, True, True, False, False, False |
|
|
|
|
|
|
|
|
h,w = img.shape[0], img.shape[1] |
|
|
|
|
|
if self.augment: |
|
|
|
|
|
if augment_flip and random.random() > 0.5: |
|
|
img = cv2.flip(img, 1) |
|
|
seg_map = cv2.flip(seg_map, 1) |
|
|
bbox[0], bbox[2] = w-bbox[2]-1, w-bbox[0]-1 |
|
|
phrase = phrase.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left') |
|
|
|
|
|
|
|
|
if augment_copy: |
|
|
img, seg_map, phrase, bbox = random_copy(img, seg_map, phrase, bbox) |
|
|
|
|
|
|
|
|
if augment_erase: |
|
|
img, seg_map = random_erase(img, seg_map) |
|
|
|
|
|
|
|
|
if augment_crop: |
|
|
img, seg_map = random_crop(img, seg_map, 40, h, w) |
|
|
|
|
|
|
|
|
if augment_hsv: |
|
|
fraction = 0.50 |
|
|
img_hsv = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2HSV) |
|
|
S = img_hsv[:, :, 1].astype(np.float32) |
|
|
V = img_hsv[:, :, 2].astype(np.float32) |
|
|
a = (random.random() * 2 - 1) * fraction + 1 |
|
|
if a > 1: |
|
|
np.clip(S, a_min=0, a_max=255, out=S) |
|
|
a = (random.random() * 2 - 1) * fraction + 1 |
|
|
V *= a |
|
|
if a > 1: |
|
|
np.clip(V, a_min=0, a_max=255, out=V) |
|
|
|
|
|
img_hsv[:, :, 1] = S.astype(np.uint8) |
|
|
img_hsv[:, :, 2] = V.astype(np.uint8) |
|
|
img = cv2.cvtColor(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR), cv2.COLOR_BGR2RGB) |
|
|
|
|
|
img, seg_map, ratio, dw, dh = letterbox(img, seg_map, self.imsize) |
|
|
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw |
|
|
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh |
|
|
|
|
|
|
|
|
if augment_affine: |
|
|
img, seg_map, bbox, M = random_affine(img, seg_map, bbox, \ |
|
|
degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10)) |
|
|
|
|
|
else: |
|
|
img, _, ratio, dw, dh = letterbox(img, None, self.imsize) |
|
|
bbox[0], bbox[2] = bbox[0]*ratio+dw, bbox[2]*ratio+dw |
|
|
bbox[1], bbox[3] = bbox[1]*ratio+dh, bbox[3]*ratio+dh |
|
|
|
|
|
draw_img = copy.deepcopy(img) |
|
|
|
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
|
|
|
|
|
|
|
|
|
word_id = clip.tokenize(phrase, 17, truncate=True) |
|
|
word_mask = ~ (word_id == 0) |
|
|
|
|
|
orig_word_id = np.array(word_id, dtype=int) |
|
|
orig_word_mask = np.array(word_mask, dtype=int) |
|
|
|
|
|
|
|
|
if self.metric_learning and self.augment: |
|
|
original_emb = self._get_hardpos_verb_rcc(seg_id, sent_idx) |
|
|
|
|
|
if self.augment: |
|
|
seg_map = cv2.resize(seg_map, (self.imsize // 2, self.imsize // 2),interpolation=cv2.INTER_NEAREST) |
|
|
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]]) |
|
|
if self.metric_learning : |
|
|
params = { |
|
|
'seg_id' : seg_id, |
|
|
'sent' : phrase, |
|
|
'hardpos_emb' : original_emb.unsqueeze(0) |
|
|
} |
|
|
return img, orig_word_id, orig_word_mask, np.array(bbox, dtype=np.float32), \ |
|
|
np.array(seg_map, dtype=np.float32), params |
|
|
else : |
|
|
return img, orig_word_id, orig_word_mask, \ |
|
|
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32) |
|
|
else: |
|
|
seg_map = np.reshape(seg_map, [1, np.shape(seg_map)[0], np.shape(seg_map)[1]]) |
|
|
return img, orig_word_id, orig_word_mask, \ |
|
|
np.array(bbox, dtype=np.float32), np.array(seg_map, dtype=np.float32), np.array(ratio, dtype=np.float32), \ |
|
|
np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0], self.images[idx][3], np.array(draw_img, dtype=np.uint8) |
|
|
|