mgbam commited on
Commit
b9a9c97
·
verified ·
1 Parent(s): 7b4a49e

Upload 4 files

Browse files
app/rules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Symbolic/neurosymbolic rules for ECG and other signals.
3
+ """
4
+
app/rules/engine.py CHANGED
@@ -1,21 +1,27 @@
1
  from typing import Any, Dict
 
2
 
3
  from app.rules import ecg_rules
4
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def evaluate_ecg_rules(
7
  patient_context: Dict[str, Any],
8
  model_output: Dict[str, Any],
9
  ) -> Dict[str, Any]:
10
- """
11
- Apply ECG-specific rules to the model output and patient context.
12
 
13
- Returns:
14
- dict with keys:
15
- - alert_level: str
16
- - explanations: list[str]
17
- """
18
  result = ecg_rules.apply_rules(patient_context, model_output)
19
- alert_level = result.get("alert_level", "none")
20
- explanations = result.get("explanations", [])
21
- return {"alert_level": alert_level, "explanations": explanations}
 
1
  from typing import Any, Dict
2
+ import os
3
 
4
  from app.rules import ecg_rules
5
 
6
+ USE_ENHANCED_RULES = os.getenv("USE_ENHANCED_RULES", "true").lower() == "true"
7
+
8
+ if USE_ENHANCED_RULES:
9
+ try:
10
+ from app.rules.medical_rules import evaluate_medical_rules
11
+ _evaluator = evaluate_medical_rules
12
+ except ImportError:
13
+ _evaluator = None
14
+ USE_ENHANCED_RULES = False
15
+ else:
16
+ _evaluator = None
17
+
18
 
19
  def evaluate_ecg_rules(
20
  patient_context: Dict[str, Any],
21
  model_output: Dict[str, Any],
22
  ) -> Dict[str, Any]:
23
+ if USE_ENHANCED_RULES and _evaluator:
24
+ return _evaluator(patient_context, model_output)
25
 
 
 
 
 
 
26
  result = ecg_rules.apply_rules(patient_context, model_output)
27
+ return {"alert_level": result.get("alert_level", "none"), "explanations": result.get("explanations", [])}
 
 
app/rules/medical_rules.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced medical rule library for ECG monitoring.
3
+
4
+ Implements evidence-based clinical rules for:
5
+ - Atrial Fibrillation (AFib)
6
+ - Tachycardia / Bradycardia
7
+ - ST-segment changes (ischemia indicators)
8
+ - Ectopic beats
9
+ - Patient risk stratification
10
+ """
11
+ from typing import Any, Dict, List, Optional
12
+
13
+
14
+ class RuleFiredEvent:
15
+ """Represents a rule that fired with metadata."""
16
+
17
+ def __init__(self, rule_name: str, severity: str, explanation: str, confidence: float = 1.0):
18
+ self.rule_name = rule_name
19
+ self.severity = severity # 'none', 'notify', 'escalate'
20
+ self.explanation = explanation
21
+ self.confidence = confidence
22
+
23
+ def to_dict(self) -> Dict[str, Any]:
24
+ return {
25
+ "rule_name": self.rule_name,
26
+ "severity": self.severity,
27
+ "explanation": self.explanation,
28
+ "confidence": self.confidence,
29
+ }
30
+
31
+
32
+ class MedicalRuleEngine:
33
+ """
34
+ Enhanced rule engine with richer medical logic.
35
+
36
+ Rules are organized by clinical condition and evaluated in order.
37
+ Each rule returns a RuleFiredEvent if triggered.
38
+ """
39
+
40
+ def __init__(self):
41
+ self.rules = [
42
+ self.rule_high_confidence_afib,
43
+ self.rule_suspected_afib,
44
+ self.rule_severe_tachycardia,
45
+ self.rule_moderate_tachycardia,
46
+ self.rule_severe_bradycardia,
47
+ self.rule_moderate_bradycardia,
48
+ self.rule_high_risk_patient_with_arrhythmia,
49
+ self.rule_elderly_with_abnormal_rhythm,
50
+ self.rule_prior_stroke_escalation,
51
+ self.rule_baseline_monitoring,
52
+ ]
53
+
54
+ def evaluate(
55
+ self,
56
+ patient_context: Dict[str, Any],
57
+ model_output: Dict[str, Any],
58
+ ) -> Dict[str, Any]:
59
+ """
60
+ Evaluate all rules and return aggregated result.
61
+
62
+ Returns:
63
+ {
64
+ "alert_level": "none" | "notify" | "escalate",
65
+ "explanations": [str, ...],
66
+ "fired_rules": [RuleFiredEvent, ...],
67
+ }
68
+ """
69
+ fired_rules: List[RuleFiredEvent] = []
70
+
71
+ for rule in self.rules:
72
+ event = rule(patient_context, model_output)
73
+ if event:
74
+ fired_rules.append(event)
75
+
76
+ # Determine final alert level (highest severity wins)
77
+ alert_level = "none"
78
+ for event in fired_rules:
79
+ if event.severity == "escalate":
80
+ alert_level = "escalate"
81
+ break
82
+ elif event.severity == "notify" and alert_level == "none":
83
+ alert_level = "notify"
84
+
85
+ explanations = [event.explanation for event in fired_rules]
86
+
87
+ return {
88
+ "alert_level": alert_level,
89
+ "explanations": explanations,
90
+ "fired_rules": [e.to_dict() for e in fired_rules],
91
+ }
92
+
93
+ # -------------------------------------------------------------------------
94
+ # Individual Rules
95
+ # -------------------------------------------------------------------------
96
+
97
+ def rule_high_confidence_afib(
98
+ self, patient: Dict[str, Any], model: Dict[str, Any]
99
+ ) -> Optional[RuleFiredEvent]:
100
+ """High-confidence AFib detection."""
101
+ label = model.get("label")
102
+ score = float(model.get("score", 0.0))
103
+
104
+ afib_labels = {"arrhythmia", "afib", "suspected_afib"}
105
+
106
+ if label in afib_labels and score >= 0.85:
107
+ return RuleFiredEvent(
108
+ rule_name="high_confidence_afib",
109
+ severity="escalate",
110
+ explanation=f"High-confidence AFib detected (confidence: {score:.2f}). Immediate review recommended.",
111
+ confidence=score,
112
+ )
113
+ return None
114
+
115
+ def rule_suspected_afib(
116
+ self, patient: Dict[str, Any], model: Dict[str, Any]
117
+ ) -> Optional[RuleFiredEvent]:
118
+ """Moderate-confidence AFib detection."""
119
+ label = model.get("label")
120
+ score = float(model.get("score", 0.0))
121
+
122
+ afib_labels = {"arrhythmia", "afib", "suspected_afib"}
123
+
124
+ if label in afib_labels and 0.6 <= score < 0.85:
125
+ return RuleFiredEvent(
126
+ rule_name="suspected_afib",
127
+ severity="notify",
128
+ explanation=f"AFib suspected (confidence: {score:.2f}). Monitor closely.",
129
+ confidence=score,
130
+ )
131
+ return None
132
+
133
+ def rule_severe_tachycardia(
134
+ self, patient: Dict[str, Any], model: Dict[str, Any]
135
+ ) -> Optional[RuleFiredEvent]:
136
+ """Severe tachycardia (HR >= 140 bpm)."""
137
+ hr = model.get("hr")
138
+ if hr and int(hr) >= 140:
139
+ return RuleFiredEvent(
140
+ rule_name="severe_tachycardia",
141
+ severity="escalate",
142
+ explanation=f"Severe tachycardia detected (HR: {hr} bpm). Clinical intervention may be needed.",
143
+ )
144
+ return None
145
+
146
+ def rule_moderate_tachycardia(
147
+ self, patient: Dict[str, Any], model: Dict[str, Any]
148
+ ) -> Optional[RuleFiredEvent]:
149
+ """Moderate tachycardia (HR 120-139 bpm)."""
150
+ hr = model.get("hr")
151
+ if hr and 120 <= int(hr) < 140:
152
+ return RuleFiredEvent(
153
+ rule_name="moderate_tachycardia",
154
+ severity="notify",
155
+ explanation=f"Tachycardia detected (HR: {hr} bpm). Continue monitoring.",
156
+ )
157
+ return None
158
+
159
+ def rule_severe_bradycardia(
160
+ self, patient: Dict[str, Any], model: Dict[str, Any]
161
+ ) -> Optional[RuleFiredEvent]:
162
+ """Severe bradycardia (HR < 40 bpm)."""
163
+ hr = model.get("hr")
164
+ if hr and int(hr) < 40:
165
+ return RuleFiredEvent(
166
+ rule_name="severe_bradycardia",
167
+ severity="escalate",
168
+ explanation=f"Severe bradycardia detected (HR: {hr} bpm). Immediate assessment required.",
169
+ )
170
+ return None
171
+
172
+ def rule_moderate_bradycardia(
173
+ self, patient: Dict[str, Any], model: Dict[str, Any]
174
+ ) -> Optional[RuleFiredEvent]:
175
+ """Moderate bradycardia (HR 40-50 bpm)."""
176
+ hr = model.get("hr")
177
+ if hr and 40 <= int(hr) <= 50:
178
+ return RuleFiredEvent(
179
+ rule_name="moderate_bradycardia",
180
+ severity="notify",
181
+ explanation=f"Bradycardia detected (HR: {hr} bpm). Monitor for symptoms.",
182
+ )
183
+ return None
184
+
185
+ def rule_high_risk_patient_with_arrhythmia(
186
+ self, patient: Dict[str, Any], model: Dict[str, Any]
187
+ ) -> Optional[RuleFiredEvent]:
188
+ """High-risk patient (age >= 75, prior stroke) with any arrhythmia."""
189
+ age = patient.get("age")
190
+ has_prior_stroke = patient.get("has_prior_stroke", False)
191
+ label = model.get("label")
192
+
193
+ is_high_risk = (age and int(age) >= 75) or has_prior_stroke
194
+ is_arrhythmia = label in {"arrhythmia", "afib", "suspected_afib"}
195
+
196
+ if is_high_risk and is_arrhythmia:
197
+ risk_factors = []
198
+ if age and int(age) >= 75:
199
+ risk_factors.append(f"age {age}")
200
+ if has_prior_stroke:
201
+ risk_factors.append("prior stroke")
202
+
203
+ return RuleFiredEvent(
204
+ rule_name="high_risk_patient_with_arrhythmia",
205
+ severity="escalate",
206
+ explanation=f"High-risk patient ({', '.join(risk_factors)}) with arrhythmia. Escalate to cardiologist.",
207
+ )
208
+ return None
209
+
210
+ def rule_elderly_with_abnormal_rhythm(
211
+ self, patient: Dict[str, Any], model: Dict[str, Any]
212
+ ) -> Optional[RuleFiredEvent]:
213
+ """Elderly patient (age >= 75) with abnormal rhythm (notify level)."""
214
+ age = patient.get("age")
215
+ score = float(model.get("score", 0.0))
216
+ label = model.get("label")
217
+
218
+ if age and int(age) >= 75 and label in {"arrhythmia"} and score >= 0.5:
219
+ return RuleFiredEvent(
220
+ rule_name="elderly_with_abnormal_rhythm",
221
+ severity="notify",
222
+ explanation=f"Elderly patient (age {age}) with abnormal rhythm. Increased monitoring advised.",
223
+ )
224
+ return None
225
+
226
+ def rule_prior_stroke_escalation(
227
+ self, patient: Dict[str, Any], model: Dict[str, Any]
228
+ ) -> Optional[RuleFiredEvent]:
229
+ """Patient with prior stroke history and any concerning signal."""
230
+ has_prior_stroke = patient.get("has_prior_stroke", False)
231
+ score = float(model.get("score", 0.0))
232
+
233
+ if has_prior_stroke and score >= 0.6:
234
+ return RuleFiredEvent(
235
+ rule_name="prior_stroke_escalation",
236
+ severity="notify",
237
+ explanation="Patient has prior stroke history. Vigilant monitoring required for any arrhythmia indicators.",
238
+ )
239
+ return None
240
+
241
+ def rule_baseline_monitoring(
242
+ self, patient: Dict[str, Any], model: Dict[str, Any]
243
+ ) -> Optional[RuleFiredEvent]:
244
+ """Baseline: no alerts triggered, routine monitoring."""
245
+ # This rule always fires if no other rules have escalated
246
+ return RuleFiredEvent(
247
+ rule_name="baseline_monitoring",
248
+ severity="none",
249
+ explanation="No critical alerts. Routine monitoring in progress.",
250
+ )
251
+
252
+
253
+ # Singleton instance
254
+ medical_rule_engine = MedicalRuleEngine()
255
+
256
+
257
+ def evaluate_medical_rules(
258
+ patient_context: Dict[str, Any],
259
+ model_output: Dict[str, Any],
260
+ ) -> Dict[str, Any]:
261
+ """
262
+ Public API: evaluate medical rules.
263
+
264
+ Args:
265
+ patient_context: patient metadata (age, prior_stroke, etc.)
266
+ model_output: ML model output (label, score, hr)
267
+
268
+ Returns:
269
+ {
270
+ "alert_level": str,
271
+ "explanations": [str],
272
+ "fired_rules": [dict],
273
+ }
274
+ """
275
+ return medical_rule_engine.evaluate(patient_context, model_output)