File size: 1,378 Bytes
4e9cc3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Tuple
from torchvision import datasets, transforms


class TransformedDataset(datasets.CIFAR10):
    """
    Custom dataset class extending CIFAR10 dataset with additional transformation capabilities.

    Args:
        root (str, optional): Root directory where the dataset is stored. Default is "./data".
        train (bool, optional): Specifies if the dataset is for training or testing. Default is True.
        download (bool, optional): If True, downloads the dataset from the internet and places it in the root directory. Default is True.
        transform (list, optional): List of transformations to apply to the images. Default is None.

    """
    def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index: int) -> Tuple:
        """
        Retrieves the item at the specified index.

        Args:
            index (int): Index of the item to retrieve.

        Returns:
            Tuple: A tuple containing the transformed image and its label.

        """
        image, label = self.data[index], self.targets[index]

        if self.transform:
            transformed = self.transform(image=image)
            image = transformed["image"]
        return image, label