Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import transformers | |
| import numpy as np | |
| from transformers import AutoTokenizer | |
| from transformers import pipeline | |
| import random | |
| import deepchem | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| import regex as re | |
| model_name = f"cafierom/bert-base-cased-ChemTok-ZN250K-V1" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name,padding = True, truncation = True) | |
| mask_filler = pipeline("fill-mask", model_name) | |
| def tokenize(batch): | |
| return tokenizer(batch["text"], padding=True, truncation=True, max_length=250, return_special_tokens_mask=True) | |
| def gen_from_multimask(text, print_flag=True, mask_flag="random", percent = 0.10, top_k = 3): | |
| """ | |
| Takes a SMILES string and tokenizes it. Depending on the mask flag, it then masks the | |
| requested percentage of tokens in the string either randomly, at the begining (first) or at | |
| the end (last). The masked string is then sent to the mask filler, and the result is expanded | |
| into all possible new strings where the top k beams are selected and used if their probability | |
| is greater than 0.1. Entropy is also calculated for each beam. | |
| Args: | |
| text: The SMILES string of the original molecule. | |
| Returns: | |
| final_smiles: a list of all the generated molecules. | |
| total_entropy: a list of the entropy of each generated molecule. | |
| """ | |
| new_tok_list = [] | |
| single_tok = tokenizer(text, padding=True, truncation=True, max_length=250, return_special_tokens_mask=True) | |
| length_count = 0 | |
| for token in single_tok["input_ids"]: | |
| if token != 0: | |
| length_count += 1 | |
| if mask_flag == "last": | |
| masked_tokens = [*range(int(length_count*(1.0-percent))-1,length_count-1)] | |
| elif mask_flag == "first": | |
| masked_tokens = [*range(0,int(length_count*percent))] | |
| elif mask_flag == "random": | |
| masked_tokens = random.sample(range(1, length_count), int(length_count*percent)) | |
| for j,token in enumerate(single_tok["input_ids"]): | |
| if token != 0: | |
| if j in masked_tokens: | |
| new_tok_list.append(103) | |
| else: | |
| new_tok_list.append(token) | |
| masked_smile = tokenizer.decode(new_tok_list, | |
| skip_special_tokens=False).replace("[PAD]","").replace("[SEP]","").replace("[CLS]","").replace(" ","") | |
| result = mask_filler(masked_smile,top_k=top_k) | |
| new_smiles = [] | |
| total_batch = [] | |
| total_entropy = [] | |
| for i in range(len(result)): | |
| batch_smiles = [] | |
| batch_entropy = [] | |
| for j in range(top_k): | |
| p = result[i][j]["score"] | |
| if result[i][j]["score"] > 0.1: | |
| if i == 0: | |
| new_smile = result[i][j]["sequence"].replace(" ","").replace("[SEP]","").replace("[CLS]","") | |
| batch_smiles.append(new_smile) | |
| batch_entropy.append(-p*np.log(p)) | |
| else: | |
| for smile,entropy in zip(total_batch[i-1],total_entropy[i-1]): | |
| new_smile = smile.replace("[MASK]",result[i][j]["token_str"],1) | |
| batch_smiles.append(new_smile) | |
| new_entropy = entropy - p*np.log(p) | |
| batch_entropy.append(new_entropy) | |
| total_entropy.append(batch_entropy) | |
| total_batch.append(batch_smiles) | |
| final_smiles = [] | |
| for smile in total_batch[-1]: | |
| new_smile = smile.replace("##","") | |
| final_smiles.append(new_smile) | |
| if print_flag: | |
| print(f"original: {text}") | |
| final_smiles.insert(0,text) | |
| for smile in final_smiles: | |
| print(f"generated: {smile}") | |
| return final_smiles,total_entropy[-1] | |
| def validate_smiles(in_smiles, in_entropy): | |
| """ | |
| Takes a list of SMILES strings checks to see if the compile to valid MOL objects. | |
| Valid molecules are then converted to canonical SMILES strings and duplicates are | |
| dropped. | |
| Args: | |
| text: The SMILES string of the original molecule. | |
| Returns: | |
| unique_smiles: a list of all the unique, valid generated molecules. | |
| unique_entropies: a list of the entropy of each generated molecule. | |
| """ | |
| valid_smiles = [] | |
| valid_entropies = [] | |
| unique_smiles = [] | |
| unique_entropies = [] | |
| for smile,entropy in zip(in_smiles,in_entropy): | |
| try: | |
| mol = Chem.MolFromSmiles(smile) | |
| if mol is not None: | |
| valid_smiles.append(smile) | |
| valid_entropies.append(entropy) | |
| except: | |
| print("Could not convert to mol") | |
| canon_smiles = [Chem.CanonSmiles(smile) for smile in valid_smiles] | |
| for smile,entropy in zip(canon_smiles,valid_entropies): | |
| if smile not in unique_smiles: | |
| unique_smiles.append(smile) | |
| unique_entropies.append(entropy) | |
| print(f"Total unique SMILES generated: {len(unique_smiles)}") | |
| print(f"Average entropy: {sum(unique_entropies)/len(unique_entropies)}") | |
| return unique_smiles,unique_entropies | |
| def calc_qed(smiles): | |
| mols = [Chem.MolFromSmiles(smile) for smile in smiles] | |
| qed = [Chem.QED.default(mol) for mol in mols] | |
| return qed,mols | |
| def gen_mask(smile_in: str, percent_mask: float) -> str: | |
| """ | |
| Generate Analogues of a hit for hit expansion using generative mask-filling. | |
| The molecule corresponding to the input smiles is masked in different, | |
| random ways, creating various masked versions of the molelcule. | |
| A model, cafierom/bert-base-cased-ChemTok-ZN250K-V1, | |
| is used to generate SMILES strings for analogue molecules by unmasking the | |
| masked versions. All possibilities created by the generative mask-filling | |
| are kept as long as the probability is greater than a cut-off, which is set | |
| to 0.1 but which may be changed. The QED value, or quantitative estimate of druglikeness, a weighted average of | |
| various ADME properties is also calculated. A value of 1.0 is perfect | |
| drug-likeness, and a value of 0.0 is not drug-like. A value of 0.5 is average for many drugs. | |
| Args: | |
| smile: The SMILES string of the original molecule. | |
| Returns: | |
| out_text: a string with all of the SMILES for the generated molecules | |
| and their QED values. | |
| pic: An image of the molecules with QED values. | |
| """ | |
| which_statins = [smile_in] | |
| percent_to_use = percent_mask | |
| try: | |
| main_smiles = [] | |
| main_entropy = [] | |
| for statin in which_statins: | |
| result, calc_entropy = gen_from_multimask(statin, print_flag=False, mask_flag = "first", percent=percent_to_use) | |
| for smile,entropy in zip(result,calc_entropy): | |
| if smile not in main_smiles: | |
| main_smiles.append(smile) | |
| main_entropy.append(entropy) | |
| length = len(main_smiles) | |
| print(f"First masking generated {length} SMILES") | |
| result, calc_entropy = gen_from_multimask(statin, print_flag=False, mask_flag = "last", percent=percent_to_use) | |
| for smile,entropy in zip(result,calc_entropy): | |
| if smile not in main_smiles: | |
| main_smiles.append(smile) | |
| main_entropy.append(entropy) | |
| print(f"Last masking generated {len(main_smiles)-length} SMILES") | |
| length = len(main_smiles) | |
| for _ in range(4): | |
| result, calc_entropy = gen_from_multimask(statin, print_flag=False, mask_flag = "random", percent=percent_to_use) | |
| for smile,entropy in zip(result,calc_entropy): | |
| if smile not in main_smiles: | |
| main_smiles.append(smile) | |
| main_entropy.append(entropy) | |
| print(f"Random masking generated {len(main_smiles)-length} SMILES") | |
| length = len(main_smiles) | |
| print(f"Total SMILES generated: {len(main_smiles)}") | |
| final_smiles,final_entropy = validate_smiles(main_smiles,main_entropy) | |
| qeds,mols = calc_qed(final_smiles) | |
| out_text = f"Total SMILES generated for hit: {len(final_smiles)}\n" | |
| out_text += "===================================================\n" | |
| i = 1 | |
| for smile, qed in zip(final_smiles,qeds): | |
| out_text += f"analogue {i}: {smile} with QED: {qed:.3f}\n" | |
| i += 1 | |
| legends = [f"QED = {qed:.3f}" for qed in qeds] | |
| img = Draw.MolsToGridImage(mols, legends=legends, molsPerRow=3, subImgSize=(200,200),useSVG=False,returnPNG=False) | |
| except: | |
| out_text = "Invalid SMILES string" | |
| img = None | |
| return out_text,img | |
| with gr.Blocks() as gradio_app: | |
| gr.Markdown( | |
| """ | |
| # Generate Analogues of a hit for hit expansion using generative mask-filling. | |
| - The hit molecule is input by the user; this molecule is then masked in different, | |
| random ways. A model, cafierom/bert-base-cased-ChemTok-ZN250K-V1, | |
| is used to generate SMILES strings for analogue molecules by unmasking the | |
| hit molecule. All possibilities created by the generative mask-filling | |
| are kept as long as the probability is greater than a cut-off, which is set | |
| to 0.1 but which may be changed. | |
| - The QED value, or quantitative estimate of druglikeness, a weighted average of | |
| various ADME properties is also calculated. A value of 1.0 is perfect | |
| drug-likeness, and a value of 0.0 is not drug-like. A value of about 0.5 | |
| is average for many drugs. | |
| """) | |
| smile = gr.Textbox(label="SMILES for hit expansion") | |
| percent_mask = gr.Radio(choices = [0.10, 0.15, 0.20], | |
| label="Fraction of hit molecule to mask.", value = 0.15,interactive=True) | |
| mask_btn = gr.Button("Generate analogues with Mask-filling.") | |
| with gr.Row(): | |
| results = gr.Textbox(label="New Molecules: ") | |
| mol_pic = gr.Image(label="Molecule Images:") | |
| def do_genmask(smile, percent_mask): | |
| """ | |
| Generate Analogues of a hit for hit expansion using generative mask-filling. | |
| The molecule corresponding to the input smiles is masked in different, | |
| random ways, creating various masked versions of the molelcule. | |
| A model, cafierom/bert-base-cased-ChemTok-ZN250K-V1, | |
| is used to generate SMILES strings for analogue molecules by unmasking the | |
| masked versions. All possibilities created by the generative mask-filling | |
| are kept as long as the probability is greater than a cut-off, which is set | |
| to 0.1 but which may be changed. The QED value, or quantitative estimate of druglikeness, a weighted average of | |
| various ADME properties is also calculated. A value of 1.0 is perfect | |
| drug-likeness, and a value of 0.0 is not drug-like. A value of 0.5 is average for many drugs. | |
| Args: | |
| smile: The SMILES string of the original molecule. | |
| Returns: | |
| out_text: a string with all of the SMILES for the generated molecules | |
| and their QED values. | |
| pic: An image of the molecules with QED values. | |
| """ | |
| return gen_mask(smile, percent_mask) | |
| if __name__ == "__main__": | |
| gradio_app.launch(mcp_server=True) |