zaixizhang
renew
10efe81
import os
import pickle
import lmdb
import torch
from torch.utils.data import Dataset
from tqdm.auto import tqdm
import numpy as np
from ..protein_ligand import PDBProtein, parse_sdf_file
from ..data import ProteinLigandData, torchify_dict
from ..mol_tree import MolTree
def reset_moltree_root(moltree, ligand_pos, protein_pos):
ligand2 = np.sum(np.square(ligand_pos), 1, keepdims=True)
protein2 = np.sum(np.square(protein_pos), 1, keepdims=True)
dist = np.add(np.add(-2 * np.dot(ligand_pos, protein_pos.T), ligand2), protein2.T)
min_dist = np.min(dist, 1)
avg_min_dist = []
for node in moltree.nodes:
avg_min_dist.append(np.min(min_dist[node.clique]))
root = np.argmin(avg_min_dist)
if root > 0:
moltree.nodes[0], moltree.nodes[root] = moltree.nodes[root], moltree.nodes[0]
contact_idx = np.argmin(np.min(dist[moltree.nodes[0].clique], 0))
contact_protein = torch.tensor(np.min(dist, 0) < 4 ** 2)
return moltree, contact_protein, torch.tensor([contact_idx])
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None):
instance = {}
if protein_dict is not None:
for key, item in protein_dict.items():
instance['protein_' + key] = item
if ligand_dict is not None:
for key, item in ligand_dict.items():
if key == 'moltree':
instance['moltree'] = item
else:
instance['ligand_' + key] = item
return instance
class PocketLigandPairDataset(Dataset):
def __init__(self, raw_path, transform=None):
super().__init__()
self.raw_path = raw_path.rstrip('/')
self.index_path = os.path.join(self.raw_path, 'index.pt')
self.processed_path = os.path.join(os.path.dirname(self.raw_path),
os.path.basename(self.raw_path) + '_processed.lmdb')
self.name2id_path = os.path.join(os.path.dirname(self.raw_path),
os.path.basename(self.raw_path) + '_name2id.pt')
self.transform = transform
self.db = None
self.keys = None
if not os.path.exists(self.processed_path):
self._process()
self._precompute_name2id()
self.name2id = torch.load(self.name2id_path)
def _connect_db(self):
"""
Establish read-only database connection
"""
assert self.db is None, 'A connection has already been opened.'
self.db = lmdb.open(
self.processed_path,
map_size=10 * (1024 * 1024 * 1024), # 10GB
create=False,
subdir=False,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
with self.db.begin() as txn:
self.keys = list(txn.cursor().iternext(values=False))
def _close_db(self):
self.db.close()
self.db = None
self.keys = None
def _process(self):
db = lmdb.open(
self.processed_path,
map_size=10 * (1024 * 1024 * 1024), # 10GB
create=True,
subdir=False,
readonly=False, # Writable
)
#with open(self.index_path, 'rb') as f:
#index = pickle.load(f)
index = torch.load(self.index_path)
vocab = []
for line in open('./vocab.txt'):
p, _, _ = line.partition(':')
vocab.append(p)
num_skipped = 0
with db.begin(write=True, buffers=True) as txn:
for i, pdbid in enumerate(tqdm(index)):
if pdbid is None: continue
try:
ligand_fn = os.path.join(pdbid, pdbid + '_ligand.sdf')
pocket_fn = os.path.join(pdbid, pdbid + '_pocket.pdb')
pocket_dict = PDBProtein(os.path.join(self.raw_path, pocket_fn)).to_dict_atom()
ligand_dict = parse_sdf_file(os.path.join(self.raw_path, ligand_fn))
ligand_dict['moltree'], pocket_dict['contact'], pocket_dict['contact_idx'] = reset_moltree_root(
ligand_dict['moltree'],
ligand_dict['pos'],
pocket_dict['pos'])
data = from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict=torchify_dict(ligand_dict),
)
data['protein_filename'] = pocket_fn
data['ligand_filename'] = ligand_fn
data['pdbid'] = pdbid
txn.put(
key=str(i).encode(),
value=pickle.dumps(data)
)
for c in ligand_dict['moltree'].nodes:
smile_cluster = c.smiles
assert smile_cluster in vocab
except:
num_skipped += 1
print('Skipping (%d) %s' % (num_skipped, ligand_fn,))
continue
db.close()
def _precompute_name2id(self):
name2id = {}
for i in tqdm(range(self.__len__()), 'Indexing'):
try:
data = self.__getitem__(i)
except AssertionError as e:
print(i, e)
continue
name = data['pdbid']
name2id[name] = i
torch.save(name2id, self.name2id_path)
def __len__(self):
if self.db is None:
self._connect_db()
return len(self.keys)
def __getitem__(self, idx):
if self.db is None:
self._connect_db()
key = self.keys[idx]
data = pickle.loads(self.db.begin().get(key))
data['id'] = idx
assert data['protein_pos'].size(0) > 0
if self.transform is not None:
data = self.transform(data)
return data
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
args = parser.parse_args()
PocketLigandPairDataset(args.path)