| __author__ = 'licheng' |
|
|
| """ |
| This interface provides access to four datasets: |
| 1) refclef |
| 2) refcoco |
| 3) refcoco+ |
| 4) refcocog |
| split by unc and google |
| |
| The following API functions are defined: |
| REFER - REFER api class |
| getRefIds - get ref ids that satisfy given filter conditions. |
| getAnnIds - get ann ids that satisfy given filter conditions. |
| getImgIds - get image ids that satisfy given filter conditions. |
| getCatIds - get category ids that satisfy given filter conditions. |
| loadRefs - load refs with the specified ref ids. |
| loadAnns - load anns with the specified ann ids. |
| loadImgs - load images with the specified image ids. |
| loadCats - load category names with the specified category ids. |
| getRefBox - get ref's bounding box [x, y, w, h] given the ref_id |
| showRef - show image, segmentation or box of the referred object with the ref |
| getMask - get mask and area of the referred object given ref |
| showMask - show mask of the referred object given ref |
| """ |
|
|
| from doctest import REPORT_ONLY_FIRST_FAILURE |
| import sys |
| import os.path as osp |
| import json |
| import pickle |
| import time |
| import itertools |
| import skimage.io as io |
| import matplotlib.pyplot as plt |
| from matplotlib.collections import PatchCollection |
| from matplotlib.patches import Polygon, Rectangle |
| from pprint import pprint |
| import numpy as np |
| from pycocotools import mask |
| |
| |
|
|
|
|
| class REFER: |
| def __init__(self, data_root, dataset='refcoco', splitBy='unc'): |
| |
| |
| |
| print('loading dataset {} into memory...'.format(dataset)) |
| self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) |
| self.DATA_DIR = osp.join(data_root, dataset) |
| if dataset in ['refcoco', 'refcoco+', 'refcocog']: |
| self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014') |
| elif dataset == 'refclef': |
| self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') |
| else: |
| print('No refer dataset is called [{}]'.format(dataset)) |
| sys.exit() |
|
|
| |
| tic = time.time() |
| ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p') |
| self.data = {} |
| self.data['dataset'] = dataset |
| self.data['refs'] = pickle.load(open(ref_file, 'rb')) |
|
|
| |
| instances_file = osp.join(self.DATA_DIR, 'instances.json') |
| instances = json.load(open(instances_file, 'r')) |
| self.data['images'] = instances['images'] |
| self.data['annotations'] = instances['annotations'] |
| self.data['categories'] = instances['categories'] |
|
|
| |
| self.createIndex() |
| print('DONE (t=%.2fs)'.format(time.time()-tic)) |
|
|
| def createIndex(self): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print('creating index...') |
| |
| Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} |
| for ann in self.data['annotations']: |
| Anns[ann['id']] = ann |
| imgToAnns[ann['image_id']] = imgToAnns.get( |
| ann['image_id'], []) + [ann] |
| for img in self.data['images']: |
| Imgs[img['id']] = img |
| for cat in self.data['categories']: |
| Cats[cat['id']] = cat['name'] |
|
|
| |
| Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} |
| Sents, sentToRef, sentToTokens = {}, {}, {} |
| for ref in self.data['refs']: |
| |
| ref_id = ref['ref_id'] |
| ann_id = ref['ann_id'] |
| category_id = ref['category_id'] |
| image_id = ref['image_id'] |
|
|
| |
| Refs[ref_id] = ref |
| imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] |
| catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] |
| refToAnn[ref_id] = Anns[ann_id] |
| annToRef[ann_id] = ref |
|
|
| |
| for sent in ref['sentences']: |
| Sents[sent['sent_id']] = sent |
| sentToRef[sent['sent_id']] = ref |
| sentToTokens[sent['sent_id']] = sent['tokens'] |
|
|
| |
| self.Refs = Refs |
| self.Anns = Anns |
| self.Imgs = Imgs |
| self.Cats = Cats |
| self.Sents = Sents |
| self.imgToRefs = imgToRefs |
| self.imgToAnns = imgToAnns |
| self.refToAnn = refToAnn |
| self.annToRef = annToRef |
| self.catToRefs = catToRefs |
| self.sentToRef = sentToRef |
| self.sentToTokens = sentToTokens |
| print('index created.') |
|
|
| def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): |
| image_ids = image_ids if type(image_ids) == list else [image_ids] |
| cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] |
| ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
| if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: |
| refs = self.data['refs'] |
| else: |
| if not len(image_ids) == 0: |
| refs = [self.imgToRefs[image_id] for image_id in image_ids] |
| else: |
| refs = self.data['refs'] |
| if not len(cat_ids) == 0: |
| refs = [ref for ref in refs if ref['category_id'] in cat_ids] |
| if not len(ref_ids) == 0: |
| refs = [ref for ref in refs if ref['ref_id'] in ref_ids] |
| if not len(split) == 0: |
| if split in ['testA', 'testB', 'testC']: |
| |
| refs = [ref for ref in refs if split[-1] in ref['split']] |
| elif split in ['testAB', 'testBC', 'testAC']: |
| |
| refs = [ref for ref in refs if ref['split'] == split] |
| elif split == 'test': |
| refs = [ref for ref in refs if 'test' in ref['split']] |
| elif split == 'train' or split == 'val': |
| refs = [ref for ref in refs if ref['split'] == split] |
| else: |
| print('No such split [{}]'.format(split)) |
| sys.exit() |
| ref_ids = [ref['ref_id'] for ref in refs] |
| return ref_ids |
|
|
| def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): |
| image_ids = image_ids if type(image_ids) == list else [image_ids] |
| cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] |
| ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
| if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: |
| ann_ids = [ann['id'] for ann in self.data['annotations']] |
| else: |
| if not len(image_ids) == 0: |
| lists = [self.imgToAnns[image_id] |
| for image_id in image_ids if image_id in self.imgToAnns] |
| anns = list(itertools.chain.from_iterable(lists)) |
| else: |
| anns = self.data['annotations'] |
| if not len(cat_ids) == 0: |
| anns = [ann for ann in anns if ann['category_id'] in cat_ids] |
| ann_ids = [ann['id'] for ann in anns] |
| if not len(ref_ids) == 0: |
| ids = set(ann_ids).intersection( |
| set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) |
| return ann_ids |
|
|
| def getImgIds(self, ref_ids=[]): |
| ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
| if not len(ref_ids) == 0: |
| image_ids = list(set([self.Refs[ref_id]['image_id'] |
| for ref_id in ref_ids])) |
| else: |
| image_ids = self.Imgs.keys() |
| return image_ids |
|
|
| def getCatIds(self): |
| return self.Cats.keys() |
|
|
| def loadRefs(self, ref_ids=[]): |
| if type(ref_ids) == list: |
| return [self.Refs[ref_id] for ref_id in ref_ids] |
| elif type(ref_ids) == int: |
| return [self.Refs[ref_ids]] |
|
|
| def loadAnns(self, ann_ids=[]): |
| if type(ann_ids) == list: |
| return [self.Anns[ann_id] for ann_id in ann_ids] |
| elif type(ann_ids) == int or type(ann_ids) == unicode: |
| return [self.Anns[ann_ids]] |
|
|
| def loadImgs(self, image_ids=[]): |
| if type(image_ids) == list: |
| return [self.Imgs[image_id] for image_id in image_ids] |
| elif type(image_ids) == int: |
| return [self.Imgs[image_ids]] |
|
|
| def loadCats(self, cat_ids=[]): |
| if type(cat_ids) == list: |
| return [self.Cats[cat_id] for cat_id in cat_ids] |
| elif type(cat_ids) == int: |
| return [self.Cats[cat_ids]] |
|
|
| def getRefBox(self, ref_id): |
| ref = self.Refs[ref_id] |
| ann = self.refToAnn[ref_id] |
| return ann['bbox'] |
|
|
| def showRef(self, ref, seg_box='seg'): |
| ax = plt.gca() |
| |
| image = self.Imgs[ref['image_id']] |
| I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) |
| ax.imshow(I) |
| |
| for sid, sent in enumerate(ref['sentences']): |
| print('{}. {}'.format(sid+1, sent['sent'])) |
| |
| if seg_box == 'seg': |
| ann_id = ref['ann_id'] |
| ann = self.Anns[ann_id] |
| polygons = [] |
| color = [] |
| c = 'none' |
| if type(ann['segmentation'][0]) == list: |
| |
| for seg in ann['segmentation']: |
| poly = np.array(seg).reshape((len(seg)/2, 2)) |
| polygons.append(Polygon(poly, True, alpha=0.4)) |
| color.append(c) |
| p = PatchCollection(polygons, facecolors=color, edgecolors=( |
| 1, 1, 0, 0), linewidths=3, alpha=1) |
| ax.add_collection(p) |
| p = PatchCollection(polygons, facecolors=color, edgecolors=( |
| 1, 0, 0, 0), linewidths=1, alpha=1) |
| ax.add_collection(p) |
| else: |
| |
| rle = ann['segmentation'] |
| m = mask.decode(rle) |
| img = np.ones((m.shape[0], m.shape[1], 3)) |
| color_mask = np.array([2.0, 166.0, 101.0])/255 |
| for i in range(3): |
| img[:, :, i] = color_mask[i] |
| ax.imshow(np.dstack((img, m*0.5))) |
| |
| elif seg_box == 'box': |
| ann_id = ref['ann_id'] |
| ann = self.Anns[ann_id] |
| bbox = self.getRefBox(ref['ref_id']) |
| box_plot = Rectangle( |
| (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3) |
| ax.add_patch(box_plot) |
|
|
| def getMask(self, ref): |
| |
| ann = self.refToAnn[ref['ref_id']] |
| image = self.Imgs[ref['image_id']] |
| if type(ann['segmentation'][0]) == list: |
| rle = mask.frPyObjects( |
| ann['segmentation'], image['height'], image['width']) |
| else: |
| rle = ann['segmentation'] |
| m = mask.decode(rle) |
| |
| m = np.sum(m, axis=2) |
| m = m.astype(np.uint8) |
| |
| area = sum(mask.area(rle)) |
| return {'mask': m, 'area': area} |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def showMask(self, ref): |
| M = self.getMask(ref) |
| msk = M['mask'] |
| ax = plt.gca() |
| ax.imshow(msk) |
|
|
|
|
| if __name__ == '__main__': |
| refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg', |
| dataset='refcocog', splitBy='google') |
| ref_ids = refer.getRefIds() |
| print(len(ref_ids)) |
|
|
| print(len(refer.Imgs)) |
| print(len(refer.imgToRefs)) |
|
|
| ref_ids = refer.getRefIds(split='train') |
| print('There are {} training referred objects.' % len(ref_ids)) |
|
|
| for ref_id in ref_ids: |
| ref = refer.loadRefs(ref_id)[0] |
| if len(ref['sentences']) < 2: |
| continue |
|
|
| pprint(ref) |
| print('The label is {}.'.format(refer.Cats[ref['category_id']])) |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|