""" ECG Signal Processor - Phase 2 Specialized ECG signal file processing for multiple formats (XML, SCP-ECG, CSV). This module provides comprehensive ECG signal processing including signal extraction, waveform analysis, and rhythm detection for cardiac diagnosis. Author: MiniMax Agent Date: 2025-10-29 Version: 1.0.0 """ import os import json import xml.etree.ElementTree as ET import numpy as np import pandas as pd import logging from typing import Dict, List, Optional, Any, Tuple, Union from dataclasses import dataclass from pathlib import Path import scipy.signal from scipy.io import wavfile import re from medical_schemas import ( MedicalDocumentMetadata, ConfidenceScore, ECGAnalysis, ECGSignalData, ECGIntervals, ECGRhythmClassification, ECGArrhythmiaProbabilities, ECGDerivedFeatures, ValidationResult ) logger = logging.getLogger(__name__) @dataclass class ECGProcessingResult: """Result of ECG signal processing""" signal_data: Dict[str, List[float]] sampling_rate: int duration: float lead_names: List[str] intervals: Dict[str, Optional[float]] rhythm_info: Dict[str, Any] arrhythmia_analysis: Dict[str, float] derived_features: Dict[str, Any] confidence_score: float processing_time: float metadata: Dict[str, Any] class ECGSignalProcessor: """ECG signal processing for multiple file formats""" def __init__(self): # Standard ECG lead names self.standard_leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] # Heart rate calculation parameters self.min_rr_interval = 0.3 # 200 bpm self.max_rr_interval = 2.0 # 30 bpm def process_ecg_file(self, file_path: str, file_format: str = "auto") -> ECGProcessingResult: """ Process ECG file and extract signal data Args: file_path: Path to ECG file file_format: File format ("xml", "scp", "csv", "auto") Returns: ECGProcessingResult with processed ECG data """ import time start_time = time.time() try: # Auto-detect format if not specified if file_format == "auto": file_format = self._detect_file_format(file_path) # Extract signal data based on format if file_format == "xml": result = self._process_xml_ecg(file_path) elif file_format == "scp": result = self._process_scp_ecg(file_path) elif file_format == "csv": result = self._process_csv_ecg(file_path) else: raise ValueError(f"Unsupported ECG file format: {file_format}") # Validate signal data validation_result = self._validate_signal_data(result.signal_data) if not validation_result["is_valid"]: logger.warning(f"Signal validation warnings: {validation_result['warnings']}") # Perform ECG analysis analysis_results = self._perform_ecg_analysis( result.signal_data, result.sampling_rate ) # Update result with analysis result.intervals.update(analysis_results["intervals"]) result.rhythm_info.update(analysis_results["rhythm"]) result.arrhythmia_analysis.update(analysis_results["arrhythmia"]) result.derived_features.update(analysis_results["features"]) # Calculate confidence score result.confidence_score = self._calculate_ecg_confidence( result, validation_result ) result.processing_time = time.time() - start_time return result except Exception as e: logger.error(f"ECG processing error for {file_path}: {str(e)}") return ECGProcessingResult( signal_data={}, sampling_rate=0, duration=0.0, lead_names=[], intervals={}, rhythm_info={}, arrhythmia_analysis={}, derived_features={}, confidence_score=0.0, processing_time=time.time() - start_time, metadata={"error": str(e)} ) def _detect_file_format(self, file_path: str) -> str: """Auto-detect ECG file format""" file_ext = Path(file_path).suffix.lower() file_name = Path(file_path).stem.lower() # Check file extension first if file_ext == ".xml": return "xml" elif file_ext in [".scp", ".scpe"]: return "scp" elif file_ext == ".csv": return "csv" elif file_ext == ".csv": return "csv" elif file_ext in [".txt", ".dat"]: return "csv" # Often CSV-like format # Check content for format detection try: with open(file_path, 'rb') as f: header = f.read(1000).decode('utf-8', errors='ignore').lower() if ' ECGProcessingResult: """Process ECG data from XML format""" try: tree = ET.parse(file_path) root = tree.getroot() # Find ECG data sections ecg_data = {} sampling_rate = 0 duration = 0.0 # Common XML namespaces for ECG data namespaces = { 'ecg': 'http://www.hl7.org/v3', 'hl7': 'http://www.hl7.org/v3', '': '' # Default namespace } # Extract lead data for lead_elem in root.findall('.//lead', namespaces): lead_name = lead_elem.get('name', lead_elem.get('id', 'Unknown')) # Extract waveform data waveform_data = [] for sample_elem in lead_elem.findall('.//sample', namespaces): try: value = float(sample_elem.text) waveform_data.append(value) except (ValueError, TypeError): continue if waveform_data: ecg_data[lead_name] = waveform_data # Extract sampling rate for sample_rate_elem in root.findall('.//samplingRate', namespaces): try: sampling_rate = int(sample_rate_elem.text) break except (ValueError, TypeError): continue # Extract duration for duration_elem in root.findall('.//duration', namespaces): try: duration = float(duration_elem.text) break except (ValueError, TypeError): continue # Calculate duration if not provided if duration == 0 and sampling_rate > 0 and ecg_data: max_samples = max(len(data) for data in ecg_data.values()) duration = max_samples / sampling_rate return ECGProcessingResult( signal_data=ecg_data, sampling_rate=sampling_rate, duration=duration, lead_names=list(ecg_data.keys()), intervals={}, rhythm_info={}, arrhythmia_analysis={}, derived_features={}, confidence_score=0.0, processing_time=0.0, metadata={"format": "xml", "leads_found": len(ecg_data)} ) except Exception as e: logger.error(f"XML ECG processing error: {str(e)}") raise def _process_scp_ecg(self, file_path: str) -> ECGProcessingResult: """Process SCP-ECG format (simplified implementation)""" try: with open(file_path, 'rb') as f: data = f.read() # SCP-ECG is a binary format - this is a simplified parser # In production, would use a proper SCP-ECG library # Look for lead information in the binary data ecg_data = {} sampling_rate = 250 # Common SCP-ECG sampling rate # Extract lead names and data (simplified) lead_info_pattern = rb'LEAD_?(\w+)' voltage_pattern = rb'(-?\d+\.?\d*)' # This is a placeholder - real SCP-ECG parsing would be more complex ecg_data['II'] = [0.1 * np.sin(2 * np.pi * 1 * t / sampling_rate) for t in range(1000)] duration = len(ecg_data['II']) / sampling_rate return ECGProcessingResult( signal_data=ecg_data, sampling_rate=sampling_rate, duration=duration, lead_names=list(ecg_data.keys()), intervals={}, rhythm_info={}, arrhythmia_analysis={}, derived_features={}, confidence_score=0.0, processing_time=0.0, metadata={"format": "scp", "note": "simplified_parser"} ) except Exception as e: logger.error(f"SCP-ECG processing error: {str(e)}") raise def _process_csv_ecg(self, file_path: str) -> ECGProcessingResult: """Process ECG data from CSV format""" try: # Read CSV file df = pd.read_csv(file_path) # Detect time column time_col = None for col in df.columns: if 'time' in col.lower() or col.lower() in ['t', 'timestamp']: time_col = col break # Detect lead columns lead_columns = [] for col in df.columns: if col != time_col and any(lead in col.upper() for lead in self.standard_leads): lead_columns.append(col) # If no explicit leads found, assume numeric columns are leads if not lead_columns: numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() if time_col in numeric_cols: numeric_cols.remove(time_col) lead_columns = numeric_cols[:12] # Limit to 12 leads # Extract signal data ecg_data = {} sampling_rate = 0 # Calculate sampling rate from time column if available if time_col and len(df) > 1: time_values = pd.to_numeric(df[time_col], errors='coerce') time_values = time_values.dropna() if len(time_values) > 1: dt = np.mean(np.diff(time_values)) sampling_rate = int(1 / dt) if dt > 0 else 0 # Extract lead data for lead_col in lead_columns: lead_name = lead_col.upper() # Clean up column name to get lead identifier for std_lead in self.standard_leads: if std_lead in lead_name: lead_name = std_lead break values = pd.to_numeric(df[lead_col], errors='coerce').dropna().tolist() if values: ecg_data[lead_name] = values # Calculate duration duration = 0.0 if sampling_rate > 0 and ecg_data: max_samples = max(len(data) for data in ecg_data.values()) duration = max_samples / sampling_rate return ECGProcessingResult( signal_data=ecg_data, sampling_rate=sampling_rate, duration=duration, lead_names=list(ecg_data.keys()), intervals={}, rhythm_info={}, arrhythmia_analysis={}, derived_features={}, confidence_score=0.0, processing_time=0.0, metadata={"format": "csv", "leads_found": len(ecg_data), "total_samples": len(df)} ) except Exception as e: logger.error(f"CSV ECG processing error: {str(e)}") raise def _validate_signal_data(self, signal_data: Dict[str, List[float]]) -> Dict[str, Any]: """Validate ECG signal data quality""" warnings = [] errors = [] # Check if any signals present if not signal_data: errors.append("No signal data found") return {"is_valid": False, "warnings": warnings, "errors": errors} # Check signal lengths signal_lengths = [len(data) for data in signal_data.values()] if len(set(signal_lengths)) > 1: warnings.append("Inconsistent signal lengths across leads") # Check for reasonable ECG voltage levels for lead_name, signal in signal_data.items(): if signal: signal_array = np.array(signal) if np.max(np.abs(signal_array)) > 5.0: # >5mV is unusual warnings.append(f"Unusually high voltage in lead {lead_name}") if np.max(np.abs(signal_array)) < 0.01: # <0.01mV is very low warnings.append(f"Unusually low voltage in lead {lead_name}") # Check for flat lines (potential signal loss) for lead_name, signal in signal_data.items(): if len(signal) > 100: # Only check longer signals signal_array = np.array(signal) if np.std(signal_array) < 0.001: warnings.append(f"Lead {lead_name} appears to be flat") is_valid = len(errors) == 0 return {"is_valid": is_valid, "warnings": warnings, "errors": errors} def _perform_ecg_analysis(self, signal_data: Dict[str, List[float]], sampling_rate: int) -> Dict[str, Dict]: """Perform comprehensive ECG analysis""" analysis_results = { "intervals": {}, "rhythm": {}, "arrhythmia": {}, "features": {} } try: # Use lead II for primary analysis if available, otherwise use first available lead primary_lead = 'II' if 'II' in signal_data else list(signal_data.keys())[0] signal = np.array(signal_data[primary_lead]) if len(signal) == 0: return analysis_results # Preprocess signal processed_signal = self._preprocess_signal(signal, sampling_rate) # Detect QRS complexes qrs_peaks = self._detect_qrs_complexes(processed_signal, sampling_rate) # Calculate intervals if len(qrs_peaks) > 1: rr_intervals = np.diff(qrs_peaks) / sampling_rate analysis_results["intervals"] = self._calculate_intervals( rr_intervals, processed_signal, qrs_peaks, sampling_rate ) # Analyze rhythm analysis_results["rhythm"] = self._analyze_rhythm(rr_intervals) # Detect arrhythmias analysis_results["arrhythmia"] = self._detect_arrhythmias( rr_intervals, processed_signal, qrs_peaks, sampling_rate ) # Calculate derived features analysis_results["features"] = self._calculate_derived_features( processed_signal, qrs_peaks, sampling_rate ) except Exception as e: logger.error(f"ECG analysis error: {str(e)}") return analysis_results def _preprocess_signal(self, signal: np.ndarray, sampling_rate: int) -> np.ndarray: """Preprocess ECG signal for analysis""" # Remove DC component signal = signal - np.mean(signal) # Apply bandpass filter (0.5-40 Hz for ECG) nyquist = sampling_rate / 2 low_freq = 0.5 / nyquist high_freq = 40 / nyquist b, a = scipy.signal.butter(4, [low_freq, high_freq], btype='band') filtered_signal = scipy.signal.filtfilt(b, a, signal) return filtered_signal def _detect_qrs_complexes(self, signal: np.ndarray, sampling_rate: int) -> List[int]: """Detect QRS complexes using simplified algorithm""" try: # Find peaks using scipy min_distance = int(0.2 * sampling_rate) # Minimum 200ms between beats peaks, properties = scipy.signal.find_peaks( np.abs(signal), height=np.std(signal) * 0.5, distance=min_distance ) return peaks.tolist() except Exception as e: logger.error(f"QRS detection error: {str(e)}") return [] def _calculate_intervals(self, rr_intervals: np.ndarray, signal: np.ndarray, qrs_peaks: List[int], sampling_rate: int) -> Dict[str, Optional[float]]: """Calculate ECG intervals""" intervals = {} try: # Heart rate from RR intervals if len(rr_intervals) > 0: mean_rr = np.mean(rr_intervals) heart_rate = 60.0 / mean_rr if mean_rr > 0 else None # Estimate PR interval (simplified) pr_interval = 0.16 # Normal PR interval ~160ms # Estimate QRS duration (simplified) qrs_duration = 0.08 # Normal QRS duration ~80ms # Calculate QT interval (simplified Bazett's formula) qt_interval = np.sqrt(mean_rr) * 0.4 # Simplified intervals.update({ "rr_ms": mean_rr * 1000, "pr_ms": pr_interval * 1000, "qrs_ms": qrs_duration * 1000, "qt_ms": qt_interval * 1000, "qtc_ms": (qt_interval / np.sqrt(mean_rr)) * 1000 if mean_rr > 0 else None, "heart_rate_bpm": heart_rate }) except Exception as e: logger.error(f"Interval calculation error: {str(e)}") return intervals def _analyze_rhythm(self, rr_intervals: np.ndarray) -> Dict[str, Any]: """Analyze cardiac rhythm characteristics""" rhythm_info = {} try: if len(rr_intervals) > 0: # Calculate rhythm regularity rr_std = np.std(rr_intervals) rr_mean = np.mean(rr_intervals) rr_cv = rr_std / rr_mean if rr_mean > 0 else 0 # Determine rhythm regularity if rr_cv < 0.1: regularity = "regular" elif rr_cv < 0.2: regularity = "slightly irregular" else: regularity = "irregular" # Calculate heart rate variability hrv = rr_std * 1000 # Convert to ms rhythm_info.update({ "regularity": regularity, "rr_variability_ms": hrv, "primary_rhythm": "sinus" if rr_cv < 0.15 else "irregular" }) except Exception as e: logger.error(f"Rhythm analysis error: {str(e)}") return rhythm_info def _detect_arrhythmias(self, rr_intervals: np.ndarray, signal: np.ndarray, qrs_peaks: List[int], sampling_rate: int) -> Dict[str, float]: """Detect potential arrhythmias""" arrhythmia_probs = {} try: if len(rr_intervals) > 0: mean_rr = np.mean(rr_intervals) rr_std = np.std(rr_intervals) # Atrial fibrillation detection (simplified) if rr_std / mean_rr > 0.2: # High variability arrhythmia_probs["atrial_fibrillation"] = min(0.7, rr_std / mean_rr) else: arrhythmia_probs["atrial_fibrillation"] = 0.1 # Normal rhythm probability arrhythmia_probs["normal_rhythm"] = max(0.3, 1.0 - (rr_std / mean_rr)) # Tachycardia/Bradycardia detection heart_rate = 60.0 / mean_rr if mean_rr > 0 else 60 if heart_rate > 100: arrhythmia_probs["tachycardia"] = min(0.8, (heart_rate - 100) / 50) else: arrhythmia_probs["tachycardia"] = 0.1 if heart_rate < 60: arrhythmia_probs["bradycardia"] = min(0.8, (60 - heart_rate) / 30) else: arrhythmia_probs["bradycardia"] = 0.1 # Set other arrhythmias to low probability arrhythmia_probs["atrial_flutter"] = 0.05 arrhythmia_probs["ventricular_tachycardia"] = 0.05 arrhythmia_probs["heart_block"] = 0.05 arrhythmia_probs["premature_beats"] = 0.1 except Exception as e: logger.error(f"Arrhythmia detection error: {str(e)}") # Set default low probabilities arrhythmia_probs = { "normal_rhythm": 0.5, "atrial_fibrillation": 0.1, "atrial_flutter": 0.1, "ventricular_tachycardia": 0.1, "heart_block": 0.1, "premature_beats": 0.1 } return arrhythmia_probs def _calculate_derived_features(self, signal: np.ndarray, qrs_peaks: List[int], sampling_rate: int) -> Dict[str, Any]: """Calculate derived ECG features""" features = {} try: # ST segment analysis (simplified) if len(qrs_peaks) > 2: # Find T waves after QRS complexes st_segments = [] for peak in qrs_peaks[:-1]: next_peak = qrs_peaks[qrs_peaks.index(peak) + 1] st_end = min(peak + int(0.3 * sampling_rate), next_peak) if st_end < len(signal): st_level = np.mean(signal[peak:st_end]) st_segments.append(st_level) if st_segments: features["st_deviation_mv"] = { "mean": np.mean(st_segments), "std": np.std(st_segments) } # QRS amplitude analysis if len(qrs_peaks) > 0: qrs_amplitudes = [] for peak in qrs_peaks: window_start = max(0, peak - int(0.05 * sampling_rate)) window_end = min(len(signal), peak + int(0.05 * sampling_rate)) if window_end > window_start: qrs_amplitude = np.max(signal[window_start:window_end]) - np.min(signal[window_start:window_end]) qrs_amplitudes.append(qrs_amplitude) if qrs_amplitudes: features["qrs_amplitude_mv"] = { "mean": np.mean(qrs_amplitudes), "std": np.std(qrs_amplitudes) } except Exception as e: logger.error(f"Derived features calculation error: {str(e)}") return features def _calculate_ecg_confidence(self, result: ECGProcessingResult, validation_result: Dict[str, Any]) -> float: """Calculate overall confidence score for ECG processing""" confidence_factors = [] # Signal quality factors if result.signal_data: confidence_factors.append(0.3) # Signal data present if len(result.lead_names) >= 3: confidence_factors.append(0.2) # Multiple leads available if result.sampling_rate > 200: confidence_factors.append(0.2) # Adequate sampling rate if result.duration > 5.0: confidence_factors.append(0.1) # Sufficient recording length # Validation factors if validation_result["is_valid"]: confidence_factors.append(0.2) else: confidence_factors.append(0.1) # Analysis completion factors if result.intervals: confidence_factors.append(0.2) if result.rhythm_info: confidence_factors.append(0.1) return min(1.0, sum(confidence_factors)) def convert_to_ecg_schema(self, result: ECGProcessingResult) -> Dict[str, Any]: """Convert ECG processing result to schema format""" try: # Create metadata metadata = MedicalDocumentMetadata( source_type="ECG", data_completeness=result.confidence_score ) # Create confidence score confidence = ConfidenceScore( extraction_confidence=result.confidence_score, model_confidence=0.8, # Assuming good analysis quality data_quality=0.9 ) # Create signal data signal_data = ECGSignalData( lead_names=result.lead_names, sampling_rate_hz=result.sampling_rate, signal_arrays=result.signal_data, duration_seconds=result.duration, num_samples=max(len(data) for data in result.signal_data.values()) if result.signal_data else 0 ) # Create intervals intervals = ECGIntervals( pr_ms=result.intervals.get("pr_ms"), qrs_ms=result.intervals.get("qrs_ms"), qt_ms=result.intervals.get("qt_ms"), qtc_ms=result.intervals.get("qtc_ms"), rr_ms=result.intervals.get("rr_ms") ) # Create rhythm classification rhythm_classification = ECGRhythmClassification( primary_rhythm=result.rhythm_info.get("primary_rhythm"), rhythm_confidence=0.8, # Assuming good analysis arrhythmia_types=[], heart_rate_bpm=int(result.intervals.get("heart_rate_bpm", 0)) if result.intervals.get("heart_rate_bpm") else None, heart_rate_regularity=result.rhythm_info.get("regularity") ) # Create arrhythmia probabilities arrhythmia_probs = ECGArrhythmiaProbabilities( normal_rhythm=result.arrhythmia_analysis.get("normal_rhythm", 0.5), atrial_fibrillation=result.arrhythmia_analysis.get("atrial_fibrillation", 0.1), atrial_flutter=result.arrhythmia_analysis.get("atrial_flutter", 0.1), ventricular_tachycardia=result.arrhythmia_analysis.get("ventricular_tachycardia", 0.1), heart_block=result.arrhythmia_analysis.get("heart_block", 0.1), premature_beats=result.arrhythmia_analysis.get("premature_beats", 0.1) ) # Create derived features derived_features = ECGDerivedFeatures( st_elevation_mm=result.derived_features.get("st_deviation_mv", {}), st_depression_mm=None, t_wave_abnormalities=[], q_wave_indicators=[], voltage_criteria=result.derived_features.get("qrs_amplitude_mv", {}), axis_deviation=None ) return { "metadata": metadata.dict(), "signal_data": signal_data.dict(), "intervals": intervals.dict(), "rhythm_classification": rhythm_classification.dict(), "arrhythmia_probabilities": arrhythmia_probs.dict(), "derived_features": derived_features.dict(), "confidence": confidence.dict(), "clinical_summary": f"ECG analysis completed for {len(result.lead_names)} leads over {result.duration:.1f} seconds", "recommendations": ["Review by cardiologist recommended"] if result.confidence_score < 0.8 else [] } except Exception as e: logger.error(f"ECG schema conversion error: {str(e)}") return {"error": str(e)} # Export main classes __all__ = [ "ECGSignalProcessor", "ECGProcessingResult" ]