GDMProjects commited on
Commit
3fc7211
·
verified ·
1 Parent(s): 991eadf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -31
app.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  HOST, PORT, SHARE = "0.0.0.0", 7860, True
3
 
4
  # ---------- Env hygiene ----------
@@ -10,6 +10,11 @@ for _k in ("HTTP_PROXY","http_proxy","HTTPS_PROXY","https_proxy"):
10
  os.environ.setdefault("GRADIO_OPEN_BROWSER", "false")
11
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
12
 
 
 
 
 
 
13
  # ---------- Imports ----------
14
  from typing import Any, Dict, Optional, Tuple, List
15
  import re
@@ -17,16 +22,22 @@ import numpy as np
17
  import pandas as pd
18
  import gradio as gr
19
  from pathlib import Path
 
 
 
20
  from pycaret.classification import load_model, predict_model
21
  from huggingface_hub import hf_hub_download
22
- REPO = os.getenv("MODEL_REPO", "GDMProjects/my-private-model")
 
 
23
  FNAME = os.getenv("MODEL_FILE", "best_insulin_model.pkl")
24
  TOKEN = os.getenv("HF_TOKEN")
25
 
 
 
 
 
26
 
27
- SAMPLE_FILE = "INS.xlsx"
28
- TARGET_NAME = "insulin"
29
- POS_CLASS = 1
30
  FEATURES = [
31
  "age",
32
  "BMI",
@@ -40,13 +51,11 @@ FEATURES = [
40
  "Previos_Obsteric_History_AB",
41
  "infertility",
42
  ]
 
43
  NUMERIC_INPUTS = {"age", "BMI", "Previos_Obsteric_History_AB"}
44
- BOOL_FEATURES = [f for f in FEATURES if f not in NUMERIC_INPUTS] # 8 flags
45
 
46
  # ---------- Utilities ----------
47
- def strip_pkl(x: str) -> str:
48
- return x[:-4] if x.lower().endswith(".pkl") else x
49
-
50
  def normalize(s: str) -> str:
51
  return re.sub(r"[^a-z0-9]+", "", str(s).lower())
52
 
@@ -58,23 +67,25 @@ def coerce_numeric(val: Any) -> Optional[float]:
58
  def truthy(val: Any) -> bool:
59
  if pd.isna(val): return False
60
  s = str(val).strip().lower()
61
- return s in {"1","true","yes","y","t"} or val is True or val == 1
62
 
63
  def extract_probability_for_positive(preds: pd.DataFrame, positive_label=1) -> Optional[float]:
64
  str_pos = str(positive_label)
 
65
  if str_pos in preds.columns:
66
  return float(preds.iloc[0][str_pos])
67
  for c in preds.columns:
68
  if str_pos == str(c) or str(c).endswith("_"+str_pos):
69
  try: return float(preds.iloc[0][c])
70
  except: pass
71
- for cname in ("prediction_score","Score"):
72
  if cname in preds.columns:
73
  try: return float(preds.iloc[0][cname])
74
  except: pass
75
  return None
76
 
77
  def get_global_importance_table(model) -> Optional[pd.DataFrame]:
 
78
  try:
79
  if hasattr(model, "named_steps"):
80
  est = model.named_steps.get("trained_model", list(model.named_steps.values())[-1])
@@ -84,6 +95,7 @@ def get_global_importance_table(model) -> Optional[pd.DataFrame]:
84
  est = model
85
  except Exception:
86
  est = model
 
87
  X_cols = getattr(model, "feature_names_in_", None)
88
  if hasattr(est, "feature_importances_"):
89
  vals = np.asarray(est.feature_importances_)
@@ -92,6 +104,7 @@ def get_global_importance_table(model) -> Optional[pd.DataFrame]:
92
  else:
93
  df_imp = pd.DataFrame({"feature": [f"f{i}" for i in range(len(vals))], "importance": vals})
94
  return df_imp.sort_values("importance", ascending=False).reset_index(drop=True)
 
95
  if hasattr(est, "coef_"):
96
  coef = np.array(est.coef_)
97
  if coef.ndim > 1: coef = coef[0]
@@ -100,14 +113,35 @@ def get_global_importance_table(model) -> Optional[pd.DataFrame]:
100
  df_coef = pd.DataFrame({"feature": list(X_cols), "coefficient": coef})
101
  else:
102
  df_coef = pd.DataFrame({"feature": [f"f{i}" for i in range(len(coef))], "coefficient": coef})
103
- return df_coef.reindex(df_coef.iloc[:, -1].abs().sort_values(ascending=False).index).reset_index(drop=True)
 
 
104
  return None
105
 
106
- # ---------- Load model ----------
107
  local_path = hf_hub_download(repo_id=REPO, filename=FNAME, token=TOKEN)
108
  MODEL = load_model(str(Path(local_path).with_suffix("")))
109
 
110
- # ---------- Load fixed sample file ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def load_sample_dataframe(path: str) -> Tuple[pd.DataFrame, str]:
112
  if not os.path.exists(path):
113
  raise FileNotFoundError(f"Sample file not found: {path}")
@@ -139,13 +173,11 @@ def load_sample_dataframe(path: str) -> Tuple[pd.DataFrame, str]:
139
  try:
140
  SAMPLE_DF, SAMPLE_TARGET = load_sample_dataframe(SAMPLE_FILE)
141
  except Exception as e:
142
- # Fall back to empty DF but keep the app alive with a warning in UI
143
  SAMPLE_DF, SAMPLE_TARGET = pd.DataFrame(columns=FEATURES+[TARGET_NAME]), TARGET_NAME
144
  SAMPLE_ERROR = f"⚠️ Could not load sample file: {e}"
145
  else:
146
  SAMPLE_ERROR = ""
147
 
148
- # Build initial dropdown choices
149
  def build_sample_choices(df: pd.DataFrame, tgt: str, flt: str = "All") -> List[str]:
150
  if df.empty: return []
151
  if flt == "All":
@@ -155,6 +187,76 @@ def build_sample_choices(df: pd.DataFrame, tgt: str, flt: str = "All") -> List[s
155
  idxs = [i for i in range(len(df)) if str(df.iloc[i][tgt]) == str(want)]
156
  return [f"{i}: y={df.iloc[i][tgt]}" for i in idxs]
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # ---------- Gradio UI ----------
159
  with gr.Blocks(theme=gr.themes.Soft(), css="""
160
  * { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI; }
@@ -188,23 +290,27 @@ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 8px 0 14px; }
188
  checkbox_map[feat] = gr.Checkbox(label=feat, value=False)
189
 
190
  gr.Markdown("<hr class='sep'/>")
191
- thr = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label="Decision threshold for class '1'")
192
- run_btn = gr.Button("🚀 Predict (manual)", variant="primary")
 
 
193
 
194
  # -------- Sample picker (fixed file) --------
195
  gr.Markdown("<hr class='sep'/>")
196
  gr.Markdown("### 2) Sample picker (from fixed file)")
197
- grp_dd = gr.Dropdown(label="Filter by target", choices=["All","0","1"], value="All")
198
- choices0 = build_sample_choices(SAMPLE_DF, SAMPLE_TARGET, "All")
199
- sample_dd= gr.Dropdown(label="Choose sample row", choices=choices0, value=(choices0[0] if choices0 else None))
200
- pred_btn = gr.Button("🎯 Predict & compare (sample)", variant="primary")
 
 
201
 
202
  # -------- Right: Results --------
203
  with gr.Column(scale=1):
204
  gr.Markdown("### 3) Results")
205
  pred_label = gr.Textbox(label="Predicted label (with threshold decision)", interactive=False)
206
  with gr.Row():
207
- prob_out = gr.Number(label="P(class==1)", interactive=False, precision=6)
208
  decision = gr.Textbox(label="Decision @ threshold", interactive=False)
209
  with gr.Row():
210
  gt_out = gr.Textbox(label="Ground truth (sample)", interactive=False)
@@ -212,12 +318,14 @@ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 8px 0 14px; }
212
  with gr.Accordion("Echoed input (row sent to model)", open=False):
213
  echoed = gr.Dataframe(wrap=True)
214
 
215
- GI = get_global_importance_table(MODEL)
216
- if GI is not None and not GI.empty:
217
- with gr.Accordion("Global feature importance / coefficients", open=False):
218
- gr.Dataframe(value=GI, interactive=False, wrap=True)
219
- else:
220
- gr.Markdown("<div class='card small'>No native importances/coefficients available for this estimator.</div>")
 
 
221
 
222
  # -------- Manual predict --------
223
  def do_predict_manual(age, bmi, prev_ab_cnt, threshold, *flag_values):
@@ -246,6 +354,23 @@ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 8px 0 14px; }
246
  outputs=[pred_label, prob_out, decision, gt_out, match_out, echoed],
247
  )
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  # -------- Update sample choices on filter change --------
250
  def update_choices(group_value):
251
  ch = build_sample_choices(SAMPLE_DF, SAMPLE_TARGET, group_value)
@@ -253,6 +378,27 @@ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 8px 0 14px; }
253
 
254
  grp_dd.change(update_choices, inputs=[grp_dd], outputs=[sample_dd])
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  # -------- Predict & compare for selected sample --------
257
  def predict_sample(sample_choice, threshold):
258
  if SAMPLE_DF.empty or sample_choice is None or str(sample_choice).strip() == "":
@@ -274,7 +420,6 @@ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 8px 0 14px; }
274
  label = preds.iloc[0][label_col] if label_col else None
275
  p = extract_probability_for_positive(preds, positive_label=POS_CLASS)
276
 
277
- # Decision & compare
278
  if p is not None:
279
  dec = 1 if float(p) >= float(threshold) else 0
280
  pretty = f"{label} (threshold {threshold:.2f} ⇒ decision={dec})"
@@ -294,4 +439,4 @@ hr.sep { border: none; border-top: 1px solid #e5e7eb; margin: 8px 0 14px; }
294
 
295
  # ---------- Launch ----------
296
  if __name__ == "__main__":
297
- demo.launch()
 
1
+ # ---------- Host/port ----------
2
  HOST, PORT, SHARE = "0.0.0.0", 7860, True
3
 
4
  # ---------- Env hygiene ----------
 
10
  os.environ.setdefault("GRADIO_OPEN_BROWSER", "false")
11
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
12
 
13
+ # --- FORCE NON-INTERACTIVE MATPLOTLIB BACKEND (for SHAP plots) ---
14
+ os.environ["MPLBACKEND"] = "Agg"
15
+ import matplotlib
16
+ matplotlib.use("Agg", force=True)
17
+
18
  # ---------- Imports ----------
19
  from typing import Any, Dict, Optional, Tuple, List
20
  import re
 
22
  import pandas as pd
23
  import gradio as gr
24
  from pathlib import Path
25
+ import matplotlib.pyplot as plt
26
+ import shap
27
+
28
  from pycaret.classification import load_model, predict_model
29
  from huggingface_hub import hf_hub_download
30
+
31
+ # ---------- Hub model ----------
32
+ REPO = os.getenv("MODEL_REPO", "GDMProjects/my-private-model")
33
  FNAME = os.getenv("MODEL_FILE", "best_insulin_model.pkl")
34
  TOKEN = os.getenv("HF_TOKEN")
35
 
36
+ # ---------- Data / schema ----------
37
+ SAMPLE_FILE = "INS.xlsx"
38
+ TARGET_NAME = "insulin"
39
+ POS_CLASS = 1
40
 
 
 
 
41
  FEATURES = [
42
  "age",
43
  "BMI",
 
51
  "Previos_Obsteric_History_AB",
52
  "infertility",
53
  ]
54
+
55
  NUMERIC_INPUTS = {"age", "BMI", "Previos_Obsteric_History_AB"}
56
+ BOOL_FEATURES = [f for f in FEATURES if f not in NUMERIC_INPUTS] # flags
57
 
58
  # ---------- Utilities ----------
 
 
 
59
  def normalize(s: str) -> str:
60
  return re.sub(r"[^a-z0-9]+", "", str(s).lower())
61
 
 
67
  def truthy(val: Any) -> bool:
68
  if pd.isna(val): return False
69
  s = str(val).strip().lower()
70
+ return s in {"1","true","yes","y","t","on"} or val is True or val == 1
71
 
72
  def extract_probability_for_positive(preds: pd.DataFrame, positive_label=1) -> Optional[float]:
73
  str_pos = str(positive_label)
74
+ # PyCaret predict_model often outputs per-class columns named as labels
75
  if str_pos in preds.columns:
76
  return float(preds.iloc[0][str_pos])
77
  for c in preds.columns:
78
  if str_pos == str(c) or str(c).endswith("_"+str_pos):
79
  try: return float(preds.iloc[0][c])
80
  except: pass
81
+ for cname in ("prediction_score","Score","score"):
82
  if cname in preds.columns:
83
  try: return float(preds.iloc[0][cname])
84
  except: pass
85
  return None
86
 
87
  def get_global_importance_table(model) -> Optional[pd.DataFrame]:
88
+ """Fallback (non-SHAP) importances/coefficients from the final estimator."""
89
  try:
90
  if hasattr(model, "named_steps"):
91
  est = model.named_steps.get("trained_model", list(model.named_steps.values())[-1])
 
95
  est = model
96
  except Exception:
97
  est = model
98
+
99
  X_cols = getattr(model, "feature_names_in_", None)
100
  if hasattr(est, "feature_importances_"):
101
  vals = np.asarray(est.feature_importances_)
 
104
  else:
105
  df_imp = pd.DataFrame({"feature": [f"f{i}" for i in range(len(vals))], "importance": vals})
106
  return df_imp.sort_values("importance", ascending=False).reset_index(drop=True)
107
+
108
  if hasattr(est, "coef_"):
109
  coef = np.array(est.coef_)
110
  if coef.ndim > 1: coef = coef[0]
 
113
  df_coef = pd.DataFrame({"feature": list(X_cols), "coefficient": coef})
114
  else:
115
  df_coef = pd.DataFrame({"feature": [f"f{i}" for i in range(len(coef))], "coefficient": coef})
116
+ order = df_coef.iloc[:, -1].abs().sort_values(ascending=False).index
117
+ return df_coef.reindex(order).reset_index(drop=True)
118
+
119
  return None
120
 
121
+ # ---------- Load model (strip .pkl because PyCaret appends) ----------
122
  local_path = hf_hub_download(repo_id=REPO, filename=FNAME, token=TOKEN)
123
  MODEL = load_model(str(Path(local_path).with_suffix("")))
124
 
125
+ # ---------- Helpers to find positive-class index for predict_proba ----------
126
+ def _get_pos_index_and_classes(pipe, pos_label=1):
127
+ est = None
128
+ try:
129
+ est = getattr(pipe, "named_steps", {}).get("trained_model", None)
130
+ except Exception:
131
+ est = None
132
+ if est is None:
133
+ est = pipe
134
+ classes = getattr(est, "classes_", None)
135
+ if classes is not None and pos_label in list(classes):
136
+ return list(classes).index(pos_label), list(classes)
137
+ # fallback: assume last column is positive if 2-class
138
+ if classes is not None and len(classes) == 2:
139
+ return 1, list(classes)
140
+ return -1, list(classes) if classes is not None else None
141
+
142
+ POS_IDX, _CLASSES = _get_pos_index_and_classes(MODEL, POS_CLASS)
143
+
144
+ # ---------- Load fixed sample file (+ normalizer) ----------
145
  def load_sample_dataframe(path: str) -> Tuple[pd.DataFrame, str]:
146
  if not os.path.exists(path):
147
  raise FileNotFoundError(f"Sample file not found: {path}")
 
173
  try:
174
  SAMPLE_DF, SAMPLE_TARGET = load_sample_dataframe(SAMPLE_FILE)
175
  except Exception as e:
 
176
  SAMPLE_DF, SAMPLE_TARGET = pd.DataFrame(columns=FEATURES+[TARGET_NAME]), TARGET_NAME
177
  SAMPLE_ERROR = f"⚠️ Could not load sample file: {e}"
178
  else:
179
  SAMPLE_ERROR = ""
180
 
 
181
  def build_sample_choices(df: pd.DataFrame, tgt: str, flt: str = "All") -> List[str]:
182
  if df.empty: return []
183
  if flt == "All":
 
187
  idxs = [i for i in range(len(df)) if str(df.iloc[i][tgt]) == str(want)]
188
  return [f"{i}: y={df.iloc[i][tgt]}" for i in idxs]
189
 
190
+ # ---------- SHAP background / explainer ----------
191
+ def _prepare_background(df_samples: pd.DataFrame | None, max_rows: int = 200) -> pd.DataFrame:
192
+ if df_samples is None or df_samples.empty:
193
+ # tiny synthetic background of zeros
194
+ bg = pd.DataFrame([{k: 0.0 for k in FEATURES} for _ in range(50)])
195
+ else:
196
+ bg = df_samples[FEATURES].copy()
197
+ # numeric coercion + boolean to {0,1} + median impute
198
+ for c in FEATURES:
199
+ if c not in bg.columns:
200
+ bg[c] = np.nan
201
+ for c in FEATURES:
202
+ if c in NUMERIC_INPUTS:
203
+ bg[c] = pd.to_numeric(bg[c], errors="coerce")
204
+ else:
205
+ bg[c] = bg[c].apply(lambda v: 1.0 if truthy(v) else 0.0)
206
+ bg = bg.fillna(bg.median(numeric_only=True))
207
+ if len(bg) > max_rows:
208
+ bg = bg.sample(max_rows, random_state=42)
209
+ return bg.reset_index(drop=True)
210
+
211
+ BACKGROUND = _prepare_background(SAMPLE_DF)
212
+
213
+ def _f_proba_pos(X_np: np.ndarray) -> np.ndarray:
214
+ X_df = pd.DataFrame(X_np, columns=FEATURES)
215
+ proba = MODEL.predict_proba(X_df)
216
+ if POS_IDX >= 0 and POS_IDX < proba.shape[1]:
217
+ return proba[:, POS_IDX]
218
+ # fallback: try class "1" if present
219
+ if proba.shape[1] >= 2:
220
+ return proba[:, 1]
221
+ return proba[:, 0]
222
+
223
+ try:
224
+ EXPLAINER = shap.Explainer(_f_proba_pos, BACKGROUND.values)
225
+ except Exception as e:
226
+ print("[WARN] SHAP explainer init failed:", e)
227
+ EXPLAINER = None
228
+
229
+ def _plot_local_shap(row_dict: dict):
230
+ if EXPLAINER is None:
231
+ return None
232
+ X = pd.DataFrame([row_dict], columns=FEATURES)
233
+ exp = EXPLAINER(X.values) # (1, n_features)
234
+ vals = exp.values[0]
235
+ order = np.argsort(np.abs(vals))
236
+ fig, ax = plt.subplots(figsize=(7, 4.5))
237
+ ax.barh(np.array(FEATURES)[order], vals[order])
238
+ ax.axvline(0, linewidth=1)
239
+ ax.set_title("Local SHAP values (current input)")
240
+ ax.set_xlabel(f"Impact on P(class=={POS_CLASS})")
241
+ fig.tight_layout()
242
+ return fig
243
+
244
+ def _plot_global_shap():
245
+ if EXPLAINER is None:
246
+ return None
247
+ exp = EXPLAINER(BACKGROUND.values)
248
+ mean_abs = np.mean(np.abs(exp.values), axis=0)
249
+ order = np.argsort(mean_abs)
250
+ fig, ax = plt.subplots(figsize=(7, 4.5))
251
+ ax.barh(np.array(FEATURES)[order], mean_abs[order])
252
+ ax.set_title("Global feature importance (mean |SHAP|)")
253
+ ax.set_xlabel(f"Mean |impact on P(class=={POS_CLASS})|")
254
+ fig.tight_layout()
255
+ return fig
256
+
257
+ GLOBAL_FIG = _plot_global_shap()
258
+ GLOBAL_FI_TEXT = (get_global_importance_table(MODEL) or pd.DataFrame())
259
+
260
  # ---------- Gradio UI ----------
261
  with gr.Blocks(theme=gr.themes.Soft(), css="""
262
  * { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI; }
 
290
  checkbox_map[feat] = gr.Checkbox(label=feat, value=False)
291
 
292
  gr.Markdown("<hr class='sep'/>")
293
+ thr = gr.Slider(0.05, 0.95, value=0.50, step=0.01, label=f"Decision threshold for class '{POS_CLASS}'")
294
+ with gr.Row():
295
+ run_btn = gr.Button("🚀 Predict (manual)", variant="primary")
296
+ explain_btn = gr.Button("🧠 Explain (SHAP for current input)")
297
 
298
  # -------- Sample picker (fixed file) --------
299
  gr.Markdown("<hr class='sep'/>")
300
  gr.Markdown("### 2) Sample picker (from fixed file)")
301
+ grp_dd = gr.Dropdown(label="Filter by target", choices=["All","0","1"], value="All")
302
+ choices0 = build_sample_choices(SAMPLE_DF, SAMPLE_TARGET, "All")
303
+ sample_dd = gr.Dropdown(label="Choose sample row", choices=choices0, value=(choices0[0] if choices0 else None))
304
+ with gr.Row():
305
+ load_btn = gr.Button("📥 Load sample into manual inputs", variant="secondary")
306
+ pred_btn = gr.Button("🎯 Predict & compare (sample)", variant="primary")
307
 
308
  # -------- Right: Results --------
309
  with gr.Column(scale=1):
310
  gr.Markdown("### 3) Results")
311
  pred_label = gr.Textbox(label="Predicted label (with threshold decision)", interactive=False)
312
  with gr.Row():
313
+ prob_out = gr.Number(label=f"P(class=={POS_CLASS})", interactive=False, precision=6)
314
  decision = gr.Textbox(label="Decision @ threshold", interactive=False)
315
  with gr.Row():
316
  gt_out = gr.Textbox(label="Ground truth (sample)", interactive=False)
 
318
  with gr.Accordion("Echoed input (row sent to model)", open=False):
319
  echoed = gr.Dataframe(wrap=True)
320
 
321
+ with gr.Accordion("Global feature importance (SHAP)", open=False):
322
+ gr.Plot(value=GLOBAL_FIG)
323
+ if isinstance(GLOBAL_FI_TEXT, pd.DataFrame) and not GLOBAL_FI_TEXT.empty:
324
+ gr.Markdown("> Text fallback (native model importances/coefficients):")
325
+ gr.Dataframe(value=GLOBAL_FI_TEXT, interactive=False, wrap=True)
326
+
327
+ with gr.Accordion("Local explanation (SHAP) for current input", open=False):
328
+ local_plot = gr.Plot()
329
 
330
  # -------- Manual predict --------
331
  def do_predict_manual(age, bmi, prev_ab_cnt, threshold, *flag_values):
 
354
  outputs=[pred_label, prob_out, decision, gt_out, match_out, echoed],
355
  )
356
 
357
+ # -------- Local SHAP for current manual input --------
358
+ def do_explain_local(age, bmi, prev_ab_cnt, *flag_values):
359
+ row = {c: None for c in FEATURES}
360
+ row["age"] = coerce_numeric(age)
361
+ row["BMI"] = coerce_numeric(bmi)
362
+ row["Previos_Obsteric_History_AB"] = coerce_numeric(prev_ab_cnt)
363
+ for feat, val in zip(BOOL_FEATURES, flag_values):
364
+ row[feat] = 1.0 if bool(val) else 0.0
365
+ fig = _plot_local_shap(row)
366
+ return fig
367
+
368
+ explain_btn.click(
369
+ do_explain_local,
370
+ inputs=[age_in, bmi_in, prev_ab] + [checkbox_map[f] for f in BOOL_FEATURES],
371
+ outputs=[local_plot],
372
+ )
373
+
374
  # -------- Update sample choices on filter change --------
375
  def update_choices(group_value):
376
  ch = build_sample_choices(SAMPLE_DF, SAMPLE_TARGET, group_value)
 
378
 
379
  grp_dd.change(update_choices, inputs=[grp_dd], outputs=[sample_dd])
380
 
381
+ # -------- Load selected sample INTO manual inputs --------
382
+ def load_into_manual(sample_choice):
383
+ if SAMPLE_DF.empty or sample_choice is None or str(sample_choice).strip() == "":
384
+ raise gr.Error("Sample file is empty or no row selected. Check SAMPLE_FILE path.")
385
+ idx = int(str(sample_choice).split(":")[0])
386
+ srow = SAMPLE_DF.iloc[idx]
387
+
388
+ updates = [
389
+ gr.update(value=coerce_numeric(srow["age"])),
390
+ gr.update(value=coerce_numeric(srow["BMI"])),
391
+ gr.update(value=coerce_numeric(srow["Previos_Obsteric_History_AB"])),
392
+ ]
393
+ for feat in BOOL_FEATURES:
394
+ updates.append(gr.update(value=bool(truthy(srow[feat]))))
395
+ # also surface ground truth to the Results panel
396
+ updates.append(gr.update(value=str(srow[SAMPLE_TARGET])))
397
+ return updates
398
+
399
+ load_into_outputs = [age_in, bmi_in, prev_ab] + [checkbox_map[f] for f in BOOL_FEATURES] + [gt_out]
400
+ load_btn.click(load_into_manual, inputs=[sample_dd], outputs=load_into_outputs)
401
+
402
  # -------- Predict & compare for selected sample --------
403
  def predict_sample(sample_choice, threshold):
404
  if SAMPLE_DF.empty or sample_choice is None or str(sample_choice).strip() == "":
 
420
  label = preds.iloc[0][label_col] if label_col else None
421
  p = extract_probability_for_positive(preds, positive_label=POS_CLASS)
422
 
 
423
  if p is not None:
424
  dec = 1 if float(p) >= float(threshold) else 0
425
  pretty = f"{label} (threshold {threshold:.2f} ⇒ decision={dec})"
 
439
 
440
  # ---------- Launch ----------
441
  if __name__ == "__main__":
442
+ demo.launch(server_name=HOST, server_port=PORT, share=SHARE)