Spaces:
Sleeping
Sleeping
| # from lightning.pytorch.utilities.types import TRAIN_DATALOADERS,EVAL_DATALOADERS | |
| # def calculate_mean_std_mnist(datamodule:pl.LightningDataModule): | |
| # data_loader:TRAIN_DATALOADERS; | |
| # mean = torch.zeros(1); | |
| # std = torch.zeros(1) | |
| # num_samples = 0 | |
| # for img in data_loader: | |
| # image = img[0] | |
| # image = image.squeeze() | |
| # mean += image.mean() # mean across channel sum for all pics | |
| # std += image.std() | |
| # num_samples += 1 | |
| # mean /= num_samples | |
| # std /= num_samples | |
| # return (mean.item(),std.item()) | |
| from torchvision import transforms | |
| TRAIN_TRANSFORMS = transforms.Compose([ | |
| transforms.RandomApply([transforms.CenterCrop(22), ], p=0.1), | |
| transforms.RandomAffine(degrees=7, shear=10, translate=(0.1, 0.1), scale=(0.8, 1.2)), | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)), | |
| ]) | |
| TEST_TRANSFORMS = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) |