Synav commited on
Commit
21fcf07
·
verified ·
1 Parent(s): a27119e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -5
app.py CHANGED
@@ -628,6 +628,12 @@ def train_and_save(
628
  "selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None,
629
  "note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
630
  },
 
 
 
 
 
 
631
  "positive_class": str(pos_class),
632
  "metrics": metrics,
633
  }
@@ -668,6 +674,131 @@ def ensure_model_repo_exists(model_repo_id: str, token: str):
668
  except Exception:
669
  pass
670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
 
672
  def publish_to_hub(model_repo_id: str, version_tag: str):
673
  """
@@ -1276,7 +1407,18 @@ with tab_train:
1276
  use_dimred=use_dimred,
1277
  svd_components=svd_components
1278
  )
1279
-
 
 
 
 
 
 
 
 
 
 
 
1280
  explainer = build_shap_explainer(pipe, X_train)
1281
 
1282
  st.session_state.pipe = pipe
@@ -1509,6 +1651,12 @@ with tab_train:
1509
  try:
1510
  with st.spinner("Uploading to Hugging Face Model repo..."):
1511
  paths = publish_to_hub(MODEL_REPO_ID, version_tag)
 
 
 
 
 
 
1512
 
1513
  st.success("Uploaded successfully to your model repository.")
1514
  st.json(paths)
@@ -1576,6 +1724,21 @@ with tab_predict:
1576
  num_cols = meta["schema"]["numeric"]
1577
  cat_cols = meta["schema"]["categorical"]
1578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1579
  # 2) Now we can build lookup
1580
  FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols}
1581
 
@@ -2429,9 +2592,25 @@ with tab_predict:
2429
  X_batch_t = transform_before_clf(pipe, X_batch)
2430
 
2431
  explainer = st.session_state.get("explainer")
2432
- if explainer is None:
2433
- st.session_state.explainer = build_shap_explainer(pipe, X_inf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2434
  explainer = st.session_state.explainer
 
2435
 
2436
  shap_vals_batch = explainer.shap_values(X_batch_t)
2437
  if isinstance(shap_vals_batch, list):
@@ -2627,9 +2806,25 @@ with tab_predict:
2627
  X_one_t = transform_before_clf(pipe, X_one)
2628
 
2629
  explainer = st.session_state.get("explainer")
2630
- if explainer is None:
2631
- st.session_state.explainer = build_shap_explainer(pipe, X_inf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2632
  explainer = st.session_state.explainer
 
2633
 
2634
  shap_vals = explainer.shap_values(X_one_t)
2635
  if isinstance(shap_vals, list):
 
628
  "selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None,
629
  "note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
630
  },
631
+ "shap_background": {
632
+ "file": "background.csv",
633
+ "max_rows": 200,
634
+ "note": "Raw (pre-transform) background sample for SHAP LinearExplainer."
635
+ },
636
+
637
  "positive_class": str(pos_class),
638
  "metrics": metrics,
639
  }
 
674
  except Exception:
675
  pass
676
 
677
+ def coerce_X_like_schema(X: pd.DataFrame, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame:
678
+ """
679
+ Ensure X has correct columns and coercions, matching your training/inference convention.
680
+ """
681
+ X = X[feature_cols].copy().replace({pd.NA: np.nan})
682
+
683
+ for c in num_cols:
684
+ if c in X.columns:
685
+ X[c] = pd.to_numeric(X[c], errors="coerce")
686
+
687
+ for c in cat_cols:
688
+ if c in X.columns:
689
+ X[c] = X[c].astype("object")
690
+ X.loc[X[c].isna(), c] = np.nan
691
+ X[c] = X[c].map(lambda v: v if pd.isna(v) else str(v))
692
+
693
+ return X
694
+
695
+
696
+ def get_shap_background_auto(model_repo_id: str, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame | None:
697
+ """
698
+ Attempts to load SHAP background from HF repo. Returns coerced background or None.
699
+ """
700
+ df_bg = load_latest_background(model_repo_id)
701
+ if df_bg is None:
702
+ return None
703
+
704
+ # Ensure required columns exist
705
+ missing = [c for c in feature_cols if c not in df_bg.columns]
706
+ if missing:
707
+ return None
708
+
709
+ return coerce_X_like_schema(df_bg, feature_cols, num_cols, cat_cols)
710
+
711
+
712
+
713
+ # ============================================================
714
+ # SHAP background persistence (best practice)
715
+ # ============================================================
716
+
717
+ def save_background_sample_csv(X_bg: pd.DataFrame, feature_cols: list[str], max_rows: int = 200, out_path: str = "background.csv"):
718
+ """
719
+ Saves a small *raw* background dataset (pre-transform) for SHAP explainer.
720
+ Must contain columns exactly matching feature_cols.
721
+ """
722
+ if X_bg is None or len(X_bg) == 0:
723
+ raise ValueError("X_bg is empty; cannot save background sample.")
724
+
725
+ X_bg = X_bg[feature_cols].copy()
726
+
727
+ if len(X_bg) > int(max_rows):
728
+ X_bg = X_bg.sample(int(max_rows), random_state=42)
729
+
730
+ # Preserve exact columns for future loading
731
+ X_bg.to_csv(out_path, index=False, encoding="utf-8")
732
+ return out_path
733
+
734
+
735
+ def publish_background_to_hub(model_repo_id: str, version_tag: str, background_path: str = "background.csv"):
736
+ """
737
+ Uploads background.csv to both versioned and latest paths.
738
+ Requires HF_TOKEN with write permissions.
739
+ """
740
+ token = os.environ.get("HF_TOKEN")
741
+ if not token:
742
+ raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.")
743
+ api = HfApi(token=token)
744
+
745
+ version_bg_path = f"releases/{version_tag}/background.csv"
746
+
747
+ # Versioned
748
+ api.upload_file(
749
+ path_or_fileobj=background_path,
750
+ path_in_repo=version_bg_path,
751
+ repo_id=model_repo_id,
752
+ repo_type="model",
753
+ commit_message=f"Upload SHAP background ({version_tag})"
754
+ )
755
+
756
+ # Latest
757
+ api.upload_file(
758
+ path_or_fileobj=background_path,
759
+ path_in_repo="latest/background.csv",
760
+ repo_id=model_repo_id,
761
+ repo_type="model",
762
+ commit_message=f"Update latest SHAP background ({version_tag})"
763
+ )
764
+
765
+ return {
766
+ "version_bg_path": version_bg_path,
767
+ "latest_bg_path": "latest/background.csv",
768
+ }
769
+
770
+
771
+ def load_latest_background(model_repo_id: str) -> pd.DataFrame | None:
772
+ """
773
+ Loads latest/background.csv if present. Returns None if not found / cannot load.
774
+ """
775
+ try:
776
+ bg_file = hf_hub_download(
777
+ repo_id=model_repo_id,
778
+ repo_type="model",
779
+ filename="latest/background.csv",
780
+ )
781
+ df_bg = pd.read_csv(bg_file)
782
+ return df_bg
783
+ except Exception:
784
+ return None
785
+
786
+
787
+ def load_background_by_version(model_repo_id: str, version_tag: str) -> pd.DataFrame | None:
788
+ """
789
+ Loads releases/<version>/background.csv if present.
790
+ """
791
+ try:
792
+ bg_file = hf_hub_download(
793
+ repo_id=model_repo_id,
794
+ repo_type="model",
795
+ filename=f"releases/{version_tag}/background.csv",
796
+ )
797
+ df_bg = pd.read_csv(bg_file)
798
+ return df_bg
799
+ except Exception:
800
+ return None
801
+
802
 
803
  def publish_to_hub(model_repo_id: str, version_tag: str):
804
  """
 
1407
  use_dimred=use_dimred,
1408
  svd_components=svd_components
1409
  )
1410
+ # --- Save background sample for SHAP (raw X_train) ---
1411
+ try:
1412
+ save_background_sample_csv(
1413
+ X_bg=X_train,
1414
+ feature_cols=feature_cols,
1415
+ max_rows=200,
1416
+ out_path="background.csv"
1417
+ )
1418
+ st.success("Saved SHAP background sample (background.csv).")
1419
+ except Exception as e:
1420
+ st.warning(f"Could not save SHAP background sample: {e}")
1421
+
1422
  explainer = build_shap_explainer(pipe, X_train)
1423
 
1424
  st.session_state.pipe = pipe
 
1651
  try:
1652
  with st.spinner("Uploading to Hugging Face Model repo..."):
1653
  paths = publish_to_hub(MODEL_REPO_ID, version_tag)
1654
+ # Upload background.csv if it exists
1655
+ if os.path.exists("background.csv"):
1656
+ bg_paths = publish_background_to_hub(MODEL_REPO_ID, version_tag, background_path="background.csv")
1657
+ paths.update(bg_paths)
1658
+ else:
1659
+ st.warning("background.csv not found; SHAP background will not be uploaded.")
1660
 
1661
  st.success("Uploaded successfully to your model repository.")
1662
  st.json(paths)
 
1724
  num_cols = meta["schema"]["numeric"]
1725
  cat_cols = meta["schema"]["categorical"]
1726
 
1727
+ # ------------------------------------------------------------
1728
+ # SHAP background: prefer inference file, else HF background.csv
1729
+ # ------------------------------------------------------------
1730
+ df_inf = st.session_state.get("df_inf")
1731
+
1732
+ if df_inf is not None:
1733
+ # use user cohort as background (optional)
1734
+ X_bg = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols)
1735
+ else:
1736
+ # fall back to published background
1737
+ X_bg = get_shap_background_auto(MODEL_REPO_ID, feature_cols, num_cols, cat_cols)
1738
+
1739
+ st.session_state.X_bg_for_shap = X_bg
1740
+
1741
+
1742
  # 2) Now we can build lookup
1743
  FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols}
1744
 
 
2592
  X_batch_t = transform_before_clf(pipe, X_batch)
2593
 
2594
  explainer = st.session_state.get("explainer")
2595
+ explainer_sig = st.session_state.get("explainer_sig")
2596
+
2597
+ # Create a simple signature that changes if model changes or background changes
2598
+ # (using version + number of background rows is usually enough)
2599
+ current_sig = (
2600
+ selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc")
2601
+ None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
2602
+ )
2603
+
2604
+ if explainer is None or explainer_sig != current_sig:
2605
+ X_bg = st.session_state.get("X_bg_for_shap")
2606
+ if X_bg is None:
2607
+ st.error("SHAP background not available. Admin must publish latest/background.csv.")
2608
+ st.stop()
2609
+
2610
+ st.session_state.explainer = build_shap_explainer(pipe, X_bg)
2611
+ st.session_state.explainer_sig = current_sig
2612
  explainer = st.session_state.explainer
2613
+
2614
 
2615
  shap_vals_batch = explainer.shap_values(X_batch_t)
2616
  if isinstance(shap_vals_batch, list):
 
2806
  X_one_t = transform_before_clf(pipe, X_one)
2807
 
2808
  explainer = st.session_state.get("explainer")
2809
+ explainer_sig = st.session_state.get("explainer_sig")
2810
+
2811
+ # Create a simple signature that changes if model changes or background changes
2812
+ # (using version + number of background rows is usually enough)
2813
+ current_sig = (
2814
+ selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc")
2815
+ None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
2816
+ )
2817
+
2818
+ if explainer is None or explainer_sig != current_sig:
2819
+ X_bg = st.session_state.get("X_bg_for_shap")
2820
+ if X_bg is None:
2821
+ st.error("SHAP background not available. Admin must publish latest/background.csv.")
2822
+ st.stop()
2823
+
2824
+ st.session_state.explainer = build_shap_explainer(pipe, X_bg)
2825
+ st.session_state.explainer_sig = current_sig
2826
  explainer = st.session_state.explainer
2827
+
2828
 
2829
  shap_vals = explainer.shap_values(X_one_t)
2830
  if isinstance(shap_vals, list):