| import json | |
| import logging | |
| import sys | |
| import os.path | |
| if len(sys.argv) != 4: | |
| print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json") | |
| exit(1) | |
| annotation_file = sys.argv[1] | |
| images_dir = sys.argv[2] | |
| output_file = sys.argv[3] | |
| logging.info("Processing Flicker 8k dataset") | |
| with open(annotation_file, "r") as f: | |
| annotations = json.load(f) | |
| lines = [] | |
| for image_path, captions in annotations.items(): | |
| edited_captions = [] | |
| for caption in captions: | |
| if len(caption) > 0: | |
| edited_captions.append(caption.replace("<start> ", "").replace(" <end>", "")) | |
| full_image_path = images_dir+"/"+image_path | |
| if os.path.isfile(full_image_path): | |
| if len(edited_captions) > 0: | |
| lines.append(json.dumps({"image_path": full_image_path, "captions": edited_captions})) | |
| else: | |
| print(f"{full_image_path} doesn't exist") | |
| train_lines = lines[:-801] | |
| valid_lines = lines[-801:] | |
| with open(output_file+"_train.json", "w") as f: | |
| f.write("\n".join(train_lines)) | |
| with open(output_file+"_val.json", "w") as f: | |
| f.write("\n".join(valid_lines)) | |
| logging.info(f"Processing Flicker 8k dataset done. {len(lines)} images processed.") | |