|
|
import os |
|
|
import sys |
|
|
import json |
|
|
from typing import List, Union |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import lmdb |
|
|
import random |
|
|
import itertools |
|
|
import albumentations as A |
|
|
|
|
|
import numpy as np |
|
|
import pyarrow as pa |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision.transforms import functional as F |
|
|
from bert.tokenization_bert import BertTokenizer |
|
|
|
|
|
info = { |
|
|
'refcoco': { |
|
|
'train': 42404, |
|
|
'val': 3811, |
|
|
'val-test': 3811, |
|
|
'testA': 1975, |
|
|
'testB': 1810 |
|
|
}, |
|
|
'refcoco+': { |
|
|
'train': 42278, |
|
|
'val': 3805, |
|
|
'val-test': 3805, |
|
|
'testA': 1975, |
|
|
'testB': 1798 |
|
|
}, |
|
|
'refcocog_u': { |
|
|
'train': 42226, |
|
|
'val': 2573, |
|
|
'val-test': 2573, |
|
|
'test': 5023, |
|
|
'testA':100, |
|
|
'testB':100 |
|
|
}, |
|
|
'refcocog_g': { |
|
|
'train': 44822, |
|
|
'val': 5000, |
|
|
'val-test': 5000 |
|
|
} |
|
|
} |
|
|
_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
|
|
|
|
def tokenize(texts: Union[str, List[str]], |
|
|
context_length: int = 77, |
|
|
truncate: bool = False) -> torch.LongTensor: |
|
|
""" |
|
|
Returns the tokenized representation of given input string(s) |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
texts : Union[str, List[str]] |
|
|
An input string or a list of input strings to tokenize |
|
|
|
|
|
context_length : int |
|
|
The context length to use; all CLIP models use 77 as the context length |
|
|
|
|
|
truncate: bool |
|
|
Whether to truncate the text in case its encoding is longer than the context length |
|
|
|
|
|
Returns |
|
|
------- |
|
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] |
|
|
""" |
|
|
l_mask = [0] * context_length |
|
|
result = [0] * context_length |
|
|
|
|
|
tokens = _tokenizer.encode(text=texts, add_special_tokens=True) |
|
|
tokens = tokens[:context_length] |
|
|
result[:len(tokens)] = tokens |
|
|
l_mask[:len(tokens)] = [1]*len(tokens) |
|
|
|
|
|
result = torch.tensor(result).unsqueeze(0) |
|
|
l_mask = torch.tensor(l_mask).unsqueeze(0) |
|
|
return result, l_mask |
|
|
|
|
|
|
|
|
def loads_pyarrow(buf): |
|
|
""" |
|
|
Args: |
|
|
buf: the output of `dumps`. |
|
|
""" |
|
|
return pa.deserialize(buf) |
|
|
|
|
|
|
|
|
class RefDataset_gref(Dataset): |
|
|
def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size, |
|
|
word_length, args): |
|
|
super(RefDataset_gref, self).__init__() |
|
|
self.lmdb_dir = lmdb_dir |
|
|
self.mask_dir = mask_dir |
|
|
self.ROOT = '/data/seunghoon/VerbCentric_CY/datasets/VRIS' |
|
|
|
|
|
|
|
|
self.ROOT = '/data2/dataset/RefCOCO/VRIS' |
|
|
self.all_hp_root = "/data2/dataset/RefCOCO/refcocog/SBERT_gref_umd" |
|
|
self.exclude_position = True |
|
|
self.metric_learning = args.metric_learning |
|
|
self.exclude_multiobj = args.exclude_multiobj |
|
|
self.metric_mode = args.metric_mode |
|
|
self.hp_selection = args.hp_selection |
|
|
|
|
|
|
|
|
assert self.hp_selection in ['strict', 'base'], "Invalid hard positive selection mode" |
|
|
|
|
|
self.dataset = dataset |
|
|
self.split = split |
|
|
self.mode = mode |
|
|
self.input_size = (input_size, input_size) |
|
|
|
|
|
self.word_length = word_length |
|
|
self.emb_size = 384 |
|
|
self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1) |
|
|
self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1) |
|
|
self.length = info[dataset][split] |
|
|
self.env = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.metric_learning: |
|
|
self.multi_obj_ref_ids = self._load_multi_obj_ref_ids() |
|
|
self.hardpos_meta = self._load_metadata() |
|
|
else: |
|
|
self.multi_obj_ref_ids = None |
|
|
self.hardpos_meta = None |
|
|
|
|
|
def _load_multi_obj_ref_ids(self): |
|
|
|
|
|
if not self.exclude_multiobj and not self.exclude_position : |
|
|
return None |
|
|
elif self.exclude_position: |
|
|
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt') |
|
|
elif self.exclude_multiobj : |
|
|
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt') |
|
|
with open(multiobj_path, 'r') as f: |
|
|
return [int(line.strip()) for line in f.readlines()] |
|
|
|
|
|
def _load_metadata(self): |
|
|
|
|
|
hardpos_path = '/data2/dataset/RefCOCO/VRIS/hardpos_verdict_gref_v4.json' |
|
|
with open(hardpos_path, 'r', encoding='utf-8') as f: |
|
|
hardpos_json = json.load(f) |
|
|
return hardpos_json |
|
|
|
|
|
def _init_db(self): |
|
|
self.env = lmdb.open(self.lmdb_dir, |
|
|
subdir=os.path.isdir(self.lmdb_dir), |
|
|
readonly=True, |
|
|
lock=False, |
|
|
readahead=False, |
|
|
meminit=False) |
|
|
with self.env.begin(write=False) as txn: |
|
|
self.length = loads_pyarrow(txn.get(b'__len__')) |
|
|
self.keys = loads_pyarrow(txn.get(b'__keys__')) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
if self.env is None: |
|
|
self._init_db() |
|
|
env = self.env |
|
|
with env.begin(write=False) as txn: |
|
|
byteflow = txn.get(self.keys[index]) |
|
|
ref = loads_pyarrow(byteflow) |
|
|
|
|
|
ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8), |
|
|
cv2.IMREAD_COLOR) |
|
|
img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) |
|
|
img_size = img.shape[:2] |
|
|
|
|
|
seg_id = ref['seg_id'] |
|
|
mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png') |
|
|
|
|
|
idx = np.random.choice(ref['num_sents']) |
|
|
sents = ref['sents'] |
|
|
|
|
|
|
|
|
mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8), |
|
|
cv2.IMREAD_GRAYSCALE) |
|
|
mask = mask / 255. |
|
|
if self.mode == 'train': |
|
|
|
|
|
sent = sents[idx] |
|
|
|
|
|
raw_hardpos, hardpos, hp_pad_mask, hardpos_emb = self._get_hardpos_verb(ref, seg_id, idx) |
|
|
img, mask, sent = self.convert(img, mask, sent, inference=False) |
|
|
word_vec, pad_mask = tokenize(sent, self.word_length, True) |
|
|
|
|
|
|
|
|
hardpos = hardpos.unsqueeze(0) |
|
|
|
|
|
params = { |
|
|
'seg_id': seg_id, |
|
|
'sent': sent, |
|
|
'hardpos': raw_hardpos, |
|
|
'hardpos_emb': hardpos_emb |
|
|
} |
|
|
|
|
|
is_there_hp = torch.sum(hardpos) > 0 |
|
|
|
|
|
img = img.unsqueeze(0) |
|
|
mask = mask.unsqueeze(0) |
|
|
if is_there_hp: |
|
|
img = torch.cat([img, img], dim=0) |
|
|
word_vec = torch.cat([word_vec, hardpos], dim=0) |
|
|
mask = torch.cat([mask, mask], dim=0) |
|
|
pad_mask = torch.cat([pad_mask, hp_pad_mask], dim=0) |
|
|
|
|
|
return img, word_vec, mask, pad_mask, params |
|
|
elif self.mode == 'val': |
|
|
|
|
|
sent = sents[-1] |
|
|
word_vec, pad_mask = tokenize(sent, self.word_length, True) |
|
|
img, mask, sent = self.convert(img, mask, sent, inference=False) |
|
|
img = img.unsqueeze(0) |
|
|
mask = mask.unsqueeze(0) |
|
|
return img, word_vec, mask, pad_mask, None |
|
|
else: |
|
|
|
|
|
word_vecs = [] |
|
|
pad_masks = [] |
|
|
for sent in sents: |
|
|
word_vec, pad_mask = tokenize(sent, self.word_length, True) |
|
|
word_vecs.append(word_vec) |
|
|
pad_masks.append(pad_mask) |
|
|
img, mask, sent = self.convert(img, mask, sent, inference=True) |
|
|
return ori_img, img, word_vecs, mask, pad_masks, seg_id, sents, |
|
|
|
|
|
|
|
|
def _get_hardpos_verb(self, ref, seg_id, sent_idx): |
|
|
""" |
|
|
Handle the logic for selecting hard positive verb phrases during metric learning. |
|
|
Returns the sentence, raw_verb, and tokenized verb if applicable. |
|
|
""" |
|
|
|
|
|
if seg_id in self.multi_obj_ref_ids: |
|
|
verb_hardpos = torch.zeros(self.word_length, dtype=torch.long) |
|
|
verb_pad_mask = torch.zeros(self.word_length, dtype=torch.long).unsqueeze(0) |
|
|
verb_embed = torch.zeros(self.emb_size, dtype=torch.float32) |
|
|
|
|
|
return '', verb_hardpos, verb_pad_mask, verb_embed |
|
|
|
|
|
|
|
|
hardpos_dict = self.hardpos_meta.get(str(seg_id), {}) |
|
|
if self.hp_selection == 'strict' : |
|
|
sent_id_list = list(hardpos_dict.keys()) |
|
|
cur_sent_id = sent_id_list[sent_idx] |
|
|
cur_hardpos = hardpos_dict.get(cur_sent_id, {}).get('phrases', []) |
|
|
|
|
|
if cur_hardpos: |
|
|
|
|
|
rand_index = random.randint(0, len(cur_hardpos) - 1) |
|
|
raw_verb = cur_hardpos[rand_index] |
|
|
|
|
|
verb_hardpos, verb_pad_mask = tokenize(raw_verb, self.word_length, True) |
|
|
verb_hardpos = verb_hardpos.squeeze(0) |
|
|
verb_embed = torch.from_numpy(self._get_hardpos_embed(seg_id, cur_sent_id, rand_index)) |
|
|
|
|
|
|
|
|
return raw_verb, verb_hardpos, verb_pad_mask, verb_embed |
|
|
|
|
|
verb_hardpos = torch.zeros(self.word_length, dtype=torch.long) |
|
|
verb_pad_mask = torch.zeros(self.word_length, dtype=torch.long).unsqueeze(0) |
|
|
verb_embed = torch.zeros(self.emb_size, dtype=torch.float32) |
|
|
return '', verb_hardpos, verb_pad_mask, verb_embed |
|
|
|
|
|
|
|
|
def _get_hardpos_embed(self, seg_id, sent_id, rand_index): |
|
|
emb_folder = os.path.join(self.all_hp_root, str(seg_id)) |
|
|
emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_{sent_id}_") and f.endswith(".npy")]) |
|
|
selected_emb_file = os.path.join(emb_folder, emb_files[rand_index]) |
|
|
|
|
|
return np.load(selected_emb_file) |
|
|
|
|
|
|
|
|
def convert(self, img, mask, sent, inference=False): |
|
|
img = Image.fromarray(np.uint8(img)) |
|
|
mask = Image.fromarray(np.uint8(mask), mode="P") |
|
|
img = F.resize(img, self.input_size) |
|
|
if not inference: |
|
|
mask = F.resize(mask, self.input_size, interpolation=Image.NEAREST) |
|
|
img = F.to_tensor(img) |
|
|
mask = torch.as_tensor(np.asarray(mask).copy(), dtype=torch.int64) |
|
|
img = F.normalize(img, mean=self.mean, std=self.std) |
|
|
return img, mask, sent |
|
|
|
|
|
|
|
|
def __repr__(self): |
|
|
return self.__class__.__name__ + "(" + \ |
|
|
f"db_path={self.lmdb_dir}, " + \ |
|
|
f"dataset={self.dataset}, " + \ |
|
|
f"split={self.split}, " + \ |
|
|
f"mode={self.mode}, " + \ |
|
|
f"input_size={self.input_size}, " + \ |
|
|
f"word_length={self.word_length}" |
|
|
|
|
|
|
|
|
|