import os os.environ["MPLBACKEND"] = "Agg" import matplotlib matplotlib.use("Agg", force=True) import json import numpy as np import pandas as pd import gradio as gr import matplotlib.pyplot as plt import shap from pathlib import Path from pycaret.classification import load_model from huggingface_hub import hf_hub_download # --- config --- MODEL_BASENAME = "subset_best_model" SAMPLES_CSV = "GTT.csv" TARGET_COL = "gtt" POS_LABEL = 1 REPO = os.getenv("MODEL_REPO", "GDMProjects/my-private-model") FNAME = os.getenv("MODEL_FILE", "subset_best_model.pkl") TOKEN = os.getenv("HF_TOKEN") SUBSET_FEATURES = [ "age", "bmi", "history_of_htn", "history_infectious_cardiovascular_diseae", "previos_obsteric_history_ab", "fbs_first_trimester", "hb", "hct", "cr", "plt", "vit_d3", "sono_nt_nt", "sono_nt_crl", ] # ---------- utils ---------- def normalize_cols(df: pd.DataFrame) -> pd.DataFrame: out = df.copy() out.columns = ( out.columns.str.strip() .str.replace(r"[\s/\\\.\-]+", "_", regex=True) .str.replace(r"__+", "_", regex=True) .str.lower() ) return out def load_samples(): if not Path(SAMPLES_CSV).exists(): return None df = pd.read_csv(SAMPLES_CSV) df = normalize_cols(df) needed = set(["id", TARGET_COL] + SUBSET_FEATURES) if not needed.issubset(df.columns): missing = needed - set(df.columns) print(f"[WARN] samples file missing columns: {sorted(missing)}") return None df = df.reset_index(drop=False).rename(columns={"index": "_rid"}) return df def pretty_json(d): return json.dumps(d, ensure_ascii=False, indent=2) def as_bool(x, default=False): if x is None or (isinstance(x, float) and pd.isna(x)): return default if isinstance(x, bool): return x if isinstance(x, (int,)): return bool(x) s = str(x).strip().lower() yes = {"1","true","t","yes","y","on","pos","positive"} no = {"0","false","f","no","n","off","neg","negative"} if s in yes: return True if s in no: return False try: return bool(int(float(s))) except Exception: return default def f_or_none(v): return float(v) if (v is not None and not (isinstance(v, float) and pd.isna(v))) else None def build_row_dict( age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt, vitd3, sono_nt, sono_crl ): return { "age": age, "bmi": bmi, "previos_obsteric_history_ab": ab_count, "history_of_htn": 1 if htn else 0, "history_infectious_cardiovascular_diseae": 1 if cvd else 0, "fbs_first_trimester": fbs1, "hb": hb, "hct": hct, "cr": cr, "plt": plt, "vit_d3": vitd3, "sono_nt_nt": sono_nt, "sono_nt_crl": sono_crl, } def _get_pos_index_and_classes(pipe, pos_label=1): est = None try: est = getattr(pipe, "named_steps", {}).get("trained_model", None) except Exception: est = None if est is None: est = pipe classes = getattr(est, "classes_", None) if classes is not None and pos_label in list(classes): return list(classes).index(pos_label), list(classes) return -1, list(classes) if classes is not None else None # ---------- model & samples ---------- local_path = hf_hub_download(repo_id=REPO, filename=FNAME, token=TOKEN) model = load_model(str(Path(local_path).with_suffix(""))) samples_df = load_samples() # ---------- SHAP: background + explainer (built once) ---------- def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame: if df_samples is None: bg = pd.DataFrame([{k: 0.0 for k in SUBSET_FEATURES} for _ in range(50)]) else: bg = df_samples[SUBSET_FEATURES].copy() for c in SUBSET_FEATURES: if c not in bg.columns: bg[c] = np.nan bg = bg.apply(pd.to_numeric, errors="coerce") bg = bg.fillna(bg.median(numeric_only=True)) if len(bg) > max_rows: bg = bg.sample(max_rows, random_state=42) return bg.reset_index(drop=True) BACKGROUND = _prepare_background(samples_df) POS_IDX, _ = _get_pos_index_and_classes(model, POS_LABEL) def _f_proba_pos(X_np: np.ndarray) -> np.ndarray: """Model function returning P(class==1) for SHAP. X_np is numpy; convert to DataFrame with right columns.""" X_df = pd.DataFrame(X_np, columns=SUBSET_FEATURES) return model.predict_proba(X_df)[:, POS_IDX] # SHAP Explainer try: EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values) except Exception as e: print("[WARN] SHAP explainer init failed:", e) EXPLAINER = None def _plot_local_shap(row_dict: dict): """Returns a matplotlib Figure with local SHAP bar chart for the given row.""" if EXPLAINER is None: return None X = pd.DataFrame([row_dict], columns=SUBSET_FEATURES) exp = EXPLAINER(X.values) vals = exp.values[0] order = np.argsort(np.abs(vals)) fig, ax = plt.subplots(figsize=(7, 4.5)) ax.barh(np.array(SUBSET_FEATURES)[order], vals[order]) ax.axvline(0, linewidth=1) ax.set_title("Local SHAP values (current input)") ax.set_xlabel("Impact on P(class==1)") fig.tight_layout() return fig def _plot_global_shap(): """Returns a matplotlib Figure with global mean(|SHAP|) bar chart over BACKGROUND.""" if EXPLAINER is None: return None exp = EXPLAINER(BACKGROUND.values) mean_abs = np.mean(np.abs(exp.values), axis=0) order = np.argsort(mean_abs) fig, ax = plt.subplots(figsize=(7, 4.5)) ax.barh(np.array(SUBSET_FEATURES)[order], mean_abs[order]) ax.set_title("Global feature importance (mean |SHAP|)") ax.set_xlabel("Mean |impact on P(class==1)|") fig.tight_layout() return fig GLOBAL_FIG = _plot_global_shap() # ---------- prediction ---------- def predict_manual( threshold, age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl ): row = build_row_dict( age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl ) df = pd.DataFrame([row], columns=SUBSET_FEATURES) proba = model.predict_proba(df) p1 = float(proba[0][POS_IDX]) decision = 1 if p1 >= float(threshold) else 0 return int(decision), round(p1, 4), ("Positive" if decision==1 else "Negative"), pretty_json(row) def explain_local( age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl ): row = build_row_dict( age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl ) fig = _plot_local_shap(row) return fig def explain_global(): return GLOBAL_FIG def filter_sample_options(filter_target): if samples_df is None: return gr.update(choices=[], value=None) df = samples_df if filter_target in ("0", "1"): df = df[df[TARGET_COL] == int(filter_target)] opts = [ (f"{int(r['_rid'])}: y={int(r[TARGET_COL])}", int(r["_rid"])) for _, r in df.iterrows() ] return gr.update(choices=opts, value=(opts[0][1] if opts else None)) def load_sample(rid): if samples_df is None or rid is None: return [gr.update()]*13 + [gr.update(value="")] r = samples_df.loc[samples_df["_rid"] == int(rid)] if r.empty: return [gr.update()]*13 + [gr.update(value="")] r = r.iloc[0] updates = [ gr.update(value=f_or_none(r.get("age"))), gr.update(value=f_or_none(r.get("bmi"))), gr.update(value=int(r.get("previos_obsteric_history_ab", 0)) if pd.notna(r.get("previos_obsteric_history_ab")) else 0), gr.update(value=as_bool(r.get("history_of_htn"))), gr.update(value=as_bool(r.get("history_infectious_cardiovascular_diseae"))), gr.update(value=f_or_none(r.get("fbs_first_trimester"))), gr.update(value=f_or_none(r.get("hb"))), gr.update(value=f_or_none(r.get("hct"))), gr.update(value=f_or_none(r.get("cr"))), gr.update(value=f_or_none(r.get("plt"))), gr.update(value=f_or_none(r.get("vit_d3"))), gr.update(value=f_or_none(r.get("sono_nt_nt"))), gr.update(value=f_or_none(r.get("sono_nt_crl"))), gr.update(value=str(int(r.get(TARGET_COL))) if pd.notna(r.get(TARGET_COL)) else "") ] return updates def compare_correctness(gt_text, decision_label): if gt_text is None or gt_text == "": return "—" try: gt = int(float(gt_text)) except Exception: return "—" return "✅ Correct" if gt == int(decision_label) else "❌ Incorrect" def get_feature_importance_text(): est = None try: est = getattr(model, "named_steps", {}).get("trained_model", None) except Exception: est = None if est is None: est = model fi = None if hasattr(est, "feature_importances_"): fi = list(est.feature_importances_) elif hasattr(est, "coef_"): coef = est.coef_ if coef is not None: fi = list(coef.reshape(-1)) if not fi or len(fi) != len(SUBSET_FEATURES): return "Not available for this model." pairs = sorted(zip(SUBSET_FEATURES, fi), key=lambda x: abs(x[1]), reverse=True) return "\n".join([f"- {k}: {v:.4f}" for k, v in pairs]) GLOBAL_FI_TEXT = get_feature_importance_text() # ---------- theme ---------- theme = gr.themes.Soft( primary_hue="violet", neutral_hue="slate", ).set( body_background_fill_dark="#0b0f19", block_border_width="1px" ) # ---------- UI ---------- with gr.Blocks(theme=theme, title="GTT Classifier") as demo: gr.Markdown("## GTT Prediction \n**Auto-preprocessing · Thresholdable**") with gr.Row(): # (1) Manual input with gr.Column(scale=1): gr.Markdown("### 1) Manual input") age = gr.Number(label="Age (years)", value=0) bmi = gr.Number(label="BMI", value=0) ab_count = gr.Number(label="Previos Obsteric History of Abortion (count)", value=0, precision=0) gr.Markdown("---\n**Clinical flags**") htn = gr.Checkbox(label="History of Hypertension", value=False) cvd = gr.Checkbox(label="History of Cardiovascular disease", value=False) with gr.Accordion("Numeric features", open=False): fbs1 = gr.Number(label="First trimester FBS") hb = gr.Number(label="First trimester HB") hct = gr.Number(label="First trimester HCT") cr = gr.Number(label="First trimester CR") plt_v = gr.Number(label="First trimester PLT") vitd3 = gr.Number(label="First trimester Vit D3") sono_nt = gr.Number(label="First trimester Sonographic NT (nt)") sono_crl = gr.Number(label="First trimester Sonographic NT (crl)") with gr.Row(): threshold = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'") reset_thr = gr.Button("↻", size="sm") predict_btn = gr.Button("🚀 Predict (manual)", variant="primary") explain_btn = gr.Button("🧠 Explain (SHAP for current input)") # (2) Sample picker with gr.Column(scale=1): gr.Markdown("### 2) Sample picker (from fixed file)") filt = gr.Dropdown(choices=["All", "0", "1"], value="All", label="Filter by target") sample_dd = gr.Dropdown(choices=[], value=None, label="Choose sample row") load_ok = gr.Button("Load sample into manual inputs", variant="secondary") # (3) Results with gr.Column(scale=1): gr.Markdown("### 3) Results") pred_label = gr.Number(label="Predicted label (with threshold decision)", interactive=False) with gr.Row(): pred_prob = gr.Number(label="P(class==1)", value=0, interactive=False) decision_text = gr.Textbox(label="Decision @ threshold", interactive=False) gt_box = gr.Textbox(label="Ground truth (sample)", interactive=False) correctness = gr.Textbox(label="Correct vs. ground truth?", interactive=False) with gr.Accordion("Echoed input (row sent to model)", open=False): echoed = gr.Code(label="", language="json") with gr.Accordion("Global feature importance (SHAP)", open=False): global_plot = gr.Plot(value=GLOBAL_FIG) gr.Markdown("> Text fallback (native model importances):") gr.Markdown(GLOBAL_FI_TEXT) with gr.Accordion("Local explanation (SHAP) for current input", open=False): local_plot = gr.Plot() # events demo.load(lambda: filter_sample_options("All"), inputs=None, outputs=[sample_dd], queue=False) filt.change(filter_sample_options, inputs=[filt], outputs=[sample_dd]) reset_thr.click(fn=lambda: 0.5, inputs=None, outputs=[threshold]) load_ok.click( fn=load_sample, inputs=[sample_dd], outputs=[ age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl, gt_box ], ) predict_btn.click( fn=predict_manual, inputs=[ threshold, age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl ], outputs=[pred_label, pred_prob, decision_text, echoed], ).then( fn=compare_correctness, inputs=[gt_box, pred_label], outputs=[correctness] ) explain_btn.click( fn=explain_local, inputs=[age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl], outputs=[local_plot] ) if __name__ == "__main__": os.environ["NO_PROXY"] = "127.0.0.1,localhost" os.environ["no_proxy"] = "127.0.0.1,localhost" demo.launch()