SheGuard / src /evaluate.py
3v324v23's picture
Deploy SheGuard - Maternal Risk Assessment with Mamba3 SSM
9686dbe
# src/evaluate.py
"""
MamaGuard β€” Model Evaluation
Produces accuracy, per-class metrics, confusion matrix, ROC-AUC, and a text report.
Usage: python -m src.evaluate
"""
import os
import pickle
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
roc_auc_score,
roc_curve,
)
from src.data_pipeline import load_and_preprocess
from src.model import MamaGuardMamba3
# ── Config ────────────────────────────────────────────────────────────────────
CSV_PATH = "data/maternal_health.csv"
MODEL_PATH = "models/mamaguard_mamba3.pt"
SCALER_PATH = "models/scaler.pkl"
REPORT_PATH = "models/evaluation_report.txt"
CM_PATH = "models/confusion_matrix.png"
ROC_PATH = "models/roc_curves.png"
CLASS_NAMES = ["Low risk", "Medium risk", "High risk"]
CLASS_COLORS = ["#2e7d32", "#e65100", "#c62828"]
# ── Load model ────────────────────────────────────────────────────────────────
def load_model(device: str):
"""Load trained model weights into a fresh MamaGuardMamba3 instance."""
model = MamaGuardMamba3(
input_dim=6, d_model=64, n_layers=4, n_classes=3, d_state=32
)
model.load_state_dict(
torch.load(MODEL_PATH, map_location=device)
)
model.to(device)
model.eval()
return model
# ── Inference ─────────────────────────────────────────────────────────────────
def get_predictions(model, X: np.ndarray, device: str, batch_size: int = 64):
"""
Run the model on all validation sequences in batches.
Returns: y_pred (N,), y_proba (N, 3)
"""
model.eval()
all_probs = []
for i in range(0, len(X), batch_size):
batch = torch.tensor(X[i:i+batch_size], dtype=torch.float32).to(device)
with torch.no_grad():
logits = model(batch)
probs = F.softmax(logits, dim=-1).cpu().numpy()
all_probs.append(probs)
y_proba = np.vstack(all_probs)
y_pred = y_proba.argmax(axis=1)
return y_pred, y_proba
# ── Confusion matrix plot ─────────────────────────────────────────────────────
def plot_confusion_matrix(cm: np.ndarray, save_path: str):
"""Plot raw and normalised confusion matrices side by side."""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle("MamaGuard Confusion Matrix", fontsize=14, fontweight="bold")
for ax_idx, (data, title) in enumerate([
(cm, "Raw counts"),
(cm.astype(float) / cm.sum(axis=1, keepdims=True), "Normalised (row %)")
]):
im = axes[ax_idx].imshow(data, cmap="RdYlGn", vmin=0,
vmax=(1 if ax_idx == 1 else None))
axes[ax_idx].set_xticks(range(3))
axes[ax_idx].set_yticks(range(3))
axes[ax_idx].set_xticklabels(CLASS_NAMES, rotation=15, ha="right")
axes[ax_idx].set_yticklabels(CLASS_NAMES)
axes[ax_idx].set_xlabel("Predicted label")
axes[ax_idx].set_ylabel("True label")
axes[ax_idx].set_title(title)
for i in range(3):
for j in range(3):
val = data[i, j]
text = f"{val:.2f}" if ax_idx == 1 else str(int(val))
color = "white" if (ax_idx == 1 and val < 0.4) or \
(ax_idx == 0 and val > cm.max() * 0.6) else "black"
axes[ax_idx].text(j, i, text, ha="center", va="center",
fontsize=11, color=color, fontweight="bold")
plt.colorbar(im, ax=axes[ax_idx], fraction=0.046, pad=0.04)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f" Confusion matrix saved -> {save_path}")
# ── ROC curves ────────────────────────────────────────────────────────────────
def plot_roc_curves(y_true: np.ndarray, y_proba: np.ndarray, save_path: str):
"""Plot one-vs-rest ROC curves per class with AUC scores."""
from sklearn.preprocessing import label_binarize
y_bin = label_binarize(y_true, classes=[0, 1, 2])
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label="Random (AUC = 0.50)")
for i, (class_name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)):
fpr, tpr, _ = roc_curve(y_bin[:, i], y_proba[:, i])
auc = roc_auc_score(y_bin[:, i], y_proba[:, i])
ax.plot(fpr, tpr, color=color, linewidth=2,
label=f"{class_name} (AUC = {auc:.3f})")
ax.set_xlabel("False Positive Rate", fontsize=12)
ax.set_ylabel("True Positive Rate (Recall)", fontsize=12)
ax.set_title("ROC Curves β€” MamaGuard Mamba3", fontsize=13, fontweight="bold")
ax.legend(loc="lower right", fontsize=10)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f" ROC curves saved -> {save_path}")
# ── Text report ───────────────────────────────────────────────────────────────
def build_text_report(
y_true, y_pred, y_proba,
class_report: str,
cm: np.ndarray,
n_train: int,
n_val: int
) -> str:
"""Build a complete text report for model card / README."""
overall_acc = accuracy_score(y_true, y_pred)
from sklearn.preprocessing import label_binarize
y_bin = label_binarize(y_true, classes=[0, 1, 2])
aucs = [roc_auc_score(y_bin[:, i], y_proba[:, i]) for i in range(3)]
hr_recall = cm[2, 2] / max(cm[2, :].sum(), 1)
hr_precision = cm[2, 2] / max(cm[:, 2].sum(), 1)
report = f"""
╔══════════════════════════════════════════════════════════════════════════╗
β•‘ MAMAGUARD β€” MODEL EVALUATION REPORT β•‘
β•‘ Generated automatically by src/evaluate.py β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
MODEL
Architecture : MamaGuard-Mamba3 (Trapezoidal SSM + MIMO + Complex state)
Parameters : ~287,815
Input : Sequence of prenatal visits (up to 5) Γ— 6 vital signs
Output : 3-class risk prediction (Low / Medium / High)
Dataset : UCI Maternal Health Risk (1,014 rows)
Training set : {n_train} sequences
Validation set: {n_val} sequences
──────────────────────────────────────────────────────────────────────────
PERFORMANCE METRICS (on held-out validation set)
──────────────────────────────────────────────────────────────────────────
Overall accuracy : {overall_acc:.3f} ({overall_acc*100:.1f}%)
ROC-AUC per class:
Low risk : {aucs[0]:.3f}
Medium risk : {aucs[1]:.3f}
High risk : {aucs[2]:.3f}
** HIGH RISK RECALL (most critical for patient safety):
{hr_recall:.3f} β€” of all truly high-risk patients, {hr_recall*100:.1f}% were correctly flagged
** HIGH RISK PRECISION (alarm fatigue indicator):
{hr_precision:.3f} β€” of all patients flagged high-risk, {hr_precision*100:.1f}% truly were
Detailed per-class report:
{class_report}
──────────────────────────────────────────────────────────────────────────
CONFUSION MATRIX (raw counts)
──────────────────────────────────────────────────────────────────────────
Rows = true label, Columns = predicted label
Pred: Low Pred: Mid Pred: High
True: Low {cm[0,0]:5d} {cm[0,1]:5d} {cm[0,2]:5d}
True: Mid {cm[1,0]:5d} {cm[1,1]:5d} {cm[1,2]:5d}
True: High {cm[2,0]:5d} {cm[2,1]:5d} {cm[2,2]:5d}
Most dangerous mistakes (False Negatives for High Risk):
High-risk patients predicted as Low risk : {cm[2,0]}
High-risk patients predicted as Mid risk : {cm[2,1]}
──────────────────────────────────────────────────────────────────────────
HYBRID SYSTEM NOTE
──────────────────────────────────────────────────────────────────────────
MamaGuard uses a HYBRID architecture:
1. Mamba3 model: handles subtle temporal patterns (learned from data)
2. WHO clinical rules: hard overrides for obvious danger signs
- Rule 1: SystolicBP >= 160 -> RED
- Rule 2: SystolicBP >= 140 -> AMBER minimum
- Rule 3: Blood sugar > 11.1 -> AMBER minimum
- Rule 4: BP rise >= 20 mmHg -> AMBER minimum
- Rule 5: 3+ vitals escalating simultaneously -> RED
The metrics above reflect the NEURAL MODEL ONLY (without clinical rules).
In deployment, the clinical rules provide an additional safety floor,
meaning real-world recall for high-risk cases is higher than shown above.
──────────────────────────────────────────────────────────────────────────
LIMITATIONS
──────────────────────────────────────────────────────────────────────────
1. SMALL DATASET: Trained on 1,014 rows from a single UCI dataset.
Real-world clinical models typically require 10,000–100,000+ samples.
2. SYNTHETIC SEQUENCES: The UCI dataset has no patient IDs or timestamps.
We created artificial 5-visit sequences by sorting rows by age.
These do not represent real patient trajectories.
3. NOT CLINICALLY VALIDATED: This model has NOT been validated against
real patient outcomes in a clinical setting. It must NOT be used
for actual medical decisions without proper clinical validation trials.
4. POPULATION BIAS: The UCI dataset was collected from a specific
population. Performance may differ on patients from different
regions, ethnicities, or healthcare contexts.
5. RESEARCH PROTOTYPE: This is a proof-of-concept demonstrating the
application of Mamba3 SSMs to maternal health risk prediction.
The system design (alarm fatigue mitigation, resource-aware routing,
OCR auto-fill) represents the primary contribution of this work.
──────────────────────────────────────────────────────────────────────────
CITATION
──────────────────────────────────────────────────────────────────────────
If you use this work, please cite:
MamaGuard: Maternal Mortality Early Warning using Mamba3 Sequential
State-Space Models with Clinical Safety Rules.
[Your Name], 2025. GitHub: [your-repo-url]
Based on: Gu & Dao (2023) Mamba; UCI Maternal Health Risk dataset.
══════════════════════════════════════════════════════════════════════════
"""
return report
# ── Main ──────────────────────────────────────────────────────────────────────
def evaluate():
print("\n" + "="*60)
print(" MamaGuard -- Model Evaluation")
print("="*60 + "\n")
for path, name in [(MODEL_PATH, "model"), (SCALER_PATH, "scaler"), (CSV_PATH, "dataset")]:
if not os.path.exists(path):
print(f"ERROR: {name} not found at {path}")
print("Run python -m src.train first.")
return
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
# Load data
print("\nLoading and preprocessing data...")
(X_train, y_train, q_train,
X_val, y_val, q_val,
scaler) = load_and_preprocess(CSV_PATH)
n_train = len(X_train)
n_val = len(X_val)
print(f"Validation set: {n_val} sequences")
print(f"Class distribution in validation: "
f"Low={sum(y_val==0)} Mid={sum(y_val==1)} High={sum(y_val==2)}")
# Load model and predict
print("\nLoading model...")
model = load_model(device)
print("Running inference on validation set...")
y_pred, y_proba = get_predictions(model, X_val, device)
# Compute metrics
print("\nComputing metrics...")
overall_acc = accuracy_score(y_val, y_pred)
cm = confusion_matrix(y_val, y_pred)
class_report = classification_report(
y_val, y_pred,
target_names=CLASS_NAMES,
digits=3
)
print(f"\n{'-'*50}")
print(f" Overall Accuracy: {overall_acc:.3f} ({overall_acc*100:.1f}%)")
print(f"{'-'*50}")
print("\nPer-class metrics:")
print(class_report)
print("Confusion matrix (rows=true, cols=predicted):")
header = f"{'':15s} {'Low':>8s} {'Mid':>8s} {'High':>8s}"
print(header)
for i, name in enumerate(CLASS_NAMES):
row = f" {name:13s} " + " ".join(f"{cm[i,j]:8d}" for j in range(3))
print(row)
hr_recall = cm[2, 2] / max(cm[2, :].sum(), 1)
print(f"\n** HIGH RISK RECALL: {hr_recall:.3f}")
if hr_recall < 0.60:
print(" [!] WARNING: Less than 60% of high-risk patients detected by model alone.")
print(" The WHO clinical rules provide the safety floor in deployment.")
elif hr_recall < 0.75:
print(" [!] Moderate. Consider retraining with more data (Path B).")
else:
print(" [OK] Good recall -- model is learning the high-risk pattern.")
# Save plots
print(f"\nSaving evaluation plots...")
os.makedirs("models", exist_ok=True)
plot_confusion_matrix(cm, CM_PATH)
plot_roc_curves(y_val, y_proba, ROC_PATH)
# Save text report
report_text = build_text_report(
y_val, y_pred, y_proba,
class_report, cm, n_train, n_val
)
with open(REPORT_PATH, "w", encoding="utf-8") as f:
f.write(report_text)
print(f" Full report saved -> {REPORT_PATH}")
# Final summary
print(f"\n{'='*60}")
print(" EVALUATION COMPLETE")
print(f"{'='*60}")
print(f"\n Files saved:")
print(f" {REPORT_PATH} <- paste into Hugging Face model card")
print(f" {CM_PATH} <- include in LinkedIn post")
print(f" {ROC_PATH} <- include in GitHub README")
print(f"\n Overall accuracy : {overall_acc*100:.1f}%")
print(f" High-risk recall : {hr_recall*100:.1f}%")
if hr_recall < 0.60:
print("\n RECOMMENDATION: Retrain with augmented data before publishing.")
print(" Current model relies heavily on WHO clinical rules for safety.")
print(" This is still publishable as a research prototype -- be transparent.")
else:
print("\n RECOMMENDATION: Model is ready to publish as research prototype.")
print(" Include the evaluation_report.txt in your model card.")
print()
if __name__ == "__main__":
evaluate()