| import os | |
| import torch | |
| import torchvision | |
| import torchvision.datasets as datasets | |
| def rotate_img(img): | |
| return torchvision.transforms.functional.rotate(img, -90) | |
| def flip_img(img): | |
| return torchvision.transforms.functional.hflip(img) | |
| def emnist_preprocess(): | |
| return torchvision.transforms.Compose( | |
| [ | |
| rotate_img, | |
| flip_img, | |
| ] | |
| ) | |
| class EMNIST: | |
| def __init__( | |
| self, | |
| preprocess, | |
| location, | |
| batch_size=128, | |
| num_workers=8, | |
| ): | |
| preprocess1 = emnist_preprocess() | |
| preprocess = torchvision.transforms.Compose( | |
| [ | |
| preprocess, | |
| preprocess1, | |
| ] | |
| ) | |
| # if not os.path.exists(location): | |
| # os.makedirs(location, exist_ok=True) | |
| self.train_dataset = datasets.EMNIST( | |
| root=location, | |
| download=True, | |
| split="digits", | |
| transform=preprocess, | |
| train=True, | |
| ) | |
| self.train_loader = torch.utils.data.DataLoader( | |
| self.train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| ) | |
| self.test_dataset = datasets.EMNIST( | |
| root=location, | |
| download=True, | |
| split="digits", | |
| transform=preprocess, | |
| train=False, | |
| ) | |
| self.test_loader = torch.utils.data.DataLoader( | |
| self.test_dataset, | |
| batch_size=32, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| ) | |
| self.classnames = self.train_dataset.classes | |