Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| import glob | |
| import json | |
| import os | |
| from typing import Callable, Dict, List, Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision.datasets import ImageNet | |
| from virtex.data import transforms as T | |
| class ImageNetDataset(ImageNet): | |
| r""" | |
| Simple wrapper over torchvision's ImageNet dataset with a feature to support | |
| restricting dataset size for semi-supervised learning setup (data-efficiency | |
| ablations). | |
| We also handle image transform here instead of passing to super class. | |
| Parameters | |
| ---------- | |
| data_root: str, optional (default = "datasets/imagenet") | |
| Path to the dataset root directory. This must contain directories | |
| ``train``, ``val`` with per-category sub-directories. | |
| split: str, optional (default = "train") | |
| Which split to read from. One of ``{"train", "val"}``. | |
| image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) | |
| A list of transformations, from either `albumentations | |
| <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms` | |
| to be applied on the image. | |
| percentage: int, optional (default = 100) | |
| Percentage of dataset to keep. This dataset retains first K% of images | |
| per class to retain same class label distribution. This is 100% by | |
| default, and will be ignored if ``split`` is ``val``. | |
| """ | |
| def __init__( | |
| self, | |
| data_root: str = "datasets/imagenet", | |
| split: str = "train", | |
| image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, | |
| percentage: float = 100, | |
| ): | |
| super().__init__(data_root, split) | |
| assert percentage > 0, "Cannot load dataset with 0 percent original size." | |
| self.image_transform = image_transform | |
| # Super class has `imgs` list and `targets` list. Make a dict of | |
| # class ID to index of instances in these lists and pick first K%. | |
| if split == "train" and percentage < 100: | |
| label_to_indices: Dict[int, List[int]] = defaultdict(list) | |
| for index, target in enumerate(self.targets): | |
| label_to_indices[target].append(index) | |
| # Trim list of indices per label. | |
| for label in label_to_indices: | |
| retain = int(len(label_to_indices[label]) * (percentage / 100)) | |
| label_to_indices[label] = label_to_indices[label][:retain] | |
| # Trim `self.imgs` and `self.targets` as per indices we have. | |
| retained_indices: List[int] = [ | |
| index | |
| for indices_per_label in label_to_indices.values() | |
| for index in indices_per_label | |
| ] | |
| # Shorter dataset with size K% of original dataset, but almost same | |
| # class label distribution. super class will handle the rest. | |
| self.imgs = [self.imgs[i] for i in retained_indices] | |
| self.targets = [self.targets[i] for i in retained_indices] | |
| self.samples = self.imgs | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| image, label = super().__getitem__(idx) | |
| # Apply transformation to image and convert to CHW format. | |
| image = self.image_transform(image=np.array(image))["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| return { | |
| "image": torch.tensor(image, dtype=torch.float), | |
| "label": torch.tensor(label, dtype=torch.long), | |
| } | |
| def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
| return { | |
| "image": torch.stack([d["image"] for d in data], dim=0), | |
| "label": torch.stack([d["label"] for d in data], dim=0), | |
| } | |
| class INaturalist2018Dataset(Dataset): | |
| r""" | |
| A dataset which provides image-label pairs from the iNaturalist 2018 dataset. | |
| Parameters | |
| ---------- | |
| data_root: str, optional (default = "datasets/inaturalist") | |
| Path to the dataset root directory. This must contain images and | |
| annotations (``train2018``, ``val2018`` and ``annotations`` directories). | |
| split: str, optional (default = "train") | |
| Which split to read from. One of ``{"train", "val"}``. | |
| image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) | |
| A list of transformations, from either `albumentations | |
| <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms` | |
| to be applied on the image. | |
| """ | |
| def __init__( | |
| self, | |
| data_root: str = "datasets/inaturalist", | |
| split: str = "train", | |
| image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, | |
| ): | |
| self.split = split | |
| self.image_transform = image_transform | |
| annotations = json.load( | |
| open(os.path.join(data_root, "annotations", f"{split}2018.json")) | |
| ) | |
| # Make a list of image IDs to file paths. | |
| self.image_id_to_file_path = { | |
| ann["id"]: os.path.join(data_root, ann["file_name"]) | |
| for ann in annotations["images"] | |
| } | |
| # For a list of instances: (image_id, category_id) tuples. | |
| self.instances = [ | |
| (ann["image_id"], ann["category_id"]) | |
| for ann in annotations["annotations"] | |
| ] | |
| def __len__(self): | |
| return len(self.instances) | |
| def __getitem__(self, idx: int): | |
| image_id, label = self.instances[idx] | |
| image_path = self.image_id_to_file_path[image_id] | |
| # Open image from path and apply transformation, convert to CHW format. | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = self.image_transform(image=image)["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| return { | |
| "image": torch.tensor(image, dtype=torch.float), | |
| "label": torch.tensor(label, dtype=torch.long), | |
| } | |
| def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
| return { | |
| "image": torch.stack([d["image"] for d in data], dim=0), | |
| "label": torch.stack([d["label"] for d in data], dim=0), | |
| } | |
| class VOC07ClassificationDataset(Dataset): | |
| r""" | |
| A dataset which provides image-label pairs from the PASCAL VOC 2007 dataset. | |
| Parameters | |
| ---------- | |
| data_root: str, optional (default = "datasets/VOC2007") | |
| Path to the dataset root directory. This must contain directories | |
| ``Annotations``, ``ImageSets`` and ``JPEGImages``. | |
| split: str, optional (default = "trainval") | |
| Which split to read from. One of ``{"trainval", "test"}``. | |
| image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) | |
| A list of transformations, from either `albumentations | |
| <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms` | |
| to be applied on the image. | |
| """ | |
| def __init__( | |
| self, | |
| data_root: str = "datasets/VOC2007", | |
| split: str = "trainval", | |
| image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, | |
| ): | |
| self.split = split | |
| self.image_transform = image_transform | |
| ann_paths = sorted( | |
| glob.glob(os.path.join(data_root, "ImageSets", "Main", f"*_{split}.txt")) | |
| ) | |
| # A list like; ["aeroplane", "bicycle", "bird", ...] | |
| self.class_names = [ | |
| os.path.basename(path).split("_")[0] for path in ann_paths | |
| ] | |
| # We will construct a map for image name to a list of | |
| # shape: (num_classes, ) and values as one of {-1, 0, 1}. | |
| # 1: present, -1: not present, 0: ignore. | |
| image_names_to_labels: Dict[str, torch.Tensor] = defaultdict( | |
| lambda: -torch.ones(len(self.class_names), dtype=torch.int32) | |
| ) | |
| for cls_num, ann_path in enumerate(ann_paths): | |
| with open(ann_path, "r") as fopen: | |
| for line in fopen: | |
| img_name, orig_label_str = line.strip().split() | |
| orig_label = int(orig_label_str) | |
| # In VOC data, -1 (not present): set to 0 as train target | |
| # In VOC data, 0 (ignore): set to -1 as train target. | |
| orig_label = ( | |
| 0 if orig_label == -1 else -1 if orig_label == 0 else 1 | |
| ) | |
| image_names_to_labels[img_name][cls_num] = orig_label | |
| # Convert the dict to a list of tuples for easy indexing. | |
| # Replace image name with full image path. | |
| self.instances: List[Tuple[str, torch.Tensor]] = [ | |
| ( | |
| os.path.join(data_root, "JPEGImages", f"{image_name}.jpg"), | |
| label.tolist(), | |
| ) | |
| for image_name, label in image_names_to_labels.items() | |
| ] | |
| def __len__(self): | |
| return len(self.instances) | |
| def __getitem__(self, idx: int): | |
| image_path, label = self.instances[idx] | |
| # Open image from path and apply transformation, convert to CHW format. | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = self.image_transform(image=image)["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| return { | |
| "image": torch.tensor(image, dtype=torch.float), | |
| "label": torch.tensor(label, dtype=torch.long), | |
| } | |
| def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
| return { | |
| "image": torch.stack([d["image"] for d in data], dim=0), | |
| "label": torch.stack([d["label"] for d in data], dim=0), | |
| } | |
| class ImageDirectoryDataset(Dataset): | |
| r""" | |
| A dataset which reads images from any directory. This class is useful to | |
| run image captioning inference on our models with any arbitrary images. | |
| Parameters | |
| ---------- | |
| data_root: str | |
| Path to a directory containing images. | |
| image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) | |
| A list of transformations, from either `albumentations | |
| <https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms` | |
| to be applied on the image. | |
| """ | |
| def __init__( | |
| self, data_root: str, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM | |
| ): | |
| self.image_paths = glob.glob(os.path.join(data_root, "*")) | |
| self.image_transform = image_transform | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx: int): | |
| image_path = self.image_paths[idx] | |
| # Remove extension from image name to use as image_id. | |
| image_id = os.path.splitext(os.path.basename(image_path))[0] | |
| # Open image from path and apply transformation, convert to CHW format. | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = self.image_transform(image=image)["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| # Return image id as string so collate_fn does not cast to torch.tensor. | |
| return {"image_id": str(image_id), "image": torch.tensor(image)} | |