""" Enhanced medical rule library for ECG monitoring. Implements evidence-based clinical rules for: - Atrial Fibrillation (AFib) - Tachycardia / Bradycardia - ST-segment changes (ischemia indicators) - Ectopic beats - Patient risk stratification """ from typing import Any, Dict, List, Optional class RuleFiredEvent: """Represents a rule that fired with metadata.""" def __init__(self, rule_name: str, severity: str, explanation: str, confidence: float = 1.0): self.rule_name = rule_name self.severity = severity # 'none', 'notify', 'escalate' self.explanation = explanation self.confidence = confidence def to_dict(self) -> Dict[str, Any]: return { "rule_name": self.rule_name, "severity": self.severity, "explanation": self.explanation, "confidence": self.confidence, } class MedicalRuleEngine: """ Enhanced rule engine with richer medical logic. Rules are organized by clinical condition and evaluated in order. Each rule returns a RuleFiredEvent if triggered. """ def __init__(self): self.rules = [ self.rule_high_confidence_afib, self.rule_suspected_afib, self.rule_severe_tachycardia, self.rule_moderate_tachycardia, self.rule_severe_bradycardia, self.rule_moderate_bradycardia, self.rule_high_risk_patient_with_arrhythmia, self.rule_elderly_with_abnormal_rhythm, self.rule_prior_stroke_escalation, self.rule_baseline_monitoring, ] def evaluate( self, patient_context: Dict[str, Any], model_output: Dict[str, Any], ) -> Dict[str, Any]: """ Evaluate all rules and return aggregated result. Returns: { "alert_level": "none" | "notify" | "escalate", "explanations": [str, ...], "fired_rules": [RuleFiredEvent, ...], } """ fired_rules: List[RuleFiredEvent] = [] for rule in self.rules: event = rule(patient_context, model_output) if event: fired_rules.append(event) # Determine final alert level (highest severity wins) alert_level = "none" for event in fired_rules: if event.severity == "escalate": alert_level = "escalate" break elif event.severity == "notify" and alert_level == "none": alert_level = "notify" explanations = [event.explanation for event in fired_rules] return { "alert_level": alert_level, "explanations": explanations, "fired_rules": [e.to_dict() for e in fired_rules], } # ------------------------------------------------------------------------- # Individual Rules # ------------------------------------------------------------------------- def rule_high_confidence_afib( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """High-confidence AFib detection.""" label = model.get("label") score = float(model.get("score", 0.0)) afib_labels = {"arrhythmia", "afib", "suspected_afib"} if label in afib_labels and score >= 0.85: return RuleFiredEvent( rule_name="high_confidence_afib", severity="escalate", explanation=f"High-confidence AFib detected (confidence: {score:.2f}). Immediate review recommended.", confidence=score, ) return None def rule_suspected_afib( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Moderate-confidence AFib detection.""" label = model.get("label") score = float(model.get("score", 0.0)) afib_labels = {"arrhythmia", "afib", "suspected_afib"} if label in afib_labels and 0.6 <= score < 0.85: return RuleFiredEvent( rule_name="suspected_afib", severity="notify", explanation=f"AFib suspected (confidence: {score:.2f}). Monitor closely.", confidence=score, ) return None def rule_severe_tachycardia( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Severe tachycardia (HR >= 140 bpm).""" hr = model.get("hr") if hr and int(hr) >= 140: return RuleFiredEvent( rule_name="severe_tachycardia", severity="escalate", explanation=f"Severe tachycardia detected (HR: {hr} bpm). Clinical intervention may be needed.", ) return None def rule_moderate_tachycardia( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Moderate tachycardia (HR 120-139 bpm).""" hr = model.get("hr") if hr and 120 <= int(hr) < 140: return RuleFiredEvent( rule_name="moderate_tachycardia", severity="notify", explanation=f"Tachycardia detected (HR: {hr} bpm). Continue monitoring.", ) return None def rule_severe_bradycardia( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Severe bradycardia (HR < 40 bpm).""" hr = model.get("hr") if hr and int(hr) < 40: return RuleFiredEvent( rule_name="severe_bradycardia", severity="escalate", explanation=f"Severe bradycardia detected (HR: {hr} bpm). Immediate assessment required.", ) return None def rule_moderate_bradycardia( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Moderate bradycardia (HR 40-50 bpm).""" hr = model.get("hr") if hr and 40 <= int(hr) <= 50: return RuleFiredEvent( rule_name="moderate_bradycardia", severity="notify", explanation=f"Bradycardia detected (HR: {hr} bpm). Monitor for symptoms.", ) return None def rule_high_risk_patient_with_arrhythmia( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """High-risk patient (age >= 75, prior stroke) with any arrhythmia.""" age = patient.get("age") has_prior_stroke = patient.get("has_prior_stroke", False) label = model.get("label") is_high_risk = (age and int(age) >= 75) or has_prior_stroke is_arrhythmia = label in {"arrhythmia", "afib", "suspected_afib"} if is_high_risk and is_arrhythmia: risk_factors = [] if age and int(age) >= 75: risk_factors.append(f"age {age}") if has_prior_stroke: risk_factors.append("prior stroke") return RuleFiredEvent( rule_name="high_risk_patient_with_arrhythmia", severity="escalate", explanation=f"High-risk patient ({', '.join(risk_factors)}) with arrhythmia. Escalate to cardiologist.", ) return None def rule_elderly_with_abnormal_rhythm( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Elderly patient (age >= 75) with abnormal rhythm (notify level).""" age = patient.get("age") score = float(model.get("score", 0.0)) label = model.get("label") if age and int(age) >= 75 and label in {"arrhythmia"} and score >= 0.5: return RuleFiredEvent( rule_name="elderly_with_abnormal_rhythm", severity="notify", explanation=f"Elderly patient (age {age}) with abnormal rhythm. Increased monitoring advised.", ) return None def rule_prior_stroke_escalation( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Patient with prior stroke history and any concerning signal.""" has_prior_stroke = patient.get("has_prior_stroke", False) score = float(model.get("score", 0.0)) if has_prior_stroke and score >= 0.6: return RuleFiredEvent( rule_name="prior_stroke_escalation", severity="notify", explanation="Patient has prior stroke history. Vigilant monitoring required for any arrhythmia indicators.", ) return None def rule_baseline_monitoring( self, patient: Dict[str, Any], model: Dict[str, Any] ) -> Optional[RuleFiredEvent]: """Baseline: no alerts triggered, routine monitoring.""" # This rule always fires if no other rules have escalated return RuleFiredEvent( rule_name="baseline_monitoring", severity="none", explanation="No critical alerts. Routine monitoring in progress.", ) # Singleton instance medical_rule_engine = MedicalRuleEngine() def evaluate_medical_rules( patient_context: Dict[str, Any], model_output: Dict[str, Any], ) -> Dict[str, Any]: """ Public API: evaluate medical rules. Args: patient_context: patient metadata (age, prior_stroke, etc.) model_output: ML model output (label, score, hr) Returns: { "alert_level": str, "explanations": [str], "fired_rules": [dict], } """ return medical_rule_engine.evaluate(patient_context, model_output)