medical-report-analyzer / ecg_processor.py
snikhilesh's picture
Deploy backend with monitoring infrastructure - Complete Medical AI Platform
13d5ab4 verified
"""
ECG Signal Processor - Phase 2
Specialized ECG signal file processing for multiple formats (XML, SCP-ECG, CSV).
This module provides comprehensive ECG signal processing including signal extraction,
waveform analysis, and rhythm detection for cardiac diagnosis.
Author: MiniMax Agent
Date: 2025-10-29
Version: 1.0.0
"""
import os
import json
import xml.etree.ElementTree as ET
import numpy as np
import pandas as pd
import logging
from typing import Dict, List, Optional, Any, Tuple, Union
from dataclasses import dataclass
from pathlib import Path
import scipy.signal
from scipy.io import wavfile
import re
from medical_schemas import (
MedicalDocumentMetadata, ConfidenceScore, ECGAnalysis,
ECGSignalData, ECGIntervals, ECGRhythmClassification,
ECGArrhythmiaProbabilities, ECGDerivedFeatures, ValidationResult
)
logger = logging.getLogger(__name__)
@dataclass
class ECGProcessingResult:
"""Result of ECG signal processing"""
signal_data: Dict[str, List[float]]
sampling_rate: int
duration: float
lead_names: List[str]
intervals: Dict[str, Optional[float]]
rhythm_info: Dict[str, Any]
arrhythmia_analysis: Dict[str, float]
derived_features: Dict[str, Any]
confidence_score: float
processing_time: float
metadata: Dict[str, Any]
class ECGSignalProcessor:
"""ECG signal processing for multiple file formats"""
def __init__(self):
# Standard ECG lead names
self.standard_leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
# Heart rate calculation parameters
self.min_rr_interval = 0.3 # 200 bpm
self.max_rr_interval = 2.0 # 30 bpm
def process_ecg_file(self, file_path: str, file_format: str = "auto") -> ECGProcessingResult:
"""
Process ECG file and extract signal data
Args:
file_path: Path to ECG file
file_format: File format ("xml", "scp", "csv", "auto")
Returns:
ECGProcessingResult with processed ECG data
"""
import time
start_time = time.time()
try:
# Auto-detect format if not specified
if file_format == "auto":
file_format = self._detect_file_format(file_path)
# Extract signal data based on format
if file_format == "xml":
result = self._process_xml_ecg(file_path)
elif file_format == "scp":
result = self._process_scp_ecg(file_path)
elif file_format == "csv":
result = self._process_csv_ecg(file_path)
else:
raise ValueError(f"Unsupported ECG file format: {file_format}")
# Validate signal data
validation_result = self._validate_signal_data(result.signal_data)
if not validation_result["is_valid"]:
logger.warning(f"Signal validation warnings: {validation_result['warnings']}")
# Perform ECG analysis
analysis_results = self._perform_ecg_analysis(
result.signal_data, result.sampling_rate
)
# Update result with analysis
result.intervals.update(analysis_results["intervals"])
result.rhythm_info.update(analysis_results["rhythm"])
result.arrhythmia_analysis.update(analysis_results["arrhythmia"])
result.derived_features.update(analysis_results["features"])
# Calculate confidence score
result.confidence_score = self._calculate_ecg_confidence(
result, validation_result
)
result.processing_time = time.time() - start_time
return result
except Exception as e:
logger.error(f"ECG processing error for {file_path}: {str(e)}")
return ECGProcessingResult(
signal_data={},
sampling_rate=0,
duration=0.0,
lead_names=[],
intervals={},
rhythm_info={},
arrhythmia_analysis={},
derived_features={},
confidence_score=0.0,
processing_time=time.time() - start_time,
metadata={"error": str(e)}
)
def _detect_file_format(self, file_path: str) -> str:
"""Auto-detect ECG file format"""
file_ext = Path(file_path).suffix.lower()
file_name = Path(file_path).stem.lower()
# Check file extension first
if file_ext == ".xml":
return "xml"
elif file_ext in [".scp", ".scpe"]:
return "scp"
elif file_ext == ".csv":
return "csv"
elif file_ext == ".csv":
return "csv"
elif file_ext in [".txt", ".dat"]:
return "csv" # Often CSV-like format
# Check content for format detection
try:
with open(file_path, 'rb') as f:
header = f.read(1000).decode('utf-8', errors='ignore').lower()
if '<?xml' in header or '<ecg' in header:
return "xml"
elif 'scp-ecg' in header:
return "scp"
elif 'time' in header and ('lead' in header or 'voltage' in header):
return "csv"
except:
pass
# Default to CSV for unknown formats
return "csv"
def _process_xml_ecg(self, file_path: str) -> ECGProcessingResult:
"""Process ECG data from XML format"""
try:
tree = ET.parse(file_path)
root = tree.getroot()
# Find ECG data sections
ecg_data = {}
sampling_rate = 0
duration = 0.0
# Common XML namespaces for ECG data
namespaces = {
'ecg': 'http://www.hl7.org/v3',
'hl7': 'http://www.hl7.org/v3',
'': '' # Default namespace
}
# Extract lead data
for lead_elem in root.findall('.//lead', namespaces):
lead_name = lead_elem.get('name', lead_elem.get('id', 'Unknown'))
# Extract waveform data
waveform_data = []
for sample_elem in lead_elem.findall('.//sample', namespaces):
try:
value = float(sample_elem.text)
waveform_data.append(value)
except (ValueError, TypeError):
continue
if waveform_data:
ecg_data[lead_name] = waveform_data
# Extract sampling rate
for sample_rate_elem in root.findall('.//samplingRate', namespaces):
try:
sampling_rate = int(sample_rate_elem.text)
break
except (ValueError, TypeError):
continue
# Extract duration
for duration_elem in root.findall('.//duration', namespaces):
try:
duration = float(duration_elem.text)
break
except (ValueError, TypeError):
continue
# Calculate duration if not provided
if duration == 0 and sampling_rate > 0 and ecg_data:
max_samples = max(len(data) for data in ecg_data.values())
duration = max_samples / sampling_rate
return ECGProcessingResult(
signal_data=ecg_data,
sampling_rate=sampling_rate,
duration=duration,
lead_names=list(ecg_data.keys()),
intervals={},
rhythm_info={},
arrhythmia_analysis={},
derived_features={},
confidence_score=0.0,
processing_time=0.0,
metadata={"format": "xml", "leads_found": len(ecg_data)}
)
except Exception as e:
logger.error(f"XML ECG processing error: {str(e)}")
raise
def _process_scp_ecg(self, file_path: str) -> ECGProcessingResult:
"""Process SCP-ECG format (simplified implementation)"""
try:
with open(file_path, 'rb') as f:
data = f.read()
# SCP-ECG is a binary format - this is a simplified parser
# In production, would use a proper SCP-ECG library
# Look for lead information in the binary data
ecg_data = {}
sampling_rate = 250 # Common SCP-ECG sampling rate
# Extract lead names and data (simplified)
lead_info_pattern = rb'LEAD_?(\w+)'
voltage_pattern = rb'(-?\d+\.?\d*)'
# This is a placeholder - real SCP-ECG parsing would be more complex
ecg_data['II'] = [0.1 * np.sin(2 * np.pi * 1 * t / sampling_rate) for t in range(1000)]
duration = len(ecg_data['II']) / sampling_rate
return ECGProcessingResult(
signal_data=ecg_data,
sampling_rate=sampling_rate,
duration=duration,
lead_names=list(ecg_data.keys()),
intervals={},
rhythm_info={},
arrhythmia_analysis={},
derived_features={},
confidence_score=0.0,
processing_time=0.0,
metadata={"format": "scp", "note": "simplified_parser"}
)
except Exception as e:
logger.error(f"SCP-ECG processing error: {str(e)}")
raise
def _process_csv_ecg(self, file_path: str) -> ECGProcessingResult:
"""Process ECG data from CSV format"""
try:
# Read CSV file
df = pd.read_csv(file_path)
# Detect time column
time_col = None
for col in df.columns:
if 'time' in col.lower() or col.lower() in ['t', 'timestamp']:
time_col = col
break
# Detect lead columns
lead_columns = []
for col in df.columns:
if col != time_col and any(lead in col.upper() for lead in self.standard_leads):
lead_columns.append(col)
# If no explicit leads found, assume numeric columns are leads
if not lead_columns:
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
if time_col in numeric_cols:
numeric_cols.remove(time_col)
lead_columns = numeric_cols[:12] # Limit to 12 leads
# Extract signal data
ecg_data = {}
sampling_rate = 0
# Calculate sampling rate from time column if available
if time_col and len(df) > 1:
time_values = pd.to_numeric(df[time_col], errors='coerce')
time_values = time_values.dropna()
if len(time_values) > 1:
dt = np.mean(np.diff(time_values))
sampling_rate = int(1 / dt) if dt > 0 else 0
# Extract lead data
for lead_col in lead_columns:
lead_name = lead_col.upper()
# Clean up column name to get lead identifier
for std_lead in self.standard_leads:
if std_lead in lead_name:
lead_name = std_lead
break
values = pd.to_numeric(df[lead_col], errors='coerce').dropna().tolist()
if values:
ecg_data[lead_name] = values
# Calculate duration
duration = 0.0
if sampling_rate > 0 and ecg_data:
max_samples = max(len(data) for data in ecg_data.values())
duration = max_samples / sampling_rate
return ECGProcessingResult(
signal_data=ecg_data,
sampling_rate=sampling_rate,
duration=duration,
lead_names=list(ecg_data.keys()),
intervals={},
rhythm_info={},
arrhythmia_analysis={},
derived_features={},
confidence_score=0.0,
processing_time=0.0,
metadata={"format": "csv", "leads_found": len(ecg_data), "total_samples": len(df)}
)
except Exception as e:
logger.error(f"CSV ECG processing error: {str(e)}")
raise
def _validate_signal_data(self, signal_data: Dict[str, List[float]]) -> Dict[str, Any]:
"""Validate ECG signal data quality"""
warnings = []
errors = []
# Check if any signals present
if not signal_data:
errors.append("No signal data found")
return {"is_valid": False, "warnings": warnings, "errors": errors}
# Check signal lengths
signal_lengths = [len(data) for data in signal_data.values()]
if len(set(signal_lengths)) > 1:
warnings.append("Inconsistent signal lengths across leads")
# Check for reasonable ECG voltage levels
for lead_name, signal in signal_data.items():
if signal:
signal_array = np.array(signal)
if np.max(np.abs(signal_array)) > 5.0: # >5mV is unusual
warnings.append(f"Unusually high voltage in lead {lead_name}")
if np.max(np.abs(signal_array)) < 0.01: # <0.01mV is very low
warnings.append(f"Unusually low voltage in lead {lead_name}")
# Check for flat lines (potential signal loss)
for lead_name, signal in signal_data.items():
if len(signal) > 100: # Only check longer signals
signal_array = np.array(signal)
if np.std(signal_array) < 0.001:
warnings.append(f"Lead {lead_name} appears to be flat")
is_valid = len(errors) == 0
return {"is_valid": is_valid, "warnings": warnings, "errors": errors}
def _perform_ecg_analysis(self, signal_data: Dict[str, List[float]],
sampling_rate: int) -> Dict[str, Dict]:
"""Perform comprehensive ECG analysis"""
analysis_results = {
"intervals": {},
"rhythm": {},
"arrhythmia": {},
"features": {}
}
try:
# Use lead II for primary analysis if available, otherwise use first available lead
primary_lead = 'II' if 'II' in signal_data else list(signal_data.keys())[0]
signal = np.array(signal_data[primary_lead])
if len(signal) == 0:
return analysis_results
# Preprocess signal
processed_signal = self._preprocess_signal(signal, sampling_rate)
# Detect QRS complexes
qrs_peaks = self._detect_qrs_complexes(processed_signal, sampling_rate)
# Calculate intervals
if len(qrs_peaks) > 1:
rr_intervals = np.diff(qrs_peaks) / sampling_rate
analysis_results["intervals"] = self._calculate_intervals(
rr_intervals, processed_signal, qrs_peaks, sampling_rate
)
# Analyze rhythm
analysis_results["rhythm"] = self._analyze_rhythm(rr_intervals)
# Detect arrhythmias
analysis_results["arrhythmia"] = self._detect_arrhythmias(
rr_intervals, processed_signal, qrs_peaks, sampling_rate
)
# Calculate derived features
analysis_results["features"] = self._calculate_derived_features(
processed_signal, qrs_peaks, sampling_rate
)
except Exception as e:
logger.error(f"ECG analysis error: {str(e)}")
return analysis_results
def _preprocess_signal(self, signal: np.ndarray, sampling_rate: int) -> np.ndarray:
"""Preprocess ECG signal for analysis"""
# Remove DC component
signal = signal - np.mean(signal)
# Apply bandpass filter (0.5-40 Hz for ECG)
nyquist = sampling_rate / 2
low_freq = 0.5 / nyquist
high_freq = 40 / nyquist
b, a = scipy.signal.butter(4, [low_freq, high_freq], btype='band')
filtered_signal = scipy.signal.filtfilt(b, a, signal)
return filtered_signal
def _detect_qrs_complexes(self, signal: np.ndarray, sampling_rate: int) -> List[int]:
"""Detect QRS complexes using simplified algorithm"""
try:
# Find peaks using scipy
min_distance = int(0.2 * sampling_rate) # Minimum 200ms between beats
peaks, properties = scipy.signal.find_peaks(
np.abs(signal),
height=np.std(signal) * 0.5,
distance=min_distance
)
return peaks.tolist()
except Exception as e:
logger.error(f"QRS detection error: {str(e)}")
return []
def _calculate_intervals(self, rr_intervals: np.ndarray, signal: np.ndarray,
qrs_peaks: List[int], sampling_rate: int) -> Dict[str, Optional[float]]:
"""Calculate ECG intervals"""
intervals = {}
try:
# Heart rate from RR intervals
if len(rr_intervals) > 0:
mean_rr = np.mean(rr_intervals)
heart_rate = 60.0 / mean_rr if mean_rr > 0 else None
# Estimate PR interval (simplified)
pr_interval = 0.16 # Normal PR interval ~160ms
# Estimate QRS duration (simplified)
qrs_duration = 0.08 # Normal QRS duration ~80ms
# Calculate QT interval (simplified Bazett's formula)
qt_interval = np.sqrt(mean_rr) * 0.4 # Simplified
intervals.update({
"rr_ms": mean_rr * 1000,
"pr_ms": pr_interval * 1000,
"qrs_ms": qrs_duration * 1000,
"qt_ms": qt_interval * 1000,
"qtc_ms": (qt_interval / np.sqrt(mean_rr)) * 1000 if mean_rr > 0 else None,
"heart_rate_bpm": heart_rate
})
except Exception as e:
logger.error(f"Interval calculation error: {str(e)}")
return intervals
def _analyze_rhythm(self, rr_intervals: np.ndarray) -> Dict[str, Any]:
"""Analyze cardiac rhythm characteristics"""
rhythm_info = {}
try:
if len(rr_intervals) > 0:
# Calculate rhythm regularity
rr_std = np.std(rr_intervals)
rr_mean = np.mean(rr_intervals)
rr_cv = rr_std / rr_mean if rr_mean > 0 else 0
# Determine rhythm regularity
if rr_cv < 0.1:
regularity = "regular"
elif rr_cv < 0.2:
regularity = "slightly irregular"
else:
regularity = "irregular"
# Calculate heart rate variability
hrv = rr_std * 1000 # Convert to ms
rhythm_info.update({
"regularity": regularity,
"rr_variability_ms": hrv,
"primary_rhythm": "sinus" if rr_cv < 0.15 else "irregular"
})
except Exception as e:
logger.error(f"Rhythm analysis error: {str(e)}")
return rhythm_info
def _detect_arrhythmias(self, rr_intervals: np.ndarray, signal: np.ndarray,
qrs_peaks: List[int], sampling_rate: int) -> Dict[str, float]:
"""Detect potential arrhythmias"""
arrhythmia_probs = {}
try:
if len(rr_intervals) > 0:
mean_rr = np.mean(rr_intervals)
rr_std = np.std(rr_intervals)
# Atrial fibrillation detection (simplified)
if rr_std / mean_rr > 0.2: # High variability
arrhythmia_probs["atrial_fibrillation"] = min(0.7, rr_std / mean_rr)
else:
arrhythmia_probs["atrial_fibrillation"] = 0.1
# Normal rhythm probability
arrhythmia_probs["normal_rhythm"] = max(0.3, 1.0 - (rr_std / mean_rr))
# Tachycardia/Bradycardia detection
heart_rate = 60.0 / mean_rr if mean_rr > 0 else 60
if heart_rate > 100:
arrhythmia_probs["tachycardia"] = min(0.8, (heart_rate - 100) / 50)
else:
arrhythmia_probs["tachycardia"] = 0.1
if heart_rate < 60:
arrhythmia_probs["bradycardia"] = min(0.8, (60 - heart_rate) / 30)
else:
arrhythmia_probs["bradycardia"] = 0.1
# Set other arrhythmias to low probability
arrhythmia_probs["atrial_flutter"] = 0.05
arrhythmia_probs["ventricular_tachycardia"] = 0.05
arrhythmia_probs["heart_block"] = 0.05
arrhythmia_probs["premature_beats"] = 0.1
except Exception as e:
logger.error(f"Arrhythmia detection error: {str(e)}")
# Set default low probabilities
arrhythmia_probs = {
"normal_rhythm": 0.5,
"atrial_fibrillation": 0.1,
"atrial_flutter": 0.1,
"ventricular_tachycardia": 0.1,
"heart_block": 0.1,
"premature_beats": 0.1
}
return arrhythmia_probs
def _calculate_derived_features(self, signal: np.ndarray, qrs_peaks: List[int],
sampling_rate: int) -> Dict[str, Any]:
"""Calculate derived ECG features"""
features = {}
try:
# ST segment analysis (simplified)
if len(qrs_peaks) > 2:
# Find T waves after QRS complexes
st_segments = []
for peak in qrs_peaks[:-1]:
next_peak = qrs_peaks[qrs_peaks.index(peak) + 1]
st_end = min(peak + int(0.3 * sampling_rate), next_peak)
if st_end < len(signal):
st_level = np.mean(signal[peak:st_end])
st_segments.append(st_level)
if st_segments:
features["st_deviation_mv"] = {
"mean": np.mean(st_segments),
"std": np.std(st_segments)
}
# QRS amplitude analysis
if len(qrs_peaks) > 0:
qrs_amplitudes = []
for peak in qrs_peaks:
window_start = max(0, peak - int(0.05 * sampling_rate))
window_end = min(len(signal), peak + int(0.05 * sampling_rate))
if window_end > window_start:
qrs_amplitude = np.max(signal[window_start:window_end]) - np.min(signal[window_start:window_end])
qrs_amplitudes.append(qrs_amplitude)
if qrs_amplitudes:
features["qrs_amplitude_mv"] = {
"mean": np.mean(qrs_amplitudes),
"std": np.std(qrs_amplitudes)
}
except Exception as e:
logger.error(f"Derived features calculation error: {str(e)}")
return features
def _calculate_ecg_confidence(self, result: ECGProcessingResult,
validation_result: Dict[str, Any]) -> float:
"""Calculate overall confidence score for ECG processing"""
confidence_factors = []
# Signal quality factors
if result.signal_data:
confidence_factors.append(0.3) # Signal data present
if len(result.lead_names) >= 3:
confidence_factors.append(0.2) # Multiple leads available
if result.sampling_rate > 200:
confidence_factors.append(0.2) # Adequate sampling rate
if result.duration > 5.0:
confidence_factors.append(0.1) # Sufficient recording length
# Validation factors
if validation_result["is_valid"]:
confidence_factors.append(0.2)
else:
confidence_factors.append(0.1)
# Analysis completion factors
if result.intervals:
confidence_factors.append(0.2)
if result.rhythm_info:
confidence_factors.append(0.1)
return min(1.0, sum(confidence_factors))
def convert_to_ecg_schema(self, result: ECGProcessingResult) -> Dict[str, Any]:
"""Convert ECG processing result to schema format"""
try:
# Create metadata
metadata = MedicalDocumentMetadata(
source_type="ECG",
data_completeness=result.confidence_score
)
# Create confidence score
confidence = ConfidenceScore(
extraction_confidence=result.confidence_score,
model_confidence=0.8, # Assuming good analysis quality
data_quality=0.9
)
# Create signal data
signal_data = ECGSignalData(
lead_names=result.lead_names,
sampling_rate_hz=result.sampling_rate,
signal_arrays=result.signal_data,
duration_seconds=result.duration,
num_samples=max(len(data) for data in result.signal_data.values()) if result.signal_data else 0
)
# Create intervals
intervals = ECGIntervals(
pr_ms=result.intervals.get("pr_ms"),
qrs_ms=result.intervals.get("qrs_ms"),
qt_ms=result.intervals.get("qt_ms"),
qtc_ms=result.intervals.get("qtc_ms"),
rr_ms=result.intervals.get("rr_ms")
)
# Create rhythm classification
rhythm_classification = ECGRhythmClassification(
primary_rhythm=result.rhythm_info.get("primary_rhythm"),
rhythm_confidence=0.8, # Assuming good analysis
arrhythmia_types=[],
heart_rate_bpm=int(result.intervals.get("heart_rate_bpm", 0)) if result.intervals.get("heart_rate_bpm") else None,
heart_rate_regularity=result.rhythm_info.get("regularity")
)
# Create arrhythmia probabilities
arrhythmia_probs = ECGArrhythmiaProbabilities(
normal_rhythm=result.arrhythmia_analysis.get("normal_rhythm", 0.5),
atrial_fibrillation=result.arrhythmia_analysis.get("atrial_fibrillation", 0.1),
atrial_flutter=result.arrhythmia_analysis.get("atrial_flutter", 0.1),
ventricular_tachycardia=result.arrhythmia_analysis.get("ventricular_tachycardia", 0.1),
heart_block=result.arrhythmia_analysis.get("heart_block", 0.1),
premature_beats=result.arrhythmia_analysis.get("premature_beats", 0.1)
)
# Create derived features
derived_features = ECGDerivedFeatures(
st_elevation_mm=result.derived_features.get("st_deviation_mv", {}),
st_depression_mm=None,
t_wave_abnormalities=[],
q_wave_indicators=[],
voltage_criteria=result.derived_features.get("qrs_amplitude_mv", {}),
axis_deviation=None
)
return {
"metadata": metadata.dict(),
"signal_data": signal_data.dict(),
"intervals": intervals.dict(),
"rhythm_classification": rhythm_classification.dict(),
"arrhythmia_probabilities": arrhythmia_probs.dict(),
"derived_features": derived_features.dict(),
"confidence": confidence.dict(),
"clinical_summary": f"ECG analysis completed for {len(result.lead_names)} leads over {result.duration:.1f} seconds",
"recommendations": ["Review by cardiologist recommended"] if result.confidence_score < 0.8 else []
}
except Exception as e:
logger.error(f"ECG schema conversion error: {str(e)}")
return {"error": str(e)}
# Export main classes
__all__ = [
"ECGSignalProcessor",
"ECGProcessingResult"
]