# -*- coding: utf-8 -*-
"""Enhanced CI Outcome Predictor with Advanced UI (HF Spaces Ready)"""
import os
import re
import pickle
import joblib
import numpy as np
import pandas as pd
import gradio as gr
from pathlib import Path
# =========================
# PATHS (Hugging Face Spaces / repo root)
# Place these files in the SAME folder as app.py:
# - validation_data.csv
# - Cochlear_Implant_Dataset.csv
# - ci_success_classifier.pkl
# - ci_speech_score_regressor.pkl
# =========================
BASE_DIR = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
VAL_CSV_PATH = BASE_DIR / "validation_data.csv"
MAIN_CSV_PATH = BASE_DIR / "Cochlear_Implant_Dataset.csv"
CLF_PKL_PATH = BASE_DIR / "ci_success_classifier.pkl"
REG_PKL_PATH = BASE_DIR / "ci_speech_score_regressor.pkl"
# Writable location on HF Spaces
BATCH_OUT_PATH = Path("/tmp/predictions_output.csv")
def _require(p: Path) -> bool:
return p.exists() and p.is_file()
def _missing_files_message():
missing = []
for p in [VAL_CSV_PATH, MAIN_CSV_PATH, CLF_PKL_PATH, REG_PKL_PATH]:
if not _require(p):
missing.append(p.name)
return f"""
β οΈ Setup required: Missing required files in repo root:
{" ".join(missing) if missing else "β"}
Upload the files to the same folder as app.py and restart the Space.
"""
# =========================
# Load data + models (guarded so Space can boot even if files missing)
# =========================
APP_READY = all(_require(p) for p in [VAL_CSV_PATH, MAIN_CSV_PATH, CLF_PKL_PATH, REG_PKL_PATH])
val_df = pd.DataFrame()
main_df = pd.DataFrame()
clf_model = None
reg_model = None
if APP_READY:
val_df = pd.read_csv(VAL_CSV_PATH)
main_df = pd.read_csv(MAIN_CSV_PATH)
def load_model(path: Path):
try:
return joblib.load(path)
except Exception:
with open(path, "rb") as f:
return pickle.load(f)
clf_model = load_model(CLF_PKL_PATH)
reg_model = load_model(REG_PKL_PATH)
def get_model_feature_names(m):
if m is None:
return None
if hasattr(m, "feature_names_in_"):
return list(getattr(m, "feature_names_in_"))
if hasattr(m, "named_steps"):
for step in m.named_steps.values():
if hasattr(step, "feature_names_in_"):
return list(step.feature_names_in_)
return None
clf_expected = get_model_feature_names(clf_model) or []
reg_expected = get_model_feature_names(reg_model) or []
# Union of expected columns (preserve order)
input_cols = []
for colset in [clf_expected, reg_expected]:
for c in colset:
if c not in input_cols:
input_cols.append(c)
if not input_cols and APP_READY:
input_cols = list(val_df.columns)
# =========================
# Build Gene dropdown choices from MAIN dataset
# =========================
def find_gene_column(df: pd.DataFrame):
if df is None or df.empty:
return None
if "Gene" in df.columns:
return "Gene"
for c in df.columns:
if "gene" in c.lower():
return c
return None
def normalize_str_series(s: pd.Series) -> pd.Series:
return (
s.astype(str)
.str.strip()
.replace({"null": np.nan, "NULL": np.nan, "None": np.nan, "none": np.nan,
"": np.nan, "nan": np.nan, "NaN": np.nan})
)
gene_choices = []
if APP_READY:
gene_col_main = find_gene_column(main_df)
if gene_col_main is not None:
gene_choices = sorted(set(normalize_str_series(main_df[gene_col_main]).dropna().tolist()))
if not gene_choices:
gene_col_val = find_gene_column(val_df)
if gene_col_val is not None:
gene_choices = sorted(set(normalize_str_series(val_df[gene_col_val]).dropna().tolist()))
# =========================
# Helpers
# =========================
def parse_age_to_years(age_raw: str, mode: str):
"""
mode:
- "Years.Months (1.11 = 1y 11m)" -> 1 + 11/12
- "Decimal (1.11 = 1.11 years)" -> 1.11
Accepts "1.6YRS", "2yrs", etc.
"""
if age_raw is None:
return np.nan
s = str(age_raw).strip()
if s == "" or s.lower() in {"nan", "none", "null"}:
return np.nan
cleaned = re.sub(r"[^0-9\.]", "", s)
if mode.startswith("Decimal"):
try:
return float(cleaned)
except:
return np.nan
# Years.Months mode
if cleaned.count(".") == 1:
a, b = cleaned.split(".")
if a.isdigit() and b.isdigit() and len(b) == 2:
years = int(a)
months = int(b)
if 0 <= months <= 11:
return years + months / 12.0
# fallback to decimal
try:
return float(cleaned)
except:
return np.nan
try:
return float(cleaned)
except:
return np.nan
def safe_pct(x):
try:
return int(round(float(x) * 100))
except:
return None
def get_gene_feature_name(cols):
for c in cols:
if c.lower() == "gene":
return c
for c in cols:
if "gene" in c.lower():
return c
return None
def get_age_feature_names(cols):
return [c for c in cols if "age" in c.lower()]
GENE_FEAT = get_gene_feature_name(input_cols) if APP_READY else None
AGE_FEATS = get_age_feature_names(input_cols) if APP_READY else []
def align_to_expected(df: pd.DataFrame, expected_cols):
if not expected_cols:
return df
out = df.copy()
for c in expected_cols:
if c not in out.columns:
out[c] = np.nan
return out[expected_cols]
def render_single_result_html(gene, age_entered, age_used_years, parse_mode, label, prob, speech):
if label == 1:
status = "Likely Success"
badge = "ok"
icon = "β"
emoji = "π"
elif label == 0:
status = "Lower Likelihood"
badge = "warn"
icon = "!"
emoji = "β οΈ"
else:
status = "Unavailable"
badge = "neutral"
icon = "?"
emoji = "β"
prob_pct = safe_pct(prob) if prob is not None else None
prob_text = f"{prob_pct}%" if prob_pct is not None else "β"
bar_width = f"{prob_pct}%" if prob_pct is not None else "0%"
try:
speech_disp = f"{float(speech):.3f}"
except:
speech_disp = "β"
age_used_disp = f"{float(age_used_years):.3f} years" if np.isfinite(age_used_years) else "β"
gene_disp = str(gene) if gene is not None else "β"
return f"""
{emoji}
{status}
Prediction Complete
π§¬
Gene
{gene_disp}
π
Age Entered
{age_entered}
π¬
Model Age
{age_used_disp}
π
Parse Mode
{parse_mode.split('(')[0].strip()}
Success Probability{prob_text}
0%25%50%75%100%
Predicted Label
{label}
Speech Score
{speech_disp}
Informational tool only. Not medical advice. Consult healthcare professionals for clinical decisions.
"""
def predict_single(gene, age_text, parse_mode):
if not APP_READY:
return _missing_files_message()
if gene is None or str(gene).strip() == "":
raise gr.Error("Please select a Gene.")
age_used = parse_age_to_years(age_text, parse_mode)
if not (isinstance(age_used, (float, np.floating)) and np.isfinite(age_used)):
raise gr.Error("Please enter a valid Age (e.g., 1.6YRS, 1.11, 2.3).")
row = {}
for c in input_cols:
if GENE_FEAT and c == GENE_FEAT:
row[c] = gene
elif c in AGE_FEATS:
row[c] = age_used
else:
row[c] = np.nan
X = pd.DataFrame([row])
Xc = align_to_expected(X, clf_expected)
Xr = align_to_expected(X, reg_expected)
label = int(clf_model.predict(Xc)[0])
prob = None
if hasattr(clf_model, "predict_proba"):
p = clf_model.predict_proba(Xc)[0]
if len(p) >= 2:
prob = float(p[1])
speech = reg_model.predict(Xr)[0]
return render_single_result_html(gene, age_text, age_used, parse_mode, label, prob, speech)
def _file_to_path(file_obj):
"""
Gradio v3/v4 compatibility:
- Sometimes a string path
- Sometimes object with .name
- In Gradio 4: often has .path
- Sometimes dict-like with 'name' or 'path'
"""
if file_obj is None:
return None
if isinstance(file_obj, str):
return file_obj
if hasattr(file_obj, "path") and file_obj.path:
return file_obj.path
if hasattr(file_obj, "name") and file_obj.name:
return file_obj.name
if isinstance(file_obj, dict):
if file_obj.get("path"):
return file_obj["path"]
if file_obj.get("name"):
return file_obj["name"]
return None
def predict_batch(csv_file, parse_mode):
if not APP_READY:
raise gr.Error("Missing required model/data files in repo root. Please upload them and restart the Space.")
path = _file_to_path(csv_file)
if not path:
raise gr.Error("Please upload a CSV file.")
df = pd.read_csv(path)
if df.empty:
raise gr.Error("Uploaded CSV is empty.")
df_cols_lower = {c.lower(): c for c in df.columns}
gene_col = None
if GENE_FEAT and GENE_FEAT.lower() in df_cols_lower:
gene_col = df_cols_lower[GENE_FEAT.lower()]
else:
for c in df.columns:
if "gene" in c.lower():
gene_col = c
break
if gene_col is None:
raise gr.Error("CSV must include a Gene column (e.g., 'Gene').")
age_source_col = None
for c in df.columns:
if "age" in c.lower():
age_source_col = c
break
if age_source_col is None:
raise gr.Error("CSV must include an Age column (e.g., 'Age').")
X = pd.DataFrame(index=df.index)
parsed_age = df[age_source_col].apply(lambda v: parse_age_to_years(v, parse_mode))
if parsed_age.isna().any():
bad_n = int(parsed_age.isna().sum())
raise gr.Error(f"{bad_n} rows have invalid Age values for the selected parsing mode.")
for col in input_cols:
if GENE_FEAT and col == GENE_FEAT:
X[col] = df[gene_col]
elif col in AGE_FEATS:
X[col] = parsed_age
else:
src = df_cols_lower.get(col.lower())
X[col] = df[src] if src is not None else np.nan
Xc = align_to_expected(X, clf_expected)
Xr = align_to_expected(X, reg_expected)
out = df.copy()
out["success_label_pred"] = clf_model.predict(Xc)
if hasattr(clf_model, "predict_proba"):
proba = clf_model.predict_proba(Xc)
if proba.shape[1] == 2:
out["success_prob_class1"] = proba[:, 1]
out["speech_score_pred"] = reg_model.predict(Xr)
out.to_csv(BATCH_OUT_PATH, index=False)
n = len(out)
succ = int((out["success_label_pred"] == 1).sum())
succ_pct = int(round((succ / n) * 100)) if n else 0
avg_prob_txt = "β"
if "success_prob_class1" in out.columns:
try:
avg_prob_txt = f"{int(round(float(out['success_prob_class1'].mean())*100))}%"
except:
pass
avg_speech_txt = "β"
try:
avg_speech_txt = f"{float(pd.to_numeric(out['speech_score_pred'], errors='coerce').mean()):.3f}"
except:
pass
summary = f"""
πBatch Analysis Complete
{n} rows processed
β
Predicted Success
{succ}
{succ_pct}% of total
π―
Avg Probability
{avg_prob_txt}
π€
Avg Speech Score
{avg_speech_txt}
βοΈ
Parse Mode
{parse_mode.split('(')[0].strip()}
Download the complete results CSV below
"""
return summary, out.head(20), str(BATCH_OUT_PATH)
def age_preview(age_text, parse_mode):
v = parse_age_to_years(age_text, parse_mode)
if isinstance(v, (float, np.floating)) and np.isfinite(v):
return f"""
""")
theme_state = gr.State("light")
theme_btn = gr.Button("π Dark Mode", elem_id="theme-toggle-btn")
theme_btn.click(fn=None, js=JS, outputs=theme_state)
with gr.Tabs():
with gr.Tab("π― Single Prediction"):
with gr.Row():
with gr.Column(scale=1):
gr.HTML('
π Input Parameters
')
gene_in = gr.Dropdown(
choices=gene_choices,
value=gene_choices[0] if gene_choices else None,
label="𧬠Select Gene",
info="Choose the gene variant for analysis",
filterable=True,
)
age_in = gr.Textbox(
label="π Patient Age",
placeholder="Examples: 1.11 | 1.6YRS | 2.3",
info="Enter age in supported format"
)
parse_mode = gr.Radio(
choices=[
"Decimal (1.11 = 1.11 years)",
"Years.Months (1.11 = 1y 11m)"
],
value="Decimal (1.11 = 1.11 years)",
label="βοΈ Age Parsing Mode",
info="Select how age should be interpreted"
)
age_hint = gr.HTML(value=age_preview("", "Decimal (1.11 = 1.11 years)"))
btn = gr.Button("π Run Prediction", elem_id="predict-btn", size="lg")
with gr.Column(scale=1):
gr.HTML('
π Prediction Results
')
single_out = gr.HTML(
value=_missing_files_message() if not APP_READY else
"
π Enter parameters and click 'Run Prediction' to see results
π Requirements: Your CSV must include Gene and Age columns.
Additional features will be auto-filled. Model expects {len(input_cols) if APP_READY else 0} total feature columns.