| import pandas as pd | |
| import os.path | |
| import sys | |
| import json | |
| import logging | |
| import contexttimer | |
| if len(sys.argv) != 4: | |
| print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json") | |
| exit(1) | |
| annotation_file = sys.argv[1] | |
| images_dir = sys.argv[2] | |
| output_file = sys.argv[3] | |
| logging.info("Processing Flicker 30k dataset") | |
| with contexttimer.Timer(prefix="Loading from tsv"): | |
| df = pd.read_csv(annotation_file, delimiter='\t') | |
| images_dict = {} | |
| for index, caption, image_name in df.itertuples(): | |
| if image_name in images_dict: | |
| images_dict[image_name] += [caption] | |
| else: | |
| images_dict[image_name] = [caption] | |
| lines = [] | |
| for image_path, captions in images_dict.items(): | |
| full_image_path = images_dir+"/"+image_name | |
| if os.path.isfile(full_image_path): | |
| lines.append(json.dumps({"image_path": full_image_path, "captions": captions})) | |
| else: | |
| print(f"{full_image_path} doesn't exist") | |
| train_lines = lines[:-3_001] | |
| valid_lines = lines[-3_001:] | |
| with open(output_file+"_train.json", "w") as f: | |
| f.write("\n".join(train_lines)) | |
| with open(output_file+"_val.json", "w") as f: | |
| f.write("\n".join(valid_lines)) | |
| logging.info(f"Processing Flicker 30k dataset done. {len(lines)} images processed.") | |