|
|
import json |
|
|
import os |
|
|
from collections import Counter |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision.datasets import ImageFolder |
|
|
|
|
|
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL |
|
|
|
|
|
|
|
|
class CaptionDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
image_train_dir_path, |
|
|
annotations_path, |
|
|
is_train, |
|
|
dataset_name, |
|
|
image_val_dir_path=None, |
|
|
which_gt=None, |
|
|
best_gt_caption_path=None, |
|
|
): |
|
|
self.image_train_dir_path = image_train_dir_path |
|
|
self.image_val_dir_path = image_val_dir_path |
|
|
self.annotations = [] |
|
|
self.is_train = is_train |
|
|
self.dataset_name = dataset_name |
|
|
|
|
|
full_annotations = json.load(open(annotations_path))["images"] |
|
|
|
|
|
for i in range(len(full_annotations)): |
|
|
if self.is_train and full_annotations[i]["split"] != "train": |
|
|
continue |
|
|
elif not self.is_train and full_annotations[i]["split"] != "test": |
|
|
continue |
|
|
|
|
|
self.annotations.append(full_annotations[i]) |
|
|
|
|
|
if isinstance(which_gt, str): |
|
|
self.which_gt = int(which_gt) if which_gt.isdigit() else which_gt |
|
|
else: |
|
|
self.which_gt = which_gt |
|
|
|
|
|
if best_gt_caption_path is not None: |
|
|
with open(best_gt_caption_path, 'r') as f: |
|
|
self.best_gt_captions = json.load(f) |
|
|
else: |
|
|
self.best_gt_captions = None |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.annotations) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if self.dataset_name == "coco": |
|
|
image = Image.open( |
|
|
os.path.join( |
|
|
self.image_train_dir_path, self.annotations[idx]["filename"] |
|
|
) |
|
|
if self.annotations[idx]["filepath"] == "train2014" |
|
|
else os.path.join( |
|
|
self.image_val_dir_path, self.annotations[idx]["filename"] |
|
|
) |
|
|
) |
|
|
elif self.dataset_name == "flickr": |
|
|
image = Image.open( |
|
|
os.path.join( |
|
|
self.image_train_dir_path, self.annotations[idx]["filename"] |
|
|
) |
|
|
) |
|
|
image.load() |
|
|
|
|
|
image_id = self.annotations[idx]["cocoid"] if self.dataset_name == "coco" else self.annotations[idx]["filename"].split(".")[0] |
|
|
|
|
|
if isinstance(self.which_gt, int): |
|
|
cpt_idx = self.which_gt |
|
|
elif isinstance(self.which_gt, dict): |
|
|
cpt_idx = self.which_gt[image_id] |
|
|
elif self.which_gt == "best": |
|
|
cpt_idx = self.best_gt_captions[str(image_id)] |
|
|
else: |
|
|
assert self.which_gt is None |
|
|
cpt_idx = 0 |
|
|
|
|
|
caption = self.annotations[idx]["sentences"][cpt_idx]["raw"] |
|
|
return { |
|
|
"image": image, |
|
|
"caption": caption, |
|
|
"image_id": image_id, |
|
|
} |
|
|
|
|
|
|
|
|
class VQADataset(Dataset): |
|
|
def __init__( |
|
|
self, image_dir_path, question_path, annotations_path, is_train, dataset_name, which_gt='all', is_tensor=False |
|
|
): |
|
|
self.questions = json.load(open(question_path, "r"))["questions"] |
|
|
if annotations_path is not None: |
|
|
self.answers = json.load(open(annotations_path, "r"))["annotations"] |
|
|
else: |
|
|
self.answers = None |
|
|
self.image_dir_path = image_dir_path |
|
|
self.is_train = is_train |
|
|
self.dataset_name = dataset_name |
|
|
if self.dataset_name in {"vqav2", "ok_vqa"}: |
|
|
self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1] |
|
|
assert self.img_coco_split in {"train2014", "val2014", "test2015"} |
|
|
self.which_gt = which_gt |
|
|
self.is_tensor = is_tensor |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.questions) |
|
|
|
|
|
def get_img_path(self, question): |
|
|
if self.dataset_name in {"vqav2", "ok_vqa"}: |
|
|
return os.path.join( |
|
|
self.image_dir_path, |
|
|
f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg" |
|
|
if self.is_train |
|
|
else f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg", |
|
|
) |
|
|
elif self.dataset_name == "vizwiz": |
|
|
return os.path.join(self.image_dir_path, question["image_id"]) |
|
|
elif self.dataset_name == "textvqa": |
|
|
return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") |
|
|
else: |
|
|
raise Exception(f"Unknown VQA dataset {self.dataset_name}") |
|
|
|
|
|
def get_from_id(self, question_id): |
|
|
assert not self.is_train |
|
|
assert self.dataset_name == "textvqa" |
|
|
prefix = '' |
|
|
image_path = f"{self.image_dir_path}/{prefix}{str(question_id).zfill(12)}.pt" |
|
|
image = torch.load(image_path) |
|
|
return image |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
question = self.questions[idx] |
|
|
img_path = self.get_img_path(question) |
|
|
if self.is_tensor: |
|
|
image_path = img_path.replace("jpg", "pt") |
|
|
image = torch.load(image_path) |
|
|
else: |
|
|
image = Image.open(img_path) |
|
|
image.load() |
|
|
results = { |
|
|
"image": image, |
|
|
"question": question["question"], |
|
|
"question_id": question["question_id"], |
|
|
} |
|
|
if self.answers is not None: |
|
|
answers = self.answers[idx] |
|
|
answers = [a["answer"] for a in answers["answers"]] |
|
|
if self.which_gt in ["all", None]: |
|
|
results["answers"] = answers |
|
|
elif isinstance(self.which_gt, int) or isinstance(self.which_gt, dict): |
|
|
which_gt = self.which_gt[question["question_id"]] if isinstance(self.which_gt, dict) else self.which_gt |
|
|
|
|
|
counter = Counter(answers) |
|
|
most_common = counter.most_common() |
|
|
if which_gt >= len(most_common): |
|
|
results["answers"] = [] |
|
|
else: |
|
|
results["answers"] = [most_common[which_gt][0]] |
|
|
else: |
|
|
raise ValueError(f"Unknown which_gt: {self.which_gt}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
class ImageNetDataset(ImageFolder): |
|
|
"""Class to represent the ImageNet1k dataset.""" |
|
|
|
|
|
def __init__(self, root, **kwargs): |
|
|
super().__init__(root=root, **kwargs) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sample, target = super().__getitem__(idx) |
|
|
target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target] |
|
|
return { |
|
|
"id": idx, |
|
|
"image": sample, |
|
|
"class_id": target, |
|
|
"class_name": target_label, |
|
|
} |
|
|
|
|
|
|
|
|
class HatefulMemesDataset(Dataset): |
|
|
def __init__(self, image_dir_path, annotations_path): |
|
|
self.image_dir_path = image_dir_path |
|
|
with open(annotations_path, "r") as f: |
|
|
self.annotations = [json.loads(line) for line in f] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.annotations) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
annotation = self.annotations[idx] |
|
|
img_path = os.path.join(self.image_dir_path, annotation["img"].split("/")[-1]) |
|
|
image = Image.open(img_path) |
|
|
image.load() |
|
|
return { |
|
|
"id": annotation["id"], |
|
|
"image": image, |
|
|
"ocr": annotation["text"], |
|
|
"class_name": "yes" if annotation["label"] == 1 else "no", |
|
|
"class_id": annotation["label"], |
|
|
} |
|
|
|
|
|
|
|
|
class TensorCaptionDataset(CaptionDataset): |
|
|
def get_from_id(self, image_id): |
|
|
assert self.dataset_name == "coco" |
|
|
assert not self.is_train |
|
|
|
|
|
prefix = '' |
|
|
image_path = f"{self.image_val_dir_path}/{prefix}{str(image_id).zfill(12)}.pt" |
|
|
image = torch.load(image_path) |
|
|
return image |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if self.dataset_name == "coco": |
|
|
image_path = os.path.join( |
|
|
self.image_train_dir_path if self.annotations[idx]["filepath"] == "train2014" else self.image_val_dir_path, |
|
|
self.annotations[idx]["filename"] |
|
|
) |
|
|
image_path = image_path.replace("jpg", "pt") |
|
|
image = torch.load(image_path) |
|
|
elif self.dataset_name == "flickr": |
|
|
raise NotImplementedError |
|
|
image = Image.open( |
|
|
os.path.join( |
|
|
self.image_train_dir_path, self.annotations[idx]["filename"] |
|
|
) |
|
|
) |
|
|
caption = self.annotations[idx]["sentences"][0]["raw"] |
|
|
return { |
|
|
"image": image, |
|
|
"caption": caption, |
|
|
"image_id": self.annotations[idx]["cocoid"] |
|
|
if self.dataset_name == "coco" |
|
|
else self.annotations[idx]["filename"].split(".")[0], |
|
|
} |