Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import lmdb | |
| import torch | |
| from torch.utils.data import Dataset | |
| from tqdm.auto import tqdm | |
| import numpy as np | |
| from ..protein_ligand import PDBProtein, parse_sdf_file | |
| from ..data import ProteinLigandData, torchify_dict | |
| def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, residue_dict=None, seq=None, full_seq_idx=None, r10_idx=None): | |
| instance = {} | |
| if protein_dict is not None: | |
| for key, item in protein_dict.items(): | |
| instance['protein_' + key] = item | |
| if ligand_dict is not None: | |
| for key, item in ligand_dict.items(): | |
| instance['ligand_' + key] = item | |
| if residue_dict is not None: | |
| for key, item in residue_dict.items(): | |
| instance[key] = item | |
| if seq is not None: | |
| instance['seq'] = seq | |
| if full_seq_idx is not None: | |
| instance['full_seq_idx'] = full_seq_idx | |
| if r10_idx is not None: | |
| instance['r10_idx'] = r10_idx | |
| return instance | |
| class PocketLigandPairDataset(Dataset): | |
| def __init__(self, raw_path, transform=None): | |
| super().__init__() | |
| self.raw_path = raw_path.rstrip('/') | |
| self.index_path = os.path.join(self.raw_path, 'index_seq.pkl') | |
| self.processed_path = os.path.join(os.path.dirname(self.raw_path), | |
| os.path.basename(self.raw_path) + '_processed.lmdb') | |
| self.name2id_path = os.path.join(os.path.dirname(self.raw_path), | |
| os.path.basename(self.raw_path) + '_name2id.pt') | |
| self.transform = transform | |
| self.db = None | |
| self.keys = None | |
| if not os.path.exists(self.processed_path): | |
| self._process() | |
| # self._precompute_name2id() | |
| # self.name2id = torch.load(self.name2id_path) | |
| def _connect_db(self): | |
| """ | |
| Establish read-only database connection | |
| """ | |
| assert self.db is None, 'A connection has already been opened.' | |
| self.db = lmdb.open( | |
| self.processed_path, | |
| map_size=10 * (1024 * 1024 * 1024), # 10GB | |
| create=False, | |
| subdir=False, | |
| readonly=True, | |
| lock=False, | |
| readahead=False, | |
| meminit=False, | |
| ) | |
| with self.db.begin() as txn: | |
| self.keys = list(txn.cursor().iternext(values=False)) | |
| def _close_db(self): | |
| self.db.close() | |
| self.db = None | |
| self.keys = None | |
| def _process(self): | |
| db = lmdb.open( | |
| self.processed_path, | |
| map_size=10 * (1024 * 1024 * 1024), # 10GB | |
| create=True, | |
| subdir=False, | |
| readonly=False, # Writable | |
| ) | |
| with open(self.index_path, 'rb') as f: | |
| index = pickle.load(f) | |
| num_skipped = 0 | |
| with db.begin(write=True, buffers=True) as txn: | |
| for i, (pocket_fn, ligand_fn, protein_fn, rmsd_str, seq, full_seq_idx, r10_idx) in enumerate(tqdm(index)): | |
| if pocket_fn is None: continue | |
| # if len(seq)>500: continue | |
| try: | |
| pdb_data = PDBProtein(os.path.join(self.raw_path, pocket_fn)) | |
| pocket_dict = pdb_data.to_dict_atom() | |
| residue_dict = pdb_data.to_dict_residue() | |
| ligand_dict = parse_sdf_file(os.path.join(self.raw_path, ligand_fn)) | |
| _, residue_dict['protein_edit_residue'] = pdb_data.query_residues_ligand(ligand_dict) | |
| assert residue_dict['protein_edit_residue'].sum() > 0 and residue_dict['protein_edit_residue'].sum() == len(full_seq_idx) | |
| assert len(residue_dict['protein_edit_residue']) == len(r10_idx) | |
| full_seq_idx.sort() | |
| r10_idx.sort() | |
| data = from_protein_ligand_dicts( | |
| protein_dict=torchify_dict(pocket_dict), | |
| ligand_dict=torchify_dict(ligand_dict), | |
| residue_dict=torchify_dict(residue_dict), | |
| seq=seq, | |
| full_seq_idx=torch.tensor(full_seq_idx), | |
| r10_idx=torch.tensor(r10_idx) | |
| ) | |
| data['protein_filename'] = pocket_fn | |
| data['ligand_filename'] = ligand_fn | |
| data['whole_protein_name'] = protein_fn | |
| txn.put( | |
| key=str(i).encode(), | |
| value=pickle.dumps(data) | |
| ) | |
| except: | |
| num_skipped += 1 | |
| print('Skipping (%d) %s' % (num_skipped, ligand_fn,)) | |
| continue | |
| db.close() | |
| def _precompute_name2id(self): | |
| name2id = {} | |
| for i in tqdm(range(self.__len__()), 'Indexing'): | |
| try: | |
| data = self.__getitem__(i) | |
| except AssertionError as e: | |
| print(i, e) | |
| continue | |
| name = (data['protein_filename'], data['ligand_filename']) | |
| name2id[name] = i | |
| torch.save(name2id, self.name2id_path) | |
| def __len__(self): | |
| if self.db is None: | |
| self._connect_db() | |
| return len(self.keys) | |
| def __getitem__(self, idx): | |
| if self.db is None: | |
| self._connect_db() | |
| key = self.keys[idx] | |
| data = pickle.loads(self.db.begin().get(key)) | |
| data['id'] = idx | |
| assert data['protein_pos'].size(0) > 0 | |
| if self.transform is not None: | |
| data = self.transform(data) | |
| return data | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('path', type=str) | |
| args = parser.parse_args() | |
| PocketLigandPairDataset(args.path) | |