MRaCL / CGFormer /utils /dataset_sbert.py
dianecy's picture
Upload folder using huggingface_hub
ea1014e verified
import os
import sys
import json
from typing import List, Union
import cv2
from PIL import Image
import lmdb
import random
import itertools
import albumentations as A
import numpy as np
import pyarrow as pa
import torch
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from bert.tokenization_bert import BertTokenizer
info = {
'refcoco': {
'train': 42404,
'val': 3811,
'val-test': 3811,
'testA': 1975,
'testB': 1810
},
'refcoco+': {
'train': 42278,
'val': 3805,
'val-test': 3805,
'testA': 1975,
'testB': 1798
},
'refcocog_u': {
'train': 42226,
'val': 2573,
'val-test': 2573,
'test': 5023,
'testA':100,
'testB':100
},
'refcocog_g': {
'train': 44822,
'val': 5000,
'val-test': 5000
}
}
_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize(texts: Union[str, List[str]],
context_length: int = 77,
truncate: bool = False) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
l_mask = [0] * context_length
result = [0] * context_length
tokens = _tokenizer.encode(text=texts, add_special_tokens=True)
tokens = tokens[:context_length]
result[:len(tokens)] = tokens
l_mask[:len(tokens)] = [1]*len(tokens)
result = torch.tensor(result).unsqueeze(0)
l_mask = torch.tensor(l_mask).unsqueeze(0)
return result, l_mask
def loads_pyarrow(buf):
"""
Args:
buf: the output of `dumps`.
"""
return pa.deserialize(buf)
class RefDataset_gref(Dataset):
def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size,
word_length, args):
super(RefDataset_gref, self).__init__()
self.lmdb_dir = lmdb_dir
self.mask_dir = mask_dir
self.ROOT = '/data/seunghoon/VerbCentric_CY/datasets/VRIS'
# hardpos related
self.ROOT = '/data2/dataset/RefCOCO/VRIS'
self.all_hp_root = "/data2/dataset/RefCOCO/refcocog/SBERT_gref_umd"
self.exclude_position = True
self.metric_learning = args.metric_learning
self.exclude_multiobj = args.exclude_multiobj
self.metric_mode = args.metric_mode
self.hp_selection = args.hp_selection
assert self.hp_selection in ['strict', 'base'], "Invalid hard positive selection mode"
self.dataset = dataset
self.split = split
self.mode = mode
self.input_size = (input_size, input_size)
#self.mask_size = [13, 26, 52]
self.word_length = word_length
self.emb_size = 384 # for filtering
self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)
self.length = info[dataset][split]
self.env = None
# self.args = args
# self.coco_transforms = make_coco_transforms(mode, cautious=False)
# if (dataset not refcocog_u) or (dataset not refcocog_g):
# assert hp
if self.metric_learning:
self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
self.hardpos_meta = self._load_metadata()
else:
self.multi_obj_ref_ids = None
self.hardpos_meta = None
def _load_multi_obj_ref_ids(self):
# Load multi-object reference IDs based on configurations
if not self.exclude_multiobj and not self.exclude_position :
return None
elif self.exclude_position:
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt')
elif self.exclude_multiobj :
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt')
with open(multiobj_path, 'r') as f:
return [int(line.strip()) for line in f.readlines()]
def _load_metadata(self):
# Load metadata for hard positive verb phrases, hard negative queries
hardpos_path = '/data2/dataset/RefCOCO/VRIS/hardpos_verdict_gref_v4.json'
with open(hardpos_path, 'r', encoding='utf-8') as f:
hardpos_json = json.load(f)
return hardpos_json
def _init_db(self):
self.env = lmdb.open(self.lmdb_dir,
subdir=os.path.isdir(self.lmdb_dir),
readonly=True,
lock=False,
readahead=False,
meminit=False)
with self.env.begin(write=False) as txn:
self.length = loads_pyarrow(txn.get(b'__len__'))
self.keys = loads_pyarrow(txn.get(b'__keys__'))
def __len__(self):
return self.length
def __getitem__(self, index):
# Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129
if self.env is None:
self._init_db()
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
ref = loads_pyarrow(byteflow)
# img
ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8),
cv2.IMREAD_COLOR)
img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
img_size = img.shape[:2]
# mask
seg_id = ref['seg_id']
mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png')
# sentences
idx = np.random.choice(ref['num_sents'])
sents = ref['sents']
# transform
# mask transform
mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8),
cv2.IMREAD_GRAYSCALE)
mask = mask / 255.
if self.mode == 'train':
sent = sents[idx]
# sentence -> vector
raw_hardpos, hardpos, hp_pad_mask, hardpos_emb = self._get_hardpos_verb(ref, seg_id, idx)
img, mask, sent = self.convert(img, mask, sent, inference=False)
word_vec, pad_mask = tokenize(sent, self.word_length, True)
# hardpos token ó��
hardpos = hardpos.unsqueeze(0)
params = {
'seg_id': seg_id,
'sent': sent,
'hardpos': raw_hardpos,
'hardpos_emb': hardpos_emb
}
is_there_hp = torch.sum(hardpos) > 0
img = img.unsqueeze(0)
mask = mask.unsqueeze(0)
if is_there_hp:
img = torch.cat([img, img], dim=0)
word_vec = torch.cat([word_vec, hardpos], dim=0)
mask = torch.cat([mask, mask], dim=0)
pad_mask = torch.cat([pad_mask, hp_pad_mask], dim=0)
return img, word_vec, mask, pad_mask, params
elif self.mode == 'val':
# sentence -> vector
sent = sents[-1]
word_vec, pad_mask = tokenize(sent, self.word_length, True)
img, mask, sent = self.convert(img, mask, sent, inference=False)
img = img.unsqueeze(0)
mask = mask.unsqueeze(0)
return img, word_vec, mask, pad_mask, None
else:
# sentence -> vector
word_vecs = []
pad_masks = []
for sent in sents:
word_vec, pad_mask = tokenize(sent, self.word_length, True)
word_vecs.append(word_vec)
pad_masks.append(pad_mask)
img, mask, sent = self.convert(img, mask, sent, inference=True)
return ori_img, img, word_vecs, mask, pad_masks, seg_id, sents,
def _get_hardpos_verb(self, ref, seg_id, sent_idx):
"""
Handle the logic for selecting hard positive verb phrases during metric learning.
Returns the sentence, raw_verb, and tokenized verb if applicable.
"""
# If the object appears multiple times, no hard positive is used
if seg_id in self.multi_obj_ref_ids:
verb_hardpos = torch.zeros(self.word_length, dtype=torch.long)
verb_pad_mask = torch.zeros(self.word_length, dtype=torch.long).unsqueeze(0)
verb_embed = torch.zeros(self.emb_size, dtype=torch.float32)
return '', verb_hardpos, verb_pad_mask, verb_embed
# Extract metadata for hard positives if present
hardpos_dict = self.hardpos_meta.get(str(seg_id), {})
if self.hp_selection == 'strict' :
sent_id_list = list(hardpos_dict.keys())
cur_sent_id = sent_id_list[sent_idx]
cur_hardpos = hardpos_dict.get(cur_sent_id, {}).get('phrases', [])
if cur_hardpos:
# Assign a hard positive verb phrase if available
rand_index = random.randint(0, len(cur_hardpos) - 1)
raw_verb = cur_hardpos[rand_index]
# raw_verb = random.choice(cur_hardpos)
verb_hardpos, verb_pad_mask = tokenize(raw_verb, self.word_length, True)
verb_hardpos = verb_hardpos.squeeze(0)
verb_embed = torch.from_numpy(self._get_hardpos_embed(seg_id, cur_sent_id, rand_index))
# print("Positive phrase : " , raw_verb)
return raw_verb, verb_hardpos, verb_pad_mask, verb_embed
verb_hardpos = torch.zeros(self.word_length, dtype=torch.long)
verb_pad_mask = torch.zeros(self.word_length, dtype=torch.long).unsqueeze(0)
verb_embed = torch.zeros(self.emb_size, dtype=torch.float32)
return '', verb_hardpos, verb_pad_mask, verb_embed
def _get_hardpos_embed(self, seg_id, sent_id, rand_index):
emb_folder = os.path.join(self.all_hp_root, str(seg_id))
emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_{sent_id}_") and f.endswith(".npy")])
selected_emb_file = os.path.join(emb_folder, emb_files[rand_index])
return np.load(selected_emb_file)
def convert(self, img, mask, sent, inference=False):
img = Image.fromarray(np.uint8(img))
mask = Image.fromarray(np.uint8(mask), mode="P")
img = F.resize(img, self.input_size)
if not inference:
mask = F.resize(mask, self.input_size, interpolation=Image.NEAREST)
img = F.to_tensor(img)
mask = torch.as_tensor(np.asarray(mask).copy(), dtype=torch.int64)
img = F.normalize(img, mean=self.mean, std=self.std)
return img, mask, sent
def __repr__(self):
return self.__class__.__name__ + "(" + \
f"db_path={self.lmdb_dir}, " + \
f"dataset={self.dataset}, " + \
f"split={self.split}, " + \
f"mode={self.mode}, " + \
f"input_size={self.input_size}, " + \
f"word_length={self.word_length}"