GTT / app.py
GDMProjects's picture
Update app.py
afdeb02 verified
import os
os.environ["MPLBACKEND"] = "Agg"
import matplotlib
matplotlib.use("Agg", force=True)
import json
import numpy as np
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import shap
from pathlib import Path
from pycaret.classification import load_model
from huggingface_hub import hf_hub_download
# --- config ---
MODEL_BASENAME = "subset_best_model"
SAMPLES_CSV = "GTT.csv"
TARGET_COL = "gtt"
POS_LABEL = 1
REPO = os.getenv("MODEL_REPO", "GDMProjects/my-private-model")
FNAME = os.getenv("MODEL_FILE", "subset_best_model.pkl")
TOKEN = os.getenv("HF_TOKEN")
SUBSET_FEATURES = [
"age",
"bmi",
"history_of_htn",
"history_infectious_cardiovascular_diseae",
"previos_obsteric_history_ab",
"fbs_first_trimester",
"hb",
"hct",
"cr",
"plt",
"vit_d3",
"sono_nt_nt",
"sono_nt_crl",
]
# ---------- utils ----------
def normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
out = df.copy()
out.columns = (
out.columns.str.strip()
.str.replace(r"[\s/\\\.\-]+", "_", regex=True)
.str.replace(r"__+", "_", regex=True)
.str.lower()
)
return out
def load_samples():
if not Path(SAMPLES_CSV).exists():
return None
df = pd.read_csv(SAMPLES_CSV)
df = normalize_cols(df)
needed = set(["id", TARGET_COL] + SUBSET_FEATURES)
if not needed.issubset(df.columns):
missing = needed - set(df.columns)
print(f"[WARN] samples file missing columns: {sorted(missing)}")
return None
df = df.reset_index(drop=False).rename(columns={"index": "_rid"})
return df
def pretty_json(d):
return json.dumps(d, ensure_ascii=False, indent=2)
def as_bool(x, default=False):
if x is None or (isinstance(x, float) and pd.isna(x)):
return default
if isinstance(x, bool):
return x
if isinstance(x, (int,)):
return bool(x)
s = str(x).strip().lower()
yes = {"1","true","t","yes","y","on","pos","positive"}
no = {"0","false","f","no","n","off","neg","negative"}
if s in yes: return True
if s in no: return False
try:
return bool(int(float(s)))
except Exception:
return default
def f_or_none(v):
return float(v) if (v is not None and not (isinstance(v, float) and pd.isna(v))) else None
def build_row_dict(
age, bmi, ab_count,
htn, cvd,
fbs1, hb, hct, cr, plt, vitd3, sono_nt, sono_crl
):
return {
"age": age,
"bmi": bmi,
"previos_obsteric_history_ab": ab_count,
"history_of_htn": 1 if htn else 0,
"history_infectious_cardiovascular_diseae": 1 if cvd else 0,
"fbs_first_trimester": fbs1,
"hb": hb,
"hct": hct,
"cr": cr,
"plt": plt,
"vit_d3": vitd3,
"sono_nt_nt": sono_nt,
"sono_nt_crl": sono_crl,
}
def _get_pos_index_and_classes(pipe, pos_label=1):
est = None
try:
est = getattr(pipe, "named_steps", {}).get("trained_model", None)
except Exception:
est = None
if est is None:
est = pipe
classes = getattr(est, "classes_", None)
if classes is not None and pos_label in list(classes):
return list(classes).index(pos_label), list(classes)
return -1, list(classes) if classes is not None else None
# ---------- model & samples ----------
local_path = hf_hub_download(repo_id=REPO, filename=FNAME, token=TOKEN)
model = load_model(str(Path(local_path).with_suffix("")))
samples_df = load_samples()
# ---------- SHAP: background + explainer (built once) ----------
def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame:
if df_samples is None:
bg = pd.DataFrame([{k: 0.0 for k in SUBSET_FEATURES} for _ in range(50)])
else:
bg = df_samples[SUBSET_FEATURES].copy()
for c in SUBSET_FEATURES:
if c not in bg.columns:
bg[c] = np.nan
bg = bg.apply(pd.to_numeric, errors="coerce")
bg = bg.fillna(bg.median(numeric_only=True))
if len(bg) > max_rows:
bg = bg.sample(max_rows, random_state=42)
return bg.reset_index(drop=True)
BACKGROUND = _prepare_background(samples_df)
POS_IDX, _ = _get_pos_index_and_classes(model, POS_LABEL)
def _f_proba_pos(X_np: np.ndarray) -> np.ndarray:
"""Model function returning P(class==1) for SHAP. X_np is numpy; convert to DataFrame with right columns."""
X_df = pd.DataFrame(X_np, columns=SUBSET_FEATURES)
return model.predict_proba(X_df)[:, POS_IDX]
# SHAP Explainer
try:
EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values)
except Exception as e:
print("[WARN] SHAP explainer init failed:", e)
EXPLAINER = None
def _plot_local_shap(row_dict: dict):
"""Returns a matplotlib Figure with local SHAP bar chart for the given row."""
if EXPLAINER is None:
return None
X = pd.DataFrame([row_dict], columns=SUBSET_FEATURES)
exp = EXPLAINER(X.values)
vals = exp.values[0]
order = np.argsort(np.abs(vals))
fig, ax = plt.subplots(figsize=(7, 4.5))
ax.barh(np.array(SUBSET_FEATURES)[order], vals[order])
ax.axvline(0, linewidth=1)
ax.set_title("Local SHAP values (current input)")
ax.set_xlabel("Impact on P(class==1)")
fig.tight_layout()
return fig
def _plot_global_shap():
"""Returns a matplotlib Figure with global mean(|SHAP|) bar chart over BACKGROUND."""
if EXPLAINER is None:
return None
exp = EXPLAINER(BACKGROUND.values)
mean_abs = np.mean(np.abs(exp.values), axis=0)
order = np.argsort(mean_abs)
fig, ax = plt.subplots(figsize=(7, 4.5))
ax.barh(np.array(SUBSET_FEATURES)[order], mean_abs[order])
ax.set_title("Global feature importance (mean |SHAP|)")
ax.set_xlabel("Mean |impact on P(class==1)|")
fig.tight_layout()
return fig
GLOBAL_FIG = _plot_global_shap()
# ---------- prediction ----------
def predict_manual(
threshold,
age, bmi, ab_count,
htn, cvd,
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
):
row = build_row_dict(
age, bmi, ab_count,
htn, cvd,
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
)
df = pd.DataFrame([row], columns=SUBSET_FEATURES)
proba = model.predict_proba(df)
p1 = float(proba[0][POS_IDX])
decision = 1 if p1 >= float(threshold) else 0
return int(decision), round(p1, 4), ("Positive" if decision==1 else "Negative"), pretty_json(row)
def explain_local(
age, bmi, ab_count,
htn, cvd,
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
):
row = build_row_dict(
age, bmi, ab_count,
htn, cvd,
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
)
fig = _plot_local_shap(row)
return fig
def explain_global():
return GLOBAL_FIG
def filter_sample_options(filter_target):
if samples_df is None:
return gr.update(choices=[], value=None)
df = samples_df
if filter_target in ("0", "1"):
df = df[df[TARGET_COL] == int(filter_target)]
opts = [ (f"{int(r['_rid'])}: y={int(r[TARGET_COL])}", int(r["_rid"])) for _, r in df.iterrows() ]
return gr.update(choices=opts, value=(opts[0][1] if opts else None))
def load_sample(rid):
if samples_df is None or rid is None:
return [gr.update()]*13 + [gr.update(value="")]
r = samples_df.loc[samples_df["_rid"] == int(rid)]
if r.empty:
return [gr.update()]*13 + [gr.update(value="")]
r = r.iloc[0]
updates = [
gr.update(value=f_or_none(r.get("age"))),
gr.update(value=f_or_none(r.get("bmi"))),
gr.update(value=int(r.get("previos_obsteric_history_ab", 0)) if pd.notna(r.get("previos_obsteric_history_ab")) else 0),
gr.update(value=as_bool(r.get("history_of_htn"))),
gr.update(value=as_bool(r.get("history_infectious_cardiovascular_diseae"))),
gr.update(value=f_or_none(r.get("fbs_first_trimester"))),
gr.update(value=f_or_none(r.get("hb"))),
gr.update(value=f_or_none(r.get("hct"))),
gr.update(value=f_or_none(r.get("cr"))),
gr.update(value=f_or_none(r.get("plt"))),
gr.update(value=f_or_none(r.get("vit_d3"))),
gr.update(value=f_or_none(r.get("sono_nt_nt"))),
gr.update(value=f_or_none(r.get("sono_nt_crl"))),
gr.update(value=str(int(r.get(TARGET_COL))) if pd.notna(r.get(TARGET_COL)) else "")
]
return updates
def compare_correctness(gt_text, decision_label):
if gt_text is None or gt_text == "":
return "—"
try:
gt = int(float(gt_text))
except Exception:
return "—"
return "✅ Correct" if gt == int(decision_label) else "❌ Incorrect"
def get_feature_importance_text():
est = None
try:
est = getattr(model, "named_steps", {}).get("trained_model", None)
except Exception:
est = None
if est is None:
est = model
fi = None
if hasattr(est, "feature_importances_"):
fi = list(est.feature_importances_)
elif hasattr(est, "coef_"):
coef = est.coef_
if coef is not None:
fi = list(coef.reshape(-1))
if not fi or len(fi) != len(SUBSET_FEATURES):
return "Not available for this model."
pairs = sorted(zip(SUBSET_FEATURES, fi), key=lambda x: abs(x[1]), reverse=True)
return "\n".join([f"- {k}: {v:.4f}" for k, v in pairs])
GLOBAL_FI_TEXT = get_feature_importance_text()
# ---------- theme ----------
theme = gr.themes.Soft(
primary_hue="violet",
neutral_hue="slate",
).set(
body_background_fill_dark="#0b0f19",
block_border_width="1px"
)
# ---------- UI ----------
with gr.Blocks(theme=theme, title="GTT Classifier") as demo:
gr.Markdown("## GTT Prediction \n**Auto-preprocessing · Thresholdable**")
with gr.Row():
# (1) Manual input
with gr.Column(scale=1):
gr.Markdown("### 1) Manual input")
age = gr.Number(label="Age (years)", value=0)
bmi = gr.Number(label="BMI", value=0)
ab_count = gr.Number(label="Previos Obsteric History of Abortion (count)", value=0, precision=0)
gr.Markdown("---\n**Clinical flags**")
htn = gr.Checkbox(label="History of Hypertension", value=False)
cvd = gr.Checkbox(label="History of Cardiovascular disease", value=False)
with gr.Accordion("Numeric features", open=False):
fbs1 = gr.Number(label="First trimester FBS")
hb = gr.Number(label="First trimester HB")
hct = gr.Number(label="First trimester HCT")
cr = gr.Number(label="First trimester CR")
plt_v = gr.Number(label="First trimester PLT")
vitd3 = gr.Number(label="First trimester Vit D3")
sono_nt = gr.Number(label="First trimester Sonographic NT (nt)")
sono_crl = gr.Number(label="First trimester Sonographic NT (crl)")
with gr.Row():
threshold = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'")
reset_thr = gr.Button("↻", size="sm")
predict_btn = gr.Button("🚀 Predict (manual)", variant="primary")
explain_btn = gr.Button("🧠 Explain (SHAP for current input)")
# (2) Sample picker
with gr.Column(scale=1):
gr.Markdown("### 2) Sample picker (from fixed file)")
filt = gr.Dropdown(choices=["All", "0", "1"], value="All", label="Filter by target")
sample_dd = gr.Dropdown(choices=[], value=None, label="Choose sample row")
load_ok = gr.Button("Load sample into manual inputs", variant="secondary")
# (3) Results
with gr.Column(scale=1):
gr.Markdown("### 3) Results")
pred_label = gr.Number(label="Predicted label (with threshold decision)", interactive=False)
with gr.Row():
pred_prob = gr.Number(label="P(class==1)", value=0, interactive=False)
decision_text = gr.Textbox(label="Decision @ threshold", interactive=False)
gt_box = gr.Textbox(label="Ground truth (sample)", interactive=False)
correctness = gr.Textbox(label="Correct vs. ground truth?", interactive=False)
with gr.Accordion("Echoed input (row sent to model)", open=False):
echoed = gr.Code(label="", language="json")
with gr.Accordion("Global feature importance (SHAP)", open=False):
global_plot = gr.Plot(value=GLOBAL_FIG)
gr.Markdown("> Text fallback (native model importances):")
gr.Markdown(GLOBAL_FI_TEXT)
with gr.Accordion("Local explanation (SHAP) for current input", open=False):
local_plot = gr.Plot()
# events
demo.load(lambda: filter_sample_options("All"), inputs=None, outputs=[sample_dd], queue=False)
filt.change(filter_sample_options, inputs=[filt], outputs=[sample_dd])
reset_thr.click(fn=lambda: 0.5, inputs=None, outputs=[threshold])
load_ok.click(
fn=load_sample,
inputs=[sample_dd],
outputs=[
age, bmi, ab_count,
htn, cvd,
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl,
gt_box
],
)
predict_btn.click(
fn=predict_manual,
inputs=[
threshold,
age, bmi, ab_count,
htn, cvd,
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
],
outputs=[pred_label, pred_prob, decision_text, echoed],
).then(
fn=compare_correctness,
inputs=[gt_box, pred_label],
outputs=[correctness]
)
explain_btn.click(
fn=explain_local,
inputs=[age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl],
outputs=[local_plot]
)
if __name__ == "__main__":
os.environ["NO_PROXY"] = "127.0.0.1,localhost"
os.environ["no_proxy"] = "127.0.0.1,localhost"
demo.launch()