stroke_risk_xgboost / handler.py
amobionovo's picture
Update handler.py
4a71291 verified
# handler.py — Quantium insights Inference Endpoint (fixes XGBWrappedModel unpickle + Residence_type)
import os
import sys
import types
import json
import traceback
from typing import Any, Dict, List, Tuple
import joblib
import numpy as np
import pandas as pd
# =========================
# Re-declare the custom wrapper class and register it where pickle expects it
# =========================
class XGBWrappedModel:
"""
Wrapper saved in model.joblib:
- preprocessor_: sklearn ColumnTransformer
- model_: XGBClassifier (or similar exposing predict_proba)
- explainer_: optional SHAP explainer
- feature_names_out_: names after preprocessing
Provides:
- predict_proba(X_df)
- top_contrib(X_df, k)
"""
def __init__(self, preprocessor=None, booster=None, explainer=None,
feat_names_out=None, cat_prefix="cat__", num_prefix="num__"):
self.preprocessor_ = preprocessor
self.model_ = booster
self.explainer_ = explainer
self.feature_names_out_ = np.array(feat_names_out).astype(str) if feat_names_out is not None else None
self.cat_prefix = cat_prefix
self.num_prefix = num_prefix
def predict_proba(self, X_df: pd.DataFrame):
Z = self.preprocessor_.transform(X_df)
# XGBoost exposes predict_proba for binary: shape (n, 2)
return self.model_.predict_proba(Z)
def top_contrib(self, X_df: pd.DataFrame, k: int = 5) -> Tuple[List[str], List[float]]:
if self.explainer_ is None:
return [], []
Z = self.preprocessor_.transform(X_df)
try:
sv = self.explainer_.shap_values(Z)
if isinstance(sv, list):
sv = sv[1] if len(sv) > 1 else sv[0]
except Exception:
res = self.explainer_(Z)
sv = res.values
sv_row = np.array(sv[0], dtype=float)
def to_orig(name: str) -> str:
if name.startswith(self.cat_prefix):
return name[len(self.cat_prefix):].split("_", 1)[0]
if name.startswith(self.num_prefix):
return name[len(self.num_prefix):]
return name.split("_", 1)[0]
if self.feature_names_out_ is None:
names_out = [f"f{i}" for i in range(len(sv_row))]
else:
names_out = list(self.feature_names_out_)
orig_names = [to_orig(n) for n in names_out]
abs_sum: Dict[str, float] = {}
signed_sum: Dict[str, float] = {}
for n, v in zip(orig_names, sv_row):
abs_sum[n] = abs_sum.get(n, 0.0) + abs(float(v))
signed_sum[n] = signed_sum.get(n, 0.0) + float(v)
ranked = sorted(abs_sum.items(), key=lambda kv: kv[1], reverse=True)[:k]
names = [n for n, _ in ranked]
values = [signed_sum[n] for n, _ in ranked]
return names, values
# Register class under the module names pickle may look for
# (your training run saved it from __main__; sometimes from 'train_export_xgb')
sys.modules['__main__'].__dict__['XGBWrappedModel'] = XGBWrappedModel
if 'train_export_xgb' not in sys.modules:
sys.modules['train_export_xgb'] = types.ModuleType('train_export_xgb')
sys.modules['train_export_xgb'].__dict__['XGBWrappedModel'] = XGBWrappedModel
# =========================
# Feature schema (canonical)
# =========================
NUMERIC_COLS = ["age", "avg_glucose_level", "bmi", "hypertension", "heart_disease"]
# Canonical Residence key uses capital R
CAT_COLS = ["gender", "ever_married", "work_type", "smoking_status", "Residence_type"]
ALL_CANON = NUMERIC_COLS + CAT_COLS
EXPLAIN_ORDER = [
"age", "avg_glucose_level", "bmi", "hypertension", "heart_disease",
"gender", "ever_married", "work_type", "smoking_status", "Residence_type"
]
# =========================
# Utility: dtype coercion
# =========================
def _to_int01(x: Any) -> int:
if isinstance(x, (bool, np.bool_)):
return int(bool(x))
try:
if isinstance(x, str):
s = x.strip().lower()
if s in {"1", "true", "t", "yes", "y"}:
return 1
if s in {"0", "false", "f", "no", "n"}:
return 0
return int(float(x))
except Exception:
return 0
def _coerce_dataframe(rows: List[Dict[str, Any]]) -> pd.DataFrame:
"""
Build a clean DataFrame:
- Canonical Residence key is 'Residence_type' (capital R).
- Accept 'residence_type' and map it to 'Residence_type' if needed.
- Ensure numerics are float64 and 0/1 flags are ints then float64.
- Ensure categoricals are plain strings (object), no NA.
- Also mirror lowercase 'residence_type' for legacy models.
"""
norm_rows: List[Dict[str, Any]] = []
for r in rows:
r = dict(r or {})
if "Residence_type" not in r and "residence_type" in r:
r["Residence_type"] = r["residence_type"]
entry = {k: r.get(k, None) for k in ALL_CANON}
norm_rows.append(entry)
df = pd.DataFrame(norm_rows, columns=ALL_CANON)
for col in ["hypertension", "heart_disease"]:
df[col] = df[col].map(_to_int01)
for col in ["age", "avg_glucose_level", "bmi"]:
df[col] = pd.to_numeric(df[col], errors="coerce")
df[NUMERIC_COLS] = df[NUMERIC_COLS].astype("float64")
for col in CAT_COLS:
df[col] = df[col].where(df[col].notna(), "Unknown")
df[col] = df[col].map(lambda v: "Unknown" if v is None else str(v)).astype(object)
# Mirror lowercase for backward compatibility
df["residence_type"] = df["Residence_type"].astype(object)
return df
# =========================
# Safety patches for OHE
# =========================
def _iter_estimators(est):
yield est
if hasattr(est, "named_steps"):
for step in est.named_steps.values():
yield from _iter_estimators(step)
if hasattr(est, "transformers"):
for _, tr, _ in est.transformers:
yield from _iter_estimators(tr)
def _numeric_like(x) -> bool:
if x is None:
return True
if isinstance(x, (int, np.integer, float, np.floating)):
return True
if isinstance(x, str):
try:
float(x)
return True
except Exception:
return False
return False
def _sanitize_onehot_categories(model):
"""Coerce OneHotEncoder.categories_ to consistent dtypes to avoid np.isnan crashes."""
try:
from sklearn.preprocessing import OneHotEncoder # type: ignore
except Exception:
OneHotEncoder = None
if OneHotEncoder is None:
return
for node in _iter_estimators(model):
if isinstance(node, OneHotEncoder) and hasattr(node, "categories_"):
new_cats = []
for cats in node.categories_:
arr = np.asarray(cats, dtype=object)
if all(_numeric_like(v) for v in arr):
vals = []
for v in arr:
try:
vals.append(np.nan if v is None else float(v))
except Exception:
vals.append(np.nan)
new_cats.append(np.asarray(vals, dtype=float))
else:
strs = ["Unknown" if (v is None or (isinstance(v, float) and np.isnan(v))) else str(v) for v in arr]
new_cats.append(np.asarray(strs, dtype=object))
node.categories_ = new_cats
if hasattr(node, "handle_unknown"):
node.handle_unknown = "ignore"
def _patch_check_unknown():
"""Patch sklearn _check_unknown to avoid np.isnan on object arrays (older builds)."""
try:
from sklearn.utils import _encode # type: ignore
_orig = _encode._check_unknown
def _safe_check_unknown(values, known_values, return_mask=False):
try:
return _orig(values, known_values, return_mask=return_mask)
except TypeError:
vals = np.asarray(values, dtype=object)
known = np.asarray(known_values, dtype=object)
mask = np.isin(vals, known, assume_unique=False)
diff = vals[~mask]
if return_mask:
return diff, mask
return diff
_encode._check_unknown = _safe_check_unknown # type: ignore[attr-defined]
print("[handler] Patched sklearn.utils._encode._check_unknown", flush=True)
except Exception as e:
print(f"[handler] Patch for _check_unknown not applied: {e}", flush=True)
# =========================
# Model introspection (debug)
# =========================
def _introspect_model(model) -> Dict[str, Any]:
info: Dict[str, Any] = {"type": str(type(model))}
try:
if hasattr(model, "named_steps"):
info["pipeline_steps"] = list(model.named_steps.keys())
for name, step in model.named_steps.items():
if step.__class__.__name__ == "ColumnTransformer":
info["column_transformer"] = str(step)
try:
info["transformers_"] = [(n, str(t.__class__), cols) for (n, t, cols) in step.transformers]
except Exception:
pass
except Exception:
pass
try:
info["feature_names_in_"] = list(getattr(model, "feature_names_in_", []))
except Exception:
pass
return info
# =========================
# Handler
# =========================
class EndpointHandler:
def __init__(self, path: str = "/repository") -> None:
_patch_check_unknown() # apply safety patch early
model_path = os.path.join(path, "model.joblib")
self.model = joblib.load(model_path)
try:
self.threshold = float(os.getenv("THRESHOLD", "0.38"))
except Exception:
self.threshold = 0.38
self.explainer = getattr(self.model, "explainer_", None)
_sanitize_onehot_categories(self.model)
print("[handler] Model loaded", flush=True)
print(f"[handler] Using threshold: {self.threshold}", flush=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
debug = bool(data.get("debug", False))
explain = bool(data.get("explain", False))
rows = data.get("inputs") or []
if isinstance(rows, dict):
rows = [rows]
if not isinstance(rows, list) or not rows:
return {"error": "inputs must be a non-empty list of records", "threshold": self.threshold}
df = _coerce_dataframe(rows)
debug_info = {
"columns": list(df.columns),
"dtypes": {c: str(df[c].dtype) for c in df.columns},
"threshold": self.threshold,
"model": _introspect_model(self.model),
"head": df.head(1).to_dict(orient="records"),
}
# Predict
try:
if hasattr(self.model, "predict_proba"):
proba = self.model.predict_proba(df)[:, 1].astype(float)
else:
raw = self.model.predict(df).astype(float)
proba = 1.0 / (1.0 + np.exp(-raw))
except Exception as e:
return {
"error": f"model.predict failed: {e}",
"trace": traceback.format_exc(),
"debug": debug_info,
"threshold": self.threshold,
}
p = float(proba[0])
label = int(p >= self.threshold)
resp: Dict[str, Any] = {
"risk_probability": p,
"risk_label": label,
"threshold": self.threshold,
}
if explain:
if hasattr(self.model, "top_contrib"):
try:
names, vals = self.model.top_contrib(df, k=5)
if names:
resp["shap"] = {"feature_names": names, "values": vals}
except Exception as e:
resp["shap_error"] = f"top_contrib failed: {e}"
elif self.explainer is not None:
try:
shap_vals = self.explainer(df)
vals = shap_vals.values[0] if hasattr(shap_vals, "values") else shap_vals[0]
contrib = []
for feat in EXPLAIN_ORDER:
if feat in df.columns:
idx = list(df.columns).index(feat)
contrib.append({"feature": feat, "effect": float(vals[idx])})
resp["shap"] = {"contrib": contrib}
except Exception as e:
resp["shap_error"] = f"explainer failed: {e}"
if debug:
resp["debug"] = debug_info
try:
print(f"[handler] prob={p:.4f} label={label}", flush=True)
except Exception:
pass
return resp