|
|
|
|
|
|
|
|
|
|
| import csv
|
| import logging
|
| import os
|
| from enum import Enum
|
| from typing import Callable, List, Optional, Tuple, Union
|
|
|
| import numpy as np
|
|
|
| from .decoders import ImageDataDecoder, TargetDecoder
|
| from .extended import ExtendedVisionDataset
|
|
|
| logger = logging.getLogger("dinov2")
|
| _Target = int
|
|
|
|
|
| class _Split(Enum):
|
| TRAIN = "train"
|
| VAL = "val"
|
| TEST = "test"
|
|
|
| @property
|
| def length(self) -> int:
|
| split_lengths = {
|
| _Split.TRAIN: 1_281_167,
|
| _Split.VAL: 50_000,
|
| _Split.TEST: 100_000,
|
| }
|
| return split_lengths[self]
|
|
|
| def get_dirname(self, class_id: Optional[str] = None) -> str:
|
| return self.value if class_id is None else os.path.join(self.value, class_id)
|
|
|
| def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
|
| dirname = self.get_dirname(class_id)
|
| if self == _Split.TRAIN:
|
| basename = f"{class_id}_{actual_index}"
|
| else:
|
| basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
|
| return os.path.join(dirname, basename + ".JPEG")
|
|
|
| def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
|
| assert self != _Split.TEST
|
| dirname, filename = os.path.split(image_relpath)
|
| class_id = os.path.split(dirname)[-1]
|
| basename, _ = os.path.splitext(filename)
|
| actual_index = int(basename.split("_")[-1])
|
| return class_id, actual_index
|
|
|
|
|
| class ImageNet(ExtendedVisionDataset):
|
| Target = Union[_Target]
|
| Split = Union[_Split]
|
|
|
| def __init__(
|
| self,
|
| *,
|
| split: "ImageNet.Split",
|
| root: str,
|
| extra: str,
|
| transforms: Optional[Callable] = None,
|
| transform: Optional[Callable] = None,
|
| target_transform: Optional[Callable] = None,
|
| ) -> None:
|
| super().__init__(
|
| root=root,
|
| transforms=transforms,
|
| transform=transform,
|
| target_transform=target_transform,
|
| image_decoder=ImageDataDecoder,
|
| target_decoder=TargetDecoder,
|
| )
|
| self._extra_root = extra
|
| self._split = split
|
|
|
| self._entries = None
|
| self._class_ids = None
|
| self._class_names = None
|
|
|
| @property
|
| def split(self) -> "ImageNet.Split":
|
| return self._split
|
|
|
| def _get_extra_full_path(self, extra_path: str) -> str:
|
| return os.path.join(self._extra_root, extra_path)
|
|
|
| def _load_extra(self, extra_path: str) -> np.ndarray:
|
| extra_full_path = self._get_extra_full_path(extra_path)
|
| return np.load(extra_full_path, mmap_mode="r")
|
|
|
| def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
| extra_full_path = self._get_extra_full_path(extra_path)
|
| os.makedirs(self._extra_root, exist_ok=True)
|
| np.save(extra_full_path, extra_array)
|
|
|
| @property
|
| def _entries_path(self) -> str:
|
| return f"entries-{self._split.value.upper()}.npy"
|
|
|
| @property
|
| def _class_ids_path(self) -> str:
|
| return f"class-ids-{self._split.value.upper()}.npy"
|
|
|
| @property
|
| def _class_names_path(self) -> str:
|
| return f"class-names-{self._split.value.upper()}.npy"
|
|
|
| def _get_entries(self) -> np.ndarray:
|
| if self._entries is None:
|
| self._entries = self._load_extra(self._entries_path)
|
| assert self._entries is not None
|
| return self._entries
|
|
|
| def _get_class_ids(self) -> np.ndarray:
|
| if self._split == _Split.TEST:
|
| raise AssertionError("Class IDs are not available in TEST split")
|
| if self._class_ids is None:
|
| self._class_ids = self._load_extra(self._class_ids_path)
|
| assert self._class_ids is not None
|
| return self._class_ids
|
|
|
| def _get_class_names(self) -> np.ndarray:
|
| if self._split == _Split.TEST:
|
| raise AssertionError("Class names are not available in TEST split")
|
| if self._class_names is None:
|
| self._class_names = self._load_extra(self._class_names_path)
|
| assert self._class_names is not None
|
| return self._class_names
|
|
|
| def find_class_id(self, class_index: int) -> str:
|
| class_ids = self._get_class_ids()
|
| return str(class_ids[class_index])
|
|
|
| def find_class_name(self, class_index: int) -> str:
|
| class_names = self._get_class_names()
|
| return str(class_names[class_index])
|
|
|
| def get_image_data(self, index: int) -> bytes:
|
| entries = self._get_entries()
|
| actual_index = entries[index]["actual_index"]
|
|
|
| class_id = self.get_class_id(index)
|
|
|
| image_relpath = self.split.get_image_relpath(actual_index, class_id)
|
| image_full_path = os.path.join(self.root, image_relpath)
|
| with open(image_full_path, mode="rb") as f:
|
| image_data = f.read()
|
| return image_data
|
|
|
| def get_target(self, index: int) -> Optional[Target]:
|
| entries = self._get_entries()
|
| class_index = entries[index]["class_index"]
|
| return None if self.split == _Split.TEST else int(class_index)
|
|
|
| def get_targets(self) -> Optional[np.ndarray]:
|
| entries = self._get_entries()
|
| return None if self.split == _Split.TEST else entries["class_index"]
|
|
|
| def get_class_id(self, index: int) -> Optional[str]:
|
| entries = self._get_entries()
|
| class_id = entries[index]["class_id"]
|
| return None if self.split == _Split.TEST else str(class_id)
|
|
|
| def get_class_name(self, index: int) -> Optional[str]:
|
| entries = self._get_entries()
|
| class_name = entries[index]["class_name"]
|
| return None if self.split == _Split.TEST else str(class_name)
|
|
|
| def __len__(self) -> int:
|
| entries = self._get_entries()
|
| assert len(entries) == self.split.length
|
| return len(entries)
|
|
|
| def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
|
| labels_full_path = os.path.join(self.root, labels_path)
|
| labels = []
|
|
|
| try:
|
| with open(labels_full_path, "r") as f:
|
| reader = csv.reader(f)
|
| for row in reader:
|
| class_id, class_name = row
|
| labels.append((class_id, class_name))
|
| except OSError as e:
|
| raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
|
|
|
| return labels
|
|
|
| def _dump_entries(self) -> None:
|
| split = self.split
|
| if split == ImageNet.Split.TEST:
|
| dataset = None
|
| sample_count = split.length
|
| max_class_id_length, max_class_name_length = 0, 0
|
| else:
|
| labels_path = "labels.txt"
|
| logger.info(f'loading labels from "{labels_path}"')
|
| labels = self._load_labels(labels_path)
|
|
|
|
|
| from torchvision.datasets import ImageFolder
|
|
|
| dataset_root = os.path.join(self.root, split.get_dirname())
|
| dataset = ImageFolder(dataset_root)
|
| sample_count = len(dataset)
|
| max_class_id_length, max_class_name_length = -1, -1
|
| for sample in dataset.samples:
|
| _, class_index = sample
|
| class_id, class_name = labels[class_index]
|
| max_class_id_length = max(len(class_id), max_class_id_length)
|
| max_class_name_length = max(len(class_name), max_class_name_length)
|
|
|
| dtype = np.dtype(
|
| [
|
| ("actual_index", "<u4"),
|
| ("class_index", "<u4"),
|
| ("class_id", f"U{max_class_id_length}"),
|
| ("class_name", f"U{max_class_name_length}"),
|
| ]
|
| )
|
| entries_array = np.empty(sample_count, dtype=dtype)
|
|
|
| if split == ImageNet.Split.TEST:
|
| old_percent = -1
|
| for index in range(sample_count):
|
| percent = 100 * (index + 1) // sample_count
|
| if percent > old_percent:
|
| logger.info(f"creating entries: {percent}%")
|
| old_percent = percent
|
|
|
| actual_index = index + 1
|
| class_index = np.uint32(-1)
|
| class_id, class_name = "", ""
|
| entries_array[index] = (actual_index, class_index, class_id, class_name)
|
| else:
|
| class_names = {class_id: class_name for class_id, class_name in labels}
|
|
|
| assert dataset
|
| old_percent = -1
|
| for index in range(sample_count):
|
| percent = 100 * (index + 1) // sample_count
|
| if percent > old_percent:
|
| logger.info(f"creating entries: {percent}%")
|
| old_percent = percent
|
|
|
| image_full_path, class_index = dataset.samples[index]
|
| image_relpath = os.path.relpath(image_full_path, self.root)
|
| class_id, actual_index = split.parse_image_relpath(image_relpath)
|
| class_name = class_names[class_id]
|
| entries_array[index] = (actual_index, class_index, class_id, class_name)
|
|
|
| logger.info(f'saving entries to "{self._entries_path}"')
|
| self._save_extra(entries_array, self._entries_path)
|
|
|
| def _dump_class_ids_and_names(self) -> None:
|
| split = self.split
|
| if split == ImageNet.Split.TEST:
|
| return
|
|
|
| entries_array = self._load_extra(self._entries_path)
|
|
|
| max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
|
| for entry in entries_array:
|
| class_index, class_id, class_name = (
|
| entry["class_index"],
|
| entry["class_id"],
|
| entry["class_name"],
|
| )
|
| max_class_index = max(int(class_index), max_class_index)
|
| max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
| max_class_name_length = max(len(str(class_name)), max_class_name_length)
|
|
|
| class_count = max_class_index + 1
|
| class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
|
| class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
|
| for entry in entries_array:
|
| class_index, class_id, class_name = (
|
| entry["class_index"],
|
| entry["class_id"],
|
| entry["class_name"],
|
| )
|
| class_ids_array[class_index] = class_id
|
| class_names_array[class_index] = class_name
|
|
|
| logger.info(f'saving class IDs to "{self._class_ids_path}"')
|
| self._save_extra(class_ids_array, self._class_ids_path)
|
|
|
| logger.info(f'saving class names to "{self._class_names_path}"')
|
| self._save_extra(class_names_array, self._class_names_path)
|
|
|
| def dump_extra(self) -> None:
|
| self._dump_entries()
|
| self._dump_class_ids_and_names()
|
|
|