b3clf / app.py
legend1234's picture
Add working webserver
ec17199
import itertools as it
import os
import tempfile
from io import StringIO
import joblib
import numpy as np
import pandas as pd
import pkg_resources
# page set up
import streamlit as st
from b3clf.descriptor_padel import compute_descriptors
from b3clf.geometry_opt import geometry_optimize
from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors
# from PIL import Image
from streamlit_extras.let_it_rain import rain
from streamlit_ketcher import st_ketcher
from utils import generate_predictions, load_all_models
st.cache_data.clear()
st.set_page_config(
page_title="BBB Permeability Prediction with Imbalanced Learning",
# page_icon="🧊",
layout="wide",
# initial_sidebar_state="expanded",
# menu_items={
# "Get Help": "https://www.extremelycoolapp.com/help",
# "Report a bug": "https://www.extremelycoolapp.com/bug",
# "About": "# This is a header. This is an *extremely* cool app!"
# }
)
keep_features = "no"
keep_sdf = "no"
classifiers_dict = {
"decision tree": "dtree",
"kNN": "knn",
"logistic regression": "logreg",
"XGBoost": "xgb",
}
resample_methods_dict = {
"random undersampling": "classic_RandUndersampling",
"SMOTE": "classic_SMOTE",
"Borderline SMOTE": "borderline_SMOTE",
"k-means SMOTE": "kmeans_SMOTE",
"ADASYN": "classic_ADASYN",
"no resampling": "common",
}
pandas_display_options = {
"line_limit": 50,
}
mol_features = None
info_df = None
results = None
temp_file_path = None
all_models = load_all_models()
# Initialize global variables and cleanup function
if 'temp_dir' not in st.session_state:
st.session_state.temp_dir = None
if 'processing' not in st.session_state:
st.session_state.processing = False
def cleanup_temp_files():
"""Clean up temporary directory and files"""
if st.session_state.temp_dir and os.path.exists(st.session_state.temp_dir):
try:
import shutil
shutil.rmtree(st.session_state.temp_dir)
st.session_state.temp_dir = None
except Exception as e:
st.error(f"Error cleaning up temporary files: {e}")
def clear_cache():
"""Clear Streamlit cache and session state data"""
st.cache_data.clear()
st.cache_resource.clear()
if 'mol_features' in st.session_state:
st.session_state.mol_features = None
if 'info_df' in st.session_state:
st.session_state.info_df = None
cleanup_temp_files()
# Create the Streamlit app
st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]")
info_column, upload_column = st.columns(2)
# inatialize the molecule features and info dataframe session state
if "mol_features" not in st.session_state:
st.session_state.mol_features = None
if "info_df" not in st.session_state:
st.session_state.info_df = None
if "classifier" not in st.session_state:
st.session_state.classifier = "XGBoost"
if "resampler" not in st.session_state:
st.session_state.resampler = "ADASYN"
if "historical_data" not in st.session_state:
st.session_state.historical_data = []
# download sample files
with info_column:
st.subheader("About `B3clf`")
# fmt: off
st.markdown(
"""
`B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at qcdevs@gmail.com."""
)
st.text(" \n")
# text_body = """
# `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf.
# """
# st.markdown(f"<p align="justify">{text_body}</p>",
# unsafe_allow_html=True)
# image = Image.open("images/b3clf_workflow.png")
# st.image(image=image, use_column_width=True)
# image_path = "images/b3clf_workflow.png"
# image_width_percent = 80
# info_column.markdown(
# f"<img src="{image_path}" style="max-width: {image_width_percent}%; height: auto;">",
# unsafe_allow_html=True
# )
# fmt: on
sdf_col, smi_col = st.columns(2)
with sdf_col:
# uneven columns
# st.columns((2, 1, 1, 1))
# two subcolumns for sample input files
# download sample sdf
# st.markdown(" \n \n")
with open("sample_input.sdf", "r") as file_sdf:
btn = st.download_button(
label="Download SDF sample file",
data=file_sdf,
file_name="sample_input.sdf",
)
with smi_col:
with open("sample_input_smiles.csv", "r") as file_smi:
btn = st.download_button(
label="Download SMILES sample file",
data=file_smi,
file_name="sample_input_smiles.csv",
)
# Create a file uploader
with upload_column:
st.subheader("Model Selection")
with st.container():
algorithm_col, resampler_col = st.columns(2)
# algorithm and resampling method selection column
with algorithm_col:
classifier = st.selectbox(
label="Classification Algorithm:",
options=("XGBoost", "kNN", "decision tree", "logistic regression"),
key="classifier",
help="Select the classification algorithm to use"
)
with resampler_col:
resampler = st.selectbox(
label="Resampling Method:",
options=(
"ADASYN",
"random undersampling",
"Borderline SMOTE",
"k-means SMOTE",
"SMOTE",
"no resampling",
),
key="resampler",
help="Select the resampling method to handle imbalanced data"
)
# Update session state based on selections
if "classifier" not in st.session_state:
st.session_state.classifier = classifier
if "resampler" not in st.session_state:
st.session_state.resampler = resampler
# horizontal line
st.divider()
# upload_col, submit_job_col = st.columns((2, 1))
upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05))
# upload file column
with upload_col:
# session state tracking of the file uploader
if "uploaded_file" not in st.session_state:
st.session_state.uploaded_file = None
if "uploaded_file_changed" not in st.session_state:
st.session_state.uploaded_file_changed = False
# def update_uploader_session_info():
# """Update the session state of the file uploader."""
# st.session_state.uploaded_file = uploaded_file
uploaded_file = st.file_uploader(
label="Upload a CSV, SDF, TXT or SMI file",
type=["csv", "sdf", "txt", "smi"],
help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.",
accept_multiple_files=False,
# key="uploaded_file",
# on_change=update_uploader_session_info,
)
if uploaded_file:
# st.write(f"the uploaded file: {uploaded_file}")
# when new file is uploaded is different from the previous one
if st.session_state.uploaded_file != uploaded_file:
st.session_state.uploaded_file_changed = True
else:
st.session_state.uploaded_file_changed = False
st.session_state.uploaded_file = uploaded_file
# when new file is the same as the previous one
# else:
# st.session_state.uploaded_file_changed = False
# st.session_state.uploaded_file = uploaded_file
# set session state for the file uploader
# st.write(f"the state of uploaded file: {st.session_state.uploaded_file}")
# st.write(f"the state of uploaded file changed: {st.session_state.uploaded_file_changed}")
# submit job column
with submit_job_col:
st.text(" \n")
st.text(" \n")
st.markdown(
"<div style='display: flex; justify-content: center;'>",
unsafe_allow_html=True,
)
submit_job_button = st.button(
label="Submit Job",
type="secondary",
key="job_button",
help="Click to start calculations with current configuration"
)
if not submit_job_button:
if "results" in locals():
del results
if "mol_features" in locals():
del mol_features
if "info_df" in locals():
del info_df
# Display sections
feature_column, prediction_column = st.columns(2)
with feature_column:
st.subheader("Molecular Features")
placeholder_features = st.empty()
with prediction_column:
st.subheader("Predictions")
# Only process when Submit Job is clicked
if submit_job_button:
if not uploaded_file and not st.session_state.mol_features:
st.warning("Please upload a file first or select data from history to process.")
else:
if st.session_state.processing:
st.warning("A job is already running. Please wait for it to complete.")
else:
try:
st.session_state.processing = True
with st.spinner('Processing... Please wait.'):
# Clean up previous files and cache
cleanup_temp_files()
clear_cache()
# Case 1: New file uploaded
if uploaded_file:
# Create new temporary directory
st.session_state.temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(st.session_state.temp_dir, uploaded_file.name)
with open(temp_file_path, "wb") as temp_file:
temp_file.write(uploaded_file.read())
# Store current data in history before processing new data
if st.session_state.mol_features is not None and st.session_state.info_df is not None:
st.session_state.historical_data.append({
'mol_features': st.session_state.mol_features.copy(),
'info_df': st.session_state.info_df.copy()
})
# Clear current data
st.session_state.mol_features = None
st.session_state.info_df = None
try:
mol_features, info_df, results = generate_predictions(
input_fname=temp_file_path,
sep="\s+|\t+",
clf=classifiers_dict[st.session_state.classifier],
_models_dict=all_models,
sampling=resample_methods_dict[st.session_state.resampler],
time_per_mol=120,
mol_features=None,
info_df=None,
)
finally:
# Clean up temporary files after processing
cleanup_temp_files()
# Case 2: Recalculate with existing data
else:
mol_features, info_df, results = generate_predictions(
input_fname=None,
sep="\s+|\t+",
clf=classifiers_dict[st.session_state.classifier],
_models_dict=all_models,
sampling=resample_methods_dict[st.session_state.resampler],
time_per_mol=120,
mol_features=st.session_state.mol_features,
info_df=st.session_state.info_df,
)
# Update session state with new results
if mol_features is not None and info_df is not None:
st.session_state.mol_features = mol_features
st.session_state.info_df = info_df
except Exception as e:
st.error(f"Error during processing: {str(e)}")
finally:
st.session_state.processing = False
# Display results
# feture table
with feature_column:
if st.session_state.mol_features is not None:
selected_feature_rows = np.min(
[st.session_state.mol_features.shape[0], pandas_display_options["line_limit"]]
)
st.dataframe(st.session_state.mol_features.iloc[:selected_feature_rows, :], hide_index=False)
# placeholder_features.dataframe(mol_features, hide_index=False)
feature_file_name = uploaded_file.name.split(".")[0] + "_b3clf_features.csv"
features_csv = st.session_state.mol_features.to_csv(index=True)
st.download_button(
"Download features as CSV",
data=features_csv,
file_name=feature_file_name,
)
# prediction table
with prediction_column:
# st.subheader("Predictions")
if results is not None:
# Display the predictions in a table
selected_result_rows = np.min(
[results.shape[0], pandas_display_options["line_limit"]]
)
results_df_display = results.iloc[:selected_result_rows, :].style.format(
{"B3clf_predicted_probability": "{:.6f}".format}
)
st.dataframe(results_df_display, hide_index=True)
# Add a button to download the predictions as a CSV file
predictions_csv = results.to_csv(index=True)
results_file_name = (
uploaded_file.name.split(".")[0] + "_b3clf_predictions.csv"
)
st.download_button(
"Download predictions as CSV",
data=predictions_csv,
file_name=results_file_name,
)
# indicate the success of the job
# rain(
# emoji="🎈",
# font_size=54,
# falling_speed=5,
# animation_length=10,
# )
st.balloons()
# hide footer
# https://github.com/streamlit/streamlit/issues/892
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
# add google analytics
st.markdown(
"""
<!-- Google tag (gtag.js) -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-WG8QYRELP9"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag("js", new Date());
gtag("config", "G-WG8QYRELP9");
</script>
""",
unsafe_allow_html=True,
)