S13 / datasets.py
Shivdutta's picture
Upload 3 files
4e9cc3f verified
from typing import Tuple
from torchvision import datasets, transforms
class TransformedDataset(datasets.CIFAR10):
"""
Custom dataset class extending CIFAR10 dataset with additional transformation capabilities.
Args:
root (str, optional): Root directory where the dataset is stored. Default is "./data".
train (bool, optional): Specifies if the dataset is for training or testing. Default is True.
download (bool, optional): If True, downloads the dataset from the internet and places it in the root directory. Default is True.
transform (list, optional): List of transformations to apply to the images. Default is None.
"""
def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index: int) -> Tuple:
"""
Retrieves the item at the specified index.
Args:
index (int): Index of the item to retrieve.
Returns:
Tuple: A tuple containing the transformed image and its label.
"""
image, label = self.data[index], self.targets[index]
if self.transform:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label