Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| import pyarrow.parquet as pq | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from catboost import CatBoostClassifier, cv, Pool | |
| from sklearn.model_selection import StratifiedKFold | |
| from model_utils import get_model, save_model, save_model_ensemble, ensemble_predict | |
| from preprocess_utils import load_train_features | |
| from preprocess_utils import preprocess_pipeline as preprocess | |
| from inference_utils import compute_metrics, st_shap | |
| from sidebar import sidebar | |
| import shap | |
| import lime | |
| import lime.lime_tabular | |
| LOCAL = False | |
| if not LOCAL: | |
| from huggingface_hub import hf_hub_download | |
| SAVED_MODELS_DIR = Path("src/saved_models") | |
| SAVED_MODELS_DIR.mkdir(exist_ok=True) | |
| # Initialize sidebar | |
| sidebar() | |
| st.title("🧪 Preprocessing & Training") | |
| uploaded_file = st.file_uploader("Upload CSV", type=["csv"]) | |
| if uploaded_file: | |
| df = pd.read_csv(uploaded_file, header=1) | |
| st.write("Raw Data:") | |
| st.dataframe(df) | |
| st.session_state.target_col = st.selectbox( | |
| "Select target column to predict:", | |
| options=[ | |
| "GVHD", | |
| "Acute GVHD(<100 days)", | |
| "Chronic GVHD>100 days", | |
| ], | |
| index=0 | |
| ) | |
| if st.button("Preprocess"): | |
| df_proc = preprocess(df) | |
| st.session_state.edited_df = df_proc | |
| # Show the edited version if it's already in session state | |
| if "edited_df" in st.session_state: | |
| st.session_state.edited_df = st.data_editor(st.session_state.edited_df, num_rows="dynamic") | |
| if st.button("Re-train"): | |
| if "edited_df" not in st.session_state: | |
| st.warning("Please preprocess and edit data first.") | |
| else: | |
| # Model selection | |
| model_type = "CatBoost" # Fixed to CatBoost | |
| df = st.session_state.edited_df.copy() | |
| target_col = st.session_state.target_col | |
| if target_col in ["Acute GVHD(<100 days)", "Chronic GVHD>100 days"]: | |
| df = df[df[target_col] != 3] | |
| y = df[target_col] | |
| st.dataframe(df[target_col].value_counts(), width=250) | |
| train_features, cat_features = load_train_features() | |
| X = df[train_features] | |
| for col in cat_features: | |
| X[col] = X[col].astype(str) | |
| st.info("Running 5-Fold cross-validation with model saving...") | |
| skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0) | |
| fold_models = [] | |
| fold_scores = [] | |
| best_iterations = [] | |
| all_shap_values = [] | |
| all_base_values = [] | |
| all_data = [] | |
| for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), start=1): | |
| st.write(f"Training Fold {fold}...") | |
| X_train, X_val = X.iloc[train_idx], X.iloc[val_idx] | |
| y_train, y_val = y.iloc[train_idx], y.iloc[val_idx] | |
| train_pool = Pool(X_train, y_train, cat_features=cat_features) | |
| val_pool = Pool(X_val, y_val, cat_features=cat_features) | |
| model = get_model(model_type, mode="ensemble", target=target_col) | |
| if model_type == "CatBoost": | |
| model.fit( | |
| X_train, y_train, | |
| eval_set=(X_val, y_val), | |
| cat_features=cat_features, | |
| use_best_model=True, | |
| ) | |
| else: | |
| model.fit(X_train, y_train) | |
| best_iter = model.get_best_iteration() | |
| best_iterations.append(best_iter) | |
| fold_models.append(model) | |
| val_preds = model.predict_proba(X_val)[:, 1] | |
| fold_scores.append(model.eval_metrics(val_pool, ["AUC", "F1", "Accuracy", "Precision", "Recall", "BrierScore", "Logloss"], best_iter)) | |
| # Calculate SHAP values for the validation set | |
| explainer = shap.TreeExplainer(model) | |
| fold_shap_values = explainer(X_val) | |
| all_shap_values.append(fold_shap_values) | |
| all_base_values.append(fold_shap_values.base_values) | |
| all_data.append(X_val) | |
| st.success(f"Fold {fold} trained. Best iteration: {best_iter}") | |
| st.session_state.trained_models = fold_models | |
| st.session_state.fold_scores = fold_scores | |
| st.session_state.best_iterations = best_iterations | |
| # ---- Aggregate SHAP data across folds ---- | |
| st.session_state.all_shap_values_array = np.vstack([sv.values for sv in all_shap_values]) | |
| st.session_state.all_data_array = np.vstack([sv.data for sv in all_shap_values]) | |
| st.session_state.all_base_values_array = np.hstack([sv.base_values for sv in all_shap_values]) | |
| st.session_state.expected_value = np.mean(st.session_state.all_base_values_array) | |
| ### TURN OFF SINGLE MODEL TRAINING #### | |
| # Single model training | |
| st.session_state.best_iteration = np.max(st.session_state.best_iterations) # if "best_iterations" in st.session_state else 5000 | |
| final_model = get_model(model_type, mode="ensemble", target=target_col, best_iter=st.session_state.best_iteration) | |
| if model_type == "CatBoost": | |
| final_model.fit( | |
| X, y, | |
| cat_features=cat_features, | |
| ) | |
| else: | |
| final_model.fit(X, y) | |
| st.session_state.trained_model = final_model | |
| st.success("TRAINING OF ALL FOLDS COMPLETED.") | |
| # CV summary metrics | |
| if "fold_scores" in st.session_state: | |
| st.subheader("Cross-Validation Metrics (5-Fold)") | |
| metrics = ["AUC", "F1", "Accuracy", "Precision", "Recall", "BrierScore", "Logloss"] | |
| scores = st.session_state.fold_scores | |
| for metric in metrics: | |
| values = [score[metric][-1] for score in scores] # last = best_iteration | |
| mean_val = sum(values) / len(values) | |
| std_val = pd.Series(values).std() | |
| st.write(f"**{metric}**: {mean_val:.3f} ± {std_val:.3f}") | |
| # Single & ensemble evaluation | |
| if "trained_model" in st.session_state or "trained_models" in st.session_state: | |
| # st.subheader("🔮 Ensemble Evaluation (on Training Data)") | |
| models = st.session_state.trained_models | |
| ### TURN OFF SINGLE MODEL EVALUATION ### | |
| single_model = st.session_state.trained_model | |
| df = st.session_state.edited_df.copy() | |
| target_col = st.session_state.target_col | |
| if target_col in ["Acute GVHD(<100 days)", "Chronic GVHD>100 days"]: | |
| df = df[df[target_col] != 3] | |
| y = df[target_col] | |
| st.session_state.targets_df = y | |
| train_features, cat_features = load_train_features() | |
| X = df[train_features] | |
| for col in cat_features: | |
| X[col] = X[col].astype(str) | |
| ### TURN OFF SINGLE MODEL EVALUATION ### | |
| y_pred_prob_single = single_model.predict_proba(X)[:, 1] | |
| metrics_result_single = compute_metrics(y, y_pred_prob_single) | |
| y_pred_prob_ensemble = ensemble_predict(models, X, cat_features) | |
| metrics_result_ensemble = compute_metrics(y, y_pred_prob_ensemble) | |
| # ### TURN OFF SINGLE MODEL EVALUATION ### | |
| # st.write("Single Model Predictions:") | |
| # for metric, value in metrics_result_single.items(): | |
| # st.write(f"**{metric}**: {value:.3f}") | |
| # st.write("Ensemble Predictions:") | |
| # for metric, value in metrics_result_ensemble.items(): | |
| # st.write(f"**{metric}**: {value:.3f}") | |
| # Display SHAP explainability | |
| with st.expander("Show SHAP Explainability", expanded=True): | |
| # ---- Determine top features ---- | |
| def get_top_features(shap_values_array, feature_names, n=20): | |
| mean_abs_shap = np.abs(shap_values_array).mean(0) | |
| feature_importance = pd.DataFrame({ | |
| "feature": feature_names, | |
| "importance": mean_abs_shap | |
| }) | |
| return feature_importance.sort_values("importance", ascending=False)["feature"].tolist()[:n] | |
| top_features = get_top_features(st.session_state.all_shap_values_array, X.columns) | |
| # ---- Let user pick which features to visualize ---- | |
| selected_features = st.multiselect( | |
| "Select features to display in plots", | |
| options=X.columns.tolist(), | |
| default=top_features | |
| ) | |
| if not selected_features: | |
| st.warning("Please select at least one feature to display.") | |
| else: | |
| feature_indices = [list(X.columns).index(f) for f in selected_features] | |
| # Build filtered SHAP explanation | |
| shap_values_selected = shap.Explanation( | |
| values=st.session_state.all_shap_values_array[:, feature_indices], | |
| base_values=st.session_state.expected_value, | |
| data=st.session_state.all_data_array[:, feature_indices], | |
| feature_names=selected_features | |
| ) | |
| # ---- Force plot for one sample ---- | |
| st.subheader("SHAP Force Plot (Single Prediction)") | |
| # Create a DataFrame version of the data for easier display | |
| X_force = pd.DataFrame( | |
| shap_values_selected.data, | |
| columns=shap_values_selected.feature_names | |
| ) | |
| sample_idx = st.slider("Select sample index", 0, len(shap_values_selected.values) - 1, 0) | |
| # Display SHAP force plot | |
| st_shap( | |
| shap.force_plot( | |
| st.session_state.expected_value, | |
| shap_values_selected.values[sample_idx, :], | |
| X_force.iloc[sample_idx, :] | |
| ), | |
| height=200 | |
| ) | |
| # ---- Display feature + SHAP values for selected single-sample ---- | |
| st.markdown("**Feature values and SHAP contributions for this prediction:**") | |
| actual_values = X_force.iloc[sample_idx, :].to_frame().T | |
| shap_values_row = pd.DataFrame( | |
| [shap_values_selected.values[sample_idx, :]], | |
| columns=shap_values_selected.feature_names | |
| ) | |
| single_row_df = pd.concat( | |
| [actual_values, shap_values_row.round(4)], | |
| keys=["Actual Value", "SHAP Value"] | |
| ) | |
| st.dataframe(single_row_df, use_container_width=True) | |
| # ---- Download single sample ---- | |
| csv_data = single_row_df.to_csv(index=False).encode('utf-8') | |
| st.download_button( | |
| label="⬇️ Download single-sample SHAP CSV", | |
| data=csv_data, | |
| file_name=f"sample_{sample_idx}_features.csv", | |
| mime="text/csv" | |
| ) | |
| # ---- Force plot for all samples ---- | |
| st.subheader("SHAP Force Plot (All Predictions)") | |
| all_actual_df = X_force.copy() | |
| all_shap_df = pd.DataFrame( | |
| shap_values_selected.values, | |
| columns=[f"{col}" for col in shap_values_selected.feature_names] | |
| ) | |
| # Create merged DataFrame with suffixes | |
| all_combined_df = pd.concat( | |
| [all_actual_df.add_suffix("_actual"), all_shap_df.add_suffix("_shap")], | |
| axis=1 | |
| ) | |
| st_shap( | |
| shap.force_plot( | |
| st.session_state.expected_value, | |
| shap_values_selected.values, | |
| X_force | |
| ), | |
| height=400 | |
| ) | |
| # st.dataframe(all_combined_df.head(20), use_container_width=True) | |
| csv_download = all_combined_df.to_csv(index=False).encode("utf-8") | |
| filename = "all_SHAP_Values_5CV.csv" | |
| st.download_button( | |
| label=f"⬇️ Download 5-fold cross-validation SHAP CSV", | |
| data=csv_download, | |
| file_name=filename, | |
| mime="text/csv" | |
| ) | |
| # ---- Beeswarm: overall feature impact ---- | |
| # st.subheader("SHAP Feature Importance (Beeswarm)") | |
| st.subheader("SHAP Feature Importance") | |
| plt.figure(figsize=(10, 6)) | |
| shap.plots.beeswarm(shap_values_selected, max_display=20, show=False) | |
| st.pyplot(plt.gcf(), bbox_inches='tight') | |
| plt.clf() | |
| # ---- Mean absolute SHAP bar chart ---- | |
| st.subheader("Mean(|SHAP value|) per Feature") | |
| plt.figure(figsize=(10, 6)) | |
| shap.plots.bar(shap_values_selected, max_display=20, show=False) | |
| st.pyplot(plt.gcf(), bbox_inches='tight') | |
| plt.clf() | |
| # ---- Dependence plot ---- | |
| st.subheader("SHAP Dependence Plot") | |
| feature = st.selectbox("Select main feature", selected_features) | |
| interaction_feature = st.selectbox( | |
| "Select interaction feature (optional)", | |
| ["None"] + selected_features | |
| ) | |
| plt.figure(figsize=(10, 6)) | |
| shap.dependence_plot( | |
| feature, | |
| shap_values_selected.values, | |
| pd.DataFrame(shap_values_selected.data, columns=selected_features), | |
| interaction_index=None if interaction_feature == "None" else interaction_feature, | |
| show=False | |
| ) | |
| st.pyplot(plt.gcf(), bbox_inches='tight') | |
| plt.clf() | |
| # # Display LIME explainability | |
| # if st.button("Show LIME Explainability"): | |
| # # Load the trained model from session state | |
| # model = st.session_state.trained_model | |
| # df = st.session_state.edited_df.copy() | |
| # target_col = st.session_state.target_col | |
| # train_features, cat_features = load_train_features() | |
| # X = df[train_features] | |
| # for col in cat_features: | |
| # X[col] = X[col].astype(str) | |
| # # Prepare the LIME explainer | |
| # explainer = lime.lime_tabular.LimeTabularExplainer( | |
| # training_data=X.values, | |
| # feature_names=X.columns, | |
| # class_names=[target_col], | |
| # mode='classification', | |
| # categorical_features=[i for i, c in enumerate(X.columns) if c in cat_features], | |
| # ) | |
| # # Pick a random sample to explain | |
| # idx = np.random.randint(0, len(X)) | |
| # explanation = explainer.explain_instance(X.iloc[idx].values, model.predict_proba, num_features=10) | |
| # # Show explanation in Streamlit | |
| # st.subheader("LIME Explanation for Random Sample") | |
| # explanation.show_in_notebook() | |
| # st.pyplot(bbox_inches="tight") | |
| user_model_name = st.text_input("Enter model name to be saved:") | |
| if user_model_name: | |
| ### single model saving | |
| single_filename = save_model(st.session_state.trained_model, user_model_name, metrics_result_single) | |
| # ensemble model saving | |
| ensemble_filename = save_model_ensemble( | |
| st.session_state.trained_models, | |
| user_model_name, | |
| best_iterations=st.session_state.best_iterations, | |
| fold_scores=st.session_state.fold_scores, | |
| metrics_result_ensemble=metrics_result_ensemble | |
| ) | |
| st.success(f"{ensemble_filename} and {single_filename} is successfully saved!") | |
| if not LOCAL: | |
| def get_model_bytes(parquet_filename): | |
| """Download and extract model bytes from Hugging Face parquet.""" | |
| date_folder = parquet_filename.split('_')[0] | |
| if not os.path.exists(date_folder): | |
| os.makedirs(date_folder) | |
| path = hf_hub_download( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| filename=f"models/{date_folder}/{parquet_filename}", | |
| token=os.environ["HF_TOKEN"] | |
| ) | |
| table = pq.read_table(path) | |
| row = table.to_pylist()[0] | |
| return row["model_file"]["bytes"] | |
| single_parquet = single_filename + ".parquet" | |
| ensemble_parquet = ensemble_filename + ".parquet" | |
| # Get model bytes | |
| single_bytes = get_model_bytes(single_parquet) | |
| ensemble_bytes = get_model_bytes(ensemble_parquet) | |
| # Show download buttons | |
| st.download_button( | |
| label="⬇️ Download Single Model", | |
| data=single_bytes, | |
| file_name=single_parquet, | |
| mime="application/octet-stream" | |
| ) | |
| st.download_button( | |
| label="⬇️ Download Ensemble Model", | |
| data=ensemble_bytes, | |
| file_name=ensemble_parquet, | |
| mime="application/octet-stream" | |
| ) | |
| # Show saved model paths | |
| st.success(f"Models saved to:\n- Single model: {SAVED_MODELS_DIR / (single_filename + '.pkl')}\n- Ensemble model: {SAVED_MODELS_DIR / (ensemble_filename + '.pkl')}") | |
| else: | |
| pass | |