OrthoReg / src /datasets /svhn.py
gezi2333's picture
Upload folder using huggingface_hub
3589275 verified
import os
import torch
from torchvision.datasets import SVHN as PyTorchSVHN
import numpy as np
class SVHN:
def __init__(self,
preprocess,
location=os.path.expanduser('~/data'),
batch_size=128,
num_workers=16):
# to fit with repo conventions for location
modified_location = os.path.join(location, 'svhn')
self.train_dataset = PyTorchSVHN(
root=modified_location,
download=True,
split='train',
transform=preprocess
)
self.train_loader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
self.test_dataset = PyTorchSVHN(
root=modified_location,
download=True,
split='test',
transform=preprocess
)
self.test_loader = torch.utils.data.DataLoader(
self.test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']