Spaces:
Running
on
Zero
Running
on
Zero
| import inspect | |
| from torch.utils.data import DataLoader | |
| from src.interface.data_interface import DInterface_base | |
| import torch | |
| import os.path as osp | |
| from src.tools.utils import cuda | |
| import pdb | |
| from src.tools.utils import load_yaml_config | |
| class MyDataLoader(DataLoader): | |
| def __init__(self, dataset, model_name, batch_size=64, num_workers=8, *args, **kwargs): | |
| super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs) | |
| self.pretrain_device = 'cuda:0' | |
| self.model_name = model_name | |
| def __iter__(self): | |
| for batch in super().__iter__(): | |
| # 在这里对batch进行处理 | |
| # ... | |
| try: | |
| self.pretrain_device = f'cuda:{torch.distributed.get_rank()}' | |
| except: | |
| self.pretrain_device = 'cuda:0' | |
| stream = torch.cuda.Stream( | |
| self.pretrain_device | |
| ) | |
| with torch.cuda.stream(stream): | |
| if self.model_name=='GVP': | |
| batch = batch.cuda(non_blocking=True, device=self.pretrain_device) | |
| yield batch | |
| else: | |
| for key, val in batch.items(): | |
| if type(val) == torch.Tensor: | |
| batch[key] = batch[key].cuda(non_blocking=True, device=self.pretrain_device) | |
| # X = batch['X'].cuda(non_blocking=True, device=self.pretrain_device) | |
| # S = batch['S'].cuda(non_blocking=True, device=self.pretrain_device) | |
| # score = batch['score'].cuda(non_blocking=True, device=self.pretrain_device) | |
| # mask = batch['mask'].cuda(non_blocking=True, device=self.pretrain_device) | |
| # lengths = batch['lengths'].cuda(non_blocking=True, device=self.pretrain_device) | |
| # chain_mask = batch['chain_mask'].cuda(non_blocking=True, device=self.pretrain_device) | |
| # chain_encoding = batch['chain_encoding'].cuda(non_blocking=True, device=self.pretrain_device) | |
| yield batch | |
| class DInterface(DInterface_base): | |
| def __init__(self,**kwargs): | |
| super().__init__(**kwargs) | |
| self.save_hyperparameters() | |
| self.load_data_module() | |
| def setup(self, stage=None): | |
| from src.datasets.featurizer import (featurize_AF, featurize_GTrans, featurize_GVP, | |
| featurize_ProteinMPNN, featurize_Inversefolding) | |
| if self.hparams.model_name in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN', 'GCA', 'E3PiFold']: | |
| self.collate_fn = featurize_GTrans | |
| elif self.hparams.model_name == 'GVP': | |
| featurizer = featurize_GVP() | |
| self.collate_fn = featurizer.collate | |
| elif self.hparams.model_name == 'ProteinMPNN': | |
| self.collate_fn = featurize_ProteinMPNN | |
| elif self.hparams.model_name == 'ESMIF': | |
| self.collate_fn = featurize_Inversefolding | |
| # Assign train/val datasets for use in dataloaders | |
| if stage == 'fit' or stage is None: | |
| self.trainset = self.instancialize(split = 'train') | |
| self.valset = self.instancialize(split='valid') | |
| # Assign test dataset for use in dataloader(s) | |
| if stage == 'test' or stage is None: | |
| self.testset = self.instancialize(split='test') | |
| if stage in ['predict','eval']: | |
| self.predictset = self.instancialize(split='predict') | |
| def train_dataloader(self): | |
| return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=self.collate_fn) | |
| def val_dataloader(self): | |
| return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn) | |
| def test_dataloader(self): | |
| return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn) | |
| def predict_dataloader(self): | |
| return MyDataLoader(self.predictset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn) | |
| def load_data_module(self): | |
| name = self.hparams.dataset | |
| if name == 'AF2DB': | |
| from src.datasets.AF2DB_dataset_lmdb import Af2dbDataset | |
| self.data_module = Af2dbDataset | |
| if name == 'TS': | |
| from src.datasets.ts_dataset import TSDataset | |
| self.data_module = TSDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'ts') | |
| if name == 'CASP15': | |
| from src.datasets.casp_dataset import CASPDataset | |
| self.data_module = CASPDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'casp15') | |
| if name == 'CATH4.2': | |
| from src.datasets.cath_dataset import CATHDataset | |
| self.data_module = CATHDataset | |
| self.hparams['version'] = 4.2 | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.2') | |
| if name == 'CATH4.3': | |
| from src.datasets.cath_dataset import CATHDataset | |
| self.data_module = CATHDataset | |
| self.hparams['version'] = 4.3 | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3') | |
| if name == 'MPNN': | |
| from src.datasets.mpnn_dataset import MPNNDataset | |
| self.data_module = MPNNDataset | |
| if name == 'FOLDSWITCHERS_1': | |
| from src.datasets.foldswitchers_dataset import FoldswitchersDataset | |
| self.data_module = FoldswitchersDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_1') | |
| if name == 'FOLDSWITCHERS_2': | |
| from src.datasets.foldswitchers_dataset import FoldswitchersDataset | |
| self.data_module = FoldswitchersDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_2') | |
| if name == 'PDBInference': | |
| from src.datasets.pdb_inference import PDBInference | |
| self.data_module = PDBInference | |
| self.hparams['path'] = osp.join(self.hparams.infer_path) | |
| if name == 'ATLAS_DIST_1': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_1') | |
| if name == 'ATLAS_DIST_2': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_2') | |
| if name == 'ATLAS_CLUSTER_1': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_1') | |
| if name == 'ATLAS_CLUSTER_2': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_2') | |
| if name == 'ATLAS_PDB': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_pdb_inference/') | |
| if name == 'ATLAS_FULL_MINIMIZED': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/minimized_PDBs/pdbs/') | |
| if name == 'ATLAS_FULL_REFOLDED': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/refolded_PDBs/pdbs/') | |
| if name == 'ATLAS_FULL_CRYSTAL': | |
| from src.datasets.atlas_dataset import AtlasDataset | |
| self.data_module = AtlasDataset | |
| self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/crystal_PDBs/pdbs/') | |
| if name == 'FLEX_CATH4.3': | |
| from src.datasets.flex_cath_dataset import FlexCATHDataset | |
| self.data_module = FlexCATHDataset | |
| self.hparams['version'] = 4.3 | |
| self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3') | |
| def instancialize(self, **other_args): | |
| """ Instancialize a model using the corresponding parameters | |
| from self.hparams dictionary. You can also input any args | |
| to overwrite the corresponding value in self.kwargs. | |
| """ | |
| class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:] | |
| inkeys = self.hparams.keys() | |
| args1 = {} | |
| for arg in class_args: | |
| if arg in inkeys: | |
| args1[arg] = self.hparams[arg] | |
| args1.update(other_args) | |
| # if self.hparams['test_engineering'] and self.hparams['use_dynamics']: | |
| # args1['data_jsonl_name'] = self.hparams['test_eng_data_path'] | |
| #elif self.hparams['use_dynamics']: | |
| if self.hparams['use_dynamics']: | |
| args1['data_jsonl_name'] = load_yaml_config('configs/ANMAwareFlexibilityProtTrans.yaml')['data_jsonl_name'] | |
| # import pdb; pdb.set_trace() | |
| return self.data_module(**args1) #Here this leads to __init__ of the class dataset |