| from types import SimpleNamespace | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision.datasets import ImageFolder | |
| from torchvision.transforms import v2 | |
| DATASET_MEAN = 0.5077385902404785 | |
| DATASET_STD = 0.255077600479126 | |
| class PreprocessedImageFolder(ImageFolder): | |
| def __init__(self, root, transform=None): | |
| super().__init__(root, transform=transform) | |
| self.preprocess = v2.Compose( | |
| [ | |
| v2.Grayscale(), | |
| v2.PILToTensor(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=(DATASET_MEAN,), std=(DATASET_STD,)), | |
| ] | |
| ) | |
| processed_samples = [] | |
| for path, target in self.samples: | |
| sample = self.loader(path) | |
| processed_sample = self.preprocess(sample) | |
| processed_samples.append((processed_sample, target)) | |
| self.samples = processed_samples | |
| def __getitem__(self, index): | |
| sample, target = self.samples[index] | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample, target | |
| augmentations = v2.Compose( | |
| [ | |
| v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| v2.RandomHorizontalFlip(), | |
| v2.RandomResizedCrop(size=48, scale=(0.9, 1.1), antialias=True), | |
| v2.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1)), | |
| v2.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)), | |
| ] | |
| ) | |
| def make_dls(train_ds, valid_ds, batch_size=64, num_workers=2): | |
| train_dl = DataLoader( | |
| train_ds, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| valid_dl = DataLoader( | |
| valid_ds, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| dls = SimpleNamespace(**{"train": train_dl, "valid": valid_dl}) | |
| return dls | |