| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| def get_dataloader(train): | |
| transform_train = transforms.Compose([ | |
| transforms.RandomCrop(32, padding=4), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandAugment(num_ops=2, magnitude=9), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
| transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3)), | |
| ]) | |
| transform_test = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
| ]) | |
| transform = transform_train if train else transform_test | |
| dataset = torchvision.datasets.CIFAR100( | |
| root='./data', train=train, download=True, transform=transform) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=128, shuffle=train, num_workers=2) | |
| return dataloader | |