Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import albumentations | |
| from torchvision import datasets | |
| from albumentations.pytorch import ToTensorV2 | |
| from torch.utils.data import Dataset, DataLoader | |
| class CIFAR10Data(Dataset): | |
| def __init__(self, dataset, transforms=None) -> None: | |
| self.dataset = dataset | |
| self.transforms = transforms | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, index): | |
| image, label = self.dataset[index] | |
| image = np.array(image) | |
| if self.transforms: | |
| image = self.transforms(image=image)['image'] | |
| return image, label | |
| def _get_test_transforms(): | |
| test_transforms = albumentations.Compose([albumentations.Normalize([0.49139968, 0.48215841, 0.44653091], | |
| [0.24703223, 0.24348513, 0.26158784]), | |
| ToTensorV2()]) | |
| return test_transforms | |
| def _get_data(is_train, is_download): | |
| """Method to get data for training or testing | |
| Args: | |
| is_train (bool): True if data is for training else false | |
| is_download (bool): True to download dataset from iternet | |
| Returns: | |
| object: Oject of dataset | |
| """ | |
| data = datasets.CIFAR10('../data', train=is_train, download=is_download) | |
| return data | |
| def _get_data_loader(data, **kwargs): | |
| """Method to get data loader. | |
| Args: | |
| data (object): Oject of dataset | |
| Returns: | |
| object: Object of DataLoader class used to feed data to neural network model | |
| """ | |
| loader = DataLoader(data, **kwargs) | |
| return loader | |
| def get_test_data_loader(**kwargs): | |
| """Method to get data loader for testing | |
| Args: | |
| batch_size (int): Number of images in a batch | |
| Returns: | |
| object: Object of DataLoader class used to feed data to neural network model | |
| """ | |
| test_transforms = _get_test_transforms() | |
| test_data = _get_data(is_train=False, is_download=True) | |
| test_data = CIFAR10Data(test_data, test_transforms) | |
| test_loader = _get_data_loader(data=test_data, **kwargs) | |
| return test_loader |