Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pickle | |
| from mhnreact.inspect import list_models, load_clf | |
| from rdkit.Chem import rdChemReactions as Reaction | |
| from rdkit.Chem.Draw import rdMolDraw2D | |
| from PIL import Image, ImageDraw, ImageFont | |
| from ssretro_template import ssretro, ssretro_custom | |
| def custom_template_file(template: str): | |
| temp = [x.strip() for x in template.split(',')] | |
| template_dict = {} | |
| for i in range(len(temp)): | |
| template_dict[i] = temp[i] | |
| with open('saved_dictionary.pkl', 'wb') as f: | |
| pickle.dump(template_dict, f) | |
| return template_dict | |
| def get_output(p): | |
| rxn = Reaction.ReactionFromSmarts(p, useSmiles=False) | |
| d = rdMolDraw2D.MolDraw2DCairo(800, 200) | |
| d.DrawReaction(rxn, highlightByReactant=False) | |
| d.FinishDrawing() | |
| text = d.GetDrawingText() | |
| return text | |
| def ssretro_prediction(molecule, custom_template=False): | |
| model_fn = list_models()[0] | |
| retro_clf = load_clf(model_fn) | |
| predict, txt = [], [] | |
| if custom_template: | |
| outputs = ssretro_custom(molecule, retro_clf) | |
| else: | |
| outputs = ssretro(molecule, retro_clf) | |
| for pred in outputs: | |
| txt.append( | |
| f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;') | |
| predict.append(get_output(pred["reaction"])) | |
| return predict, txt | |
| def mhn_react_backend(mol, use_custom: bool): | |
| output_dir = "outputs" | |
| formatter = "03d" | |
| images = [] | |
| predictions, comments = ssretro_prediction(mol, use_custom) | |
| for i in range(len(predictions)): | |
| output_im = f"{str(output_dir)}/{format(i, formatter)}.png" | |
| with open(output_im, "wb") as fh: | |
| fh.write(predictions[i]) | |
| fh.close() | |
| font = ImageFont.truetype(r'tools/arial.ttf', 20) | |
| img = Image.open(output_im) | |
| right = 10 | |
| left = 10 | |
| top = 50 | |
| bottom = 1 | |
| width, height = img.size | |
| new_width = width + right + left | |
| new_height = height + top + bottom | |
| result = Image.new(img.mode, (new_width, new_height), (255, 255, 255)) | |
| result.paste(img, (left, top)) | |
| I1 = ImageDraw.Draw(result) | |
| I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0)) | |
| images.append(result) | |
| result.save(output_im) | |
| return images | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| [](https://img.shields.io/badge/github-%20mhn--react-blue) | |
| [](https://doi.org/10.1021/acs.jcim.1c01065) | |
| [](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb) | |
| ### MHN-react | |
| Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities, | |
| molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis. | |
| """ | |
| ) | |
| with gr.Accordion("Information"): | |
| gr.Markdown("use one of example molecules <br> CC(=O)NCCC1=CNc2c1cc(OC)cc2, <br> CN1CCC[C@H]1c2cccnc2, <br> OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N" | |
| "In case the output is empty, no applicable templates were found" | |
| ) | |
| with gr.Tab("Generate Templates"): | |
| with gr.Row(): | |
| with gr.Column(scale = 1): | |
| inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule") | |
| radio = gr.Radio([False, True], label="use custom templates") | |
| btn = gr.Button(value="Generate") | |
| with gr.Column(scale=2): | |
| out = gr.Gallery(label="retro-synthesis") | |
| btn.click(mhn_react_backend, [inp, radio], out) | |
| with gr.Tab("Create custom templates"): | |
| gr.Markdown( | |
| """ | |
| Input the templates separated by comma. <br> Please do not upload templates one-by-one | |
| """ | |
| ) | |
| with gr.Column(): | |
| inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)") | |
| btn = gr.Button(value="upload") | |
| out_t = gr.Textbox(label = "added templates") | |
| btn.click(custom_template_file, inp_t, out_t) | |
| demo.launch(debug = True) | |