from pathlib import Path import torchvision from torchvision.transforms import v2 import torch data_path = Path(__file__).resolve().parent.parent.parent / 'data' print(data_path) def transform(): return v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Pad(2), v2.Normalize((0.1307,), (0.3081,)) ]) def add_noise(noise_factor=0.5): def add_noise_to_image(x): noise = torch.randn_like(x) * noise_factor return torch.clamp(x + noise, 0., 1.) return v2.Lambda(add_noise_to_image) def train_transform(): return v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Pad(2), v2.RandomAffine(degrees=5, translate=(0.2, 0.2), scale=(0.5, 1.2)), v2.Normalize((0.1307,), (0.3081,)) ]) def get_dataset(val_split=0.2): train_dataset = torchvision.datasets.MNIST( root=str(data_path), train=True, transform=train_transform(), download=True ) test_dataset = torchvision.datasets.MNIST( root=str(data_path), train=False, transform=transform(), download=True ) return train_dataset, test_dataset