# utils file import matplotlib.pyplot as plt import torch from torchvision import transforms import torchvision import numpy as np import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 class Cifar10SearchDataset(torchvision.datasets.CIFAR10): def __init__(self, root="./data", train=True, download=True, transform=None): super().__init__(root=root, train=train, download=download, transform=transform) def __getitem__(self, index): image, label = self.data[index], self.targets[index] if self.transform is not None: transformed = self.transform(image=image) image = transformed["image"] return image, label def augmentation_custom_resnet(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616), pad=4): transform = A.Compose([A.Normalize(mean=mean, std=std), ToTensorV2()]) return transform