Spaces:
Runtime error
Runtime error
| import itertools as it | |
| import os | |
| import tempfile | |
| from io import StringIO | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import pkg_resources | |
| # page set up | |
| import streamlit as st | |
| from b3clf.descriptor_padel import compute_descriptors | |
| from b3clf.geometry_opt import geometry_optimize | |
| from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors | |
| # from PIL import Image | |
| from streamlit_extras.let_it_rain import rain | |
| from streamlit_ketcher import st_ketcher | |
| from utils import generate_predictions, load_all_models | |
| st.cache_data.clear() | |
| st.set_page_config( | |
| page_title="BBB Permeability Prediction with Imbalanced Learning", | |
| # page_icon="🧊", | |
| layout="wide", | |
| # initial_sidebar_state="expanded", | |
| # menu_items={ | |
| # "Get Help": "https://www.extremelycoolapp.com/help", | |
| # "Report a bug": "https://www.extremelycoolapp.com/bug", | |
| # "About": "# This is a header. This is an *extremely* cool app!" | |
| # } | |
| ) | |
| keep_features = "no" | |
| keep_sdf = "no" | |
| classifiers_dict = { | |
| "decision tree": "dtree", | |
| "kNN": "knn", | |
| "logistic regression": "logreg", | |
| "XGBoost": "xgb", | |
| } | |
| resample_methods_dict = { | |
| "random undersampling": "classic_RandUndersampling", | |
| "SMOTE": "classic_SMOTE", | |
| "Borderline SMOTE": "borderline_SMOTE", | |
| "k-means SMOTE": "kmeans_SMOTE", | |
| "ADASYN": "classic_ADASYN", | |
| "no resampling": "common", | |
| } | |
| pandas_display_options = { | |
| "line_limit": 50, | |
| } | |
| mol_features = None | |
| info_df = None | |
| results = None | |
| temp_file_path = None | |
| all_models = load_all_models() | |
| # Create the Streamlit app | |
| st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]") | |
| info_column, upload_column = st.columns(2) | |
| # inatialize the molecule features and info dataframe session state | |
| if "mol_features" not in st.session_state: | |
| st.session_state.mol_features = None | |
| if "info_df" not in st.session_state: | |
| st.session_state.info_df = None | |
| # download sample files | |
| with info_column: | |
| st.subheader("About `B3clf`") | |
| # fmt: off | |
| st.markdown( | |
| """ | |
| `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at qcdevs@gmail.com.""" | |
| ) | |
| st.text(" \n") | |
| # text_body = """ | |
| # `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. | |
| # """ | |
| # st.markdown(f"<p align="justify">{text_body}</p>", | |
| # unsafe_allow_html=True) | |
| # image = Image.open("images/b3clf_workflow.png") | |
| # st.image(image=image, use_column_width=True) | |
| # image_path = "images/b3clf_workflow.png" | |
| # image_width_percent = 80 | |
| # info_column.markdown( | |
| # f"<img src="{image_path}" style="max-width: {image_width_percent}%; height: auto;">", | |
| # unsafe_allow_html=True | |
| # ) | |
| # fmt: on | |
| sdf_col, smi_col = st.columns(2) | |
| with sdf_col: | |
| # uneven columns | |
| # st.columns((2, 1, 1, 1)) | |
| # two subcolumns for sample input files | |
| # download sample sdf | |
| # st.markdown(" \n \n") | |
| with open("sample_input.sdf", "r") as file_sdf: | |
| btn = st.download_button( | |
| label="Download SDF sample file", | |
| data=file_sdf, | |
| file_name="sample_input.sdf", | |
| ) | |
| with smi_col: | |
| with open("sample_input_smiles.csv", "r") as file_smi: | |
| btn = st.download_button( | |
| label="Download SMILES sample file", | |
| data=file_smi, | |
| file_name="sample_input_smiles.csv", | |
| ) | |
| # Create a file uploader | |
| with upload_column: | |
| st.subheader("Model Selection") | |
| with st.container(): | |
| algorithm_col, resampler_col = st.columns(2) | |
| # algorithm and resampling method selection column | |
| with algorithm_col: | |
| classifier = st.selectbox( | |
| label="Classification Algorithm:", | |
| options=("XGBoost", "kNN", "decision tree", "logistic regression"), | |
| ) | |
| with resampler_col: | |
| resampler = st.selectbox( | |
| label="Resampling Method:", | |
| options=( | |
| "ADASYN", | |
| "random undersampling", | |
| "Borderline SMOTE", | |
| "k-means SMOTE", | |
| "SMOTE", | |
| "no resampling", | |
| ), | |
| ) | |
| # horizontal line | |
| st.divider() | |
| # upload_col, submit_job_col = st.columns((2, 1)) | |
| upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05)) | |
| # upload file column | |
| with upload_col: | |
| # session state tracking of the file uploader | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state.uploaded_file = None | |
| if "uploaded_file_changed" not in st.session_state: | |
| st.session_state.uploaded_file_changed = False | |
| # def update_uploader_session_info(): | |
| # """Update the session state of the file uploader.""" | |
| # st.session_state.uploaded_file = uploaded_file | |
| uploaded_file = st.file_uploader( | |
| label="Upload a CSV, SDF, TXT or SMI file", | |
| type=["csv", "sdf", "txt", "smi"], | |
| help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.", | |
| accept_multiple_files=False, | |
| # key="uploaded_file", | |
| # on_change=update_uploader_session_info, | |
| ) | |
| # when new file is the same as the previous one | |
| # else: | |
| # st.session_state.uploaded_file_changed = False | |
| # st.session_state.uploaded_file = uploaded_file | |
| # set session state for the file uploader | |
| # st.write(f"the state of uploaded file: {st.session_state.uploaded_file}") | |
| # st.write(f"the state of uploaded file changed: {st.session_state.uploaded_file_changed}") | |
| # submit job column | |
| with submit_job_col: | |
| st.text(" \n") | |
| st.text(" \n") | |
| st.markdown( | |
| "<div style='display: flex; justify-content: center;'>", | |
| unsafe_allow_html=True, | |
| ) | |
| submit_job_button = st.button( | |
| label="Submit Job", type="secondary", key="job_button" | |
| ) | |
| # submit_job_col.markdown("<div style="display: flex; justify-content: center;">", | |
| # unsafe_allow_html=True) | |
| # submit_job_button = submit_job_col.button( | |
| # label="Submit job", key="submit_job_button", type="secondary" | |
| # ) | |
| # submit_job_col.markdown("</div>", unsafe_allow_html=True) | |
| # st.write("The content of the file will be displayed below once uploaded.") | |
| # if file: | |
| # if "csv" in file.name or "txt" in file.name: | |
| # st.write(file.read().decode("utf-8")) | |
| # st.write(file) | |
| feature_column, prediction_column = st.columns(2) | |
| with feature_column: | |
| st.subheader("Molecular Features") | |
| placeholder_features = st.empty() | |
| # placeholder_features = pd.DataFrame(index=[1, 2, 3, 4], | |
| # columns=["ID", "nAcid", "ALogP", "Alogp2", | |
| # "AMR", "naAromAtom", "nH", "nN"]) | |
| # st.dataframe(placeholder_features) | |
| # placeholder_features.text("molecular features") | |
| with prediction_column: | |
| st.subheader("Predictions") | |
| # placeholder_predictions = st.empty() | |
| # placeholder_predictions.text("prediction") | |
| st.write( | |
| f"the state of uploaded file changed before checking: {st.session_state.uploaded_file_changed}" | |
| ) | |
| # Generate predictions when the user uploads a file | |
| # if submit_job_button: | |
| if submit_job_button: | |
| if uploaded_file: | |
| # st.write(f"the uploaded file: {uploaded_file}") | |
| # when new file is uploaded is different from thprevious one | |
| if st.session_state.uploaded_file != uploaded_file: | |
| st.session_state.uploaded_file_changed = True | |
| else: | |
| st.session_state.uploaded_file_changed = False | |
| st.session_state.uploaded_file = uploaded_file | |
| # when new file is uploaded | |
| # update_uploader_session_info() | |
| st.write( | |
| f"the state of uploaded file changed after checking: {st.session_state.uploaded_file_changed}" | |
| ) | |
| if st.session_state.uploaded_file_changed: | |
| temp_dir = tempfile.mkdtemp() | |
| # Create a temporary file path for the uploaded file | |
| temp_file_path = os.path.join(temp_dir, uploaded_file.name) | |
| # Save the uploaded file to the temporary file path | |
| with open(temp_file_path, "wb") as temp_file: | |
| temp_file.write(uploaded_file.read()) | |
| mol_features, info_df, results = generate_predictions( | |
| input_fname=temp_file_path, | |
| sep="\s+|\t+", | |
| clf=classifiers_dict[classifier], | |
| _models_dict=all_models, | |
| sampling=resample_methods_dict[resampler], | |
| time_per_mol=120, | |
| mol_features=None, | |
| info_df=None, | |
| ) | |
| st.session_state.mol_features = mol_features | |
| st.session_state.info_df = info_df | |
| else: | |
| mol_features, info_df, results = generate_predictions( | |
| input_fname=None, | |
| sep="\s+|\t+", | |
| clf=classifiers_dict[classifier], | |
| _models_dict=all_models, | |
| sampling=resample_methods_dict[resampler], | |
| time_per_mol=120, | |
| mol_features=st.session_state.mol_features, | |
| info_df=st.session_state.info_df, | |
| ) | |
| # feture table | |
| with feature_column: | |
| if mol_features is not None: | |
| selected_feature_rows = np.min( | |
| [mol_features.shape[0], pandas_display_options["line_limit"]] | |
| ) | |
| st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False) | |
| # placeholder_features.dataframe(mol_features, hide_index=False) | |
| feature_file_name = uploaded_file.name.split(".")[0] + "_b3clf_features.csv" | |
| features_csv = mol_features.to_csv(index=True) | |
| st.download_button( | |
| "Download features as CSV", | |
| data=features_csv, | |
| file_name=feature_file_name, | |
| ) | |
| # prediction table | |
| with prediction_column: | |
| # st.subheader("Predictions") | |
| if results is not None: | |
| # Display the predictions in a table | |
| selected_result_rows = np.min( | |
| [results.shape[0], pandas_display_options["line_limit"]] | |
| ) | |
| results_df_display = results.iloc[:selected_result_rows, :].style.format( | |
| {"B3clf_predicted_probability": "{:.6f}".format} | |
| ) | |
| st.dataframe(results_df_display, hide_index=True) | |
| # Add a button to download the predictions as a CSV file | |
| predictions_csv = results.to_csv(index=True) | |
| results_file_name = ( | |
| uploaded_file.name.split(".")[0] + "_b3clf_predictions.csv" | |
| ) | |
| st.download_button( | |
| "Download predictions as CSV", | |
| data=predictions_csv, | |
| file_name=results_file_name, | |
| ) | |
| # indicate the success of the job | |
| # rain( | |
| # emoji="🎈", | |
| # font_size=54, | |
| # falling_speed=5, | |
| # animation_length=10, | |
| # ) | |
| st.balloons() | |
| # hide footer | |
| # https://github.com/streamlit/streamlit/issues/892 | |
| hide_streamlit_style = """ | |
| <style> | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| </style> | |
| """ | |
| st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |
| # add google analytics | |
| st.markdown( | |
| """ | |
| <!-- Google tag (gtag.js) --> | |
| <script async src="https://www.googletagmanager.com/gtag/js?id=G-WG8QYRELP9"></script> | |
| <script> | |
| window.dataLayer = window.dataLayer || []; | |
| function gtag(){dataLayer.push(arguments);} | |
| gtag("js", new Date()); | |
| gtag("config", "G-WG8QYRELP9"); | |
| </script> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |