Deploy_El7a2ny_Application / ECG /ECG_Classify.py
Hussein El-Hadidy
Latest ECG
676f928
import wfdb
from wfdb import processing
import numpy as np
import joblib
import pywt
import os
import cv2
from pdf2image import convert_from_path
import warnings
import pickle
from scipy import signal as sg
warnings.filterwarnings('ignore')
def extract_hrv_features(rr_intervals):
"""
Extract heart rate variability features from RR intervals.
Args:
rr_intervals (numpy.ndarray): RR intervals in seconds
Returns:
list: Four HRV features [sdnn, rmssd, pnn50, tri_index]
"""
if len(rr_intervals) < 2:
return [0, 0, 0, 0]
sdnn = np.std(rr_intervals)
diff_rr = np.diff(rr_intervals)
rmssd = np.sqrt(np.mean(diff_rr**2)) if len(diff_rr) > 0 else 0
pnn50 = 100 * np.sum(np.abs(diff_rr) > 0.05) / len(diff_rr) if len(diff_rr) > 0 else 0
if len(rr_intervals) > 2:
bin_width = 1/128
bins = np.arange(min(rr_intervals), max(rr_intervals) + bin_width, bin_width)
n, _ = np.histogram(rr_intervals, bins=bins)
tri_index = len(rr_intervals) / np.max(n) if np.max(n) > 0 else 0
else:
tri_index = 0
return [sdnn, rmssd, pnn50, tri_index]
def extract_qrs_features(signal, r_peaks, fs):
"""
Extract QRS complex features from ECG signal and detected R peaks.
Args:
signal (numpy.ndarray): ECG signal
r_peaks (numpy.ndarray): Array of R peak indices
fs (int): Sampling frequency in Hz
Returns:
list: Three QRS features [qrs_width_mean, qrs_width_std, qrs_amplitude_mean]
"""
if len(r_peaks) < 2:
return [0, 0, 0]
qrs_width = []
for i in range(len(r_peaks)):
r_pos = r_peaks[i]
window_before = max(0, r_pos - int(0.1 * fs))
window_after = min(len(signal) - 1, r_pos + int(0.1 * fs))
if r_pos > window_before:
q_pos = window_before + np.argmin(signal[window_before:r_pos])
else:
q_pos = window_before
if r_pos < window_after:
s_pos = r_pos + np.argmin(signal[r_pos:window_after])
else:
s_pos = r_pos
if s_pos > q_pos:
qrs_width.append((s_pos - q_pos) / fs)
qrs_width_mean = np.mean(qrs_width) if qrs_width else 0
qrs_width_std = np.std(qrs_width) if qrs_width else 0
qrs_amplitude_mean = np.mean([signal[r] for r in r_peaks]) if r_peaks.size > 0 else 0
return [qrs_width_mean, qrs_width_std, qrs_amplitude_mean]
def digitize_ecg_from_pdf(pdf_path, output_file=None):
"""
Process an ECG PDF file and convert it to a .dat signal file.
Args:
pdf_path (str): Path to the ECG PDF file
output_file (str, optional): Path to save the output .dat file
Returns:
tuple: (path to the created .dat file, list of paths to segment files)
"""
if output_file is None:
output_file = 'calibrated_ecg.dat'
images = convert_from_path(pdf_path)
temp_image_path = 'temp_ecg_image.jpg'
images[0].save(temp_image_path, 'JPEG')
img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
height, width = img.shape
calibration = {
'seconds_per_pixel': 2.0 / 197.0,
'mv_per_pixel': 1.0 / 78.8,
}
layer1_start = int(height * 35.35 / 100)
layer1_end = int(height * 51.76 / 100)
layer2_start = int(height * 51.82 / 100)
layer2_end = int(height * 69.41 / 100)
layer3_start = int(height * 69.47 / 100)
layer3_end = int(height * 87.06 / 100)
layers = [
img[layer1_start:layer1_end, :],
img[layer2_start:layer2_end, :],
img[layer3_start:layer3_end, :]
]
signals = []
time_points = []
layer_duration = 10.0
for i, layer in enumerate(layers):
_, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
waveform_contour = max(contours, key=cv2.contourArea)
sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
x_coords = np.array([point[0][0] for point in sorted_contour])
y_coords = np.array([point[0][1] for point in sorted_contour])
isoelectric_line_y = layer.shape[0] * 0.6
x_min, x_max = np.min(x_coords), np.max(x_coords)
time = (x_coords - x_min) / (x_max - x_min) * layer_duration
signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
signal_mv = signal_mv - np.mean(signal_mv)
time_points.append(time)
signals.append(signal_mv)
total_duration = layer_duration * len(layers)
sampling_frequency = 500
num_samples = int(total_duration * sampling_frequency)
combined_time = np.linspace(0, total_duration, num_samples)
combined_signal = np.zeros(num_samples)
for i, (time, signal) in enumerate(zip(time_points, signals)):
start_time = i * layer_duration
mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
relevant_times = combined_time[mask]
interpolated_signal = np.interp(relevant_times, start_time + time, signal)
combined_signal[mask] = interpolated_signal
combined_signal = combined_signal - np.mean(combined_signal)
signal_peak = np.max(np.abs(combined_signal))
target_amplitude = 2.0
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
scaling_factor = target_amplitude / signal_peak
combined_signal = combined_signal * scaling_factor
adc_gain = 1000.0
int_signal = (combined_signal * adc_gain).astype(np.int16)
int_signal.tofile(output_file)
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
segment_files = []
samples_per_segment = int(layer_duration * sampling_frequency)
base_name = os.path.splitext(output_file)[0]
for i in range(3):
start_idx = i * samples_per_segment
end_idx = (i + 1) * samples_per_segment
segment = combined_signal[start_idx:end_idx]
segment_file = f"{base_name}_segment{i+1}.dat"
(segment * adc_gain).astype(np.int16).tofile(segment_file)
segment_files.append(segment_file)
return output_file, segment_files
def split_dat_into_segments(file_path, segment_duration=10.0):
"""
Split a DAT file into equal segments.
Args:
file_path (str): Path to the DAT file (without extension)
segment_duration (float): Duration of each segment in seconds
Returns:
list: Paths to the segment files
"""
signal_all_leads, fs = load_dat_signal(file_path)
if signal_all_leads.shape[1] == 1:
lead_index = 0
else:
lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
signal = signal_all_leads[:, lead_index]
samples_per_segment = int(segment_duration * fs)
total_samples = len(signal)
num_segments = total_samples // samples_per_segment
segment_files = []
base_name = os.path.splitext(file_path)[0]
for i in range(num_segments):
start_idx = i * samples_per_segment
end_idx = (i + 1) * samples_per_segment
segment = signal[start_idx:end_idx]
segment_file = f"{base_name}_segment{i+1}.dat"
segment.reshape(-1, 1).tofile(segment_file)
segment_files.append(segment_file)
return segment_files
def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16):
"""
Load a DAT file containing ECG signal data.
Args:
file_path (str): Path to the DAT file (without extension)
n_leads (int): Number of leads in the signal
n_samples (int): Number of samples per lead
dtype: Data type of the signal
Returns:
tuple: (numpy array of signal data, sampling frequency)
"""
if file_path.endswith('.dat'):
dat_path = file_path
else:
dat_path = file_path + '.dat'
raw = np.fromfile(dat_path, dtype=dtype)
if raw.size != n_leads * n_samples:
if raw.size == n_samples:
signal = raw.reshape(n_samples, 1)
return signal, 500
possible_leads = [1, 2, 3, 6, 12]
for possible_lead_count in possible_leads:
if raw.size % possible_lead_count == 0:
actual_samples = raw.size // possible_lead_count
signal = raw.reshape(actual_samples, possible_lead_count)
return signal, 500
signal = raw.reshape(-1, 1)
return signal, 500
signal = raw.reshape(n_samples, n_leads)
return signal, 500
def extract_features_from_signal(signal):
"""
Extract features from an ECG signal.
Args:
signal (numpy.ndarray): ECG signal
Returns:
list: Basic features extracted from the signal (32 features)
"""
features = []
features.append(np.mean(signal))
features.append(np.std(signal))
features.append(np.median(signal))
features.append(np.min(signal))
features.append(np.max(signal))
features.append(np.percentile(signal, 25))
features.append(np.percentile(signal, 75))
features.append(np.mean(np.diff(signal)))
coeffs = pywt.wavedec(signal, 'db4', level=5)
for coeff in coeffs:
features.append(np.mean(coeff))
features.append(np.std(coeff))
features.append(np.min(coeff))
features.append(np.max(coeff))
return features
def classify_new_ecg(file_path, model):
"""
Classify a new ECG file.
Args:
file_path (str): Path to the ECG file (without extension)
model: The trained model for classification
Returns:
str: Classification result ("Normal", "Abnormal", or error message)
"""
signal_all_leads, fs = load_dat_signal(file_path)
if signal_all_leads.shape[1] == 1:
lead_index = 0
else:
lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
signal = signal_all_leads[:, lead_index]
signal = (signal - np.mean(signal)) / np.std(signal)
try:
r_peaks = processing.gqrs_detect(sig=signal, fs=fs)
except:
r_peaks = np.array([])
if len(r_peaks) < 2:
basic_features = extract_features_from_signal(signal)
record_features = basic_features + [0] * (45 - len(basic_features))
else:
rr_intervals = np.diff(r_peaks) / fs
qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
record_features = []
basic_features = extract_features_from_signal(signal)
record_features.extend(basic_features)
record_features.extend([
len(r_peaks),
np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
np.mean(qrs_durations) / fs if len(qrs_durations) > 0 else 0,
np.std(qrs_durations) / fs if len(qrs_durations) > 0 else 0
])
hrv_features = extract_hrv_features(rr_intervals)
record_features.extend(hrv_features)
qrs_features = extract_qrs_features(signal, r_peaks, fs)
record_features.extend(qrs_features)
if len(rr_intervals) >= 4:
try:
rr_times = np.cumsum(rr_intervals)
rr_times = np.insert(rr_times, 0, 0)
fs_interp = 4.0
t_interp = np.arange(0, rr_times[-1], 1/fs_interp)
rr_interp = np.interp(t_interp, rr_times[:-1], rr_intervals)
freq, psd = sg.welch(rr_interp, fs=fs_interp, nperseg=min(256, len(rr_interp)))
vlf_mask = (freq >= 0.0033) & (freq < 0.04)
lf_mask = (freq >= 0.04) & (freq < 0.15)
hf_mask = (freq >= 0.15) & (freq < 0.4)
lf_power = np.trapz(psd[lf_mask], freq[lf_mask]) if np.any(lf_mask) else 0
hf_power = np.trapz(psd[hf_mask], freq[hf_mask]) if np.any(hf_mask) else 0
lf_hf_ratio = lf_power / hf_power if hf_power > 0 else 0
normalized_lf = lf_power / (lf_power + hf_power) if (lf_power + hf_power) > 0 else 0
except:
lf_power = hf_power = lf_hf_ratio = normalized_lf = 0
else:
lf_power = hf_power = lf_hf_ratio = normalized_lf = 0
record_features.extend([lf_power, hf_power, lf_hf_ratio, normalized_lf])
if len(record_features) < 45:
record_features.extend([0] * (45 - len(record_features)))
elif len(record_features) > 45:
record_features = record_features[:45]
prediction = model.predict([record_features])[0]
result = "Abnormal" if prediction == 1 else "Normal"
return result
def classify_ecg(file_path, model, is_pdf=False):
"""
Wrapper function that handles both PDF and DAT ECG files with segment voting.
Args:
file_path (str): Path to the ECG file (.pdf or without extension for .dat)
model: The trained model for classification
is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
Returns:
str: Classification result ("Normal", "Abnormal", or error message)
"""
try:
if model is None:
return "Error: Model not loaded. Please check model compatibility."
if is_pdf:
base_name = os.path.splitext(os.path.basename(file_path))[0]
output_dat = f"{base_name}_digitized.dat"
dat_path, segment_files = digitize_ecg_from_pdf(
pdf_path=file_path,
output_file=output_dat
)
else:
segment_files = split_dat_into_segments(file_path)
if not segment_files:
return classify_new_ecg(file_path, model)
segment_results = []
for segment_file in segment_files:
segment_path = os.path.splitext(segment_file)[0]
result = classify_new_ecg(segment_path, model)
segment_results.append(result)
try:
os.remove(segment_file)
except:
pass
if segment_results:
normal_count = segment_results.count("Normal")
abnormal_count = segment_results.count("Abnormal")
if abnormal_count > normal_count:
final_result = "Abnormal"
elif normal_count > abnormal_count:
final_result = "Normal"
else:
final_result = "Inconclusive"
return final_result
else:
return "Error: No valid segments to classify"
except Exception as e:
error_msg = f"Classification error: {str(e)}"
return error_msg