GDMProjects commited on
Commit
2a03d03
·
verified ·
1 Parent(s): fe39161

Upload 6 files

Browse files
Files changed (6) hide show
  1. .dockerignore +4 -0
  2. Dockerfile +15 -0
  3. GTT.csv +10 -0
  4. app.py +409 -0
  5. requirements.txt +5 -0
  6. subset_best_model.pkl +3 -0
.dockerignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ subset_best_model.pkl
2
+ GTT.csv
3
+ app.py
4
+ requirements.txt
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+ WORKDIR /app
3
+
4
+ # Copy your app and data into the image
5
+ COPY app.py .
6
+ COPY requirements.txt .
7
+ COPY subset_best_model.pkl .
8
+ COPY data.csv .
9
+
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ EXPOSE 7860
13
+ ENV GRADIO_SERVER_NAME=0.0.0.0
14
+
15
+ CMD ["python", "app.py"]
GTT.csv ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ID,GTT,age,BMI,history_of_htn,history_infectious_cardiovascular_diseae,Previos_Obsteric_History_AB,FBS_first_trimester,HB,CR,Hct,PLT,VIT D3,sono_nt.crl,sono_nt.nt
2
+ 3505,1,37,22,0,0,0,88,13.4,0.6,39.9,278,36,56,1.6
3
+ 3530,1,33,28,0,0,0,96,14,0,41.2,187,9,54,1.4
4
+ 4057,0,33,26,0,0,2,110,12,0.7,35.9,333,26,48.3,1.1
5
+ 4491,0,27,25,0,0,3,84,13.6,0.7,40.3,204,13,69,1.9
6
+ 4707,0,39,27,0,0,1,71,14.9,0.6,44,335,14,64,1
7
+ 4813,0,37,22,0,0,1,88,13.2,0,37.9,150,39,54.3,1
8
+ 5098,0,36,25,0,0,4,91,12.9,1.8,38,288,16,55,1.3
9
+ 5314,1,41,35,1,0,0,98,10.8,0.9,34.5,398,21,45.2,3.2
10
+ 5767,1,37,22,0,0,0,101,14.5,1,42,300,33,62.2,1
app.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # pip install "pycaret>=3.3,<4" gradio pandas shap matplotlib
3
+
4
+ # --- FORCE NON-INTERACTIVE MATPLOTLIB BACKEND (must be first!) ---
5
+ import os
6
+ os.environ["MPLBACKEND"] = "Agg" # prevents Tk backend init
7
+ import matplotlib
8
+ matplotlib.use("Agg", force=True)
9
+
10
+ import json
11
+ import numpy as np
12
+ import pandas as pd
13
+ import gradio as gr
14
+ import matplotlib.pyplot as plt
15
+ import shap
16
+
17
+ from pathlib import Path
18
+ from pycaret.classification import load_model
19
+
20
+ # --- config ---
21
+ MODEL_BASENAME = "subset_best_model"
22
+ SAMPLES_CSV = "GTT.csv" # fixed hidden file
23
+ TARGET_COL = "gtt"
24
+ POS_LABEL = 1
25
+
26
+ # subset features used by the model (normalized names)
27
+ SUBSET_FEATURES = [
28
+ "age",
29
+ "bmi",
30
+ "history_of_htn",
31
+ "history_infectious_cardiovascular_diseae",
32
+ "previos_obsteric_history_ab",
33
+ "fbs_first_trimester",
34
+ "hb",
35
+ "hct",
36
+ "cr",
37
+ "plt",
38
+ "vit_d3",
39
+ "sono_nt_nt",
40
+ "sono_nt_crl",
41
+ ]
42
+
43
+ # ---------- utils ----------
44
+ def normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
45
+ out = df.copy()
46
+ out.columns = (
47
+ out.columns.str.strip()
48
+ .str.replace(r"[\s/\\\.\-]+", "_", regex=True)
49
+ .str.replace(r"__+", "_", regex=True)
50
+ .str.lower()
51
+ )
52
+ return out
53
+
54
+ def load_samples():
55
+ if not Path(SAMPLES_CSV).exists():
56
+ return None
57
+ df = pd.read_csv(SAMPLES_CSV)
58
+ df = normalize_cols(df)
59
+ needed = set(["id", TARGET_COL] + SUBSET_FEATURES)
60
+ if not needed.issubset(df.columns):
61
+ missing = needed - set(df.columns)
62
+ print(f"[WARN] samples file missing columns: {sorted(missing)}")
63
+ return None
64
+ df = df.reset_index(drop=False).rename(columns={"index": "_rid"}) # stable row id for dropdown
65
+ return df
66
+
67
+ def pretty_json(d):
68
+ return json.dumps(d, ensure_ascii=False, indent=2)
69
+
70
+ def as_bool(x, default=False):
71
+ if x is None or (isinstance(x, float) and pd.isna(x)):
72
+ return default
73
+ if isinstance(x, bool):
74
+ return x
75
+ if isinstance(x, (int,)):
76
+ return bool(x)
77
+ s = str(x).strip().lower()
78
+ yes = {"1","true","t","yes","y","on","pos","positive"}
79
+ no = {"0","false","f","no","n","off","neg","negative"}
80
+ if s in yes: return True
81
+ if s in no: return False
82
+ try:
83
+ return bool(int(float(s)))
84
+ except Exception:
85
+ return default
86
+
87
+ def f_or_none(v):
88
+ return float(v) if (v is not None and not (isinstance(v, float) and pd.isna(v))) else None
89
+
90
+ def build_row_dict(
91
+ age, bmi, ab_count,
92
+ htn, cvd,
93
+ fbs1, hb, hct, cr, plt, vitd3, sono_nt, sono_crl
94
+ ):
95
+ return {
96
+ "age": age,
97
+ "bmi": bmi,
98
+ "previos_obsteric_history_ab": ab_count,
99
+ "history_of_htn": 1 if htn else 0,
100
+ "history_infectious_cardiovascular_diseae": 1 if cvd else 0,
101
+ "fbs_first_trimester": fbs1,
102
+ "hb": hb,
103
+ "hct": hct,
104
+ "cr": cr,
105
+ "plt": plt,
106
+ "vit_d3": vitd3,
107
+ "sono_nt_nt": sono_nt,
108
+ "sono_nt_crl": sono_crl,
109
+ }
110
+
111
+ def _get_pos_index_and_classes(pipe, pos_label=1):
112
+ est = None
113
+ try:
114
+ est = getattr(pipe, "named_steps", {}).get("trained_model", None)
115
+ except Exception:
116
+ est = None
117
+ if est is None:
118
+ est = pipe
119
+ classes = getattr(est, "classes_", None)
120
+ if classes is not None and pos_label in list(classes):
121
+ return list(classes).index(pos_label), list(classes)
122
+ return -1, list(classes) if classes is not None else None
123
+
124
+ # ---------- model & samples ----------
125
+ model = load_model(MODEL_BASENAME)
126
+ samples_df = load_samples()
127
+
128
+ # ---------- SHAP: background + explainer (built once) ----------
129
+ def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame:
130
+ if df_samples is None:
131
+ # if no CSV, make a tiny synthetic background of zeros
132
+ bg = pd.DataFrame([{k: 0.0 for k in SUBSET_FEATURES} for _ in range(50)])
133
+ else:
134
+ bg = df_samples[SUBSET_FEATURES].copy()
135
+ # numeric coercion + median impute
136
+ for c in SUBSET_FEATURES:
137
+ if c not in bg.columns:
138
+ bg[c] = np.nan
139
+ bg = bg.apply(pd.to_numeric, errors="coerce")
140
+ bg = bg.fillna(bg.median(numeric_only=True))
141
+ if len(bg) > max_rows:
142
+ bg = bg.sample(max_rows, random_state=42)
143
+ return bg.reset_index(drop=True)
144
+
145
+ BACKGROUND = _prepare_background(samples_df)
146
+ POS_IDX, _ = _get_pos_index_and_classes(model, POS_LABEL)
147
+
148
+ def _f_proba_pos(X_np: np.ndarray) -> np.ndarray:
149
+ """Model function returning P(class==1) for SHAP. X_np is numpy; convert to DataFrame with right columns."""
150
+ X_df = pd.DataFrame(X_np, columns=SUBSET_FEATURES)
151
+ return model.predict_proba(X_df)[:, POS_IDX]
152
+
153
+ # SHAP Explainer (KernelExplainer via unified interface)
154
+ try:
155
+ EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values)
156
+ except Exception as e:
157
+ print("[WARN] SHAP explainer init failed:", e)
158
+ EXPLAINER = None
159
+
160
+ def _plot_local_shap(row_dict: dict):
161
+ """Returns a matplotlib Figure with local SHAP bar chart for the given row."""
162
+ if EXPLAINER is None:
163
+ return None
164
+ X = pd.DataFrame([row_dict], columns=SUBSET_FEATURES)
165
+ exp = EXPLAINER(X.values) # exp.values shape: (1, n_features)
166
+ vals = exp.values[0]
167
+ order = np.argsort(np.abs(vals))
168
+ fig, ax = plt.subplots(figsize=(7, 4.5))
169
+ ax.barh(np.array(SUBSET_FEATURES)[order], vals[order])
170
+ ax.axvline(0, linewidth=1)
171
+ ax.set_title("Local SHAP values (current input)")
172
+ ax.set_xlabel("Impact on P(class==1)")
173
+ fig.tight_layout()
174
+ return fig
175
+
176
+ def _plot_global_shap():
177
+ """Returns a matplotlib Figure with global mean(|SHAP|) bar chart over BACKGROUND."""
178
+ if EXPLAINER is None:
179
+ return None
180
+ exp = EXPLAINER(BACKGROUND.values)
181
+ mean_abs = np.mean(np.abs(exp.values), axis=0)
182
+ order = np.argsort(mean_abs)
183
+ fig, ax = plt.subplots(figsize=(7, 4.5))
184
+ ax.barh(np.array(SUBSET_FEATURES)[order], mean_abs[order])
185
+ ax.set_title("Global feature importance (mean |SHAP|)")
186
+ ax.set_xlabel("Mean |impact on P(class==1)|")
187
+ fig.tight_layout()
188
+ return fig
189
+
190
+ GLOBAL_FIG = _plot_global_shap()
191
+
192
+ # ---------- prediction ----------
193
+ def predict_manual(
194
+ threshold,
195
+ age, bmi, ab_count,
196
+ htn, cvd,
197
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
198
+ ):
199
+ row = build_row_dict(
200
+ age, bmi, ab_count,
201
+ htn, cvd,
202
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
203
+ )
204
+ df = pd.DataFrame([row], columns=SUBSET_FEATURES)
205
+ proba = model.predict_proba(df)
206
+ p1 = float(proba[0][POS_IDX])
207
+ decision = 1 if p1 >= float(threshold) else 0
208
+ return int(decision), round(p1, 4), ("Positive" if decision==1 else "Negative"), pretty_json(row)
209
+
210
+ def explain_local(
211
+ age, bmi, ab_count,
212
+ htn, cvd,
213
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
214
+ ):
215
+ row = build_row_dict(
216
+ age, bmi, ab_count,
217
+ htn, cvd,
218
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
219
+ )
220
+ fig = _plot_local_shap(row)
221
+ return fig
222
+
223
+ def explain_global():
224
+ return GLOBAL_FIG
225
+
226
+ def filter_sample_options(filter_target):
227
+ if samples_df is None:
228
+ return gr.update(choices=[], value=None)
229
+ df = samples_df
230
+ if filter_target in ("0", "1"):
231
+ df = df[df[TARGET_COL] == int(filter_target)]
232
+ opts = [ (f"{int(r['_rid'])}: y={int(r[TARGET_COL])}", int(r["_rid"])) for _, r in df.iterrows() ]
233
+ return gr.update(choices=opts, value=(opts[0][1] if opts else None))
234
+
235
+ def load_sample(rid):
236
+ if samples_df is None or rid is None:
237
+ return [gr.update()]*13 + [gr.update(value="")]
238
+ r = samples_df.loc[samples_df["_rid"] == int(rid)]
239
+ if r.empty:
240
+ return [gr.update()]*13 + [gr.update(value="")]
241
+ r = r.iloc[0]
242
+
243
+ updates = [
244
+ gr.update(value=f_or_none(r.get("age"))),
245
+ gr.update(value=f_or_none(r.get("bmi"))),
246
+ gr.update(value=int(r.get("previos_obsteric_history_ab", 0)) if pd.notna(r.get("previos_obsteric_history_ab")) else 0),
247
+
248
+ gr.update(value=as_bool(r.get("history_of_htn"))),
249
+ gr.update(value=as_bool(r.get("history_infectious_cardiovascular_diseae"))),
250
+
251
+ gr.update(value=f_or_none(r.get("fbs_first_trimester"))),
252
+ gr.update(value=f_or_none(r.get("hb"))),
253
+ gr.update(value=f_or_none(r.get("hct"))),
254
+ gr.update(value=f_or_none(r.get("cr"))),
255
+ gr.update(value=f_or_none(r.get("plt"))),
256
+ gr.update(value=f_or_none(r.get("vit_d3"))),
257
+ gr.update(value=f_or_none(r.get("sono_nt_nt"))),
258
+ gr.update(value=f_or_none(r.get("sono_nt_crl"))),
259
+
260
+ gr.update(value=str(int(r.get(TARGET_COL))) if pd.notna(r.get(TARGET_COL)) else "")
261
+ ]
262
+ return updates
263
+
264
+ def compare_correctness(gt_text, decision_label):
265
+ if gt_text is None or gt_text == "":
266
+ return "—"
267
+ try:
268
+ gt = int(float(gt_text))
269
+ except Exception:
270
+ return "—"
271
+ return "✅ Correct" if gt == int(decision_label) else "❌ Incorrect"
272
+
273
+ def get_feature_importance_text():
274
+ # Keep textual fallback if SHAP not available
275
+ est = None
276
+ try:
277
+ est = getattr(model, "named_steps", {}).get("trained_model", None)
278
+ except Exception:
279
+ est = None
280
+ if est is None:
281
+ est = model
282
+ fi = None
283
+ if hasattr(est, "feature_importances_"):
284
+ fi = list(est.feature_importances_)
285
+ elif hasattr(est, "coef_"):
286
+ coef = est.coef_
287
+ if coef is not None:
288
+ fi = list(coef.reshape(-1))
289
+ if not fi or len(fi) != len(SUBSET_FEATURES):
290
+ return "Not available for this model."
291
+ pairs = sorted(zip(SUBSET_FEATURES, fi), key=lambda x: abs(x[1]), reverse=True)
292
+ return "\n".join([f"- {k}: {v:.4f}" for k, v in pairs])
293
+
294
+ GLOBAL_FI_TEXT = get_feature_importance_text()
295
+
296
+ # ---------- theme ----------
297
+ theme = gr.themes.Soft(
298
+ primary_hue="violet",
299
+ neutral_hue="slate",
300
+ ).set(
301
+ body_background_fill_dark="#0b0f19",
302
+ block_border_width="1px"
303
+ )
304
+
305
+ # ---------- UI ----------
306
+ with gr.Blocks(theme=theme, title="GTT Classifier — Manual + Fixed Samples") as demo:
307
+ gr.Markdown("## GTT Prediction (Subset Features)\n**PyCaret pipeline · Auto-preprocessing · Thresholdable**")
308
+
309
+ with gr.Row():
310
+ # (1) Manual input
311
+ with gr.Column(scale=1):
312
+ gr.Markdown("### 1) Manual input")
313
+
314
+ age = gr.Number(label="Age (years)", value=0)
315
+ bmi = gr.Number(label="BMI", value=0)
316
+ ab_count = gr.Number(label="Previos Obsteric History of Abortion (count)", value=0, precision=0)
317
+
318
+ gr.Markdown("---\n**Clinical flags**")
319
+ htn = gr.Checkbox(label="History of Hypertension", value=False)
320
+ cvd = gr.Checkbox(label="History of Cardiovascular disease", value=False)
321
+
322
+ with gr.Accordion("More numeric features (optional)", open=False):
323
+ fbs1 = gr.Number(label="FBS of First trimester")
324
+ hb = gr.Number(label="HB")
325
+ hct = gr.Number(label="HCT")
326
+ cr = gr.Number(label="CR")
327
+ plt_v = gr.Number(label="PLT")
328
+ vitd3 = gr.Number(label="Vit D3")
329
+ sono_nt = gr.Number(label="Sonographic NT")
330
+ sono_crl = gr.Number(label="Sonographic CRL")
331
+
332
+ with gr.Row():
333
+ threshold = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'")
334
+ reset_thr = gr.Button("↻", size="sm")
335
+
336
+ predict_btn = gr.Button("🚀 Predict (manual)", variant="primary")
337
+ explain_btn = gr.Button("🧠 Explain (SHAP for current input)")
338
+
339
+ # (2) Sample picker
340
+ with gr.Column(scale=1):
341
+ gr.Markdown("### 2) Sample picker (from fixed file)")
342
+ filt = gr.Dropdown(choices=["All", "0", "1"], value="All", label="Filter by target")
343
+ sample_dd = gr.Dropdown(choices=[], value=None, label="Choose sample row")
344
+ load_ok = gr.Button("Load sample into manual inputs", variant="secondary")
345
+
346
+ # (3) Results
347
+ with gr.Column(scale=1):
348
+ gr.Markdown("### 3) Results")
349
+
350
+ pred_label = gr.Number(label="Predicted label (with threshold decision)", interactive=False)
351
+ with gr.Row():
352
+ pred_prob = gr.Number(label="P(class==1)", value=0, interactive=False)
353
+ decision_text = gr.Textbox(label="Decision @ threshold", interactive=False)
354
+
355
+ gt_box = gr.Textbox(label="Ground truth (sample)", interactive=False)
356
+ correctness = gr.Textbox(label="Correct vs. ground truth?", interactive=False)
357
+
358
+ with gr.Accordion("Echoed input (row sent to model)", open=False):
359
+ echoed = gr.Code(label="", language="json")
360
+
361
+ with gr.Accordion("Global feature importance (SHAP)", open=False):
362
+ global_plot = gr.Plot(value=GLOBAL_FIG)
363
+ gr.Markdown("> Text fallback (native model importances):")
364
+ gr.Markdown(GLOBAL_FI_TEXT)
365
+
366
+ with gr.Accordion("Local explanation (SHAP) for current input", open=False):
367
+ local_plot = gr.Plot()
368
+
369
+ # events
370
+ demo.load(lambda: filter_sample_options("All"), inputs=None, outputs=[sample_dd], queue=False)
371
+ filt.change(filter_sample_options, inputs=[filt], outputs=[sample_dd])
372
+ reset_thr.click(fn=lambda: 0.5, inputs=None, outputs=[threshold])
373
+
374
+ load_ok.click(
375
+ fn=load_sample,
376
+ inputs=[sample_dd],
377
+ outputs=[
378
+ age, bmi, ab_count,
379
+ htn, cvd,
380
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl,
381
+ gt_box
382
+ ],
383
+ )
384
+
385
+ predict_btn.click(
386
+ fn=predict_manual,
387
+ inputs=[
388
+ threshold,
389
+ age, bmi, ab_count,
390
+ htn, cvd,
391
+ fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl
392
+ ],
393
+ outputs=[pred_label, pred_prob, decision_text, echoed],
394
+ ).then(
395
+ fn=compare_correctness,
396
+ inputs=[gt_box, pred_label],
397
+ outputs=[correctness]
398
+ )
399
+
400
+ explain_btn.click(
401
+ fn=explain_local,
402
+ inputs=[age, bmi, ab_count, htn, cvd, fbs1, hb, hct, cr, plt_v, vitd3, sono_nt, sono_crl],
403
+ outputs=[local_plot]
404
+ )
405
+
406
+ if __name__ == "__main__":
407
+ os.environ["NO_PROXY"] = "127.0.0.1,localhost"
408
+ os.environ["no_proxy"] = "127.0.0.1,localhost"
409
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pycaret>=3.3,<4
2
+ gradio
3
+ pandas
4
+ shap
5
+ matplotlib
subset_best_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b87c8c09e49f9423392d1c4da3b319759820003428bb66bfe74f3155a18b82dd
3
+ size 149680