Spaces:
Runtime error
Runtime error
| from .transformations import AddGaussianNoise | |
| from abc import abstractmethod, ABCMeta | |
| from argparse import ArgumentParser | |
| from pytorch_lightning import LightningDataModule | |
| from torch.utils.data import ( | |
| DataLoader, | |
| Dataset, | |
| default_collate, | |
| RandomSampler, | |
| SequentialSampler, | |
| ) | |
| from torchvision import transforms | |
| from typing import Optional | |
| class ImageDataModule(LightningDataModule, metaclass=ABCMeta): | |
| def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: | |
| parser = parent_parser.add_argument_group("Data Modules") | |
| parser.add_argument( | |
| "--data_dir", | |
| type=str, | |
| default="data/", | |
| help="The directory where the data is stored.", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=32, | |
| help="The batch size to use.", | |
| ) | |
| parser.add_argument( | |
| "--add_noise", | |
| action="store_true", | |
| help="Use gaussian noise augmentation.", | |
| ) | |
| parser.add_argument( | |
| "--add_rotation", | |
| action="store_true", | |
| help="Use rotation augmentation.", | |
| ) | |
| parser.add_argument( | |
| "--add_blur", | |
| action="store_true", | |
| help="Use blur augmentation.", | |
| ) | |
| parser.add_argument( | |
| "--num_workers", | |
| type=int, | |
| default=4, | |
| help="Number of workers to use for data loading.", | |
| ) | |
| return parent_parser | |
| # Declare variables that will be initialized later | |
| train_data: Dataset | |
| val_data: Dataset | |
| test_data: Dataset | |
| def __init__( | |
| self, | |
| feature_extractor: Optional[callable] = None, | |
| data_dir: str = "data/", | |
| batch_size: int = 32, | |
| add_noise: bool = False, | |
| add_rotation: bool = False, | |
| add_blur: bool = False, | |
| num_workers: int = 4, | |
| ): | |
| """Abstract Pytorch Lightning DataModule for image datasets. | |
| Args: | |
| feature_extractor (callable): feature extractor instance | |
| data_dir (str): directory to store the dataset | |
| batch_size (int): batch size for the train/val/test dataloaders | |
| add_noise (bool): whether to add noise to the images | |
| add_rotation (bool): whether to add random rotation to the images | |
| add_blur (bool): whether to add blur to the images | |
| num_workers (int): number of workers for train/val/test dataloaders | |
| """ | |
| super().__init__() | |
| # Store hyperparameters | |
| self.data_dir = data_dir | |
| self.batch_size = batch_size | |
| self.feature_extractor = feature_extractor | |
| self.num_workers = num_workers | |
| # Set the transforms | |
| # If the feature_extractor is None, then we do not split the images into features | |
| init_transforms = [feature_extractor] if feature_extractor else [] | |
| self.transform = transforms.Compose(init_transforms) | |
| self._add_transforms(add_noise, add_rotation, add_blur) | |
| # Set the collate function and the samplers | |
| # These can be adapted in a child datamodule class to have a different behavior | |
| self.collate_fn = default_collate | |
| self.shuffled_sampler = RandomSampler | |
| self.sequential_sampler = SequentialSampler | |
| def _add_transforms(self, noise: bool, rotation: bool, blur: bool): | |
| """Add transforms to the module's transformations list. | |
| Args: | |
| noise (bool): whether to add noise to the images | |
| rotation (bool): whether to add random rotation to the images | |
| blur (bool): whether to add blur to the images | |
| """ | |
| # TODO: | |
| # - Which order to add the transforms in? | |
| # - Applied in both train and test or just test? | |
| # - Check what transforms are applied by the model | |
| if noise: | |
| self.transform.transforms.append(AddGaussianNoise(0.0, 1.0)) | |
| if rotation: | |
| self.transform.transforms.append(transforms.RandomRotation(20)) | |
| if blur: | |
| self.transform.transforms.append(transforms.GaussianBlur(3)) | |
| def prepare_data(self): | |
| raise NotImplementedError() | |
| def setup(self, stage: Optional[str] = None): | |
| raise NotImplementedError() | |
| # noinspection PyTypeChecker | |
| def train_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.train_data, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| collate_fn=self.collate_fn, | |
| sampler=self.shuffled_sampler(self.train_data), | |
| ) | |
| # noinspection PyTypeChecker | |
| def val_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.val_data, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| collate_fn=self.collate_fn, | |
| sampler=self.sequential_sampler(self.val_data), | |
| ) | |
| # noinspection PyTypeChecker | |
| def test_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| self.test_data, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| collate_fn=self.collate_fn, | |
| sampler=self.sequential_sampler(self.test_data), | |
| ) | |