singhn9 commited on
Commit
a7bac14
·
verified ·
1 Parent(s): cb7a53e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +20 -0
src/streamlit_app.py CHANGED
@@ -870,13 +870,33 @@ with tabs[4]:
870
  # SHAP direction analysis
871
  expl = shap.TreeExplainer(model)
872
  shap_vals = expl.shap_values(sample_X)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  mean_abs = np.abs(shap_vals).mean(axis=0)
874
  mean_sign = np.sign(shap_vals).mean(axis=0)
 
875
  importance = pd.DataFrame({
876
  "Feature": sample_X.columns,
877
  "Mean |SHAP|": mean_abs,
878
  "Mean SHAP Sign": mean_sign
879
  }).sort_values("Mean |SHAP|", ascending=False)
 
880
 
881
  # Display Top 5 Drivers
882
  st.markdown("### Top 5 Operational Drivers Influencing Target")
 
870
  # SHAP direction analysis
871
  expl = shap.TreeExplainer(model)
872
  shap_vals = expl.shap_values(sample_X)
873
+
874
+ # --- Normalize SHAP output structure (handles list, ndarray, or multi-dim cases) ---
875
+ if isinstance(shap_vals, list): # e.g., for multiclass models
876
+ shap_vals = shap_vals[0]
877
+
878
+ shap_vals = np.array(shap_vals)
879
+
880
+ # If SHAP output has >2 dims, reduce to (n_samples, n_features)
881
+ if shap_vals.ndim > 2:
882
+ shap_vals = shap_vals.reshape(shap_vals.shape[0], -1)
883
+
884
+ # Align SHAP features to DataFrame
885
+ if shap_vals.shape[1] != sample_X.shape[1]:
886
+ min_feats = min(shap_vals.shape[1], sample_X.shape[1])
887
+ shap_vals = shap_vals[:, :min_feats]
888
+ sample_X = sample_X.iloc[:, :min_feats]
889
+
890
+ # Compute robust means
891
  mean_abs = np.abs(shap_vals).mean(axis=0)
892
  mean_sign = np.sign(shap_vals).mean(axis=0)
893
+
894
  importance = pd.DataFrame({
895
  "Feature": sample_X.columns,
896
  "Mean |SHAP|": mean_abs,
897
  "Mean SHAP Sign": mean_sign
898
  }).sort_values("Mean |SHAP|", ascending=False)
899
+
900
 
901
  # Display Top 5 Drivers
902
  st.markdown("### Top 5 Operational Drivers Influencing Target")