Update app.py
Browse files
app.py
CHANGED
|
@@ -32,7 +32,7 @@ from sklearn.linear_model import LogisticRegression
|
|
| 32 |
from sklearn.model_selection import train_test_split
|
| 33 |
|
| 34 |
#Figures setting block
|
| 35 |
-
|
| 36 |
|
| 37 |
|
| 38 |
|
|
@@ -47,6 +47,7 @@ def make_fig(figsize=(5.5, 3.6), dpi=120):
|
|
| 47 |
|
| 48 |
|
| 49 |
def fig_to_png_bytes(fig, dpi=600):
|
|
|
|
| 50 |
buf = io.BytesIO()
|
| 51 |
fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
|
| 52 |
buf.seek(0)
|
|
@@ -62,6 +63,7 @@ def render_plot_with_download(
|
|
| 62 |
export_dpi: int = 600,
|
| 63 |
key: Optional[str] = None
|
| 64 |
):
|
|
|
|
| 65 |
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
|
| 66 |
st.pyplot(fig, clear_figure=False)
|
| 67 |
st.download_button(
|
|
@@ -76,6 +78,7 @@ def render_plot_with_download(
|
|
| 76 |
|
| 77 |
|
| 78 |
|
|
|
|
| 79 |
# ============================================================
|
| 80 |
# Fixed schema definition (PLACEHOLDER FRAMEWORK)
|
| 81 |
# ============================================================
|
|
@@ -191,6 +194,82 @@ def find_col(df: pd.DataFrame, candidates: list[str]) -> str | None:
|
|
| 191 |
return lookup[k]
|
| 192 |
return None
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# ============================================================
|
| 196 |
# Model pipeline
|
|
@@ -684,82 +763,33 @@ def train_and_save(
|
|
| 684 |
|
| 685 |
survival_trained = False
|
| 686 |
surv_notes = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
|
| 688 |
if time_days is not None and event01 is not None:
|
| 689 |
try:
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
df_surv.loc[df_surv[c].isna(), c] = np.nan
|
| 700 |
-
df_surv[c] = df_surv[c].map(lambda v: v if pd.isna(v) else str(v))
|
| 701 |
-
|
| 702 |
-
df_surv["time_days"] = time_days
|
| 703 |
-
df_surv["event"] = event01
|
| 704 |
-
|
| 705 |
-
df_surv = df_surv.dropna(subset=["time_days", "event"])
|
| 706 |
-
|
| 707 |
-
# one-hot
|
| 708 |
-
# one-hot (THIS LINE IS MISSING IN YOUR CODE)
|
| 709 |
-
df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True)
|
| 710 |
-
|
| 711 |
-
# Remove duplicate columns (good safety)
|
| 712 |
-
df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy()
|
| 713 |
-
|
| 714 |
-
duration_col = "time_days"
|
| 715 |
-
event_col = "event"
|
| 716 |
-
X_cols = [c for c in df_surv_oh.columns if c not in (duration_col, event_col)]
|
| 717 |
-
|
| 718 |
-
# Force numeric predictors
|
| 719 |
-
df_surv_oh[X_cols] = df_surv_oh[X_cols].apply(pd.to_numeric, errors="coerce")
|
| 720 |
-
|
| 721 |
-
# Impute (training-time)
|
| 722 |
-
imp = SimpleImputer(strategy="median")
|
| 723 |
-
|
| 724 |
-
# DEBUG
|
| 725 |
-
st.write("X_cols count:", len(X_cols))
|
| 726 |
-
st.write("df_surv_oh[X_cols] shape:", df_surv_oh[X_cols].shape)
|
| 727 |
-
|
| 728 |
-
X_imp = imp.fit_transform(df_surv_oh[X_cols])
|
| 729 |
-
st.write("imputer output shape:", X_imp.shape)
|
| 730 |
-
|
| 731 |
-
# Assign back (preserve index/columns)
|
| 732 |
-
df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_imp, columns=X_cols, index=df_surv_oh.index)
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
# fit Cox
|
| 737 |
-
cph = CoxPHFitter(penalizer=0.1)
|
| 738 |
-
cph.fit(df_surv_oh[[*X_cols, duration_col, event_col]], duration_col=duration_col, event_col=event_col)
|
| 739 |
-
|
| 740 |
-
bundle = {
|
| 741 |
-
"model": cph,
|
| 742 |
-
"columns": X_cols, # predictors only
|
| 743 |
-
"imputer": imp,
|
| 744 |
-
"cat_cols": cat_cols,
|
| 745 |
-
"num_cols": num_cols,
|
| 746 |
-
"feature_cols": feature_cols,
|
| 747 |
-
"duration_col": duration_col,
|
| 748 |
-
"event_col": event_col,
|
| 749 |
-
"version": 1
|
| 750 |
-
}
|
| 751 |
joblib.dump(bundle, "survival_bundle.joblib", compress=3)
|
| 752 |
-
|
| 753 |
-
|
| 754 |
survival_trained = True
|
| 755 |
-
surv_notes = "Survival model trained successfully."
|
| 756 |
-
|
| 757 |
except Exception as e:
|
| 758 |
survival_trained = False
|
| 759 |
surv_notes = f"Survival model training failed: {e}"
|
| 760 |
else:
|
| 761 |
surv_notes = "Survival columns missing or could not be parsed; survival model not trained."
|
| 762 |
|
|
|
|
| 763 |
|
| 764 |
|
| 765 |
|
|
@@ -833,18 +863,38 @@ def train_and_save(
|
|
| 833 |
# SHAP
|
| 834 |
# ============================================================
|
| 835 |
def build_shap_explainer(pipe, X_bg, max_bg=200):
|
| 836 |
-
import shap
|
| 837 |
-
|
| 838 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
|
| 840 |
clf = pipe.named_steps["clf"]
|
| 841 |
Xt_bg = transform_before_clf(pipe, X_bg)
|
| 842 |
|
| 843 |
explainer = shap.LinearExplainer(
|
| 844 |
-
clf,
|
|
|
|
|
|
|
| 845 |
)
|
| 846 |
return explainer
|
| 847 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 848 |
def ensure_model_repo_exists(model_repo_id: str, token: str):
|
| 849 |
"""
|
| 850 |
Optional helper: create the model repo if it doesn't exist.
|
|
@@ -1324,32 +1374,27 @@ def normalize_country_name(x: str) -> str | None:
|
|
| 1324 |
|
| 1325 |
from typing import Optional
|
| 1326 |
|
| 1327 |
-
def country_to_region(country:
|
| 1328 |
"""
|
| 1329 |
-
|
| 1330 |
-
Returns one of: Africa, Americas, Asia, Europe, Oceania, Unknown
|
| 1331 |
-
Lazy-imports country_converter to reduce startup memory.
|
| 1332 |
"""
|
| 1333 |
-
if not country
|
| 1334 |
return REGION_UNKNOWN
|
| 1335 |
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
import country_converter as coco # lazy import
|
| 1339 |
|
| 1340 |
r = coco.convert(names=country, to="continent")
|
| 1341 |
-
|
| 1342 |
if not r or str(r).lower() in ("not found", "nan", "none"):
|
| 1343 |
return REGION_UNKNOWN
|
| 1344 |
|
| 1345 |
if r == "America":
|
| 1346 |
return "Americas"
|
| 1347 |
-
|
| 1348 |
return str(r)
|
| 1349 |
|
| 1350 |
|
|
|
|
| 1351 |
def add_ethnicity_region(df: pd.DataFrame, eth_col: str = "Ethnicity", out_col: str = "Ethnicity_Region") -> pd.DataFrame:
|
| 1352 |
-
"""Adds an analytics-only region column derived from the Ethnicity/nationality column."""
|
| 1353 |
if eth_col not in df.columns:
|
| 1354 |
df[out_col] = REGION_UNKNOWN
|
| 1355 |
return df
|
|
@@ -1359,6 +1404,7 @@ def add_ethnicity_region(df: pd.DataFrame, eth_col: str = "Ethnicity", out_col:
|
|
| 1359 |
return df
|
| 1360 |
|
| 1361 |
|
|
|
|
| 1362 |
# ============================================================
|
| 1363 |
# Streamlit UI
|
| 1364 |
# ============================================================
|
|
@@ -3120,7 +3166,8 @@ with tab_predict:
|
|
| 3120 |
|
| 3121 |
# Dense conversion once (used for summary + waterfalls)
|
| 3122 |
try:
|
| 3123 |
-
X_dense =
|
|
|
|
| 3124 |
except Exception:
|
| 3125 |
X_dense = np.array(X_batch_t)
|
| 3126 |
|
|
|
|
| 32 |
from sklearn.model_selection import train_test_split
|
| 33 |
|
| 34 |
#Figures setting block
|
| 35 |
+
|
| 36 |
|
| 37 |
|
| 38 |
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def fig_to_png_bytes(fig, dpi=600):
|
| 50 |
+
import io
|
| 51 |
buf = io.BytesIO()
|
| 52 |
fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
|
| 53 |
buf.seek(0)
|
|
|
|
| 63 |
export_dpi: int = 600,
|
| 64 |
key: Optional[str] = None
|
| 65 |
):
|
| 66 |
+
import matplotlib.pyplot as plt # lazy
|
| 67 |
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
|
| 68 |
st.pyplot(fig, clear_figure=False)
|
| 69 |
st.download_button(
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
|
| 81 |
+
|
| 82 |
# ============================================================
|
| 83 |
# Fixed schema definition (PLACEHOLDER FRAMEWORK)
|
| 84 |
# ============================================================
|
|
|
|
| 194 |
return lookup[k]
|
| 195 |
return None
|
| 196 |
|
| 197 |
+
def train_survival_bundle(
|
| 198 |
+
df: pd.DataFrame,
|
| 199 |
+
feature_cols: list[str],
|
| 200 |
+
num_cols: list[str],
|
| 201 |
+
cat_cols: list[str],
|
| 202 |
+
time_days: np.ndarray,
|
| 203 |
+
event01: np.ndarray,
|
| 204 |
+
*,
|
| 205 |
+
penalizer: float = 0.1
|
| 206 |
+
):
|
| 207 |
+
"""
|
| 208 |
+
Returns (bundle_dict, notes). Raises exceptions if hard-fail.
|
| 209 |
+
Lazy-imports lifelines.
|
| 210 |
+
"""
|
| 211 |
+
from lifelines import CoxPHFitter # lazy
|
| 212 |
+
from sklearn.impute import SimpleImputer # light, ok here too
|
| 213 |
+
|
| 214 |
+
# build survival DF
|
| 215 |
+
df_surv = df[feature_cols].copy().replace({pd.NA: np.nan})
|
| 216 |
+
|
| 217 |
+
# coerce numeric/cat
|
| 218 |
+
for c in num_cols:
|
| 219 |
+
if c in df_surv.columns:
|
| 220 |
+
df_surv[c] = pd.to_numeric(df_surv[c], errors="coerce")
|
| 221 |
+
|
| 222 |
+
for c in cat_cols:
|
| 223 |
+
if c in df_surv.columns:
|
| 224 |
+
df_surv[c] = df_surv[c].astype("object")
|
| 225 |
+
df_surv.loc[df_surv[c].isna(), c] = np.nan
|
| 226 |
+
df_surv[c] = df_surv[c].map(lambda v: v if pd.isna(v) else str(v))
|
| 227 |
+
|
| 228 |
+
df_surv["time_days"] = time_days
|
| 229 |
+
df_surv["event"] = event01
|
| 230 |
+
df_surv = df_surv.dropna(subset=["time_days", "event"])
|
| 231 |
+
|
| 232 |
+
duration_col = "time_days"
|
| 233 |
+
event_col = "event"
|
| 234 |
+
|
| 235 |
+
# one-hot
|
| 236 |
+
df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True)
|
| 237 |
+
|
| 238 |
+
# remove duplicate columns if any messy headers caused duplicates
|
| 239 |
+
df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy()
|
| 240 |
+
|
| 241 |
+
# predictor columns
|
| 242 |
+
X_cols = [c for c in df_surv_oh.columns if c not in (duration_col, event_col)]
|
| 243 |
+
|
| 244 |
+
# force numeric for Cox predictors
|
| 245 |
+
df_surv_oh[X_cols] = df_surv_oh[X_cols].apply(pd.to_numeric, errors="coerce")
|
| 246 |
+
|
| 247 |
+
# impute predictors
|
| 248 |
+
imp = SimpleImputer(strategy="median")
|
| 249 |
+
X_imp = imp.fit_transform(df_surv_oh[X_cols])
|
| 250 |
+
|
| 251 |
+
# assign back safely with same columns + index
|
| 252 |
+
df_surv_oh.loc[:, X_cols] = pd.DataFrame(X_imp, columns=X_cols, index=df_surv_oh.index)
|
| 253 |
+
|
| 254 |
+
# fit Cox
|
| 255 |
+
cph = CoxPHFitter(penalizer=float(penalizer))
|
| 256 |
+
cph.fit(df_surv_oh[[*X_cols, duration_col, event_col]],
|
| 257 |
+
duration_col=duration_col,
|
| 258 |
+
event_col=event_col)
|
| 259 |
+
|
| 260 |
+
bundle = {
|
| 261 |
+
"model": cph,
|
| 262 |
+
"columns": X_cols, # predictors only
|
| 263 |
+
"imputer": imp, # fitted imputer
|
| 264 |
+
"cat_cols": cat_cols,
|
| 265 |
+
"num_cols": num_cols,
|
| 266 |
+
"feature_cols": feature_cols,
|
| 267 |
+
"duration_col": duration_col,
|
| 268 |
+
"event_col": event_col,
|
| 269 |
+
"version": 1
|
| 270 |
+
}
|
| 271 |
+
return bundle, "Survival model trained successfully."
|
| 272 |
+
|
| 273 |
|
| 274 |
# ============================================================
|
| 275 |
# Model pipeline
|
|
|
|
| 763 |
|
| 764 |
survival_trained = False
|
| 765 |
surv_notes = None
|
| 766 |
+
surv_used_cols = None
|
| 767 |
+
|
| 768 |
+
try:
|
| 769 |
+
time_days, event01, surv_used_cols = build_survival_targets(df)
|
| 770 |
+
except Exception:
|
| 771 |
+
time_days, event01, surv_used_cols = None, None, None
|
| 772 |
|
| 773 |
if time_days is not None and event01 is not None:
|
| 774 |
try:
|
| 775 |
+
bundle, surv_notes = train_survival_bundle(
|
| 776 |
+
df=df,
|
| 777 |
+
feature_cols=feature_cols,
|
| 778 |
+
num_cols=num_cols,
|
| 779 |
+
cat_cols=cat_cols,
|
| 780 |
+
time_days=time_days,
|
| 781 |
+
event01=event01,
|
| 782 |
+
penalizer=0.1
|
| 783 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 784 |
joblib.dump(bundle, "survival_bundle.joblib", compress=3)
|
|
|
|
|
|
|
| 785 |
survival_trained = True
|
|
|
|
|
|
|
| 786 |
except Exception as e:
|
| 787 |
survival_trained = False
|
| 788 |
surv_notes = f"Survival model training failed: {e}"
|
| 789 |
else:
|
| 790 |
surv_notes = "Survival columns missing or could not be parsed; survival model not trained."
|
| 791 |
|
| 792 |
+
|
| 793 |
|
| 794 |
|
| 795 |
|
|
|
|
| 863 |
# SHAP
|
| 864 |
# ============================================================
|
| 865 |
def build_shap_explainer(pipe, X_bg, max_bg=200):
|
| 866 |
+
import shap # lazy
|
| 867 |
+
|
| 868 |
+
if X_bg is None or len(X_bg) == 0:
|
| 869 |
+
raise ValueError("SHAP background is empty.")
|
| 870 |
+
|
| 871 |
+
if len(X_bg) > int(max_bg):
|
| 872 |
+
X_bg = X_bg.sample(int(max_bg), random_state=42)
|
| 873 |
|
| 874 |
clf = pipe.named_steps["clf"]
|
| 875 |
Xt_bg = transform_before_clf(pipe, X_bg)
|
| 876 |
|
| 877 |
explainer = shap.LinearExplainer(
|
| 878 |
+
clf,
|
| 879 |
+
Xt_bg,
|
| 880 |
+
feature_perturbation="interventional"
|
| 881 |
)
|
| 882 |
return explainer
|
| 883 |
|
| 884 |
+
|
| 885 |
+
def safe_dense(Xt, max_rows: int = 2000):
|
| 886 |
+
"""
|
| 887 |
+
Convert sparse->dense carefully. Avoid converting huge matrices to dense.
|
| 888 |
+
"""
|
| 889 |
+
if hasattr(Xt, "shape") and Xt.shape[0] > max_rows:
|
| 890 |
+
Xt = Xt[:max_rows]
|
| 891 |
+
try:
|
| 892 |
+
return Xt.toarray()
|
| 893 |
+
except Exception:
|
| 894 |
+
return np.array(Xt)
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
|
| 898 |
def ensure_model_repo_exists(model_repo_id: str, token: str):
|
| 899 |
"""
|
| 900 |
Optional helper: create the model repo if it doesn't exist.
|
|
|
|
| 1374 |
|
| 1375 |
from typing import Optional
|
| 1376 |
|
| 1377 |
+
def country_to_region(country: str | None) -> str:
|
| 1378 |
"""
|
| 1379 |
+
Lazy-import country_converter to reduce startup memory.
|
| 1380 |
+
Returns one of: Africa, Americas, Asia, Europe, Oceania, Unknown
|
|
|
|
| 1381 |
"""
|
| 1382 |
+
if not country:
|
| 1383 |
return REGION_UNKNOWN
|
| 1384 |
|
| 1385 |
+
import country_converter as coco # lazy
|
|
|
|
|
|
|
| 1386 |
|
| 1387 |
r = coco.convert(names=country, to="continent")
|
|
|
|
| 1388 |
if not r or str(r).lower() in ("not found", "nan", "none"):
|
| 1389 |
return REGION_UNKNOWN
|
| 1390 |
|
| 1391 |
if r == "America":
|
| 1392 |
return "Americas"
|
|
|
|
| 1393 |
return str(r)
|
| 1394 |
|
| 1395 |
|
| 1396 |
+
|
| 1397 |
def add_ethnicity_region(df: pd.DataFrame, eth_col: str = "Ethnicity", out_col: str = "Ethnicity_Region") -> pd.DataFrame:
|
|
|
|
| 1398 |
if eth_col not in df.columns:
|
| 1399 |
df[out_col] = REGION_UNKNOWN
|
| 1400 |
return df
|
|
|
|
| 1404 |
return df
|
| 1405 |
|
| 1406 |
|
| 1407 |
+
|
| 1408 |
# ============================================================
|
| 1409 |
# Streamlit UI
|
| 1410 |
# ============================================================
|
|
|
|
| 3166 |
|
| 3167 |
# Dense conversion once (used for summary + waterfalls)
|
| 3168 |
try:
|
| 3169 |
+
X_dense = safe_dense(X_batch_t, max_rows=200)
|
| 3170 |
+
|
| 3171 |
except Exception:
|
| 3172 |
X_dense = np.array(X_batch_t)
|
| 3173 |
|