Update src/streamlit_app.py
Browse files- src/streamlit_app.py +20 -0
src/streamlit_app.py
CHANGED
|
@@ -870,13 +870,33 @@ with tabs[4]:
|
|
| 870 |
# SHAP direction analysis
|
| 871 |
expl = shap.TreeExplainer(model)
|
| 872 |
shap_vals = expl.shap_values(sample_X)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
mean_abs = np.abs(shap_vals).mean(axis=0)
|
| 874 |
mean_sign = np.sign(shap_vals).mean(axis=0)
|
|
|
|
| 875 |
importance = pd.DataFrame({
|
| 876 |
"Feature": sample_X.columns,
|
| 877 |
"Mean |SHAP|": mean_abs,
|
| 878 |
"Mean SHAP Sign": mean_sign
|
| 879 |
}).sort_values("Mean |SHAP|", ascending=False)
|
|
|
|
| 880 |
|
| 881 |
# Display Top 5 Drivers
|
| 882 |
st.markdown("### Top 5 Operational Drivers Influencing Target")
|
|
|
|
| 870 |
# SHAP direction analysis
|
| 871 |
expl = shap.TreeExplainer(model)
|
| 872 |
shap_vals = expl.shap_values(sample_X)
|
| 873 |
+
|
| 874 |
+
# --- Normalize SHAP output structure (handles list, ndarray, or multi-dim cases) ---
|
| 875 |
+
if isinstance(shap_vals, list): # e.g., for multiclass models
|
| 876 |
+
shap_vals = shap_vals[0]
|
| 877 |
+
|
| 878 |
+
shap_vals = np.array(shap_vals)
|
| 879 |
+
|
| 880 |
+
# If SHAP output has >2 dims, reduce to (n_samples, n_features)
|
| 881 |
+
if shap_vals.ndim > 2:
|
| 882 |
+
shap_vals = shap_vals.reshape(shap_vals.shape[0], -1)
|
| 883 |
+
|
| 884 |
+
# Align SHAP features to DataFrame
|
| 885 |
+
if shap_vals.shape[1] != sample_X.shape[1]:
|
| 886 |
+
min_feats = min(shap_vals.shape[1], sample_X.shape[1])
|
| 887 |
+
shap_vals = shap_vals[:, :min_feats]
|
| 888 |
+
sample_X = sample_X.iloc[:, :min_feats]
|
| 889 |
+
|
| 890 |
+
# Compute robust means
|
| 891 |
mean_abs = np.abs(shap_vals).mean(axis=0)
|
| 892 |
mean_sign = np.sign(shap_vals).mean(axis=0)
|
| 893 |
+
|
| 894 |
importance = pd.DataFrame({
|
| 895 |
"Feature": sample_X.columns,
|
| 896 |
"Mean |SHAP|": mean_abs,
|
| 897 |
"Mean SHAP Sign": mean_sign
|
| 898 |
}).sort_values("Mean |SHAP|", ascending=False)
|
| 899 |
+
|
| 900 |
|
| 901 |
# Display Top 5 Drivers
|
| 902 |
st.markdown("### Top 5 Operational Drivers Influencing Target")
|