Update app.py
Browse files
app.py
CHANGED
|
@@ -628,6 +628,12 @@ def train_and_save(
|
|
| 628 |
"selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None,
|
| 629 |
"note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
|
| 630 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
"positive_class": str(pos_class),
|
| 632 |
"metrics": metrics,
|
| 633 |
}
|
|
@@ -668,6 +674,131 @@ def ensure_model_repo_exists(model_repo_id: str, token: str):
|
|
| 668 |
except Exception:
|
| 669 |
pass
|
| 670 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
|
| 672 |
def publish_to_hub(model_repo_id: str, version_tag: str):
|
| 673 |
"""
|
|
@@ -1276,7 +1407,18 @@ with tab_train:
|
|
| 1276 |
use_dimred=use_dimred,
|
| 1277 |
svd_components=svd_components
|
| 1278 |
)
|
| 1279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1280 |
explainer = build_shap_explainer(pipe, X_train)
|
| 1281 |
|
| 1282 |
st.session_state.pipe = pipe
|
|
@@ -1509,6 +1651,12 @@ with tab_train:
|
|
| 1509 |
try:
|
| 1510 |
with st.spinner("Uploading to Hugging Face Model repo..."):
|
| 1511 |
paths = publish_to_hub(MODEL_REPO_ID, version_tag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1512 |
|
| 1513 |
st.success("Uploaded successfully to your model repository.")
|
| 1514 |
st.json(paths)
|
|
@@ -1576,6 +1724,21 @@ with tab_predict:
|
|
| 1576 |
num_cols = meta["schema"]["numeric"]
|
| 1577 |
cat_cols = meta["schema"]["categorical"]
|
| 1578 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1579 |
# 2) Now we can build lookup
|
| 1580 |
FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols}
|
| 1581 |
|
|
@@ -2429,9 +2592,25 @@ with tab_predict:
|
|
| 2429 |
X_batch_t = transform_before_clf(pipe, X_batch)
|
| 2430 |
|
| 2431 |
explainer = st.session_state.get("explainer")
|
| 2432 |
-
|
| 2433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2434 |
explainer = st.session_state.explainer
|
|
|
|
| 2435 |
|
| 2436 |
shap_vals_batch = explainer.shap_values(X_batch_t)
|
| 2437 |
if isinstance(shap_vals_batch, list):
|
|
@@ -2627,9 +2806,25 @@ with tab_predict:
|
|
| 2627 |
X_one_t = transform_before_clf(pipe, X_one)
|
| 2628 |
|
| 2629 |
explainer = st.session_state.get("explainer")
|
| 2630 |
-
|
| 2631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2632 |
explainer = st.session_state.explainer
|
|
|
|
| 2633 |
|
| 2634 |
shap_vals = explainer.shap_values(X_one_t)
|
| 2635 |
if isinstance(shap_vals, list):
|
|
|
|
| 628 |
"selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None,
|
| 629 |
"note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
|
| 630 |
},
|
| 631 |
+
"shap_background": {
|
| 632 |
+
"file": "background.csv",
|
| 633 |
+
"max_rows": 200,
|
| 634 |
+
"note": "Raw (pre-transform) background sample for SHAP LinearExplainer."
|
| 635 |
+
},
|
| 636 |
+
|
| 637 |
"positive_class": str(pos_class),
|
| 638 |
"metrics": metrics,
|
| 639 |
}
|
|
|
|
| 674 |
except Exception:
|
| 675 |
pass
|
| 676 |
|
| 677 |
+
def coerce_X_like_schema(X: pd.DataFrame, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame:
|
| 678 |
+
"""
|
| 679 |
+
Ensure X has correct columns and coercions, matching your training/inference convention.
|
| 680 |
+
"""
|
| 681 |
+
X = X[feature_cols].copy().replace({pd.NA: np.nan})
|
| 682 |
+
|
| 683 |
+
for c in num_cols:
|
| 684 |
+
if c in X.columns:
|
| 685 |
+
X[c] = pd.to_numeric(X[c], errors="coerce")
|
| 686 |
+
|
| 687 |
+
for c in cat_cols:
|
| 688 |
+
if c in X.columns:
|
| 689 |
+
X[c] = X[c].astype("object")
|
| 690 |
+
X.loc[X[c].isna(), c] = np.nan
|
| 691 |
+
X[c] = X[c].map(lambda v: v if pd.isna(v) else str(v))
|
| 692 |
+
|
| 693 |
+
return X
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def get_shap_background_auto(model_repo_id: str, feature_cols: list[str], num_cols: list[str], cat_cols: list[str]) -> pd.DataFrame | None:
|
| 697 |
+
"""
|
| 698 |
+
Attempts to load SHAP background from HF repo. Returns coerced background or None.
|
| 699 |
+
"""
|
| 700 |
+
df_bg = load_latest_background(model_repo_id)
|
| 701 |
+
if df_bg is None:
|
| 702 |
+
return None
|
| 703 |
+
|
| 704 |
+
# Ensure required columns exist
|
| 705 |
+
missing = [c for c in feature_cols if c not in df_bg.columns]
|
| 706 |
+
if missing:
|
| 707 |
+
return None
|
| 708 |
+
|
| 709 |
+
return coerce_X_like_schema(df_bg, feature_cols, num_cols, cat_cols)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
# ============================================================
|
| 714 |
+
# SHAP background persistence (best practice)
|
| 715 |
+
# ============================================================
|
| 716 |
+
|
| 717 |
+
def save_background_sample_csv(X_bg: pd.DataFrame, feature_cols: list[str], max_rows: int = 200, out_path: str = "background.csv"):
|
| 718 |
+
"""
|
| 719 |
+
Saves a small *raw* background dataset (pre-transform) for SHAP explainer.
|
| 720 |
+
Must contain columns exactly matching feature_cols.
|
| 721 |
+
"""
|
| 722 |
+
if X_bg is None or len(X_bg) == 0:
|
| 723 |
+
raise ValueError("X_bg is empty; cannot save background sample.")
|
| 724 |
+
|
| 725 |
+
X_bg = X_bg[feature_cols].copy()
|
| 726 |
+
|
| 727 |
+
if len(X_bg) > int(max_rows):
|
| 728 |
+
X_bg = X_bg.sample(int(max_rows), random_state=42)
|
| 729 |
+
|
| 730 |
+
# Preserve exact columns for future loading
|
| 731 |
+
X_bg.to_csv(out_path, index=False, encoding="utf-8")
|
| 732 |
+
return out_path
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def publish_background_to_hub(model_repo_id: str, version_tag: str, background_path: str = "background.csv"):
|
| 736 |
+
"""
|
| 737 |
+
Uploads background.csv to both versioned and latest paths.
|
| 738 |
+
Requires HF_TOKEN with write permissions.
|
| 739 |
+
"""
|
| 740 |
+
token = os.environ.get("HF_TOKEN")
|
| 741 |
+
if not token:
|
| 742 |
+
raise RuntimeError("HF_TOKEN not found. Add it in Space Settings → Secrets.")
|
| 743 |
+
api = HfApi(token=token)
|
| 744 |
+
|
| 745 |
+
version_bg_path = f"releases/{version_tag}/background.csv"
|
| 746 |
+
|
| 747 |
+
# Versioned
|
| 748 |
+
api.upload_file(
|
| 749 |
+
path_or_fileobj=background_path,
|
| 750 |
+
path_in_repo=version_bg_path,
|
| 751 |
+
repo_id=model_repo_id,
|
| 752 |
+
repo_type="model",
|
| 753 |
+
commit_message=f"Upload SHAP background ({version_tag})"
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# Latest
|
| 757 |
+
api.upload_file(
|
| 758 |
+
path_or_fileobj=background_path,
|
| 759 |
+
path_in_repo="latest/background.csv",
|
| 760 |
+
repo_id=model_repo_id,
|
| 761 |
+
repo_type="model",
|
| 762 |
+
commit_message=f"Update latest SHAP background ({version_tag})"
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
return {
|
| 766 |
+
"version_bg_path": version_bg_path,
|
| 767 |
+
"latest_bg_path": "latest/background.csv",
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def load_latest_background(model_repo_id: str) -> pd.DataFrame | None:
|
| 772 |
+
"""
|
| 773 |
+
Loads latest/background.csv if present. Returns None if not found / cannot load.
|
| 774 |
+
"""
|
| 775 |
+
try:
|
| 776 |
+
bg_file = hf_hub_download(
|
| 777 |
+
repo_id=model_repo_id,
|
| 778 |
+
repo_type="model",
|
| 779 |
+
filename="latest/background.csv",
|
| 780 |
+
)
|
| 781 |
+
df_bg = pd.read_csv(bg_file)
|
| 782 |
+
return df_bg
|
| 783 |
+
except Exception:
|
| 784 |
+
return None
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def load_background_by_version(model_repo_id: str, version_tag: str) -> pd.DataFrame | None:
|
| 788 |
+
"""
|
| 789 |
+
Loads releases/<version>/background.csv if present.
|
| 790 |
+
"""
|
| 791 |
+
try:
|
| 792 |
+
bg_file = hf_hub_download(
|
| 793 |
+
repo_id=model_repo_id,
|
| 794 |
+
repo_type="model",
|
| 795 |
+
filename=f"releases/{version_tag}/background.csv",
|
| 796 |
+
)
|
| 797 |
+
df_bg = pd.read_csv(bg_file)
|
| 798 |
+
return df_bg
|
| 799 |
+
except Exception:
|
| 800 |
+
return None
|
| 801 |
+
|
| 802 |
|
| 803 |
def publish_to_hub(model_repo_id: str, version_tag: str):
|
| 804 |
"""
|
|
|
|
| 1407 |
use_dimred=use_dimred,
|
| 1408 |
svd_components=svd_components
|
| 1409 |
)
|
| 1410 |
+
# --- Save background sample for SHAP (raw X_train) ---
|
| 1411 |
+
try:
|
| 1412 |
+
save_background_sample_csv(
|
| 1413 |
+
X_bg=X_train,
|
| 1414 |
+
feature_cols=feature_cols,
|
| 1415 |
+
max_rows=200,
|
| 1416 |
+
out_path="background.csv"
|
| 1417 |
+
)
|
| 1418 |
+
st.success("Saved SHAP background sample (background.csv).")
|
| 1419 |
+
except Exception as e:
|
| 1420 |
+
st.warning(f"Could not save SHAP background sample: {e}")
|
| 1421 |
+
|
| 1422 |
explainer = build_shap_explainer(pipe, X_train)
|
| 1423 |
|
| 1424 |
st.session_state.pipe = pipe
|
|
|
|
| 1651 |
try:
|
| 1652 |
with st.spinner("Uploading to Hugging Face Model repo..."):
|
| 1653 |
paths = publish_to_hub(MODEL_REPO_ID, version_tag)
|
| 1654 |
+
# Upload background.csv if it exists
|
| 1655 |
+
if os.path.exists("background.csv"):
|
| 1656 |
+
bg_paths = publish_background_to_hub(MODEL_REPO_ID, version_tag, background_path="background.csv")
|
| 1657 |
+
paths.update(bg_paths)
|
| 1658 |
+
else:
|
| 1659 |
+
st.warning("background.csv not found; SHAP background will not be uploaded.")
|
| 1660 |
|
| 1661 |
st.success("Uploaded successfully to your model repository.")
|
| 1662 |
st.json(paths)
|
|
|
|
| 1724 |
num_cols = meta["schema"]["numeric"]
|
| 1725 |
cat_cols = meta["schema"]["categorical"]
|
| 1726 |
|
| 1727 |
+
# ------------------------------------------------------------
|
| 1728 |
+
# SHAP background: prefer inference file, else HF background.csv
|
| 1729 |
+
# ------------------------------------------------------------
|
| 1730 |
+
df_inf = st.session_state.get("df_inf")
|
| 1731 |
+
|
| 1732 |
+
if df_inf is not None:
|
| 1733 |
+
# use user cohort as background (optional)
|
| 1734 |
+
X_bg = coerce_X_like_schema(df_inf, feature_cols, num_cols, cat_cols)
|
| 1735 |
+
else:
|
| 1736 |
+
# fall back to published background
|
| 1737 |
+
X_bg = get_shap_background_auto(MODEL_REPO_ID, feature_cols, num_cols, cat_cols)
|
| 1738 |
+
|
| 1739 |
+
st.session_state.X_bg_for_shap = X_bg
|
| 1740 |
+
|
| 1741 |
+
|
| 1742 |
# 2) Now we can build lookup
|
| 1743 |
FEATURE_LOOKUP = {norm_col(c): c for c in feature_cols}
|
| 1744 |
|
|
|
|
| 2592 |
X_batch_t = transform_before_clf(pipe, X_batch)
|
| 2593 |
|
| 2594 |
explainer = st.session_state.get("explainer")
|
| 2595 |
+
explainer_sig = st.session_state.get("explainer_sig")
|
| 2596 |
+
|
| 2597 |
+
# Create a simple signature that changes if model changes or background changes
|
| 2598 |
+
# (using version + number of background rows is usually enough)
|
| 2599 |
+
current_sig = (
|
| 2600 |
+
selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc")
|
| 2601 |
+
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
|
| 2602 |
+
)
|
| 2603 |
+
|
| 2604 |
+
if explainer is None or explainer_sig != current_sig:
|
| 2605 |
+
X_bg = st.session_state.get("X_bg_for_shap")
|
| 2606 |
+
if X_bg is None:
|
| 2607 |
+
st.error("SHAP background not available. Admin must publish latest/background.csv.")
|
| 2608 |
+
st.stop()
|
| 2609 |
+
|
| 2610 |
+
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
|
| 2611 |
+
st.session_state.explainer_sig = current_sig
|
| 2612 |
explainer = st.session_state.explainer
|
| 2613 |
+
|
| 2614 |
|
| 2615 |
shap_vals_batch = explainer.shap_values(X_batch_t)
|
| 2616 |
if isinstance(shap_vals_batch, list):
|
|
|
|
| 2806 |
X_one_t = transform_before_clf(pipe, X_one)
|
| 2807 |
|
| 2808 |
explainer = st.session_state.get("explainer")
|
| 2809 |
+
explainer_sig = st.session_state.get("explainer_sig")
|
| 2810 |
+
|
| 2811 |
+
# Create a simple signature that changes if model changes or background changes
|
| 2812 |
+
# (using version + number of background rows is usually enough)
|
| 2813 |
+
current_sig = (
|
| 2814 |
+
selected, # or meta.get("created_at_utc") or meta.get("metrics", {}).get("roc_auc")
|
| 2815 |
+
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
|
| 2816 |
+
)
|
| 2817 |
+
|
| 2818 |
+
if explainer is None or explainer_sig != current_sig:
|
| 2819 |
+
X_bg = st.session_state.get("X_bg_for_shap")
|
| 2820 |
+
if X_bg is None:
|
| 2821 |
+
st.error("SHAP background not available. Admin must publish latest/background.csv.")
|
| 2822 |
+
st.stop()
|
| 2823 |
+
|
| 2824 |
+
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
|
| 2825 |
+
st.session_state.explainer_sig = current_sig
|
| 2826 |
explainer = st.session_state.explainer
|
| 2827 |
+
|
| 2828 |
|
| 2829 |
shap_vals = explainer.shap_values(X_one_t)
|
| 2830 |
if isinstance(shap_vals, list):
|