Spaces:
Runtime error
Runtime error
| """ | |
| Cache CLIP features for all images in training split in preparation for RICES | |
| """ | |
| import argparse | |
| import sys | |
| import os | |
| sys.path.append( | |
| os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), | |
| "..", | |
| ) | |
| ) | |
| from eval.rices import RICES | |
| from eval.eval_datasets import ( | |
| CaptionDataset, | |
| VQADataset, | |
| ImageNetDataset, | |
| HatefulMemesDataset, | |
| ) | |
| import os | |
| import torch | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| required=True, | |
| help="Directory to save the cached features.", | |
| ) | |
| parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) | |
| parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) | |
| parser.add_argument("--batch_size", default=256) | |
| # Per-dataset flags | |
| parser.add_argument( | |
| "--eval_coco", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache COCO.", | |
| ) | |
| parser.add_argument( | |
| "--eval_vqav2", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache VQAV2.", | |
| ) | |
| parser.add_argument( | |
| "--eval_ok_vqa", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache OK-VQA.", | |
| ) | |
| parser.add_argument( | |
| "--eval_vizwiz", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache VizWiz.", | |
| ) | |
| parser.add_argument( | |
| "--eval_textvqa", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache TextVQA.", | |
| ) | |
| parser.add_argument( | |
| "--eval_imagenet", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache ImageNet.", | |
| ) | |
| parser.add_argument( | |
| "--eval_flickr30", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache Flickr30.", | |
| ) | |
| parser.add_argument( | |
| "--eval_hateful_memes", | |
| action="store_true", | |
| default=False, | |
| help="Whether to cache Hateful Memes.", | |
| ) | |
| # Dataset arguments | |
| ## Flickr30 Dataset | |
| parser.add_argument( | |
| "--flickr_image_dir_path", | |
| type=str, | |
| help="Path to the flickr30/flickr30k_images directory.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--flickr_karpathy_json_path", | |
| type=str, | |
| help="Path to the dataset_flickr30k.json file.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--flickr_annotations_json_path", | |
| type=str, | |
| help="Path to the dataset_flickr30k_coco_style.json file.", | |
| ) | |
| ## COCO Dataset | |
| parser.add_argument( | |
| "--coco_train_image_dir_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--coco_val_image_dir_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--coco_karpathy_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--coco_annotations_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| ## VQAV2 Dataset | |
| parser.add_argument( | |
| "--vqav2_train_image_dir_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--vqav2_train_questions_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--vqav2_train_annotations_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| ## OK-VQA Dataset | |
| parser.add_argument( | |
| "--ok_vqa_train_image_dir_path", | |
| type=str, | |
| help="Path to the vqav2/train2014 directory.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--ok_vqa_train_questions_json_path", | |
| type=str, | |
| help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--ok_vqa_train_annotations_json_path", | |
| type=str, | |
| help="Path to the v2_mscoco_train2014_annotations.json file.", | |
| default=None, | |
| ) | |
| ## VizWiz Dataset | |
| parser.add_argument( | |
| "--vizwiz_train_image_dir_path", | |
| type=str, | |
| help="Path to the vizwiz train images directory.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--vizwiz_train_questions_json_path", | |
| type=str, | |
| help="Path to the vizwiz questions json file.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--vizwiz_train_annotations_json_path", | |
| type=str, | |
| help="Path to the vizwiz annotations json file.", | |
| default=None, | |
| ) | |
| # TextVQA Dataset | |
| parser.add_argument( | |
| "--textvqa_image_dir_path", | |
| type=str, | |
| help="Path to the textvqa images directory.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--textvqa_train_questions_json_path", | |
| type=str, | |
| help="Path to the textvqa questions json file.", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--textvqa_train_annotations_json_path", | |
| type=str, | |
| help="Path to the textvqa annotations json file.", | |
| default=None, | |
| ) | |
| ## Imagenet dataset | |
| parser.add_argument("--imagenet_root", type=str, default="/tmp") | |
| ## Hateful Memes dataset | |
| parser.add_argument( | |
| "--hateful_memes_image_dir_path", | |
| type=str, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--hateful_memes_train_annotations_json_path", | |
| type=str, | |
| default=None, | |
| ) | |
| def main(): | |
| args, leftovers = parser.parse_known_args() | |
| device_id = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" | |
| if args.eval_flickr30: | |
| print("Caching Flickr30k...") | |
| train_dataset = CaptionDataset( | |
| image_train_dir_path=args.flickr_image_dir_path, | |
| image_val_dir_path=None, | |
| annotations_path=args.flickr_karpathy_json_path, | |
| is_train=True, | |
| dataset_name="flickr", | |
| ) | |
| rices_dataset = RICES( | |
| train_dataset, | |
| device_id, | |
| args.batch_size, | |
| vision_encoder_path=args.vision_encoder_path, | |
| vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| ) | |
| torch.save( | |
| rices_dataset.features, | |
| os.path.join(args.output_dir, "flickr30.pkl"), | |
| ) | |
| if args.eval_coco: | |
| print("Caching COCO...") | |
| train_dataset = CaptionDataset( | |
| image_train_dir_path=args.coco_train_image_dir_path, | |
| image_val_dir_path=args.coco_val_image_dir_path, | |
| annotations_path=args.coco_karpathy_json_path, | |
| is_train=True, | |
| dataset_name="coco", | |
| ) | |
| rices_dataset = RICES( | |
| train_dataset, | |
| device_id, | |
| args.batch_size, | |
| vision_encoder_path=args.vision_encoder_path, | |
| vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| ) | |
| torch.save( | |
| rices_dataset.features, | |
| os.path.join(args.output_dir, "coco.pkl"), | |
| ) | |
| if args.eval_ok_vqa: | |
| print("Caching OK-VQA...") | |
| train_dataset = VQADataset( | |
| image_dir_path=args.ok_vqa_train_image_dir_path, | |
| question_path=args.ok_vqa_train_questions_json_path, | |
| annotations_path=args.ok_vqa_train_annotations_json_path, | |
| is_train=True, | |
| dataset_name="ok_vqa", | |
| ) | |
| rices_dataset = RICES( | |
| train_dataset, | |
| device_id, | |
| args.batch_size, | |
| vision_encoder_path=args.vision_encoder_path, | |
| vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| ) | |
| torch.save( | |
| rices_dataset.features, | |
| os.path.join(args.output_dir, "ok_vqa.pkl"), | |
| ) | |
| if args.eval_vizwiz: | |
| print("Caching VizWiz...") | |
| train_dataset = VQADataset( | |
| image_dir_path=args.vizwiz_train_image_dir_path, | |
| question_path=args.vizwiz_train_questions_json_path, | |
| annotations_path=args.vizwiz_train_annotations_json_path, | |
| is_train=True, | |
| dataset_name="vizwiz", | |
| ) | |
| rices_dataset = RICES( | |
| train_dataset, | |
| device_id, | |
| args.batch_size, | |
| vision_encoder_path=args.vision_encoder_path, | |
| vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| ) | |
| torch.save( | |
| rices_dataset.features, | |
| os.path.join(args.output_dir, "vizwiz.pkl"), | |
| ) | |
| if args.eval_vqav2: | |
| print("Caching VQAv2...") | |
| train_dataset = VQADataset( | |
| image_dir_path=args.vqav2_train_image_dir_path, | |
| question_path=args.vqav2_train_questions_json_path, | |
| annotations_path=args.vqav2_train_annotations_json_path, | |
| is_train=True, | |
| dataset_name="vqav2", | |
| ) | |
| rices_dataset = RICES( | |
| train_dataset, | |
| device_id, | |
| args.batch_size, | |
| vision_encoder_path=args.vision_encoder_path, | |
| vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| ) | |
| torch.save( | |
| rices_dataset.features, | |
| os.path.join(args.output_dir, "vqav2.pkl"), | |
| ) | |
| if args.eval_textvqa: | |
| print("Caching TextVQA...") | |
| train_dataset = VQADataset( | |
| image_dir_path=args.textvqa_image_dir_path, | |
| question_path=args.textvqa_train_questions_json_path, | |
| annotations_path=args.textvqa_train_annotations_json_path, | |
| is_train=True, | |
| dataset_name="textvqa", | |
| ) | |
| rices_dataset = RICES( | |
| train_dataset, | |
| device_id, | |
| args.batch_size, | |
| vision_encoder_path=args.vision_encoder_path, | |
| vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| ) | |
| torch.save( | |
| rices_dataset.features, | |
| os.path.join(args.output_dir, "textvqa.pkl"), | |
| ) | |
| if args.eval_hateful_memes: | |
| print("Caching Hateful Memes...") | |
| train_dataset = HatefulMemesDataset( | |
| image_dir_path=args.hateful_memes_image_dir_path, | |
| annotations_path=args.hateful_memes_train_annotations_json_path, | |
| ) | |
| rices_dataset = RICES( | |
| train_dataset, | |
| device_id, | |
| args.batch_size, | |
| vision_encoder_path=args.vision_encoder_path, | |
| vision_encoder_pretrained=args.vision_encoder_pretrained, | |
| ) | |
| torch.save( | |
| rices_dataset.features, | |
| os.path.join(args.output_dir, "hateful_memes.pkl"), | |
| ) | |
| if __name__ == "__main__": | |
| main() | |