Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,016 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import inspect
from torch.utils.data import DataLoader
from src.interface.data_interface import DInterface_base
import torch
import os.path as osp
from src.tools.utils import cuda
import pdb
from src.tools.utils import load_yaml_config
class MyDataLoader(DataLoader):
def __init__(self, dataset, model_name, batch_size=64, num_workers=8, *args, **kwargs):
super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs)
self.pretrain_device = 'cuda:0'
self.model_name = model_name
def __iter__(self):
for batch in super().__iter__():
# 在这里对batch进行处理
# ...
try:
self.pretrain_device = f'cuda:{torch.distributed.get_rank()}'
except:
self.pretrain_device = 'cuda:0'
stream = torch.cuda.Stream(
self.pretrain_device
)
with torch.cuda.stream(stream):
if self.model_name=='GVP':
batch = batch.cuda(non_blocking=True, device=self.pretrain_device)
yield batch
else:
for key, val in batch.items():
if type(val) == torch.Tensor:
batch[key] = batch[key].cuda(non_blocking=True, device=self.pretrain_device)
# X = batch['X'].cuda(non_blocking=True, device=self.pretrain_device)
# S = batch['S'].cuda(non_blocking=True, device=self.pretrain_device)
# score = batch['score'].cuda(non_blocking=True, device=self.pretrain_device)
# mask = batch['mask'].cuda(non_blocking=True, device=self.pretrain_device)
# lengths = batch['lengths'].cuda(non_blocking=True, device=self.pretrain_device)
# chain_mask = batch['chain_mask'].cuda(non_blocking=True, device=self.pretrain_device)
# chain_encoding = batch['chain_encoding'].cuda(non_blocking=True, device=self.pretrain_device)
yield batch
class DInterface(DInterface_base):
def __init__(self,**kwargs):
super().__init__(**kwargs)
self.save_hyperparameters()
self.load_data_module()
def setup(self, stage=None):
from src.datasets.featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
featurize_ProteinMPNN, featurize_Inversefolding)
if self.hparams.model_name in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN', 'GCA', 'E3PiFold']:
self.collate_fn = featurize_GTrans
elif self.hparams.model_name == 'GVP':
featurizer = featurize_GVP()
self.collate_fn = featurizer.collate
elif self.hparams.model_name == 'ProteinMPNN':
self.collate_fn = featurize_ProteinMPNN
elif self.hparams.model_name == 'ESMIF':
self.collate_fn = featurize_Inversefolding
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
self.trainset = self.instancialize(split = 'train')
self.valset = self.instancialize(split='valid')
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.testset = self.instancialize(split='test')
if stage in ['predict','eval']:
self.predictset = self.instancialize(split='predict')
def train_dataloader(self):
return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=self.collate_fn)
def val_dataloader(self):
return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
def test_dataloader(self):
return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
def predict_dataloader(self):
return MyDataLoader(self.predictset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
def load_data_module(self):
name = self.hparams.dataset
if name == 'AF2DB':
from src.datasets.AF2DB_dataset_lmdb import Af2dbDataset
self.data_module = Af2dbDataset
if name == 'TS':
from src.datasets.ts_dataset import TSDataset
self.data_module = TSDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'ts')
if name == 'CASP15':
from src.datasets.casp_dataset import CASPDataset
self.data_module = CASPDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'casp15')
if name == 'CATH4.2':
from src.datasets.cath_dataset import CATHDataset
self.data_module = CATHDataset
self.hparams['version'] = 4.2
self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.2')
if name == 'CATH4.3':
from src.datasets.cath_dataset import CATHDataset
self.data_module = CATHDataset
self.hparams['version'] = 4.3
self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3')
if name == 'MPNN':
from src.datasets.mpnn_dataset import MPNNDataset
self.data_module = MPNNDataset
if name == 'FOLDSWITCHERS_1':
from src.datasets.foldswitchers_dataset import FoldswitchersDataset
self.data_module = FoldswitchersDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_1')
if name == 'FOLDSWITCHERS_2':
from src.datasets.foldswitchers_dataset import FoldswitchersDataset
self.data_module = FoldswitchersDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_2')
if name == 'PDBInference':
from src.datasets.pdb_inference import PDBInference
self.data_module = PDBInference
self.hparams['path'] = osp.join(self.hparams.infer_path)
if name == 'ATLAS_DIST_1':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_1')
if name == 'ATLAS_DIST_2':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_2')
if name == 'ATLAS_CLUSTER_1':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_1')
if name == 'ATLAS_CLUSTER_2':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_2')
if name == 'ATLAS_PDB':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_pdb_inference/')
if name == 'ATLAS_FULL_MINIMIZED':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/minimized_PDBs/pdbs/')
if name == 'ATLAS_FULL_REFOLDED':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/refolded_PDBs/pdbs/')
if name == 'ATLAS_FULL_CRYSTAL':
from src.datasets.atlas_dataset import AtlasDataset
self.data_module = AtlasDataset
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/crystal_PDBs/pdbs/')
if name == 'FLEX_CATH4.3':
from src.datasets.flex_cath_dataset import FlexCATHDataset
self.data_module = FlexCATHDataset
self.hparams['version'] = 4.3
self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3')
def instancialize(self, **other_args):
""" Instancialize a model using the corresponding parameters
from self.hparams dictionary. You can also input any args
to overwrite the corresponding value in self.kwargs.
"""
class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
inkeys = self.hparams.keys()
args1 = {}
for arg in class_args:
if arg in inkeys:
args1[arg] = self.hparams[arg]
args1.update(other_args)
# if self.hparams['test_engineering'] and self.hparams['use_dynamics']:
# args1['data_jsonl_name'] = self.hparams['test_eng_data_path']
#elif self.hparams['use_dynamics']:
if self.hparams['use_dynamics']:
args1['data_jsonl_name'] = load_yaml_config('configs/ANMAwareFlexibilityProtTrans.yaml')['data_jsonl_name']
# import pdb; pdb.set_trace()
return self.data_module(**args1) #Here this leads to __init__ of the class dataset |