readmitiq-api / src /api /predictor.py
DrKryptoMed's picture
feat: deploy ReadmitIQ API to HuggingFace Spaces
742cd9a
Raw
History Blame Contribute Delete
10.3 kB
"""
Model loading and inference engine.
Loads XGBoost model, scaler, and SHAP explainer once at startup.
All prediction requests are handled through the Predictor class.
Design principle: load once, predict many times.
Loading a model on every request would be 10-100x slower.
"""
import joblib
import numpy as np
import pandas as pd
import shap
import logging
from pathlib import Path
from typing import Optional
from src.api.models import (
PatientFeatures,
PredictionResponse,
RiskCategory,
RiskFactor,
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Feature order must match exactly what the model was trained on
# Any mismatch causes silent wrong predictions β€” worse than an error
FEATURE_ORDER = [
'los_days', 'los_days_log', 'admission_month', 'admission_dow',
'is_emergency', 'n_admissions_prior_6m', 'n_admissions_prior_12m',
'n_ed_visits_prior_6m', 'days_since_last_admission',
'has_prior_admission', 'charlson_score', 'n_active_conditions',
'has_heart_failure', 'has_diabetes', 'has_diabetes_complex',
'has_copd', 'has_ckd', 'has_mi', 'has_cancer', 'has_dementia',
'has_cerebrovascular', 'has_pvd', 'n_medications_capped',
'is_high_polypharmacy', 'has_insulin', 'has_anticoagulant',
'has_diuretic', 'has_ace_inhibitor', 'age_at_admission',
'gender_male', 'income',
]
# Clinical descriptions for each feature
# Used to generate plain-English explanations in the response
FEATURE_DESCRIPTIONS = {
'los_days': 'Length of hospital stay',
'los_days_log': 'Length of hospital stay (log scale)',
'admission_month': 'Month of admission',
'admission_dow': 'Day of week admitted',
'is_emergency': 'Emergency admission',
'n_admissions_prior_6m': 'Hospital admissions in past 6 months',
'n_admissions_prior_12m': 'Hospital admissions in past 12 months',
'n_ed_visits_prior_6m': 'Emergency department visits in past 6 months',
'days_since_last_admission': 'Days since last hospital admission',
'has_prior_admission': 'History of prior hospitalisation',
'charlson_score': 'Charlson Comorbidity Index score',
'n_active_conditions': 'Number of active medical conditions',
'has_heart_failure': 'Heart failure diagnosis',
'has_diabetes': 'Diabetes mellitus diagnosis',
'has_diabetes_complex': 'Diabetes with complications',
'has_copd': 'Chronic obstructive pulmonary disease',
'has_ckd': 'Chronic kidney disease',
'has_mi': 'History of myocardial infarction',
'has_cancer': 'Active cancer diagnosis',
'has_dementia': 'Dementia diagnosis',
'has_cerebrovascular': 'Cerebrovascular disease',
'has_pvd': 'Peripheral vascular disease',
'n_medications_capped': 'Number of active medications',
'is_high_polypharmacy': 'High polypharmacy (>10 medications)',
'has_insulin': 'Insulin therapy',
'has_anticoagulant': 'Anticoagulant therapy',
'has_diuretic': 'Diuretic therapy',
'has_ace_inhibitor': 'ACE inhibitor therapy',
'age_at_admission': 'Patient age at admission',
'gender_male': 'Male gender',
'income': 'Household income',
}
# Risk thresholds β€” calibrated to our model's output distribution
# LOW: < 20% β€” standard discharge
# MEDIUM: 20-50% β€” enhanced follow-up
# HIGH: > 50% β€” intensive intervention
RISK_THRESHOLDS = {
'LOW': 0.20,
'MEDIUM': 0.50,
}
# Clinical recommendations by risk category
RECOMMENDATIONS = {
RiskCategory.LOW: (
"Standard discharge protocol. "
"Schedule routine 30-day follow-up appointment. "
"Provide patient with discharge summary and medication list."
),
RiskCategory.MEDIUM: (
"Enhanced discharge protocol. "
"Schedule 14-day follow-up appointment. "
"Conduct pharmacist medication reconciliation. "
"Provide patient education on warning signs requiring ED return."
),
RiskCategory.HIGH: (
"High-risk discharge protocol. "
"Schedule 7-day follow-up appointment. "
"Assign care coordinator for post-discharge monitoring. "
"Conduct pharmacist medication reconciliation before discharge. "
"Consider social work referral if social support is limited. "
"Confirm patient has transport and medication access post-discharge."
),
}
class Predictor:
"""
Manages model lifecycle and inference.
Loaded once at API startup via FastAPI lifespan events.
Thread-safe for concurrent requests β€” models are read-only after loading.
"""
def __init__(self, models_path: str = "models"):
self.models_path = Path(models_path)
self.model = None
self.scaler = None
self.explainer = None
self.model_version = "xgboost_v1_auc0.891"
self._loaded = False
def load(self) -> None:
"""
Load model, scaler, and SHAP explainer from disk.
Called once at API startup β€” not on every request.
"""
logger.info("Loading model artifacts...")
try:
# Load XGBoost model
model_path = self.models_path / "xgboost_model.pkl"
self.model = joblib.load(model_path)
logger.info(f"Model loaded from {model_path}")
# Load scaler β€” must be the same scaler used in training
scaler_path = self.models_path / "scaler.pkl"
self.scaler = joblib.load(scaler_path)
logger.info(f"Scaler loaded from {scaler_path}")
# Build SHAP explainer β€” TreeExplainer for XGBoost
# This is slow to build (~5s) but fast to query
self.explainer = shap.TreeExplainer(self.model)
logger.info("SHAP explainer ready")
self._loaded = True
logger.info(
f"All artifacts loaded successfully. "
f"Model version: {self.model_version}"
)
except FileNotFoundError as e:
logger.error(f"Model file not found: {e}")
raise
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@property
def is_loaded(self) -> bool:
return self._loaded
def _features_to_array(
self, features: PatientFeatures
) -> pd.DataFrame:
"""Convert Pydantic model to DataFrame in correct feature order."""
feature_dict = features.model_dump()
ordered = {f: [feature_dict[f]] for f in FEATURE_ORDER}
return pd.DataFrame(ordered)
def _get_risk_category(
self, probability: float
) -> RiskCategory:
"""Map probability to clinical risk category."""
if probability < RISK_THRESHOLDS['LOW']:
return RiskCategory.LOW
elif probability < RISK_THRESHOLDS['MEDIUM']:
return RiskCategory.MEDIUM
else:
return RiskCategory.HIGH
def _compute_shap_factors(
self,
raw_array: np.ndarray,
features: PatientFeatures,
top_n: int = 5
) -> list[RiskFactor]:
"""
Compute SHAP values and return top N risk factors
with clinical plain-English descriptions.
"""
shap_values = self.explainer.shap_values(raw_array)[0]
feature_dict = features.model_dump()
# Build list of (feature, value, shap_value)
factors = [
(feat, feature_dict[feat], shap_val)
for feat, shap_val in zip(FEATURE_ORDER, shap_values)
]
# Sort by absolute SHAP value β€” largest impact first
factors_sorted = sorted(
factors, key=lambda x: abs(x[2]), reverse=True
)[:top_n]
risk_factors = []
for feat, val, shap_val in factors_sorted:
risk_factors.append(RiskFactor(
feature=feat,
value=round(float(val), 3),
impact=round(float(shap_val), 4),
direction="increases" if shap_val > 0 else "decreases",
description=FEATURE_DESCRIPTIONS.get(feat, feat)
))
return risk_factors
def predict(self, features: PatientFeatures) -> PredictionResponse:
"""
Run full prediction pipeline for one patient.
Steps:
1. Convert features to array in correct order
2. Scale features using training scaler
3. Get probability from XGBoost
4. Compute SHAP values for explainability
5. Build structured clinical response
Returns:
PredictionResponse with risk score, category,
SHAP factors, and clinical recommendation
"""
if not self._loaded:
raise RuntimeError(
"Model not loaded. Call predictor.load() first."
)
# Step 1 β€” features to array
raw_df = self._features_to_array(features)
# Step 2 β€” scale
scaled_array = self.scaler.transform(raw_df)
# Step 3 β€” predict probability
probability = float(
self.model.predict_proba(scaled_array)[0][1]
)
# Step 4 β€” SHAP explanation
risk_factors = self._compute_shap_factors(
scaled_array, features, top_n=5
)
# Step 5 β€” build response
risk_category = self._get_risk_category(probability)
risk_percent = round(probability * 100, 1)
return PredictionResponse(
readmission_risk = round(probability, 4),
risk_percent = risk_percent,
risk_category = risk_category,
risk_score_display = f"{risk_percent}% readmission risk",
top_risk_factors = risk_factors,
recommendation = RECOMMENDATIONS[risk_category],
model_version = self.model_version,
)
# Global predictor instance
# Imported by routers β€” loaded once at startup
predictor = Predictor()