GDMProjects commited on
Commit
93a36b7
·
verified ·
1 Parent(s): 6430c95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -409
app.py CHANGED
@@ -1,409 +1,414 @@
1
- # app.py
2
- # pip install "pycaret>=3.3,<4" gradio pandas shap matplotlib
3
-
4
- # --- FORCE NON-INTERACTIVE MATPLOTLIB BACKEND (must be first!) ---
5
- import os
6
- os.environ["MPLBACKEND"] = "Agg" # prevents Tk backend init
7
- import matplotlib
8
- matplotlib.use("Agg", force=True)
9
-
10
- import json
11
- import numpy as np
12
- import pandas as pd
13
- import gradio as gr
14
- import matplotlib.pyplot as plt
15
- import shap
16
-
17
- from pathlib import Path
18
- from pycaret.classification import load_model
19
-
20
- # --- config ---
21
- MODEL_BASENAME = "subset_best_model"
22
- SAMPLES_CSV = "GTT.csv" # fixed hidden file
23
- TARGET_COL = "gtt"
24
- POS_LABEL = 1
25
-
26
- # subset features used by the model (normalized names)
27
- SUBSET_FEATURES = [
28
- "age",
29
- "bmi",
30
- "history_of_htn",
31
- "history_infectious_cardiovascular_diseae",
32
- "previos_obsteric_history_ab",
33
- "fbs_first_trimester",
34
- "hb",
35
- "hct",
36
- "cr",
37
- "plt",
38
- "vit_d3",
39
- "sono_nt_nt",
40
- "sono_nt_crl",
41
- ]
42
-
43
- # ---------- utils ----------
44
- def normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
45
- out = df.copy()
46
- out.columns = (
47
- out.columns.str.strip()
48
- .str.replace(r"[\s/\\\.\-]+", "_", regex=True)
49
- .str.replace(r"__+", "_", regex=True)
50
- .str.lower()
51
- )
52
- return out
53
-
54
- def load_samples():
55
- if not Path(SAMPLES_CSV).exists():
56
- return None
57
- df = pd.read_csv(SAMPLES_CSV)
58
- df = normalize_cols(df)
59
- needed = set(["id", TARGET_COL] + SUBSET_FEATURES)
60
- if not needed.issubset(df.columns):
61
- missing = needed - set(df.columns)
62
- print(f"[WARN] samples file missing columns: {sorted(missing)}")
63
- return None
64
- df = df.reset_index(drop=False).rename(columns={"index": "_rid"}) # stable row id for dropdown
65
- return df
66
-
67
- def pretty_json(d):
68
- return json.dumps(d, ensure_ascii=False, indent=2)
69
-
70
- def as_bool(x, default=False):
71
- if x is None or (isinstance(x, float) and pd.isna(x)):
72
- return default
73
- if isinstance(x, bool):
74
- return x
75
- if isinstance(x, (int,)):
76
- return bool(x)
77
- s = str(x).strip().lower()
78
- yes = {"1","true","t","yes","y","on","pos","positive"}
79
- no = {"0","false","f","no","n","off","neg","negative"}
80
- if s in yes: return True
81
- if s in no: return False
82
- try:
83
- return bool(int(float(s)))
84
- except Exception:
85
- return default
86
-
87
- def f_or_none(v):
88
- return float(v) if (v is not None and not (isinstance(v, float) and pd.isna(v))) else None
89
-
90
- def build_row_dict(
91
- age, bmi, ab_count,
92
- htn, cvd,
93
- fbs1, hb, hct, cr, plt, vitd3, sono_nt, sono_crl
94
- ):
95
- return {
96
- "age": age,
97
- "bmi": bmi,
98
- "previos_obsteric_history_ab": ab_count,
99
- "history_of_htn": 1 if htn else 0,
100
- "history_infectious_cardiovascular_diseae": 1 if cvd else 0,
101
- "fbs_first_trimester": fbs1,
102
- "hb": hb,
103
- "hct": hct,
104
- "cr": cr,
105
- "plt": plt,
106
- "vit_d3": vitd3,
107
- "sono_nt_nt": sono_nt,
108
- "sono_nt_crl": sono_crl,
109
- }
110
-
111
- def _get_pos_index_and_classes(pipe, pos_label=1):
112
- est = None
113
- try:
114
- est = getattr(pipe, "named_steps", {}).get("trained_model", None)
115
- except Exception:
116
- est = None
117
- if est is None:
118
- est = pipe
119
- classes = getattr(est, "classes_", None)
120
- if classes is not None and pos_label in list(classes):
121
- return list(classes).index(pos_label), list(classes)
122
- return -1, list(classes) if classes is not None else None
123
-
124
- # ---------- model & samples ----------
125
- model = load_model(MODEL_BASENAME)
126
- samples_df = load_samples()
127
-
128
- # ---------- SHAP: background + explainer (built once) ----------
129
- def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame:
130
- if df_samples is None:
131
- # if no CSV, make a tiny synthetic background of zeros
132
- bg = pd.DataFrame([{k: 0.0 for k in SUBSET_FEATURES} for _ in range(50)])
133
- else:
134
- bg = df_samples[SUBSET_FEATURES].copy()
135
- # numeric coercion + median impute
136
- for c in SUBSET_FEATURES:
137
- if c not in bg.columns:
138
- bg[c] = np.nan
139
- bg = bg.apply(pd.to_numeric, errors="coerce")
140
- bg = bg.fillna(bg.median(numeric_only=True))
141
- if len(bg) > max_rows:
142
- bg = bg.sample(max_rows, random_state=42)
143
- return bg.reset_index(drop=True)
144
-
145
- BACKGROUND = _prepare_background(samples_df)
146
- POS_IDX, _ = _get_pos_index_and_classes(model, POS_LABEL)
147
-
148
- def _f_proba_pos(X_np: np.ndarray) -> np.ndarray:
149
- """Model function returning P(class==1) for SHAP. X_np is numpy; convert to DataFrame with right columns."""
150
- X_df = pd.DataFrame(X_np, columns=SUBSET_FEATURES)
151
- return model.predict_proba(X_df)[:, POS_IDX]
152
-
153
- # SHAP Explainer (KernelExplainer via unified interface)
154
- try:
155
- EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values)
156
- except Exception as e:
157
- print("[WARN] SHAP explainer init failed:", e)
158
- EXPLAINER = None
159
-
160
- def _plot_local_shap(row_dict: dict):
161
- """Returns a matplotlib Figure with local SHAP bar chart for the given row."""
162
- if EXPLAINER is None:
163
- return None
164
- X = pd.DataFrame([row_dict], columns=SUBSET_FEATURES)
165
- exp = EXPLAINER(X.values) # exp.values shape: (1, n_features)
166
- vals = exp.values[0]
167
- order = np.argsort(np.abs(vals))
168
- fig, ax = plt.subplots(figsize=(7, 4.5))
169
- ax.barh(np.array(SUBSET_FEATURES)[order], vals[order])
170
- ax.axvline(0, linewidth=1)
171
- ax.set_title("Local SHAP values (current input)")
172
- ax.set_xlabel("Impact on P(class==1)")
173
- fig.tight_layout()
174
- return fig
175
-
176
- def _plot_global_shap():
177
- """Returns a matplotlib Figure with global mean(|SHAP|) bar chart over BACKGROUND."""
178
- if EXPLAINER is None:
179
- return None
180
- exp = EXPLAINER(BACKGROUND.values)
181
- mean_abs = np.mean(np.abs(exp.values), axis=0)
182
- order = np.argsort(mean_abs)
183
- fig, ax = plt.subplots(figsize=(7, 4.5))
184
- ax.barh(np.array(SUBSET_FEATURES)[order], mean_abs[order])
185
- ax.set_title("Global feature importance (mean |SHAP|)")
186
- ax.set_xlabel("Mean |impact on P(class==1)|")
187
- fig.tight_layout()
188
- return fig
189
-
190
- GLOBAL_FIG = _plot_global_shap()
191
-
192
- # ---------- prediction ----------
193
- def predict_manual(
194
- threshold,
195
- age, bmi, ab_count,
196
- htn, cvd,
197
- fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
198
- ):
199
- row = build_row_dict(
200
- age, bmi, ab_count,
201
- htn, cvd,
202
- fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
203
- )
204
- df = pd.DataFrame([row], columns=SUBSET_FEATURES)
205
- proba = model.predict_proba(df)
206
- p1 = float(proba[0][POS_IDX])
207
- decision = 1 if p1 >= float(threshold) else 0
208
- return int(decision), round(p1, 4), ("Positive" if decision==1 else "Negative"), pretty_json(row)
209
-
210
- def explain_local(
211
- age, bmi, ab_count,
212
- htn, cvd,
213
- fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
214
- ):
215
- row = build_row_dict(
216
- age, bmi, ab_count,
217
- htn, cvd,
218
- fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
219
- )
220
- fig = _plot_local_shap(row)
221
- return fig
222
-
223
- def explain_global():
224
- return GLOBAL_FIG
225
-
226
- def filter_sample_options(filter_target):
227
- if samples_df is None:
228
- return gr.update(choices=[], value=None)
229
- df = samples_df
230
- if filter_target in ("0", "1"):
231
- df = df[df[TARGET_COL] == int(filter_target)]
232
- opts = [ (f"{int(r['_rid'])}: y={int(r[TARGET_COL])}", int(r["_rid"])) for _, r in df.iterrows() ]
233
- return gr.update(choices=opts, value=(opts[0][1] if opts else None))
234
-
235
- def load_sample(rid):
236
- if samples_df is None or rid is None:
237
- return [gr.update()]*13 + [gr.update(value="")]
238
- r = samples_df.loc[samples_df["_rid"] == int(rid)]
239
- if r.empty:
240
- return [gr.update()]*13 + [gr.update(value="")]
241
- r = r.iloc[0]
242
-
243
- updates = [
244
- gr.update(value=f_or_none(r.get("age"))),
245
- gr.update(value=f_or_none(r.get("bmi"))),
246
- gr.update(value=int(r.get("previos_obsteric_history_ab", 0)) if pd.notna(r.get("previos_obsteric_history_ab")) else 0),
247
-
248
- gr.update(value=as_bool(r.get("history_of_htn"))),
249
- gr.update(value=as_bool(r.get("history_infectious_cardiovascular_diseae"))),
250
-
251
- gr.update(value=f_or_none(r.get("fbs_first_trimester"))),
252
- gr.update(value=f_or_none(r.get("hb"))),
253
- gr.update(value=f_or_none(r.get("hct"))),
254
- gr.update(value=f_or_none(r.get("cr"))),
255
- gr.update(value=f_or_none(r.get("plt"))),
256
- gr.update(value=f_or_none(r.get("vit_d3"))),
257
- gr.update(value=f_or_none(r.get("sono_nt_nt"))),
258
- gr.update(value=f_or_none(r.get("sono_nt_crl"))),
259
-
260
- gr.update(value=str(int(r.get(TARGET_COL))) if pd.notna(r.get(TARGET_COL)) else "")
261
- ]
262
- return updates
263
-
264
- def compare_correctness(gt_text, decision_label):
265
- if gt_text is None or gt_text == "":
266
- return "—"
267
- try:
268
- gt = int(float(gt_text))
269
- except Exception:
270
- return ""
271
- return "✅ Correct" if gt == int(decision_label) else "❌ Incorrect"
272
-
273
- def get_feature_importance_text():
274
- # Keep textual fallback if SHAP not available
275
- est = None
276
- try:
277
- est = getattr(model, "named_steps", {}).get("trained_model", None)
278
- except Exception:
279
- est = None
280
- if est is None:
281
- est = model
282
- fi = None
283
- if hasattr(est, "feature_importances_"):
284
- fi = list(est.feature_importances_)
285
- elif hasattr(est, "coef_"):
286
- coef = est.coef_
287
- if coef is not None:
288
- fi = list(coef.reshape(-1))
289
- if not fi or len(fi) != len(SUBSET_FEATURES):
290
- return "Not available for this model."
291
- pairs = sorted(zip(SUBSET_FEATURES, fi), key=lambda x: abs(x[1]), reverse=True)
292
- return "\n".join([f"- {k}: {v:.4f}" for k, v in pairs])
293
-
294
- GLOBAL_FI_TEXT = get_feature_importance_text()
295
-
296
- # ---------- theme ----------
297
- theme = gr.themes.Soft(
298
- primary_hue="violet",
299
- neutral_hue="slate",
300
- ).set(
301
- body_background_fill_dark="#0b0f19",
302
- block_border_width="1px"
303
- )
304
-
305
- # ---------- UI ----------
306
- with gr.Blocks(theme=theme, title="GTT Classifier — Manual + Fixed Samples") as demo:
307
- gr.Markdown("## GTT Prediction (Subset Features)\n**PyCaret pipeline · Auto-preprocessing · Thresholdable**")
308
-
309
- with gr.Row():
310
- # (1) Manual input
311
- with gr.Column(scale=1):
312
- gr.Markdown("### 1) Manual input")
313
-
314
- age = gr.Number(label="Age (years)", value=0)
315
- bmi = gr.Number(label="BMI", value=0)
316
- ab_count = gr.Number(label="Previos Obsteric History of Abortion (count)", value=0, precision=0)
317
-
318
- gr.Markdown("---\n**Clinical flags**")
319
- htn = gr.Checkbox(label="History of Hypertension", value=False)
320
- cvd = gr.Checkbox(label="History of Cardiovascular disease", value=False)
321
-
322
- with gr.Accordion("More numeric features (optional)", open=False):
323
- fbs1 = gr.Number(label="FBS of First trimester")
324
- hb = gr.Number(label="HB")
325
- hct = gr.Number(label="HCT")
326
- cr = gr.Number(label="CR")
327
- plt_v = gr.Number(label="PLT")
328
- vitd3 = gr.Number(label="Vit D3")
329
- sono_nt = gr.Number(label="Sonographic NT")
330
- sono_crl = gr.Number(label="Sonographic CRL")
331
-
332
- with gr.Row():
333
- threshold = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'")
334
- reset_thr = gr.Button("↻", size="sm")
335
-
336
- predict_btn = gr.Button("🚀 Predict (manual)", variant="primary")
337
- explain_btn = gr.Button("🧠 Explain (SHAP for current input)")
338
-
339
- # (2) Sample picker
340
- with gr.Column(scale=1):
341
- gr.Markdown("### 2) Sample picker (from fixed file)")
342
- filt = gr.Dropdown(choices=["All", "0", "1"], value="All", label="Filter by target")
343
- sample_dd = gr.Dropdown(choices=[], value=None, label="Choose sample row")
344
- load_ok = gr.Button("Load sample into manual inputs", variant="secondary")
345
-
346
- # (3) Results
347
- with gr.Column(scale=1):
348
- gr.Markdown("### 3) Results")
349
-
350
- pred_label = gr.Number(label="Predicted label (with threshold decision)", interactive=False)
351
- with gr.Row():
352
- pred_prob = gr.Number(label="P(class==1)", value=0, interactive=False)
353
- decision_text = gr.Textbox(label="Decision @ threshold", interactive=False)
354
-
355
- gt_box = gr.Textbox(label="Ground truth (sample)", interactive=False)
356
- correctness = gr.Textbox(label="Correct vs. ground truth?", interactive=False)
357
-
358
- with gr.Accordion("Echoed input (row sent to model)", open=False):
359
- echoed = gr.Code(label="", language="json")
360
-
361
- with gr.Accordion("Global feature importance (SHAP)", open=False):
362
- global_plot = gr.Plot(value=GLOBAL_FIG)
363
- gr.Markdown("> Text fallback (native model importances):")
364
- gr.Markdown(GLOBAL_FI_TEXT)
365
-
366
- with gr.Accordion("Local explanation (SHAP) for current input", open=False):
367
- local_plot = gr.Plot()
368
-
369
- # events
370
- demo.load(lambda: filter_sample_options("All"), inputs=None, outputs=[sample_dd], queue=False)
371
- filt.change(filter_sample_options, inputs=[filt], outputs=[sample_dd])
372
- reset_thr.click(fn=lambda: 0.5, inputs=None, outputs=[threshold])
373
-
374
- load_ok.click(
375
- fn=load_sample,
376
- inputs=[sample_dd],
377
- outputs=[
378
- age, bmi, ab_count,
379
- htn, cvd,
380
- fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl,
381
- gt_box
382
- ],
383
- )
384
-
385
- predict_btn.click(
386
- fn=predict_manual,
387
- inputs=[
388
- threshold,
389
- age, bmi, ab_count,
390
- htn, cvd,
391
- fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
392
- ],
393
- outputs=[pred_label, pred_prob, decision_text, echoed],
394
- ).then(
395
- fn=compare_correctness,
396
- inputs=[gt_box, pred_label],
397
- outputs=[correctness]
398
- )
399
-
400
- explain_btn.click(
401
- fn=explain_local,
402
- inputs=[age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl],
403
- outputs=[local_plot]
404
- )
405
-
406
- if __name__ == "__main__":
407
- os.environ["NO_PROXY"] = "127.0.0.1,localhost"
408
- os.environ["no_proxy"] = "127.0.0.1,localhost"
409
- demo.launch()
 
 
 
 
 
 
1
+ # app.py
2
+ # pip install "pycaret>=3.3,<4" gradio pandas shap matplotlib
3
+
4
+ # --- FORCE NON-INTERACTIVE MATPLOTLIB BACKEND (must be first!) ---
5
+ import os
6
+ os.environ["MPLBACKEND"] = "Agg" # prevents Tk backend init
7
+ import matplotlib
8
+ matplotlib.use("Agg", force=True)
9
+
10
+ import json
11
+ import numpy as np
12
+ import pandas as pd
13
+ import gradio as gr
14
+ import matplotlib.pyplot as plt
15
+ import shap
16
+
17
+ from pathlib import Path
18
+ from pycaret.classification import load_model
19
+ from huggingface_hub import hf_hub_download
20
+ # --- config ---
21
+ MODEL_BASENAME = "subset_best_model"
22
+ SAMPLES_CSV = "GTT.csv" # fixed hidden file
23
+ TARGET_COL = "gtt"
24
+ POS_LABEL = 1
25
+
26
+ REPO = os.getenv("MODEL_REPO", "GDMProjects/my-private-model")
27
+ FNAME = os.getenv("MODEL_FILE", "subset_best_model.pkl")
28
+ TOKEN = os.getenv("HF_TOKEN")
29
+
30
+ # subset features used by the model (normalized names)
31
+ SUBSET_FEATURES = [
32
+ "age",
33
+ "bmi",
34
+ "history_of_htn",
35
+ "history_infectious_cardiovascular_diseae",
36
+ "previos_obsteric_history_ab",
37
+ "fbs_first_trimester",
38
+ "hb",
39
+ "hct",
40
+ "cr",
41
+ "plt",
42
+ "vit_d3",
43
+ "sono_nt_nt",
44
+ "sono_nt_crl",
45
+ ]
46
+
47
+ # ---------- utils ----------
48
+ def normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
49
+ out = df.copy()
50
+ out.columns = (
51
+ out.columns.str.strip()
52
+ .str.replace(r"[\s/\\\.\-]+", "_", regex=True)
53
+ .str.replace(r"__+", "_", regex=True)
54
+ .str.lower()
55
+ )
56
+ return out
57
+
58
+ def load_samples():
59
+ if not Path(SAMPLES_CSV).exists():
60
+ return None
61
+ df = pd.read_csv(SAMPLES_CSV)
62
+ df = normalize_cols(df)
63
+ needed = set(["id", TARGET_COL] + SUBSET_FEATURES)
64
+ if not needed.issubset(df.columns):
65
+ missing = needed - set(df.columns)
66
+ print(f"[WARN] samples file missing columns: {sorted(missing)}")
67
+ return None
68
+ df = df.reset_index(drop=False).rename(columns={"index": "_rid"}) # stable row id for dropdown
69
+ return df
70
+
71
+ def pretty_json(d):
72
+ return json.dumps(d, ensure_ascii=False, indent=2)
73
+
74
+ def as_bool(x, default=False):
75
+ if x is None or (isinstance(x, float) and pd.isna(x)):
76
+ return default
77
+ if isinstance(x, bool):
78
+ return x
79
+ if isinstance(x, (int,)):
80
+ return bool(x)
81
+ s = str(x).strip().lower()
82
+ yes = {"1","true","t","yes","y","on","pos","positive"}
83
+ no = {"0","false","f","no","n","off","neg","negative"}
84
+ if s in yes: return True
85
+ if s in no: return False
86
+ try:
87
+ return bool(int(float(s)))
88
+ except Exception:
89
+ return default
90
+
91
+ def f_or_none(v):
92
+ return float(v) if (v is not None and not (isinstance(v, float) and pd.isna(v))) else None
93
+
94
+ def build_row_dict(
95
+ age, bmi, ab_count,
96
+ htn, cvd,
97
+ fbs1, hb, hct, cr, plt, vitd3, sono_nt, sono_crl
98
+ ):
99
+ return {
100
+ "age": age,
101
+ "bmi": bmi,
102
+ "previos_obsteric_history_ab": ab_count,
103
+ "history_of_htn": 1 if htn else 0,
104
+ "history_infectious_cardiovascular_diseae": 1 if cvd else 0,
105
+ "fbs_first_trimester": fbs1,
106
+ "hb": hb,
107
+ "hct": hct,
108
+ "cr": cr,
109
+ "plt": plt,
110
+ "vit_d3": vitd3,
111
+ "sono_nt_nt": sono_nt,
112
+ "sono_nt_crl": sono_crl,
113
+ }
114
+
115
+ def _get_pos_index_and_classes(pipe, pos_label=1):
116
+ est = None
117
+ try:
118
+ est = getattr(pipe, "named_steps", {}).get("trained_model", None)
119
+ except Exception:
120
+ est = None
121
+ if est is None:
122
+ est = pipe
123
+ classes = getattr(est, "classes_", None)
124
+ if classes is not None and pos_label in list(classes):
125
+ return list(classes).index(pos_label), list(classes)
126
+ return -1, list(classes) if classes is not None else None
127
+
128
+ # ---------- model & samples ----------
129
+ local_path = hf_hub_download(repo_id=REPO, filename=FNAME, token=TOKEN)
130
+ model = load_model(local_path)
131
+ samples_df = load_samples()
132
+
133
+ # ---------- SHAP: background + explainer (built once) ----------
134
+ def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame:
135
+ if df_samples is None:
136
+ # if no CSV, make a tiny synthetic background of zeros
137
+ bg = pd.DataFrame([{k: 0.0 for k in SUBSET_FEATURES} for _ in range(50)])
138
+ else:
139
+ bg = df_samples[SUBSET_FEATURES].copy()
140
+ # numeric coercion + median impute
141
+ for c in SUBSET_FEATURES:
142
+ if c not in bg.columns:
143
+ bg[c] = np.nan
144
+ bg = bg.apply(pd.to_numeric, errors="coerce")
145
+ bg = bg.fillna(bg.median(numeric_only=True))
146
+ if len(bg) > max_rows:
147
+ bg = bg.sample(max_rows, random_state=42)
148
+ return bg.reset_index(drop=True)
149
+
150
+ BACKGROUND = _prepare_background(samples_df)
151
+ POS_IDX, _ = _get_pos_index_and_classes(model, POS_LABEL)
152
+
153
+ def _f_proba_pos(X_np: np.ndarray) -> np.ndarray:
154
+ """Model function returning P(class==1) for SHAP. X_np is numpy; convert to DataFrame with right columns."""
155
+ X_df = pd.DataFrame(X_np, columns=SUBSET_FEATURES)
156
+ return model.predict_proba(X_df)[:, POS_IDX]
157
+
158
+ # SHAP Explainer (KernelExplainer via unified interface)
159
+ try:
160
+ EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values)
161
+ except Exception as e:
162
+ print("[WARN] SHAP explainer init failed:", e)
163
+ EXPLAINER = None
164
+
165
+ def _plot_local_shap(row_dict: dict):
166
+ """Returns a matplotlib Figure with local SHAP bar chart for the given row."""
167
+ if EXPLAINER is None:
168
+ return None
169
+ X = pd.DataFrame([row_dict], columns=SUBSET_FEATURES)
170
+ exp = EXPLAINER(X.values) # exp.values shape: (1, n_features)
171
+ vals = exp.values[0]
172
+ order = np.argsort(np.abs(vals))
173
+ fig, ax = plt.subplots(figsize=(7, 4.5))
174
+ ax.barh(np.array(SUBSET_FEATURES)[order], vals[order])
175
+ ax.axvline(0, linewidth=1)
176
+ ax.set_title("Local SHAP values (current input)")
177
+ ax.set_xlabel("Impact on P(class==1)")
178
+ fig.tight_layout()
179
+ return fig
180
+
181
+ def _plot_global_shap():
182
+ """Returns a matplotlib Figure with global mean(|SHAP|) bar chart over BACKGROUND."""
183
+ if EXPLAINER is None:
184
+ return None
185
+ exp = EXPLAINER(BACKGROUND.values)
186
+ mean_abs = np.mean(np.abs(exp.values), axis=0)
187
+ order = np.argsort(mean_abs)
188
+ fig, ax = plt.subplots(figsize=(7, 4.5))
189
+ ax.barh(np.array(SUBSET_FEATURES)[order], mean_abs[order])
190
+ ax.set_title("Global feature importance (mean |SHAP|)")
191
+ ax.set_xlabel("Mean |impact on P(class==1)|")
192
+ fig.tight_layout()
193
+ return fig
194
+
195
+ GLOBAL_FIG = _plot_global_shap()
196
+
197
+ # ---------- prediction ----------
198
+ def predict_manual(
199
+ threshold,
200
+ age, bmi, ab_count,
201
+ htn, cvd,
202
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
203
+ ):
204
+ row = build_row_dict(
205
+ age, bmi, ab_count,
206
+ htn, cvd,
207
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
208
+ )
209
+ df = pd.DataFrame([row], columns=SUBSET_FEATURES)
210
+ proba = model.predict_proba(df)
211
+ p1 = float(proba[0][POS_IDX])
212
+ decision = 1 if p1 >= float(threshold) else 0
213
+ return int(decision), round(p1, 4), ("Positive" if decision==1 else "Negative"), pretty_json(row)
214
+
215
+ def explain_local(
216
+ age, bmi, ab_count,
217
+ htn, cvd,
218
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
219
+ ):
220
+ row = build_row_dict(
221
+ age, bmi, ab_count,
222
+ htn, cvd,
223
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
224
+ )
225
+ fig = _plot_local_shap(row)
226
+ return fig
227
+
228
+ def explain_global():
229
+ return GLOBAL_FIG
230
+
231
+ def filter_sample_options(filter_target):
232
+ if samples_df is None:
233
+ return gr.update(choices=[], value=None)
234
+ df = samples_df
235
+ if filter_target in ("0", "1"):
236
+ df = df[df[TARGET_COL] == int(filter_target)]
237
+ opts = [ (f"{int(r['_rid'])}: y={int(r[TARGET_COL])}", int(r["_rid"])) for _, r in df.iterrows() ]
238
+ return gr.update(choices=opts, value=(opts[0][1] if opts else None))
239
+
240
+ def load_sample(rid):
241
+ if samples_df is None or rid is None:
242
+ return [gr.update()]*13 + [gr.update(value="")]
243
+ r = samples_df.loc[samples_df["_rid"] == int(rid)]
244
+ if r.empty:
245
+ return [gr.update()]*13 + [gr.update(value="")]
246
+ r = r.iloc[0]
247
+
248
+ updates = [
249
+ gr.update(value=f_or_none(r.get("age"))),
250
+ gr.update(value=f_or_none(r.get("bmi"))),
251
+ gr.update(value=int(r.get("previos_obsteric_history_ab", 0)) if pd.notna(r.get("previos_obsteric_history_ab")) else 0),
252
+
253
+ gr.update(value=as_bool(r.get("history_of_htn"))),
254
+ gr.update(value=as_bool(r.get("history_infectious_cardiovascular_diseae"))),
255
+
256
+ gr.update(value=f_or_none(r.get("fbs_first_trimester"))),
257
+ gr.update(value=f_or_none(r.get("hb"))),
258
+ gr.update(value=f_or_none(r.get("hct"))),
259
+ gr.update(value=f_or_none(r.get("cr"))),
260
+ gr.update(value=f_or_none(r.get("plt"))),
261
+ gr.update(value=f_or_none(r.get("vit_d3"))),
262
+ gr.update(value=f_or_none(r.get("sono_nt_nt"))),
263
+ gr.update(value=f_or_none(r.get("sono_nt_crl"))),
264
+
265
+ gr.update(value=str(int(r.get(TARGET_COL))) if pd.notna(r.get(TARGET_COL)) else "")
266
+ ]
267
+ return updates
268
+
269
+ def compare_correctness(gt_text, decision_label):
270
+ if gt_text is None or gt_text == "":
271
+ return ""
272
+ try:
273
+ gt = int(float(gt_text))
274
+ except Exception:
275
+ return "—"
276
+ return "✅ Correct" if gt == int(decision_label) else "❌ Incorrect"
277
+
278
+ def get_feature_importance_text():
279
+ # Keep textual fallback if SHAP not available
280
+ est = None
281
+ try:
282
+ est = getattr(model, "named_steps", {}).get("trained_model", None)
283
+ except Exception:
284
+ est = None
285
+ if est is None:
286
+ est = model
287
+ fi = None
288
+ if hasattr(est, "feature_importances_"):
289
+ fi = list(est.feature_importances_)
290
+ elif hasattr(est, "coef_"):
291
+ coef = est.coef_
292
+ if coef is not None:
293
+ fi = list(coef.reshape(-1))
294
+ if not fi or len(fi) != len(SUBSET_FEATURES):
295
+ return "Not available for this model."
296
+ pairs = sorted(zip(SUBSET_FEATURES, fi), key=lambda x: abs(x[1]), reverse=True)
297
+ return "\n".join([f"- {k}: {v:.4f}" for k, v in pairs])
298
+
299
+ GLOBAL_FI_TEXT = get_feature_importance_text()
300
+
301
+ # ---------- theme ----------
302
+ theme = gr.themes.Soft(
303
+ primary_hue="violet",
304
+ neutral_hue="slate",
305
+ ).set(
306
+ body_background_fill_dark="#0b0f19",
307
+ block_border_width="1px"
308
+ )
309
+
310
+ # ---------- UI ----------
311
+ with gr.Blocks(theme=theme, title="GTT Classifier — Manual + Fixed Samples") as demo:
312
+ gr.Markdown("## GTT Prediction (Subset Features)\n**PyCaret pipeline · Auto-preprocessing · Thresholdable**")
313
+
314
+ with gr.Row():
315
+ # (1) Manual input
316
+ with gr.Column(scale=1):
317
+ gr.Markdown("### 1) Manual input")
318
+
319
+ age = gr.Number(label="Age (years)", value=0)
320
+ bmi = gr.Number(label="BMI", value=0)
321
+ ab_count = gr.Number(label="Previos Obsteric History of Abortion (count)", value=0, precision=0)
322
+
323
+ gr.Markdown("---\n**Clinical flags**")
324
+ htn = gr.Checkbox(label="History of Hypertension", value=False)
325
+ cvd = gr.Checkbox(label="History of Cardiovascular disease", value=False)
326
+
327
+ with gr.Accordion("More numeric features (optional)", open=False):
328
+ fbs1 = gr.Number(label="FBS of First trimester")
329
+ hb = gr.Number(label="HB")
330
+ hct = gr.Number(label="HCT")
331
+ cr = gr.Number(label="CR")
332
+ plt_v = gr.Number(label="PLT")
333
+ vitd3 = gr.Number(label="Vit D3")
334
+ sono_nt = gr.Number(label="Sonographic NT")
335
+ sono_crl = gr.Number(label="Sonographic CRL")
336
+
337
+ with gr.Row():
338
+ threshold = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'")
339
+ reset_thr = gr.Button("↻", size="sm")
340
+
341
+ predict_btn = gr.Button("🚀 Predict (manual)", variant="primary")
342
+ explain_btn = gr.Button("🧠 Explain (SHAP for current input)")
343
+
344
+ # (2) Sample picker
345
+ with gr.Column(scale=1):
346
+ gr.Markdown("### 2) Sample picker (from fixed file)")
347
+ filt = gr.Dropdown(choices=["All", "0", "1"], value="All", label="Filter by target")
348
+ sample_dd = gr.Dropdown(choices=[], value=None, label="Choose sample row")
349
+ load_ok = gr.Button("Load sample into manual inputs", variant="secondary")
350
+
351
+ # (3) Results
352
+ with gr.Column(scale=1):
353
+ gr.Markdown("### 3) Results")
354
+
355
+ pred_label = gr.Number(label="Predicted label (with threshold decision)", interactive=False)
356
+ with gr.Row():
357
+ pred_prob = gr.Number(label="P(class==1)", value=0, interactive=False)
358
+ decision_text = gr.Textbox(label="Decision @ threshold", interactive=False)
359
+
360
+ gt_box = gr.Textbox(label="Ground truth (sample)", interactive=False)
361
+ correctness = gr.Textbox(label="Correct vs. ground truth?", interactive=False)
362
+
363
+ with gr.Accordion("Echoed input (row sent to model)", open=False):
364
+ echoed = gr.Code(label="", language="json")
365
+
366
+ with gr.Accordion("Global feature importance (SHAP)", open=False):
367
+ global_plot = gr.Plot(value=GLOBAL_FIG)
368
+ gr.Markdown("> Text fallback (native model importances):")
369
+ gr.Markdown(GLOBAL_FI_TEXT)
370
+
371
+ with gr.Accordion("Local explanation (SHAP) for current input", open=False):
372
+ local_plot = gr.Plot()
373
+
374
+ # events
375
+ demo.load(lambda: filter_sample_options("All"), inputs=None, outputs=[sample_dd], queue=False)
376
+ filt.change(filter_sample_options, inputs=[filt], outputs=[sample_dd])
377
+ reset_thr.click(fn=lambda: 0.5, inputs=None, outputs=[threshold])
378
+
379
+ load_ok.click(
380
+ fn=load_sample,
381
+ inputs=[sample_dd],
382
+ outputs=[
383
+ age, bmi, ab_count,
384
+ htn, cvd,
385
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl,
386
+ gt_box
387
+ ],
388
+ )
389
+
390
+ predict_btn.click(
391
+ fn=predict_manual,
392
+ inputs=[
393
+ threshold,
394
+ age, bmi, ab_count,
395
+ htn, cvd,
396
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
397
+ ],
398
+ outputs=[pred_label, pred_prob, decision_text, echoed],
399
+ ).then(
400
+ fn=compare_correctness,
401
+ inputs=[gt_box, pred_label],
402
+ outputs=[correctness]
403
+ )
404
+
405
+ explain_btn.click(
406
+ fn=explain_local,
407
+ inputs=[age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl],
408
+ outputs=[local_plot]
409
+ )
410
+
411
+ if __name__ == "__main__":
412
+ os.environ["NO_PROXY"] = "127.0.0.1,localhost"
413
+ os.environ["no_proxy"] = "127.0.0.1,localhost"
414
+ demo.launch()