import itertools as it import os import tempfile from io import StringIO import joblib import numpy as np import pandas as pd import pkg_resources # page set up import streamlit as st from b3clf.descriptor_padel import compute_descriptors from b3clf.geometry_opt import geometry_optimize from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors # from PIL import Image from streamlit_extras.let_it_rain import rain from streamlit_ketcher import st_ketcher from utils import generate_predictions, load_all_models st.cache_data.clear() st.set_page_config( page_title="BBB Permeability Prediction with Imbalanced Learning", # page_icon="🧊", layout="wide", # initial_sidebar_state="expanded", # menu_items={ # "Get Help": "https://www.extremelycoolapp.com/help", # "Report a bug": "https://www.extremelycoolapp.com/bug", # "About": "# This is a header. This is an *extremely* cool app!" # } ) keep_features = "no" keep_sdf = "no" classifiers_dict = { "decision tree": "dtree", "kNN": "knn", "logistic regression": "logreg", "XGBoost": "xgb", } resample_methods_dict = { "random undersampling": "classic_RandUndersampling", "SMOTE": "classic_SMOTE", "Borderline SMOTE": "borderline_SMOTE", "k-means SMOTE": "kmeans_SMOTE", "ADASYN": "classic_ADASYN", "no resampling": "common", } pandas_display_options = { "line_limit": 50, } mol_features = None info_df = None results = None temp_file_path = None all_models = load_all_models() # Initialize global variables and cleanup function if 'temp_dir' not in st.session_state: st.session_state.temp_dir = None if 'processing' not in st.session_state: st.session_state.processing = False def cleanup_temp_files(): """Clean up temporary directory and files""" if st.session_state.temp_dir and os.path.exists(st.session_state.temp_dir): try: import shutil shutil.rmtree(st.session_state.temp_dir) st.session_state.temp_dir = None except Exception as e: st.error(f"Error cleaning up temporary files: {e}") def clear_cache(): """Clear Streamlit cache and session state data""" st.cache_data.clear() st.cache_resource.clear() if 'mol_features' in st.session_state: st.session_state.mol_features = None if 'info_df' in st.session_state: st.session_state.info_df = None cleanup_temp_files() # Create the Streamlit app st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]") info_column, upload_column = st.columns(2) # inatialize the molecule features and info dataframe session state if "mol_features" not in st.session_state: st.session_state.mol_features = None if "info_df" not in st.session_state: st.session_state.info_df = None if "classifier" not in st.session_state: st.session_state.classifier = "XGBoost" if "resampler" not in st.session_state: st.session_state.resampler = "ADASYN" if "historical_data" not in st.session_state: st.session_state.historical_data = [] # download sample files with info_column: st.subheader("About `B3clf`") # fmt: off st.markdown( """ `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at qcdevs@gmail.com.""" ) st.text(" \n") # text_body = """ # `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. # """ # st.markdown(f"

{text_body}

", # unsafe_allow_html=True) # image = Image.open("images/b3clf_workflow.png") # st.image(image=image, use_column_width=True) # image_path = "images/b3clf_workflow.png" # image_width_percent = 80 # info_column.markdown( # f"", # unsafe_allow_html=True # ) # fmt: on sdf_col, smi_col = st.columns(2) with sdf_col: # uneven columns # st.columns((2, 1, 1, 1)) # two subcolumns for sample input files # download sample sdf # st.markdown(" \n \n") with open("sample_input.sdf", "r") as file_sdf: btn = st.download_button( label="Download SDF sample file", data=file_sdf, file_name="sample_input.sdf", ) with smi_col: with open("sample_input_smiles.csv", "r") as file_smi: btn = st.download_button( label="Download SMILES sample file", data=file_smi, file_name="sample_input_smiles.csv", ) # Create a file uploader with upload_column: st.subheader("Model Selection") with st.container(): algorithm_col, resampler_col = st.columns(2) # algorithm and resampling method selection column with algorithm_col: classifier = st.selectbox( label="Classification Algorithm:", options=("XGBoost", "kNN", "decision tree", "logistic regression"), key="classifier", help="Select the classification algorithm to use" ) with resampler_col: resampler = st.selectbox( label="Resampling Method:", options=( "ADASYN", "random undersampling", "Borderline SMOTE", "k-means SMOTE", "SMOTE", "no resampling", ), key="resampler", help="Select the resampling method to handle imbalanced data" ) # Update session state based on selections if "classifier" not in st.session_state: st.session_state.classifier = classifier if "resampler" not in st.session_state: st.session_state.resampler = resampler # horizontal line st.divider() # upload_col, submit_job_col = st.columns((2, 1)) upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05)) # upload file column with upload_col: # session state tracking of the file uploader if "uploaded_file" not in st.session_state: st.session_state.uploaded_file = None if "uploaded_file_changed" not in st.session_state: st.session_state.uploaded_file_changed = False # def update_uploader_session_info(): # """Update the session state of the file uploader.""" # st.session_state.uploaded_file = uploaded_file uploaded_file = st.file_uploader( label="Upload a CSV, SDF, TXT or SMI file", type=["csv", "sdf", "txt", "smi"], help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.", accept_multiple_files=False, # key="uploaded_file", # on_change=update_uploader_session_info, ) if uploaded_file: # st.write(f"the uploaded file: {uploaded_file}") # when new file is uploaded is different from the previous one if st.session_state.uploaded_file != uploaded_file: st.session_state.uploaded_file_changed = True else: st.session_state.uploaded_file_changed = False st.session_state.uploaded_file = uploaded_file # when new file is the same as the previous one # else: # st.session_state.uploaded_file_changed = False # st.session_state.uploaded_file = uploaded_file # set session state for the file uploader # st.write(f"the state of uploaded file: {st.session_state.uploaded_file}") # st.write(f"the state of uploaded file changed: {st.session_state.uploaded_file_changed}") # submit job column with submit_job_col: st.text(" \n") st.text(" \n") st.markdown( "
", unsafe_allow_html=True, ) submit_job_button = st.button( label="Submit Job", type="secondary", key="job_button", help="Click to start calculations with current configuration" ) if not submit_job_button: if "results" in locals(): del results if "mol_features" in locals(): del mol_features if "info_df" in locals(): del info_df # Display sections feature_column, prediction_column = st.columns(2) with feature_column: st.subheader("Molecular Features") placeholder_features = st.empty() with prediction_column: st.subheader("Predictions") # Only process when Submit Job is clicked if submit_job_button: if not uploaded_file and not st.session_state.mol_features: st.warning("Please upload a file first or select data from history to process.") else: if st.session_state.processing: st.warning("A job is already running. Please wait for it to complete.") else: try: st.session_state.processing = True with st.spinner('Processing... Please wait.'): # Clean up previous files and cache cleanup_temp_files() clear_cache() # Case 1: New file uploaded if uploaded_file: # Create new temporary directory st.session_state.temp_dir = tempfile.mkdtemp() temp_file_path = os.path.join(st.session_state.temp_dir, uploaded_file.name) with open(temp_file_path, "wb") as temp_file: temp_file.write(uploaded_file.read()) # Store current data in history before processing new data if st.session_state.mol_features is not None and st.session_state.info_df is not None: st.session_state.historical_data.append({ 'mol_features': st.session_state.mol_features.copy(), 'info_df': st.session_state.info_df.copy() }) # Clear current data st.session_state.mol_features = None st.session_state.info_df = None try: mol_features, info_df, results = generate_predictions( input_fname=temp_file_path, sep="\s+|\t+", clf=classifiers_dict[st.session_state.classifier], _models_dict=all_models, sampling=resample_methods_dict[st.session_state.resampler], time_per_mol=120, mol_features=None, info_df=None, ) finally: # Clean up temporary files after processing cleanup_temp_files() # Case 2: Recalculate with existing data else: mol_features, info_df, results = generate_predictions( input_fname=None, sep="\s+|\t+", clf=classifiers_dict[st.session_state.classifier], _models_dict=all_models, sampling=resample_methods_dict[st.session_state.resampler], time_per_mol=120, mol_features=st.session_state.mol_features, info_df=st.session_state.info_df, ) # Update session state with new results if mol_features is not None and info_df is not None: st.session_state.mol_features = mol_features st.session_state.info_df = info_df except Exception as e: st.error(f"Error during processing: {str(e)}") finally: st.session_state.processing = False # Display results # feture table with feature_column: if st.session_state.mol_features is not None: selected_feature_rows = np.min( [st.session_state.mol_features.shape[0], pandas_display_options["line_limit"]] ) st.dataframe(st.session_state.mol_features.iloc[:selected_feature_rows, :], hide_index=False) # placeholder_features.dataframe(mol_features, hide_index=False) feature_file_name = uploaded_file.name.split(".")[0] + "_b3clf_features.csv" features_csv = st.session_state.mol_features.to_csv(index=True) st.download_button( "Download features as CSV", data=features_csv, file_name=feature_file_name, ) # prediction table with prediction_column: # st.subheader("Predictions") if results is not None: # Display the predictions in a table selected_result_rows = np.min( [results.shape[0], pandas_display_options["line_limit"]] ) results_df_display = results.iloc[:selected_result_rows, :].style.format( {"B3clf_predicted_probability": "{:.6f}".format} ) st.dataframe(results_df_display, hide_index=True) # Add a button to download the predictions as a CSV file predictions_csv = results.to_csv(index=True) results_file_name = ( uploaded_file.name.split(".")[0] + "_b3clf_predictions.csv" ) st.download_button( "Download predictions as CSV", data=predictions_csv, file_name=results_file_name, ) # indicate the success of the job # rain( # emoji="🎈", # font_size=54, # falling_speed=5, # animation_length=10, # ) st.balloons() # hide footer # https://github.com/streamlit/streamlit/issues/892 hide_streamlit_style = """ """ st.markdown(hide_streamlit_style, unsafe_allow_html=True) # add google analytics st.markdown( """ """, unsafe_allow_html=True, )