| import copy |
|
|
| from .base_eval_dataset import BaseEvalDataset |
|
|
| import os |
| import numpy as np |
| import torch |
| from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, |
| master_only) |
| from xtuner.registry import BUILDER |
| from mmengine.config import Config |
| from mmengine.config import ConfigDict |
|
|
| from mmengine import print_log |
| from PIL import Image |
| from pycocotools import mask |
| from projects.omg_llava.dataset.utils import expand2square |
| from projects.omg_llava.dataset.utils.refcoco_refer import REFER |
| from projects.omg_llava.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU |
| from pycocotools import mask as _mask |
|
|
| DATASETS_ATTRIBUTES = { |
| 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'}, |
| 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'}, |
| 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'}, |
| } |
|
|
| class RESDataset(BaseEvalDataset): |
| METAINFO: dict = dict(name='Referring Expression Segmentation') |
|
|
| def __init__(self, |
| image_folder, |
| dataset_name, |
| image_processor, |
| data_path=None, |
| split='val', |
| pad_image_to_square=True, |
| metainfo=None, |
| ori_image=False, |
| ): |
| super().__init__(metainfo) |
| self.split = split |
| self._set_attribute(dataset_name) |
|
|
| json_datas = self.json_file_preprocess(data_path) |
| self.json_datas = json_datas |
|
|
| self.image_folder = image_folder |
| size = image_processor.crop_size |
| if isinstance(size, int): |
| self.image_h, self.image_w = size, size |
| else: |
| self.image_w, self.image_h = size |
|
|
| if isinstance(image_processor, dict) or isinstance( |
| image_processor, Config) or isinstance(image_processor, |
| ConfigDict): |
| self.image_processor = BUILDER.build(image_processor) |
| else: |
| self.image_processor = image_processor |
| self.pad_image_to_square = pad_image_to_square |
| self.down_ratio = 1 |
|
|
| self.ori_image = ori_image |
|
|
| def _set_attribute(self, dataset_name): |
| attr_dict = DATASETS_ATTRIBUTES[dataset_name] |
|
|
| self.splitBy = attr_dict['splitBy'] |
| self.dataset_name = attr_dict['dataset_name'] |
|
|
| def __len__(self): |
| return len(self.json_datas) |
|
|
| def real_len(self): |
| return len(self.json_datas) |
|
|
| def json_file_preprocess(self, data_path): |
| splitBy = self.splitBy |
| dataset_name = self.dataset_name |
| refer_api = REFER(data_path, dataset_name, splitBy) |
| ref_ids_train = refer_api.getRefIds(split=self.split) |
| images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) |
| refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) |
| self.img2refs = self.create_img_to_refs_mapping(refs_train) |
|
|
| image_infos = [] |
| loaded_images = refer_api.loadImgs(image_ids=images_ids_train) |
| for item in loaded_images: |
| item = item.copy() |
| image_infos.append(item) |
|
|
| self.annotations = refer_api.Anns |
| refs = [self.img2refs[image_info['id']] for image_info in image_infos] |
|
|
| ret = [] |
| for image_info, ref in zip(image_infos, refs): |
| if len(ref) == 0: |
| continue |
|
|
| sents = [] |
| ann_ids = [] |
| for _ref in ref: |
| for sent in _ref["sentences"]: |
| text = sent["sent"] |
| sents.append(text) |
| ann_ids.append(_ref["ann_id"]) |
|
|
| sampled_inds = list(range(len(sents))) |
| sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() |
| sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds] |
| selected_labels = sampled_sents |
| ret.append( |
| {'image_info': image_info, |
| 'sampled_ann_id': sampled_ann_ids, |
| 'selected_labels': selected_labels, |
| 'image': image_info['file_name'] |
| } |
| ) |
| return ret |
|
|
| def create_img_to_refs_mapping(self, refs_train): |
| img2refs = {} |
| for ref in refs_train: |
| img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ] |
| return img2refs |
|
|
| def decode_mask(self, annotations_ids, image_info): |
| flag = False |
| masks = [] |
|
|
| for ann_id in annotations_ids: |
| if isinstance(ann_id, list): |
| flag = True |
| if -1 in ann_id: |
| assert len(ann_id) == 1 |
| m = np.zeros((image_info["height"], image_info["width"])).astype( |
| np.uint8 |
| ) |
| else: |
| m_final = np.zeros( |
| (image_info["height"], image_info["width"]) |
| ).astype(np.uint8) |
| for ann_id_i in ann_id: |
| ann = self.annotations[ann_id_i] |
|
|
| if len(ann["segmentation"]) == 0: |
| m = np.zeros( |
| (image_info["height"], image_info["width"]) |
| ).astype(np.uint8) |
| else: |
| if type(ann["segmentation"][0]) == list: |
| rle = mask.frPyObjects( |
| ann["segmentation"], image_info["height"], image_info["width"], ) |
| else: |
| rle = ann["segmentation"] |
| for i in range(len(rle)): |
| if not isinstance(rle[i]["counts"], bytes): |
| rle[i]["counts"] = rle[i]["counts"].encode() |
| m = mask.decode(rle) |
| m = np.sum( |
| m, axis=2 |
| ) |
| m = m.astype(np.uint8) |
| m_final = m_final | m |
| m = m_final |
| masks.append(m) |
| continue |
|
|
| ann = self.annotations[ann_id] |
|
|
| if len(ann["segmentation"]) == 0: |
| m = np.zeros((image_info["height"], image_info["width"])).astype( |
| np.uint8 |
| ) |
| masks.append(m) |
| continue |
|
|
| if type(ann["segmentation"][0]) == list: |
| rle = mask.frPyObjects( |
| ann["segmentation"], image_info["height"], image_info["width"] |
| ) |
| else: |
| rle = ann["segmentation"] |
| for i in range(len(rle)): |
| if not isinstance(rle[i]["counts"], bytes): |
| rle[i]["counts"] = rle[i]["counts"].encode() |
| m = mask.decode(rle) |
| m = np.sum(m, axis=2) |
| m = m.astype(np.uint8) |
| masks.append(m) |
| masks = np.stack(masks, axis=0) |
|
|
| |
| masks = torch.from_numpy(masks) |
| return masks |
|
|
| def only_get_text_infos(self, json_data): |
| return {'sampled_sents': json_data['selected_labels']} |
|
|
| def get_questions(self, text_require_infos): |
| sampled_sents = text_require_infos['sampled_sents'] |
| ret = [] |
| for sent in sampled_sents: |
| ret.append("<image>\n Please segment {} in this image.".format(sent)) |
| return ret |
|
|
| def filter_data_dict(self, data_dict): |
| names = ['pixel_values', 'masks', 'ori_image_size', 'text_prompts', 'img_id', 'ori_image'] |
| ret = {name: data_dict[name] for name in names} |
| return ret |
|
|
| def __getitem__(self, index): |
| index = index % self.real_len() |
| data_dict = self.json_datas[index] |
| text_require_infos = self.only_get_text_infos(data_dict) |
| questions = self.get_questions(text_require_infos) |
|
|
| assert data_dict.get('image', None) is not None |
| if data_dict.get('image', None) is not None: |
| image_file = data_dict['image'] |
| image_file = os.path.join(self.image_folder, image_file) |
| image = Image.open(image_file).convert('RGB') |
| if self.ori_image: |
| ori_image = copy.deepcopy(image) |
| ori_width, ori_height = image.size |
| if self.pad_image_to_square: |
| image = expand2square( |
| image, |
| tuple( |
| int(x * 255) for x in self.image_processor.image_mean)) |
| image = self.image_processor.preprocess( |
| image, return_tensors='pt')['pixel_values'][0] |
| data_dict['pixel_values'] = image |
|
|
| |
| masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info']) |
| data_dict['masks'] = masks |
| data_dict['ori_image_size'] = (ori_width, ori_height) |
| data_dict['text_prompts'] = questions |
| data_dict['img_id'] = str(index) |
|
|
| if self.ori_image: |
| data_dict['ori_image'] = ori_image |
|
|
| return self.filter_data_dict(data_dict) |
|
|
| @master_only |
| def evaluate(self, result, work_dir): |
| trackers = { |
| "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM), |
| "union": AverageMeter("Union", ":6.3f", Summary.SUM), |
| "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM) |
| } |
| for pred_dict in result: |
| intersection, union, accuracy_iou = 0.0, 0.0, 0.0 |
| masks = pred_dict['prediction_masks'] |
| _masks = [] |
| for mask in masks: |
| if mask is not None: |
| mask = rle_to_mask(mask) |
| _masks.append(mask) |
| targets = pred_dict['gt_masks'] |
| _targets = rle_to_mask(targets) |
|
|
| for i_item, _mask in enumerate(_masks): |
| if _mask is None: |
| continue |
|
|
| _target = _targets[i_item: i_item+1] |
| for prediction, target in zip(_mask, _target): |
| prediction = torch.from_numpy(prediction).int().cuda() |
| target = torch.from_numpy(target).int().cuda() |
| intersect, union_, _ = intersectionAndUnionGPU( |
| prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255 |
| ) |
| intersection += intersect |
| union += union_ |
| accuracy_iou += intersect / (union_ + 1e-5) |
| accuracy_iou[union_ == 0] += 1.0 |
|
|
| intersection, union = intersection.cpu().numpy(), union.cpu().numpy() |
| accuracy_iou = accuracy_iou.cpu().numpy() / _targets.shape[0] |
| trackers["intersection"].update(intersection) |
| trackers["union"].update(union) |
| trackers["gIoU"].update(accuracy_iou, n=_targets.shape[0]) |
|
|
| cur_results = {'pixel_intersection': trackers["intersection"].sum[1], |
| 'pixel_union': trackers["union"].sum[1], |
| 'gIoU': trackers["gIoU"].avg[1], |
| 'mask_counts': trackers["gIoU"].count, |
| } |
| class_iou = cur_results['pixel_intersection'] / (cur_results['pixel_union'] + 1e-10) |
| global_iou = cur_results['gIoU'] |
|
|
| print_log('============================================', 'current') |
| print_log('CIoU: {}, GIoU: {}'.format(class_iou, global_iou), 'current') |
| print_log('============================================', 'current') |
| print_log('RES_{}_{} successfully finished evaluating'.format(self.dataset_name, self.split), |
| 'current') |
| return {'Acc': class_iou} |
|
|
|
|
| def rle_to_mask(rle): |
| mask = [] |
| for r in rle: |
| m = _mask.decode(r) |
| m = np.uint8(m) |
| mask.append(m) |
| mask = np.stack(mask, axis=0) |
| return mask |