import wfdb from wfdb import processing import numpy as np import joblib import pywt import os import cv2 from pdf2image import convert_from_path import warnings import pickle from scipy import signal as sg warnings.filterwarnings('ignore') def extract_hrv_features(rr_intervals): """ Extract heart rate variability features from RR intervals. Args: rr_intervals (numpy.ndarray): RR intervals in seconds Returns: list: Four HRV features [sdnn, rmssd, pnn50, tri_index] """ if len(rr_intervals) < 2: return [0, 0, 0, 0] sdnn = np.std(rr_intervals) diff_rr = np.diff(rr_intervals) rmssd = np.sqrt(np.mean(diff_rr**2)) if len(diff_rr) > 0 else 0 pnn50 = 100 * np.sum(np.abs(diff_rr) > 0.05) / len(diff_rr) if len(diff_rr) > 0 else 0 if len(rr_intervals) > 2: bin_width = 1/128 bins = np.arange(min(rr_intervals), max(rr_intervals) + bin_width, bin_width) n, _ = np.histogram(rr_intervals, bins=bins) tri_index = len(rr_intervals) / np.max(n) if np.max(n) > 0 else 0 else: tri_index = 0 return [sdnn, rmssd, pnn50, tri_index] def extract_qrs_features(signal, r_peaks, fs): """ Extract QRS complex features from ECG signal and detected R peaks. Args: signal (numpy.ndarray): ECG signal r_peaks (numpy.ndarray): Array of R peak indices fs (int): Sampling frequency in Hz Returns: list: Three QRS features [qrs_width_mean, qrs_width_std, qrs_amplitude_mean] """ if len(r_peaks) < 2: return [0, 0, 0] qrs_width = [] for i in range(len(r_peaks)): r_pos = r_peaks[i] window_before = max(0, r_pos - int(0.1 * fs)) window_after = min(len(signal) - 1, r_pos + int(0.1 * fs)) if r_pos > window_before: q_pos = window_before + np.argmin(signal[window_before:r_pos]) else: q_pos = window_before if r_pos < window_after: s_pos = r_pos + np.argmin(signal[r_pos:window_after]) else: s_pos = r_pos if s_pos > q_pos: qrs_width.append((s_pos - q_pos) / fs) qrs_width_mean = np.mean(qrs_width) if qrs_width else 0 qrs_width_std = np.std(qrs_width) if qrs_width else 0 qrs_amplitude_mean = np.mean([signal[r] for r in r_peaks]) if r_peaks.size > 0 else 0 return [qrs_width_mean, qrs_width_std, qrs_amplitude_mean] def digitize_ecg_from_pdf(pdf_path, output_file=None): """ Process an ECG PDF file and convert it to a .dat signal file. Args: pdf_path (str): Path to the ECG PDF file output_file (str, optional): Path to save the output .dat file Returns: tuple: (path to the created .dat file, list of paths to segment files) """ if output_file is None: output_file = 'calibrated_ecg.dat' images = convert_from_path(pdf_path) temp_image_path = 'temp_ecg_image.jpg' images[0].save(temp_image_path, 'JPEG') img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE) height, width = img.shape calibration = { 'seconds_per_pixel': 2.0 / 197.0, 'mv_per_pixel': 1.0 / 78.8, } layer1_start = int(height * 35.35 / 100) layer1_end = int(height * 51.76 / 100) layer2_start = int(height * 51.82 / 100) layer2_end = int(height * 69.41 / 100) layer3_start = int(height * 69.47 / 100) layer3_end = int(height * 87.06 / 100) layers = [ img[layer1_start:layer1_end, :], img[layer2_start:layer2_end, :], img[layer3_start:layer3_end, :] ] signals = [] time_points = [] layer_duration = 10.0 for i, layer in enumerate(layers): _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) waveform_contour = max(contours, key=cv2.contourArea) sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0]) x_coords = np.array([point[0][0] for point in sorted_contour]) y_coords = np.array([point[0][1] for point in sorted_contour]) isoelectric_line_y = layer.shape[0] * 0.6 x_min, x_max = np.min(x_coords), np.max(x_coords) time = (x_coords - x_min) / (x_max - x_min) * layer_duration signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel'] signal_mv = signal_mv - np.mean(signal_mv) time_points.append(time) signals.append(signal_mv) total_duration = layer_duration * len(layers) sampling_frequency = 500 num_samples = int(total_duration * sampling_frequency) combined_time = np.linspace(0, total_duration, num_samples) combined_signal = np.zeros(num_samples) for i, (time, signal) in enumerate(zip(time_points, signals)): start_time = i * layer_duration mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration) relevant_times = combined_time[mask] interpolated_signal = np.interp(relevant_times, start_time + time, signal) combined_signal[mask] = interpolated_signal combined_signal = combined_signal - np.mean(combined_signal) signal_peak = np.max(np.abs(combined_signal)) target_amplitude = 2.0 if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0): scaling_factor = target_amplitude / signal_peak combined_signal = combined_signal * scaling_factor adc_gain = 1000.0 int_signal = (combined_signal * adc_gain).astype(np.int16) int_signal.tofile(output_file) if os.path.exists(temp_image_path): os.remove(temp_image_path) segment_files = [] samples_per_segment = int(layer_duration * sampling_frequency) base_name = os.path.splitext(output_file)[0] for i in range(3): start_idx = i * samples_per_segment end_idx = (i + 1) * samples_per_segment segment = combined_signal[start_idx:end_idx] segment_file = f"{base_name}_segment{i+1}.dat" (segment * adc_gain).astype(np.int16).tofile(segment_file) segment_files.append(segment_file) return output_file, segment_files def split_dat_into_segments(file_path, segment_duration=10.0): """ Split a DAT file into equal segments. Args: file_path (str): Path to the DAT file (without extension) segment_duration (float): Duration of each segment in seconds Returns: list: Paths to the segment files """ signal_all_leads, fs = load_dat_signal(file_path) if signal_all_leads.shape[1] == 1: lead_index = 0 else: lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0) lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0) signal = signal_all_leads[:, lead_index] samples_per_segment = int(segment_duration * fs) total_samples = len(signal) num_segments = total_samples // samples_per_segment segment_files = [] base_name = os.path.splitext(file_path)[0] for i in range(num_segments): start_idx = i * samples_per_segment end_idx = (i + 1) * samples_per_segment segment = signal[start_idx:end_idx] segment_file = f"{base_name}_segment{i+1}.dat" segment.reshape(-1, 1).tofile(segment_file) segment_files.append(segment_file) return segment_files def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16): """ Load a DAT file containing ECG signal data. Args: file_path (str): Path to the DAT file (without extension) n_leads (int): Number of leads in the signal n_samples (int): Number of samples per lead dtype: Data type of the signal Returns: tuple: (numpy array of signal data, sampling frequency) """ if file_path.endswith('.dat'): dat_path = file_path else: dat_path = file_path + '.dat' raw = np.fromfile(dat_path, dtype=dtype) if raw.size != n_leads * n_samples: if raw.size == n_samples: signal = raw.reshape(n_samples, 1) return signal, 500 possible_leads = [1, 2, 3, 6, 12] for possible_lead_count in possible_leads: if raw.size % possible_lead_count == 0: actual_samples = raw.size // possible_lead_count signal = raw.reshape(actual_samples, possible_lead_count) return signal, 500 signal = raw.reshape(-1, 1) return signal, 500 signal = raw.reshape(n_samples, n_leads) return signal, 500 def extract_features_from_signal(signal): """ Extract features from an ECG signal. Args: signal (numpy.ndarray): ECG signal Returns: list: Basic features extracted from the signal (32 features) """ features = [] features.append(np.mean(signal)) features.append(np.std(signal)) features.append(np.median(signal)) features.append(np.min(signal)) features.append(np.max(signal)) features.append(np.percentile(signal, 25)) features.append(np.percentile(signal, 75)) features.append(np.mean(np.diff(signal))) coeffs = pywt.wavedec(signal, 'db4', level=5) for coeff in coeffs: features.append(np.mean(coeff)) features.append(np.std(coeff)) features.append(np.min(coeff)) features.append(np.max(coeff)) return features def classify_new_ecg(file_path, model): """ Classify a new ECG file. Args: file_path (str): Path to the ECG file (without extension) model: The trained model for classification Returns: str: Classification result ("Normal", "Abnormal", or error message) """ signal_all_leads, fs = load_dat_signal(file_path) if signal_all_leads.shape[1] == 1: lead_index = 0 else: lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0) lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0) signal = signal_all_leads[:, lead_index] signal = (signal - np.mean(signal)) / np.std(signal) try: r_peaks = processing.gqrs_detect(sig=signal, fs=fs) except: r_peaks = np.array([]) if len(r_peaks) < 2: basic_features = extract_features_from_signal(signal) record_features = basic_features + [0] * (45 - len(basic_features)) else: rr_intervals = np.diff(r_peaks) / fs qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))]) record_features = [] basic_features = extract_features_from_signal(signal) record_features.extend(basic_features) record_features.extend([ len(r_peaks), np.mean(rr_intervals) if len(rr_intervals) > 0 else 0, np.std(rr_intervals) if len(rr_intervals) > 0 else 0, np.median(rr_intervals) if len(rr_intervals) > 0 else 0, np.mean(qrs_durations) / fs if len(qrs_durations) > 0 else 0, np.std(qrs_durations) / fs if len(qrs_durations) > 0 else 0 ]) hrv_features = extract_hrv_features(rr_intervals) record_features.extend(hrv_features) qrs_features = extract_qrs_features(signal, r_peaks, fs) record_features.extend(qrs_features) if len(rr_intervals) >= 4: try: rr_times = np.cumsum(rr_intervals) rr_times = np.insert(rr_times, 0, 0) fs_interp = 4.0 t_interp = np.arange(0, rr_times[-1], 1/fs_interp) rr_interp = np.interp(t_interp, rr_times[:-1], rr_intervals) freq, psd = sg.welch(rr_interp, fs=fs_interp, nperseg=min(256, len(rr_interp))) vlf_mask = (freq >= 0.0033) & (freq < 0.04) lf_mask = (freq >= 0.04) & (freq < 0.15) hf_mask = (freq >= 0.15) & (freq < 0.4) lf_power = np.trapz(psd[lf_mask], freq[lf_mask]) if np.any(lf_mask) else 0 hf_power = np.trapz(psd[hf_mask], freq[hf_mask]) if np.any(hf_mask) else 0 lf_hf_ratio = lf_power / hf_power if hf_power > 0 else 0 normalized_lf = lf_power / (lf_power + hf_power) if (lf_power + hf_power) > 0 else 0 except: lf_power = hf_power = lf_hf_ratio = normalized_lf = 0 else: lf_power = hf_power = lf_hf_ratio = normalized_lf = 0 record_features.extend([lf_power, hf_power, lf_hf_ratio, normalized_lf]) if len(record_features) < 45: record_features.extend([0] * (45 - len(record_features))) elif len(record_features) > 45: record_features = record_features[:45] prediction = model.predict([record_features])[0] result = "Abnormal" if prediction == 1 else "Normal" return result def classify_ecg(file_path, model, is_pdf=False): """ Wrapper function that handles both PDF and DAT ECG files with segment voting. Args: file_path (str): Path to the ECG file (.pdf or without extension for .dat) model: The trained model for classification is_pdf (bool): Whether the input file is a PDF (True) or DAT (False) Returns: str: Classification result ("Normal", "Abnormal", or error message) """ try: if model is None: return "Error: Model not loaded. Please check model compatibility." if is_pdf: base_name = os.path.splitext(os.path.basename(file_path))[0] output_dat = f"{base_name}_digitized.dat" dat_path, segment_files = digitize_ecg_from_pdf( pdf_path=file_path, output_file=output_dat ) else: segment_files = split_dat_into_segments(file_path) if not segment_files: return classify_new_ecg(file_path, model) segment_results = [] for segment_file in segment_files: segment_path = os.path.splitext(segment_file)[0] result = classify_new_ecg(segment_path, model) segment_results.append(result) try: os.remove(segment_file) except: pass if segment_results: normal_count = segment_results.count("Normal") abnormal_count = segment_results.count("Abnormal") if abnormal_count > normal_count: final_result = "Abnormal" elif normal_count > abnormal_count: final_result = "Normal" else: final_result = "Inconclusive" return final_result else: return "Error: No valid segments to classify" except Exception as e: error_msg = f"Classification error: {str(e)}" return error_msg