Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import pytest | |
| import torch | |
| from src.data.celeba_datamodule import MNISTDataModule | |
| def test_mnist_datamodule(batch_size: int) -> None: | |
| """Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary | |
| attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes | |
| correctly match. | |
| :param batch_size: Batch size of the data to be loaded by the dataloader. | |
| """ | |
| data_dir = "data/" | |
| dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) | |
| dm.prepare_data() | |
| assert not dm.data_train and not dm.data_val and not dm.data_test | |
| assert Path(data_dir, "MNIST").exists() | |
| assert Path(data_dir, "MNIST", "raw").exists() | |
| dm.setup() | |
| assert dm.data_train and dm.data_val and dm.data_test | |
| assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() | |
| num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) | |
| assert num_datapoints == 70_000 | |
| batch = next(iter(dm.train_dataloader())) | |
| x, y = batch | |
| assert len(x) == batch_size | |
| assert len(y) == batch_size | |
| assert x.dtype == torch.float32 | |
| assert y.dtype == torch.int64 | |