Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Module containing wrapper classes for PyTorch Datasets | |
| Author: Shilpaj Bhalerao | |
| Date: Jun 25, 2023 | |
| """ | |
| # Standard Library Imports | |
| from typing import Tuple | |
| # Third-Party Imports | |
| from torchvision import datasets, transforms | |
| class AlbumDataset(datasets.CIFAR10): | |
| """ | |
| Wrapper class to use albumentations library with PyTorch Dataset | |
| """ | |
| def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None): | |
| """ | |
| Constructor | |
| :param root: Directory at which data is stored | |
| :param train: Param to distinguish if data is training or test | |
| :param download: Param to download the dataset from source | |
| :param transform: List of transformation to be performed on the dataset | |
| """ | |
| super().__init__(root=root, train=train, download=download, transform=transform) | |
| def __getitem__(self, index: int) -> Tuple: | |
| """ | |
| Method to return image and its label | |
| :param index: Index of image and label in the dataset | |
| """ | |
| image, label = self.data[index], self.targets[index] | |
| if self.transform: | |
| transformed = self.transform(image=image) | |
| image = transformed["image"] | |
| return image, label | |