| | import itertools as it |
| | import os |
| | import tempfile |
| | from io import StringIO |
| |
|
| | import joblib |
| | import numpy as np |
| | import pandas as pd |
| | import pkg_resources |
| | |
| | import streamlit as st |
| | from b3clf.descriptor_padel import compute_descriptors |
| | from b3clf.geometry_opt import geometry_optimize |
| | from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors |
| | |
| | from streamlit_extras.let_it_rain import rain |
| | from streamlit_ketcher import st_ketcher |
| |
|
| | from utils import generate_predictions, load_all_models |
| |
|
| | st.cache_data.clear() |
| |
|
| | st.set_page_config( |
| | page_title="BBB Permeability Prediction with Imbalanced Learning", |
| | |
| | layout="wide", |
| | |
| | |
| | |
| | |
| | |
| | |
| | ) |
| |
|
| |
|
| | keep_features = "no" |
| | keep_sdf = "no" |
| | classifiers_dict = { |
| | "decision tree": "dtree", |
| | "kNN": "knn", |
| | "logistic regression": "logreg", |
| | "XGBoost": "xgb", |
| | } |
| | resample_methods_dict = { |
| | "random undersampling": "classic_RandUndersampling", |
| | "SMOTE": "classic_SMOTE", |
| | "Borderline SMOTE": "borderline_SMOTE", |
| | "k-means SMOTE": "kmeans_SMOTE", |
| | "ADASYN": "classic_ADASYN", |
| | "no resampling": "common", |
| | } |
| |
|
| | pandas_display_options = { |
| | "line_limit": 50, |
| | } |
| | mol_features = None |
| | info_df = None |
| | results = None |
| | temp_file_path = None |
| | all_models = load_all_models() |
| |
|
| | |
| | if 'temp_dir' not in st.session_state: |
| | st.session_state.temp_dir = None |
| | if 'processing' not in st.session_state: |
| | st.session_state.processing = False |
| |
|
| | def cleanup_temp_files(): |
| | """Clean up temporary directory and files""" |
| | if st.session_state.temp_dir and os.path.exists(st.session_state.temp_dir): |
| | try: |
| | import shutil |
| | shutil.rmtree(st.session_state.temp_dir) |
| | st.session_state.temp_dir = None |
| | except Exception as e: |
| | st.error(f"Error cleaning up temporary files: {e}") |
| |
|
| | def clear_cache(): |
| | """Clear Streamlit cache and session state data""" |
| | st.cache_data.clear() |
| | st.cache_resource.clear() |
| | if 'mol_features' in st.session_state: |
| | st.session_state.mol_features = None |
| | if 'info_df' in st.session_state: |
| | st.session_state.info_df = None |
| | cleanup_temp_files() |
| |
|
| | |
| | st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]") |
| | info_column, upload_column = st.columns(2) |
| |
|
| | |
| | if "mol_features" not in st.session_state: |
| | st.session_state.mol_features = None |
| | if "info_df" not in st.session_state: |
| | st.session_state.info_df = None |
| | if "classifier" not in st.session_state: |
| | st.session_state.classifier = "XGBoost" |
| | if "resampler" not in st.session_state: |
| | st.session_state.resampler = "ADASYN" |
| | if "historical_data" not in st.session_state: |
| | st.session_state.historical_data = [] |
| |
|
| | |
| | with info_column: |
| | st.subheader("About `B3clf`") |
| | |
| | st.markdown( |
| | """ |
| | `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at qcdevs@gmail.com.""" |
| | ) |
| | st.text(" \n") |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | sdf_col, smi_col = st.columns(2) |
| | with sdf_col: |
| | |
| | |
| | |
| | |
| | |
| | with open("sample_input.sdf", "r") as file_sdf: |
| | btn = st.download_button( |
| | label="Download SDF sample file", |
| | data=file_sdf, |
| | file_name="sample_input.sdf", |
| | ) |
| | with smi_col: |
| | with open("sample_input_smiles.csv", "r") as file_smi: |
| | btn = st.download_button( |
| | label="Download SMILES sample file", |
| | data=file_smi, |
| | file_name="sample_input_smiles.csv", |
| | ) |
| |
|
| | |
| | with upload_column: |
| | st.subheader("Model Selection") |
| | with st.container(): |
| | algorithm_col, resampler_col = st.columns(2) |
| | |
| | with algorithm_col: |
| | classifier = st.selectbox( |
| | label="Classification Algorithm:", |
| | options=("XGBoost", "kNN", "decision tree", "logistic regression"), |
| | key="classifier", |
| | help="Select the classification algorithm to use" |
| | ) |
| | with resampler_col: |
| | resampler = st.selectbox( |
| | label="Resampling Method:", |
| | options=( |
| | "ADASYN", |
| | "random undersampling", |
| | "Borderline SMOTE", |
| | "k-means SMOTE", |
| | "SMOTE", |
| | "no resampling", |
| | ), |
| | key="resampler", |
| | help="Select the resampling method to handle imbalanced data" |
| | ) |
| |
|
| | |
| | if "classifier" not in st.session_state: |
| | st.session_state.classifier = classifier |
| | if "resampler" not in st.session_state: |
| | st.session_state.resampler = resampler |
| |
|
| | |
| | st.divider() |
| | |
| | upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05)) |
| | |
| | with upload_col: |
| | |
| | if "uploaded_file" not in st.session_state: |
| | st.session_state.uploaded_file = None |
| | if "uploaded_file_changed" not in st.session_state: |
| | st.session_state.uploaded_file_changed = False |
| |
|
| | |
| | |
| | |
| |
|
| | uploaded_file = st.file_uploader( |
| | label="Upload a CSV, SDF, TXT or SMI file", |
| | type=["csv", "sdf", "txt", "smi"], |
| | help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.", |
| | accept_multiple_files=False, |
| | |
| | |
| | ) |
| |
|
| | if uploaded_file: |
| | |
| | |
| | if st.session_state.uploaded_file != uploaded_file: |
| | st.session_state.uploaded_file_changed = True |
| | else: |
| | st.session_state.uploaded_file_changed = False |
| | st.session_state.uploaded_file = uploaded_file |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | with submit_job_col: |
| | st.text(" \n") |
| | st.text(" \n") |
| | st.markdown( |
| | "<div style='display: flex; justify-content: center;'>", |
| | unsafe_allow_html=True, |
| | ) |
| | submit_job_button = st.button( |
| | label="Submit Job", |
| | type="secondary", |
| | key="job_button", |
| | help="Click to start calculations with current configuration" |
| | ) |
| | |
| | if not submit_job_button: |
| | if "results" in locals(): |
| | del results |
| | if "mol_features" in locals(): |
| | del mol_features |
| | if "info_df" in locals(): |
| | del info_df |
| |
|
| | |
| | feature_column, prediction_column = st.columns(2) |
| | with feature_column: |
| | st.subheader("Molecular Features") |
| | placeholder_features = st.empty() |
| |
|
| | with prediction_column: |
| | st.subheader("Predictions") |
| |
|
| | |
| | if submit_job_button: |
| | if not uploaded_file and not st.session_state.mol_features: |
| | st.warning("Please upload a file first or select data from history to process.") |
| | else: |
| | if st.session_state.processing: |
| | st.warning("A job is already running. Please wait for it to complete.") |
| | else: |
| | try: |
| | st.session_state.processing = True |
| | with st.spinner('Processing... Please wait.'): |
| | |
| | cleanup_temp_files() |
| | clear_cache() |
| | |
| | |
| | if uploaded_file: |
| | |
| | st.session_state.temp_dir = tempfile.mkdtemp() |
| | temp_file_path = os.path.join(st.session_state.temp_dir, uploaded_file.name) |
| | |
| | with open(temp_file_path, "wb") as temp_file: |
| | temp_file.write(uploaded_file.read()) |
| | |
| | |
| | if st.session_state.mol_features is not None and st.session_state.info_df is not None: |
| | st.session_state.historical_data.append({ |
| | 'mol_features': st.session_state.mol_features.copy(), |
| | 'info_df': st.session_state.info_df.copy() |
| | }) |
| | |
| | |
| | st.session_state.mol_features = None |
| | st.session_state.info_df = None |
| | |
| | try: |
| | mol_features, info_df, results = generate_predictions( |
| | input_fname=temp_file_path, |
| | sep="\s+|\t+", |
| | clf=classifiers_dict[st.session_state.classifier], |
| | _models_dict=all_models, |
| | sampling=resample_methods_dict[st.session_state.resampler], |
| | time_per_mol=120, |
| | mol_features=None, |
| | info_df=None, |
| | ) |
| | finally: |
| | |
| | cleanup_temp_files() |
| | |
| | |
| | else: |
| | mol_features, info_df, results = generate_predictions( |
| | input_fname=None, |
| | sep="\s+|\t+", |
| | clf=classifiers_dict[st.session_state.classifier], |
| | _models_dict=all_models, |
| | sampling=resample_methods_dict[st.session_state.resampler], |
| | time_per_mol=120, |
| | mol_features=st.session_state.mol_features, |
| | info_df=st.session_state.info_df, |
| | ) |
| | |
| | |
| | if mol_features is not None and info_df is not None: |
| | st.session_state.mol_features = mol_features |
| | st.session_state.info_df = info_df |
| | |
| | except Exception as e: |
| | st.error(f"Error during processing: {str(e)}") |
| | finally: |
| | st.session_state.processing = False |
| |
|
| | |
| | |
| | with feature_column: |
| | if st.session_state.mol_features is not None: |
| | selected_feature_rows = np.min( |
| | [st.session_state.mol_features.shape[0], pandas_display_options["line_limit"]] |
| | ) |
| | st.dataframe(st.session_state.mol_features.iloc[:selected_feature_rows, :], hide_index=False) |
| | |
| | feature_file_name = uploaded_file.name.split(".")[0] + "_b3clf_features.csv" |
| | features_csv = st.session_state.mol_features.to_csv(index=True) |
| | st.download_button( |
| | "Download features as CSV", |
| | data=features_csv, |
| | file_name=feature_file_name, |
| | ) |
| | |
| | with prediction_column: |
| | |
| | if results is not None: |
| | |
| | selected_result_rows = np.min( |
| | [results.shape[0], pandas_display_options["line_limit"]] |
| | ) |
| | results_df_display = results.iloc[:selected_result_rows, :].style.format( |
| | {"B3clf_predicted_probability": "{:.6f}".format} |
| | ) |
| | st.dataframe(results_df_display, hide_index=True) |
| | |
| | predictions_csv = results.to_csv(index=True) |
| | results_file_name = ( |
| | uploaded_file.name.split(".")[0] + "_b3clf_predictions.csv" |
| | ) |
| | st.download_button( |
| | "Download predictions as CSV", |
| | data=predictions_csv, |
| | file_name=results_file_name, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | st.balloons() |
| |
|
| |
|
| | |
| | |
| | hide_streamlit_style = """ |
| | <style> |
| | #MainMenu {visibility: hidden;} |
| | footer {visibility: hidden;} |
| | </style> |
| | """ |
| | st.markdown(hide_streamlit_style, unsafe_allow_html=True) |
| |
|
| | |
| | st.markdown( |
| | """ |
| | <!-- Google tag (gtag.js) --> |
| | <script async src="https://www.googletagmanager.com/gtag/js?id=G-WG8QYRELP9"></script> |
| | <script> |
| | window.dataLayer = window.dataLayer || []; |
| | function gtag(){dataLayer.push(arguments);} |
| | gtag("js", new Date()); |
| | |
| | gtag("config", "G-WG8QYRELP9"); |
| | </script> |
| | """, |
| | unsafe_allow_html=True, |
| | ) |
| |
|