ECG / labels_refined.py
IFMedTechdemo's picture
Upload folder using huggingface_hub
e3b4744 verified
import re
import torch
import numpy as np
# Define classes
CLASSES = [
'Sinus Rhythm', # 0
'Atrial Fibrillation', # 1
'Sinus Tachycardia', # 2
'Sinus Bradycardia', # 3
'Ventricular Tachycardia' # 4
]
def get_refined_labels(text):
"""
Parses ECG report text to extract diagnostic labels using Regex with negation handling.
Args:
text (str): Combined text from report columns.
Returns:
torch.Tensor: Multi-hot encoded label vector.
"""
text = text.lower()
labels = torch.zeros(len(CLASSES), dtype=torch.float32)
# ---------------------------------------------------------
# Helper: Check for positive mention (ignoring negations)
# ---------------------------------------------------------
def has_condition(patterns, exclusion_patterns=None):
if exclusion_patterns:
for excl in exclusion_patterns:
if re.search(excl, text):
return False
for pat in patterns:
# Check for negation preceding the match
# finding all matches
matches = re.finditer(pat, text)
for match in matches:
start_idx = match.start()
# Look at the window before the match (e.g., 20 chars)
context_before = text[max(0, start_idx-25):start_idx]
# Negation triggers
negations = ['no ', 'not ', 'rule out ', 'denies ', 'absence of ', 'free of ']
if any(neg in context_before for neg in negations):
continue # This match is negated
return True # Found a positive, non-negated match
return False
# ---------------------------------------------------------
# Class 0: Sinus Rhythm
# ---------------------------------------------------------
# "Sinus rhythm" is often the default, but we should check for it explicitly.
if has_condition([r'sinus rhythm']):
labels[0] = 1.0
# ---------------------------------------------------------
# Class 1: Atrial Fibrillation
# ---------------------------------------------------------
# Synonyms: AFib, A-fib, Atrial Fib
if has_condition([r'atrial fibrillation', r'afib', r'a-fib', r'atrial fib']):
labels[1] = 1.0
# ---------------------------------------------------------
# Class 2: Sinus Tachycardia
# ---------------------------------------------------------
if has_condition([r'sinus tachycardia']):
labels[2] = 1.0
# ---------------------------------------------------------
# Class 3: Sinus Bradycardia
# ---------------------------------------------------------
if has_condition([r'sinus bradycardia']):
labels[3] = 1.0
# ---------------------------------------------------------
# Class 4: Ventricular Tachycardia
# ---------------------------------------------------------
# Synonyms: VTach, V-Tach, VT
# Be careful with "VT" matching random text
if has_condition([r'ventricular tachycardia', r'vtach', r'\bvt\b', r'v-tach']):
labels[4] = 1.0
return labels
if __name__ == "__main__":
# Test cases
test_sentences = [
("Normal sinus rhythm", [1, 0, 0, 0, 0]),
("Atrial fibrillation with rapid ventricular response", [0, 1, 0, 0, 0]),
("No atrial fibrillation detected", [0, 0, 0, 0, 0]),
("Sinus tachycardia", [0, 0, 1, 0, 0]),
("Rule out ventricular tachycardia", [0, 0, 0, 0, 0]),
("Patient has history of afib", [0, 1, 0, 0, 0]), # History might be ambiguous, but usually valid for label
("Sinus bradycardia observed", [0, 0, 0, 1, 0])
]
print("Running Regex Label Tests...")
for txt, expected in test_sentences:
res = get_refined_labels(txt)
match = torch.all(res == torch.tensor(expected)).item()
print(f"'{txt}' -> {res.numpy()} [{'✅' if match else '❌'}]")