File size: 4,934 Bytes
f34af6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
https://github.com/ProteinDesignLab/protpardelle
License: MIT
Author: Alex Chu
Dataloader from PDB files.
"""
import logging
import hydra
import numpy as np
import torch
import torch.utils
import torch.utils.data
import tree
from omegaconf import DictConfig
import pandas as pd
from openfold_data import data_transforms
import utils.openfold_rigid_utils as rigid_utils
from utils.pdbUtils import read_pkl, parse_chain_feats
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
def get_dataloaders(cfg):
dataset_cfg = cfg.dataset
loader_cfg = cfg.loader
num_workers = loader_cfg.num_workers
pdb_csv = pd.read_csv(dataset_cfg.csv_path)
pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= dataset_cfg.max_num_res]
pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= dataset_cfg.min_num_res]
print(pdb_csv["class"].value_counts())
train_data, test_data = train_test_split(pdb_csv, test_size=0.2, shuffle=True)
train = PdbDataset(
train_data,
dataset_cfg,
is_training=True
)
test = PdbDataset(
test_data,
dataset_cfg,
is_training=False
)
train_loader = DataLoader(
train,
batch_size=loader_cfg.batch_size,
num_workers=num_workers,
prefetch_factor=None if num_workers == 0 else loader_cfg.prefetch_factor,
pin_memory=False,
persistent_workers=True if num_workers > 0 else False
)
val_loader = DataLoader(
test,
shuffle=False,
num_workers=2,
prefetch_factor=2,
persistent_workers=True
)
return train_loader, val_loader
class PdbDataset(Dataset):
def __init__(
self,
dataset,
dataset_cfg,
is_training
):
self.pdb_csv = dataset
self._log = logging.getLogger(__name__)
self._is_training = is_training
self._dataset_cfg = dataset_cfg
self._init_metadata()
self._rng = np.random.default_rng(seed=self._dataset_cfg.seed)
@property
def is_training(self):
return self._is_training
@property
def dataset_cfg(self):
return self._dataset_cfg
def _init_metadata(self):
self.pdb_csv = self.pdb_csv.sort_values('modeled_seq_len', ascending=False)
self.csv = self.pdb_csv
self._log.info(
f'Dataset: {len(self.csv)} examples.'
)
def _process_csv_row(self, processed_file_path):
processed_features = read_pkl(processed_file_path)
processed_features = parse_chain_feats(processed_features)
modeled_idx = processed_features['modeled_idx']
min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx)
del processed_features['modeled_idx']
processed_features = tree.map_structure(
lambda x: x[min_idx:(max_idx + 1)], processed_features
)
chain_features = {
'aatype': torch.tensor(processed_features['aatype']).long(),
'all_atom_positions': torch.tensor(processed_features['atom_positions']).double(),
'all_atom_mask': torch.tensor(processed_features['atom_mask']).double()
}
chain_features = data_transforms.atom37_to_frames(chain_features)
rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_features['rigidgroups_gt_frames'])[:, 0]
rotmats_1 = rigids_1.get_rots().get_rot_mats()
trans_1 = rigids_1.get_trans()
res_idx = processed_features['residue_index']
return {
'aatype': chain_features['aatype'],
'res_idx': res_idx - np.min(res_idx) + 1,
'rotmats_1': rotmats_1,
'trans_1': trans_1,
'res_mask': torch.tensor(processed_features['bb_mask']).int(),
}
def __len__(self):
return len(self.csv)
def __getitem__(self, idx):
example_idx = idx
csv_row = self.csv.iloc[example_idx]
class_idx = csv_row["class"]
processed_file_path = csv_row['processed_path']
chain_features = self._process_csv_row(processed_file_path)
chain_features['csv_idx'] = torch.ones(1, dtype=torch.long) * idx
chain_features["class"] = torch.ones(1, dtype=torch.long) * class_idx
return chain_features
@hydra.main(version_base=None, config_path="./configs", config_name="classifier.yaml")
def my_app(cfg: DictConfig) -> None:
train, test = get_dataloaders(cfg.data)
print(next(iter(train)))
print(test.shape)
"""
data = PdbClfDataModule(cfg.data)
data.setup('train')
train_loader = data.train_dataloader()
val_loader = data.val_dataloader()
# data = PdbDataset(dataset_cfg=cfg.data.dataset, is_training=True)
print(train_loader)
print(val_loader)
#print(data[0])
"""
if __name__ == '__main__':
my_app()
|