Spaces:
Sleeping
Sleeping
move st_shap to inference_utils
Browse files
src/inference_utils.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import streamlit as st
|
|
|
|
| 3 |
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, brier_score_loss, log_loss
|
|
|
|
| 4 |
|
| 5 |
def compute_metrics(y_true, y_pred_proba, threshold=0.5):
|
| 6 |
y_pred = (y_pred_proba >= threshold).astype(int)
|
|
@@ -37,3 +39,51 @@ def add_predictions(df, probs):
|
|
| 37 |
)
|
| 38 |
|
| 39 |
return df_styled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import streamlit as st
|
| 3 |
+
import shap
|
| 4 |
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, brier_score_loss, log_loss
|
| 5 |
+
import streamlit.components.v1 as components
|
| 6 |
|
| 7 |
def compute_metrics(y_true, y_pred_proba, threshold=0.5):
|
| 8 |
y_pred = (y_pred_proba >= threshold).astype(int)
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
return df_styled
|
| 42 |
+
|
| 43 |
+
def st_shap(plot, height=None):
|
| 44 |
+
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
|
| 45 |
+
components.html(shap_html, height=height)
|
| 46 |
+
|
| 47 |
+
def ensemble_shap(models, X, model_weights=None):
|
| 48 |
+
"""
|
| 49 |
+
Compute ensemble SHAP values for a list of tree-based models.
|
| 50 |
+
Returns a shap.Explanation with mean SHAP values across models.
|
| 51 |
+
"""
|
| 52 |
+
import numpy as np
|
| 53 |
+
import shap
|
| 54 |
+
|
| 55 |
+
all_values = []
|
| 56 |
+
all_base_values = []
|
| 57 |
+
|
| 58 |
+
for model in models:
|
| 59 |
+
explainer = shap.TreeExplainer(model)
|
| 60 |
+
shap_values = explainer(X)
|
| 61 |
+
|
| 62 |
+
# Handle binary classification
|
| 63 |
+
if shap_values.values.ndim == 3:
|
| 64 |
+
# safer class selection
|
| 65 |
+
class_index = getattr(model, "classes_", [0, 1]).index(1)
|
| 66 |
+
shap_values = shap.Explanation(
|
| 67 |
+
values=shap_values.values[:, :, class_index],
|
| 68 |
+
base_values=shap_values.base_values[:, class_index],
|
| 69 |
+
data=X,
|
| 70 |
+
feature_names=X.columns
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
all_values.append(shap_values.values)
|
| 74 |
+
all_base_values.append(shap_values.base_values)
|
| 75 |
+
|
| 76 |
+
# Handle weights
|
| 77 |
+
if model_weights is None:
|
| 78 |
+
model_weights = np.ones(len(models))
|
| 79 |
+
model_weights = np.array(model_weights) / np.sum(model_weights)
|
| 80 |
+
|
| 81 |
+
mean_values = np.average(all_values, axis=0, weights=model_weights)
|
| 82 |
+
mean_base = np.average(all_base_values, axis=0, weights=model_weights)
|
| 83 |
+
|
| 84 |
+
return shap.Explanation(
|
| 85 |
+
values=mean_values,
|
| 86 |
+
base_values=mean_base,
|
| 87 |
+
data=X,
|
| 88 |
+
feature_names=X.columns
|
| 89 |
+
)
|
src/pages/3_Preprocessing_and_Training.py
CHANGED
|
@@ -1,26 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
from catboost import CatBoostClassifier, cv, Pool
|
|
|
|
|
|
|
| 6 |
from model_utils import get_model, save_model, save_model_ensemble, ensemble_predict
|
| 7 |
from preprocess_utils import load_train_features
|
| 8 |
from preprocess_utils import preprocess_pipeline as preprocess
|
| 9 |
-
from inference_utils import compute_metrics
|
| 10 |
from sidebar import sidebar
|
| 11 |
-
|
| 12 |
-
import os
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
import pyarrow.parquet as pq
|
| 15 |
import shap
|
| 16 |
import lime
|
| 17 |
import lime.lime_tabular
|
| 18 |
-
# Add this helper function at the top of the file
|
| 19 |
-
import streamlit.components.v1 as components
|
| 20 |
-
|
| 21 |
-
def st_shap(plot, height=None):
|
| 22 |
-
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
|
| 23 |
-
components.html(shap_html, height=height)
|
| 24 |
|
| 25 |
LOCAL = False
|
| 26 |
|
|
@@ -255,7 +252,7 @@ if "trained_model" in st.session_state or "trained_models" in st.session_state:
|
|
| 255 |
shap_values_selected.values[sample_idx, :],
|
| 256 |
X_force.iloc[sample_idx, :]
|
| 257 |
),
|
| 258 |
-
height=
|
| 259 |
)
|
| 260 |
|
| 261 |
# ---- Display feature + SHAP values for selected single-sample ----
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import pyarrow.parquet as pq
|
| 4 |
+
|
| 5 |
import streamlit as st
|
| 6 |
import pandas as pd
|
| 7 |
import numpy as np
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from catboost import CatBoostClassifier, cv, Pool
|
| 10 |
+
from sklearn.model_selection import StratifiedKFold
|
| 11 |
+
|
| 12 |
from model_utils import get_model, save_model, save_model_ensemble, ensemble_predict
|
| 13 |
from preprocess_utils import load_train_features
|
| 14 |
from preprocess_utils import preprocess_pipeline as preprocess
|
| 15 |
+
from inference_utils import compute_metrics, st_shap
|
| 16 |
from sidebar import sidebar
|
| 17 |
+
|
|
|
|
|
|
|
|
|
|
| 18 |
import shap
|
| 19 |
import lime
|
| 20 |
import lime.lime_tabular
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
LOCAL = False
|
| 23 |
|
|
|
|
| 252 |
shap_values_selected.values[sample_idx, :],
|
| 253 |
X_force.iloc[sample_idx, :]
|
| 254 |
),
|
| 255 |
+
height=200
|
| 256 |
)
|
| 257 |
|
| 258 |
# ---- Display feature + SHAP values for selected single-sample ----
|