Spaces:
Runtime error
Runtime error
| from loguru import logger | |
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| from torchvision import datasets, transforms | |
| import lightning as pl | |
| from typing import Optional | |
| from multiprocessing import cpu_count | |
| from sklearn.model_selection import train_test_split | |
| # Configure Loguru to save logs to the logs/ directory | |
| logger.add("logs/dataloader.log", rotation="1 MB", level="INFO") | |
| class MNISTDataModule(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| batch_size: int = 64, | |
| data_dir: str = "./data", | |
| num_workers: int = int(cpu_count()), | |
| train_subset_fraction: float = 0.25, # Fraction of training data to use | |
| ): | |
| """ | |
| Initializes the MNIST Data Module with configurations for dataloaders. | |
| Args: | |
| batch_size (int): Batch size for training, validation, and testing. | |
| data_dir (str): Directory to download and store the dataset. | |
| num_workers (int): Number of workers for data loading. | |
| train_subset_fraction (float): Fraction of training data to use (0.0 < fraction <= 1.0). | |
| """ | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.data_dir = data_dir | |
| self.num_workers = num_workers | |
| self.train_subset_fraction = train_subset_fraction | |
| self.transform = transforms.Compose( | |
| [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] | |
| ) | |
| logger.info(f"MNIST DataModule initialized with batch size {self.batch_size}") | |
| def prepare_data(self): | |
| """ | |
| Downloads the MNIST dataset if not already downloaded. | |
| """ | |
| datasets.MNIST(root=self.data_dir, train=True, download=True) | |
| datasets.MNIST(root=self.data_dir, train=False, download=True) | |
| logger.info("MNIST dataset downloaded.") | |
| def setup(self, stage: Optional[str] = None): | |
| """ | |
| Set up the dataset for different stages. | |
| Args: | |
| stage (str, optional): One of "fit", "validate", "test", or "predict". | |
| """ | |
| logger.info(f"Setting up data for stage: {stage}") | |
| if stage == "fit" or stage is None: | |
| full_train_dataset = datasets.MNIST( | |
| root=self.data_dir, train=True, transform=self.transform | |
| ) | |
| train_indices, _ = train_test_split( | |
| range(len(full_train_dataset)), | |
| train_size=self.train_subset_fraction, | |
| random_state=42, | |
| ) | |
| self.mnist_train = Subset(full_train_dataset, train_indices) | |
| self.mnist_val = datasets.MNIST( | |
| root=self.data_dir, train=False, transform=self.transform | |
| ) | |
| logger.info(f"Loaded training subset: {len(self.mnist_train)} samples.") | |
| logger.info(f"Loaded validation data: {len(self.mnist_val)} samples.") | |
| if stage == "test" or stage is None: | |
| self.mnist_test = datasets.MNIST( | |
| root=self.data_dir, train=False, transform=self.transform | |
| ) | |
| logger.info(f"Loaded test data: {len(self.mnist_test)} samples.") | |
| def train_dataloader(self) -> DataLoader: | |
| """ | |
| Returns the training DataLoader. | |
| Returns: | |
| DataLoader: Training data loader. | |
| """ | |
| logger.info("Creating training DataLoader...") | |
| return DataLoader( | |
| self.mnist_train, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=self.num_workers, | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| """ | |
| Returns the validation DataLoader. | |
| Returns: | |
| DataLoader: Validation data loader. | |
| """ | |
| logger.info("Creating validation DataLoader...") | |
| return DataLoader( | |
| self.mnist_val, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| ) | |
| def test_dataloader(self) -> DataLoader: | |
| """ | |
| Returns the test DataLoader. | |
| Returns: | |
| DataLoader: Test data loader. | |
| """ | |
| logger.info("Creating test DataLoader...") | |
| return DataLoader( | |
| self.mnist_test, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| ) | |