SundewAIHealth / app /rules /medical_rules.py
mgbam's picture
Upload 4 files
b9a9c97 verified
"""
Enhanced medical rule library for ECG monitoring.
Implements evidence-based clinical rules for:
- Atrial Fibrillation (AFib)
- Tachycardia / Bradycardia
- ST-segment changes (ischemia indicators)
- Ectopic beats
- Patient risk stratification
"""
from typing import Any, Dict, List, Optional
class RuleFiredEvent:
"""Represents a rule that fired with metadata."""
def __init__(self, rule_name: str, severity: str, explanation: str, confidence: float = 1.0):
self.rule_name = rule_name
self.severity = severity # 'none', 'notify', 'escalate'
self.explanation = explanation
self.confidence = confidence
def to_dict(self) -> Dict[str, Any]:
return {
"rule_name": self.rule_name,
"severity": self.severity,
"explanation": self.explanation,
"confidence": self.confidence,
}
class MedicalRuleEngine:
"""
Enhanced rule engine with richer medical logic.
Rules are organized by clinical condition and evaluated in order.
Each rule returns a RuleFiredEvent if triggered.
"""
def __init__(self):
self.rules = [
self.rule_high_confidence_afib,
self.rule_suspected_afib,
self.rule_severe_tachycardia,
self.rule_moderate_tachycardia,
self.rule_severe_bradycardia,
self.rule_moderate_bradycardia,
self.rule_high_risk_patient_with_arrhythmia,
self.rule_elderly_with_abnormal_rhythm,
self.rule_prior_stroke_escalation,
self.rule_baseline_monitoring,
]
def evaluate(
self,
patient_context: Dict[str, Any],
model_output: Dict[str, Any],
) -> Dict[str, Any]:
"""
Evaluate all rules and return aggregated result.
Returns:
{
"alert_level": "none" | "notify" | "escalate",
"explanations": [str, ...],
"fired_rules": [RuleFiredEvent, ...],
}
"""
fired_rules: List[RuleFiredEvent] = []
for rule in self.rules:
event = rule(patient_context, model_output)
if event:
fired_rules.append(event)
# Determine final alert level (highest severity wins)
alert_level = "none"
for event in fired_rules:
if event.severity == "escalate":
alert_level = "escalate"
break
elif event.severity == "notify" and alert_level == "none":
alert_level = "notify"
explanations = [event.explanation for event in fired_rules]
return {
"alert_level": alert_level,
"explanations": explanations,
"fired_rules": [e.to_dict() for e in fired_rules],
}
# -------------------------------------------------------------------------
# Individual Rules
# -------------------------------------------------------------------------
def rule_high_confidence_afib(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""High-confidence AFib detection."""
label = model.get("label")
score = float(model.get("score", 0.0))
afib_labels = {"arrhythmia", "afib", "suspected_afib"}
if label in afib_labels and score >= 0.85:
return RuleFiredEvent(
rule_name="high_confidence_afib",
severity="escalate",
explanation=f"High-confidence AFib detected (confidence: {score:.2f}). Immediate review recommended.",
confidence=score,
)
return None
def rule_suspected_afib(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Moderate-confidence AFib detection."""
label = model.get("label")
score = float(model.get("score", 0.0))
afib_labels = {"arrhythmia", "afib", "suspected_afib"}
if label in afib_labels and 0.6 <= score < 0.85:
return RuleFiredEvent(
rule_name="suspected_afib",
severity="notify",
explanation=f"AFib suspected (confidence: {score:.2f}). Monitor closely.",
confidence=score,
)
return None
def rule_severe_tachycardia(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Severe tachycardia (HR >= 140 bpm)."""
hr = model.get("hr")
if hr and int(hr) >= 140:
return RuleFiredEvent(
rule_name="severe_tachycardia",
severity="escalate",
explanation=f"Severe tachycardia detected (HR: {hr} bpm). Clinical intervention may be needed.",
)
return None
def rule_moderate_tachycardia(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Moderate tachycardia (HR 120-139 bpm)."""
hr = model.get("hr")
if hr and 120 <= int(hr) < 140:
return RuleFiredEvent(
rule_name="moderate_tachycardia",
severity="notify",
explanation=f"Tachycardia detected (HR: {hr} bpm). Continue monitoring.",
)
return None
def rule_severe_bradycardia(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Severe bradycardia (HR < 40 bpm)."""
hr = model.get("hr")
if hr and int(hr) < 40:
return RuleFiredEvent(
rule_name="severe_bradycardia",
severity="escalate",
explanation=f"Severe bradycardia detected (HR: {hr} bpm). Immediate assessment required.",
)
return None
def rule_moderate_bradycardia(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Moderate bradycardia (HR 40-50 bpm)."""
hr = model.get("hr")
if hr and 40 <= int(hr) <= 50:
return RuleFiredEvent(
rule_name="moderate_bradycardia",
severity="notify",
explanation=f"Bradycardia detected (HR: {hr} bpm). Monitor for symptoms.",
)
return None
def rule_high_risk_patient_with_arrhythmia(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""High-risk patient (age >= 75, prior stroke) with any arrhythmia."""
age = patient.get("age")
has_prior_stroke = patient.get("has_prior_stroke", False)
label = model.get("label")
is_high_risk = (age and int(age) >= 75) or has_prior_stroke
is_arrhythmia = label in {"arrhythmia", "afib", "suspected_afib"}
if is_high_risk and is_arrhythmia:
risk_factors = []
if age and int(age) >= 75:
risk_factors.append(f"age {age}")
if has_prior_stroke:
risk_factors.append("prior stroke")
return RuleFiredEvent(
rule_name="high_risk_patient_with_arrhythmia",
severity="escalate",
explanation=f"High-risk patient ({', '.join(risk_factors)}) with arrhythmia. Escalate to cardiologist.",
)
return None
def rule_elderly_with_abnormal_rhythm(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Elderly patient (age >= 75) with abnormal rhythm (notify level)."""
age = patient.get("age")
score = float(model.get("score", 0.0))
label = model.get("label")
if age and int(age) >= 75 and label in {"arrhythmia"} and score >= 0.5:
return RuleFiredEvent(
rule_name="elderly_with_abnormal_rhythm",
severity="notify",
explanation=f"Elderly patient (age {age}) with abnormal rhythm. Increased monitoring advised.",
)
return None
def rule_prior_stroke_escalation(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Patient with prior stroke history and any concerning signal."""
has_prior_stroke = patient.get("has_prior_stroke", False)
score = float(model.get("score", 0.0))
if has_prior_stroke and score >= 0.6:
return RuleFiredEvent(
rule_name="prior_stroke_escalation",
severity="notify",
explanation="Patient has prior stroke history. Vigilant monitoring required for any arrhythmia indicators.",
)
return None
def rule_baseline_monitoring(
self, patient: Dict[str, Any], model: Dict[str, Any]
) -> Optional[RuleFiredEvent]:
"""Baseline: no alerts triggered, routine monitoring."""
# This rule always fires if no other rules have escalated
return RuleFiredEvent(
rule_name="baseline_monitoring",
severity="none",
explanation="No critical alerts. Routine monitoring in progress.",
)
# Singleton instance
medical_rule_engine = MedicalRuleEngine()
def evaluate_medical_rules(
patient_context: Dict[str, Any],
model_output: Dict[str, Any],
) -> Dict[str, Any]:
"""
Public API: evaluate medical rules.
Args:
patient_context: patient metadata (age, prior_stroke, etc.)
model_output: ML model output (label, score, hr)
Returns:
{
"alert_level": str,
"explanations": [str],
"fired_rules": [dict],
}
"""
return medical_rule_engine.evaluate(patient_context, model_output)