bluemellophone's picture
Linted files
2192664 unverified
# -*- coding: utf-8 -*-
'''
Model implementation.
We'll be using a "simple" ResNet-18 for image classification here.
2022 Benjamin Kellenberger
'''
from os.path import abspath
import torch
from torchvision import datasets
from torchvision.transforms import Compose, Resize, ToTensor
def load(cfg):
"""
Load the MNIST dataset from PyTorch (download if needed) and return a DataLoader
MNIST is a sample dataset for machine learning, each image is 28-pixels high and 28-pixels wide (1 color channel)
"""
root = abspath('datasets')
train = torch.utils.data.DataLoader(
datasets.MNIST(
root,
train=True,
transform=Compose([Resize(cfg['image_size']), ToTensor()]),
download=True,
),
batch_size=cfg.get('batch_size'),
shuffle=True,
num_workers=cfg.get('num_workers'),
)
test = torch.utils.data.DataLoader(
datasets.MNIST(
root,
train=False,
transform=Compose([Resize(cfg['image_size']), ToTensor()]),
download=True,
),
batch_size=cfg.get('batch_size'),
shuffle=False,
num_workers=cfg.get('num_workers'),
)
return train, test