Spaces:
Runtime error
Runtime error
| # utils file | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from torchvision import transforms | |
| import torchvision | |
| import numpy as np | |
| import albumentations as A | |
| from albumentations.pytorch.transforms import ToTensorV2 | |
| class Cifar10SearchDataset(torchvision.datasets.CIFAR10): | |
| def __init__(self, root="./data", train=True, download=True, transform=None): | |
| super().__init__(root=root, train=train, download=download, transform=transform) | |
| def __getitem__(self, index): | |
| image, label = self.data[index], self.targets[index] | |
| if self.transform is not None: | |
| transformed = self.transform(image=image) | |
| image = transformed["image"] | |
| return image, label | |
| def augmentation_custom_resnet(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616), pad=4): | |
| transform = A.Compose([A.Normalize(mean=mean, std=std), | |
| ToTensorV2()]) | |
| return transform |