singhn9 commited on
Commit
50ad074
·
verified ·
1 Parent(s): 133a2b4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +86 -71
src/streamlit_app.py CHANGED
@@ -651,109 +651,124 @@ with tabs[4]:
651
  meta = Ridge(alpha=1.0)
652
  meta.fit(X_stack, y)
653
 
654
- # evaluate stacked ensemble on a holdout split
 
655
  X_tr, X_val, y_tr, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
656
- # predict with base models -> create meta inputs
657
- # build a map of family -> fitted model
 
 
 
 
 
 
 
658
  base_model_map = {bm["family"]: bm["model"] for bm in base_models}
659
-
660
  meta_inputs = []
661
  missing_families = []
 
 
 
662
  for fam in selected:
663
- bm = base_model_map.get(fam, None)
664
  if bm is None:
665
- # missing base model: fill with training mean as safe fallback
666
  missing_families.append(fam)
667
- meta_inputs.append(np.full(len(X_val), y_tr.mean()))
 
668
  continue
669
-
670
  try:
671
  preds = bm.predict(X_val)
672
- # make sure preds is 1D and correct length
673
- preds = np.asarray(preds).reshape(-1)
674
- if len(preds) != len(X_val):
675
- # fallback to mean if shape mismatch
676
- preds = np.full(len(X_val), y_tr.mean())
 
 
677
  meta_inputs.append(preds)
678
- except Exception:
679
- # fallback to mean predictions on error
680
- meta_inputs.append(np.full(len(X_val), y_tr.mean()))
681
-
682
- if len(missing_families) > 0:
683
- st.warning(f"Warning: missing base models for families: {missing_families}. Filled with mean predictions.")
684
-
685
- # Now stack into (n_samples, n_models_selected)
 
 
 
 
686
  X_meta_val = np.column_stack(meta_inputs)
687
-
688
- # Defensive check: ensure X_meta_val has same number of cols as meta was trained on
689
- n_meta_features_trained = X_stack.shape[1]
690
  n_meta_features_val = X_meta_val.shape[1]
691
- if n_meta_features_val != n_meta_features_trained:
692
- st.warning(f"Meta feature mismatch: trained on {n_meta_features_trained} cols, validating with {n_meta_features_val} cols. Aligning by padding/truncating.")
693
- # If fewer cols, pad with columns of means
694
- if n_meta_features_val < n_meta_features_trained:
695
- pad_cols = n_meta_features_trained - n_meta_features_val
696
- pad = np.tile(np.full((len(X_val),1), y_tr.mean()), (1, pad_cols))
697
- X_meta_val = np.hstack([X_meta_val, pad])
698
- # If more cols, truncate to the trained size (keeps leftmost selected order)
699
- else:
700
- X_meta_val = X_meta_val[:, :n_meta_features_trained]
701
-
702
- # final safety assert (will raise an informative error if still wrong)
703
  if X_meta_val.shape[1] != n_meta_features_trained:
704
- raise ValueError(f"Final X_meta_val columns ({X_meta_val.shape[1]}) != trained meta features ({n_meta_features_trained})")
705
-
706
- # predict
 
707
  y_meta_pred = meta.predict(X_meta_val)
708
-
709
-
710
  final_r2 = r2_score(y_val, y_meta_pred)
711
  final_rmse = mean_squared_error(y_val, y_meta_pred, squared=False)
712
-
713
  c1, c2 = st.columns(2)
714
  c1.metric("Stacked Ensemble R² (holdout)", f"{final_r2:.4f}")
715
  c2.metric("Stacked Ensemble RMSE (holdout)", f"{final_rmse:.4f}")
716
-
717
- # scatter plot
718
- fig, ax = plt.subplots(figsize=(7,4))
719
  ax.scatter(y_val, y_meta_pred, alpha=0.6)
720
  ax.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--")
721
- ax.set_xlabel("Actual"); ax.set_ylabel("Stacked Predicted")
 
722
  st.pyplot(fig)
723
-
724
- # save artifacts: base models list + meta learner
725
- stack_artifact = os.path.join(DATA_DIR, f"stacked_{use_case.replace(' ','_')}.joblib")
726
- to_save = {"base_models": {bm["family"]: bm["model"] for bm in base_models if bm["family"] in selected}, "meta": meta, "features": features, "selected": selected, "target": target}
 
 
 
 
 
 
727
  joblib.dump(to_save, stack_artifact)
728
- st.caption(f"Stacked ensemble saved: {stack_artifact}")
729
-
730
- # --- SHAP on final stack: approximate by SHAP of top base model or meta contributions ---
731
  st.markdown("### Explainability (approximate)")
732
  try:
733
- # Prefer SHAP on top base model (tree) for interpretability
734
  top_base = next((b for b in base_models if b["family"] == selected[0]), None)
735
- if top_base is not None and hasattr(top_base["model"], "predict"):
736
- # sample for speed
737
  sample_X = X_val.sample(min(300, len(X_val)), random_state=42)
738
- if hasattr(top_base["model"], "predict") and ("XGBoost" in top_base["family"] or "LightGBM" in top_base["family"] or "RandomForest" in top_base["family"] or "ExtraTrees" in top_base["family"] or "CatBoost" in top_base["family"]):
739
- expl = None
740
- # safe tree explainer creation
741
- try:
742
- expl = shap.TreeExplainer(top_base["model"])
743
- shap_vals = expl.shap_values(sample_X)
744
- fig_sh = plt.figure(figsize=(8,6))
745
- shap.summary_plot(shap_vals, sample_X, show=False)
746
- st.pyplot(fig_sh)
747
- except Exception as e:
748
- st.warning(f"SHAP tree explainer unavailable: {e}")
749
  else:
750
- st.info("Top base model not tree-based; SHAP summary skipped. You can inspect per-base feature importances above.")
751
  else:
752
- st.info("No suitable base model for SHAP explanation found.")
753
  except Exception as e:
754
- st.warning(f"SHAP step failed gracefully: {e}")
 
 
 
755
 
756
- st.success("AutoML + Stacking complete. Review metrics and saved artifacts.")
757
 
758
  # ----- Target & Business Impact tab
759
  with tabs[5]:
 
651
  meta = Ridge(alpha=1.0)
652
  meta.fit(X_stack, y)
653
 
654
+ # --- Robust holdout evaluation & SHAP (safe for deployment) ---
655
+ # Split for holdout
656
  X_tr, X_val, y_tr, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
657
+
658
+ # Helper to always produce scalar-safe mean
659
+ def scalar_mean(arr):
660
+ try:
661
+ return float(np.mean(arr))
662
+ except Exception:
663
+ return float(np.mean(np.ravel(arr)))
664
+
665
+ # Build family → model map
666
  base_model_map = {bm["family"]: bm["model"] for bm in base_models}
667
+
668
  meta_inputs = []
669
  missing_families = []
670
+ n_meta_features_trained = X_stack.shape[1]
671
+
672
+ # Collect predictions from each selected model
673
  for fam in selected:
674
+ bm = base_model_map.get(fam)
675
  if bm is None:
 
676
  missing_families.append(fam)
677
+ safe_mean = scalar_mean(y_tr)
678
+ meta_inputs.append(np.full(len(X_val), safe_mean))
679
  continue
680
+
681
  try:
682
  preds = bm.predict(X_val)
683
+ preds = np.asarray(preds)
684
+ # Collapse multi-output predictions to 1D
685
+ if preds.ndim == 2:
686
+ preds = preds.mean(axis=1)
687
+ preds = preds.reshape(-1)
688
+ if preds.shape[0] != len(X_val):
689
+ preds = np.full(len(X_val), scalar_mean(y_tr))
690
  meta_inputs.append(preds)
691
+ except Exception as e:
692
+ safe_mean = scalar_mean(y_tr)
693
+ meta_inputs.append(np.full(len(X_val), safe_mean))
694
+
695
+ if missing_families:
696
+ st.warning(f"Missing base models: {missing_families}. Using mean predictions.")
697
+
698
+ # Stack meta features
699
+ if not meta_inputs:
700
+ st.error("No meta features to predict — aborting.")
701
+ st.stop()
702
+
703
  X_meta_val = np.column_stack(meta_inputs)
 
 
 
704
  n_meta_features_val = X_meta_val.shape[1]
705
+
706
+ # Align meta features between training and validation
707
+ if n_meta_features_val < n_meta_features_trained:
708
+ pad_cols = n_meta_features_trained - n_meta_features_val
709
+ safe_mean = scalar_mean(y_tr)
710
+ pad = np.tile(np.full((len(X_val), 1), safe_mean), (1, pad_cols))
711
+ X_meta_val = np.hstack([X_meta_val, pad])
712
+ elif n_meta_features_val > n_meta_features_trained:
713
+ X_meta_val = X_meta_val[:, :n_meta_features_trained]
714
+
 
 
715
  if X_meta_val.shape[1] != n_meta_features_trained:
716
+ st.error(f"Stack alignment failed: {X_meta_val.shape[1]} != {n_meta_features_trained}")
717
+ st.stop()
718
+
719
+ # Meta prediction
720
  y_meta_pred = meta.predict(X_meta_val)
721
+
722
+ # Final evaluation
723
  final_r2 = r2_score(y_val, y_meta_pred)
724
  final_rmse = mean_squared_error(y_val, y_meta_pred, squared=False)
725
+
726
  c1, c2 = st.columns(2)
727
  c1.metric("Stacked Ensemble R² (holdout)", f"{final_r2:.4f}")
728
  c2.metric("Stacked Ensemble RMSE (holdout)", f"{final_rmse:.4f}")
729
+
730
+ # Scatter comparison
731
+ fig, ax = plt.subplots(figsize=(7, 4))
732
  ax.scatter(y_val, y_meta_pred, alpha=0.6)
733
  ax.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()], "r--")
734
+ ax.set_xlabel("Actual")
735
+ ax.set_ylabel("Stacked Predicted")
736
  st.pyplot(fig)
737
+
738
+ # Save trained stack artifacts
739
+ stack_artifact = os.path.join(DATA_DIR, f"stacked_{use_case.replace(' ', '_')}.joblib")
740
+ to_save = {
741
+ "base_models": {bm["family"]: bm["model"] for bm in base_models if bm["family"] in selected},
742
+ "meta": meta,
743
+ "features": features,
744
+ "selected": selected,
745
+ "target": target,
746
+ }
747
  joblib.dump(to_save, stack_artifact)
748
+ st.caption(f"Stacked ensemble saved: {stack_artifact}")
749
+
750
+ # Explainability
751
  st.markdown("### Explainability (approximate)")
752
  try:
 
753
  top_base = next((b for b in base_models if b["family"] == selected[0]), None)
754
+ if top_base and hasattr(top_base["model"], "predict"):
 
755
  sample_X = X_val.sample(min(300, len(X_val)), random_state=42)
756
+ if any(k in top_base["family"] for k in ["XGBoost", "LightGBM", "RandomForest", "ExtraTrees", "CatBoost"]):
757
+ expl = shap.TreeExplainer(top_base["model"])
758
+ shap_vals = expl.shap_values(sample_X)
759
+ fig_sh = plt.figure(figsize=(8, 6))
760
+ shap.summary_plot(shap_vals, sample_X, show=False)
761
+ st.pyplot(fig_sh)
 
 
 
 
 
762
  else:
763
+ st.info("Top model not tree-based; skipping SHAP summary.")
764
  else:
765
+ st.info("No suitable base model for SHAP explanation.")
766
  except Exception as e:
767
+ st.warning(f"SHAP computation skipped: {e}")
768
+
769
+ st.success("✅ AutoML + Stacking complete — metrics, artifacts, and SHAP ready.")
770
+
771
 
 
772
 
773
  # ----- Target & Business Impact tab
774
  with tabs[5]: