Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Any, Dict, Optional | |
| import pandas as pd | |
| import rootutils | |
| from lightning import LightningDataModule | |
| from torch_geometric.data import Dataset | |
| from torch_geometric.loader import DataLoader | |
| rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
| from src.data.components.pinder_dataset import PinderDataset | |
| class PINDERDataModule(LightningDataModule): | |
| """`LightningDataModule` for the PINDER dataset.""" | |
| def __init__( | |
| self, | |
| data_dir: str = "data/processed", | |
| predicted_structures: bool = False, | |
| high_quality: bool = False, | |
| batch_size: int = 1, | |
| num_workers: int = 0, | |
| pin_memory: bool = True, | |
| ) -> None: | |
| """Initialize the `PINDERDataModule`. | |
| Args: | |
| data_dir: Data for pinder. Defaults to "data/processed". | |
| predicted_structures: Whether to use predicted structures. Defaults to True. | |
| batch_size: Batch size. Defaults to 64. | |
| num_workers: Number of workers for parallel processing. Defaults to 0. | |
| pin_memory: Whether to pin memory. Defaults to True. | |
| """ | |
| super().__init__() | |
| # this line allows to access init params with 'self.hparams' attribute | |
| # also ensures init params will be stored in ckpt | |
| self.save_hyperparameters(logger=False) | |
| # get metadata | |
| metadata = pd.read_csv(os.path.join(self.hparams.data_dir, "metadata.csv")) | |
| def get_files(split: str, complex_types: list) -> list: | |
| file_df = metadata[ | |
| (metadata["split"] == split) & (metadata["complex"].isin(complex_types)) | |
| ] | |
| file_df["file_paths"] = file_df.apply( | |
| lambda row: os.path.join( | |
| "./data/processed", row["complex"], row["split"], row["file_paths"] | |
| ), | |
| axis=1, | |
| ) | |
| return file_df["file_paths"].tolist() | |
| complex_types = ["apo", "predicted"] if self.hparams.predicted_structures else ["apo"] | |
| self.train_files = get_files("train", complex_types) | |
| self.val_files = get_files("val", complex_types) | |
| self.test_files = get_files("test", complex_types) | |
| self.data_train: Optional[Dataset] = None | |
| self.data_val: Optional[Dataset] = None | |
| self.data_test: Optional[Dataset] = None | |
| self.batch_size_per_device = batch_size | |
| def setup(self, stage: Optional[str] = None) -> None: | |
| """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
| This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and | |
| `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after | |
| `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to | |
| `self.setup()` once the data is prepared and available for use. | |
| :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. | |
| """ | |
| # Divide batch size by the number of devices. | |
| if self.trainer is not None: | |
| if self.hparams.batch_size % self.trainer.world_size != 0: | |
| raise RuntimeError( | |
| f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." | |
| ) | |
| self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size | |
| # load and split datasets only if not loaded already | |
| if not self.data_train and not self.data_val and not self.data_test: | |
| self.data_train = PinderDataset(self.train_files) | |
| self.data_val = PinderDataset(self.val_files) | |
| self.data_test = PinderDataset(self.test_files) | |
| def train_dataloader(self) -> DataLoader: | |
| """Create and return the train dataloader. | |
| :return: The train dataloader. | |
| """ | |
| return DataLoader( | |
| dataset=self.data_train, | |
| batch_size=self.batch_size_per_device, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=True, | |
| drop_last=True, | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| """Create and return the validation dataloader. | |
| :return: The validation dataloader. | |
| """ | |
| return DataLoader( | |
| dataset=self.data_val, | |
| batch_size=self.batch_size_per_device, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=False, | |
| ) | |
| def test_dataloader(self) -> DataLoader: | |
| """Create and return the test dataloader. | |
| :return: The test dataloader. | |
| """ | |
| return DataLoader( | |
| dataset=self.data_test, | |
| batch_size=self.batch_size_per_device, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=False, | |
| ) | |
| def teardown(self, stage: Optional[str] = None) -> None: | |
| """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, | |
| `trainer.test()`, and `trainer.predict()`. | |
| :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. | |
| Defaults to ``None``. | |
| """ | |
| pass | |
| def state_dict(self) -> Dict[Any, Any]: | |
| """Called when saving a checkpoint. Implement to generate and save the datamodule state. | |
| :return: A dictionary containing the datamodule state that you want to save. | |
| """ | |
| return {} | |
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
| """Called when loading a checkpoint. Implement to reload datamodule state given datamodule | |
| `state_dict()`. | |
| :param state_dict: The datamodule state returned by `self.state_dict()`. | |
| """ | |
| pass | |
| if __name__ == "__main__": | |
| datamodule = PINDERDataModule() | |
| datamodule.setup() | |
| # print(datamodule.train_files[64]) | |
| train_loader = datamodule.train_dataloader() | |
| val_loader = datamodule.val_dataloader() | |
| test_loader = datamodule.test_dataloader() | |
| print(f"Number of training batches: {len(train_loader)}") | |
| print(f"Number of validation batches: {len(val_loader)}") | |
| print(f"Number of test batches: {len(test_loader)}") | |
| print(next(iter(train_loader))) | |