techatcreated commited on
Commit
c929715
·
verified ·
1 Parent(s): 2bf0c7c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +638 -0
app.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Navya_Mrig.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/10xqPbYcTUoYEytn7C0HJoSNObUNmuCxZ
8
+ """
9
+
10
+ import re
11
+ import pickle
12
+ import joblib
13
+ import numpy as np
14
+ import pandas as pd
15
+ import gradio as gr
16
+
17
+ # =========================
18
+ # PATHS
19
+ # =========================
20
+ VAL_CSV_PATH = "/content/validation_data.csv"
21
+ MAIN_CSV_PATH = "/content/Cochlear_Implant_Dataset.csv"
22
+ CLF_PKL_PATH = "/content/ci_success_classifier.pkl"
23
+ REG_PKL_PATH = "/content/ci_speech_score_regressor.pkl"
24
+
25
+ # =========================
26
+ # Load data + models
27
+ # =========================
28
+ val_df = pd.read_csv(VAL_CSV_PATH)
29
+ main_df = pd.read_csv(MAIN_CSV_PATH)
30
+
31
+ def load_model(path: str):
32
+ try:
33
+ return joblib.load(path)
34
+ except Exception:
35
+ with open(path, "rb") as f:
36
+ return pickle.load(f)
37
+
38
+ clf_model = load_model(CLF_PKL_PATH)
39
+ reg_model = load_model(REG_PKL_PATH)
40
+
41
+ def get_model_feature_names(m):
42
+ if hasattr(m, "feature_names_in_"):
43
+ return list(getattr(m, "feature_names_in_"))
44
+ if hasattr(m, "named_steps"):
45
+ for step in m.named_steps.values():
46
+ if hasattr(step, "feature_names_in_"):
47
+ return list(step.feature_names_in_)
48
+ return None
49
+
50
+ clf_expected = get_model_feature_names(clf_model) or []
51
+ reg_expected = get_model_feature_names(reg_model) or []
52
+
53
+ # Union of expected columns (preserve order)
54
+ input_cols = []
55
+ for colset in [clf_expected, reg_expected]:
56
+ for c in colset:
57
+ if c not in input_cols:
58
+ input_cols.append(c)
59
+ if not input_cols:
60
+ input_cols = list(val_df.columns)
61
+
62
+ # =========================
63
+ # Build Gene dropdown choices from MAIN dataset
64
+ # =========================
65
+ def find_gene_column(df: pd.DataFrame):
66
+ if "Gene" in df.columns:
67
+ return "Gene"
68
+ for c in df.columns:
69
+ if "gene" in c.lower():
70
+ return c
71
+ return None
72
+
73
+ def normalize_str_series(s: pd.Series) -> pd.Series:
74
+ return (
75
+ s.astype(str)
76
+ .str.strip()
77
+ .replace({"null": np.nan, "NULL": np.nan, "None": np.nan, "none": np.nan,
78
+ "": np.nan, "nan": np.nan, "NaN": np.nan})
79
+ )
80
+
81
+ gene_col_main = find_gene_column(main_df)
82
+ gene_choices = []
83
+ if gene_col_main is not None:
84
+ gene_choices = sorted(set(normalize_str_series(main_df[gene_col_main]).dropna().tolist()))
85
+ if not gene_choices:
86
+ gene_col_val = find_gene_column(val_df)
87
+ if gene_col_val is not None:
88
+ gene_choices = sorted(set(normalize_str_series(val_df[gene_col_val]).dropna().tolist()))
89
+
90
+ # =========================
91
+ # Helpers
92
+ # =========================
93
+ def parse_age_to_years(age_raw: str, mode: str):
94
+ """
95
+ mode:
96
+ - "Years.Months (1.11 = 1y 11m)" -> 1 + 11/12
97
+ - "Decimal (1.11 = 1.11 years)" -> 1.11
98
+ Accepts "1.6YRS", "2yrs", etc.
99
+ """
100
+ if age_raw is None:
101
+ return np.nan
102
+
103
+ s = str(age_raw).strip()
104
+ if s == "" or s.lower() in {"nan", "none", "null"}:
105
+ return np.nan
106
+
107
+ cleaned = re.sub(r"[^0-9\.]", "", s)
108
+
109
+ if mode.startswith("Decimal"):
110
+ try:
111
+ return float(cleaned)
112
+ except:
113
+ return np.nan
114
+
115
+ # Years.Months mode
116
+ if cleaned.count(".") == 1:
117
+ a, b = cleaned.split(".")
118
+ if a.isdigit() and b.isdigit() and len(b) == 2:
119
+ years = int(a)
120
+ months = int(b)
121
+ if 0 <= months <= 11:
122
+ return years + months / 12.0
123
+ # fallback to decimal
124
+ try:
125
+ return float(cleaned)
126
+ except:
127
+ return np.nan
128
+
129
+ try:
130
+ return float(cleaned)
131
+ except:
132
+ return np.nan
133
+
134
+ def safe_pct(x):
135
+ try:
136
+ return int(round(float(x) * 100))
137
+ except:
138
+ return None
139
+
140
+ def get_gene_feature_name(cols):
141
+ # Prefer exact "Gene"
142
+ for c in cols:
143
+ if c.lower() == "gene":
144
+ return c
145
+ # Fallback: any column containing 'gene'
146
+ for c in cols:
147
+ if "gene" in c.lower():
148
+ return c
149
+ return None
150
+
151
+ def get_age_feature_names(cols):
152
+ return [c for c in cols if "age" in c.lower()]
153
+
154
+ GENE_FEAT = get_gene_feature_name(input_cols)
155
+ AGE_FEATS = get_age_feature_names(input_cols)
156
+
157
+ def align_to_expected(df: pd.DataFrame, expected_cols):
158
+ if not expected_cols:
159
+ return df
160
+ out = df.copy()
161
+ for c in expected_cols:
162
+ if c not in out.columns:
163
+ out[c] = np.nan
164
+ return out[expected_cols]
165
+
166
+ def render_single_result_html(gene, age_entered, age_used_years, parse_mode, label, prob, speech):
167
+ if label == 1:
168
+ status = "Likely Success"
169
+ badge = "ok"
170
+ icon = "✓"
171
+ elif label == 0:
172
+ status = "Lower Likelihood"
173
+ badge = "warn"
174
+ icon = "!"
175
+ else:
176
+ status = "Unavailable"
177
+ badge = "neutral"
178
+ icon = "?"
179
+
180
+ prob_pct = safe_pct(prob) if prob is not None else None
181
+ prob_text = f"{prob_pct}%" if prob_pct is not None else "—"
182
+ bar_width = f"{prob_pct}%" if prob_pct is not None else "0%"
183
+
184
+ try:
185
+ speech_disp = f"{float(speech):.3f}"
186
+ except:
187
+ speech_disp = "—"
188
+
189
+ age_used_disp = f"{float(age_used_years):.3f} years" if np.isfinite(age_used_years) else "—"
190
+ gene_disp = str(gene) if gene is not None else "—"
191
+
192
+ return f"""
193
+ <div class="result-card">
194
+ <div class="result-head">
195
+ <div class="result-title">Prediction</div>
196
+ <div class="pill {badge}">
197
+ <span class="dot"></span>
198
+ <span class="pill-ic">{icon}</span>
199
+ <span>{status}</span>
200
+ </div>
201
+ </div>
202
+
203
+ <div class="grid2">
204
+ <div class="box">
205
+ <div class="k">Gene</div>
206
+ <div class="v mono">{gene_disp}</div>
207
+ </div>
208
+ <div class="box">
209
+ <div class="k">Age entered</div>
210
+ <div class="v mono">{age_entered}</div>
211
+ </div>
212
+ </div>
213
+
214
+ <div class="box" style="margin-top:12px;">
215
+ <div class="k">Age used by model</div>
216
+ <div class="v mono">{age_used_disp}</div>
217
+ <div class="sub">Parsing mode: <span class="mono">{parse_mode}</span></div>
218
+ </div>
219
+
220
+ <div class="box" style="margin-top:12px;">
221
+ <div class="k">Success probability (Class 1)</div>
222
+ <div class="prob-row">
223
+ <div class="prob-bar"><div class="prob-fill" style="width:{bar_width};"></div></div>
224
+ <div class="prob-txt mono">{prob_text}</div>
225
+ </div>
226
+ </div>
227
+
228
+ <div class="grid2" style="margin-top:12px;">
229
+ <div class="box">
230
+ <div class="k">Predicted label</div>
231
+ <div class="v mono">{label}</div>
232
+ </div>
233
+ <div class="box">
234
+ <div class="k">Predicted speech score</div>
235
+ <div class="v mono">{speech_disp}</div>
236
+ </div>
237
+ </div>
238
+
239
+ <div class="fine">
240
+ Informational tool only. Not medical advice.
241
+ </div>
242
+ </div>
243
+ """
244
+
245
+ def predict_single(gene, age_text, parse_mode):
246
+ if gene is None or str(gene).strip() == "":
247
+ raise gr.Error("Please select a Gene.")
248
+
249
+ age_used = parse_age_to_years(age_text, parse_mode)
250
+ if not (isinstance(age_used, (float, np.floating)) and np.isfinite(age_used)):
251
+ raise gr.Error("Please enter a valid Age (e.g., 1.6YRS, 1.11, 2.3).")
252
+
253
+ # Build model input row using known feature names; fill others with NaN
254
+ row = {}
255
+ for c in input_cols:
256
+ if GENE_FEAT and c == GENE_FEAT:
257
+ row[c] = gene
258
+ elif c in AGE_FEATS:
259
+ row[c] = age_used
260
+ else:
261
+ row[c] = np.nan
262
+ X = pd.DataFrame([row])
263
+
264
+ Xc = align_to_expected(X, clf_expected)
265
+ Xr = align_to_expected(X, reg_expected)
266
+
267
+ label = int(clf_model.predict(Xc)[0])
268
+ prob = None
269
+ if hasattr(clf_model, "predict_proba"):
270
+ p = clf_model.predict_proba(Xc)[0]
271
+ if len(p) >= 2:
272
+ prob = float(p[1])
273
+
274
+ speech = reg_model.predict(Xr)[0]
275
+ return render_single_result_html(gene, age_text, age_used, parse_mode, label, prob, speech)
276
+
277
+ def _file_to_path(file_obj):
278
+ """Gradio File can be a string path, or have .name, or be dict-like depending on version."""
279
+ if file_obj is None:
280
+ return None
281
+ if isinstance(file_obj, str):
282
+ return file_obj
283
+ if hasattr(file_obj, "name"):
284
+ return file_obj.name
285
+ if isinstance(file_obj, dict) and "name" in file_obj:
286
+ return file_obj["name"]
287
+ return None
288
+
289
+ def predict_batch(csv_file, parse_mode):
290
+ path = _file_to_path(csv_file)
291
+ if not path:
292
+ raise gr.Error("Please upload a CSV file.")
293
+
294
+ df = pd.read_csv(path)
295
+ if df.empty:
296
+ raise gr.Error("Uploaded CSV is empty.")
297
+
298
+ df_cols_lower = {c.lower(): c for c in df.columns}
299
+
300
+ # Require at least Gene + one Age column (case-insensitive)
301
+ # Gene
302
+ gene_col = None
303
+ if GENE_FEAT and GENE_FEAT.lower() in df_cols_lower:
304
+ gene_col = df_cols_lower[GENE_FEAT.lower()]
305
+ else:
306
+ # fallback: any column containing 'gene'
307
+ for c in df.columns:
308
+ if "gene" in c.lower():
309
+ gene_col = c
310
+ break
311
+ if gene_col is None:
312
+ raise gr.Error("CSV must include a Gene column (e.g., 'Gene').")
313
+
314
+ # Age (at least one)
315
+ age_source_col = None
316
+ for c in df.columns:
317
+ if "age" in c.lower():
318
+ age_source_col = c
319
+ break
320
+ if age_source_col is None:
321
+ raise gr.Error("CSV must include an Age column (e.g., 'Age').")
322
+
323
+ # Build X in the exact model input_cols order; fill missing optional cols with NaN
324
+ X = pd.DataFrame(index=df.index)
325
+ parsed_age = df[age_source_col].apply(lambda v: parse_age_to_years(v, parse_mode))
326
+
327
+ if parsed_age.isna().any():
328
+ bad_n = int(parsed_age.isna().sum())
329
+ raise gr.Error(f"{bad_n} rows have invalid Age values for the selected parsing mode.")
330
+
331
+ for col in input_cols:
332
+ if GENE_FEAT and col == GENE_FEAT:
333
+ X[col] = df[gene_col]
334
+ elif col in AGE_FEATS:
335
+ X[col] = parsed_age
336
+ else:
337
+ # try case-insensitive exact match; else NaN
338
+ src = df_cols_lower.get(col.lower())
339
+ X[col] = df[src] if src is not None else np.nan
340
+
341
+ Xc = align_to_expected(X, clf_expected)
342
+ Xr = align_to_expected(X, reg_expected)
343
+
344
+ out = df.copy()
345
+ out["success_label_pred"] = clf_model.predict(Xc)
346
+
347
+ if hasattr(clf_model, "predict_proba"):
348
+ proba = clf_model.predict_proba(Xc)
349
+ if proba.shape[1] == 2:
350
+ out["success_prob_class1"] = proba[:, 1]
351
+
352
+ out["speech_score_pred"] = reg_model.predict(Xr)
353
+
354
+ out_path = "predictions_output.csv"
355
+ out.to_csv(out_path, index=False)
356
+
357
+ n = len(out)
358
+ succ = int((out["success_label_pred"] == 1).sum())
359
+ succ_pct = int(round((succ / n) * 100)) if n else 0
360
+
361
+ avg_prob_txt = "—"
362
+ if "success_prob_class1" in out.columns:
363
+ try:
364
+ avg_prob_txt = f"{int(round(float(out['success_prob_class1'].mean())*100))}%"
365
+ except:
366
+ pass
367
+
368
+ avg_speech_txt = "—"
369
+ try:
370
+ avg_speech_txt = f"{float(pd.to_numeric(out['speech_score_pred'], errors='coerce').mean()):.3f}"
371
+ except:
372
+ pass
373
+
374
+ summary = f"""
375
+ <div class="result-card">
376
+ <div class="result-head">
377
+ <div class="result-title">Batch Summary</div>
378
+ <div class="pill neutral"><span class="dot"></span><span class="pill-ic">↯</span><span>{n} rows</span></div>
379
+ </div>
380
+ <div class="grid3">
381
+ <div class="box"><div class="k">Predicted success</div><div class="v mono">{succ}</div></div>
382
+ <div class="box"><div class="k">Predicted success (%)</div><div class="v mono">{succ_pct}%</div></div>
383
+ <div class="box"><div class="k">Avg prob (Class 1)</div><div class="v mono">{avg_prob_txt}</div></div>
384
+ </div>
385
+ <div class="box" style="margin-top:12px;">
386
+ <div class="k">Avg speech score</div><div class="v mono">{avg_speech_txt}</div>
387
+ <div class="sub">Parsing mode: <span class="mono">{parse_mode}</span></div>
388
+ </div>
389
+ <div class="fine">Download the output CSV below.</div>
390
+ </div>
391
+ """
392
+ return summary, out.head(20), out_path
393
+
394
+ def age_preview(age_text, parse_mode):
395
+ v = parse_age_to_years(age_text, parse_mode)
396
+ if isinstance(v, (float, np.floating)) and np.isfinite(v):
397
+ return f"<div class='hint'>Model will use: <span class='mono'><b>{v:.3f}</b> years</span></div>"
398
+ return "<div class='hint'>Model will use: <span class='mono'>—</span></div>"
399
+
400
+ # =========================
401
+ # CSS: minimal, clean, mobile responsive + hide Gradio footer
402
+ # =========================
403
+ CSS = """
404
+ :root{
405
+ --bg:#f6f7fb;
406
+ --card:#ffffff;
407
+ --border:#e5e7eb;
408
+ --text:#0f172a;
409
+ --muted:#64748b;
410
+ --accent:#2563eb;
411
+ --ok:#16a34a;
412
+ --warn:#d97706;
413
+ --shadow: 0 10px 30px rgba(15, 23, 42, .08);
414
+ --radius: 16px;
415
+ }
416
+
417
+ .gradio-container{
418
+ background: var(--bg);
419
+ color: var(--text);
420
+ }
421
+
422
+ /* Hide Gradio footer / API bar */
423
+ footer, .footer, #footer, .gradio-footer { display:none !important; height:0 !important; }
424
+
425
+ /* Page wrapper */
426
+ #wrap{ max-width: 980px; margin: 0 auto; padding: 14px 12px 28px; }
427
+
428
+ /* Make Rows wrap on small screens */
429
+ .gr-row{ flex-wrap: wrap !important; gap: 12px !important; }
430
+ .gr-column{ min-width: 280px; }
431
+
432
+ /* Hero */
433
+ .hero{
434
+ padding: 16px 16px;
435
+ border-radius: var(--radius);
436
+ border: 1px solid var(--border);
437
+ background: linear-gradient(180deg, #ffffff, #fbfdff);
438
+ box-shadow: var(--shadow);
439
+ margin-bottom: 12px;
440
+ }
441
+ .hero h1{ margin:0; font-size: 18px; font-weight: 800; letter-spacing:.2px; }
442
+ .hero p{ margin:6px 0 0; color: var(--muted); font-size: 13px; line-height:1.35; }
443
+
444
+ /* Card wrapper for inputs/outputs */
445
+ .card{
446
+ background: var(--card);
447
+ border: 1px solid var(--border);
448
+ border-radius: var(--radius);
449
+ box-shadow: var(--shadow);
450
+ padding: 14px;
451
+ }
452
+
453
+ .mono{ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
454
+
455
+ /* Results */
456
+ .result-card{
457
+ background: #ffffff;
458
+ border: 1px solid var(--border);
459
+ border-radius: var(--radius);
460
+ padding: 14px;
461
+ box-shadow: var(--shadow);
462
+ }
463
+ .result-head{ display:flex; align-items:center; justify-content:space-between; gap:10px; margin-bottom:12px; }
464
+ .result-title{ font-size: 13px; font-weight: 900; letter-spacing:.3px; }
465
+
466
+ .grid2{ display:grid; grid-template-columns: 1fr 1fr; gap: 10px; }
467
+ .grid3{ display:grid; grid-template-columns: 1fr 1fr 1fr; gap: 10px; }
468
+
469
+ .box{
470
+ border: 1px solid var(--border);
471
+ background: #fbfcff;
472
+ border-radius: 14px;
473
+ padding: 12px;
474
+ }
475
+ .k{ color: var(--muted); font-size: 12px; }
476
+ .v{ color: var(--text); font-size: 14px; font-weight: 800; margin-top: 3px; }
477
+ .sub{ margin-top:6px; color: var(--muted); font-size: 11px; }
478
+
479
+ .pill{
480
+ display:flex; align-items:center; gap:8px;
481
+ padding: 8px 10px;
482
+ border-radius: 999px;
483
+ border: 1px solid var(--border);
484
+ background: #ffffff;
485
+ font-size: 12px;
486
+ white-space: nowrap;
487
+ }
488
+ .pill .dot{ width:10px; height:10px; border-radius:999px; background: rgba(100,116,139,.25); }
489
+ .pill.ok{ border-color: rgba(22,163,74,.25); }
490
+ .pill.ok .dot{ background: var(--ok); }
491
+ .pill.warn{ border-color: rgba(217,119,6,.25); }
492
+ .pill.warn .dot{ background: var(--warn); }
493
+ .pill.neutral{ border-color: rgba(37,99,235,.20); }
494
+ .pill.neutral .dot{ background: var(--accent); }
495
+ .pill-ic{ font-weight: 900; }
496
+
497
+ .prob-row{ display:flex; align-items:center; gap: 10px; margin-top: 6px; }
498
+ .prob-bar{
499
+ flex: 1;
500
+ height: 10px;
501
+ border-radius: 999px;
502
+ background: #eef2ff;
503
+ border: 1px solid rgba(37,99,235,.15);
504
+ overflow: hidden;
505
+ }
506
+ .prob-fill{
507
+ height: 100%;
508
+ background: linear-gradient(90deg, rgba(37,99,235,.95), rgba(22,163,74,.85));
509
+ border-radius: 999px;
510
+ }
511
+ .prob-txt{ width: 56px; text-align:right; color: var(--text); font-weight: 900; }
512
+
513
+ .fine{
514
+ margin-top: 12px;
515
+ font-size: 11px;
516
+ color: var(--muted);
517
+ line-height: 1.35;
518
+ }
519
+
520
+ .hint{
521
+ margin-top: 6px;
522
+ font-size: 12px;
523
+ color: var(--muted);
524
+ padding: 8px 10px;
525
+ border: 1px dashed rgba(100,116,139,.35);
526
+ border-radius: 12px;
527
+ background: #ffffff;
528
+ }
529
+
530
+ /* Primary button styling + full width on mobile */
531
+ #primaryBtn button{
532
+ border-radius: 14px !important;
533
+ border: 1px solid rgba(37,99,235,.35) !important;
534
+ background: var(--accent) !important;
535
+ color: white !important;
536
+ font-weight: 900 !important;
537
+ }
538
+ @media (max-width: 740px){
539
+ #primaryBtn button{ width: 100% !important; }
540
+ .grid2{ grid-template-columns: 1fr; }
541
+ .grid3{ grid-template-columns: 1fr; }
542
+ .result-head{ flex-direction: column; align-items: flex-start; }
543
+ .gr-column{ min-width: 100%; }
544
+ }
545
+ """
546
+
547
+ theme = gr.themes.Base(
548
+ primary_hue="blue",
549
+ secondary_hue="emerald",
550
+ neutral_hue="slate",
551
+ radius_size="lg",
552
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
553
+ )
554
+
555
+ # =========================
556
+ # UI
557
+ # =========================
558
+ with gr.Blocks(theme=theme, css=CSS, title="CI Outcome Predictor") as demo:
559
+ with gr.Column(elem_id="wrap"):
560
+ gr.HTML("""
561
+ <div class="hero">
562
+ <h1>CI Outcome Predictor</h1>
563
+ <p>Minimal UI for single and batch predictions. Gene options are loaded from the main dataset. Age parsing is shown transparently.</p>
564
+ </div>
565
+ """)
566
+
567
+ with gr.Tabs():
568
+ with gr.Tab("Single Prediction"):
569
+ with gr.Row():
570
+ with gr.Column(scale=1):
571
+ with gr.Group(elem_classes=["card"]):
572
+ gene_in = gr.Dropdown(
573
+ choices=gene_choices,
574
+ value=gene_choices[0] if gene_choices else None,
575
+ label="Gene",
576
+ filterable=True,
577
+ )
578
+ age_in = gr.Textbox(
579
+ label="Age",
580
+ placeholder="Examples: 1.11 | 1.6YRS | 2.3"
581
+ )
582
+ parse_mode = gr.Radio(
583
+ choices=[
584
+ "Decimal (1.11 = 1.11 years)",
585
+ "Years.Months (1.11 = 1y 11m)"
586
+ ],
587
+ value="Decimal (1.11 = 1.11 years)",
588
+ label="Age parsing"
589
+ )
590
+
591
+ age_hint = gr.HTML(value=age_preview("", "Decimal (1.11 = 1.11 years)"))
592
+
593
+ btn = gr.Button("Run Prediction", elem_id="primaryBtn")
594
+
595
+ with gr.Column(scale=1):
596
+ single_out = gr.HTML(value="", elem_classes=["card"])
597
+
598
+ # Live preview of how age will be interpreted
599
+ age_in.change(fn=age_preview, inputs=[age_in, parse_mode], outputs=[age_hint])
600
+ parse_mode.change(fn=age_preview, inputs=[age_in, parse_mode], outputs=[age_hint])
601
+
602
+ btn.click(
603
+ fn=predict_single,
604
+ inputs=[gene_in, age_in, parse_mode],
605
+ outputs=[single_out]
606
+ )
607
+
608
+ with gr.Tab("Batch Prediction (CSV)"):
609
+ with gr.Group(elem_classes=["card"]):
610
+ gr.Markdown(
611
+ "**Minimum required columns:** `Gene`, `Age` \n"
612
+ f"**Model feature columns (auto-filled if missing):** `{len(input_cols)}` total",
613
+ elem_classes=["mono"]
614
+ )
615
+
616
+ parse_mode_b = gr.Radio(
617
+ choices=[
618
+ "Decimal (1.11 = 1.11 years)",
619
+ "Years.Months (1.11 = 1y 11m)"
620
+ ],
621
+ value="Decimal (1.11 = 1.11 years)",
622
+ label="Age parsing"
623
+ )
624
+
625
+ csv_in = gr.File(file_types=[".csv"], label="Upload CSV")
626
+ run_b = gr.Button("Run Batch Prediction", elem_id="primaryBtn")
627
+
628
+ batch_summary = gr.HTML(value="")
629
+ preview = gr.Dataframe(label="Preview (first 20 rows)", wrap=True)
630
+ out_file = gr.File(label="Download predictions_output.csv")
631
+
632
+ run_b.click(
633
+ fn=predict_batch,
634
+ inputs=[csv_in, parse_mode_b],
635
+ outputs=[batch_summary, preview, out_file]
636
+ )
637
+
638
+ demo.launch(share=True)