Update src/streamlit_app.py
Browse files- src/streamlit_app.py +86 -71
src/streamlit_app.py
CHANGED
|
@@ -651,109 +651,124 @@ with tabs[4]:
|
|
| 651 |
meta = Ridge(alpha=1.0)
|
| 652 |
meta.fit(X_stack, y)
|
| 653 |
|
| 654 |
-
#
|
|
|
|
| 655 |
X_tr, X_val, y_tr, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 656 |
-
|
| 657 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
base_model_map = {bm["family"]: bm["model"] for bm in base_models}
|
| 659 |
-
|
| 660 |
meta_inputs = []
|
| 661 |
missing_families = []
|
|
|
|
|
|
|
|
|
|
| 662 |
for fam in selected:
|
| 663 |
-
bm = base_model_map.get(fam
|
| 664 |
if bm is None:
|
| 665 |
-
# missing base model: fill with training mean as safe fallback
|
| 666 |
missing_families.append(fam)
|
| 667 |
-
|
|
|
|
| 668 |
continue
|
| 669 |
-
|
| 670 |
try:
|
| 671 |
preds = bm.predict(X_val)
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
if
|
| 675 |
-
|
| 676 |
-
|
|
|
|
|
|
|
| 677 |
meta_inputs.append(preds)
|
| 678 |
-
except Exception:
|
| 679 |
-
|
| 680 |
-
meta_inputs.append(np.full(len(X_val),
|
| 681 |
-
|
| 682 |
-
if
|
| 683 |
-
st.warning(f"
|
| 684 |
-
|
| 685 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
X_meta_val = np.column_stack(meta_inputs)
|
| 687 |
-
|
| 688 |
-
# Defensive check: ensure X_meta_val has same number of cols as meta was trained on
|
| 689 |
-
n_meta_features_trained = X_stack.shape[1]
|
| 690 |
n_meta_features_val = X_meta_val.shape[1]
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
# final safety assert (will raise an informative error if still wrong)
|
| 703 |
if X_meta_val.shape[1] != n_meta_features_trained:
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
|
|
|
| 707 |
y_meta_pred = meta.predict(X_meta_val)
|
| 708 |
-
|
| 709 |
-
|
| 710 |
final_r2 = r2_score(y_val, y_meta_pred)
|
| 711 |
final_rmse = mean_squared_error(y_val, y_meta_pred, squared=False)
|
| 712 |
-
|
| 713 |
c1, c2 = st.columns(2)
|
| 714 |
c1.metric("Stacked Ensemble R² (holdout)", f"{final_r2:.4f}")
|
| 715 |
c2.metric("Stacked Ensemble RMSE (holdout)", f"{final_rmse:.4f}")
|
| 716 |
-
|
| 717 |
-
#
|
| 718 |
-
fig, ax = plt.subplots(figsize=(7,4))
|
| 719 |
ax.scatter(y_val, y_meta_pred, alpha=0.6)
|
| 720 |
ax.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--")
|
| 721 |
-
ax.set_xlabel("Actual")
|
|
|
|
| 722 |
st.pyplot(fig)
|
| 723 |
-
|
| 724 |
-
#
|
| 725 |
-
stack_artifact = os.path.join(DATA_DIR, f"stacked_{use_case.replace(' ','_')}.joblib")
|
| 726 |
-
to_save = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
joblib.dump(to_save, stack_artifact)
|
| 728 |
-
st.caption(f"Stacked ensemble saved: {stack_artifact}")
|
| 729 |
-
|
| 730 |
-
#
|
| 731 |
st.markdown("### Explainability (approximate)")
|
| 732 |
try:
|
| 733 |
-
# Prefer SHAP on top base model (tree) for interpretability
|
| 734 |
top_base = next((b for b in base_models if b["family"] == selected[0]), None)
|
| 735 |
-
if top_base
|
| 736 |
-
# sample for speed
|
| 737 |
sample_X = X_val.sample(min(300, len(X_val)), random_state=42)
|
| 738 |
-
if
|
| 739 |
-
expl =
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
fig_sh = plt.figure(figsize=(8,6))
|
| 745 |
-
shap.summary_plot(shap_vals, sample_X, show=False)
|
| 746 |
-
st.pyplot(fig_sh)
|
| 747 |
-
except Exception as e:
|
| 748 |
-
st.warning(f"SHAP tree explainer unavailable: {e}")
|
| 749 |
else:
|
| 750 |
-
st.info("Top
|
| 751 |
else:
|
| 752 |
-
st.info("No suitable base model for SHAP explanation
|
| 753 |
except Exception as e:
|
| 754 |
-
st.warning(f"SHAP
|
|
|
|
|
|
|
|
|
|
| 755 |
|
| 756 |
-
st.success("AutoML + Stacking complete. Review metrics and saved artifacts.")
|
| 757 |
|
| 758 |
# ----- Target & Business Impact tab
|
| 759 |
with tabs[5]:
|
|
|
|
| 651 |
meta = Ridge(alpha=1.0)
|
| 652 |
meta.fit(X_stack, y)
|
| 653 |
|
| 654 |
+
# --- Robust holdout evaluation & SHAP (safe for deployment) ---
|
| 655 |
+
# Split for holdout
|
| 656 |
X_tr, X_val, y_tr, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 657 |
+
|
| 658 |
+
# Helper to always produce scalar-safe mean
|
| 659 |
+
def scalar_mean(arr):
|
| 660 |
+
try:
|
| 661 |
+
return float(np.mean(arr))
|
| 662 |
+
except Exception:
|
| 663 |
+
return float(np.mean(np.ravel(arr)))
|
| 664 |
+
|
| 665 |
+
# Build family → model map
|
| 666 |
base_model_map = {bm["family"]: bm["model"] for bm in base_models}
|
| 667 |
+
|
| 668 |
meta_inputs = []
|
| 669 |
missing_families = []
|
| 670 |
+
n_meta_features_trained = X_stack.shape[1]
|
| 671 |
+
|
| 672 |
+
# Collect predictions from each selected model
|
| 673 |
for fam in selected:
|
| 674 |
+
bm = base_model_map.get(fam)
|
| 675 |
if bm is None:
|
|
|
|
| 676 |
missing_families.append(fam)
|
| 677 |
+
safe_mean = scalar_mean(y_tr)
|
| 678 |
+
meta_inputs.append(np.full(len(X_val), safe_mean))
|
| 679 |
continue
|
| 680 |
+
|
| 681 |
try:
|
| 682 |
preds = bm.predict(X_val)
|
| 683 |
+
preds = np.asarray(preds)
|
| 684 |
+
# Collapse multi-output predictions to 1D
|
| 685 |
+
if preds.ndim == 2:
|
| 686 |
+
preds = preds.mean(axis=1)
|
| 687 |
+
preds = preds.reshape(-1)
|
| 688 |
+
if preds.shape[0] != len(X_val):
|
| 689 |
+
preds = np.full(len(X_val), scalar_mean(y_tr))
|
| 690 |
meta_inputs.append(preds)
|
| 691 |
+
except Exception as e:
|
| 692 |
+
safe_mean = scalar_mean(y_tr)
|
| 693 |
+
meta_inputs.append(np.full(len(X_val), safe_mean))
|
| 694 |
+
|
| 695 |
+
if missing_families:
|
| 696 |
+
st.warning(f"Missing base models: {missing_families}. Using mean predictions.")
|
| 697 |
+
|
| 698 |
+
# Stack meta features
|
| 699 |
+
if not meta_inputs:
|
| 700 |
+
st.error("No meta features to predict — aborting.")
|
| 701 |
+
st.stop()
|
| 702 |
+
|
| 703 |
X_meta_val = np.column_stack(meta_inputs)
|
|
|
|
|
|
|
|
|
|
| 704 |
n_meta_features_val = X_meta_val.shape[1]
|
| 705 |
+
|
| 706 |
+
# Align meta features between training and validation
|
| 707 |
+
if n_meta_features_val < n_meta_features_trained:
|
| 708 |
+
pad_cols = n_meta_features_trained - n_meta_features_val
|
| 709 |
+
safe_mean = scalar_mean(y_tr)
|
| 710 |
+
pad = np.tile(np.full((len(X_val), 1), safe_mean), (1, pad_cols))
|
| 711 |
+
X_meta_val = np.hstack([X_meta_val, pad])
|
| 712 |
+
elif n_meta_features_val > n_meta_features_trained:
|
| 713 |
+
X_meta_val = X_meta_val[:, :n_meta_features_trained]
|
| 714 |
+
|
|
|
|
|
|
|
| 715 |
if X_meta_val.shape[1] != n_meta_features_trained:
|
| 716 |
+
st.error(f"Stack alignment failed: {X_meta_val.shape[1]} != {n_meta_features_trained}")
|
| 717 |
+
st.stop()
|
| 718 |
+
|
| 719 |
+
# Meta prediction
|
| 720 |
y_meta_pred = meta.predict(X_meta_val)
|
| 721 |
+
|
| 722 |
+
# Final evaluation
|
| 723 |
final_r2 = r2_score(y_val, y_meta_pred)
|
| 724 |
final_rmse = mean_squared_error(y_val, y_meta_pred, squared=False)
|
| 725 |
+
|
| 726 |
c1, c2 = st.columns(2)
|
| 727 |
c1.metric("Stacked Ensemble R² (holdout)", f"{final_r2:.4f}")
|
| 728 |
c2.metric("Stacked Ensemble RMSE (holdout)", f"{final_rmse:.4f}")
|
| 729 |
+
|
| 730 |
+
# Scatter comparison
|
| 731 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 732 |
ax.scatter(y_val, y_meta_pred, alpha=0.6)
|
| 733 |
ax.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--")
|
| 734 |
+
ax.set_xlabel("Actual")
|
| 735 |
+
ax.set_ylabel("Stacked Predicted")
|
| 736 |
st.pyplot(fig)
|
| 737 |
+
|
| 738 |
+
# Save trained stack artifacts
|
| 739 |
+
stack_artifact = os.path.join(DATA_DIR, f"stacked_{use_case.replace(' ', '_')}.joblib")
|
| 740 |
+
to_save = {
|
| 741 |
+
"base_models": {bm["family"]: bm["model"] for bm in base_models if bm["family"] in selected},
|
| 742 |
+
"meta": meta,
|
| 743 |
+
"features": features,
|
| 744 |
+
"selected": selected,
|
| 745 |
+
"target": target,
|
| 746 |
+
}
|
| 747 |
joblib.dump(to_save, stack_artifact)
|
| 748 |
+
st.caption(f"✅ Stacked ensemble saved: {stack_artifact}")
|
| 749 |
+
|
| 750 |
+
# Explainability
|
| 751 |
st.markdown("### Explainability (approximate)")
|
| 752 |
try:
|
|
|
|
| 753 |
top_base = next((b for b in base_models if b["family"] == selected[0]), None)
|
| 754 |
+
if top_base and hasattr(top_base["model"], "predict"):
|
|
|
|
| 755 |
sample_X = X_val.sample(min(300, len(X_val)), random_state=42)
|
| 756 |
+
if any(k in top_base["family"] for k in ["XGBoost", "LightGBM", "RandomForest", "ExtraTrees", "CatBoost"]):
|
| 757 |
+
expl = shap.TreeExplainer(top_base["model"])
|
| 758 |
+
shap_vals = expl.shap_values(sample_X)
|
| 759 |
+
fig_sh = plt.figure(figsize=(8, 6))
|
| 760 |
+
shap.summary_plot(shap_vals, sample_X, show=False)
|
| 761 |
+
st.pyplot(fig_sh)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
else:
|
| 763 |
+
st.info("Top model not tree-based; skipping SHAP summary.")
|
| 764 |
else:
|
| 765 |
+
st.info("No suitable base model for SHAP explanation.")
|
| 766 |
except Exception as e:
|
| 767 |
+
st.warning(f"SHAP computation skipped: {e}")
|
| 768 |
+
|
| 769 |
+
st.success("✅ AutoML + Stacking complete — metrics, artifacts, and SHAP ready.")
|
| 770 |
+
|
| 771 |
|
|
|
|
| 772 |
|
| 773 |
# ----- Target & Business Impact tab
|
| 774 |
with tabs[5]:
|