import logging import os import torch from datasets import Dataset as HFDataset from datasets import DatasetDict, load_from_disk from mmengine import print_log from mmengine.config import Config, ConfigDict from PIL import Image from torch.utils.data import Dataset from pycocotools import mask import numpy as np import torch.nn.functional as F from xtuner.registry import BUILDER from .utils import expand2square, expand2square_mask from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset from .utils.refcoco_refer import REFER import copy from xtuner.dataset.utils import encode_fn class RefcocoReferringSegDataset(Dataset): def __init__(self, image_folder, image_processor, data_path=None, tokenizer=None, offline_processed_text_folder=None, max_dataset_length=None, dataset_map_fn=None, template_map_fn=None, max_length=2048, pad_image_to_square=False, num_proc=8, lazy=False, repeats=1,): self._set_attribute() self.tokenizer = tokenizer assert offline_processed_text_folder or (data_path and tokenizer) self.lazy = lazy self.max_length = max_length self.dataset_map_fn = dataset_map_fn self.template_map_fn = template_map_fn if isinstance(self.template_map_fn, dict) and self.lazy: _type = self.template_map_fn['type'] del self.template_map_fn['type'] self.template_map_fn = _type(**self.template_map_fn) if offline_processed_text_folder and data_path: print_log( 'Both `offline_processed_text_folder` and ' '`data_path` are set, and we load dataset from' '`offline_processed_text_folder` ' f'({offline_processed_text_folder})', logger='current', level=logging.WARNING) if offline_processed_text_folder is not None: raise NotImplementedError else: json_datas = self.json_file_preprocess(data_path) self.json_datas = json_datas json_datas = self.only_get_hf_map_infos() json_data = DatasetDict({'train': HFDataset.from_list(json_datas)}) if self.lazy: self.text_data = build_origin_dataset(json_data, 'train') else: self.text_data = process_hf_dataset( dataset=json_data, tokenizer=tokenizer, max_length=max_length, dataset_map_fn=dataset_map_fn, template_map_fn=template_map_fn, split='train', max_dataset_length=max_dataset_length, remove_unused_columns=False, pack_to_max_length=False, with_image_token=True, map_num_proc=num_proc, # because limited mem ) self.image_folder = image_folder size = image_processor.crop_size if isinstance(size, int): self.image_h, self.image_w = size, size else: self.image_w, self.image_h = size if isinstance(image_processor, dict) or isinstance( image_processor, Config) or isinstance(image_processor, ConfigDict): self.image_processor = BUILDER.build(image_processor) else: self.image_processor = image_processor self.pad_image_to_square = pad_image_to_square self.down_ratio = 1 self.repeats = repeats self.tokenizer = tokenizer def _set_attribute(self): self.splitBy = "unc" self.dataset_name = 'refcoco' def only_get_hf_map_infos(self): ret = [] for json_data in self.json_datas: ret.append({'sampled_sents': json_data['selected_labels']}) return ret def __len__(self): return len(self.text_data) * self.repeats @property def modality_length(self): length_list = [] for data_dict in self.text_data: if self.lazy: cur_len = 100 else: cur_len = len(data_dict['input_ids']) if data_dict.get('image', None) is None: cur_len = -cur_len length_list.append(cur_len) return length_list def real_len(self): return len(self.text_data) def json_file_preprocess(self, data_path): splitBy = self.splitBy dataset_name = self.dataset_name refer_api = REFER(data_path, dataset_name, splitBy) ref_ids_train = refer_api.getRefIds(split='train') images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) self.img2refs = self.create_img_to_refs_mapping(refs_train) image_infos = [] loaded_images = refer_api.loadImgs(image_ids=images_ids_train) for item in loaded_images: item = item.copy() image_infos.append(item) self.annotations = refer_api.Anns refs = [self.img2refs[image_info['id']] for image_info in image_infos] ret = [] for image_info, ref in zip(image_infos, refs): if len(ref) == 0: continue sents = [] ann_ids = [] for _ref in ref: for sent in _ref["sentences"]: text = sent["sent"] sents.append(text) ann_ids.append(_ref["ann_id"]) if len(sents) >= 3: sampled_inds = np.random.choice( list(range(len(sents))), size=3, replace=False ) else: sampled_inds = list(range(len(sents))) sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds] selected_labels = sampled_sents ret.append( {'image_info': image_info, 'sampled_ann_id': sampled_ann_ids, 'selected_labels': selected_labels, 'image': image_info['file_name'] } ) return ret def create_img_to_refs_mapping(self, refs_train): img2refs = {} for ref in refs_train: img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ] return img2refs def decode_mask(self, annotations_ids, image_info): flag = False masks = [] for ann_id in annotations_ids: if isinstance(ann_id, list): flag = True if -1 in ann_id: assert len(ann_id) == 1 m = np.zeros((image_info["height"], image_info["width"])).astype( np.uint8 ) else: m_final = np.zeros( (image_info["height"], image_info["width"]) ).astype(np.uint8) for ann_id_i in ann_id: ann = self.annotations[ann_id_i] if len(ann["segmentation"]) == 0: m = np.zeros( (image_info["height"], image_info["width"]) ).astype(np.uint8) else: if type(ann["segmentation"][0]) == list: # polygon rle = mask.frPyObjects( ann["segmentation"], image_info["height"], image_info["width"], ) else: rle = ann["segmentation"] for i in range(len(rle)): if not isinstance(rle[i]["counts"], bytes): rle[i]["counts"] = rle[i]["counts"].encode() m = mask.decode(rle) m = np.sum( m, axis=2 ) # sometimes there are multiple binary map (corresponding to multiple segs) m = m.astype(np.uint8) # convert to np.uint8 m_final = m_final | m m = m_final masks.append(m) continue ann = self.annotations[ann_id] if len(ann["segmentation"]) == 0: m = np.zeros((image_info["height"], image_info["width"])).astype( np.uint8 ) masks.append(m) continue if type(ann["segmentation"][0]) == list: # polygon rle = mask.frPyObjects( ann["segmentation"], image_info["height"], image_info["width"] ) else: rle = ann["segmentation"] for i in range(len(rle)): if not isinstance(rle[i]["counts"], bytes): rle[i]["counts"] = rle[i]["counts"].encode() m = mask.decode(rle) m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) m = m.astype(np.uint8) # convert to np.uint8 masks.append(m) masks = np.stack(masks, axis=0) if self.pad_image_to_square: masks = expand2square_mask(masks) masks = torch.from_numpy(masks) masks = F.interpolate(masks.unsqueeze(0), size=(self.image_h // self.down_ratio, self.image_w // self.down_ratio), mode='nearest').squeeze(0) return masks def __getitem__(self, index): index = index % self.real_len() data_dict = copy.deepcopy(self.text_data[index]) data_dict.update(self.json_datas[index]) if self.lazy: result = self.dataset_map_fn(data_dict) data_dict.update(result) result = self.template_map_fn(data_dict) data_dict.update(result) result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True) data_dict.update(result) assert 'image' in data_dict.keys() if data_dict.get('image', None) is not None: image_file = data_dict['image'] image_file = os.path.join(self.image_folder, image_file) image = Image.open(image_file).convert('RGB') ori_width, ori_height = image.size if self.pad_image_to_square: image = expand2square( image, tuple( int(x * 255) for x in self.image_processor.image_mean)) image = self.image_processor.preprocess( image, return_tensors='pt')['pixel_values'][0] data_dict['pixel_values'] = image # process and get masks masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info']) data_dict['masks'] = masks else: if hasattr(self.image_processor, 'crop_size'): crop_size = self.image_processor.crop_size else: crop_size = self.image_processor.size data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], crop_size['width']) data_dict['masks'] = None return data_dict class Refcoco_plus_ReferringSegDataset(RefcocoReferringSegDataset): def __init__(self, image_folder, image_processor, data_path=None, tokenizer=None, offline_processed_text_folder=None, max_dataset_length=None, dataset_map_fn=None, template_map_fn=None, max_length=2048, pad_image_to_square=False, num_proc=8, lazy=False, repeats=1,): super().__init__( image_folder=image_folder, image_processor=image_processor, data_path=data_path, tokenizer=tokenizer, offline_processed_text_folder=offline_processed_text_folder, max_dataset_length=max_dataset_length, dataset_map_fn=dataset_map_fn, template_map_fn=template_map_fn, max_length=max_length, pad_image_to_square=pad_image_to_square, num_proc=num_proc, lazy=lazy, repeats=repeats,) self.tokenizer = tokenizer def _set_attribute(self): self.splitBy = "unc" self.dataset_name = 'refcoco+' class Refcocog_ReferringSegDataset(RefcocoReferringSegDataset): def __init__(self, image_folder, image_processor, data_path=None, tokenizer=None, offline_processed_text_folder=None, max_dataset_length=None, dataset_map_fn=None, template_map_fn=None, max_length=2048, pad_image_to_square=False, num_proc=8, lazy=False, repeats=1,): super().__init__( image_folder=image_folder, image_processor=image_processor, data_path=data_path, tokenizer=tokenizer, offline_processed_text_folder=offline_processed_text_folder, max_dataset_length=max_dataset_length, dataset_map_fn=dataset_map_fn, template_map_fn=template_map_fn, max_length=max_length, pad_image_to_square=pad_image_to_square, num_proc=num_proc, lazy=lazy, repeats=repeats, ) def _set_attribute(self): self.splitBy = "umd" self.dataset_name = 'refcocog' class Refclef_ReferringSegDataset(RefcocoReferringSegDataset): def __init__(self, image_folder, image_processor, data_path=None, tokenizer=None, offline_processed_text_folder=None, max_dataset_length=None, dataset_map_fn=None, template_map_fn=None, max_length=2048, pad_image_to_square=False, num_proc=8, lazy=False, repeats=1,): super().__init__( image_folder=image_folder, image_processor=image_processor, data_path=data_path, tokenizer=tokenizer, offline_processed_text_folder=offline_processed_text_folder, max_dataset_length=max_dataset_length, dataset_map_fn=dataset_map_fn, template_map_fn=template_map_fn, max_length=max_length, pad_image_to_square=pad_image_to_square, num_proc=num_proc, lazy=lazy, repeats=repeats, ) def _set_attribute(self): self.splitBy = "unc" self.dataset_name = 'refclef'