File size: 978 Bytes
6ab5efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
import torchvision
import torchvision.transforms as transforms
def get_dataloader(train):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform = transform_train if train else transform_test
dataset = torchvision.datasets.CIFAR100(
root='./data', train=train, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=128, shuffle=train, num_workers=2)
return dataloader
|