DenseLabelDev / projects /omg_llava /dataset /ReferringSegDataset.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import logging
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict, load_from_disk
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
from pycocotools import mask
import numpy as np
import torch.nn.functional as F
from xtuner.registry import BUILDER
from .utils import expand2square, expand2square_mask
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
from .utils.refcoco_refer import REFER
import copy
from xtuner.dataset.utils import encode_fn
class RefcocoReferringSegDataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,):
self._set_attribute()
self.tokenizer = tokenizer
assert offline_processed_text_folder or (data_path and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.dataset_map_fn = dataset_map_fn
self.template_map_fn = template_map_fn
if isinstance(self.template_map_fn, dict) and self.lazy:
_type = self.template_map_fn['type']
del self.template_map_fn['type']
self.template_map_fn = _type(**self.template_map_fn)
if offline_processed_text_folder and data_path:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
json_datas = self.json_file_preprocess(data_path)
self.json_datas = json_datas
json_datas = self.only_get_hf_map_infos()
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
if self.lazy:
self.text_data = build_origin_dataset(json_data, 'train')
else:
self.text_data = process_hf_dataset(
dataset=json_data,
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
split='train',
max_dataset_length=max_dataset_length,
remove_unused_columns=False,
pack_to_max_length=False,
with_image_token=True,
map_num_proc=num_proc, # because limited mem
)
self.image_folder = image_folder
size = image_processor.crop_size
if isinstance(size, int):
self.image_h, self.image_w = size, size
else:
self.image_w, self.image_h = size
if isinstance(image_processor, dict) or isinstance(
image_processor, Config) or isinstance(image_processor,
ConfigDict):
self.image_processor = BUILDER.build(image_processor)
else:
self.image_processor = image_processor
self.pad_image_to_square = pad_image_to_square
self.down_ratio = 1
self.repeats = repeats
self.tokenizer = tokenizer
def _set_attribute(self):
self.splitBy = "unc"
self.dataset_name = 'refcoco'
def only_get_hf_map_infos(self):
ret = []
for json_data in self.json_datas:
ret.append({'sampled_sents': json_data['selected_labels']})
return ret
def __len__(self):
return len(self.text_data) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
if self.lazy:
cur_len = 100
else:
cur_len = len(data_dict['input_ids'])
if data_dict.get('image', None) is None:
cur_len = -cur_len
length_list.append(cur_len)
return length_list
def real_len(self):
return len(self.text_data)
def json_file_preprocess(self, data_path):
splitBy = self.splitBy
dataset_name = self.dataset_name
refer_api = REFER(data_path, dataset_name, splitBy)
ref_ids_train = refer_api.getRefIds(split='train')
images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
self.img2refs = self.create_img_to_refs_mapping(refs_train)
image_infos = []
loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
for item in loaded_images:
item = item.copy()
image_infos.append(item)
self.annotations = refer_api.Anns
refs = [self.img2refs[image_info['id']] for image_info in image_infos]
ret = []
for image_info, ref in zip(image_infos, refs):
if len(ref) == 0:
continue
sents = []
ann_ids = []
for _ref in ref:
for sent in _ref["sentences"]:
text = sent["sent"]
sents.append(text)
ann_ids.append(_ref["ann_id"])
if len(sents) >= 3:
sampled_inds = np.random.choice(
list(range(len(sents))), size=3, replace=False
)
else:
sampled_inds = list(range(len(sents)))
sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
selected_labels = sampled_sents
ret.append(
{'image_info': image_info,
'sampled_ann_id': sampled_ann_ids,
'selected_labels': selected_labels,
'image': image_info['file_name']
}
)
return ret
def create_img_to_refs_mapping(self, refs_train):
img2refs = {}
for ref in refs_train:
img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
return img2refs
def decode_mask(self, annotations_ids, image_info):
flag = False
masks = []
for ann_id in annotations_ids:
if isinstance(ann_id, list):
flag = True
if -1 in ann_id:
assert len(ann_id) == 1
m = np.zeros((image_info["height"], image_info["width"])).astype(
np.uint8
)
else:
m_final = np.zeros(
(image_info["height"], image_info["width"])
).astype(np.uint8)
for ann_id_i in ann_id:
ann = self.annotations[ann_id_i]
if len(ann["segmentation"]) == 0:
m = np.zeros(
(image_info["height"], image_info["width"])
).astype(np.uint8)
else:
if type(ann["segmentation"][0]) == list: # polygon
rle = mask.frPyObjects(
ann["segmentation"], image_info["height"], image_info["width"], )
else:
rle = ann["segmentation"]
for i in range(len(rle)):
if not isinstance(rle[i]["counts"], bytes):
rle[i]["counts"] = rle[i]["counts"].encode()
m = mask.decode(rle)
m = np.sum(
m, axis=2
) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
m_final = m_final | m
m = m_final
masks.append(m)
continue
ann = self.annotations[ann_id]
if len(ann["segmentation"]) == 0:
m = np.zeros((image_info["height"], image_info["width"])).astype(
np.uint8
)
masks.append(m)
continue
if type(ann["segmentation"][0]) == list: # polygon
rle = mask.frPyObjects(
ann["segmentation"], image_info["height"], image_info["width"]
)
else:
rle = ann["segmentation"]
for i in range(len(rle)):
if not isinstance(rle[i]["counts"], bytes):
rle[i]["counts"] = rle[i]["counts"].encode()
m = mask.decode(rle)
m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
masks.append(m)
masks = np.stack(masks, axis=0)
if self.pad_image_to_square:
masks = expand2square_mask(masks)
masks = torch.from_numpy(masks)
masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio,
self.image_w // self.down_ratio), mode='nearest').squeeze(0)
return masks
def __getitem__(self, index):
index = index % self.real_len()
data_dict = copy.deepcopy(self.text_data[index])
data_dict.update(self.json_datas[index])
if self.lazy:
result = self.dataset_map_fn(data_dict)
data_dict.update(result)
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
data_dict.update(result)
assert 'image' in data_dict.keys()
if data_dict.get('image', None) is not None:
image_file = data_dict['image']
image_file = os.path.join(self.image_folder, image_file)
image = Image.open(image_file).convert('RGB')
ori_width, ori_height = image.size
if self.pad_image_to_square:
image = expand2square(
image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
data_dict['pixel_values'] = image
# process and get masks
masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info'])
data_dict['masks'] = masks
else:
if hasattr(self.image_processor, 'crop_size'):
crop_size = self.image_processor.crop_size
else:
crop_size = self.image_processor.size
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
crop_size['width'])
data_dict['masks'] = None
return data_dict
class Refcoco_plus_ReferringSegDataset(RefcocoReferringSegDataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,):
super().__init__(
image_folder=image_folder,
image_processor=image_processor,
data_path=data_path,
tokenizer=tokenizer,
offline_processed_text_folder=offline_processed_text_folder,
max_dataset_length=max_dataset_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
max_length=max_length,
pad_image_to_square=pad_image_to_square,
num_proc=num_proc,
lazy=lazy,
repeats=repeats,)
self.tokenizer = tokenizer
def _set_attribute(self):
self.splitBy = "unc"
self.dataset_name = 'refcoco+'
class Refcocog_ReferringSegDataset(RefcocoReferringSegDataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,):
super().__init__(
image_folder=image_folder,
image_processor=image_processor,
data_path=data_path,
tokenizer=tokenizer,
offline_processed_text_folder=offline_processed_text_folder,
max_dataset_length=max_dataset_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
max_length=max_length,
pad_image_to_square=pad_image_to_square,
num_proc=num_proc,
lazy=lazy,
repeats=repeats,
)
def _set_attribute(self):
self.splitBy = "umd"
self.dataset_name = 'refcocog'
class Refclef_ReferringSegDataset(RefcocoReferringSegDataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
offline_processed_text_folder=None,
max_dataset_length=None,
dataset_map_fn=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
num_proc=8,
lazy=False,
repeats=1,):
super().__init__(
image_folder=image_folder,
image_processor=image_processor,
data_path=data_path,
tokenizer=tokenizer,
offline_processed_text_folder=offline_processed_text_folder,
max_dataset_length=max_dataset_length,
dataset_map_fn=dataset_map_fn,
template_map_fn=template_map_fn,
max_length=max_length,
pad_image_to_square=pad_image_to_square,
num_proc=num_proc,
lazy=lazy,
repeats=repeats,
)
def _set_attribute(self):
self.splitBy = "unc"
self.dataset_name = 'refclef'