import os import sys import json from typing import List, Union import cv2 from PIL import Image import lmdb import random import itertools import albumentations as A import numpy as np import pyarrow as pa import torch from torch.utils.data import Dataset from torchvision.transforms import functional as F from bert.tokenization_bert import BertTokenizer info = { 'refcoco': { 'train': 42404, 'val': 3811, 'val-test': 3811, 'testA': 1975, 'testB': 1810 }, 'refcoco+': { 'train': 42278, 'val': 3805, 'val-test': 3805, 'testA': 1975, 'testB': 1798 }, 'refcocog_u': { 'train': 42226, 'val': 2573, 'val-test': 2573, 'test': 5023, 'testA':100, 'testB':100 }, 'refcocog_g': { 'train': 44822, 'val': 5000, 'val-test': 5000 } } _tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ l_mask = [0] * context_length result = [0] * context_length tokens = _tokenizer.encode(text=texts, add_special_tokens=True) tokens = tokens[:context_length] result[:len(tokens)] = tokens l_mask[:len(tokens)] = [1]*len(tokens) result = torch.tensor(result).unsqueeze(0) l_mask = torch.tensor(l_mask).unsqueeze(0) return result, l_mask def loads_pyarrow(buf): """ Args: buf: the output of `dumps`. """ return pa.deserialize(buf) class RefDataset_gref(Dataset): def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size, word_length, args): super(RefDataset_gref, self).__init__() self.lmdb_dir = lmdb_dir self.mask_dir = mask_dir self.ROOT = '/data/seunghoon/VerbCentric_CY/datasets/VRIS' # hardpos related self.ROOT = '/data2/dataset/RefCOCO/VRIS' self.all_hp_root = "/data2/dataset/RefCOCO/refcocog/SBERT_gref_umd" self.exclude_position = True self.metric_learning = args.metric_learning self.exclude_multiobj = args.exclude_multiobj self.metric_mode = args.metric_mode self.hp_selection = args.hp_selection assert self.hp_selection in ['strict', 'base'], "Invalid hard positive selection mode" self.dataset = dataset self.split = split self.mode = mode self.input_size = (input_size, input_size) #self.mask_size = [13, 26, 52] self.word_length = word_length self.emb_size = 384 # for filtering self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1) self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1) self.length = info[dataset][split] self.env = None # self.args = args # self.coco_transforms = make_coco_transforms(mode, cautious=False) # if (dataset not refcocog_u) or (dataset not refcocog_g): # assert hp if self.metric_learning: self.multi_obj_ref_ids = self._load_multi_obj_ref_ids() self.hardpos_meta = self._load_metadata() else: self.multi_obj_ref_ids = None self.hardpos_meta = None def _load_multi_obj_ref_ids(self): # Load multi-object reference IDs based on configurations if not self.exclude_multiobj and not self.exclude_position : return None elif self.exclude_position: multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt') elif self.exclude_multiobj : multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt') with open(multiobj_path, 'r') as f: return [int(line.strip()) for line in f.readlines()] def _load_metadata(self): # Load metadata for hard positive verb phrases, hard negative queries hardpos_path = '/data2/dataset/RefCOCO/VRIS/hardpos_verdict_gref_v4.json' with open(hardpos_path, 'r', encoding='utf-8') as f: hardpos_json = json.load(f) return hardpos_json def _init_db(self): self.env = lmdb.open(self.lmdb_dir, subdir=os.path.isdir(self.lmdb_dir), readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.length = loads_pyarrow(txn.get(b'__len__')) self.keys = loads_pyarrow(txn.get(b'__keys__')) def __len__(self): return self.length def __getitem__(self, index): # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129 if self.env is None: self._init_db() env = self.env with env.begin(write=False) as txn: byteflow = txn.get(self.keys[index]) ref = loads_pyarrow(byteflow) # img ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8), cv2.IMREAD_COLOR) img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) img_size = img.shape[:2] # mask seg_id = ref['seg_id'] mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png') # sentences idx = np.random.choice(ref['num_sents']) sents = ref['sents'] # transform # mask transform mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8), cv2.IMREAD_GRAYSCALE) mask = mask / 255. if self.mode == 'train': sent = sents[idx] # sentence -> vector raw_hardpos, hardpos, hp_pad_mask, hardpos_emb = self._get_hardpos_verb(ref, seg_id, idx) img, mask, sent = self.convert(img, mask, sent, inference=False) word_vec, pad_mask = tokenize(sent, self.word_length, True) # hardpos token ó�� hardpos = hardpos.unsqueeze(0) params = { 'seg_id': seg_id, 'sent': sent, 'hardpos': raw_hardpos, 'hardpos_emb': hardpos_emb } is_there_hp = torch.sum(hardpos) > 0 img = img.unsqueeze(0) mask = mask.unsqueeze(0) if is_there_hp: img = torch.cat([img, img], dim=0) word_vec = torch.cat([word_vec, hardpos], dim=0) mask = torch.cat([mask, mask], dim=0) pad_mask = torch.cat([pad_mask, hp_pad_mask], dim=0) return img, word_vec, mask, pad_mask, params elif self.mode == 'val': # sentence -> vector sent = sents[-1] word_vec, pad_mask = tokenize(sent, self.word_length, True) img, mask, sent = self.convert(img, mask, sent, inference=False) img = img.unsqueeze(0) mask = mask.unsqueeze(0) return img, word_vec, mask, pad_mask, None else: # sentence -> vector word_vecs = [] pad_masks = [] for sent in sents: word_vec, pad_mask = tokenize(sent, self.word_length, True) word_vecs.append(word_vec) pad_masks.append(pad_mask) img, mask, sent = self.convert(img, mask, sent, inference=True) return ori_img, img, word_vecs, mask, pad_masks, seg_id, sents, def _get_hardpos_verb(self, ref, seg_id, sent_idx): """ Handle the logic for selecting hard positive verb phrases during metric learning. Returns the sentence, raw_verb, and tokenized verb if applicable. """ # If the object appears multiple times, no hard positive is used if seg_id in self.multi_obj_ref_ids: verb_hardpos = torch.zeros(self.word_length, dtype=torch.long) verb_pad_mask = torch.zeros(self.word_length, dtype=torch.long).unsqueeze(0) verb_embed = torch.zeros(self.emb_size, dtype=torch.float32) return '', verb_hardpos, verb_pad_mask, verb_embed # Extract metadata for hard positives if present hardpos_dict = self.hardpos_meta.get(str(seg_id), {}) if self.hp_selection == 'strict' : sent_id_list = list(hardpos_dict.keys()) cur_sent_id = sent_id_list[sent_idx] cur_hardpos = hardpos_dict.get(cur_sent_id, {}).get('phrases', []) if cur_hardpos: # Assign a hard positive verb phrase if available rand_index = random.randint(0, len(cur_hardpos) - 1) raw_verb = cur_hardpos[rand_index] # raw_verb = random.choice(cur_hardpos) verb_hardpos, verb_pad_mask = tokenize(raw_verb, self.word_length, True) verb_hardpos = verb_hardpos.squeeze(0) verb_embed = torch.from_numpy(self._get_hardpos_embed(seg_id, cur_sent_id, rand_index)) # print("Positive phrase : " , raw_verb) return raw_verb, verb_hardpos, verb_pad_mask, verb_embed verb_hardpos = torch.zeros(self.word_length, dtype=torch.long) verb_pad_mask = torch.zeros(self.word_length, dtype=torch.long).unsqueeze(0) verb_embed = torch.zeros(self.emb_size, dtype=torch.float32) return '', verb_hardpos, verb_pad_mask, verb_embed def _get_hardpos_embed(self, seg_id, sent_id, rand_index): emb_folder = os.path.join(self.all_hp_root, str(seg_id)) emb_files = sorted([f for f in os.listdir(emb_folder) if f.startswith(f"hp_{sent_id}_") and f.endswith(".npy")]) selected_emb_file = os.path.join(emb_folder, emb_files[rand_index]) return np.load(selected_emb_file) def convert(self, img, mask, sent, inference=False): img = Image.fromarray(np.uint8(img)) mask = Image.fromarray(np.uint8(mask), mode="P") img = F.resize(img, self.input_size) if not inference: mask = F.resize(mask, self.input_size, interpolation=Image.NEAREST) img = F.to_tensor(img) mask = torch.as_tensor(np.asarray(mask).copy(), dtype=torch.int64) img = F.normalize(img, mean=self.mean, std=self.std) return img, mask, sent def __repr__(self): return self.__class__.__name__ + "(" + \ f"db_path={self.lmdb_dir}, " + \ f"dataset={self.dataset}, " + \ f"split={self.split}, " + \ f"mode={self.mode}, " + \ f"input_size={self.input_size}, " + \ f"word_length={self.word_length}"