| import os | |
| import torch | |
| from torchvision.datasets import SVHN as PyTorchSVHN | |
| import numpy as np | |
| class SVHN: | |
| def __init__(self, | |
| preprocess, | |
| location=os.path.expanduser('~/data'), | |
| batch_size=128, | |
| num_workers=16): | |
| # to fit with repo conventions for location | |
| modified_location = os.path.join(location, 'svhn') | |
| self.train_dataset = PyTorchSVHN( | |
| root=modified_location, | |
| download=True, | |
| split='train', | |
| transform=preprocess | |
| ) | |
| self.train_loader = torch.utils.data.DataLoader( | |
| self.train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers | |
| ) | |
| self.test_dataset = PyTorchSVHN( | |
| root=modified_location, | |
| download=True, | |
| split='test', | |
| transform=preprocess | |
| ) | |
| self.test_loader = torch.utils.data.DataLoader( | |
| self.test_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers | |
| ) | |
| self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] | |