Spaces:
Running
Running
File size: 4,183 Bytes
e3b4744 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import re
import torch
import numpy as np
# Define classes
CLASSES = [
'Sinus Rhythm', # 0
'Atrial Fibrillation', # 1
'Sinus Tachycardia', # 2
'Sinus Bradycardia', # 3
'Ventricular Tachycardia' # 4
]
def get_refined_labels(text):
"""
Parses ECG report text to extract diagnostic labels using Regex with negation handling.
Args:
text (str): Combined text from report columns.
Returns:
torch.Tensor: Multi-hot encoded label vector.
"""
text = text.lower()
labels = torch.zeros(len(CLASSES), dtype=torch.float32)
# ---------------------------------------------------------
# Helper: Check for positive mention (ignoring negations)
# ---------------------------------------------------------
def has_condition(patterns, exclusion_patterns=None):
if exclusion_patterns:
for excl in exclusion_patterns:
if re.search(excl, text):
return False
for pat in patterns:
# Check for negation preceding the match
# finding all matches
matches = re.finditer(pat, text)
for match in matches:
start_idx = match.start()
# Look at the window before the match (e.g., 20 chars)
context_before = text[max(0, start_idx-25):start_idx]
# Negation triggers
negations = ['no ', 'not ', 'rule out ', 'denies ', 'absence of ', 'free of ']
if any(neg in context_before for neg in negations):
continue # This match is negated
return True # Found a positive, non-negated match
return False
# ---------------------------------------------------------
# Class 0: Sinus Rhythm
# ---------------------------------------------------------
# "Sinus rhythm" is often the default, but we should check for it explicitly.
if has_condition([r'sinus rhythm']):
labels[0] = 1.0
# ---------------------------------------------------------
# Class 1: Atrial Fibrillation
# ---------------------------------------------------------
# Synonyms: AFib, A-fib, Atrial Fib
if has_condition([r'atrial fibrillation', r'afib', r'a-fib', r'atrial fib']):
labels[1] = 1.0
# ---------------------------------------------------------
# Class 2: Sinus Tachycardia
# ---------------------------------------------------------
if has_condition([r'sinus tachycardia']):
labels[2] = 1.0
# ---------------------------------------------------------
# Class 3: Sinus Bradycardia
# ---------------------------------------------------------
if has_condition([r'sinus bradycardia']):
labels[3] = 1.0
# ---------------------------------------------------------
# Class 4: Ventricular Tachycardia
# ---------------------------------------------------------
# Synonyms: VTach, V-Tach, VT
# Be careful with "VT" matching random text
if has_condition([r'ventricular tachycardia', r'vtach', r'\bvt\b', r'v-tach']):
labels[4] = 1.0
return labels
if __name__ == "__main__":
# Test cases
test_sentences = [
("Normal sinus rhythm", [1, 0, 0, 0, 0]),
("Atrial fibrillation with rapid ventricular response", [0, 1, 0, 0, 0]),
("No atrial fibrillation detected", [0, 0, 0, 0, 0]),
("Sinus tachycardia", [0, 0, 1, 0, 0]),
("Rule out ventricular tachycardia", [0, 0, 0, 0, 0]),
("Patient has history of afib", [0, 1, 0, 0, 0]), # History might be ambiguous, but usually valid for label
("Sinus bradycardia observed", [0, 0, 0, 1, 0])
]
print("Running Regex Label Tests...")
for txt, expected in test_sentences:
res = get_refined_labels(txt)
match = torch.all(res == torch.tensor(expected)).item()
print(f"'{txt}' -> {res.numpy()} [{'✅' if match else '❌'}]")
|