Spaces:
Sleeping
Sleeping
add LLM files
Browse files- generate.py +286 -0
- utils/__init__.py +0 -0
- utils/chem.py +65 -0
- utils/file.py +29 -0
- utils/log.py +32 -0
- utils/plot.py +84 -0
- utils/torch_util.py +32 -0
generate.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pickle as pkl
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
import utils.chem as uc
|
| 10 |
+
import utils.torch_util as ut
|
| 11 |
+
import utils.log as ul
|
| 12 |
+
import utils.plot as up
|
| 13 |
+
import configuration.config_default as cfgd
|
| 14 |
+
import models.dataset as md
|
| 15 |
+
import preprocess.vocabulary as mv
|
| 16 |
+
import configuration.opts as opts
|
| 17 |
+
from models.transformer.module.decode import decode
|
| 18 |
+
from models.transformer.encode_decode.model import EncoderDecoder
|
| 19 |
+
# from models.seq2seq.model import Model
|
| 20 |
+
|
| 21 |
+
def prepare_input(opt):
|
| 22 |
+
''' check if the input file contain pk_diff
|
| 23 |
+
if not the pk_diff will be added
|
| 24 |
+
'''
|
| 25 |
+
df_input = pd.read_csv(os.path.join(opt.data_path, opt.test_file_name + '.csv'), sep=",")
|
| 26 |
+
delta_pkis=['(-0.5, 0.5]','(0.5, 1.5]','(1.5, 2.5]','(2.5, 4.5]','(4.5, 6.5]','(6.5, 8.5]','(8.5, 10.5]','(10.5, inf]']
|
| 27 |
+
# TODO:这是分子到分子的generate数据处理?
|
| 28 |
+
# if "Delta_pki" not in df_input.columns:
|
| 29 |
+
# newData=[]
|
| 30 |
+
# for idx,irow in df_input.iterrows():
|
| 31 |
+
# for idelta_pki in delta_pkis:
|
| 32 |
+
# newData.append([irow['fromVarSMILES'],irow['constantSMILES'],idelta_pki])
|
| 33 |
+
# df_new=pd.DataFrame(newData, columns=['fromVarSMILES','constantSMILES','Delta_pki'])
|
| 34 |
+
# df_new.to_csv(os.path.join(opt.data_path, opt.test_file_name + '_prepared'+'.csv'), index=None)
|
| 35 |
+
# return opt.test_file_name + '_prepared'
|
| 36 |
+
return opt.test_file_name
|
| 37 |
+
|
| 38 |
+
class GenerateRunner():
|
| 39 |
+
|
| 40 |
+
def __init__(self, opt):
|
| 41 |
+
|
| 42 |
+
# self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name,
|
| 43 |
+
# f'evaluation_{opt.epoch}')
|
| 44 |
+
path = Path(os.path.join(opt.save_directory))
|
| 45 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
self.save_path = os.path.join(path)
|
| 47 |
+
self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists()
|
| 48 |
+
self.overwrite = opt.overwrite
|
| 49 |
+
self.dev_no = opt.dev_no
|
| 50 |
+
global LOG
|
| 51 |
+
LOG = ul.get_logger(name="generate",
|
| 52 |
+
log_path=os.path.join(self.save_path, 'generate.log'))
|
| 53 |
+
LOG.info(opt)
|
| 54 |
+
LOG.info("Save directory: {}".format(self.save_path))
|
| 55 |
+
|
| 56 |
+
# Load vocabulary
|
| 57 |
+
with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file:
|
| 58 |
+
vocab = pkl.load(input_file)
|
| 59 |
+
self.vocab = vocab
|
| 60 |
+
self.tokenizer = mv.SMILESTokenizer()
|
| 61 |
+
|
| 62 |
+
def initialize_dataloader(self, opt, vocab, test_file):
|
| 63 |
+
"""
|
| 64 |
+
Initialize dataloader
|
| 65 |
+
:param opt:
|
| 66 |
+
:param vocab: vocabulary
|
| 67 |
+
:param test_file: test_file_name
|
| 68 |
+
:return:
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
# Read test
|
| 72 |
+
data = pd.read_csv(os.path.join(opt.data_path, test_file + '.csv'), sep=",")
|
| 73 |
+
dataset = md.Dataset(data=data, vocabulary=vocab, tokenizer=self.tokenizer, prediction_mode=True)
|
| 74 |
+
dataloader = torch.utils.data.DataLoader(dataset, opt.batch_size,
|
| 75 |
+
shuffle=False, collate_fn=md.Dataset.collate_fn)
|
| 76 |
+
return dataloader
|
| 77 |
+
|
| 78 |
+
def generate(self, opt):
|
| 79 |
+
if not self.overwrite and self.exist_flag:
|
| 80 |
+
print('GENERATED MOL EXIST, SKIP GENERATING!')
|
| 81 |
+
return
|
| 82 |
+
# set device
|
| 83 |
+
#device = ut.allocate_gpu()
|
| 84 |
+
# torch.cuda.set_device(1)
|
| 85 |
+
# current_device = torch.cuda.current_device()
|
| 86 |
+
# print("当前使用的 CUDA 设备编号是:", current_device)
|
| 87 |
+
device = torch.device(f'cuda:{self.dev_no}')
|
| 88 |
+
# 构造loader
|
| 89 |
+
dataloader_test = self.initialize_dataloader(opt, self.vocab, opt.test_file_name)
|
| 90 |
+
|
| 91 |
+
# Load model
|
| 92 |
+
file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
|
| 93 |
+
if opt.model_choice == 'transformer':
|
| 94 |
+
model = EncoderDecoder.load_from_file(file_name)
|
| 95 |
+
model.to(device)
|
| 96 |
+
model.eval()
|
| 97 |
+
elif opt.model_choice == 'seq2seq':
|
| 98 |
+
model = Model.load_from_file(file_name, evaluation_mode=True)
|
| 99 |
+
# move to GPU
|
| 100 |
+
model.network.encoder.to(device)
|
| 101 |
+
model.network.decoder.to(device)
|
| 102 |
+
# TODO: 有没有可能超长?模型崩溃调整长度,规则是2的倍数
|
| 103 |
+
max_len = cfgd.DATA_DEFAULT['max_sequence_length']
|
| 104 |
+
df_list = []
|
| 105 |
+
sampled_smiles_list = []
|
| 106 |
+
for j, batch in enumerate(ul.progress_bar(dataloader_test, total=len(dataloader_test))):
|
| 107 |
+
|
| 108 |
+
# df是dataframe 是一行的原始数据
|
| 109 |
+
src, source_length, _, src_mask, _, _, df = batch
|
| 110 |
+
|
| 111 |
+
# Move to GPU
|
| 112 |
+
src = src.to(device)
|
| 113 |
+
src_mask = src_mask.to(device)
|
| 114 |
+
smiles= self.sample(opt.model_choice, model, src, src_mask,
|
| 115 |
+
source_length,
|
| 116 |
+
opt.decode_type,
|
| 117 |
+
num_samples=opt.num_samples,
|
| 118 |
+
max_len=max_len,
|
| 119 |
+
device=device)
|
| 120 |
+
|
| 121 |
+
df_list.append(df)
|
| 122 |
+
sampled_smiles_list.extend(smiles)
|
| 123 |
+
|
| 124 |
+
# prepare dataframe
|
| 125 |
+
data_sorted = pd.concat(df_list)
|
| 126 |
+
sampled_smiles_list = np.array(sampled_smiles_list)
|
| 127 |
+
|
| 128 |
+
for i in range(opt.num_samples):
|
| 129 |
+
data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i]
|
| 130 |
+
|
| 131 |
+
result_path = os.path.join(self.save_path, "generated_molecules.csv")
|
| 132 |
+
LOG.info("Save to {}".format(result_path))
|
| 133 |
+
data_sorted.to_csv(result_path, index=False)
|
| 134 |
+
|
| 135 |
+
def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10,
|
| 136 |
+
max_len=cfgd.DATA_DEFAULT['max_sequence_length'],
|
| 137 |
+
device=None):
|
| 138 |
+
batch_size = src.shape[0]
|
| 139 |
+
num_valid_batch = np.zeros(batch_size) # current number of unique and valid samples out of total sampled
|
| 140 |
+
num_valid_batch_total = np.zeros(batch_size) # current number of sampling times no matter unique or valid
|
| 141 |
+
num_valid_batch_desired = np.asarray([num_samples] * batch_size)
|
| 142 |
+
unique_set_num_samples = [set() for i in range(batch_size)] # for each starting molecule
|
| 143 |
+
batch_index = torch.LongTensor(range(batch_size))
|
| 144 |
+
batch_index_current = torch.LongTensor(range(batch_size)).to(device)
|
| 145 |
+
# TODO:这个好像没有用到?
|
| 146 |
+
start_mols = []
|
| 147 |
+
# zeros correspondes to ****** which is valid according to RDKit
|
| 148 |
+
sequences_all = torch.ones((num_samples, batch_size, max_len))
|
| 149 |
+
sequences_all = sequences_all.type(torch.LongTensor)
|
| 150 |
+
max_trials = 100000 # Maximum trials for sampling
|
| 151 |
+
current_trials = 0
|
| 152 |
+
|
| 153 |
+
# greedy意思是只尝试一次生成,成了就有分子式,没成的话就没有
|
| 154 |
+
if decode_type == 'greedy':
|
| 155 |
+
max_trials = 1
|
| 156 |
+
|
| 157 |
+
# Set of unique starting molecules
|
| 158 |
+
if src is not None:
|
| 159 |
+
# 这里需要修改,delta_value并不是放在第一位置
|
| 160 |
+
start_ind = len(cfgd.PROPERTIES)
|
| 161 |
+
for ibatch in range(batch_size):
|
| 162 |
+
source_smi = self.tokenizer.untokenize(self.vocab.decode(src[ibatch].tolist()[start_ind:]))
|
| 163 |
+
source_smi = uc.get_canonical_smile(source_smi)
|
| 164 |
+
if source_smi:
|
| 165 |
+
# 先添加source,用于后面去重,TODO: 但这里也不太对,因为这里已经是被mmpdb分开的,而不是一个完整的SMILES
|
| 166 |
+
unique_set_num_samples[ibatch].add(source_smi)
|
| 167 |
+
start_mols.append(source_smi)
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
if model_choice == 'seq2seq':
|
| 171 |
+
encoder_outputs, decoder_hidden = model.network.encoder(src, source_length)
|
| 172 |
+
while not all(num_valid_batch >= num_valid_batch_desired) and current_trials < max_trials:
|
| 173 |
+
current_trials += 1
|
| 174 |
+
|
| 175 |
+
# batch input for current trial
|
| 176 |
+
if src is not None:
|
| 177 |
+
# 这个不就是全选嘛?
|
| 178 |
+
src_current = src.index_select(0, batch_index_current)
|
| 179 |
+
if src_mask is not None:
|
| 180 |
+
mask_current = src_mask.index_select(0, batch_index_current)
|
| 181 |
+
batch_size = src_current.shape[0]
|
| 182 |
+
|
| 183 |
+
# sample molecule
|
| 184 |
+
if model_choice == 'transformer':
|
| 185 |
+
sequences = decode(model, src_current, mask_current, max_len, decode_type)
|
| 186 |
+
padding = (0, max_len-sequences.shape[1],
|
| 187 |
+
0, 0)
|
| 188 |
+
sequences = torch.nn.functional.pad(sequences, padding)
|
| 189 |
+
elif model_choice == 'seq2seq':
|
| 190 |
+
sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden,
|
| 191 |
+
encoder_outputs, max_len, device)
|
| 192 |
+
else:
|
| 193 |
+
LOG.info('Specify transformer or seq2seq for model_choice')
|
| 194 |
+
|
| 195 |
+
# Check valid and unique
|
| 196 |
+
smiles = []
|
| 197 |
+
is_valid_index = []
|
| 198 |
+
batch_index_map = dict(zip(list(range(batch_size)), batch_index_current))
|
| 199 |
+
# Valid, ibatch index is different from original, need map back
|
| 200 |
+
for ibatch in range(batch_size):
|
| 201 |
+
seq = sequences[ibatch]
|
| 202 |
+
smi = self.tokenizer.untokenize(self.vocab.decode(seq.cpu().numpy()))
|
| 203 |
+
smi = uc.get_canonical_smile(smi)
|
| 204 |
+
smiles.append(smi)
|
| 205 |
+
# valid and not same as starting molecules
|
| 206 |
+
if uc.is_valid(smi):
|
| 207 |
+
is_valid_index.append(ibatch)
|
| 208 |
+
# total sampled times
|
| 209 |
+
num_valid_batch_total[batch_index_map[ibatch]] += 1
|
| 210 |
+
|
| 211 |
+
# Check if duplicated and update num_valid_batch and unique
|
| 212 |
+
for good_index in is_valid_index:
|
| 213 |
+
index_in_original_batch = batch_index_map[good_index]
|
| 214 |
+
if smiles[good_index] not in unique_set_num_samples[index_in_original_batch]:
|
| 215 |
+
unique_set_num_samples[index_in_original_batch].add(smiles[good_index])
|
| 216 |
+
num_valid_batch[index_in_original_batch] += 1
|
| 217 |
+
|
| 218 |
+
sequences_all[int(num_valid_batch[index_in_original_batch] - 1), index_in_original_batch, :] = \
|
| 219 |
+
sequences[good_index]
|
| 220 |
+
|
| 221 |
+
not_completed_index = np.where(num_valid_batch < num_valid_batch_desired)[0]
|
| 222 |
+
# 选择未生成满的source样本继续生成
|
| 223 |
+
if len(not_completed_index) > 0:
|
| 224 |
+
batch_index_current = batch_index.index_select(0, torch.LongTensor(not_completed_index)).to(device)
|
| 225 |
+
|
| 226 |
+
# Convert to SMILES
|
| 227 |
+
smiles_list = [] # [batch, topk]
|
| 228 |
+
seqs = np.asarray(sequences_all.numpy())
|
| 229 |
+
# [num_sample, batch_size, max_len]
|
| 230 |
+
batch_size = len(seqs[0])
|
| 231 |
+
for ibatch in range(batch_size):
|
| 232 |
+
topk_list = []
|
| 233 |
+
for k in range(num_samples):
|
| 234 |
+
seq = seqs[k, ibatch, :]
|
| 235 |
+
topk_list.extend([self.tokenizer.untokenize(self.vocab.decode(seq))])
|
| 236 |
+
smiles_list.append(topk_list)
|
| 237 |
+
|
| 238 |
+
return smiles_list
|
| 239 |
+
|
| 240 |
+
def sample_seq2seq(self, model, mask, batch_index_current, decoder_hidden, encoder_outputs, max_len, device):
|
| 241 |
+
# batch size will change when some of the generated molecules are valid
|
| 242 |
+
encoder_outputs_current = encoder_outputs.index_select(0, batch_index_current)
|
| 243 |
+
batch_size = encoder_outputs_current.shape[0]
|
| 244 |
+
|
| 245 |
+
# start token
|
| 246 |
+
start_token = torch.zeros(batch_size, dtype=torch.long)
|
| 247 |
+
start_token[:] = self.vocab["^"]
|
| 248 |
+
decoder_input = start_token.to(device)
|
| 249 |
+
sequences = []
|
| 250 |
+
mask = torch.squeeze(mask, 1).to(device)
|
| 251 |
+
|
| 252 |
+
# initial decoder hidden states
|
| 253 |
+
if isinstance(decoder_hidden, tuple):
|
| 254 |
+
decoder_hidden_current = (decoder_hidden[0].index_select(1, batch_index_current),
|
| 255 |
+
decoder_hidden[1].index_select(1, batch_index_current))
|
| 256 |
+
else:
|
| 257 |
+
decoder_hidden_current = decoder_hidden.index_select(1, batch_index_current)
|
| 258 |
+
for i in range(max_len):
|
| 259 |
+
logits, decoder_hidden_current = model.network.decoder(decoder_input.unsqueeze(1),
|
| 260 |
+
decoder_hidden_current,
|
| 261 |
+
encoder_outputs_current, mask)
|
| 262 |
+
logits = logits.squeeze(1)
|
| 263 |
+
probabilities = logits.softmax(dim=1) # torch.Size([batch_size, vocab_size])
|
| 264 |
+
topi = torch.multinomial(probabilities, 1) # torch.Size([batch_size, 1])
|
| 265 |
+
decoder_input = topi.view(-1).detach()
|
| 266 |
+
sequences.append(decoder_input.view(-1, 1))
|
| 267 |
+
|
| 268 |
+
sequences = torch.cat(sequences, 1)
|
| 269 |
+
return sequences
|
| 270 |
+
|
| 271 |
+
def run_main():
|
| 272 |
+
"""Main function."""
|
| 273 |
+
parser = argparse.ArgumentParser(
|
| 274 |
+
description='generate.py',
|
| 275 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 276 |
+
|
| 277 |
+
opts.generate_opts(parser)
|
| 278 |
+
opt = parser.parse_args()
|
| 279 |
+
opt.test_file_name = prepare_input(opt)
|
| 280 |
+
|
| 281 |
+
runner = GenerateRunner(opt)
|
| 282 |
+
runner.generate(opt)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
if __name__ == "__main__":
|
| 286 |
+
run_main()
|
utils/__init__.py
ADDED
|
File without changes
|
utils/chem.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RDKit util functions.
|
| 3 |
+
"""
|
| 4 |
+
import rdkit.Chem as rkc
|
| 5 |
+
from rdkit.Chem import AllChem
|
| 6 |
+
from rdkit import DataStructs
|
| 7 |
+
|
| 8 |
+
def disable_rdkit_logging():
|
| 9 |
+
"""
|
| 10 |
+
Disables RDKit whiny logging.
|
| 11 |
+
"""
|
| 12 |
+
import rdkit.RDLogger as rkl
|
| 13 |
+
logger = rkl.logger()
|
| 14 |
+
logger.setLevel(rkl.ERROR)
|
| 15 |
+
|
| 16 |
+
import rdkit.rdBase as rkrb
|
| 17 |
+
rkrb.DisableLog('rdApp.error')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
disable_rdkit_logging()
|
| 21 |
+
|
| 22 |
+
def to_fp_ECFP(smi):
|
| 23 |
+
if smi:
|
| 24 |
+
mol = rkc.MolFromSmiles(smi)
|
| 25 |
+
if mol is None:
|
| 26 |
+
return None
|
| 27 |
+
return AllChem.GetMorganFingerprint(mol, 2)
|
| 28 |
+
|
| 29 |
+
def tanimoto_similarity_pool(args):
|
| 30 |
+
return tanimoto_similarity(*args)
|
| 31 |
+
|
| 32 |
+
def tanimoto_similarity(smi1, smi2):
|
| 33 |
+
fp1, fp2 = None, None
|
| 34 |
+
if smi1 and type(smi1)==str and len(smi1)>0:
|
| 35 |
+
fp1 = to_fp_ECFP(smi1)
|
| 36 |
+
if smi2 and type(smi2)==str and len(smi2)>0:
|
| 37 |
+
fp2 = to_fp_ECFP(smi2)
|
| 38 |
+
|
| 39 |
+
if fp1 is not None and fp2 is not None:
|
| 40 |
+
return DataStructs.TanimotoSimilarity(fp1, fp2)
|
| 41 |
+
else:
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
def is_valid(smi):
|
| 45 |
+
return 1 if to_mol(smi) else 0
|
| 46 |
+
|
| 47 |
+
def to_mol(smi):
|
| 48 |
+
"""
|
| 49 |
+
Creates a Mol object from a SMILES string.
|
| 50 |
+
:param smi: SMILES string.
|
| 51 |
+
:return: A Mol object or None if it's not valid.
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(smi, str) and smi and len(smi)>0 and smi != 'nan':
|
| 54 |
+
return rkc.MolFromSmiles(smi)
|
| 55 |
+
|
| 56 |
+
def get_canonical_smile(smile):
|
| 57 |
+
if smile != 'None':
|
| 58 |
+
mol = rkc.MolFromSmiles(smile)
|
| 59 |
+
if mol is not None:
|
| 60 |
+
smi = rkc.MolToSmiles(mol, canonical=True, doRandom=False, isomericSmiles=False)
|
| 61 |
+
return smi
|
| 62 |
+
else:
|
| 63 |
+
return None
|
| 64 |
+
else:
|
| 65 |
+
return None
|
utils/file.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
def make_directory(file, is_dir=True):
|
| 4 |
+
dirs = file.split('/')[:-1] if not is_dir else file.split('/')
|
| 5 |
+
path = '/' if file.startswith('/') else ''
|
| 6 |
+
for dir in dirs:
|
| 7 |
+
path = os.path.join(path, dir)
|
| 8 |
+
if not os.path.exists(path):
|
| 9 |
+
os.makedirs(path)
|
| 10 |
+
|
| 11 |
+
def get_parent_dir(file):
|
| 12 |
+
dirs = file.split('/')[:-1]
|
| 13 |
+
path = ''
|
| 14 |
+
for dir in dirs:
|
| 15 |
+
path = os.path.join(path, dir)
|
| 16 |
+
if file.startswith('/'):
|
| 17 |
+
path = '/' + path
|
| 18 |
+
return path
|
| 19 |
+
|
| 20 |
+
def chunkIt(seq, num):
|
| 21 |
+
avg = len(seq) / float(num)
|
| 22 |
+
out = []
|
| 23 |
+
last = 0.0
|
| 24 |
+
|
| 25 |
+
while last < len(seq):
|
| 26 |
+
out.append(seq[int(last):int(last + avg)])
|
| 27 |
+
last += avg
|
| 28 |
+
|
| 29 |
+
return out
|
utils/log.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import tqdm
|
| 3 |
+
|
| 4 |
+
import utils.file as uf
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_logger(name, log_path, isMain=False, level=logging.INFO):
|
| 8 |
+
formatter = logging.Formatter(
|
| 9 |
+
fmt="%(asctime)s: %(module)s.%(funcName)s +%(lineno)s: %(levelname)-8s %(message)s",
|
| 10 |
+
datefmt="%H:%M:%S"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(name)
|
| 14 |
+
logger.setLevel(level)
|
| 15 |
+
|
| 16 |
+
# # Logging to console
|
| 17 |
+
stream_handler = logging.StreamHandler()
|
| 18 |
+
stream_handler.setFormatter(formatter)
|
| 19 |
+
logger.addHandler(stream_handler)
|
| 20 |
+
|
| 21 |
+
# Logging to a file
|
| 22 |
+
if isMain:
|
| 23 |
+
uf.make_directory(log_path, is_dir=False)
|
| 24 |
+
file_handler = logging.FileHandler(log_path)
|
| 25 |
+
file_handler.setFormatter(formatter)
|
| 26 |
+
logger.addHandler(file_handler)
|
| 27 |
+
|
| 28 |
+
return logger
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def progress_bar(iterable, total, **kwargs):
|
| 32 |
+
return tqdm.tqdm(iterable=iterable, total=total, ascii=True, **kwargs)
|
utils/plot.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import matplotlib as mpl
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from scipy.stats import gaussian_kde
|
| 7 |
+
mpl.use('Agg')
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def hist_box(data_frame, field, name="hist_box", path="./", title=None):
|
| 11 |
+
|
| 12 |
+
title = title if title else field
|
| 13 |
+
fig, axs = plt.subplots(1,2,figsize=(10,4))
|
| 14 |
+
data_frame[field].plot.hist(bins=100, title=title, ax=axs[0])
|
| 15 |
+
data_frame.boxplot(field, ax=axs[1])
|
| 16 |
+
plt.title(title)
|
| 17 |
+
plt.suptitle("")
|
| 18 |
+
|
| 19 |
+
plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
|
| 20 |
+
plt.close()
|
| 21 |
+
|
| 22 |
+
def hist(data_frame, field, name="hist", path="./", title=None):
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
title = title if title else field
|
| 26 |
+
|
| 27 |
+
plt.hist(data_frame[field])
|
| 28 |
+
plt.title(title)
|
| 29 |
+
plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
|
| 30 |
+
plt.close()
|
| 31 |
+
|
| 32 |
+
def hist_box_list(data_list, name="hist_box", path="./", title=None):
|
| 33 |
+
|
| 34 |
+
fig, axs = plt.subplots(1,2,figsize=(10,4))
|
| 35 |
+
axs[0].hist(data_list, bins=100)
|
| 36 |
+
axs[0].set_title(title)
|
| 37 |
+
axs[1].boxplot(data_list)
|
| 38 |
+
axs[1].set_title(title)
|
| 39 |
+
|
| 40 |
+
plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
|
| 41 |
+
plt.close()
|
| 42 |
+
|
| 43 |
+
def scatter_hist(x, y, name, path, field=None, lims=None):
|
| 44 |
+
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
|
| 45 |
+
n = len(x)
|
| 46 |
+
xy = np.vstack([x+ 0.00001 * np.random.rand(n), y+ 0.00001 * np.random.rand(n)])
|
| 47 |
+
z = gaussian_kde(xy)(xy)
|
| 48 |
+
axs[0].scatter(x, y, c=z, s=3, marker='o', alpha=0.2)
|
| 49 |
+
lims = [np.min([axs[0].get_xlim(), axs[0].get_ylim()]), np.max([axs[0].get_xlim(), axs[0].get_ylim()])] if lims is None else lims
|
| 50 |
+
axs[0].plot(lims, lims, 'k-', alpha=0.75)
|
| 51 |
+
axs[0].set_aspect('equal')
|
| 52 |
+
axs[0].set_xlim(lims)
|
| 53 |
+
axs[0].set_ylim(lims)
|
| 54 |
+
xlabel = ""
|
| 55 |
+
ylabel = ""
|
| 56 |
+
if "delta" in field:
|
| 57 |
+
if "data" in field:
|
| 58 |
+
axs[0].set_xlabel(r'$\Delta LogD$ (experimental)')
|
| 59 |
+
axs[0].set_ylabel(r'$\Delta LogD$ (calculated)')
|
| 60 |
+
xlabel = 'Delta LogD (experimental)'
|
| 61 |
+
ylabel = 'Delta LogD (calculated)'
|
| 62 |
+
elif "predict" in field:
|
| 63 |
+
axs[0].set_xlabel(r'$\Delta LogD$ (desirable)')
|
| 64 |
+
axs[0].set_ylabel(r'$\Delta LogD$ (generated)')
|
| 65 |
+
xlabel = 'Delta LogD (desirable)'
|
| 66 |
+
ylabel = 'Delta LogD (generated)'
|
| 67 |
+
if "single" in field:
|
| 68 |
+
if "data" in field:
|
| 69 |
+
xlabel, ylabel = 'LogD (experimental)', 'LogD (calculated)'
|
| 70 |
+
axs[0].set_xlabel(xlabel)
|
| 71 |
+
axs[0].set_ylabel(ylabel)
|
| 72 |
+
elif "predict" in field:
|
| 73 |
+
xlabel, ylabel = 'LogD (desirable)', 'LogD (generated)'
|
| 74 |
+
axs[0].set_xlabel(xlabel)
|
| 75 |
+
axs[0].set_ylabel(ylabel)
|
| 76 |
+
bins = np.histogram(np.hstack((x, y)), bins=100)[1] # get the bin edges
|
| 77 |
+
kwargs = dict(histtype='stepfilled', alpha=0.3, density=False, bins=bins, stacked=False)
|
| 78 |
+
axs[1].hist(x, **kwargs, color='b', label=xlabel)
|
| 79 |
+
axs[1].hist(y, **kwargs, color='r', label=ylabel)
|
| 80 |
+
plt.ylabel('Frequency')
|
| 81 |
+
plt.legend(loc='upper left')
|
| 82 |
+
plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
|
| 83 |
+
plt.close()
|
| 84 |
+
|
utils/torch_util.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch related util functions
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
def allocate_gpu(id=None):
|
| 7 |
+
'''
|
| 8 |
+
choose the free gpu in the node
|
| 9 |
+
'''
|
| 10 |
+
v = torch.empty(1)
|
| 11 |
+
if id is not None:
|
| 12 |
+
return torch.device("cuda:{}".format(str(id)))
|
| 13 |
+
else:
|
| 14 |
+
for i in range(8):
|
| 15 |
+
try:
|
| 16 |
+
dev_name = "cuda:{}".format(str(i))
|
| 17 |
+
v = v.to(dev_name)
|
| 18 |
+
print("Allocating cuda:{}.".format(i))
|
| 19 |
+
|
| 20 |
+
return v.device
|
| 21 |
+
except Exception as e:
|
| 22 |
+
pass
|
| 23 |
+
print("CUDA error: all CUDA-capable devices are busy or unavailable")
|
| 24 |
+
return v.device
|
| 25 |
+
|
| 26 |
+
def allocate_gpu_multi(id=None):
|
| 27 |
+
|
| 28 |
+
os.environ['CUDA_VISIBLE_DEVICES']='1'
|
| 29 |
+
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
|
| 30 |
+
os.environ['CUDA_VISIBLE_DEVICES']='0'
|
| 31 |
+
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
|
| 32 |
+
return device
|