Spaces:
Build error
Build error
Update inference.py
Browse files- inference.py +30 -25
inference.py
CHANGED
|
@@ -9,10 +9,13 @@ import warnings
|
|
| 9 |
|
| 10 |
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# ---------------------------------------------------------------------------
|
| 13 |
# Compatibility patch — inject _RemainderColsList if the installed sklearn
|
| 14 |
-
# version does not have it (added in sklearn 1.4+).
|
| 15 |
-
# saved with a newer sklearn to load correctly on older environments.
|
| 16 |
# ---------------------------------------------------------------------------
|
| 17 |
import sklearn.compose._column_transformer as _ct
|
| 18 |
if not hasattr(_ct, "_RemainderColsList"):
|
|
@@ -24,6 +27,7 @@ if not hasattr(_ct, "_RemainderColsList"):
|
|
| 24 |
_ct._RemainderColsList = _RemainderColsList
|
| 25 |
import sklearn.compose
|
| 26 |
sklearn.compose._RemainderColsList = _RemainderColsList
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
# ---------------------------------------------------------------------------
|
|
@@ -32,21 +36,9 @@ if not hasattr(_ct, "_RemainderColsList"):
|
|
| 32 |
|
| 33 |
NUM_COLUMNS = ["AGE", "NACS2YR"]
|
| 34 |
CATEG_COLUMNS = [
|
| 35 |
-
"AGEGPFF",
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
-
"DONORF",
|
| 39 |
-
"GRAFTYPE",
|
| 40 |
-
"CONDGRPF",
|
| 41 |
-
"CONDGRP_FINAL",
|
| 42 |
-
"ATGF",
|
| 43 |
-
"GVHD_FINAL",
|
| 44 |
-
"HLA_FINAL",
|
| 45 |
-
"RCMVPR",
|
| 46 |
-
"EXCHTFPR",
|
| 47 |
-
"VOC2YPR",
|
| 48 |
-
"VOCFRQPR",
|
| 49 |
-
"SCATXRSN",
|
| 50 |
]
|
| 51 |
|
| 52 |
FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS
|
|
@@ -82,23 +74,29 @@ DEFAULT_N_BOOT_CI = 500
|
|
| 82 |
# ---------------------------------------------------------------------------
|
| 83 |
|
| 84 |
def _load_skops_model(fname):
|
|
|
|
|
|
|
|
|
|
| 85 |
try:
|
| 86 |
untrusted = sio.get_untrusted_types(file=fname)
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
except Exception as e:
|
| 89 |
-
|
| 90 |
-
sys.exit(1)
|
| 91 |
|
| 92 |
|
|
|
|
| 93 |
preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops"))
|
| 94 |
|
|
|
|
| 95 |
classification_model_data = {}
|
| 96 |
for _o in CLASSIFICATION_OUTCOMES:
|
| 97 |
_path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops")
|
| 98 |
if os.path.exists(_path):
|
| 99 |
classification_model_data[_o] = _load_skops_model(_path)
|
| 100 |
else:
|
| 101 |
-
print(f"Warning: Model for {_o} not found at {_path}. Skipping.")
|
| 102 |
|
| 103 |
classification_models = {o: d["models"] for o, d in classification_model_data.items()}
|
| 104 |
betas = {o: d["beta"] for o, d in classification_model_data.items()}
|
|
@@ -119,7 +117,7 @@ for _o, _d in classification_model_data.items():
|
|
| 119 |
_cal = _d["calibrator"]
|
| 120 |
else:
|
| 121 |
print(
|
| 122 |
-
f"Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. "
|
| 123 |
"Skipping non-isotonic calibrator (isotonic-only policy)."
|
| 124 |
)
|
| 125 |
elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None:
|
|
@@ -138,11 +136,14 @@ ohe = preprocessor.named_transformers_["cat"]
|
|
| 138 |
ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS)
|
| 139 |
processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names])
|
| 140 |
|
|
|
|
|
|
|
| 141 |
|
| 142 |
# ---------------------------------------------------------------------------
|
| 143 |
# SHAP background data
|
| 144 |
# ---------------------------------------------------------------------------
|
| 145 |
|
|
|
|
| 146 |
np.random.seed(23)
|
| 147 |
_n_background = 500
|
| 148 |
|
|
@@ -177,9 +178,10 @@ _background_data = {
|
|
| 177 |
], _n_background),
|
| 178 |
}
|
| 179 |
|
| 180 |
-
_background_df
|
| 181 |
-
_X_background
|
| 182 |
shap_background = shap.maskers.Independent(_X_background)
|
|
|
|
| 183 |
|
| 184 |
|
| 185 |
# ---------------------------------------------------------------------------
|
|
@@ -611,4 +613,7 @@ def icon_array(probability, outcome):
|
|
| 611 |
plot_bgcolor="white",
|
| 612 |
paper_bgcolor="white",
|
| 613 |
)
|
| 614 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
|
| 11 |
|
| 12 |
+
print("===== Application Startup =====")
|
| 13 |
+
print(f"Working directory: {os.getcwd()}")
|
| 14 |
+
print(f"Files present: {os.listdir('.')}")
|
| 15 |
+
|
| 16 |
# ---------------------------------------------------------------------------
|
| 17 |
# Compatibility patch — inject _RemainderColsList if the installed sklearn
|
| 18 |
+
# version does not have it (added in sklearn 1.4+).
|
|
|
|
| 19 |
# ---------------------------------------------------------------------------
|
| 20 |
import sklearn.compose._column_transformer as _ct
|
| 21 |
if not hasattr(_ct, "_RemainderColsList"):
|
|
|
|
| 27 |
_ct._RemainderColsList = _RemainderColsList
|
| 28 |
import sklearn.compose
|
| 29 |
sklearn.compose._RemainderColsList = _RemainderColsList
|
| 30 |
+
print("Patched _RemainderColsList into sklearn.compose")
|
| 31 |
|
| 32 |
|
| 33 |
# ---------------------------------------------------------------------------
|
|
|
|
| 36 |
|
| 37 |
NUM_COLUMNS = ["AGE", "NACS2YR"]
|
| 38 |
CATEG_COLUMNS = [
|
| 39 |
+
"AGEGPFF", "SEX", "KPS", "DONORF", "GRAFTYPE", "CONDGRPF",
|
| 40 |
+
"CONDGRP_FINAL", "ATGF", "GVHD_FINAL", "HLA_FINAL",
|
| 41 |
+
"RCMVPR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
]
|
| 43 |
|
| 44 |
FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS
|
|
|
|
| 74 |
# ---------------------------------------------------------------------------
|
| 75 |
|
| 76 |
def _load_skops_model(fname):
|
| 77 |
+
"""Load a skops model file. Raises RuntimeError on failure (no sys.exit)."""
|
| 78 |
+
if not os.path.exists(fname):
|
| 79 |
+
raise RuntimeError(f"Model file not found: {fname}")
|
| 80 |
try:
|
| 81 |
untrusted = sio.get_untrusted_types(file=fname)
|
| 82 |
+
model = sio.load(fname, trusted=untrusted)
|
| 83 |
+
print(f" Loaded: {fname}")
|
| 84 |
+
return model
|
| 85 |
except Exception as e:
|
| 86 |
+
raise RuntimeError(f"Failed to load '{fname}': {type(e).__name__}: {e}") from e
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
+
print("Loading preprocessor...")
|
| 90 |
preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops"))
|
| 91 |
|
| 92 |
+
print("Loading ensemble models...")
|
| 93 |
classification_model_data = {}
|
| 94 |
for _o in CLASSIFICATION_OUTCOMES:
|
| 95 |
_path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops")
|
| 96 |
if os.path.exists(_path):
|
| 97 |
classification_model_data[_o] = _load_skops_model(_path)
|
| 98 |
else:
|
| 99 |
+
print(f" Warning: Model for {_o} not found at {_path}. Skipping.")
|
| 100 |
|
| 101 |
classification_models = {o: d["models"] for o, d in classification_model_data.items()}
|
| 102 |
betas = {o: d["beta"] for o, d in classification_model_data.items()}
|
|
|
|
| 117 |
_cal = _d["calibrator"]
|
| 118 |
else:
|
| 119 |
print(
|
| 120 |
+
f" Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. "
|
| 121 |
"Skipping non-isotonic calibrator (isotonic-only policy)."
|
| 122 |
)
|
| 123 |
elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None:
|
|
|
|
| 136 |
ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS)
|
| 137 |
processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names])
|
| 138 |
|
| 139 |
+
print(f"Models loaded: {list(classification_models.keys())}")
|
| 140 |
+
|
| 141 |
|
| 142 |
# ---------------------------------------------------------------------------
|
| 143 |
# SHAP background data
|
| 144 |
# ---------------------------------------------------------------------------
|
| 145 |
|
| 146 |
+
print("Building SHAP background...")
|
| 147 |
np.random.seed(23)
|
| 148 |
_n_background = 500
|
| 149 |
|
|
|
|
| 178 |
], _n_background),
|
| 179 |
}
|
| 180 |
|
| 181 |
+
_background_df = pd.DataFrame(_background_data)[FEATURE_NAMES]
|
| 182 |
+
_X_background = preprocessor.transform(_background_df)
|
| 183 |
shap_background = shap.maskers.Independent(_X_background)
|
| 184 |
+
print("SHAP background ready.")
|
| 185 |
|
| 186 |
|
| 187 |
# ---------------------------------------------------------------------------
|
|
|
|
| 613 |
plot_bgcolor="white",
|
| 614 |
paper_bgcolor="white",
|
| 615 |
)
|
| 616 |
+
return fig
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
print("===== inference.py loaded successfully =====")
|