Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import shap | |
| import matplotlib | |
| import traceback | |
| import warnings | |
| from sklearn.metrics import accuracy_score, confusion_matrix | |
| warnings.filterwarnings('ignore') | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| # ========================================== | |
| # 1. LOAD TRAINED ARTIFACTS FROM COLAB MEMORY | |
| # ========================================== | |
| print("Loading Model Artifacts...") | |
| try: | |
| best_model = joblib.load('ensemble_model.pkl') | |
| scaler = joblib.load('scaler.pkl') | |
| imputer = joblib.load('imputer.pkl') | |
| # --- HOTFIX FOR SKLEARN VERSION MISMATCH --- | |
| # Pickled models from older sklearn versions may lack _fill_dtype, | |
| # causing an AttributeError during SimpleImputer.transform() | |
| if hasattr(imputer, 'initial_imputer_'): | |
| if not hasattr(imputer.initial_imputer_, '_fill_dtype'): | |
| imputer.initial_imputer_._fill_dtype = getattr(imputer.initial_imputer_, '_fit_dtype', None) | |
| if not hasattr(imputer, '_fill_dtype'): | |
| imputer._fill_dtype = getattr(imputer, '_fit_dtype', None) | |
| # ------------------------------------------- | |
| encoder = joblib.load('encoder.pkl') | |
| FEATURE_NAMES = joblib.load('feature_names.pkl') | |
| cat_columns = joblib.load('cat_columns.pkl') | |
| # Extract XGBoost from StackingClassifier for SHAP explainability | |
| xgb_base = best_model.named_estimators_['xgb'] | |
| explainer = shap.TreeExplainer(xgb_base) | |
| print("All artifacts loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading artifacts: {e}. Ensure the training script ran successfully.") | |
| target_names = ['Negative', 'Malaria', 'SCA', 'Co-infection'] | |
| # ========================================== | |
| # 2. CORE PROCESSING & PREDICTION LOGIC | |
| # ========================================== | |
| def preprocess_input(input_df): | |
| """Replicates the exact Feature Engineering & Preprocessing from Training""" | |
| df = input_df.copy() | |
| # Feature Engineering | |
| symptom_cols =['fever', 'chills', 'headache', 'muscle_aches', 'fatigue', | |
| 'loss_of_appetite', 'jaundice', 'abdominal_pain', 'joint_pain', | |
| 'splenomegaly', 'pallor', 'lymphadenopathy'] | |
| df['symptom_severity_score'] = df[[c for c in symptom_cols if c in df.columns]].sum(axis=1) | |
| if 'age' in df.columns: | |
| df['age_group'] = pd.cut(df['age'], bins=[-1, 5, 12, 55, 120], labels=[0, 1, 2, 3]).astype(float) | |
| if 'hb' in df.columns and 'wbc' in df.columns: | |
| df['infection_anemia_ratio'] = df['wbc'] / (df['hb'] + 1e-5) | |
| # Align with model input shapes | |
| for c in set(FEATURE_NAMES) - set(df.columns): | |
| df[c] = np.nan | |
| df_aligned = df[FEATURE_NAMES].copy() | |
| # Categorical Encoding | |
| MISSING_STR = 'MISSING_CAT' | |
| if cat_columns: | |
| present_cats =[c for c in cat_columns if c in df_aligned.columns] | |
| if present_cats: | |
| df_aligned[present_cats] = df_aligned[present_cats].astype(str).replace(['nan', 'None'], np.nan) | |
| df_aligned[present_cats] = df_aligned[present_cats].fillna(MISSING_STR) | |
| df_aligned[present_cats] = encoder.transform(df_aligned[present_cats]) | |
| for i, col in enumerate(cat_columns): | |
| if col in present_cats and MISSING_STR in encoder.categories_[i]: | |
| missing_code = list(encoder.categories_[i]).index(MISSING_STR) | |
| df_aligned[col] = df_aligned[col].replace(missing_code, np.nan) | |
| for col in df_aligned.columns: | |
| df_aligned[col] = pd.to_numeric(df_aligned[col], errors='coerce') | |
| # Impute and Scale | |
| X_imp = pd.DataFrame(imputer.transform(df_aligned), columns=FEATURE_NAMES) | |
| X_scaled = pd.DataFrame(scaler.transform(X_imp), columns=FEATURE_NAMES) | |
| return X_scaled | |
| def get_specific_coinfection_type(hb, retic, hb_decline, hb_s): | |
| """Determines granular sub-type of Co-infection based on critical markers""" | |
| if hb < 5.0: | |
| return "Co-infection: Severe Hyperhemolytic Malarial Crisis" | |
| elif retic > 8.0: | |
| return "Co-infection: Acute Hemolytic Malarial Crisis" | |
| elif hb_decline and hb_s > 0: | |
| return "Co-infection: Rapidly Progressing Vaso-occlusive Malarial Crisis" | |
| else: | |
| return "Co-infection: Concurrent Malaria & Sickle Cell Crisis" | |
| def get_clinical_recs(diag, rule_triggered=None): | |
| recs = f"### Clinical Decision Support Protocol\n\n" | |
| if rule_triggered: | |
| recs += f"**Critical Protocol Triggered:** *{rule_triggered}*\n\n" | |
| if 'Malaria' in diag and 'Co-infection' not in diag: | |
| recs += "**Protocol:** Initiate Artemisinin-based Combination Therapy (ACT) per WHO guidelines.\n" | |
| elif diag == 'SCA': | |
| recs += "**Protocol:** Administer IV Fluids, oxygen therapy, and comprehensive pain management.\n" | |
| elif 'Co-infection' in diag: | |
| recs += "**Urgent Protocol:** High risk of hyperhemolytic or severe vaso-occlusive crisis.\n" | |
| recs += "- **Action:** Immediate admission to high-dependency unit. Initiate rapid intravenous antimalarials, aggressive hydration, and prepare for potential blood transfusion.\n" | |
| else: | |
| recs += "**Action:** Patient is currently negative for active Malaria and SCA crisis.\n" | |
| recs += "- **Follow-up:** Screen for Typhoid, Dengue, or other viral infections if febrile symptoms persist.\n" | |
| recs += "\n---\n### Diagnostic Context Notes\n" | |
| recs += "- **Overlapping Symptoms:** Fever, Fatigue, Jaundice, Splenomegaly, and Headache *(Headache is uncommon in SCA unless accompanied by severe anemia, cerebral malaria, or stroke risk).* \n" | |
| recs += "- **Co-infection Prevalences:** Key clinical indicators for Co-infection include Severe Pallor + Jaundice, High fever, Splenomegaly + malaria, and Extreme Reticulocyte (>8%) + malaria." | |
| return recs | |
| def generate_shap_plot(X_scaled): | |
| try: | |
| shap_values = explainer.shap_values(X_scaled) | |
| if isinstance(shap_values, list): | |
| pat_shap = shap_values[3][0] | |
| base_val = explainer.expected_value[3] | |
| elif len(shap_values.shape) == 3: | |
| pat_shap = shap_values[0, :, 3] | |
| base_val = explainer.expected_value[3] if isinstance(explainer.expected_value, list) else explainer.expected_value | |
| else: | |
| pat_shap = shap_values[0] | |
| base_val = explainer.expected_value | |
| fig, ax = plt.subplots(figsize=(7, 5)) | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| explanation = shap.Explanation(values=pat_shap, base_values=base_val, | |
| data=X_scaled.iloc[0], feature_names=FEATURE_NAMES) | |
| shap.waterfall_plot(explanation, show=False) | |
| plt.title("XAI Feature Contribution (Impact on Co-Infection Risk)", fontsize=11, fontweight='bold') | |
| plt.tight_layout() | |
| return fig | |
| except Exception as e: | |
| fig, ax = plt.subplots(figsize=(6,4)) | |
| ax.text(0.5, 0.5, f"Interpretability Module Offline:\n{str(e)}", ha='center', va='center') | |
| return fig | |
| def manual_inference(age, sex, temp, hb, wbc, platelets, hb_a, hb_s, hb_f, malaria_rdt, reticulocyte, hb_rapid_decline, | |
| fever, chills, headache, muscle_aches, fatigue, loss_of_appetite, jaundice, abdominal_pain, joint_pain, splenomegaly, pallor, lymphadenopathy): | |
| try: | |
| co_infection_flag = False | |
| rule_triggered = "" | |
| specific_coinfection_name = "" | |
| # Hardcoded Critical Clinical Override Rules | |
| if hb < 5.0: | |
| co_infection_flag = True | |
| rule_triggered = "Hemoglobin below critical threshold (5.0 g/dL)" | |
| elif reticulocyte > 8.0 and malaria_rdt == "Positive": | |
| co_infection_flag = True | |
| rule_triggered = "Extreme Reticulocyte (>8%) + Positive Malaria RDT" | |
| elif hb_rapid_decline and malaria_rdt == "Positive" and hb_s > 0: | |
| co_infection_flag = True | |
| rule_triggered = "Rapid Hb decline (>1.5g/dL in 48h) + Positive Malaria + SCA Genotype" | |
| if co_infection_flag: | |
| specific_coinfection_name = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s) | |
| input_data = pd.DataFrame({ | |
| 'age': [age], 'sex': [sex], 'temp': [temp], 'hb': [hb], 'wbc': [wbc], 'platelets': [platelets], | |
| 'hb_a':[hb_a], 'hb_s': [hb_s], 'hb_f': [hb_f], | |
| 'malaria_rdt':[1.0 if malaria_rdt == "Positive" else 0.0], | |
| 'reticulocyte': [reticulocyte], 'hb_rapid_decline':[1.0 if hb_rapid_decline else 0.0], | |
| 'fever':[1.0 if fever else 0.0], 'chills':[1.0 if chills else 0.0], 'headache': [1.0 if headache else 0.0], | |
| 'muscle_aches': [1.0 if muscle_aches else 0.0], 'fatigue':[1.0 if fatigue else 0.0], | |
| 'loss_of_appetite':[1.0 if loss_of_appetite else 0.0], 'jaundice':[1.0 if jaundice else 0.0], | |
| 'abdominal_pain':[1.0 if abdominal_pain else 0.0], 'joint_pain': [1.0 if joint_pain else 0.0], | |
| 'splenomegaly': [1.0 if splenomegaly else 0.0], 'pallor': [1.0 if pallor else 0.0], | |
| 'lymphadenopathy':[1.0 if lymphadenopathy else 0.0] | |
| }) | |
| X_scaled = preprocess_input(input_data) | |
| probs = best_model.predict_proba(X_scaled)[0] | |
| # Map probabilities to class names | |
| prob_dict = {target_names[i]: probs[i] * 100 for i in range(len(target_names))} | |
| # Apply Clinical Overrides if necessary | |
| if co_infection_flag: | |
| primary_diag = specific_coinfection_name | |
| # Adjust probabilities to reflect the clinical override | |
| prob_dict = { | |
| specific_coinfection_name: 100.0, | |
| 'Malaria (Override)': prob_dict['Malaria'], | |
| 'SCA (Override)': prob_dict['SCA'], | |
| 'Negative': 0.0 | |
| } | |
| else: | |
| pred_idx = np.argmax(probs) | |
| primary_diag = target_names[pred_idx] | |
| # If AI predicted co-infection without triggering rules, still give it a specific name | |
| if primary_diag == 'Co-infection': | |
| primary_diag = get_specific_coinfection_type(hb, reticulocyte, hb_rapid_decline, hb_s) | |
| prob_dict[primary_diag] = prob_dict.pop('Co-infection') | |
| # Formatting Output Markdown | |
| diag_output = f"## Primary Diagnosis: {primary_diag}\n\n### Comprehensive Confidence Breakdown:\n" | |
| # Sort and display probabilities descending | |
| sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) | |
| for disease, conf in sorted_probs: | |
| if 'Co-infection' in disease and 'Override' not in disease: | |
| diag_output += f"- **{disease}**: {conf:.1f}%\n" | |
| else: | |
| diag_output += f"- **{disease}**: {conf:.1f}%\n" | |
| recs = get_clinical_recs(primary_diag, rule_triggered) | |
| fig = generate_shap_plot(X_scaled) | |
| return diag_output, recs, fig | |
| except Exception as e: | |
| return f"### Inference Error\n```\n{traceback.format_exc()}\n```", "System Error.", None | |
| # ========================================== | |
| # 3. SYSTEM VALIDATION HELPER FUNCTIONS | |
| # ========================================== | |
| def load_systematic_metrics(): | |
| try: | |
| y_test_val = joblib.load('y_test_val.pkl') | |
| y_probs_val = joblib.load('y_probs_val.pkl') | |
| y_pred_val = np.argmax(y_probs_val, axis=1) | |
| acc = accuracy_score(y_test_val, y_pred_val) | |
| cm = confusion_matrix(y_test_val, y_pred_val) | |
| sens_list, spec_list = [],[] | |
| for i in range(len(cm)): | |
| tp = cm[i,i] | |
| fn = np.sum(cm[i,:]) - tp | |
| fp = np.sum(cm[:,i]) - tp | |
| tn = np.sum(cm) - tp - fn - fp | |
| sens_list.append(tp / (tp + fn) if (tp + fn) > 0 else 0) | |
| spec_list.append(tn / (tn + fp) if (tn + fp) > 0 else 0) | |
| sens = np.mean(sens_list) | |
| spec = np.mean(spec_list) | |
| return f"### Systematic Evaluation Metrics (Held-out Cohort)\n\n- **Overall Accuracy**: {acc*100:.2f}%\n- **Sensitivity (Macro)**: {sens*100:.2f}%\n- **Specificity (Macro)**: {spec*100:.2f}%" | |
| except Exception as e: | |
| return f"Error loading validation metrics: Ensure 'y_test_val.pkl' and 'y_probs_val.pkl' exist in memory. \n({str(e)})" | |
| def check_calibration(class_name): | |
| try: | |
| from sklearn.calibration import CalibrationDisplay | |
| y_test_val = joblib.load('y_test_val.pkl') | |
| y_probs_val = joblib.load('y_probs_val.pkl') | |
| class_idx = target_names.index(class_name) | |
| y_true_binary = (y_test_val == class_idx).astype(int) | |
| y_prob_class = y_probs_val[:, class_idx] | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| CalibrationDisplay.from_predictions(y_true_binary, y_prob_class, n_bins=10, ax=ax, name=class_name) | |
| plt.title(f"Reliability Curve (Calibration) for {class_name}", fontweight='bold') | |
| plt.tight_layout() | |
| return fig | |
| except Exception as e: | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, f"Calibration Error:\n{str(e)}", ha='center') | |
| return fig | |
| # ========================================== | |
| # 4. GRADIO UI DEFINITION | |
| # ========================================== | |
| custom_theme = gr.themes.Monochrome( | |
| primary_hue="slate", | |
| secondary_hue="gray", | |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] | |
| ) | |
| # 10 Detailed Clinical Examples spanning all feature variations | |
| clinical_examples = [ | |
| #[age, sex, temp, hb, wbc, platelets, hb_a, hb_s, hb_f, rdt, retic, hb_decline, fever, chills, headache, muscle, fatigue, appetite, jaundice, abd_pain, joint_pain, spleno, pallor, lymph][8, "Male", 39.5, 11.5, 9.5, 150, 98.0, 0.0, 2.0, "Positive", 1.5, False, True, True, True, True, True, True, False, False, False, False, False, False], # 1. Uncomplicated Malaria[22, "Female", 39.0, 7.5, 12.0, 90, 95.0, 0.0, 2.0, "Positive", 4.0, False, True, True, True, True, True, True, True, False, False, True, True, False], # 2. Severe Malaria[15, "Male", 37.2, 8.0, 11.0, 250, 5.0, 85.0, 10.0, "Negative", 6.0, False, False, False, False, True, True, False, True, True, True, False, True, False], # 3. SCA Vaso-occlusive Crisis[18, "Female", 37.5, 4.5, 14.0, 300, 2.0, 90.0, 8.0, "Negative", 10.0, True, False, False, False, False, True, False, True, False, True, True, True, False], # 4. SCA Hyperhemolytic (Trigger Hb<5)[12, "Male", 38.8, 6.5, 16.0, 110, 10.0, 80.0, 10.0, "Positive", 9.5, False, True, True, True, True, True, True, True, True, True, True, True, False], # 5. Co-infection (Acute Hemolytic, Retic>8)[25, "Female", 39.2, 7.0, 15.0, 100, 5.0, 85.0, 10.0, "Positive", 5.0, True, True, True, True, True, True, True, True, False, True, True, True, False], # 6. Co-infection (Rapidly Progressing)[30, "Male", 36.8, 14.0, 6.5, 250, 98.0, 0.0, 2.0, "Negative", 1.0, False, False, False, False, False, False, False, False, False, False, False, False, False], # 7. Healthy Adult[45, "Female", 37.8, 13.5, 5.0, 210, 97.0, 0.0, 2.0, "Negative", 1.2, False, True, False, True, True, True, False, False, False, False, False, False, True], # 8. Viral Infection (Non-malarial)[10, "Male", 39.8, 6.0, 18.0, 80, 95.0, 0.0, 3.0, "Positive", 7.0, False, True, True, True, False, True, True, True, True, False, True, True, False], # 9. Malaria with Overlapping Symptoms[28, "Female", 37.0, 12.5, 7.0, 220, 60.0, 38.0, 2.0, "Negative", 1.5, False, False, False, False, False, False, False, False, False, False, False, False, False] # 10. SCA Trait (Asymptomatic) | |
| ] | |
| with gr.Blocks(theme=custom_theme, title="Hemaclass Clinical Dashboard") as demo: | |
| gr.Markdown("# Hemaclass Clinical Decision Support System") | |
| gr.Markdown("Deep Stacking Ensemble Model for Malaria and Sickle Cell Anemia Classification.") | |
| with gr.Tabs(): | |
| # --- TAB 1: CORE INFERENCE --- | |
| with gr.TabItem("Single Patient Validation"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Demographics & Vitals") | |
| with gr.Row(): | |
| age_in = gr.Number(label="Age", value=25) | |
| sex_in = gr.Dropdown(["Male", "Female"], label="Sex", value="Female") | |
| temp_in = gr.Number(label="Temperature (°C)", value=37.5) | |
| gr.Markdown("### Clinical Symptoms") | |
| with gr.Row(): | |
| fever_in = gr.Checkbox(label="Fever") | |
| chills_in = gr.Checkbox(label="Chills") | |
| headache_in = gr.Checkbox(label="Headache") | |
| fatigue_in = gr.Checkbox(label="Fatigue") | |
| with gr.Row(): | |
| jaundice_in = gr.Checkbox(label="Jaundice") | |
| splenomegaly_in = gr.Checkbox(label="Splenomegaly") | |
| pallor_in = gr.Checkbox(label="Severe Pallor") | |
| muscle_in = gr.Checkbox(label="Muscle Aches") | |
| with gr.Accordion("Additional Symptoms", open=False): | |
| loss_appetite_in = gr.Checkbox(label="Loss of Appetite") | |
| abd_pain_in = gr.Checkbox(label="Abdominal Pain") | |
| joint_pain_in = gr.Checkbox(label="Joint Pain") | |
| lymph_in = gr.Checkbox(label="Lymphadenopathy") | |
| gr.Markdown("### Critical Laboratory Markers") | |
| with gr.Row(): | |
| rdt_in = gr.Radio(["Negative", "Positive"], label="Malaria RDT", value="Negative") | |
| retic_in = gr.Number(label="Reticulocyte Count (%)", value=2.0) | |
| with gr.Row(): | |
| hb_in = gr.Number(label="Hemoglobin (g/dL)", value=12.0) | |
| hb_decline_in = gr.Checkbox(label="Rapid Hb Decline (>1.5g/dl in 48h)") | |
| with gr.Row(): | |
| hb_a_in = gr.Number(label="HbA Fraction (%)", value=98.0) | |
| hb_s_in = gr.Number(label="HbS Fraction (%)", value=0.0) | |
| hb_f_in = gr.Number(label="HbF Fraction (%)", value=2.0) | |
| with gr.Row(): | |
| wbc_in = gr.Number(label="WBC Count (x10^9/L)", value=8.0) | |
| platelets_in = gr.Number(label="Platelet Count", value=200) | |
| manual_btn = gr.Button("Validate Diagnosis", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### System Output") | |
| out_diag = gr.Markdown() | |
| out_recs = gr.Markdown() | |
| out_shap = gr.Plot(label="Feature Contribution Analysis") | |
| gr.Markdown("---") | |
| gr.Markdown("### Load Clinical Scenarios") | |
| gr.Markdown("Select a predefined clinical case to auto-populate the diagnostic fields.") | |
| input_components =[ | |
| age_in, sex_in, temp_in, hb_in, wbc_in, platelets_in, hb_a_in, hb_s_in, hb_f_in, | |
| rdt_in, retic_in, hb_decline_in, fever_in, chills_in, headache_in, muscle_in, | |
| fatigue_in, loss_appetite_in, jaundice_in, abd_pain_in, joint_pain_in, | |
| splenomegaly_in, pallor_in, lymph_in | |
| ] | |
| gr.Examples( | |
| examples=clinical_examples, | |
| inputs=input_components, | |
| label="Predefined Patient Cases" | |
| ) | |
| manual_btn.click( | |
| manual_inference, | |
| inputs=input_components, | |
| outputs=[out_diag, out_recs, out_shap] | |
| ) | |
| # --- TAB 2: PERFORMANCE METRICS --- | |
| with gr.TabItem("Systematic Testing"): | |
| gr.Markdown("### Overall Model Performance on Unseen Test Cohort") | |
| metrics_btn = gr.Button("Calculate Systematic Metrics", variant="secondary") | |
| out_metrics = gr.Markdown() | |
| metrics_btn.click(load_systematic_metrics, inputs=[], outputs=[out_metrics]) | |
| # --- TAB 3: ADVANCED CALIBRATION --- | |
| with gr.TabItem("Advanced Validation"): | |
| gr.Markdown("### Evaluate Diagnosis Calibration") | |
| gr.Markdown("Select a disease class below to verify the alignment between predicted probabilities and true clinical frequencies.") | |
| with gr.Row(): | |
| class_dropdown = gr.Dropdown(target_names, label="Select Target Class", value="Co-infection") | |
| calib_btn = gr.Button("Check Calibration", variant="secondary") | |
| out_calib = gr.Plot() | |
| calib_btn.click(check_calibration, inputs=[class_dropdown], outputs=[out_calib]) | |
| # Launch inside Colab | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |