| --- |
| language: |
| - en |
| license: mit |
| tags: |
| - medical |
| - clinical-decision-support |
| - tabular-classification |
| - scikit-learn |
| - xgboost |
| - shap |
| - ensemble-learning |
| - global-health |
| datasets: |
| - private-western-kenya-clinical-cohort |
| metrics: |
| - accuracy |
| - f1 |
| - sensitivity |
| - specificity |
| --- |
| |
| # Hemaclass XAI: Deep Stacking Ensemble for Malaria and Sickle Cell Anemia |
|
|
| ## Model Description |
| The **Hemaclass XAI** model is a phase-4 prototype Clinical Decision Support System (CDSS) designed to classify **Malaria, Sickle Cell Anemia (SCA), Co-infections, and Negative (Healthy/Other)** patient states. Developed specifically for deployment in resource-constrained settings in Western Kenya, the model utilizes a robust Deep Stacking Ensemble architecture coupled with SHAP (SHapley Additive exPlanations) for transparent, clinician-friendly interpretability. |
|
|
| - **Developer:** Mabonga Labs / Hemaclass Project |
| - **Model Type:** Multi-Class Tabular Classification (Deep Stacking Ensemble) |
| - **Primary Architecture:** Random Forest, Support Vector Machine (RBF), and XGBoost (Base Learners) -> Logistic Regression (Meta-Learner). |
| - **Hyperparameter Tuning:** Nested Cross-Validation with Bayesian Search (`skopt`). |
|
|
| ## Intended Use & Target Audience |
| - **Primary Use Case:** Triage and secondary diagnostic validation for clinicians operating in malaria-endemic regions with high SCA prevalence. |
| - **Target Audience:** Medical doctors, clinical officers, and healthcare technicians. |
| - **Out-of-Scope Uses:** This model is **not** a standalone diagnostic device. It is designed for *decision support*. It should not override human clinical judgment or be used as a replacement for definitive laboratory protocols (e.g., blood smears, Hb electrophoresis). |
|
|
| ## Clinical Protocol & Hardcoded Overrides |
| To prioritize patient safety, the inference pipeline integrates hardcoded clinical rules that override the AI's probability outputs in critical, life-threatening scenarios: |
| 1. **Severe Hyperhemolytic Crisis:** Hemoglobin (Hb) < 5.0 g/dL. |
| 2. **Acute Hemolytic Malarial Crisis:** Reticulocyte Count > 8.0% + Positive Malaria RDT. |
| 3. **Rapidly Progressing Vaso-occlusive Malarial Crisis:** Rapid Hb decline (>1.5g/dL in 48h) + Positive Malaria RDT + Presence of HbS genotype. |
| *When triggered, the system flags the diagnosis as "Co-infection" with 100% confidence and alerts the clinician to admit the patient to a high-dependency unit.* |
|
|
| ## Model Input Features |
| The model ingests 24 clinical biomarkers and autonomously engineers 3 derived features: |
| * **Demographics:** Age, Sex |
| * **Vitals & Symptoms:** Body Temperature, Fever, Chills, Headache, Muscle Aches, Fatigue, Loss of Appetite, Jaundice, Abdominal Pain, Joint Pain, Splenomegaly, Severe Pallor, Lymphadenopathy. |
| * **Laboratory Markers:** Malaria RDT (Binary), Hemoglobin (Hb), WBC Count, Platelet Count, Reticulocyte Count, Rapid Hb Decline Alert. |
| * **Hemoglobin Fractions:** HbA, HbS, HbF. |
| * **Engineered Features:** Symptom Severity Score, Age Group (Categorical), Infection-to-Anemia Ratio (WBC / Hb). |
|
|
| ## Data Preprocessing & Augmentation |
| * **Missing Data:** Handled using Multiple Imputation by Chained Equations (MICE) up to 30 iterations to preserve complex clinical covariances. |
| * **Class Imbalance:** The foundational dataset (~350 retrospective patient records from Western Kenya) was highly imbalanced. **Extreme SMOTE** (Synthetic Minority Over-sampling Technique) was applied to map boundaries and synthesize a robust, balanced training matrix of 6,000 clinical profiles (1,500 per target class). |
| * **Encoding:** Deterministic Ordinal Encoding for categorical values; Z-Score Normalization for numerical features. |
|
|
| ## Explainability (XAI) |
| The system integrates **SHAP (TreeExplainer)** applied to the XGBoost component of the stacking ensemble. Local explanations (Waterfall plots) are generated for every inference, providing clinicians with exact quantification of how individual biomarkers (e.g., *HbS%* or *Symptom Severity*) contributed to the final predicted risk score. |
|
|
| ## Evaluation & Metrics |
| The model was evaluated on an isolated, unseen clinical test set prior to SMOTE augmentation. |
| * **Metrics Tracked:** Macro-F1 Score, Macro-Sensitivity (Recall), Macro-Specificity, and overall Accuracy. |
| * **Statistical Significance:** Friedman/Nemenyi post-hoc testing confirmed the Stacking Ensemble significantly outperforms isolated baseline models (p < 0.05). |
| *(Note: Refer to the specific model logs or the connected Gradio dashboard for live metrics).* |
|
|
| ## Ethical Considerations & Limitations |
| * **Geographic Bias:** The base dataset reflects the epidemiology of Western Kenya. Prevalence features (like overlapping splenomegaly in SCA and Malaria) may not generalize accurately to populations outside of Sub-Saharan Africa or regions with varying Plasmodium falciparum endemicity. |
| * **Synthetic Data Artifacts:** Because SMOTE was used extensively to augment the data for deep learning convergence, extreme edge-case boundaries may occasionally display synthetic bias. Prospective validation on a large-scale real-world clinical trial is required before Phase 5 medical device certification. |
| * **Explainability Proxy:** The SHAP explanations are derived from the XGBoost sub-estimator rather than the entire Stacking Classifier boundary, serving as a highly accurate proxy rather than an absolute mathematical reflection of the meta-learner. |
|
|
| ## How to Get Started |
| To interact with the model via the UI, please visit the connected [Hugging Face Space](#). |
|
|
| For programmatic inference using Python: |
| ```python |
| import joblib |
| import pandas as pd |
| |
| # 1. Load Artifacts |
| model = joblib.load('ensemble_model.pkl') |
| scaler = joblib.load('scaler.pkl') |
| imputer = joblib.load('imputer.pkl') |
| |
| # 2. Prepare patient data (ensure columns match FEATURE_NAMES) |
| # patient_df = pd.DataFrame({...}) |
| |
| # 3. Preprocess |
| # X_imp = imputer.transform(patient_df) |
| # X_scaled = scaler.transform(X_imp) |
| |
| # 4. Predict |
| # predictions = model.predict(X_scaled) |
| # probabilities = model.predict_proba(X_scaled) |