File size: 934 Bytes
d250771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# utils file

import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision
import numpy as np 
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

class Cifar10SearchDataset(torchvision.datasets.CIFAR10):

    def __init__(self, root="./data", train=True, download=True, transform=None):
      super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
      image, label = self.data[index], self.targets[index]
      
      if self.transform is not None:
        transformed = self.transform(image=image)
        image = transformed["image"]

        return image, label
        
def augmentation_custom_resnet(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616), pad=4):
 
    transform = A.Compose([A.Normalize(mean=mean, std=std),
                           ToTensorV2()])

    return transform