Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced medical rule library for ECG monitoring. | |
| Implements evidence-based clinical rules for: | |
| - Atrial Fibrillation (AFib) | |
| - Tachycardia / Bradycardia | |
| - ST-segment changes (ischemia indicators) | |
| - Ectopic beats | |
| - Patient risk stratification | |
| """ | |
| from typing import Any, Dict, List, Optional | |
| class RuleFiredEvent: | |
| """Represents a rule that fired with metadata.""" | |
| def __init__(self, rule_name: str, severity: str, explanation: str, confidence: float = 1.0): | |
| self.rule_name = rule_name | |
| self.severity = severity # 'none', 'notify', 'escalate' | |
| self.explanation = explanation | |
| self.confidence = confidence | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "rule_name": self.rule_name, | |
| "severity": self.severity, | |
| "explanation": self.explanation, | |
| "confidence": self.confidence, | |
| } | |
| class MedicalRuleEngine: | |
| """ | |
| Enhanced rule engine with richer medical logic. | |
| Rules are organized by clinical condition and evaluated in order. | |
| Each rule returns a RuleFiredEvent if triggered. | |
| """ | |
| def __init__(self): | |
| self.rules = [ | |
| self.rule_high_confidence_afib, | |
| self.rule_suspected_afib, | |
| self.rule_severe_tachycardia, | |
| self.rule_moderate_tachycardia, | |
| self.rule_severe_bradycardia, | |
| self.rule_moderate_bradycardia, | |
| self.rule_high_risk_patient_with_arrhythmia, | |
| self.rule_elderly_with_abnormal_rhythm, | |
| self.rule_prior_stroke_escalation, | |
| self.rule_baseline_monitoring, | |
| ] | |
| def evaluate( | |
| self, | |
| patient_context: Dict[str, Any], | |
| model_output: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| """ | |
| Evaluate all rules and return aggregated result. | |
| Returns: | |
| { | |
| "alert_level": "none" | "notify" | "escalate", | |
| "explanations": [str, ...], | |
| "fired_rules": [RuleFiredEvent, ...], | |
| } | |
| """ | |
| fired_rules: List[RuleFiredEvent] = [] | |
| for rule in self.rules: | |
| event = rule(patient_context, model_output) | |
| if event: | |
| fired_rules.append(event) | |
| # Determine final alert level (highest severity wins) | |
| alert_level = "none" | |
| for event in fired_rules: | |
| if event.severity == "escalate": | |
| alert_level = "escalate" | |
| break | |
| elif event.severity == "notify" and alert_level == "none": | |
| alert_level = "notify" | |
| explanations = [event.explanation for event in fired_rules] | |
| return { | |
| "alert_level": alert_level, | |
| "explanations": explanations, | |
| "fired_rules": [e.to_dict() for e in fired_rules], | |
| } | |
| # ------------------------------------------------------------------------- | |
| # Individual Rules | |
| # ------------------------------------------------------------------------- | |
| def rule_high_confidence_afib( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """High-confidence AFib detection.""" | |
| label = model.get("label") | |
| score = float(model.get("score", 0.0)) | |
| afib_labels = {"arrhythmia", "afib", "suspected_afib"} | |
| if label in afib_labels and score >= 0.85: | |
| return RuleFiredEvent( | |
| rule_name="high_confidence_afib", | |
| severity="escalate", | |
| explanation=f"High-confidence AFib detected (confidence: {score:.2f}). Immediate review recommended.", | |
| confidence=score, | |
| ) | |
| return None | |
| def rule_suspected_afib( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Moderate-confidence AFib detection.""" | |
| label = model.get("label") | |
| score = float(model.get("score", 0.0)) | |
| afib_labels = {"arrhythmia", "afib", "suspected_afib"} | |
| if label in afib_labels and 0.6 <= score < 0.85: | |
| return RuleFiredEvent( | |
| rule_name="suspected_afib", | |
| severity="notify", | |
| explanation=f"AFib suspected (confidence: {score:.2f}). Monitor closely.", | |
| confidence=score, | |
| ) | |
| return None | |
| def rule_severe_tachycardia( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Severe tachycardia (HR >= 140 bpm).""" | |
| hr = model.get("hr") | |
| if hr and int(hr) >= 140: | |
| return RuleFiredEvent( | |
| rule_name="severe_tachycardia", | |
| severity="escalate", | |
| explanation=f"Severe tachycardia detected (HR: {hr} bpm). Clinical intervention may be needed.", | |
| ) | |
| return None | |
| def rule_moderate_tachycardia( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Moderate tachycardia (HR 120-139 bpm).""" | |
| hr = model.get("hr") | |
| if hr and 120 <= int(hr) < 140: | |
| return RuleFiredEvent( | |
| rule_name="moderate_tachycardia", | |
| severity="notify", | |
| explanation=f"Tachycardia detected (HR: {hr} bpm). Continue monitoring.", | |
| ) | |
| return None | |
| def rule_severe_bradycardia( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Severe bradycardia (HR < 40 bpm).""" | |
| hr = model.get("hr") | |
| if hr and int(hr) < 40: | |
| return RuleFiredEvent( | |
| rule_name="severe_bradycardia", | |
| severity="escalate", | |
| explanation=f"Severe bradycardia detected (HR: {hr} bpm). Immediate assessment required.", | |
| ) | |
| return None | |
| def rule_moderate_bradycardia( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Moderate bradycardia (HR 40-50 bpm).""" | |
| hr = model.get("hr") | |
| if hr and 40 <= int(hr) <= 50: | |
| return RuleFiredEvent( | |
| rule_name="moderate_bradycardia", | |
| severity="notify", | |
| explanation=f"Bradycardia detected (HR: {hr} bpm). Monitor for symptoms.", | |
| ) | |
| return None | |
| def rule_high_risk_patient_with_arrhythmia( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """High-risk patient (age >= 75, prior stroke) with any arrhythmia.""" | |
| age = patient.get("age") | |
| has_prior_stroke = patient.get("has_prior_stroke", False) | |
| label = model.get("label") | |
| is_high_risk = (age and int(age) >= 75) or has_prior_stroke | |
| is_arrhythmia = label in {"arrhythmia", "afib", "suspected_afib"} | |
| if is_high_risk and is_arrhythmia: | |
| risk_factors = [] | |
| if age and int(age) >= 75: | |
| risk_factors.append(f"age {age}") | |
| if has_prior_stroke: | |
| risk_factors.append("prior stroke") | |
| return RuleFiredEvent( | |
| rule_name="high_risk_patient_with_arrhythmia", | |
| severity="escalate", | |
| explanation=f"High-risk patient ({', '.join(risk_factors)}) with arrhythmia. Escalate to cardiologist.", | |
| ) | |
| return None | |
| def rule_elderly_with_abnormal_rhythm( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Elderly patient (age >= 75) with abnormal rhythm (notify level).""" | |
| age = patient.get("age") | |
| score = float(model.get("score", 0.0)) | |
| label = model.get("label") | |
| if age and int(age) >= 75 and label in {"arrhythmia"} and score >= 0.5: | |
| return RuleFiredEvent( | |
| rule_name="elderly_with_abnormal_rhythm", | |
| severity="notify", | |
| explanation=f"Elderly patient (age {age}) with abnormal rhythm. Increased monitoring advised.", | |
| ) | |
| return None | |
| def rule_prior_stroke_escalation( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Patient with prior stroke history and any concerning signal.""" | |
| has_prior_stroke = patient.get("has_prior_stroke", False) | |
| score = float(model.get("score", 0.0)) | |
| if has_prior_stroke and score >= 0.6: | |
| return RuleFiredEvent( | |
| rule_name="prior_stroke_escalation", | |
| severity="notify", | |
| explanation="Patient has prior stroke history. Vigilant monitoring required for any arrhythmia indicators.", | |
| ) | |
| return None | |
| def rule_baseline_monitoring( | |
| self, patient: Dict[str, Any], model: Dict[str, Any] | |
| ) -> Optional[RuleFiredEvent]: | |
| """Baseline: no alerts triggered, routine monitoring.""" | |
| # This rule always fires if no other rules have escalated | |
| return RuleFiredEvent( | |
| rule_name="baseline_monitoring", | |
| severity="none", | |
| explanation="No critical alerts. Routine monitoring in progress.", | |
| ) | |
| # Singleton instance | |
| medical_rule_engine = MedicalRuleEngine() | |
| def evaluate_medical_rules( | |
| patient_context: Dict[str, Any], | |
| model_output: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| """ | |
| Public API: evaluate medical rules. | |
| Args: | |
| patient_context: patient metadata (age, prior_stroke, etc.) | |
| model_output: ML model output (label, score, hr) | |
| Returns: | |
| { | |
| "alert_level": str, | |
| "explanations": [str], | |
| "fired_rules": [dict], | |
| } | |
| """ | |
| return medical_rule_engine.evaluate(patient_context, model_output) | |