Update src/streamlit_app.py
Browse files- src/streamlit_app.py +42 -7
src/streamlit_app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
import time
|
|
@@ -654,7 +654,7 @@ with tabs[4]:
|
|
| 654 |
result = tune_family(fam, X, y, n_trials=max_trials)
|
| 655 |
model_obj = result.get("model_obj")
|
| 656 |
|
| 657 |
-
#
|
| 658 |
if hasattr(model_obj, "estimators_"):
|
| 659 |
delattr(model_obj, "estimators_") # clear stale ref if any
|
| 660 |
result["model_obj"] = model_obj
|
|
@@ -768,11 +768,39 @@ with tabs[4]:
|
|
| 768 |
try:
|
| 769 |
top_base = next((b for b in base_models if b["family"] == selected[0]), None)
|
| 770 |
if top_base and hasattr(top_base["model"], "predict"):
|
| 771 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 772 |
model = top_base["model"]
|
| 773 |
expl = shap.TreeExplainer(model)
|
| 774 |
shap_vals = expl.shap_values(sample_X)
|
| 775 |
-
if isinstance(shap_vals, list):
|
|
|
|
| 776 |
shap_vals = np.array(shap_vals)
|
| 777 |
mean_abs = np.abs(shap_vals).mean(axis=0)
|
| 778 |
mean_sign = np.sign(shap_vals).mean(axis=0)
|
|
@@ -781,17 +809,24 @@ with tabs[4]:
|
|
| 781 |
"Mean |SHAP|": mean_abs,
|
| 782 |
"Mean SHAP Sign": mean_sign
|
| 783 |
}).sort_values("Mean |SHAP|", ascending=False)
|
|
|
|
| 784 |
st.markdown("### Top 5 Operational Drivers")
|
| 785 |
st.dataframe(importance.head(5))
|
|
|
|
| 786 |
recommendations = []
|
| 787 |
for _, row in importance.head(5).iterrows():
|
| 788 |
f, s = row["Feature"], row["Mean SHAP Sign"]
|
| 789 |
-
if s > 0.05:
|
| 790 |
-
|
| 791 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
st.markdown("### Suggested Operator Adjustments")
|
| 793 |
st.write("\n".join(recommendations))
|
| 794 |
|
|
|
|
| 795 |
# --- Call HF Llama-3-70B-Instruct API for summary ---
|
| 796 |
import requests
|
| 797 |
HF_TOKEN = st.secrets.get("HF_TOKEN", os.getenv("HF_TOKEN"))
|
|
|
|
| 1 |
+
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
import time
|
|
|
|
| 654 |
result = tune_family(fam, X, y, n_trials=max_trials)
|
| 655 |
model_obj = result.get("model_obj")
|
| 656 |
|
| 657 |
+
# Fix: ensure model is safe to access before fitting
|
| 658 |
if hasattr(model_obj, "estimators_"):
|
| 659 |
delattr(model_obj, "estimators_") # clear stale ref if any
|
| 660 |
result["model_obj"] = model_obj
|
|
|
|
| 768 |
try:
|
| 769 |
top_base = next((b for b in base_models if b["family"] == selected[0]), None)
|
| 770 |
if top_base and hasattr(top_base["model"], "predict"):
|
| 771 |
+
# --- Ensure numeric dtypes for SHAP ---
|
| 772 |
+
sample_X = X_val.sample(min(300, len(X_val)), random_state=42).copy()
|
| 773 |
+
for col in sample_X.columns:
|
| 774 |
+
if sample_X[col].dtype == object:
|
| 775 |
+
# Clean any bracketed, comma, or sci-notation strings
|
| 776 |
+
sample_X[col] = (
|
| 777 |
+
sample_X[col]
|
| 778 |
+
.astype(str)
|
| 779 |
+
.str.replace("[", "", regex=False)
|
| 780 |
+
.str.replace("]", "", regex=False)
|
| 781 |
+
.str.replace(",", "", regex=False)
|
| 782 |
+
.str.replace("E", "e", regex=False)
|
| 783 |
+
.str.replace("nan", "0", regex=False)
|
| 784 |
+
.str.strip()
|
| 785 |
+
)
|
| 786 |
+
# Force numeric conversion for all columns
|
| 787 |
+
sample_X[col] = pd.to_numeric(sample_X[col], errors="coerce")
|
| 788 |
+
|
| 789 |
+
# Replace NaN with 0 for SHAP stability
|
| 790 |
+
sample_X = sample_X.fillna(0)
|
| 791 |
+
|
| 792 |
+
# Optional: show columns that were coerced
|
| 793 |
+
non_numeric_cols = [c for c in sample_X.columns if not np.issubdtype(sample_X[c].dtype, np.number)]
|
| 794 |
+
if non_numeric_cols:
|
| 795 |
+
st.warning(f"Non-numeric columns coerced: {non_numeric_cols}")
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
# --- SHAP computation ---
|
| 799 |
model = top_base["model"]
|
| 800 |
expl = shap.TreeExplainer(model)
|
| 801 |
shap_vals = expl.shap_values(sample_X)
|
| 802 |
+
if isinstance(shap_vals, list):
|
| 803 |
+
shap_vals = shap_vals[0]
|
| 804 |
shap_vals = np.array(shap_vals)
|
| 805 |
mean_abs = np.abs(shap_vals).mean(axis=0)
|
| 806 |
mean_sign = np.sign(shap_vals).mean(axis=0)
|
|
|
|
| 809 |
"Mean |SHAP|": mean_abs,
|
| 810 |
"Mean SHAP Sign": mean_sign
|
| 811 |
}).sort_values("Mean |SHAP|", ascending=False)
|
| 812 |
+
|
| 813 |
st.markdown("### Top 5 Operational Drivers")
|
| 814 |
st.dataframe(importance.head(5))
|
| 815 |
+
|
| 816 |
recommendations = []
|
| 817 |
for _, row in importance.head(5).iterrows():
|
| 818 |
f, s = row["Feature"], row["Mean SHAP Sign"]
|
| 819 |
+
if s > 0.05:
|
| 820 |
+
recommendations.append(f"Increase `{f}` likely increases `{target}`")
|
| 821 |
+
elif s < -0.05:
|
| 822 |
+
recommendations.append(f"Decrease `{f}` likely increases `{target}`")
|
| 823 |
+
else:
|
| 824 |
+
recommendations.append(f"`{f}` neutral for `{target}`")
|
| 825 |
+
|
| 826 |
st.markdown("### Suggested Operator Adjustments")
|
| 827 |
st.write("\n".join(recommendations))
|
| 828 |
|
| 829 |
+
|
| 830 |
# --- Call HF Llama-3-70B-Instruct API for summary ---
|
| 831 |
import requests
|
| 832 |
HF_TOKEN = st.secrets.get("HF_TOKEN", os.getenv("HF_TOKEN"))
|