singhn9 commited on
Commit
a838781
·
verified ·
1 Parent(s): 2730c0e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +42 -7
src/streamlit_app.py CHANGED
@@ -1,4 +1,4 @@
1
- # sail_modex_stable.py
2
  import os
3
  import json
4
  import time
@@ -654,7 +654,7 @@ with tabs[4]:
654
  result = tune_family(fam, X, y, n_trials=max_trials)
655
  model_obj = result.get("model_obj")
656
 
657
- # Fix: ensure model is safe to access before fitting
658
  if hasattr(model_obj, "estimators_"):
659
  delattr(model_obj, "estimators_") # clear stale ref if any
660
  result["model_obj"] = model_obj
@@ -768,11 +768,39 @@ with tabs[4]:
768
  try:
769
  top_base = next((b for b in base_models if b["family"] == selected[0]), None)
770
  if top_base and hasattr(top_base["model"], "predict"):
771
- sample_X = X_val.sample(min(300, len(X_val)), random_state=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  model = top_base["model"]
773
  expl = shap.TreeExplainer(model)
774
  shap_vals = expl.shap_values(sample_X)
775
- if isinstance(shap_vals, list): shap_vals = shap_vals[0]
 
776
  shap_vals = np.array(shap_vals)
777
  mean_abs = np.abs(shap_vals).mean(axis=0)
778
  mean_sign = np.sign(shap_vals).mean(axis=0)
@@ -781,17 +809,24 @@ with tabs[4]:
781
  "Mean |SHAP|": mean_abs,
782
  "Mean SHAP Sign": mean_sign
783
  }).sort_values("Mean |SHAP|", ascending=False)
 
784
  st.markdown("### Top 5 Operational Drivers")
785
  st.dataframe(importance.head(5))
 
786
  recommendations = []
787
  for _, row in importance.head(5).iterrows():
788
  f, s = row["Feature"], row["Mean SHAP Sign"]
789
- if s > 0.05: recommendations.append(f"Increase `{f}` likely increases `{target}`")
790
- elif s < -0.05: recommendations.append(f"Decrease `{f}` likely increases `{target}`")
791
- else: recommendations.append(f"`{f}` neutral for `{target}`")
 
 
 
 
792
  st.markdown("### Suggested Operator Adjustments")
793
  st.write("\n".join(recommendations))
794
 
 
795
  # --- Call HF Llama-3-70B-Instruct API for summary ---
796
  import requests
797
  HF_TOKEN = st.secrets.get("HF_TOKEN", os.getenv("HF_TOKEN"))
 
1
+
2
  import os
3
  import json
4
  import time
 
654
  result = tune_family(fam, X, y, n_trials=max_trials)
655
  model_obj = result.get("model_obj")
656
 
657
+ # Fix: ensure model is safe to access before fitting
658
  if hasattr(model_obj, "estimators_"):
659
  delattr(model_obj, "estimators_") # clear stale ref if any
660
  result["model_obj"] = model_obj
 
768
  try:
769
  top_base = next((b for b in base_models if b["family"] == selected[0]), None)
770
  if top_base and hasattr(top_base["model"], "predict"):
771
+ # --- Ensure numeric dtypes for SHAP ---
772
+ sample_X = X_val.sample(min(300, len(X_val)), random_state=42).copy()
773
+ for col in sample_X.columns:
774
+ if sample_X[col].dtype == object:
775
+ # Clean any bracketed, comma, or sci-notation strings
776
+ sample_X[col] = (
777
+ sample_X[col]
778
+ .astype(str)
779
+ .str.replace("[", "", regex=False)
780
+ .str.replace("]", "", regex=False)
781
+ .str.replace(",", "", regex=False)
782
+ .str.replace("E", "e", regex=False)
783
+ .str.replace("nan", "0", regex=False)
784
+ .str.strip()
785
+ )
786
+ # Force numeric conversion for all columns
787
+ sample_X[col] = pd.to_numeric(sample_X[col], errors="coerce")
788
+
789
+ # Replace NaN with 0 for SHAP stability
790
+ sample_X = sample_X.fillna(0)
791
+
792
+ # Optional: show columns that were coerced
793
+ non_numeric_cols = [c for c in sample_X.columns if not np.issubdtype(sample_X[c].dtype, np.number)]
794
+ if non_numeric_cols:
795
+ st.warning(f"Non-numeric columns coerced: {non_numeric_cols}")
796
+
797
+
798
+ # --- SHAP computation ---
799
  model = top_base["model"]
800
  expl = shap.TreeExplainer(model)
801
  shap_vals = expl.shap_values(sample_X)
802
+ if isinstance(shap_vals, list):
803
+ shap_vals = shap_vals[0]
804
  shap_vals = np.array(shap_vals)
805
  mean_abs = np.abs(shap_vals).mean(axis=0)
806
  mean_sign = np.sign(shap_vals).mean(axis=0)
 
809
  "Mean |SHAP|": mean_abs,
810
  "Mean SHAP Sign": mean_sign
811
  }).sort_values("Mean |SHAP|", ascending=False)
812
+
813
  st.markdown("### Top 5 Operational Drivers")
814
  st.dataframe(importance.head(5))
815
+
816
  recommendations = []
817
  for _, row in importance.head(5).iterrows():
818
  f, s = row["Feature"], row["Mean SHAP Sign"]
819
+ if s > 0.05:
820
+ recommendations.append(f"Increase `{f}` likely increases `{target}`")
821
+ elif s < -0.05:
822
+ recommendations.append(f"Decrease `{f}` likely increases `{target}`")
823
+ else:
824
+ recommendations.append(f"`{f}` neutral for `{target}`")
825
+
826
  st.markdown("### Suggested Operator Adjustments")
827
  st.write("\n".join(recommendations))
828
 
829
+
830
  # --- Call HF Llama-3-70B-Instruct API for summary ---
831
  import requests
832
  HF_TOKEN = st.secrets.get("HF_TOKEN", os.getenv("HF_TOKEN"))