Zaixi's picture
1
dcacefd
import os
import pickle
import lmdb
import torch
from torch.utils.data import Dataset
from tqdm.auto import tqdm
import numpy as np
from ..protein_ligand import PDBProtein, parse_sdf_file
from ..data import ProteinLigandData, torchify_dict
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, residue_dict=None, seq=None, full_seq_idx=None, r10_idx=None):
instance = {}
if protein_dict is not None:
for key, item in protein_dict.items():
instance['protein_' + key] = item
if ligand_dict is not None:
for key, item in ligand_dict.items():
instance['ligand_' + key] = item
if residue_dict is not None:
for key, item in residue_dict.items():
instance[key] = item
if seq is not None:
instance['seq'] = seq
if full_seq_idx is not None:
instance['full_seq_idx'] = full_seq_idx
if r10_idx is not None:
instance['r10_idx'] = r10_idx
return instance
class PocketLigandPairDataset(Dataset):
def __init__(self, raw_path, transform=None):
super().__init__()
self.raw_path = raw_path.rstrip('/')
self.index_path = os.path.join(self.raw_path, 'index_seq.pkl')
self.processed_path = os.path.join(os.path.dirname(self.raw_path),
os.path.basename(self.raw_path) + '_processed.lmdb')
self.name2id_path = os.path.join(os.path.dirname(self.raw_path),
os.path.basename(self.raw_path) + '_name2id.pt')
self.transform = transform
self.db = None
self.keys = None
if not os.path.exists(self.processed_path):
self._process()
# self._precompute_name2id()
# self.name2id = torch.load(self.name2id_path)
def _connect_db(self):
"""
Establish read-only database connection
"""
assert self.db is None, 'A connection has already been opened.'
self.db = lmdb.open(
self.processed_path,
map_size=10 * (1024 * 1024 * 1024), # 10GB
create=False,
subdir=False,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
with self.db.begin() as txn:
self.keys = list(txn.cursor().iternext(values=False))
def _close_db(self):
self.db.close()
self.db = None
self.keys = None
def _process(self):
db = lmdb.open(
self.processed_path,
map_size=10 * (1024 * 1024 * 1024), # 10GB
create=True,
subdir=False,
readonly=False, # Writable
)
with open(self.index_path, 'rb') as f:
index = pickle.load(f)
num_skipped = 0
with db.begin(write=True, buffers=True) as txn:
for i, (pocket_fn, ligand_fn, protein_fn, rmsd_str, seq, full_seq_idx, r10_idx) in enumerate(tqdm(index)):
if pocket_fn is None: continue
# if len(seq)>500: continue
try:
pdb_data = PDBProtein(os.path.join(self.raw_path, pocket_fn))
pocket_dict = pdb_data.to_dict_atom()
residue_dict = pdb_data.to_dict_residue()
ligand_dict = parse_sdf_file(os.path.join(self.raw_path, ligand_fn))
_, residue_dict['protein_edit_residue'] = pdb_data.query_residues_ligand(ligand_dict)
assert residue_dict['protein_edit_residue'].sum() > 0 and residue_dict['protein_edit_residue'].sum() == len(full_seq_idx)
assert len(residue_dict['protein_edit_residue']) == len(r10_idx)
full_seq_idx.sort()
r10_idx.sort()
data = from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict=torchify_dict(ligand_dict),
residue_dict=torchify_dict(residue_dict),
seq=seq,
full_seq_idx=torch.tensor(full_seq_idx),
r10_idx=torch.tensor(r10_idx)
)
data['protein_filename'] = pocket_fn
data['ligand_filename'] = ligand_fn
data['whole_protein_name'] = protein_fn
txn.put(
key=str(i).encode(),
value=pickle.dumps(data)
)
except:
num_skipped += 1
print('Skipping (%d) %s' % (num_skipped, ligand_fn,))
continue
db.close()
def _precompute_name2id(self):
name2id = {}
for i in tqdm(range(self.__len__()), 'Indexing'):
try:
data = self.__getitem__(i)
except AssertionError as e:
print(i, e)
continue
name = (data['protein_filename'], data['ligand_filename'])
name2id[name] = i
torch.save(name2id, self.name2id_path)
def __len__(self):
if self.db is None:
self._connect_db()
return len(self.keys)
def __getitem__(self, idx):
if self.db is None:
self._connect_db()
key = self.keys[idx]
data = pickle.loads(self.db.begin().get(key))
data['id'] = idx
assert data['protein_pos'].size(0) > 0
if self.transform is not None:
data = self.transform(data)
return data
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
args = parser.parse_args()
PocketLigandPairDataset(args.path)