Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- .dockerignore +4 -0
- Dockerfile +15 -0
- GTT.csv +10 -0
- app.py +409 -0
- requirements.txt +5 -0
- subset_best_model.pkl +3 -0
.dockerignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
subset_best_model.pkl
|
| 2 |
+
GTT.csv
|
| 3 |
+
app.py
|
| 4 |
+
requirements.txt
|
Dockerfile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
WORKDIR /app
|
| 3 |
+
|
| 4 |
+
# Copy your app and data into the image
|
| 5 |
+
COPY app.py .
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
COPY subset_best_model.pkl .
|
| 8 |
+
COPY data.csv .
|
| 9 |
+
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
EXPOSE 7860
|
| 13 |
+
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 14 |
+
|
| 15 |
+
CMD ["python", "app.py"]
|
GTT.csv
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ID,GTT,age,BMI,history_of_htn,history_infectious_cardiovascular_diseae,Previos_Obsteric_History_AB,FBS_first_trimester,HB,CR,Hct,PLT,VIT D3,sono_nt.crl,sono_nt.nt
|
| 2 |
+
3505,1,37,22,0,0,0,88,13.4,0.6,39.9,278,36,56,1.6
|
| 3 |
+
3530,1,33,28,0,0,0,96,14,0,41.2,187,9,54,1.4
|
| 4 |
+
4057,0,33,26,0,0,2,110,12,0.7,35.9,333,26,48.3,1.1
|
| 5 |
+
4491,0,27,25,0,0,3,84,13.6,0.7,40.3,204,13,69,1.9
|
| 6 |
+
4707,0,39,27,0,0,1,71,14.9,0.6,44,335,14,64,1
|
| 7 |
+
4813,0,37,22,0,0,1,88,13.2,0,37.9,150,39,54.3,1
|
| 8 |
+
5098,0,36,25,0,0,4,91,12.9,1.8,38,288,16,55,1.3
|
| 9 |
+
5314,1,41,35,1,0,0,98,10.8,0.9,34.5,398,21,45.2,3.2
|
| 10 |
+
5767,1,37,22,0,0,0,101,14.5,1,42,300,33,62.2,1
|
app.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
# pip install "pycaret>=3.3,<4" gradio pandas shap matplotlib
|
| 3 |
+
|
| 4 |
+
# --- FORCE NON-INTERACTIVE MATPLOTLIB BACKEND (must be first!) ---
|
| 5 |
+
import os
|
| 6 |
+
os.environ["MPLBACKEND"] = "Agg" # prevents Tk backend init
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use("Agg", force=True)
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import shap
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from pycaret.classification import load_model
|
| 19 |
+
|
| 20 |
+
# --- config ---
|
| 21 |
+
MODEL_BASENAME = "subset_best_model"
|
| 22 |
+
SAMPLES_CSV = "GTT.csv" # fixed hidden file
|
| 23 |
+
TARGET_COL = "gtt"
|
| 24 |
+
POS_LABEL = 1
|
| 25 |
+
|
| 26 |
+
# subset features used by the model (normalized names)
|
| 27 |
+
SUBSET_FEATURES = [
|
| 28 |
+
"age",
|
| 29 |
+
"bmi",
|
| 30 |
+
"history_of_htn",
|
| 31 |
+
"history_infectious_cardiovascular_diseae",
|
| 32 |
+
"previos_obsteric_history_ab",
|
| 33 |
+
"fbs_first_trimester",
|
| 34 |
+
"hb",
|
| 35 |
+
"hct",
|
| 36 |
+
"cr",
|
| 37 |
+
"plt",
|
| 38 |
+
"vit_d3",
|
| 39 |
+
"sono_nt_nt",
|
| 40 |
+
"sono_nt_crl",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
# ---------- utils ----------
|
| 44 |
+
def normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
|
| 45 |
+
out = df.copy()
|
| 46 |
+
out.columns = (
|
| 47 |
+
out.columns.str.strip()
|
| 48 |
+
.str.replace(r"[\s/\\\.\-]+", "_", regex=True)
|
| 49 |
+
.str.replace(r"__+", "_", regex=True)
|
| 50 |
+
.str.lower()
|
| 51 |
+
)
|
| 52 |
+
return out
|
| 53 |
+
|
| 54 |
+
def load_samples():
|
| 55 |
+
if not Path(SAMPLES_CSV).exists():
|
| 56 |
+
return None
|
| 57 |
+
df = pd.read_csv(SAMPLES_CSV)
|
| 58 |
+
df = normalize_cols(df)
|
| 59 |
+
needed = set(["id", TARGET_COL] + SUBSET_FEATURES)
|
| 60 |
+
if not needed.issubset(df.columns):
|
| 61 |
+
missing = needed - set(df.columns)
|
| 62 |
+
print(f"[WARN] samples file missing columns: {sorted(missing)}")
|
| 63 |
+
return None
|
| 64 |
+
df = df.reset_index(drop=False).rename(columns={"index": "_rid"}) # stable row id for dropdown
|
| 65 |
+
return df
|
| 66 |
+
|
| 67 |
+
def pretty_json(d):
|
| 68 |
+
return json.dumps(d, ensure_ascii=False, indent=2)
|
| 69 |
+
|
| 70 |
+
def as_bool(x, default=False):
|
| 71 |
+
if x is None or (isinstance(x, float) and pd.isna(x)):
|
| 72 |
+
return default
|
| 73 |
+
if isinstance(x, bool):
|
| 74 |
+
return x
|
| 75 |
+
if isinstance(x, (int,)):
|
| 76 |
+
return bool(x)
|
| 77 |
+
s = str(x).strip().lower()
|
| 78 |
+
yes = {"1","true","t","yes","y","on","pos","positive"}
|
| 79 |
+
no = {"0","false","f","no","n","off","neg","negative"}
|
| 80 |
+
if s in yes: return True
|
| 81 |
+
if s in no: return False
|
| 82 |
+
try:
|
| 83 |
+
return bool(int(float(s)))
|
| 84 |
+
except Exception:
|
| 85 |
+
return default
|
| 86 |
+
|
| 87 |
+
def f_or_none(v):
|
| 88 |
+
return float(v) if (v is not None and not (isinstance(v, float) and pd.isna(v))) else None
|
| 89 |
+
|
| 90 |
+
def build_row_dict(
|
| 91 |
+
age, bmi, ab_count,
|
| 92 |
+
htn, cvd,
|
| 93 |
+
fbs1, hb, hct, cr, plt, vitd3, sono_nt, sono_crl
|
| 94 |
+
):
|
| 95 |
+
return {
|
| 96 |
+
"age": age,
|
| 97 |
+
"bmi": bmi,
|
| 98 |
+
"previos_obsteric_history_ab": ab_count,
|
| 99 |
+
"history_of_htn": 1 if htn else 0,
|
| 100 |
+
"history_infectious_cardiovascular_diseae": 1 if cvd else 0,
|
| 101 |
+
"fbs_first_trimester": fbs1,
|
| 102 |
+
"hb": hb,
|
| 103 |
+
"hct": hct,
|
| 104 |
+
"cr": cr,
|
| 105 |
+
"plt": plt,
|
| 106 |
+
"vit_d3": vitd3,
|
| 107 |
+
"sono_nt_nt": sono_nt,
|
| 108 |
+
"sono_nt_crl": sono_crl,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
def _get_pos_index_and_classes(pipe, pos_label=1):
|
| 112 |
+
est = None
|
| 113 |
+
try:
|
| 114 |
+
est = getattr(pipe, "named_steps", {}).get("trained_model", None)
|
| 115 |
+
except Exception:
|
| 116 |
+
est = None
|
| 117 |
+
if est is None:
|
| 118 |
+
est = pipe
|
| 119 |
+
classes = getattr(est, "classes_", None)
|
| 120 |
+
if classes is not None and pos_label in list(classes):
|
| 121 |
+
return list(classes).index(pos_label), list(classes)
|
| 122 |
+
return -1, list(classes) if classes is not None else None
|
| 123 |
+
|
| 124 |
+
# ---------- model & samples ----------
|
| 125 |
+
model = load_model(MODEL_BASENAME)
|
| 126 |
+
samples_df = load_samples()
|
| 127 |
+
|
| 128 |
+
# ---------- SHAP: background + explainer (built once) ----------
|
| 129 |
+
def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame:
|
| 130 |
+
if df_samples is None:
|
| 131 |
+
# if no CSV, make a tiny synthetic background of zeros
|
| 132 |
+
bg = pd.DataFrame([{k: 0.0 for k in SUBSET_FEATURES} for _ in range(50)])
|
| 133 |
+
else:
|
| 134 |
+
bg = df_samples[SUBSET_FEATURES].copy()
|
| 135 |
+
# numeric coercion + median impute
|
| 136 |
+
for c in SUBSET_FEATURES:
|
| 137 |
+
if c not in bg.columns:
|
| 138 |
+
bg[c] = np.nan
|
| 139 |
+
bg = bg.apply(pd.to_numeric, errors="coerce")
|
| 140 |
+
bg = bg.fillna(bg.median(numeric_only=True))
|
| 141 |
+
if len(bg) > max_rows:
|
| 142 |
+
bg = bg.sample(max_rows, random_state=42)
|
| 143 |
+
return bg.reset_index(drop=True)
|
| 144 |
+
|
| 145 |
+
BACKGROUND = _prepare_background(samples_df)
|
| 146 |
+
POS_IDX, _ = _get_pos_index_and_classes(model, POS_LABEL)
|
| 147 |
+
|
| 148 |
+
def _f_proba_pos(X_np: np.ndarray) -> np.ndarray:
|
| 149 |
+
"""Model function returning P(class==1) for SHAP. X_np is numpy; convert to DataFrame with right columns."""
|
| 150 |
+
X_df = pd.DataFrame(X_np, columns=SUBSET_FEATURES)
|
| 151 |
+
return model.predict_proba(X_df)[:, POS_IDX]
|
| 152 |
+
|
| 153 |
+
# SHAP Explainer (KernelExplainer via unified interface)
|
| 154 |
+
try:
|
| 155 |
+
EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print("[WARN] SHAP explainer init failed:", e)
|
| 158 |
+
EXPLAINER = None
|
| 159 |
+
|
| 160 |
+
def _plot_local_shap(row_dict: dict):
|
| 161 |
+
"""Returns a matplotlib Figure with local SHAP bar chart for the given row."""
|
| 162 |
+
if EXPLAINER is None:
|
| 163 |
+
return None
|
| 164 |
+
X = pd.DataFrame([row_dict], columns=SUBSET_FEATURES)
|
| 165 |
+
exp = EXPLAINER(X.values) # exp.values shape: (1, n_features)
|
| 166 |
+
vals = exp.values[0]
|
| 167 |
+
order = np.argsort(np.abs(vals))
|
| 168 |
+
fig, ax = plt.subplots(figsize=(7, 4.5))
|
| 169 |
+
ax.barh(np.array(SUBSET_FEATURES)[order], vals[order])
|
| 170 |
+
ax.axvline(0, linewidth=1)
|
| 171 |
+
ax.set_title("Local SHAP values (current input)")
|
| 172 |
+
ax.set_xlabel("Impact on P(class==1)")
|
| 173 |
+
fig.tight_layout()
|
| 174 |
+
return fig
|
| 175 |
+
|
| 176 |
+
def _plot_global_shap():
|
| 177 |
+
"""Returns a matplotlib Figure with global mean(|SHAP|) bar chart over BACKGROUND."""
|
| 178 |
+
if EXPLAINER is None:
|
| 179 |
+
return None
|
| 180 |
+
exp = EXPLAINER(BACKGROUND.values)
|
| 181 |
+
mean_abs = np.mean(np.abs(exp.values), axis=0)
|
| 182 |
+
order = np.argsort(mean_abs)
|
| 183 |
+
fig, ax = plt.subplots(figsize=(7, 4.5))
|
| 184 |
+
ax.barh(np.array(SUBSET_FEATURES)[order], mean_abs[order])
|
| 185 |
+
ax.set_title("Global feature importance (mean |SHAP|)")
|
| 186 |
+
ax.set_xlabel("Mean |impact on P(class==1)|")
|
| 187 |
+
fig.tight_layout()
|
| 188 |
+
return fig
|
| 189 |
+
|
| 190 |
+
GLOBAL_FIG = _plot_global_shap()
|
| 191 |
+
|
| 192 |
+
# ---------- prediction ----------
|
| 193 |
+
def predict_manual(
|
| 194 |
+
threshold,
|
| 195 |
+
age, bmi, ab_count,
|
| 196 |
+
htn, cvd,
|
| 197 |
+
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
|
| 198 |
+
):
|
| 199 |
+
row = build_row_dict(
|
| 200 |
+
age, bmi, ab_count,
|
| 201 |
+
htn, cvd,
|
| 202 |
+
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
|
| 203 |
+
)
|
| 204 |
+
df = pd.DataFrame([row], columns=SUBSET_FEATURES)
|
| 205 |
+
proba = model.predict_proba(df)
|
| 206 |
+
p1 = float(proba[0][POS_IDX])
|
| 207 |
+
decision = 1 if p1 >= float(threshold) else 0
|
| 208 |
+
return int(decision), round(p1, 4), ("Positive" if decision==1 else "Negative"), pretty_json(row)
|
| 209 |
+
|
| 210 |
+
def explain_local(
|
| 211 |
+
age, bmi, ab_count,
|
| 212 |
+
htn, cvd,
|
| 213 |
+
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
|
| 214 |
+
):
|
| 215 |
+
row = build_row_dict(
|
| 216 |
+
age, bmi, ab_count,
|
| 217 |
+
htn, cvd,
|
| 218 |
+
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
|
| 219 |
+
)
|
| 220 |
+
fig = _plot_local_shap(row)
|
| 221 |
+
return fig
|
| 222 |
+
|
| 223 |
+
def explain_global():
|
| 224 |
+
return GLOBAL_FIG
|
| 225 |
+
|
| 226 |
+
def filter_sample_options(filter_target):
|
| 227 |
+
if samples_df is None:
|
| 228 |
+
return gr.update(choices=[], value=None)
|
| 229 |
+
df = samples_df
|
| 230 |
+
if filter_target in ("0", "1"):
|
| 231 |
+
df = df[df[TARGET_COL] == int(filter_target)]
|
| 232 |
+
opts = [ (f"{int(r['_rid'])}: y={int(r[TARGET_COL])}", int(r["_rid"])) for _, r in df.iterrows() ]
|
| 233 |
+
return gr.update(choices=opts, value=(opts[0][1] if opts else None))
|
| 234 |
+
|
| 235 |
+
def load_sample(rid):
|
| 236 |
+
if samples_df is None or rid is None:
|
| 237 |
+
return [gr.update()]*13 + [gr.update(value="")]
|
| 238 |
+
r = samples_df.loc[samples_df["_rid"] == int(rid)]
|
| 239 |
+
if r.empty:
|
| 240 |
+
return [gr.update()]*13 + [gr.update(value="")]
|
| 241 |
+
r = r.iloc[0]
|
| 242 |
+
|
| 243 |
+
updates = [
|
| 244 |
+
gr.update(value=f_or_none(r.get("age"))),
|
| 245 |
+
gr.update(value=f_or_none(r.get("bmi"))),
|
| 246 |
+
gr.update(value=int(r.get("previos_obsteric_history_ab", 0)) if pd.notna(r.get("previos_obsteric_history_ab")) else 0),
|
| 247 |
+
|
| 248 |
+
gr.update(value=as_bool(r.get("history_of_htn"))),
|
| 249 |
+
gr.update(value=as_bool(r.get("history_infectious_cardiovascular_diseae"))),
|
| 250 |
+
|
| 251 |
+
gr.update(value=f_or_none(r.get("fbs_first_trimester"))),
|
| 252 |
+
gr.update(value=f_or_none(r.get("hb"))),
|
| 253 |
+
gr.update(value=f_or_none(r.get("hct"))),
|
| 254 |
+
gr.update(value=f_or_none(r.get("cr"))),
|
| 255 |
+
gr.update(value=f_or_none(r.get("plt"))),
|
| 256 |
+
gr.update(value=f_or_none(r.get("vit_d3"))),
|
| 257 |
+
gr.update(value=f_or_none(r.get("sono_nt_nt"))),
|
| 258 |
+
gr.update(value=f_or_none(r.get("sono_nt_crl"))),
|
| 259 |
+
|
| 260 |
+
gr.update(value=str(int(r.get(TARGET_COL))) if pd.notna(r.get(TARGET_COL)) else "")
|
| 261 |
+
]
|
| 262 |
+
return updates
|
| 263 |
+
|
| 264 |
+
def compare_correctness(gt_text, decision_label):
|
| 265 |
+
if gt_text is None or gt_text == "":
|
| 266 |
+
return "—"
|
| 267 |
+
try:
|
| 268 |
+
gt = int(float(gt_text))
|
| 269 |
+
except Exception:
|
| 270 |
+
return "—"
|
| 271 |
+
return "✅ Correct" if gt == int(decision_label) else "❌ Incorrect"
|
| 272 |
+
|
| 273 |
+
def get_feature_importance_text():
|
| 274 |
+
# Keep textual fallback if SHAP not available
|
| 275 |
+
est = None
|
| 276 |
+
try:
|
| 277 |
+
est = getattr(model, "named_steps", {}).get("trained_model", None)
|
| 278 |
+
except Exception:
|
| 279 |
+
est = None
|
| 280 |
+
if est is None:
|
| 281 |
+
est = model
|
| 282 |
+
fi = None
|
| 283 |
+
if hasattr(est, "feature_importances_"):
|
| 284 |
+
fi = list(est.feature_importances_)
|
| 285 |
+
elif hasattr(est, "coef_"):
|
| 286 |
+
coef = est.coef_
|
| 287 |
+
if coef is not None:
|
| 288 |
+
fi = list(coef.reshape(-1))
|
| 289 |
+
if not fi or len(fi) != len(SUBSET_FEATURES):
|
| 290 |
+
return "Not available for this model."
|
| 291 |
+
pairs = sorted(zip(SUBSET_FEATURES, fi), key=lambda x: abs(x[1]), reverse=True)
|
| 292 |
+
return "\n".join([f"- {k}: {v:.4f}" for k, v in pairs])
|
| 293 |
+
|
| 294 |
+
GLOBAL_FI_TEXT = get_feature_importance_text()
|
| 295 |
+
|
| 296 |
+
# ---------- theme ----------
|
| 297 |
+
theme = gr.themes.Soft(
|
| 298 |
+
primary_hue="violet",
|
| 299 |
+
neutral_hue="slate",
|
| 300 |
+
).set(
|
| 301 |
+
body_background_fill_dark="#0b0f19",
|
| 302 |
+
block_border_width="1px"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# ---------- UI ----------
|
| 306 |
+
with gr.Blocks(theme=theme, title="GTT Classifier — Manual + Fixed Samples") as demo:
|
| 307 |
+
gr.Markdown("## GTT Prediction (Subset Features)\n**PyCaret pipeline · Auto-preprocessing · Thresholdable**")
|
| 308 |
+
|
| 309 |
+
with gr.Row():
|
| 310 |
+
# (1) Manual input
|
| 311 |
+
with gr.Column(scale=1):
|
| 312 |
+
gr.Markdown("### 1) Manual input")
|
| 313 |
+
|
| 314 |
+
age = gr.Number(label="Age (years)", value=0)
|
| 315 |
+
bmi = gr.Number(label="BMI", value=0)
|
| 316 |
+
ab_count = gr.Number(label="Previos Obsteric History of Abortion (count)", value=0, precision=0)
|
| 317 |
+
|
| 318 |
+
gr.Markdown("---\n**Clinical flags**")
|
| 319 |
+
htn = gr.Checkbox(label="History of Hypertension", value=False)
|
| 320 |
+
cvd = gr.Checkbox(label="History of Cardiovascular disease", value=False)
|
| 321 |
+
|
| 322 |
+
with gr.Accordion("More numeric features (optional)", open=False):
|
| 323 |
+
fbs1 = gr.Number(label="FBS of First trimester")
|
| 324 |
+
hb = gr.Number(label="HB")
|
| 325 |
+
hct = gr.Number(label="HCT")
|
| 326 |
+
cr = gr.Number(label="CR")
|
| 327 |
+
plt_v = gr.Number(label="PLT")
|
| 328 |
+
vitd3 = gr.Number(label="Vit D3")
|
| 329 |
+
sono_nt = gr.Number(label="Sonographic NT")
|
| 330 |
+
sono_crl = gr.Number(label="Sonographic CRL")
|
| 331 |
+
|
| 332 |
+
with gr.Row():
|
| 333 |
+
threshold = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'")
|
| 334 |
+
reset_thr = gr.Button("↻", size="sm")
|
| 335 |
+
|
| 336 |
+
predict_btn = gr.Button("🚀 Predict (manual)", variant="primary")
|
| 337 |
+
explain_btn = gr.Button("🧠 Explain (SHAP for current input)")
|
| 338 |
+
|
| 339 |
+
# (2) Sample picker
|
| 340 |
+
with gr.Column(scale=1):
|
| 341 |
+
gr.Markdown("### 2) Sample picker (from fixed file)")
|
| 342 |
+
filt = gr.Dropdown(choices=["All", "0", "1"], value="All", label="Filter by target")
|
| 343 |
+
sample_dd = gr.Dropdown(choices=[], value=None, label="Choose sample row")
|
| 344 |
+
load_ok = gr.Button("Load sample into manual inputs", variant="secondary")
|
| 345 |
+
|
| 346 |
+
# (3) Results
|
| 347 |
+
with gr.Column(scale=1):
|
| 348 |
+
gr.Markdown("### 3) Results")
|
| 349 |
+
|
| 350 |
+
pred_label = gr.Number(label="Predicted label (with threshold decision)", interactive=False)
|
| 351 |
+
with gr.Row():
|
| 352 |
+
pred_prob = gr.Number(label="P(class==1)", value=0, interactive=False)
|
| 353 |
+
decision_text = gr.Textbox(label="Decision @ threshold", interactive=False)
|
| 354 |
+
|
| 355 |
+
gt_box = gr.Textbox(label="Ground truth (sample)", interactive=False)
|
| 356 |
+
correctness = gr.Textbox(label="Correct vs. ground truth?", interactive=False)
|
| 357 |
+
|
| 358 |
+
with gr.Accordion("Echoed input (row sent to model)", open=False):
|
| 359 |
+
echoed = gr.Code(label="", language="json")
|
| 360 |
+
|
| 361 |
+
with gr.Accordion("Global feature importance (SHAP)", open=False):
|
| 362 |
+
global_plot = gr.Plot(value=GLOBAL_FIG)
|
| 363 |
+
gr.Markdown("> Text fallback (native model importances):")
|
| 364 |
+
gr.Markdown(GLOBAL_FI_TEXT)
|
| 365 |
+
|
| 366 |
+
with gr.Accordion("Local explanation (SHAP) for current input", open=False):
|
| 367 |
+
local_plot = gr.Plot()
|
| 368 |
+
|
| 369 |
+
# events
|
| 370 |
+
demo.load(lambda: filter_sample_options("All"), inputs=None, outputs=[sample_dd], queue=False)
|
| 371 |
+
filt.change(filter_sample_options, inputs=[filt], outputs=[sample_dd])
|
| 372 |
+
reset_thr.click(fn=lambda: 0.5, inputs=None, outputs=[threshold])
|
| 373 |
+
|
| 374 |
+
load_ok.click(
|
| 375 |
+
fn=load_sample,
|
| 376 |
+
inputs=[sample_dd],
|
| 377 |
+
outputs=[
|
| 378 |
+
age, bmi, ab_count,
|
| 379 |
+
htn, cvd,
|
| 380 |
+
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl,
|
| 381 |
+
gt_box
|
| 382 |
+
],
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
predict_btn.click(
|
| 386 |
+
fn=predict_manual,
|
| 387 |
+
inputs=[
|
| 388 |
+
threshold,
|
| 389 |
+
age, bmi, ab_count,
|
| 390 |
+
htn, cvd,
|
| 391 |
+
fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
|
| 392 |
+
],
|
| 393 |
+
outputs=[pred_label, pred_prob, decision_text, echoed],
|
| 394 |
+
).then(
|
| 395 |
+
fn=compare_correctness,
|
| 396 |
+
inputs=[gt_box, pred_label],
|
| 397 |
+
outputs=[correctness]
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
explain_btn.click(
|
| 401 |
+
fn=explain_local,
|
| 402 |
+
inputs=[age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl],
|
| 403 |
+
outputs=[local_plot]
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
os.environ["NO_PROXY"] = "127.0.0.1,localhost"
|
| 408 |
+
os.environ["no_proxy"] = "127.0.0.1,localhost"
|
| 409 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pycaret>=3.3,<4
|
| 2 |
+
gradio
|
| 3 |
+
pandas
|
| 4 |
+
shap
|
| 5 |
+
matplotlib
|
subset_best_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b87c8c09e49f9423392d1c4da3b319759820003428bb66bfe74f3155a18b82dd
|
| 3 |
+
size 149680
|