Spaces:
Running
Running
| """ | |
| Model loading and inference engine. | |
| Loads XGBoost model, scaler, and SHAP explainer once at startup. | |
| All prediction requests are handled through the Predictor class. | |
| Design principle: load once, predict many times. | |
| Loading a model on every request would be 10-100x slower. | |
| """ | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import shap | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| from src.api.models import ( | |
| PatientFeatures, | |
| PredictionResponse, | |
| RiskCategory, | |
| RiskFactor, | |
| ) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Feature order must match exactly what the model was trained on | |
| # Any mismatch causes silent wrong predictions β worse than an error | |
| FEATURE_ORDER = [ | |
| 'los_days', 'los_days_log', 'admission_month', 'admission_dow', | |
| 'is_emergency', 'n_admissions_prior_6m', 'n_admissions_prior_12m', | |
| 'n_ed_visits_prior_6m', 'days_since_last_admission', | |
| 'has_prior_admission', 'charlson_score', 'n_active_conditions', | |
| 'has_heart_failure', 'has_diabetes', 'has_diabetes_complex', | |
| 'has_copd', 'has_ckd', 'has_mi', 'has_cancer', 'has_dementia', | |
| 'has_cerebrovascular', 'has_pvd', 'n_medications_capped', | |
| 'is_high_polypharmacy', 'has_insulin', 'has_anticoagulant', | |
| 'has_diuretic', 'has_ace_inhibitor', 'age_at_admission', | |
| 'gender_male', 'income', | |
| ] | |
| # Clinical descriptions for each feature | |
| # Used to generate plain-English explanations in the response | |
| FEATURE_DESCRIPTIONS = { | |
| 'los_days': 'Length of hospital stay', | |
| 'los_days_log': 'Length of hospital stay (log scale)', | |
| 'admission_month': 'Month of admission', | |
| 'admission_dow': 'Day of week admitted', | |
| 'is_emergency': 'Emergency admission', | |
| 'n_admissions_prior_6m': 'Hospital admissions in past 6 months', | |
| 'n_admissions_prior_12m': 'Hospital admissions in past 12 months', | |
| 'n_ed_visits_prior_6m': 'Emergency department visits in past 6 months', | |
| 'days_since_last_admission': 'Days since last hospital admission', | |
| 'has_prior_admission': 'History of prior hospitalisation', | |
| 'charlson_score': 'Charlson Comorbidity Index score', | |
| 'n_active_conditions': 'Number of active medical conditions', | |
| 'has_heart_failure': 'Heart failure diagnosis', | |
| 'has_diabetes': 'Diabetes mellitus diagnosis', | |
| 'has_diabetes_complex': 'Diabetes with complications', | |
| 'has_copd': 'Chronic obstructive pulmonary disease', | |
| 'has_ckd': 'Chronic kidney disease', | |
| 'has_mi': 'History of myocardial infarction', | |
| 'has_cancer': 'Active cancer diagnosis', | |
| 'has_dementia': 'Dementia diagnosis', | |
| 'has_cerebrovascular': 'Cerebrovascular disease', | |
| 'has_pvd': 'Peripheral vascular disease', | |
| 'n_medications_capped': 'Number of active medications', | |
| 'is_high_polypharmacy': 'High polypharmacy (>10 medications)', | |
| 'has_insulin': 'Insulin therapy', | |
| 'has_anticoagulant': 'Anticoagulant therapy', | |
| 'has_diuretic': 'Diuretic therapy', | |
| 'has_ace_inhibitor': 'ACE inhibitor therapy', | |
| 'age_at_admission': 'Patient age at admission', | |
| 'gender_male': 'Male gender', | |
| 'income': 'Household income', | |
| } | |
| # Risk thresholds β calibrated to our model's output distribution | |
| # LOW: < 20% β standard discharge | |
| # MEDIUM: 20-50% β enhanced follow-up | |
| # HIGH: > 50% β intensive intervention | |
| RISK_THRESHOLDS = { | |
| 'LOW': 0.20, | |
| 'MEDIUM': 0.50, | |
| } | |
| # Clinical recommendations by risk category | |
| RECOMMENDATIONS = { | |
| RiskCategory.LOW: ( | |
| "Standard discharge protocol. " | |
| "Schedule routine 30-day follow-up appointment. " | |
| "Provide patient with discharge summary and medication list." | |
| ), | |
| RiskCategory.MEDIUM: ( | |
| "Enhanced discharge protocol. " | |
| "Schedule 14-day follow-up appointment. " | |
| "Conduct pharmacist medication reconciliation. " | |
| "Provide patient education on warning signs requiring ED return." | |
| ), | |
| RiskCategory.HIGH: ( | |
| "High-risk discharge protocol. " | |
| "Schedule 7-day follow-up appointment. " | |
| "Assign care coordinator for post-discharge monitoring. " | |
| "Conduct pharmacist medication reconciliation before discharge. " | |
| "Consider social work referral if social support is limited. " | |
| "Confirm patient has transport and medication access post-discharge." | |
| ), | |
| } | |
| class Predictor: | |
| """ | |
| Manages model lifecycle and inference. | |
| Loaded once at API startup via FastAPI lifespan events. | |
| Thread-safe for concurrent requests β models are read-only after loading. | |
| """ | |
| def __init__(self, models_path: str = "models"): | |
| self.models_path = Path(models_path) | |
| self.model = None | |
| self.scaler = None | |
| self.explainer = None | |
| self.model_version = "xgboost_v1_auc0.891" | |
| self._loaded = False | |
| def load(self) -> None: | |
| """ | |
| Load model, scaler, and SHAP explainer from disk. | |
| Called once at API startup β not on every request. | |
| """ | |
| logger.info("Loading model artifacts...") | |
| try: | |
| # Load XGBoost model | |
| model_path = self.models_path / "xgboost_model.pkl" | |
| self.model = joblib.load(model_path) | |
| logger.info(f"Model loaded from {model_path}") | |
| # Load scaler β must be the same scaler used in training | |
| scaler_path = self.models_path / "scaler.pkl" | |
| self.scaler = joblib.load(scaler_path) | |
| logger.info(f"Scaler loaded from {scaler_path}") | |
| # Build SHAP explainer β TreeExplainer for XGBoost | |
| # This is slow to build (~5s) but fast to query | |
| self.explainer = shap.TreeExplainer(self.model) | |
| logger.info("SHAP explainer ready") | |
| self._loaded = True | |
| logger.info( | |
| f"All artifacts loaded successfully. " | |
| f"Model version: {self.model_version}" | |
| ) | |
| except FileNotFoundError as e: | |
| logger.error(f"Model file not found: {e}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |
| def _features_to_array( | |
| self, features: PatientFeatures | |
| ) -> pd.DataFrame: | |
| """Convert Pydantic model to DataFrame in correct feature order.""" | |
| feature_dict = features.model_dump() | |
| ordered = {f: [feature_dict[f]] for f in FEATURE_ORDER} | |
| return pd.DataFrame(ordered) | |
| def _get_risk_category( | |
| self, probability: float | |
| ) -> RiskCategory: | |
| """Map probability to clinical risk category.""" | |
| if probability < RISK_THRESHOLDS['LOW']: | |
| return RiskCategory.LOW | |
| elif probability < RISK_THRESHOLDS['MEDIUM']: | |
| return RiskCategory.MEDIUM | |
| else: | |
| return RiskCategory.HIGH | |
| def _compute_shap_factors( | |
| self, | |
| raw_array: np.ndarray, | |
| features: PatientFeatures, | |
| top_n: int = 5 | |
| ) -> list[RiskFactor]: | |
| """ | |
| Compute SHAP values and return top N risk factors | |
| with clinical plain-English descriptions. | |
| """ | |
| shap_values = self.explainer.shap_values(raw_array)[0] | |
| feature_dict = features.model_dump() | |
| # Build list of (feature, value, shap_value) | |
| factors = [ | |
| (feat, feature_dict[feat], shap_val) | |
| for feat, shap_val in zip(FEATURE_ORDER, shap_values) | |
| ] | |
| # Sort by absolute SHAP value β largest impact first | |
| factors_sorted = sorted( | |
| factors, key=lambda x: abs(x[2]), reverse=True | |
| )[:top_n] | |
| risk_factors = [] | |
| for feat, val, shap_val in factors_sorted: | |
| risk_factors.append(RiskFactor( | |
| feature=feat, | |
| value=round(float(val), 3), | |
| impact=round(float(shap_val), 4), | |
| direction="increases" if shap_val > 0 else "decreases", | |
| description=FEATURE_DESCRIPTIONS.get(feat, feat) | |
| )) | |
| return risk_factors | |
| def predict(self, features: PatientFeatures) -> PredictionResponse: | |
| """ | |
| Run full prediction pipeline for one patient. | |
| Steps: | |
| 1. Convert features to array in correct order | |
| 2. Scale features using training scaler | |
| 3. Get probability from XGBoost | |
| 4. Compute SHAP values for explainability | |
| 5. Build structured clinical response | |
| Returns: | |
| PredictionResponse with risk score, category, | |
| SHAP factors, and clinical recommendation | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError( | |
| "Model not loaded. Call predictor.load() first." | |
| ) | |
| # Step 1 β features to array | |
| raw_df = self._features_to_array(features) | |
| # Step 2 β scale | |
| scaled_array = self.scaler.transform(raw_df) | |
| # Step 3 β predict probability | |
| probability = float( | |
| self.model.predict_proba(scaled_array)[0][1] | |
| ) | |
| # Step 4 β SHAP explanation | |
| risk_factors = self._compute_shap_factors( | |
| scaled_array, features, top_n=5 | |
| ) | |
| # Step 5 β build response | |
| risk_category = self._get_risk_category(probability) | |
| risk_percent = round(probability * 100, 1) | |
| return PredictionResponse( | |
| readmission_risk = round(probability, 4), | |
| risk_percent = risk_percent, | |
| risk_category = risk_category, | |
| risk_score_display = f"{risk_percent}% readmission risk", | |
| top_risk_factors = risk_factors, | |
| recommendation = RECOMMENDATIONS[risk_category], | |
| model_version = self.model_version, | |
| ) | |
| # Global predictor instance | |
| # Imported by routers β loaded once at startup | |
| predictor = Predictor() |