techatcreated commited on
Commit
d0c250a
·
verified ·
1 Parent(s): c929715

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -285
app.py CHANGED
@@ -1,43 +1,80 @@
1
- # -*- coding: utf-8 -*-
2
- """Navya_Mrig.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/10xqPbYcTUoYEytn7C0HJoSNObUNmuCxZ
8
- """
9
-
10
  import re
11
  import pickle
12
  import joblib
13
  import numpy as np
14
  import pandas as pd
15
  import gradio as gr
 
16
 
17
  # =========================
18
- # PATHS
19
  # =========================
20
- VAL_CSV_PATH = "/content/validation_data.csv"
21
- MAIN_CSV_PATH = "/content/Cochlear_Implant_Dataset.csv"
22
- CLF_PKL_PATH = "/content/ci_success_classifier.pkl"
23
- REG_PKL_PATH = "/content/ci_speech_score_regressor.pkl"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # =========================
26
- # Load data + models
27
  # =========================
28
- val_df = pd.read_csv(VAL_CSV_PATH)
29
- main_df = pd.read_csv(MAIN_CSV_PATH)
30
 
31
- def load_model(path: str):
32
- try:
33
- return joblib.load(path)
34
- except Exception:
35
- with open(path, "rb") as f:
36
- return pickle.load(f)
37
 
38
- clf_model = load_model(CLF_PKL_PATH)
39
- reg_model = load_model(REG_PKL_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
41
  def get_model_feature_names(m):
42
  if hasattr(m, "feature_names_in_"):
43
  return list(getattr(m, "feature_names_in_"))
@@ -47,17 +84,25 @@ def get_model_feature_names(m):
47
  return list(step.feature_names_in_)
48
  return None
49
 
50
- clf_expected = get_model_feature_names(clf_model) or []
51
- reg_expected = get_model_feature_names(reg_model) or []
52
-
53
- # Union of expected columns (preserve order)
54
- input_cols = []
55
- for colset in [clf_expected, reg_expected]:
56
- for c in colset:
57
- if c not in input_cols:
58
- input_cols.append(c)
59
- if not input_cols:
60
- input_cols = list(val_df.columns)
 
 
 
 
 
 
 
 
61
 
62
  # =========================
63
  # Build Gene dropdown choices from MAIN dataset
@@ -78,25 +123,20 @@ def normalize_str_series(s: pd.Series) -> pd.Series:
78
  "": np.nan, "nan": np.nan, "NaN": np.nan})
79
  )
80
 
81
- gene_col_main = find_gene_column(main_df)
82
  gene_choices = []
83
- if gene_col_main is not None:
84
- gene_choices = sorted(set(normalize_str_series(main_df[gene_col_main]).dropna().tolist()))
85
- if not gene_choices:
86
- gene_col_val = find_gene_column(val_df)
87
- if gene_col_val is not None:
88
- gene_choices = sorted(set(normalize_str_series(val_df[gene_col_val]).dropna().tolist()))
 
 
89
 
90
  # =========================
91
  # Helpers
92
  # =========================
93
  def parse_age_to_years(age_raw: str, mode: str):
94
- """
95
- mode:
96
- - "Years.Months (1.11 = 1y 11m)" -> 1 + 11/12
97
- - "Decimal (1.11 = 1.11 years)" -> 1.11
98
- Accepts "1.6YRS", "2yrs", etc.
99
- """
100
  if age_raw is None:
101
  return np.nan
102
 
@@ -112,7 +152,6 @@ def parse_age_to_years(age_raw: str, mode: str):
112
  except:
113
  return np.nan
114
 
115
- # Years.Months mode
116
  if cleaned.count(".") == 1:
117
  a, b = cleaned.split(".")
118
  if a.isdigit() and b.isdigit() and len(b) == 2:
@@ -120,7 +159,6 @@ def parse_age_to_years(age_raw: str, mode: str):
120
  months = int(b)
121
  if 0 <= months <= 11:
122
  return years + months / 12.0
123
- # fallback to decimal
124
  try:
125
  return float(cleaned)
126
  except:
@@ -138,11 +176,9 @@ def safe_pct(x):
138
  return None
139
 
140
  def get_gene_feature_name(cols):
141
- # Prefer exact "Gene"
142
  for c in cols:
143
  if c.lower() == "gene":
144
  return c
145
- # Fallback: any column containing 'gene'
146
  for c in cols:
147
  if "gene" in c.lower():
148
  return c
@@ -151,8 +187,8 @@ def get_gene_feature_name(cols):
151
  def get_age_feature_names(cols):
152
  return [c for c in cols if "age" in c.lower()]
153
 
154
- GENE_FEAT = get_gene_feature_name(input_cols)
155
- AGE_FEATS = get_age_feature_names(input_cols)
156
 
157
  def align_to_expected(df: pd.DataFrame, expected_cols):
158
  if not expected_cols:
@@ -243,6 +279,9 @@ def render_single_result_html(gene, age_entered, age_used_years, parse_mode, lab
243
  """
244
 
245
  def predict_single(gene, age_text, parse_mode):
 
 
 
246
  if gene is None or str(gene).strip() == "":
247
  raise gr.Error("Please select a Gene.")
248
 
@@ -250,7 +289,6 @@ def predict_single(gene, age_text, parse_mode):
250
  if not (isinstance(age_used, (float, np.floating)) and np.isfinite(age_used)):
251
  raise gr.Error("Please enter a valid Age (e.g., 1.6YRS, 1.11, 2.3).")
252
 
253
- # Build model input row using known feature names; fill others with NaN
254
  row = {}
255
  for c in input_cols:
256
  if GENE_FEAT and c == GENE_FEAT:
@@ -275,7 +313,6 @@ def predict_single(gene, age_text, parse_mode):
275
  return render_single_result_html(gene, age_text, age_used, parse_mode, label, prob, speech)
276
 
277
  def _file_to_path(file_obj):
278
- """Gradio File can be a string path, or have .name, or be dict-like depending on version."""
279
  if file_obj is None:
280
  return None
281
  if isinstance(file_obj, str):
@@ -287,6 +324,9 @@ def _file_to_path(file_obj):
287
  return None
288
 
289
  def predict_batch(csv_file, parse_mode):
 
 
 
290
  path = _file_to_path(csv_file)
291
  if not path:
292
  raise gr.Error("Please upload a CSV file.")
@@ -297,13 +337,10 @@ def predict_batch(csv_file, parse_mode):
297
 
298
  df_cols_lower = {c.lower(): c for c in df.columns}
299
 
300
- # Require at least Gene + one Age column (case-insensitive)
301
- # Gene
302
  gene_col = None
303
  if GENE_FEAT and GENE_FEAT.lower() in df_cols_lower:
304
  gene_col = df_cols_lower[GENE_FEAT.lower()]
305
  else:
306
- # fallback: any column containing 'gene'
307
  for c in df.columns:
308
  if "gene" in c.lower():
309
  gene_col = c
@@ -311,7 +348,6 @@ def predict_batch(csv_file, parse_mode):
311
  if gene_col is None:
312
  raise gr.Error("CSV must include a Gene column (e.g., 'Gene').")
313
 
314
- # Age (at least one)
315
  age_source_col = None
316
  for c in df.columns:
317
  if "age" in c.lower():
@@ -320,7 +356,6 @@ def predict_batch(csv_file, parse_mode):
320
  if age_source_col is None:
321
  raise gr.Error("CSV must include an Age column (e.g., 'Age').")
322
 
323
- # Build X in the exact model input_cols order; fill missing optional cols with NaN
324
  X = pd.DataFrame(index=df.index)
325
  parsed_age = df[age_source_col].apply(lambda v: parse_age_to_years(v, parse_mode))
326
 
@@ -334,7 +369,6 @@ def predict_batch(csv_file, parse_mode):
334
  elif col in AGE_FEATS:
335
  X[col] = parsed_age
336
  else:
337
- # try case-insensitive exact match; else NaN
338
  src = df_cols_lower.get(col.lower())
339
  X[col] = df[src] if src is not None else np.nan
340
 
@@ -351,8 +385,7 @@ def predict_batch(csv_file, parse_mode):
351
 
352
  out["speech_score_pred"] = reg_model.predict(Xr)
353
 
354
- out_path = "predictions_output.csv"
355
- out.to_csv(out_path, index=False)
356
 
357
  n = len(out)
358
  succ = int((out["success_label_pred"] == 1).sum())
@@ -389,7 +422,7 @@ def predict_batch(csv_file, parse_mode):
389
  <div class="fine">Download the output CSV below.</div>
390
  </div>
391
  """
392
- return summary, out.head(20), out_path
393
 
394
  def age_preview(age_text, parse_mode):
395
  v = parse_age_to_years(age_text, parse_mode)
@@ -398,151 +431,11 @@ def age_preview(age_text, parse_mode):
398
  return "<div class='hint'>Model will use: <span class='mono'>—</span></div>"
399
 
400
  # =========================
401
- # CSS: minimal, clean, mobile responsive + hide Gradio footer
402
  # =========================
403
- CSS = """
404
- :root{
405
- --bg:#f6f7fb;
406
- --card:#ffffff;
407
- --border:#e5e7eb;
408
- --text:#0f172a;
409
- --muted:#64748b;
410
- --accent:#2563eb;
411
- --ok:#16a34a;
412
- --warn:#d97706;
413
- --shadow: 0 10px 30px rgba(15, 23, 42, .08);
414
- --radius: 16px;
415
- }
416
-
417
- .gradio-container{
418
- background: var(--bg);
419
- color: var(--text);
420
- }
421
-
422
- /* Hide Gradio footer / API bar */
423
- footer, .footer, #footer, .gradio-footer { display:none !important; height:0 !important; }
424
-
425
- /* Page wrapper */
426
- #wrap{ max-width: 980px; margin: 0 auto; padding: 14px 12px 28px; }
427
-
428
- /* Make Rows wrap on small screens */
429
- .gr-row{ flex-wrap: wrap !important; gap: 12px !important; }
430
- .gr-column{ min-width: 280px; }
431
-
432
- /* Hero */
433
- .hero{
434
- padding: 16px 16px;
435
- border-radius: var(--radius);
436
- border: 1px solid var(--border);
437
- background: linear-gradient(180deg, #ffffff, #fbfdff);
438
- box-shadow: var(--shadow);
439
- margin-bottom: 12px;
440
- }
441
- .hero h1{ margin:0; font-size: 18px; font-weight: 800; letter-spacing:.2px; }
442
- .hero p{ margin:6px 0 0; color: var(--muted); font-size: 13px; line-height:1.35; }
443
-
444
- /* Card wrapper for inputs/outputs */
445
- .card{
446
- background: var(--card);
447
- border: 1px solid var(--border);
448
- border-radius: var(--radius);
449
- box-shadow: var(--shadow);
450
- padding: 14px;
451
- }
452
-
453
- .mono{ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
454
-
455
- /* Results */
456
- .result-card{
457
- background: #ffffff;
458
- border: 1px solid var(--border);
459
- border-radius: var(--radius);
460
- padding: 14px;
461
- box-shadow: var(--shadow);
462
- }
463
- .result-head{ display:flex; align-items:center; justify-content:space-between; gap:10px; margin-bottom:12px; }
464
- .result-title{ font-size: 13px; font-weight: 900; letter-spacing:.3px; }
465
-
466
- .grid2{ display:grid; grid-template-columns: 1fr 1fr; gap: 10px; }
467
- .grid3{ display:grid; grid-template-columns: 1fr 1fr 1fr; gap: 10px; }
468
-
469
- .box{
470
- border: 1px solid var(--border);
471
- background: #fbfcff;
472
- border-radius: 14px;
473
- padding: 12px;
474
- }
475
- .k{ color: var(--muted); font-size: 12px; }
476
- .v{ color: var(--text); font-size: 14px; font-weight: 800; margin-top: 3px; }
477
- .sub{ margin-top:6px; color: var(--muted); font-size: 11px; }
478
-
479
- .pill{
480
- display:flex; align-items:center; gap:8px;
481
- padding: 8px 10px;
482
- border-radius: 999px;
483
- border: 1px solid var(--border);
484
- background: #ffffff;
485
- font-size: 12px;
486
- white-space: nowrap;
487
- }
488
- .pill .dot{ width:10px; height:10px; border-radius:999px; background: rgba(100,116,139,.25); }
489
- .pill.ok{ border-color: rgba(22,163,74,.25); }
490
- .pill.ok .dot{ background: var(--ok); }
491
- .pill.warn{ border-color: rgba(217,119,6,.25); }
492
- .pill.warn .dot{ background: var(--warn); }
493
- .pill.neutral{ border-color: rgba(37,99,235,.20); }
494
- .pill.neutral .dot{ background: var(--accent); }
495
- .pill-ic{ font-weight: 900; }
496
-
497
- .prob-row{ display:flex; align-items:center; gap: 10px; margin-top: 6px; }
498
- .prob-bar{
499
- flex: 1;
500
- height: 10px;
501
- border-radius: 999px;
502
- background: #eef2ff;
503
- border: 1px solid rgba(37,99,235,.15);
504
- overflow: hidden;
505
- }
506
- .prob-fill{
507
- height: 100%;
508
- background: linear-gradient(90deg, rgba(37,99,235,.95), rgba(22,163,74,.85));
509
- border-radius: 999px;
510
- }
511
- .prob-txt{ width: 56px; text-align:right; color: var(--text); font-weight: 900; }
512
-
513
- .fine{
514
- margin-top: 12px;
515
- font-size: 11px;
516
- color: var(--muted);
517
- line-height: 1.35;
518
- }
519
-
520
- .hint{
521
- margin-top: 6px;
522
- font-size: 12px;
523
- color: var(--muted);
524
- padding: 8px 10px;
525
- border: 1px dashed rgba(100,116,139,.35);
526
- border-radius: 12px;
527
- background: #ffffff;
528
- }
529
-
530
- /* Primary button styling + full width on mobile */
531
- #primaryBtn button{
532
- border-radius: 14px !important;
533
- border: 1px solid rgba(37,99,235,.35) !important;
534
- background: var(--accent) !important;
535
- color: white !important;
536
- font-weight: 900 !important;
537
- }
538
- @media (max-width: 740px){
539
- #primaryBtn button{ width: 100% !important; }
540
- .grid2{ grid-template-columns: 1fr; }
541
- .grid3{ grid-template-columns: 1fr; }
542
- .result-head{ flex-direction: column; align-items: flex-start; }
543
- .gr-column{ min-width: 100%; }
544
- }
545
- """
546
 
547
  theme = gr.themes.Base(
548
  primary_hue="blue",
@@ -557,82 +450,83 @@ theme = gr.themes.Base(
557
  # =========================
558
  with gr.Blocks(theme=theme, css=CSS, title="CI Outcome Predictor") as demo:
559
  with gr.Column(elem_id="wrap"):
560
- gr.HTML("""
561
- <div class="hero">
562
- <h1>CI Outcome Predictor</h1>
563
- <p>Minimal UI for single and batch predictions. Gene options are loaded from the main dataset. Age parsing is shown transparently.</p>
564
- </div>
565
- """)
566
-
567
- with gr.Tabs():
568
- with gr.Tab("Single Prediction"):
569
- with gr.Row():
570
- with gr.Column(scale=1):
571
- with gr.Group(elem_classes=["card"]):
572
- gene_in = gr.Dropdown(
573
- choices=gene_choices,
574
- value=gene_choices[0] if gene_choices else None,
575
- label="Gene",
576
- filterable=True,
577
- )
578
- age_in = gr.Textbox(
579
- label="Age",
580
- placeholder="Examples: 1.11 | 1.6YRS | 2.3"
581
- )
582
- parse_mode = gr.Radio(
583
- choices=[
584
- "Decimal (1.11 = 1.11 years)",
585
- "Years.Months (1.11 = 1y 11m)"
586
- ],
587
- value="Decimal (1.11 = 1.11 years)",
588
- label="Age parsing"
589
- )
590
-
591
- age_hint = gr.HTML(value=age_preview("", "Decimal (1.11 = 1.11 years)"))
592
-
593
- btn = gr.Button("Run Prediction", elem_id="primaryBtn")
594
-
595
- with gr.Column(scale=1):
596
- single_out = gr.HTML(value="", elem_classes=["card"])
597
-
598
- # Live preview of how age will be interpreted
599
- age_in.change(fn=age_preview, inputs=[age_in, parse_mode], outputs=[age_hint])
600
- parse_mode.change(fn=age_preview, inputs=[age_in, parse_mode], outputs=[age_hint])
601
-
602
- btn.click(
603
- fn=predict_single,
604
- inputs=[gene_in, age_in, parse_mode],
605
- outputs=[single_out]
606
- )
607
-
608
- with gr.Tab("Batch Prediction (CSV)"):
609
- with gr.Group(elem_classes=["card"]):
610
- gr.Markdown(
611
- "**Minimum required columns:** `Gene`, `Age` \n"
612
- f"**Model feature columns (auto-filled if missing):** `{len(input_cols)}` total",
613
- elem_classes=["mono"]
614
  )
615
 
616
- parse_mode_b = gr.Radio(
617
- choices=[
618
- "Decimal (1.11 = 1.11 years)",
619
- "Years.Months (1.11 = 1y 11m)"
620
- ],
621
- value="Decimal (1.11 = 1.11 years)",
622
- label="Age parsing"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  )
624
 
625
- csv_in = gr.File(file_types=[".csv"], label="Upload CSV")
626
- run_b = gr.Button("Run Batch Prediction", elem_id="primaryBtn")
627
-
628
- batch_summary = gr.HTML(value="")
629
- preview = gr.Dataframe(label="Preview (first 20 rows)", wrap=True)
630
- out_file = gr.File(label="Download predictions_output.csv")
631
-
632
- run_b.click(
633
- fn=predict_batch,
634
- inputs=[csv_in, parse_mode_b],
635
- outputs=[batch_summary, preview, out_file]
636
- )
637
-
638
- demo.launch(share=True)
 
1
+ import os
 
 
 
 
 
 
 
 
2
  import re
3
  import pickle
4
  import joblib
5
  import numpy as np
6
  import pandas as pd
7
  import gradio as gr
8
+ from pathlib import Path
9
 
10
  # =========================
11
+ # PATHS (Hugging Face Spaces safe)
12
  # =========================
13
+ BASE_DIR = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
14
+
15
+ # Put your files either in repo root OR in ./assets/
16
+ ASSETS_DIR = BASE_DIR / "assets"
17
+ if not ASSETS_DIR.exists():
18
+ ASSETS_DIR = BASE_DIR
19
+
20
+ VAL_CSV_PATH = ASSETS_DIR / "validation_data.csv"
21
+ MAIN_CSV_PATH = ASSETS_DIR / "Cochlear_Implant_Dataset.csv"
22
+ CLF_PKL_PATH = ASSETS_DIR / "ci_success_classifier.pkl"
23
+ REG_PKL_PATH = ASSETS_DIR / "ci_speech_score_regressor.pkl"
24
+
25
+ # Batch output: /tmp is writable on HF Spaces
26
+ BATCH_OUT_PATH = Path("/tmp/predictions_output.csv")
27
+
28
+
29
+ def _require_file(path: Path, label: str):
30
+ if not path.exists():
31
+ raise FileNotFoundError(
32
+ f"Missing required file: {label}. "
33
+ f"Expected at: {path}. "
34
+ f"Upload it to your Space repo (recommended: /assets folder)."
35
+ )
36
 
37
  # =========================
38
+ # Load data + models (guarded for HF)
39
  # =========================
40
+ APP_READY = True
41
+ APP_ERROR_MSG = ""
42
 
43
+ try:
44
+ _require_file(VAL_CSV_PATH, "validation_data.csv")
45
+ _require_file(MAIN_CSV_PATH, "Cochlear_Implant_Dataset.csv")
46
+ _require_file(CLF_PKL_PATH, "ci_success_classifier.pkl")
47
+ _require_file(REG_PKL_PATH, "ci_speech_score_regressor.pkl")
 
48
 
49
+ val_df = pd.read_csv(VAL_CSV_PATH)
50
+ main_df = pd.read_csv(MAIN_CSV_PATH)
51
+
52
+ def load_model(path: Path):
53
+ try:
54
+ return joblib.load(path)
55
+ except Exception:
56
+ with open(path, "rb") as f:
57
+ return pickle.load(f)
58
+
59
+ clf_model = load_model(CLF_PKL_PATH)
60
+ reg_model = load_model(REG_PKL_PATH)
61
+
62
+ except Exception:
63
+ APP_READY = False
64
+ # Keep errors user-safe (no stacktraces); admins can view logs in HF
65
+ APP_ERROR_MSG = (
66
+ "This app is not configured yet. Please upload the required model and dataset files to the Space.\n\n"
67
+ "Required files:\n"
68
+ "- validation_data.csv\n"
69
+ "- Cochlear_Implant_Dataset.csv\n"
70
+ "- ci_success_classifier.pkl\n"
71
+ "- ci_speech_score_regressor.pkl\n\n"
72
+ "Recommended location: a folder named 'assets' in the Space repo."
73
+ )
74
 
75
+ # =========================
76
+ # Feature name extraction
77
+ # =========================
78
  def get_model_feature_names(m):
79
  if hasattr(m, "feature_names_in_"):
80
  return list(getattr(m, "feature_names_in_"))
 
84
  return list(step.feature_names_in_)
85
  return None
86
 
87
+ # If app isn't ready, define minimal placeholders to avoid NameErrors
88
+ if not APP_READY:
89
+ val_df = pd.DataFrame()
90
+ main_df = pd.DataFrame()
91
+ clf_model = None
92
+ reg_model = None
93
+ clf_expected, reg_expected, input_cols = [], [], []
94
+ else:
95
+ clf_expected = get_model_feature_names(clf_model) or []
96
+ reg_expected = get_model_feature_names(reg_model) or []
97
+
98
+ # Union of expected columns (preserve order)
99
+ input_cols = []
100
+ for colset in [clf_expected, reg_expected]:
101
+ for c in colset:
102
+ if c not in input_cols:
103
+ input_cols.append(c)
104
+ if not input_cols:
105
+ input_cols = list(val_df.columns)
106
 
107
  # =========================
108
  # Build Gene dropdown choices from MAIN dataset
 
123
  "": np.nan, "nan": np.nan, "NaN": np.nan})
124
  )
125
 
 
126
  gene_choices = []
127
+ if APP_READY:
128
+ gene_col_main = find_gene_column(main_df)
129
+ if gene_col_main is not None:
130
+ gene_choices = sorted(set(normalize_str_series(main_df[gene_col_main]).dropna().tolist()))
131
+ if not gene_choices:
132
+ gene_col_val = find_gene_column(val_df)
133
+ if gene_col_val is not None:
134
+ gene_choices = sorted(set(normalize_str_series(val_df[gene_col_val]).dropna().tolist()))
135
 
136
  # =========================
137
  # Helpers
138
  # =========================
139
  def parse_age_to_years(age_raw: str, mode: str):
 
 
 
 
 
 
140
  if age_raw is None:
141
  return np.nan
142
 
 
152
  except:
153
  return np.nan
154
 
 
155
  if cleaned.count(".") == 1:
156
  a, b = cleaned.split(".")
157
  if a.isdigit() and b.isdigit() and len(b) == 2:
 
159
  months = int(b)
160
  if 0 <= months <= 11:
161
  return years + months / 12.0
 
162
  try:
163
  return float(cleaned)
164
  except:
 
176
  return None
177
 
178
  def get_gene_feature_name(cols):
 
179
  for c in cols:
180
  if c.lower() == "gene":
181
  return c
 
182
  for c in cols:
183
  if "gene" in c.lower():
184
  return c
 
187
  def get_age_feature_names(cols):
188
  return [c for c in cols if "age" in c.lower()]
189
 
190
+ GENE_FEAT = get_gene_feature_name(input_cols) if APP_READY else None
191
+ AGE_FEATS = get_age_feature_names(input_cols) if APP_READY else []
192
 
193
  def align_to_expected(df: pd.DataFrame, expected_cols):
194
  if not expected_cols:
 
279
  """
280
 
281
  def predict_single(gene, age_text, parse_mode):
282
+ if not APP_READY:
283
+ raise gr.Error("App is not configured. Please upload required files to the Space.")
284
+
285
  if gene is None or str(gene).strip() == "":
286
  raise gr.Error("Please select a Gene.")
287
 
 
289
  if not (isinstance(age_used, (float, np.floating)) and np.isfinite(age_used)):
290
  raise gr.Error("Please enter a valid Age (e.g., 1.6YRS, 1.11, 2.3).")
291
 
 
292
  row = {}
293
  for c in input_cols:
294
  if GENE_FEAT and c == GENE_FEAT:
 
313
  return render_single_result_html(gene, age_text, age_used, parse_mode, label, prob, speech)
314
 
315
  def _file_to_path(file_obj):
 
316
  if file_obj is None:
317
  return None
318
  if isinstance(file_obj, str):
 
324
  return None
325
 
326
  def predict_batch(csv_file, parse_mode):
327
+ if not APP_READY:
328
+ raise gr.Error("App is not configured. Please upload required files to the Space.")
329
+
330
  path = _file_to_path(csv_file)
331
  if not path:
332
  raise gr.Error("Please upload a CSV file.")
 
337
 
338
  df_cols_lower = {c.lower(): c for c in df.columns}
339
 
 
 
340
  gene_col = None
341
  if GENE_FEAT and GENE_FEAT.lower() in df_cols_lower:
342
  gene_col = df_cols_lower[GENE_FEAT.lower()]
343
  else:
 
344
  for c in df.columns:
345
  if "gene" in c.lower():
346
  gene_col = c
 
348
  if gene_col is None:
349
  raise gr.Error("CSV must include a Gene column (e.g., 'Gene').")
350
 
 
351
  age_source_col = None
352
  for c in df.columns:
353
  if "age" in c.lower():
 
356
  if age_source_col is None:
357
  raise gr.Error("CSV must include an Age column (e.g., 'Age').")
358
 
 
359
  X = pd.DataFrame(index=df.index)
360
  parsed_age = df[age_source_col].apply(lambda v: parse_age_to_years(v, parse_mode))
361
 
 
369
  elif col in AGE_FEATS:
370
  X[col] = parsed_age
371
  else:
 
372
  src = df_cols_lower.get(col.lower())
373
  X[col] = df[src] if src is not None else np.nan
374
 
 
385
 
386
  out["speech_score_pred"] = reg_model.predict(Xr)
387
 
388
+ out.to_csv(BATCH_OUT_PATH, index=False)
 
389
 
390
  n = len(out)
391
  succ = int((out["success_label_pred"] == 1).sum())
 
422
  <div class="fine">Download the output CSV below.</div>
423
  </div>
424
  """
425
+ return summary, out.head(20), str(BATCH_OUT_PATH)
426
 
427
  def age_preview(age_text, parse_mode):
428
  v = parse_age_to_years(age_text, parse_mode)
 
431
  return "<div class='hint'>Model will use: <span class='mono'>—</span></div>"
432
 
433
  # =========================
434
+ # CSS (unchanged)
435
  # =========================
436
+ CSS = """<YOUR EXISTING CSS HERE>"""
437
+ # ↑ Keep your CSS block exactly as-is.
438
+ # (I’m not re-pasting it here to keep the patch focused.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  theme = gr.themes.Base(
441
  primary_hue="blue",
 
450
  # =========================
451
  with gr.Blocks(theme=theme, css=CSS, title="CI Outcome Predictor") as demo:
452
  with gr.Column(elem_id="wrap"):
453
+ if not APP_READY:
454
+ gr.Markdown(APP_ERROR_MSG)
455
+ else:
456
+ gr.HTML("""
457
+ <div class="hero">
458
+ <h1>CI Outcome Predictor</h1>
459
+ <p>Single and batch predictions. Gene options are loaded from the dataset. Age parsing is shown transparently.</p>
460
+ </div>
461
+ """)
462
+
463
+ with gr.Tabs():
464
+ with gr.Tab("Single Prediction"):
465
+ with gr.Row():
466
+ with gr.Column(scale=1):
467
+ with gr.Group(elem_classes=["card"]):
468
+ gene_in = gr.Dropdown(
469
+ choices=gene_choices,
470
+ value=gene_choices[0] if gene_choices else None,
471
+ label="Gene",
472
+ filterable=True,
473
+ )
474
+ age_in = gr.Textbox(
475
+ label="Age",
476
+ placeholder="Examples: 1.11 | 1.6YRS | 2.3"
477
+ )
478
+ parse_mode = gr.Radio(
479
+ choices=[
480
+ "Decimal (1.11 = 1.11 years)",
481
+ "Years.Months (1.11 = 1y 11m)"
482
+ ],
483
+ value="Decimal (1.11 = 1.11 years)",
484
+ label="Age format"
485
+ )
486
+
487
+ age_hint = gr.HTML(value=age_preview("", "Decimal (1.11 = 1.11 years)"))
488
+ btn = gr.Button("Run Prediction", elem_id="primaryBtn")
489
+
490
+ with gr.Column(scale=1):
491
+ single_out = gr.HTML(value="", elem_classes=["card"])
492
+
493
+ age_in.change(fn=age_preview, inputs=[age_in, parse_mode], outputs=[age_hint])
494
+ parse_mode.change(fn=age_preview, inputs=[age_in, parse_mode], outputs=[age_hint])
495
+
496
+ btn.click(
497
+ fn=predict_single,
498
+ inputs=[gene_in, age_in, parse_mode],
499
+ outputs=[single_out]
 
 
 
 
 
 
 
500
  )
501
 
502
+ with gr.Tab("Batch Prediction (CSV)"):
503
+ with gr.Group(elem_classes=["card"]):
504
+ gr.Markdown(
505
+ "**Required columns:** `Gene`, `Age`",
506
+ elem_classes=["mono"]
507
+ )
508
+
509
+ parse_mode_b = gr.Radio(
510
+ choices=[
511
+ "Decimal (1.11 = 1.11 years)",
512
+ "Years.Months (1.11 = 1y 11m)"
513
+ ],
514
+ value="Decimal (1.11 = 1.11 years)",
515
+ label="Age format"
516
+ )
517
+
518
+ csv_in = gr.File(file_types=[".csv"], label="Upload CSV")
519
+ run_b = gr.Button("Run Batch Prediction", elem_id="primaryBtn")
520
+
521
+ batch_summary = gr.HTML(value="")
522
+ preview = gr.Dataframe(label="Preview (first 20 rows)", wrap=True)
523
+ out_file = gr.File(label="Download results")
524
+
525
+ run_b.click(
526
+ fn=predict_batch,
527
+ inputs=[csv_in, parse_mode_b],
528
+ outputs=[batch_summary, preview, out_file]
529
  )
530
 
531
+ # Hugging Face Spaces provides the external URL; don't use share=True there.
532
+ demo.launch(show_error=False, quiet=True)