Prior2DSM / src /dinov3 /data /datasets /coco_captions.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import json
import os
import random
from enum import Enum
from typing import Callable, Dict, List, Optional, Union
from .decoders import ImageDataDecoder, TargetDecoder
from .extended import ExtendedVisionDataset
# Dataset: https://www.kaggle.com/datasets/nikhil7280/coco-image-caption
class _Split(Enum):
TRAIN = "train"
VAL = "val"
def read_images_and_captions(root: str, split: _Split) -> List[Dict]:
image_dir = None
if _Split(split) == _Split.TRAIN:
annotations_full_path = os.path.join(
root, "annotations_trainval2014/annotations/captions_train2014.json"
)
image_dir = os.path.join(root, "train2014/train2014")
else:
annotations_full_path = os.path.join(
root, "annotations_trainval2017/annotations/captions_train2017.json"
)
image_dir = os.path.join(root, "val2017/val2017")
with open(annotations_full_path) as f:
all_annotations = json.load(f)
data = {}
for item in all_annotations["images"]:
id = item["id"]
data[id] = {
"id": None,
"image": os.path.join(image_dir, item["file_name"]),
"captions": [],
}
for item in all_annotations["annotations"]:
data[item["image_id"]]["id"] = item["image_id"]
data[item["image_id"]]["captions"].append(item["caption"])
return list(data.values())
class CocoCaptions(ExtendedVisionDataset):
Split = Union[_Split]
def __init__(
self,
*,
split: "CocoCaptions.Split",
root: Optional[str] = None,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(
root=root,
transforms=transforms,
transform=transform,
target_transform=target_transform,
image_decoder=ImageDataDecoder,
target_decoder=TargetDecoder,
)
self.image_captions = read_images_and_captions(root, split)
def get_image_relpath(self, index: int) -> str:
image_path = self.image_captions[index]["image"]
return image_path
def get_image_data(self, index: int) -> bytes:
image_path = self.get_image_relpath(index)
with open(image_path, mode="rb") as f:
image_data = f.read()
return image_data
def get_target(self, index: int) -> str:
return random.choice(self.image_captions[index]["captions"])
def __len__(self) -> int:
return len(self.image_captions)