""" https://github.com/ProteinDesignLab/protpardelle License: MIT Author: Alex Chu Dataloader from PDB files. """ import logging import hydra import numpy as np import torch import torch.utils import torch.utils.data import tree from omegaconf import DictConfig import pandas as pd from openfold_data import data_transforms import utils.openfold_rigid_utils as rigid_utils from utils.pdbUtils import read_pkl, parse_chain_feats from torch.utils.data import DataLoader, Dataset from sklearn.model_selection import train_test_split def get_dataloaders(cfg): dataset_cfg = cfg.dataset loader_cfg = cfg.loader num_workers = loader_cfg.num_workers pdb_csv = pd.read_csv(dataset_cfg.csv_path) pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= dataset_cfg.max_num_res] pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= dataset_cfg.min_num_res] print(pdb_csv["class"].value_counts()) train_data, test_data = train_test_split(pdb_csv, test_size=0.2, shuffle=True) train = PdbDataset( train_data, dataset_cfg, is_training=True ) test = PdbDataset( test_data, dataset_cfg, is_training=False ) train_loader = DataLoader( train, batch_size=loader_cfg.batch_size, num_workers=num_workers, prefetch_factor=None if num_workers == 0 else loader_cfg.prefetch_factor, pin_memory=False, persistent_workers=True if num_workers > 0 else False ) val_loader = DataLoader( test, shuffle=False, num_workers=2, prefetch_factor=2, persistent_workers=True ) return train_loader, val_loader class PdbDataset(Dataset): def __init__( self, dataset, dataset_cfg, is_training ): self.pdb_csv = dataset self._log = logging.getLogger(__name__) self._is_training = is_training self._dataset_cfg = dataset_cfg self._init_metadata() self._rng = np.random.default_rng(seed=self._dataset_cfg.seed) @property def is_training(self): return self._is_training @property def dataset_cfg(self): return self._dataset_cfg def _init_metadata(self): self.pdb_csv = self.pdb_csv.sort_values('modeled_seq_len', ascending=False) self.csv = self.pdb_csv self._log.info( f'Dataset: {len(self.csv)} examples.' ) def _process_csv_row(self, processed_file_path): processed_features = read_pkl(processed_file_path) processed_features = parse_chain_feats(processed_features) modeled_idx = processed_features['modeled_idx'] min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx) del processed_features['modeled_idx'] processed_features = tree.map_structure( lambda x: x[min_idx:(max_idx + 1)], processed_features ) chain_features = { 'aatype': torch.tensor(processed_features['aatype']).long(), 'all_atom_positions': torch.tensor(processed_features['atom_positions']).double(), 'all_atom_mask': torch.tensor(processed_features['atom_mask']).double() } chain_features = data_transforms.atom37_to_frames(chain_features) rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_features['rigidgroups_gt_frames'])[:, 0] rotmats_1 = rigids_1.get_rots().get_rot_mats() trans_1 = rigids_1.get_trans() res_idx = processed_features['residue_index'] return { 'aatype': chain_features['aatype'], 'res_idx': res_idx - np.min(res_idx) + 1, 'rotmats_1': rotmats_1, 'trans_1': trans_1, 'res_mask': torch.tensor(processed_features['bb_mask']).int(), } def __len__(self): return len(self.csv) def __getitem__(self, idx): example_idx = idx csv_row = self.csv.iloc[example_idx] class_idx = csv_row["class"] processed_file_path = csv_row['processed_path'] chain_features = self._process_csv_row(processed_file_path) chain_features['csv_idx'] = torch.ones(1, dtype=torch.long) * idx chain_features["class"] = torch.ones(1, dtype=torch.long) * class_idx return chain_features @hydra.main(version_base=None, config_path="./configs", config_name="classifier.yaml") def my_app(cfg: DictConfig) -> None: train, test = get_dataloaders(cfg.data) print(next(iter(train))) print(test.shape) """ data = PdbClfDataModule(cfg.data) data.setup('train') train_loader = data.train_dataloader() val_loader = data.val_dataloader() # data = PdbDataset(dataset_cfg=cfg.data.dataset, is_training=True) print(train_loader) print(val_loader) #print(data[0]) """ if __name__ == '__main__': my_app()