FlowProt / model /data_test.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
4.93 kB
"""
https://github.com/ProteinDesignLab/protpardelle
License: MIT
Author: Alex Chu
Dataloader from PDB files.
"""
import logging
import hydra
import numpy as np
import torch
import torch.utils
import torch.utils.data
import tree
from omegaconf import DictConfig
import pandas as pd
from openfold_data import data_transforms
import utils.openfold_rigid_utils as rigid_utils
from utils.pdbUtils import read_pkl, parse_chain_feats
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
def get_dataloaders(cfg):
dataset_cfg = cfg.dataset
loader_cfg = cfg.loader
num_workers = loader_cfg.num_workers
pdb_csv = pd.read_csv(dataset_cfg.csv_path)
pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= dataset_cfg.max_num_res]
pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= dataset_cfg.min_num_res]
print(pdb_csv["class"].value_counts())
train_data, test_data = train_test_split(pdb_csv, test_size=0.2, shuffle=True)
train = PdbDataset(
train_data,
dataset_cfg,
is_training=True
)
test = PdbDataset(
test_data,
dataset_cfg,
is_training=False
)
train_loader = DataLoader(
train,
batch_size=loader_cfg.batch_size,
num_workers=num_workers,
prefetch_factor=None if num_workers == 0 else loader_cfg.prefetch_factor,
pin_memory=False,
persistent_workers=True if num_workers > 0 else False
)
val_loader = DataLoader(
test,
shuffle=False,
num_workers=2,
prefetch_factor=2,
persistent_workers=True
)
return train_loader, val_loader
class PdbDataset(Dataset):
def __init__(
self,
dataset,
dataset_cfg,
is_training
):
self.pdb_csv = dataset
self._log = logging.getLogger(__name__)
self._is_training = is_training
self._dataset_cfg = dataset_cfg
self._init_metadata()
self._rng = np.random.default_rng(seed=self._dataset_cfg.seed)
@property
def is_training(self):
return self._is_training
@property
def dataset_cfg(self):
return self._dataset_cfg
def _init_metadata(self):
self.pdb_csv = self.pdb_csv.sort_values('modeled_seq_len', ascending=False)
self.csv = self.pdb_csv
self._log.info(
f'Dataset: {len(self.csv)} examples.'
)
def _process_csv_row(self, processed_file_path):
processed_features = read_pkl(processed_file_path)
processed_features = parse_chain_feats(processed_features)
modeled_idx = processed_features['modeled_idx']
min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx)
del processed_features['modeled_idx']
processed_features = tree.map_structure(
lambda x: x[min_idx:(max_idx + 1)], processed_features
)
chain_features = {
'aatype': torch.tensor(processed_features['aatype']).long(),
'all_atom_positions': torch.tensor(processed_features['atom_positions']).double(),
'all_atom_mask': torch.tensor(processed_features['atom_mask']).double()
}
chain_features = data_transforms.atom37_to_frames(chain_features)
rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_features['rigidgroups_gt_frames'])[:, 0]
rotmats_1 = rigids_1.get_rots().get_rot_mats()
trans_1 = rigids_1.get_trans()
res_idx = processed_features['residue_index']
return {
'aatype': chain_features['aatype'],
'res_idx': res_idx - np.min(res_idx) + 1,
'rotmats_1': rotmats_1,
'trans_1': trans_1,
'res_mask': torch.tensor(processed_features['bb_mask']).int(),
}
def __len__(self):
return len(self.csv)
def __getitem__(self, idx):
example_idx = idx
csv_row = self.csv.iloc[example_idx]
class_idx = csv_row["class"]
processed_file_path = csv_row['processed_path']
chain_features = self._process_csv_row(processed_file_path)
chain_features['csv_idx'] = torch.ones(1, dtype=torch.long) * idx
chain_features["class"] = torch.ones(1, dtype=torch.long) * class_idx
return chain_features
@hydra.main(version_base=None, config_path="./configs", config_name="classifier.yaml")
def my_app(cfg: DictConfig) -> None:
train, test = get_dataloaders(cfg.data)
print(next(iter(train)))
print(test.shape)
"""
data = PdbClfDataModule(cfg.data)
data.setup('train')
train_loader = data.train_dataloader()
val_loader = data.val_dataloader()
# data = PdbDataset(dataset_cfg=cfg.data.dataset, is_training=True)
print(train_loader)
print(val_loader)
#print(data[0])
"""
if __name__ == '__main__':
my_app()