Spaces:
Runtime error
Runtime error
File size: 934 Bytes
d250771 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
# utils file
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
def __init__(self, root="./data", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label
def augmentation_custom_resnet(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616), pad=4):
transform = A.Compose([A.Normalize(mean=mean, std=std),
ToTensorV2()])
return transform |