Spaces:
Running
Running
| import re | |
| import torch | |
| import numpy as np | |
| # Define classes | |
| CLASSES = [ | |
| 'Sinus Rhythm', # 0 | |
| 'Atrial Fibrillation', # 1 | |
| 'Sinus Tachycardia', # 2 | |
| 'Sinus Bradycardia', # 3 | |
| 'Ventricular Tachycardia' # 4 | |
| ] | |
| def get_refined_labels(text): | |
| """ | |
| Parses ECG report text to extract diagnostic labels using Regex with negation handling. | |
| Args: | |
| text (str): Combined text from report columns. | |
| Returns: | |
| torch.Tensor: Multi-hot encoded label vector. | |
| """ | |
| text = text.lower() | |
| labels = torch.zeros(len(CLASSES), dtype=torch.float32) | |
| # --------------------------------------------------------- | |
| # Helper: Check for positive mention (ignoring negations) | |
| # --------------------------------------------------------- | |
| def has_condition(patterns, exclusion_patterns=None): | |
| if exclusion_patterns: | |
| for excl in exclusion_patterns: | |
| if re.search(excl, text): | |
| return False | |
| for pat in patterns: | |
| # Check for negation preceding the match | |
| # finding all matches | |
| matches = re.finditer(pat, text) | |
| for match in matches: | |
| start_idx = match.start() | |
| # Look at the window before the match (e.g., 20 chars) | |
| context_before = text[max(0, start_idx-25):start_idx] | |
| # Negation triggers | |
| negations = ['no ', 'not ', 'rule out ', 'denies ', 'absence of ', 'free of '] | |
| if any(neg in context_before for neg in negations): | |
| continue # This match is negated | |
| return True # Found a positive, non-negated match | |
| return False | |
| # --------------------------------------------------------- | |
| # Class 0: Sinus Rhythm | |
| # --------------------------------------------------------- | |
| # "Sinus rhythm" is often the default, but we should check for it explicitly. | |
| if has_condition([r'sinus rhythm']): | |
| labels[0] = 1.0 | |
| # --------------------------------------------------------- | |
| # Class 1: Atrial Fibrillation | |
| # --------------------------------------------------------- | |
| # Synonyms: AFib, A-fib, Atrial Fib | |
| if has_condition([r'atrial fibrillation', r'afib', r'a-fib', r'atrial fib']): | |
| labels[1] = 1.0 | |
| # --------------------------------------------------------- | |
| # Class 2: Sinus Tachycardia | |
| # --------------------------------------------------------- | |
| if has_condition([r'sinus tachycardia']): | |
| labels[2] = 1.0 | |
| # --------------------------------------------------------- | |
| # Class 3: Sinus Bradycardia | |
| # --------------------------------------------------------- | |
| if has_condition([r'sinus bradycardia']): | |
| labels[3] = 1.0 | |
| # --------------------------------------------------------- | |
| # Class 4: Ventricular Tachycardia | |
| # --------------------------------------------------------- | |
| # Synonyms: VTach, V-Tach, VT | |
| # Be careful with "VT" matching random text | |
| if has_condition([r'ventricular tachycardia', r'vtach', r'\bvt\b', r'v-tach']): | |
| labels[4] = 1.0 | |
| return labels | |
| if __name__ == "__main__": | |
| # Test cases | |
| test_sentences = [ | |
| ("Normal sinus rhythm", [1, 0, 0, 0, 0]), | |
| ("Atrial fibrillation with rapid ventricular response", [0, 1, 0, 0, 0]), | |
| ("No atrial fibrillation detected", [0, 0, 0, 0, 0]), | |
| ("Sinus tachycardia", [0, 0, 1, 0, 0]), | |
| ("Rule out ventricular tachycardia", [0, 0, 0, 0, 0]), | |
| ("Patient has history of afib", [0, 1, 0, 0, 0]), # History might be ambiguous, but usually valid for label | |
| ("Sinus bradycardia observed", [0, 0, 0, 1, 0]) | |
| ] | |
| print("Running Regex Label Tests...") | |
| for txt, expected in test_sentences: | |
| res = get_refined_labels(txt) | |
| match = torch.all(res == torch.tensor(expected)).item() | |
| print(f"'{txt}' -> {res.numpy()} [{'✅' if match else '❌'}]") | |