Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["MPLBACKEND"] = "Agg" | |
| import matplotlib | |
| matplotlib.use("Agg", force=True) | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import shap | |
| from pathlib import Path | |
| from pycaret.classification import load_model | |
| from huggingface_hub import hf_hub_download | |
| # --- config --- | |
| MODEL_BASENAME = "subset_best_model" | |
| SAMPLES_CSV = "GTT.csv" | |
| TARGET_COL = "gtt" | |
| POS_LABEL = 1 | |
| REPO = os.getenv("MODEL_REPO", "GDMProjects/my-private-model") | |
| FNAME = os.getenv("MODEL_FILE", "subset_best_model.pkl") | |
| TOKEN = os.getenv("HF_TOKEN") | |
| SUBSET_FEATURES = [ | |
| "age", | |
| "bmi", | |
| "history_of_htn", | |
| "history_infectious_cardiovascular_diseae", | |
| "previos_obsteric_history_ab", | |
| "fbs_first_trimester", | |
| "hb", | |
| "hct", | |
| "cr", | |
| "plt", | |
| "vit_d3", | |
| "sono_nt_nt", | |
| "sono_nt_crl", | |
| ] | |
| # ---------- utils ---------- | |
| def normalize_cols(df: pd.DataFrame) -> pd.DataFrame: | |
| out = df.copy() | |
| out.columns = ( | |
| out.columns.str.strip() | |
| .str.replace(r"[\s/\\\.\-]+", "_", regex=True) | |
| .str.replace(r"__+", "_", regex=True) | |
| .str.lower() | |
| ) | |
| return out | |
| def load_samples(): | |
| if not Path(SAMPLES_CSV).exists(): | |
| return None | |
| df = pd.read_csv(SAMPLES_CSV) | |
| df = normalize_cols(df) | |
| needed = set(["id", TARGET_COL] + SUBSET_FEATURES) | |
| if not needed.issubset(df.columns): | |
| missing = needed - set(df.columns) | |
| print(f"[WARN] samples file missing columns: {sorted(missing)}") | |
| return None | |
| df = df.reset_index(drop=False).rename(columns={"index": "_rid"}) | |
| return df | |
| def pretty_json(d): | |
| return json.dumps(d, ensure_ascii=False, indent=2) | |
| def as_bool(x, default=False): | |
| if x is None or (isinstance(x, float) and pd.isna(x)): | |
| return default | |
| if isinstance(x, bool): | |
| return x | |
| if isinstance(x, (int,)): | |
| return bool(x) | |
| s = str(x).strip().lower() | |
| yes = {"1","true","t","yes","y","on","pos","positive"} | |
| no = {"0","false","f","no","n","off","neg","negative"} | |
| if s in yes: return True | |
| if s in no: return False | |
| try: | |
| return bool(int(float(s))) | |
| except Exception: | |
| return default | |
| def f_or_none(v): | |
| return float(v) if (v is not None and not (isinstance(v, float) and pd.isna(v))) else None | |
| def build_row_dict( | |
| age, bmi, ab_count, | |
| htn, cvd, | |
| fbs1, hb, hct, cr, plt, vitd3, sono_nt, sono_crl | |
| ): | |
| return { | |
| "age": age, | |
| "bmi": bmi, | |
| "previos_obsteric_history_ab": ab_count, | |
| "history_of_htn": 1 if htn else 0, | |
| "history_infectious_cardiovascular_diseae": 1 if cvd else 0, | |
| "fbs_first_trimester": fbs1, | |
| "hb": hb, | |
| "hct": hct, | |
| "cr": cr, | |
| "plt": plt, | |
| "vit_d3": vitd3, | |
| "sono_nt_nt": sono_nt, | |
| "sono_nt_crl": sono_crl, | |
| } | |
| def _get_pos_index_and_classes(pipe, pos_label=1): | |
| est = None | |
| try: | |
| est = getattr(pipe, "named_steps", {}).get("trained_model", None) | |
| except Exception: | |
| est = None | |
| if est is None: | |
| est = pipe | |
| classes = getattr(est, "classes_", None) | |
| if classes is not None and pos_label in list(classes): | |
| return list(classes).index(pos_label), list(classes) | |
| return -1, list(classes) if classes is not None else None | |
| # ---------- model & samples ---------- | |
| local_path = hf_hub_download(repo_id=REPO, filename=FNAME, token=TOKEN) | |
| model = load_model(str(Path(local_path).with_suffix(""))) | |
| samples_df = load_samples() | |
| # ---------- SHAP: background + explainer (built once) ---------- | |
| def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame: | |
| if df_samples is None: | |
| bg = pd.DataFrame([{k: 0.0 for k in SUBSET_FEATURES} for _ in range(50)]) | |
| else: | |
| bg = df_samples[SUBSET_FEATURES].copy() | |
| for c in SUBSET_FEATURES: | |
| if c not in bg.columns: | |
| bg[c] = np.nan | |
| bg = bg.apply(pd.to_numeric, errors="coerce") | |
| bg = bg.fillna(bg.median(numeric_only=True)) | |
| if len(bg) > max_rows: | |
| bg = bg.sample(max_rows, random_state=42) | |
| return bg.reset_index(drop=True) | |
| BACKGROUND = _prepare_background(samples_df) | |
| POS_IDX, _ = _get_pos_index_and_classes(model, POS_LABEL) | |
| def _f_proba_pos(X_np: np.ndarray) -> np.ndarray: | |
| """Model function returning P(class==1) for SHAP. X_np is numpy; convert to DataFrame with right columns.""" | |
| X_df = pd.DataFrame(X_np, columns=SUBSET_FEATURES) | |
| return model.predict_proba(X_df)[:, POS_IDX] | |
| # SHAP Explainer | |
| try: | |
| EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values) | |
| except Exception as e: | |
| print("[WARN] SHAP explainer init failed:", e) | |
| EXPLAINER = None | |
| def _plot_local_shap(row_dict: dict): | |
| """Returns a matplotlib Figure with local SHAP bar chart for the given row.""" | |
| if EXPLAINER is None: | |
| return None | |
| X = pd.DataFrame([row_dict], columns=SUBSET_FEATURES) | |
| exp = EXPLAINER(X.values) | |
| vals = exp.values[0] | |
| order = np.argsort(np.abs(vals)) | |
| fig, ax = plt.subplots(figsize=(7, 4.5)) | |
| ax.barh(np.array(SUBSET_FEATURES)[order], vals[order]) | |
| ax.axvline(0, linewidth=1) | |
| ax.set_title("Local SHAP values (current input)") | |
| ax.set_xlabel("Impact on P(class==1)") | |
| fig.tight_layout() | |
| return fig | |
| def _plot_global_shap(): | |
| """Returns a matplotlib Figure with global mean(|SHAP|) bar chart over BACKGROUND.""" | |
| if EXPLAINER is None: | |
| return None | |
| exp = EXPLAINER(BACKGROUND.values) | |
| mean_abs = np.mean(np.abs(exp.values), axis=0) | |
| order = np.argsort(mean_abs) | |
| fig, ax = plt.subplots(figsize=(7, 4.5)) | |
| ax.barh(np.array(SUBSET_FEATURES)[order], mean_abs[order]) | |
| ax.set_title("Global feature importance (mean |SHAP|)") | |
| ax.set_xlabel("Mean |impact on P(class==1)|") | |
| fig.tight_layout() | |
| return fig | |
| GLOBAL_FIG = _plot_global_shap() | |
| # ---------- prediction ---------- | |
| def predict_manual( | |
| threshold, | |
| age, bmi, ab_count, | |
| htn, cvd, | |
| fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl | |
| ): | |
| row = build_row_dict( | |
| age, bmi, ab_count, | |
| htn, cvd, | |
| fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl | |
| ) | |
| df = pd.DataFrame([row], columns=SUBSET_FEATURES) | |
| proba = model.predict_proba(df) | |
| p1 = float(proba[0][POS_IDX]) | |
| decision = 1 if p1 >= float(threshold) else 0 | |
| return int(decision), round(p1, 4), ("Positive" if decision==1 else "Negative"), pretty_json(row) | |
| def explain_local( | |
| age, bmi, ab_count, | |
| htn, cvd, | |
| fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl | |
| ): | |
| row = build_row_dict( | |
| age, bmi, ab_count, | |
| htn, cvd, | |
| fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl | |
| ) | |
| fig = _plot_local_shap(row) | |
| return fig | |
| def explain_global(): | |
| return GLOBAL_FIG | |
| def filter_sample_options(filter_target): | |
| if samples_df is None: | |
| return gr.update(choices=[], value=None) | |
| df = samples_df | |
| if filter_target in ("0", "1"): | |
| df = df[df[TARGET_COL] == int(filter_target)] | |
| opts = [ (f"{int(r['_rid'])}: y={int(r[TARGET_COL])}", int(r["_rid"])) for _, r in df.iterrows() ] | |
| return gr.update(choices=opts, value=(opts[0][1] if opts else None)) | |
| def load_sample(rid): | |
| if samples_df is None or rid is None: | |
| return [gr.update()]*13 + [gr.update(value="")] | |
| r = samples_df.loc[samples_df["_rid"] == int(rid)] | |
| if r.empty: | |
| return [gr.update()]*13 + [gr.update(value="")] | |
| r = r.iloc[0] | |
| updates = [ | |
| gr.update(value=f_or_none(r.get("age"))), | |
| gr.update(value=f_or_none(r.get("bmi"))), | |
| gr.update(value=int(r.get("previos_obsteric_history_ab", 0)) if pd.notna(r.get("previos_obsteric_history_ab")) else 0), | |
| gr.update(value=as_bool(r.get("history_of_htn"))), | |
| gr.update(value=as_bool(r.get("history_infectious_cardiovascular_diseae"))), | |
| gr.update(value=f_or_none(r.get("fbs_first_trimester"))), | |
| gr.update(value=f_or_none(r.get("hb"))), | |
| gr.update(value=f_or_none(r.get("hct"))), | |
| gr.update(value=f_or_none(r.get("cr"))), | |
| gr.update(value=f_or_none(r.get("plt"))), | |
| gr.update(value=f_or_none(r.get("vit_d3"))), | |
| gr.update(value=f_or_none(r.get("sono_nt_nt"))), | |
| gr.update(value=f_or_none(r.get("sono_nt_crl"))), | |
| gr.update(value=str(int(r.get(TARGET_COL))) if pd.notna(r.get(TARGET_COL)) else "") | |
| ] | |
| return updates | |
| def compare_correctness(gt_text, decision_label): | |
| if gt_text is None or gt_text == "": | |
| return "—" | |
| try: | |
| gt = int(float(gt_text)) | |
| except Exception: | |
| return "—" | |
| return "✅ Correct" if gt == int(decision_label) else "❌ Incorrect" | |
| def get_feature_importance_text(): | |
| est = None | |
| try: | |
| est = getattr(model, "named_steps", {}).get("trained_model", None) | |
| except Exception: | |
| est = None | |
| if est is None: | |
| est = model | |
| fi = None | |
| if hasattr(est, "feature_importances_"): | |
| fi = list(est.feature_importances_) | |
| elif hasattr(est, "coef_"): | |
| coef = est.coef_ | |
| if coef is not None: | |
| fi = list(coef.reshape(-1)) | |
| if not fi or len(fi) != len(SUBSET_FEATURES): | |
| return "Not available for this model." | |
| pairs = sorted(zip(SUBSET_FEATURES, fi), key=lambda x: abs(x[1]), reverse=True) | |
| return "\n".join([f"- {k}: {v:.4f}" for k, v in pairs]) | |
| GLOBAL_FI_TEXT = get_feature_importance_text() | |
| # ---------- theme ---------- | |
| theme = gr.themes.Soft( | |
| primary_hue="violet", | |
| neutral_hue="slate", | |
| ).set( | |
| body_background_fill_dark="#0b0f19", | |
| block_border_width="1px" | |
| ) | |
| # ---------- UI ---------- | |
| with gr.Blocks(theme=theme, title="GTT Classifier") as demo: | |
| gr.Markdown("## GTT Prediction \n**Auto-preprocessing · Thresholdable**") | |
| with gr.Row(): | |
| # (1) Manual input | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1) Manual input") | |
| age = gr.Number(label="Age (years)", value=0) | |
| bmi = gr.Number(label="BMI", value=0) | |
| ab_count = gr.Number(label="Previos Obsteric History of Abortion (count)", value=0, precision=0) | |
| gr.Markdown("---\n**Clinical flags**") | |
| htn = gr.Checkbox(label="History of Hypertension", value=False) | |
| cvd = gr.Checkbox(label="History of Cardiovascular disease", value=False) | |
| with gr.Accordion("Numeric features", open=False): | |
| fbs1 = gr.Number(label="First trimester FBS") | |
| hb = gr.Number(label="First trimester HB") | |
| hct = gr.Number(label="First trimester HCT") | |
| cr = gr.Number(label="First trimester CR") | |
| plt_v = gr.Number(label="First trimester PLT") | |
| vitd3 = gr.Number(label="First trimester Vit D3") | |
| sono_nt = gr.Number(label="First trimester Sonographic NT (nt)") | |
| sono_crl = gr.Number(label="First trimester Sonographic NT (crl)") | |
| with gr.Row(): | |
| threshold = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'") | |
| reset_thr = gr.Button("↻", size="sm") | |
| predict_btn = gr.Button("🚀 Predict (manual)", variant="primary") | |
| explain_btn = gr.Button("🧠 Explain (SHAP for current input)") | |
| # (2) Sample picker | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 2) Sample picker (from fixed file)") | |
| filt = gr.Dropdown(choices=["All", "0", "1"], value="All", label="Filter by target") | |
| sample_dd = gr.Dropdown(choices=[], value=None, label="Choose sample row") | |
| load_ok = gr.Button("Load sample into manual inputs", variant="secondary") | |
| # (3) Results | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 3) Results") | |
| pred_label = gr.Number(label="Predicted label (with threshold decision)", interactive=False) | |
| with gr.Row(): | |
| pred_prob = gr.Number(label="P(class==1)", value=0, interactive=False) | |
| decision_text = gr.Textbox(label="Decision @ threshold", interactive=False) | |
| gt_box = gr.Textbox(label="Ground truth (sample)", interactive=False) | |
| correctness = gr.Textbox(label="Correct vs. ground truth?", interactive=False) | |
| with gr.Accordion("Echoed input (row sent to model)", open=False): | |
| echoed = gr.Code(label="", language="json") | |
| with gr.Accordion("Global feature importance (SHAP)", open=False): | |
| global_plot = gr.Plot(value=GLOBAL_FIG) | |
| gr.Markdown("> Text fallback (native model importances):") | |
| gr.Markdown(GLOBAL_FI_TEXT) | |
| with gr.Accordion("Local explanation (SHAP) for current input", open=False): | |
| local_plot = gr.Plot() | |
| # events | |
| demo.load(lambda: filter_sample_options("All"), inputs=None, outputs=[sample_dd], queue=False) | |
| filt.change(filter_sample_options, inputs=[filt], outputs=[sample_dd]) | |
| reset_thr.click(fn=lambda: 0.5, inputs=None, outputs=[threshold]) | |
| load_ok.click( | |
| fn=load_sample, | |
| inputs=[sample_dd], | |
| outputs=[ | |
| age, bmi, ab_count, | |
| htn, cvd, | |
| fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl, | |
| gt_box | |
| ], | |
| ) | |
| predict_btn.click( | |
| fn=predict_manual, | |
| inputs=[ | |
| threshold, | |
| age, bmi, ab_count, | |
| htn, cvd, | |
| fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl | |
| ], | |
| outputs=[pred_label, pred_prob, decision_text, echoed], | |
| ).then( | |
| fn=compare_correctness, | |
| inputs=[gt_box, pred_label], | |
| outputs=[correctness] | |
| ) | |
| explain_btn.click( | |
| fn=explain_local, | |
| inputs=[age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl], | |
| outputs=[local_plot] | |
| ) | |
| if __name__ == "__main__": | |
| os.environ["NO_PROXY"] = "127.0.0.1,localhost" | |
| os.environ["no_proxy"] = "127.0.0.1,localhost" | |
| demo.launch() | |