Spaces:
Build error
Build error
| from pathlib import Path | |
| import torchvision | |
| from torchvision.transforms import v2 | |
| import torch | |
| data_path = Path(__file__).resolve().parent.parent.parent / 'data' | |
| print(data_path) | |
| def transform(): | |
| return v2.Compose([ | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Pad(2), | |
| v2.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| def add_noise(noise_factor=0.5): | |
| def add_noise_to_image(x): | |
| noise = torch.randn_like(x) * noise_factor | |
| return torch.clamp(x + noise, 0., 1.) | |
| return v2.Lambda(add_noise_to_image) | |
| def train_transform(): | |
| return v2.Compose([ | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Pad(2), | |
| v2.RandomAffine(degrees=5, translate=(0.2, 0.2), scale=(0.5, 1.2)), | |
| v2.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| def get_dataset(val_split=0.2): | |
| train_dataset = torchvision.datasets.MNIST( | |
| root=str(data_path), train=True, transform=train_transform(), download=True | |
| ) | |
| test_dataset = torchvision.datasets.MNIST( | |
| root=str(data_path), train=False, transform=transform(), download=True | |
| ) | |
| return train_dataset, test_dataset | |