|
|
|
|
|
|
|
|
|
|
| import json
|
| import os
|
| import random
|
| from enum import Enum
|
| from typing import Callable, Dict, List, Optional, Union
|
|
|
| from .decoders import ImageDataDecoder, TargetDecoder
|
| from .extended import ExtendedVisionDataset
|
|
|
|
|
|
|
|
|
| class _Split(Enum):
|
| TRAIN = "train"
|
| VAL = "val"
|
|
|
|
|
| def read_images_and_captions(root: str, split: _Split) -> List[Dict]:
|
| image_dir = None
|
| if _Split(split) == _Split.TRAIN:
|
| annotations_full_path = os.path.join(
|
| root, "annotations_trainval2014/annotations/captions_train2014.json"
|
| )
|
| image_dir = os.path.join(root, "train2014/train2014")
|
| else:
|
| annotations_full_path = os.path.join(
|
| root, "annotations_trainval2017/annotations/captions_train2017.json"
|
| )
|
| image_dir = os.path.join(root, "val2017/val2017")
|
| with open(annotations_full_path) as f:
|
| all_annotations = json.load(f)
|
| data = {}
|
| for item in all_annotations["images"]:
|
| id = item["id"]
|
| data[id] = {
|
| "id": None,
|
| "image": os.path.join(image_dir, item["file_name"]),
|
| "captions": [],
|
| }
|
| for item in all_annotations["annotations"]:
|
| data[item["image_id"]]["id"] = item["image_id"]
|
| data[item["image_id"]]["captions"].append(item["caption"])
|
| return list(data.values())
|
|
|
|
|
| class CocoCaptions(ExtendedVisionDataset):
|
| Split = Union[_Split]
|
|
|
| def __init__(
|
| self,
|
| *,
|
| split: "CocoCaptions.Split",
|
| root: Optional[str] = None,
|
| transforms: Optional[Callable] = None,
|
| transform: Optional[Callable] = None,
|
| target_transform: Optional[Callable] = None,
|
| ) -> None:
|
| super().__init__(
|
| root=root,
|
| transforms=transforms,
|
| transform=transform,
|
| target_transform=target_transform,
|
| image_decoder=ImageDataDecoder,
|
| target_decoder=TargetDecoder,
|
| )
|
|
|
| self.image_captions = read_images_and_captions(root, split)
|
|
|
| def get_image_relpath(self, index: int) -> str:
|
| image_path = self.image_captions[index]["image"]
|
| return image_path
|
|
|
| def get_image_data(self, index: int) -> bytes:
|
| image_path = self.get_image_relpath(index)
|
| with open(image_path, mode="rb") as f:
|
| image_data = f.read()
|
| return image_data
|
|
|
| def get_target(self, index: int) -> str:
|
| return random.choice(self.image_captions[index]["captions"])
|
|
|
| def __len__(self) -> int:
|
| return len(self.image_captions)
|
|
|