|
|
from rdkit import RDLogger |
|
|
from augmentation import * |
|
|
|
|
|
|
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
from rdkit import Chem |
|
|
import random |
|
|
import torch |
|
|
|
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import torch |
|
|
import pandas as pd |
|
|
from rdkit.Chem import AllChem |
|
|
import argparse |
|
|
from SmilesPE.pretokenizer import atomwise_tokenizer |
|
|
|
|
|
import pathlib |
|
|
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem, DataStructs |
|
|
import numpy as np |
|
|
from itertools import combinations |
|
|
import re |
|
|
from collections import defaultdict |
|
|
import partialsmiles as ps |
|
|
|
|
|
from collections import OrderedDict |
|
|
from SmilesPE.pretokenizer import atomwise_tokenizer |
|
|
|
|
|
class AtomwiseTokenizer(): |
|
|
def __init__(self, str_bos="<can>", str_eos="<eos>"): |
|
|
self.bos_token = str_bos |
|
|
self.eos_token = str_eos |
|
|
def tokenize(self, smiles): |
|
|
return atomwise_tokenizer(smiles) |
|
|
def convert_tokens_to_string(self, tokens): |
|
|
return "".join(tokens) |
|
|
def assign_vocab(self, vocab): |
|
|
self.vocab = vocab |
|
|
self.vocab_inv = {v: k for k, v in vocab.items()} |
|
|
self.eos_token_id = vocab[self.eos_token] |
|
|
self.bos_token_id = vocab[self.bos_token] |
|
|
def decode(self, ids,skip_special_tokens=True): |
|
|
if isinstance(ids, torch.Tensor): |
|
|
return "".join([self.vocab_inv[id] for id in ids.cpu().numpy()]) |
|
|
return "".join([self.vocab_inv[id] for id in ids]) |
|
|
|
|
|
|
|
|
def gen_psv_table(partial_smiles, vocab,eos_str,sep_str,partial_valid): |
|
|
psv_table = [] |
|
|
for token in vocab.keys(): |
|
|
if token == eos_str or token == sep_str: |
|
|
psv_table.append(partial_valid) |
|
|
else: |
|
|
try: |
|
|
mol = ps.ParseSmiles(partial_smiles + token, partial=True) |
|
|
assert mol is not None |
|
|
psv_table.append(True) |
|
|
except: |
|
|
psv_table.append(False) |
|
|
return psv_table |
|
|
|
|
|
def calculate_bm_scaffold(smiles): |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
|
|
|
return Chem.MolToSmiles(GetScaffoldForMol(mol)) |
|
|
except: |
|
|
return None |
|
|
|
|
|
def get_morgan_fp(smiles, radius=2, n_bits=2048): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return None |
|
|
return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits) |
|
|
|
|
|
def compute_internal_diversity(smiles_list): |
|
|
fps = [get_morgan_fp(sm) for sm in smiles_list] |
|
|
fps = [fp for fp in fps if fp is not None] |
|
|
if len(fps) < 2: |
|
|
return 0.0 |
|
|
similarities = [] |
|
|
for fp1, fp2 in combinations(fps, 2): |
|
|
sim = DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
similarities.append(sim) |
|
|
mean_sim = np.mean(similarities) |
|
|
int_div = 1 - mean_sim |
|
|
return int_div |
|
|
|
|
|
def atomwise_tokenizer_fixed(x): |
|
|
list_subSMILES = [atomwise_tokenizer(subSMILES) for subSMILES in x.split("|")] |
|
|
y_in = list_subSMILES[0] |
|
|
for i in range(len(list_subSMILES)-1): |
|
|
y_in += ["|"] + list_subSMILES[i+1] |
|
|
return y_in |
|
|
|
|
|
|
|
|
|
|
|
def customized_forward(model, x_in, y_in, y_out=None,boundary=None, return_last_hidden_state=False): |
|
|
x_in = model.drop(model.tok_emb(x_in) + model.pos_emb[:, :x_in.size()[1], :]) |
|
|
y_in = model.drop(model.tok_emb(y_in) + model.pos_emb[:, :y_in.size()[1], :]) |
|
|
|
|
|
for encoder_block in model.encoder_blocks: |
|
|
x_in = encoder_block(x_in) |
|
|
x_in = model.ln_f(x_in) |
|
|
for decoder_block in model.decoder_blocks: |
|
|
y_in = decoder_block(x_in,y_in) |
|
|
y_in = model.ln_f(y_in) |
|
|
logits = model.head(y_in) |
|
|
loss = None |
|
|
if y_out is not None: |
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y_out.view(-1)) |
|
|
if return_last_hidden_state: |
|
|
return logits, y_in |
|
|
else: |
|
|
return logits, loss |
|
|
|
|
|
def path_aligned_generation( |
|
|
model, |
|
|
tokenizer, |
|
|
max_length=256, |
|
|
batch_size=128, |
|
|
device="cuda:0", |
|
|
budget_generation=10, |
|
|
sample_suffix="Cc1ccccc1", |
|
|
tensor_scaffold=None, |
|
|
boundary=None, |
|
|
n_generation=100, |
|
|
supress_eos=False, |
|
|
max_molwt=1000, |
|
|
max_clogp=10, |
|
|
max_rotatable_bond=10, |
|
|
use_merge=True, |
|
|
top_k=0, |
|
|
top_p=1., |
|
|
min_prefix_length=4, |
|
|
typical_sampling=False, |
|
|
contrastive_search=False, |
|
|
pre_check_merge=False, |
|
|
): |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
generated_smiles = OrderedDict() |
|
|
dict_inchikey_count = defaultdict(int) |
|
|
dict_inchikey_merged_path = defaultdict(OrderedDict) |
|
|
dict_path_inchikey = {} |
|
|
iteration_counter = 0 |
|
|
total_merge_count = 0 |
|
|
n_calls = 0 |
|
|
n_repeated = 0 |
|
|
n_supressed_eos = 0 |
|
|
n_invalid = 0 |
|
|
count_merged = 0 |
|
|
with torch.no_grad(): |
|
|
while len(generated_smiles) < n_generation: |
|
|
tensor_generation = torch.zeros(batch_size,2).long().to(device) |
|
|
tensor_generation[:,0] = tokenizer.bos_token_id |
|
|
tensor_generation[:,1] = tokenizer.vocab["[*]"] |
|
|
for step_idx in range(1,max_length-1): |
|
|
inputs = tensor_generation[:,:step_idx+1].to(device) |
|
|
|
|
|
if tensor_scaffold is not None: |
|
|
logits, base_h = customized_forward(model, tensor_scaffold[:inputs.shape[0]], inputs, None, boundary, return_last_hidden_state=True) |
|
|
|
|
|
logits = logits[:,-1,:] |
|
|
n_calls += inputs.shape[0] |
|
|
else: |
|
|
outputs = model.forward(inputs) |
|
|
logits = outputs.logits[:,-1,:] |
|
|
n_calls += inputs.shape[0] |
|
|
|
|
|
list_supress_eos = [] |
|
|
list_merged_idx = [] |
|
|
list_finished_idx = [] |
|
|
list_invalid_idx = [] |
|
|
filter_value = -float('Inf') |
|
|
if top_k > 0: |
|
|
indices_to_remove = logits < torch.topk(logits,top_k,dim=-1)[0][:,[-1]] |
|
|
logits[indices_to_remove] = filter_value |
|
|
if top_p < 1.: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits,dim=-1),dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[...,1:] = sorted_indices_to_remove[...,:-1].clone() |
|
|
sorted_indices_to_remove[...,0] = 0 |
|
|
sorted_logits[sorted_indices_to_remove] = filter_value |
|
|
logits = torch.gather(sorted_logits, -1, sorted_indices.argsort(-1)) |
|
|
next_token_id = torch.multinomial(F.softmax(logits,dim=-1),num_samples=1) |
|
|
|
|
|
current_prefix = [tokenizer.decode(tensor_generation[sample_idx,1:step_idx+1]) for sample_idx in range(tensor_generation.shape[0])] |
|
|
|
|
|
if step_idx > 0: |
|
|
for sample_idx, current_decoded in enumerate(current_prefix): |
|
|
mol = None |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(current_decoded) |
|
|
except: |
|
|
mol = None |
|
|
if mol is not None and current_decoded not in generated_smiles: |
|
|
generated_smiles[current_decoded] = 1 |
|
|
list_finished_idx.append(sample_idx) |
|
|
keep_mask = torch.ones(tensor_generation.shape[0], dtype=torch.bool) |
|
|
keep_mask[list_finished_idx] = False |
|
|
tensor_generation = torch.cat([tensor_generation[keep_mask],next_token_id[keep_mask]],dim=1) |
|
|
|
|
|
if tensor_generation.shape[0] == 0: |
|
|
break |
|
|
str_print = f"Iteration {iteration_counter:05d}" |
|
|
str_print += f" step {step_idx:05d}" |
|
|
str_print += f" merged_t {total_merge_count:05d}" |
|
|
str_print += f" merged_c {count_merged:05d}" |
|
|
str_print += f" dict_prefix {len(dict_path_inchikey):05d}" |
|
|
str_print += f" dict_inch {len(dict_inchikey_merged_path):05d}" |
|
|
|
|
|
str_print += f" gen_c {tensor_generation.shape[0]:05d}" |
|
|
str_print += f" gen_t {len(generated_smiles):08d}" |
|
|
str_print += f" n_calls {n_calls:08d}" |
|
|
str_print += f" n_repeated {n_repeated:05d}" |
|
|
|
|
|
str_print += f" n_invalid {n_invalid:05d}" |
|
|
|
|
|
print(str_print) |
|
|
iteration_counter += 1 |
|
|
total_merge_count += count_merged |
|
|
return generated_smiles, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated |
|
|
|
|
|
|
|
|
|
|
|
ATTACHMENT_POINT_TOKEN = "*" |
|
|
ATTACHMENT_POINT_NUM_REGEXP = r"\[{}:(\d+)\]".format(re.escape(ATTACHMENT_POINT_TOKEN)) |
|
|
ATTACHMENT_POINT_REGEXP = r"(?:{0}|\[{0}[^\]]*\])".format(re.escape(ATTACHMENT_POINT_TOKEN)) |
|
|
ATTACHMENT_POINT_NO_BRACKETS_REGEXP = r"(?<!\[){}".format(re.escape(ATTACHMENT_POINT_TOKEN)) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--save_dir", type=str, default="entropy/gpt2_zinc_87m") |
|
|
parser.add_argument("--model_name", type=str, default="gpt2_zinc_87m") |
|
|
parser.add_argument("--generate_mode", type=str, default="scaffold_decorator") |
|
|
parser.add_argument("--filepath_scaffold", type=str, default="/shared/healthinfolab/xiw14035/TF_debug/SCMG/SCMG/20250505/scaf_5.smi") |
|
|
parser.add_argument("--model_path", type=str, default="") |
|
|
parser.add_argument("--n_to_gen", type=int, default=100) |
|
|
parser.add_argument("--max_length", type=int, default=30) |
|
|
parser.add_argument("--max_molwt", type=float, default=500) |
|
|
parser.add_argument("--max_clogp", type=float, default=4.5) |
|
|
parser.add_argument("--max_rotatable_bond", type=int, default=8) |
|
|
parser.add_argument("--min_prefix_length", type=int, default=4) |
|
|
parser.add_argument("--top_p", type=float, default=1.0) |
|
|
parser.add_argument("--top_k", type=int, default=10) |
|
|
|
|
|
parser.add_argument("--scaffold", type=str, default="[*]c1ccccc1") |
|
|
|
|
|
parser.add_argument("--decode_methods", type=str, default="Structure-Aware_Decoding") |
|
|
args = parser.parse_args() |
|
|
|
|
|
pathlib.Path(args.save_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
model = torch.load("src/clm/model_new_torch.pt",weights_only=False, map_location="cpu") |
|
|
vocab = model.vocab_encoder |
|
|
tokenizer = AtomwiseTokenizer(str_bos="<scmg_char_cano>", str_eos="<eos>") |
|
|
tokenizer.assign_vocab(vocab) |
|
|
tokenizer.sep_token = "|" |
|
|
tokenizer.sep_token_id = vocab[tokenizer.sep_token] |
|
|
|
|
|
|
|
|
def path_aligned_generation_supress_eos(model,tokenizer,max_length=256,n_generation=100,batch_size=128,device="cuda:0",tensor_scaffold=None,boundary=None,budget_generation=10,max_molwt=1000,max_clogp=10,max_rotatable_bond=10): |
|
|
return path_aligned_generation(model,tokenizer,max_length=max_length,n_generation=n_generation,batch_size=batch_size,device=device,tensor_scaffold=tensor_scaffold,boundary=boundary,budget_generation=budget_generation,supress_eos=True,max_molwt=max_molwt,max_clogp=max_clogp,max_rotatable_bond=max_rotatable_bond) |
|
|
|
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
budget_generation = 10 |
|
|
batch_size = 512 |
|
|
|
|
|
|
|
|
scaf_smi = args.scaffold |
|
|
|
|
|
if len(scaf_smi) > 0: |
|
|
if "[*]" not in scaf_smi: |
|
|
raise ValueError("Scaffold does not contain attachment point") |
|
|
sequence_scaffold = [tokenizer.bos_token_id] + [vocab[a] for a in tokenizer.tokenize(scaf_smi)] + [tokenizer.eos_token_id] |
|
|
tensor_scaffold = torch.tensor(sequence_scaffold).unsqueeze(0).to(device).repeat(batch_size,1) |
|
|
boundary = torch.zeros(batch_size,1).long().to(device) + tensor_scaffold.shape[1] + 1 |
|
|
else: |
|
|
tensor_scaffold = None |
|
|
boundary = None |
|
|
|
|
|
df_result = pd.DataFrame(columns=["n_to_gen", "gen_func_name", "internal_diversity", "n_bm_scaffold"]) |
|
|
|
|
|
|
|
|
seed_value = 42 |
|
|
random.seed(seed_value) |
|
|
np.random.seed(seed_value) |
|
|
torch.manual_seed(seed_value) |
|
|
torch.cuda.manual_seed(seed_value) |
|
|
torch.cuda.manual_seed_all(seed_value) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
n_to_gen = args.n_to_gen |
|
|
generated_smiles_raw, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated = path_aligned_generation( |
|
|
model, |
|
|
tokenizer=tokenizer, |
|
|
max_length=args.max_length, |
|
|
n_generation=n_to_gen, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
tensor_scaffold=tensor_scaffold, |
|
|
boundary=boundary, |
|
|
budget_generation=budget_generation, |
|
|
max_molwt=args.max_molwt, |
|
|
max_clogp=args.max_clogp, |
|
|
max_rotatable_bond=args.max_rotatable_bond, |
|
|
use_merge=True, |
|
|
min_prefix_length=args.min_prefix_length |
|
|
) |
|
|
generated_smiles = dict([(smiles.split("<can>")[-1], freq) for smiles, freq in generated_smiles_raw.items()]) |
|
|
|
|
|
pd.DataFrame({ |
|
|
"smiles": list(generated_smiles.keys()), |
|
|
"count": list(generated_smiles.values()) |
|
|
}).to_csv("generated_molecules.csv", index=False) |
|
|
|