| |
| import os |
| import json |
| import argparse |
| from PIL import Image |
| import numpy as np |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--ann', default='datasets/cc3m/Train_GCC-training.tsv') |
| parser.add_argument('--save_image_path', default='datasets/cc3m/training/') |
| parser.add_argument('--cat_info', default='datasets/lvis/lvis_v1_val.json') |
| parser.add_argument('--out_path', default='datasets/cc3m/train_image_info.json') |
| parser.add_argument('--not_download_image', action='store_true') |
| args = parser.parse_args() |
| categories = json.load(open(args.cat_info, 'r'))['categories'] |
| images = [] |
| if not os.path.exists(args.save_image_path): |
| os.makedirs(args.save_image_path) |
| f = open(args.ann) |
| for i, line in enumerate(f): |
| cap, path = line[:-1].split('\t') |
| print(i, cap, path) |
| if not args.not_download_image: |
| os.system( |
| 'wget {} -O {}/{}.jpg'.format( |
| path, args.save_image_path, i + 1)) |
| try: |
| img = Image.open( |
| open('{}/{}.jpg'.format(args.save_image_path, i + 1), "rb")) |
| img = np.asarray(img.convert("RGB")) |
| h, w = img.shape[:2] |
| except: |
| continue |
| image_info = { |
| 'id': i + 1, |
| 'file_name': '{}.jpg'.format(i + 1), |
| 'height': h, |
| 'width': w, |
| 'captions': [cap], |
| } |
| images.append(image_info) |
| data = {'categories': categories, 'images': images, 'annotations': []} |
| for k, v in data.items(): |
| print(k, len(v)) |
| print('Saving to', args.out_path) |
| json.dump(data, open(args.out_path, 'w')) |
|
|