# encoding=utf8 # %matplotlib inline import numpy as np import os from refer import REFER import os.path as osp import cv2 import argparse parser = argparse.ArgumentParser(description='Data preparation') parser.add_argument('--data_root', type=str) # contains refclef, refcoco, refcoco+, refcocog and images parser.add_argument('--output_dir', type=str) parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog'], default='refcoco') parser.add_argument('--split', type=str,default='umd') parser.add_argument('--generate_mask', action='store_true') args = parser.parse_args() # data_root # contains refclef, refcoco, refcoco+, refcocog and images refer = REFER(args.data_root, args.dataset, args.split) print ('dataset [%s_%s] contains: ' % (args.dataset, args.split)) ref_ids = refer.getRefIds() image_ids = refer.getImgIds() print ('%s expressions for %s refs in %s images.' % (len(refer.Sents), len(ref_ids), len(image_ids))) print('\nAmong them:') if args.dataset == 'refclef': if args.split == 'unc': splits = ['train', 'val', 'testA','testB','testC'] else: splits = ['train', 'val', 'test'] elif args.dataset == 'refcoco': splits = ['train', 'val', 'testA', 'testB'] elif args.dataset == 'refcoco+': splits = ['train', 'val', 'testA', 'testB'] elif args.dataset == 'grefcoco': splits = ['train', 'val', 'testA', 'testB'] elif args.dataset == 'refcocog': splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now. # split data as a type in splits list for split in splits: ref_ids = refer.getRefIds(split=split) print('%s refs are in split [%s].' % (len(ref_ids), split)) # show a batch data with bounding box,cat,sentences def show_a_batch(batch_size): split='train' # batch_size=32 ref_ids = refer.getRefIds(split=split) print(split+'_size:',len(ref_ids)) batch_index=list(np.random.choice(len(ref_ids),batch_size)) # print(refer.Refs) ref_id = [ref_ids[i] for i in batch_index] refs = [refer.Refs[i] for i in ref_id] bboxs=[refer.getRefBox(i) for i in ref_id] sentences=[ref['sentences'] for ref in refs] image_urls=[refer.loadImgs(image_ids=ref['image_id']) for ref in refs] cats = [refer.loadCats(cat_ids=ref['category_id']) for ref in refs] # plt.figure() # plt.subplot(batch_size) grid_width = 2 grid_height = int(batch_size / grid_width) # fig, axs = plt.subplots(grid_height, grid_width, figsize=(grid_width*10, 10*grid_height)) for i in range(batch_size): print('bbox for batch[{}]:'.format(i),bboxs[i]) print('sentences for batch[{}]:'.format(i)) for sid, sent in enumerate(sentences[i]): print('%s. %s' % (sid+1, sent['sent'])) print('cats for batch[{}]:'.format(i), cats[i]) image_url=image_urls[i][0] image=cv2.imread(osp.join(refer.IMAGE_DIR, image_url['file_name'])) print(image.shape) # print(bboxs[i][0]) cv2.rectangle(image,(int(bboxs[i][0]), int(bboxs[i][1])), (int(bboxs[i][0]+bboxs[i][2]),int(bboxs[i][1]+ bboxs[i][3])),255,3) cv2.putText(image, str(sent['sent']), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, .9,(0,255,0), 2) os.mkdir('debug_vis') cv2.imwrite('./debug_vis/'+image_url['file_name'], image) cv2.imwrite('./debug_vis/mask'+image_url['file_name'], refer.getMask(refs[i])['mask']*255) # ax.imshow(image) # plt.show() def cat_process(cat): if cat >= 1 and cat <= 11: cat = cat - 1 elif cat >= 13 and cat <= 25: cat = cat - 2 elif cat >= 27 and cat <= 28: cat = cat - 3 elif cat >= 31 and cat <= 44: cat = cat - 5 elif cat >= 46 and cat <= 65: cat = cat - 6 elif cat == 67: cat = cat - 7 elif cat == 70: cat = cat - 9 elif cat >= 72 and cat <= 82: cat = cat - 10 elif cat >= 84 and cat <= 90: cat = cat - 11 return cat def bbox_process(bbox,cat,segement_id): x_min = int(bbox[0]) y_min = int(bbox[1]) x_max = x_min + int(bbox[2]) y_max = y_min + int(bbox[3]) box_info = " %d,%d,%d,%d,%d,%d" % (int(x_min), int(y_min), int(x_max), int(y_max), int(cat),int(segement_id)) return box_info def prepare_dataset(dataset,splits,output_dir,generate_mask=False): # split_type='train' # splits=[split_type] # batch_size=32 if dataset == 'refcocog': dataset = 'refcocog_' + args.split if not os.path.exists(os.path.join(output_dir,'anns',dataset)): os.makedirs(os.path.join(output_dir,'anns',dataset)) if not os.path.exists(os.path.join(output_dir,'masks',dataset)): os.makedirs(os.path.join(output_dir,'masks',dataset)) for split in splits: f = open(os.path.join(output_dir,'anns', dataset, split + '.txt'), 'w', encoding='utf-8') # print(split) split_num=0 ll=0 ref_ids = refer.getRefIds(split=split) print(split+'_size:',len(ref_ids)) for i in ref_ids: # ref_id = ref_ids[i] refs = refer.Refs[i] bboxs=refer.getRefBox(i) print("bboxs", bboxs) sentences=refs['sentences'] image_urls=refer.loadImgs(image_ids=refs['image_id'])[0] # grefcoco中的category_id是一个list cat = refs['category_id'] if type(cat) == list: for j in range(len(cat)): cat[j] = cat_process(cat[j]) else: cat = cat_process(cat) image_urls=image_urls['file_name'] if dataset=='refclef' and image_urls in ['19579.jpg', '17975.jpg', '19575.jpg']: continue # RES中box信息和cat信息使用不到 if type(bboxs[0]) == list: box_info = bbox_process(bboxs[0], cat[0], i) # add segement id else: box_info=bbox_process(bboxs,cat,i) #add segement id f.write(image_urls) f.write(box_info) # f.write(' '+str(i)) if generate_mask: if dataset == 'grefcoco': np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMaskByRef(refs, merge=True)['mask']) else: np.save(os.path.join(output_dir,'masks',dataset,str(i)+'.npy'),refer.getMask(refs)['mask']) #if need seg mask ,set it! for sentence in sentences: f.write(' ~ ') # print(sentence['sent'].encode('UTF-8')) f.write(sentence['sent']) if llmax_len: # max_len=len(line[sent_stop:i]) sent_stop = i + 1 for i in range(50): if word_l_count[i]>0: print('length:%d'%i,',count:%d'%word_l_count[i]) # print('max_len:',max_len) # print(len(lines)) prepare_dataset(args.dataset,splits,args.output_dir,args.generate_mask)