Spaces:
Sleeping
Sleeping
| # src/evaluate.py | |
| """ | |
| MamaGuard β Model Evaluation | |
| Produces accuracy, per-class metrics, confusion matrix, ROC-AUC, and a text report. | |
| Usage: python -m src.evaluate | |
| """ | |
| import os | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| classification_report, | |
| confusion_matrix, | |
| roc_auc_score, | |
| roc_curve, | |
| ) | |
| from src.data_pipeline import load_and_preprocess | |
| from src.model import MamaGuardMamba3 | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSV_PATH = "data/maternal_health.csv" | |
| MODEL_PATH = "models/mamaguard_mamba3.pt" | |
| SCALER_PATH = "models/scaler.pkl" | |
| REPORT_PATH = "models/evaluation_report.txt" | |
| CM_PATH = "models/confusion_matrix.png" | |
| ROC_PATH = "models/roc_curves.png" | |
| CLASS_NAMES = ["Low risk", "Medium risk", "High risk"] | |
| CLASS_COLORS = ["#2e7d32", "#e65100", "#c62828"] | |
| # ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(device: str): | |
| """Load trained model weights into a fresh MamaGuardMamba3 instance.""" | |
| model = MamaGuardMamba3( | |
| input_dim=6, d_model=64, n_layers=4, n_classes=3, d_state=32 | |
| ) | |
| model.load_state_dict( | |
| torch.load(MODEL_PATH, map_location=device) | |
| ) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_predictions(model, X: np.ndarray, device: str, batch_size: int = 64): | |
| """ | |
| Run the model on all validation sequences in batches. | |
| Returns: y_pred (N,), y_proba (N, 3) | |
| """ | |
| model.eval() | |
| all_probs = [] | |
| for i in range(0, len(X), batch_size): | |
| batch = torch.tensor(X[i:i+batch_size], dtype=torch.float32).to(device) | |
| with torch.no_grad(): | |
| logits = model(batch) | |
| probs = F.softmax(logits, dim=-1).cpu().numpy() | |
| all_probs.append(probs) | |
| y_proba = np.vstack(all_probs) | |
| y_pred = y_proba.argmax(axis=1) | |
| return y_pred, y_proba | |
| # ββ Confusion matrix plot βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_confusion_matrix(cm: np.ndarray, save_path: str): | |
| """Plot raw and normalised confusion matrices side by side.""" | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.suptitle("MamaGuard Confusion Matrix", fontsize=14, fontweight="bold") | |
| for ax_idx, (data, title) in enumerate([ | |
| (cm, "Raw counts"), | |
| (cm.astype(float) / cm.sum(axis=1, keepdims=True), "Normalised (row %)") | |
| ]): | |
| im = axes[ax_idx].imshow(data, cmap="RdYlGn", vmin=0, | |
| vmax=(1 if ax_idx == 1 else None)) | |
| axes[ax_idx].set_xticks(range(3)) | |
| axes[ax_idx].set_yticks(range(3)) | |
| axes[ax_idx].set_xticklabels(CLASS_NAMES, rotation=15, ha="right") | |
| axes[ax_idx].set_yticklabels(CLASS_NAMES) | |
| axes[ax_idx].set_xlabel("Predicted label") | |
| axes[ax_idx].set_ylabel("True label") | |
| axes[ax_idx].set_title(title) | |
| for i in range(3): | |
| for j in range(3): | |
| val = data[i, j] | |
| text = f"{val:.2f}" if ax_idx == 1 else str(int(val)) | |
| color = "white" if (ax_idx == 1 and val < 0.4) or \ | |
| (ax_idx == 0 and val > cm.max() * 0.6) else "black" | |
| axes[ax_idx].text(j, i, text, ha="center", va="center", | |
| fontsize=11, color=color, fontweight="bold") | |
| plt.colorbar(im, ax=axes[ax_idx], fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f" Confusion matrix saved -> {save_path}") | |
| # ββ ROC curves ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_roc_curves(y_true: np.ndarray, y_proba: np.ndarray, save_path: str): | |
| """Plot one-vs-rest ROC curves per class with AUC scores.""" | |
| from sklearn.preprocessing import label_binarize | |
| y_bin = label_binarize(y_true, classes=[0, 1, 2]) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label="Random (AUC = 0.50)") | |
| for i, (class_name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)): | |
| fpr, tpr, _ = roc_curve(y_bin[:, i], y_proba[:, i]) | |
| auc = roc_auc_score(y_bin[:, i], y_proba[:, i]) | |
| ax.plot(fpr, tpr, color=color, linewidth=2, | |
| label=f"{class_name} (AUC = {auc:.3f})") | |
| ax.set_xlabel("False Positive Rate", fontsize=12) | |
| ax.set_ylabel("True Positive Rate (Recall)", fontsize=12) | |
| ax.set_title("ROC Curves β MamaGuard Mamba3", fontsize=13, fontweight="bold") | |
| ax.legend(loc="lower right", fontsize=10) | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f" ROC curves saved -> {save_path}") | |
| # ββ Text report βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_text_report( | |
| y_true, y_pred, y_proba, | |
| class_report: str, | |
| cm: np.ndarray, | |
| n_train: int, | |
| n_val: int | |
| ) -> str: | |
| """Build a complete text report for model card / README.""" | |
| overall_acc = accuracy_score(y_true, y_pred) | |
| from sklearn.preprocessing import label_binarize | |
| y_bin = label_binarize(y_true, classes=[0, 1, 2]) | |
| aucs = [roc_auc_score(y_bin[:, i], y_proba[:, i]) for i in range(3)] | |
| hr_recall = cm[2, 2] / max(cm[2, :].sum(), 1) | |
| hr_precision = cm[2, 2] / max(cm[:, 2].sum(), 1) | |
| report = f""" | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β MAMAGUARD β MODEL EVALUATION REPORT β | |
| β Generated automatically by src/evaluate.py β | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL | |
| Architecture : MamaGuard-Mamba3 (Trapezoidal SSM + MIMO + Complex state) | |
| Parameters : ~287,815 | |
| Input : Sequence of prenatal visits (up to 5) Γ 6 vital signs | |
| Output : 3-class risk prediction (Low / Medium / High) | |
| Dataset : UCI Maternal Health Risk (1,014 rows) | |
| Training set : {n_train} sequences | |
| Validation set: {n_val} sequences | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PERFORMANCE METRICS (on held-out validation set) | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Overall accuracy : {overall_acc:.3f} ({overall_acc*100:.1f}%) | |
| ROC-AUC per class: | |
| Low risk : {aucs[0]:.3f} | |
| Medium risk : {aucs[1]:.3f} | |
| High risk : {aucs[2]:.3f} | |
| ** HIGH RISK RECALL (most critical for patient safety): | |
| {hr_recall:.3f} β of all truly high-risk patients, {hr_recall*100:.1f}% were correctly flagged | |
| ** HIGH RISK PRECISION (alarm fatigue indicator): | |
| {hr_precision:.3f} β of all patients flagged high-risk, {hr_precision*100:.1f}% truly were | |
| Detailed per-class report: | |
| {class_report} | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CONFUSION MATRIX (raw counts) | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Rows = true label, Columns = predicted label | |
| Pred: Low Pred: Mid Pred: High | |
| True: Low {cm[0,0]:5d} {cm[0,1]:5d} {cm[0,2]:5d} | |
| True: Mid {cm[1,0]:5d} {cm[1,1]:5d} {cm[1,2]:5d} | |
| True: High {cm[2,0]:5d} {cm[2,1]:5d} {cm[2,2]:5d} | |
| Most dangerous mistakes (False Negatives for High Risk): | |
| High-risk patients predicted as Low risk : {cm[2,0]} | |
| High-risk patients predicted as Mid risk : {cm[2,1]} | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HYBRID SYSTEM NOTE | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MamaGuard uses a HYBRID architecture: | |
| 1. Mamba3 model: handles subtle temporal patterns (learned from data) | |
| 2. WHO clinical rules: hard overrides for obvious danger signs | |
| - Rule 1: SystolicBP >= 160 -> RED | |
| - Rule 2: SystolicBP >= 140 -> AMBER minimum | |
| - Rule 3: Blood sugar > 11.1 -> AMBER minimum | |
| - Rule 4: BP rise >= 20 mmHg -> AMBER minimum | |
| - Rule 5: 3+ vitals escalating simultaneously -> RED | |
| The metrics above reflect the NEURAL MODEL ONLY (without clinical rules). | |
| In deployment, the clinical rules provide an additional safety floor, | |
| meaning real-world recall for high-risk cases is higher than shown above. | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LIMITATIONS | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| 1. SMALL DATASET: Trained on 1,014 rows from a single UCI dataset. | |
| Real-world clinical models typically require 10,000β100,000+ samples. | |
| 2. SYNTHETIC SEQUENCES: The UCI dataset has no patient IDs or timestamps. | |
| We created artificial 5-visit sequences by sorting rows by age. | |
| These do not represent real patient trajectories. | |
| 3. NOT CLINICALLY VALIDATED: This model has NOT been validated against | |
| real patient outcomes in a clinical setting. It must NOT be used | |
| for actual medical decisions without proper clinical validation trials. | |
| 4. POPULATION BIAS: The UCI dataset was collected from a specific | |
| population. Performance may differ on patients from different | |
| regions, ethnicities, or healthcare contexts. | |
| 5. RESEARCH PROTOTYPE: This is a proof-of-concept demonstrating the | |
| application of Mamba3 SSMs to maternal health risk prediction. | |
| The system design (alarm fatigue mitigation, resource-aware routing, | |
| OCR auto-fill) represents the primary contribution of this work. | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CITATION | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| If you use this work, please cite: | |
| MamaGuard: Maternal Mortality Early Warning using Mamba3 Sequential | |
| State-Space Models with Clinical Safety Rules. | |
| [Your Name], 2025. GitHub: [your-repo-url] | |
| Based on: Gu & Dao (2023) Mamba; UCI Maternal Health Risk dataset. | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| return report | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate(): | |
| print("\n" + "="*60) | |
| print(" MamaGuard -- Model Evaluation") | |
| print("="*60 + "\n") | |
| for path, name in [(MODEL_PATH, "model"), (SCALER_PATH, "scaler"), (CSV_PATH, "dataset")]: | |
| if not os.path.exists(path): | |
| print(f"ERROR: {name} not found at {path}") | |
| print("Run python -m src.train first.") | |
| return | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| # Load data | |
| print("\nLoading and preprocessing data...") | |
| (X_train, y_train, q_train, | |
| X_val, y_val, q_val, | |
| scaler) = load_and_preprocess(CSV_PATH) | |
| n_train = len(X_train) | |
| n_val = len(X_val) | |
| print(f"Validation set: {n_val} sequences") | |
| print(f"Class distribution in validation: " | |
| f"Low={sum(y_val==0)} Mid={sum(y_val==1)} High={sum(y_val==2)}") | |
| # Load model and predict | |
| print("\nLoading model...") | |
| model = load_model(device) | |
| print("Running inference on validation set...") | |
| y_pred, y_proba = get_predictions(model, X_val, device) | |
| # Compute metrics | |
| print("\nComputing metrics...") | |
| overall_acc = accuracy_score(y_val, y_pred) | |
| cm = confusion_matrix(y_val, y_pred) | |
| class_report = classification_report( | |
| y_val, y_pred, | |
| target_names=CLASS_NAMES, | |
| digits=3 | |
| ) | |
| print(f"\n{'-'*50}") | |
| print(f" Overall Accuracy: {overall_acc:.3f} ({overall_acc*100:.1f}%)") | |
| print(f"{'-'*50}") | |
| print("\nPer-class metrics:") | |
| print(class_report) | |
| print("Confusion matrix (rows=true, cols=predicted):") | |
| header = f"{'':15s} {'Low':>8s} {'Mid':>8s} {'High':>8s}" | |
| print(header) | |
| for i, name in enumerate(CLASS_NAMES): | |
| row = f" {name:13s} " + " ".join(f"{cm[i,j]:8d}" for j in range(3)) | |
| print(row) | |
| hr_recall = cm[2, 2] / max(cm[2, :].sum(), 1) | |
| print(f"\n** HIGH RISK RECALL: {hr_recall:.3f}") | |
| if hr_recall < 0.60: | |
| print(" [!] WARNING: Less than 60% of high-risk patients detected by model alone.") | |
| print(" The WHO clinical rules provide the safety floor in deployment.") | |
| elif hr_recall < 0.75: | |
| print(" [!] Moderate. Consider retraining with more data (Path B).") | |
| else: | |
| print(" [OK] Good recall -- model is learning the high-risk pattern.") | |
| # Save plots | |
| print(f"\nSaving evaluation plots...") | |
| os.makedirs("models", exist_ok=True) | |
| plot_confusion_matrix(cm, CM_PATH) | |
| plot_roc_curves(y_val, y_proba, ROC_PATH) | |
| # Save text report | |
| report_text = build_text_report( | |
| y_val, y_pred, y_proba, | |
| class_report, cm, n_train, n_val | |
| ) | |
| with open(REPORT_PATH, "w", encoding="utf-8") as f: | |
| f.write(report_text) | |
| print(f" Full report saved -> {REPORT_PATH}") | |
| # Final summary | |
| print(f"\n{'='*60}") | |
| print(" EVALUATION COMPLETE") | |
| print(f"{'='*60}") | |
| print(f"\n Files saved:") | |
| print(f" {REPORT_PATH} <- paste into Hugging Face model card") | |
| print(f" {CM_PATH} <- include in LinkedIn post") | |
| print(f" {ROC_PATH} <- include in GitHub README") | |
| print(f"\n Overall accuracy : {overall_acc*100:.1f}%") | |
| print(f" High-risk recall : {hr_recall*100:.1f}%") | |
| if hr_recall < 0.60: | |
| print("\n RECOMMENDATION: Retrain with augmented data before publishing.") | |
| print(" Current model relies heavily on WHO clinical rules for safety.") | |
| print(" This is still publishable as a research prototype -- be transparent.") | |
| else: | |
| print("\n RECOMMENDATION: Model is ready to publish as research prototype.") | |
| print(" Include the evaluation_report.txt in your model card.") | |
| print() | |
| if __name__ == "__main__": | |
| evaluate() |