ERAV2_S13 / utils.py
Vasudevakrishna's picture
S13 added.
d250771
raw
history blame contribute delete
934 Bytes
# utils file
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
def __init__(self, root="./data", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label
def augmentation_custom_resnet(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616), pad=4):
transform = A.Compose([A.Normalize(mean=mean, std=std),
ToTensorV2()])
return transform