Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import pandas as pd | |
| import mols2grid | |
| from ipywidgets import interact, widgets | |
| import textwrap | |
| import moses | |
| from transformers import EncoderDecoderModel, RobertaTokenizer | |
| from moses.metrics.utils import QED, SA, logP, NP, weight, get_n_rings | |
| from moses.utils import mapper, get_mol | |
| # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str}) | |
| from typing import List | |
| from util import filter_dataframe | |
| def load_models(): | |
| # protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") | |
| # mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") | |
| model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne") | |
| model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo") | |
| return model1, model2 # , protein_tokenizer, mol_tokenizer | |
| def count(smiles_list: List[str]): | |
| counts = [] | |
| for smiles in smiles_list: | |
| counts.append(len(smiles)) | |
| return counts | |
| def remove_none_elements(mol_list, smiles_list): | |
| filtered_mol_list = [] | |
| filtered_smiles_list = [] | |
| indices = [] | |
| for i, element in enumerate(mol_list): | |
| if element is not None: | |
| filtered_mol_list.append(element) | |
| else: | |
| indices.append(i) | |
| removed_len = len(indices) | |
| for i in range(len(smiles_list)): | |
| if i not in indices: | |
| filtered_smiles_list.append(smiles_list.__getitem__(i)) | |
| return filtered_mol_list, filtered_smiles_list, removed_len | |
| def format_list_numbers(lst): | |
| for i, value in enumerate(lst): | |
| lst[i] = float("{:.3f}".format(value)) | |
| def generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool): | |
| protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") | |
| mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") | |
| # model1, model2, protein_tokenizer, mol_tokenizer = load_models() | |
| model1, model2 = load_models() | |
| inputs = protein_tokenizer(target, return_tensors="pt") | |
| model = model1 if model_name == 'WarmMolGenOne' else model2 | |
| outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, | |
| eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, | |
| max_length=int(max_new_tokens), num_return_sequences=int(num_mols), | |
| do_sample=do_sample, num_beams=num_beams) | |
| output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| st.write("### Generated Molecules") | |
| # mol_list = list(map(MolFromSmiles, output_smiles)) | |
| # print(mol_list) | |
| # QED_scores = list(map(QED.qed, mol_list)) | |
| # print(QED_scores) | |
| # st.write(output_smiles) | |
| mol_list = mapper(pool)(get_mol, output_smiles) | |
| mol_list, output_smiles, removed_len = remove_none_elements(mol_list, output_smiles) | |
| if removed_len != 0: | |
| st.write(f"#### Note that: {removed_len} numbers of generated invalid molecules are discarded.") | |
| QED_scores = mapper(pool)(QED, mol_list) | |
| SA_scores = mapper(pool)(SA, mol_list) | |
| logP_scores = mapper(pool)(logP, mol_list) | |
| NP_scores = mapper(pool)(NP, mol_list) | |
| weight_scores = mapper(pool)(weight, mol_list) | |
| format_list_numbers(QED_scores) | |
| format_list_numbers(SA_scores) | |
| format_list_numbers(logP_scores) | |
| format_list_numbers(NP_scores) | |
| format_list_numbers(weight_scores) | |
| df_smiles = pd.DataFrame( | |
| {'SMILES': output_smiles, "QED": QED_scores, "SA": SA_scores, "logP": logP_scores, "NP": NP_scores, | |
| "Weight": weight_scores}) | |
| return df_smiles | |
| def warm_molgen_demo(): | |
| with st.form("my_form"): | |
| with st.sidebar: | |
| st.sidebar.subheader("Configurable parameters") | |
| model_name = st.sidebar.selectbox( | |
| "Model Selector", | |
| options=[ | |
| "WarmMolGenOne", | |
| "WarmMolGenTwo", | |
| ], | |
| index=0, | |
| ) | |
| num_mols = st.sidebar.number_input( | |
| "Number of generated molecules", | |
| min_value=0, | |
| max_value=20, | |
| value=10, | |
| help="The number of molecules to be generated.", | |
| ) | |
| max_new_tokens = st.sidebar.number_input( | |
| "Maximum length", | |
| min_value=0, | |
| max_value=1024, | |
| value=128, | |
| help="The maximum length of the sequence to be generated.", | |
| ) | |
| do_sample = st.sidebar.selectbox( | |
| "Sampling?", | |
| (True, False), | |
| help="Whether or not to use sampling; use beam decoding otherwise.", | |
| ) | |
| target = st.text_area( | |
| "Target Sequence", | |
| "MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG", | |
| ) | |
| generate_new_molecules = st.form_submit_button("Generate Molecules") | |
| num_beams = None if do_sample is True else int(num_mols) | |
| pool = 1 | |
| if generate_new_molecules: | |
| st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, | |
| target, pool) | |
| if 'df' not in st.session_state: | |
| st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, | |
| target, pool) | |
| df = st.session_state.df | |
| filtered_df = filter_dataframe(df) | |
| if filtered_df.empty: | |
| st.markdown( | |
| """ | |
| <span style='color: blue; font-size: 30px;'>No molecules were found with specified properties.</span> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| else: | |
| raw_html = mols2grid.display(filtered_df, height=1000)._repr_html_() | |
| components.html(raw_html, width=900, height=450, scrolling=True) | |
| st.markdown("## How to Generate") | |
| generation_code = f""" | |
| from transformers import EncoderDecoderModel, RobertaTokenizer | |
| protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}") | |
| mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") | |
| model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}") | |
| inputs = protein_tokenizer("{target}", return_tensors="pt") | |
| outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, | |
| eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, | |
| max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams}) | |
| mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| """ | |
| st.code(textwrap.dedent(generation_code)) # textwrap.dedent("".join("Halletcez"))) | |
| st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide') | |
| st.markdown("# WarmMolGen Demo") | |
| st.sidebar.header("WarmMolGen Demo") | |
| st.markdown( | |
| """ | |
| This demo illustrates WarmMolGen models' generation capabilities. | |
| Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence. | |
| Please enter an input sequence below 👇 and configure parameters from the sidebar 👈 to generate molecules! | |
| See below for saving the output molecules and the code snippet generating them! | |
| """ | |
| ) | |
| warm_molgen_demo() | |