roboflamingo-demo / open_flamingo /scripts /cache_rices_features.py
aw1app's picture
Deploy real interactive model
c6cb161
"""
Cache CLIP features for all images in training split in preparation for RICES
"""
import argparse
import sys
import os
sys.path.append(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
)
)
from eval.rices import RICES
from eval.eval_datasets import (
CaptionDataset,
VQADataset,
ImageNetDataset,
HatefulMemesDataset,
)
import os
import torch
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Directory to save the cached features.",
)
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
parser.add_argument("--batch_size", default=256)
# Per-dataset flags
parser.add_argument(
"--eval_coco",
action="store_true",
default=False,
help="Whether to cache COCO.",
)
parser.add_argument(
"--eval_vqav2",
action="store_true",
default=False,
help="Whether to cache VQAV2.",
)
parser.add_argument(
"--eval_ok_vqa",
action="store_true",
default=False,
help="Whether to cache OK-VQA.",
)
parser.add_argument(
"--eval_vizwiz",
action="store_true",
default=False,
help="Whether to cache VizWiz.",
)
parser.add_argument(
"--eval_textvqa",
action="store_true",
default=False,
help="Whether to cache TextVQA.",
)
parser.add_argument(
"--eval_imagenet",
action="store_true",
default=False,
help="Whether to cache ImageNet.",
)
parser.add_argument(
"--eval_flickr30",
action="store_true",
default=False,
help="Whether to cache Flickr30.",
)
parser.add_argument(
"--eval_hateful_memes",
action="store_true",
default=False,
help="Whether to cache Hateful Memes.",
)
# Dataset arguments
## Flickr30 Dataset
parser.add_argument(
"--flickr_image_dir_path",
type=str,
help="Path to the flickr30/flickr30k_images directory.",
default=None,
)
parser.add_argument(
"--flickr_karpathy_json_path",
type=str,
help="Path to the dataset_flickr30k.json file.",
default=None,
)
parser.add_argument(
"--flickr_annotations_json_path",
type=str,
help="Path to the dataset_flickr30k_coco_style.json file.",
)
## COCO Dataset
parser.add_argument(
"--coco_train_image_dir_path",
type=str,
default=None,
)
parser.add_argument(
"--coco_val_image_dir_path",
type=str,
default=None,
)
parser.add_argument(
"--coco_karpathy_json_path",
type=str,
default=None,
)
parser.add_argument(
"--coco_annotations_json_path",
type=str,
default=None,
)
## VQAV2 Dataset
parser.add_argument(
"--vqav2_train_image_dir_path",
type=str,
default=None,
)
parser.add_argument(
"--vqav2_train_questions_json_path",
type=str,
default=None,
)
parser.add_argument(
"--vqav2_train_annotations_json_path",
type=str,
default=None,
)
## OK-VQA Dataset
parser.add_argument(
"--ok_vqa_train_image_dir_path",
type=str,
help="Path to the vqav2/train2014 directory.",
default=None,
)
parser.add_argument(
"--ok_vqa_train_questions_json_path",
type=str,
help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
default=None,
)
parser.add_argument(
"--ok_vqa_train_annotations_json_path",
type=str,
help="Path to the v2_mscoco_train2014_annotations.json file.",
default=None,
)
## VizWiz Dataset
parser.add_argument(
"--vizwiz_train_image_dir_path",
type=str,
help="Path to the vizwiz train images directory.",
default=None,
)
parser.add_argument(
"--vizwiz_train_questions_json_path",
type=str,
help="Path to the vizwiz questions json file.",
default=None,
)
parser.add_argument(
"--vizwiz_train_annotations_json_path",
type=str,
help="Path to the vizwiz annotations json file.",
default=None,
)
# TextVQA Dataset
parser.add_argument(
"--textvqa_image_dir_path",
type=str,
help="Path to the textvqa images directory.",
default=None,
)
parser.add_argument(
"--textvqa_train_questions_json_path",
type=str,
help="Path to the textvqa questions json file.",
default=None,
)
parser.add_argument(
"--textvqa_train_annotations_json_path",
type=str,
help="Path to the textvqa annotations json file.",
default=None,
)
## Imagenet dataset
parser.add_argument("--imagenet_root", type=str, default="/tmp")
## Hateful Memes dataset
parser.add_argument(
"--hateful_memes_image_dir_path",
type=str,
default=None,
)
parser.add_argument(
"--hateful_memes_train_annotations_json_path",
type=str,
default=None,
)
def main():
args, leftovers = parser.parse_known_args()
device_id = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
if args.eval_flickr30:
print("Caching Flickr30k...")
train_dataset = CaptionDataset(
image_train_dir_path=args.flickr_image_dir_path,
image_val_dir_path=None,
annotations_path=args.flickr_karpathy_json_path,
is_train=True,
dataset_name="flickr",
)
rices_dataset = RICES(
train_dataset,
device_id,
args.batch_size,
vision_encoder_path=args.vision_encoder_path,
vision_encoder_pretrained=args.vision_encoder_pretrained,
)
torch.save(
rices_dataset.features,
os.path.join(args.output_dir, "flickr30.pkl"),
)
if args.eval_coco:
print("Caching COCO...")
train_dataset = CaptionDataset(
image_train_dir_path=args.coco_train_image_dir_path,
image_val_dir_path=args.coco_val_image_dir_path,
annotations_path=args.coco_karpathy_json_path,
is_train=True,
dataset_name="coco",
)
rices_dataset = RICES(
train_dataset,
device_id,
args.batch_size,
vision_encoder_path=args.vision_encoder_path,
vision_encoder_pretrained=args.vision_encoder_pretrained,
)
torch.save(
rices_dataset.features,
os.path.join(args.output_dir, "coco.pkl"),
)
if args.eval_ok_vqa:
print("Caching OK-VQA...")
train_dataset = VQADataset(
image_dir_path=args.ok_vqa_train_image_dir_path,
question_path=args.ok_vqa_train_questions_json_path,
annotations_path=args.ok_vqa_train_annotations_json_path,
is_train=True,
dataset_name="ok_vqa",
)
rices_dataset = RICES(
train_dataset,
device_id,
args.batch_size,
vision_encoder_path=args.vision_encoder_path,
vision_encoder_pretrained=args.vision_encoder_pretrained,
)
torch.save(
rices_dataset.features,
os.path.join(args.output_dir, "ok_vqa.pkl"),
)
if args.eval_vizwiz:
print("Caching VizWiz...")
train_dataset = VQADataset(
image_dir_path=args.vizwiz_train_image_dir_path,
question_path=args.vizwiz_train_questions_json_path,
annotations_path=args.vizwiz_train_annotations_json_path,
is_train=True,
dataset_name="vizwiz",
)
rices_dataset = RICES(
train_dataset,
device_id,
args.batch_size,
vision_encoder_path=args.vision_encoder_path,
vision_encoder_pretrained=args.vision_encoder_pretrained,
)
torch.save(
rices_dataset.features,
os.path.join(args.output_dir, "vizwiz.pkl"),
)
if args.eval_vqav2:
print("Caching VQAv2...")
train_dataset = VQADataset(
image_dir_path=args.vqav2_train_image_dir_path,
question_path=args.vqav2_train_questions_json_path,
annotations_path=args.vqav2_train_annotations_json_path,
is_train=True,
dataset_name="vqav2",
)
rices_dataset = RICES(
train_dataset,
device_id,
args.batch_size,
vision_encoder_path=args.vision_encoder_path,
vision_encoder_pretrained=args.vision_encoder_pretrained,
)
torch.save(
rices_dataset.features,
os.path.join(args.output_dir, "vqav2.pkl"),
)
if args.eval_textvqa:
print("Caching TextVQA...")
train_dataset = VQADataset(
image_dir_path=args.textvqa_image_dir_path,
question_path=args.textvqa_train_questions_json_path,
annotations_path=args.textvqa_train_annotations_json_path,
is_train=True,
dataset_name="textvqa",
)
rices_dataset = RICES(
train_dataset,
device_id,
args.batch_size,
vision_encoder_path=args.vision_encoder_path,
vision_encoder_pretrained=args.vision_encoder_pretrained,
)
torch.save(
rices_dataset.features,
os.path.join(args.output_dir, "textvqa.pkl"),
)
if args.eval_hateful_memes:
print("Caching Hateful Memes...")
train_dataset = HatefulMemesDataset(
image_dir_path=args.hateful_memes_image_dir_path,
annotations_path=args.hateful_memes_train_annotations_json_path,
)
rices_dataset = RICES(
train_dataset,
device_id,
args.batch_size,
vision_encoder_path=args.vision_encoder_path,
vision_encoder_pretrained=args.vision_encoder_pretrained,
)
torch.save(
rices_dataset.features,
os.path.join(args.output_dir, "hateful_memes.pkl"),
)
if __name__ == "__main__":
main()