import re import torch import numpy as np # Define classes CLASSES = [ 'Sinus Rhythm', # 0 'Atrial Fibrillation', # 1 'Sinus Tachycardia', # 2 'Sinus Bradycardia', # 3 'Ventricular Tachycardia' # 4 ] def get_refined_labels(text): """ Parses ECG report text to extract diagnostic labels using Regex with negation handling. Args: text (str): Combined text from report columns. Returns: torch.Tensor: Multi-hot encoded label vector. """ text = text.lower() labels = torch.zeros(len(CLASSES), dtype=torch.float32) # --------------------------------------------------------- # Helper: Check for positive mention (ignoring negations) # --------------------------------------------------------- def has_condition(patterns, exclusion_patterns=None): if exclusion_patterns: for excl in exclusion_patterns: if re.search(excl, text): return False for pat in patterns: # Check for negation preceding the match # finding all matches matches = re.finditer(pat, text) for match in matches: start_idx = match.start() # Look at the window before the match (e.g., 20 chars) context_before = text[max(0, start_idx-25):start_idx] # Negation triggers negations = ['no ', 'not ', 'rule out ', 'denies ', 'absence of ', 'free of '] if any(neg in context_before for neg in negations): continue # This match is negated return True # Found a positive, non-negated match return False # --------------------------------------------------------- # Class 0: Sinus Rhythm # --------------------------------------------------------- # "Sinus rhythm" is often the default, but we should check for it explicitly. if has_condition([r'sinus rhythm']): labels[0] = 1.0 # --------------------------------------------------------- # Class 1: Atrial Fibrillation # --------------------------------------------------------- # Synonyms: AFib, A-fib, Atrial Fib if has_condition([r'atrial fibrillation', r'afib', r'a-fib', r'atrial fib']): labels[1] = 1.0 # --------------------------------------------------------- # Class 2: Sinus Tachycardia # --------------------------------------------------------- if has_condition([r'sinus tachycardia']): labels[2] = 1.0 # --------------------------------------------------------- # Class 3: Sinus Bradycardia # --------------------------------------------------------- if has_condition([r'sinus bradycardia']): labels[3] = 1.0 # --------------------------------------------------------- # Class 4: Ventricular Tachycardia # --------------------------------------------------------- # Synonyms: VTach, V-Tach, VT # Be careful with "VT" matching random text if has_condition([r'ventricular tachycardia', r'vtach', r'\bvt\b', r'v-tach']): labels[4] = 1.0 return labels if __name__ == "__main__": # Test cases test_sentences = [ ("Normal sinus rhythm", [1, 0, 0, 0, 0]), ("Atrial fibrillation with rapid ventricular response", [0, 1, 0, 0, 0]), ("No atrial fibrillation detected", [0, 0, 0, 0, 0]), ("Sinus tachycardia", [0, 0, 1, 0, 0]), ("Rule out ventricular tachycardia", [0, 0, 0, 0, 0]), ("Patient has history of afib", [0, 1, 0, 0, 0]), # History might be ambiguous, but usually valid for label ("Sinus bradycardia observed", [0, 0, 0, 1, 0]) ] print("Running Regex Label Tests...") for txt, expected in test_sentences: res = get_refined_labels(txt) match = torch.all(res == torch.tensor(expected)).item() print(f"'{txt}' -> {res.numpy()} [{'✅' if match else '❌'}]")