# -*- coding: utf-8 -*- ''' Model implementation. We'll be using a "simple" ResNet-18 for image classification here. 2022 Benjamin Kellenberger ''' from os.path import abspath import torch from torchvision import datasets from torchvision.transforms import Compose, Resize, ToTensor def load(cfg): """ Load the MNIST dataset from PyTorch (download if needed) and return a DataLoader MNIST is a sample dataset for machine learning, each image is 28-pixels high and 28-pixels wide (1 color channel) """ root = abspath('datasets') train = torch.utils.data.DataLoader( datasets.MNIST( root, train=True, transform=Compose([Resize(cfg['image_size']), ToTensor()]), download=True, ), batch_size=cfg.get('batch_size'), shuffle=True, num_workers=cfg.get('num_workers'), ) test = torch.utils.data.DataLoader( datasets.MNIST( root, train=False, transform=Compose([Resize(cfg['image_size']), ToTensor()]), download=True, ), batch_size=cfg.get('batch_size'), shuffle=False, num_workers=cfg.get('num_workers'), ) return train, test