Spaces:
Sleeping
Sleeping
| from typing import List | |
| import __main__ | |
| import rootutils | |
| import torch | |
| from torch_geometric.data import Dataset | |
| # setup root dir and pythonpath | |
| rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
| from src.data.components.prepare_data import CropPairedPDB | |
| setattr(__main__, "CropPairedPDB", CropPairedPDB) | |
| class PinderDataset(Dataset): | |
| """Pinder dataset. | |
| Args: | |
| Dataset: PyTorch Geometric Dataset. | |
| """ | |
| def __init__(self, file_paths: List[str]) -> None: | |
| """Initialize the PinderDataset. | |
| Args: | |
| file_paths: List of file paths. | |
| """ | |
| super().__init__() | |
| self.file_paths = file_paths | |
| def processed_file_names(self) -> List[str]: | |
| """Return the processed file names. | |
| Returns: | |
| List[str]: List of processed | |
| """ | |
| return self.file_paths | |
| def len(self) -> int: | |
| """Return the length of the dataset. | |
| Returns: | |
| int: Length of the dataset | |
| """ | |
| return len(self.processed_file_names) | |
| def get(self, idx) -> CropPairedPDB: | |
| """Get the data at the given index. | |
| Args: | |
| idx: Index of the data. | |
| Returns: | |
| CropPairedPDB: CropPairedPDB object. | |
| """ | |
| data = torch.load(self.processed_file_names[idx], weights_only=False) | |
| return data | |
| if __name__ == "__main__": | |
| file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"] | |
| dataset = PinderDataset(file_paths=file_paths) | |
| print(dataset[0]) | |