|
|
import wfdb |
|
|
from wfdb import processing |
|
|
import numpy as np |
|
|
import joblib |
|
|
import pywt |
|
|
import os |
|
|
import cv2 |
|
|
from pdf2image import convert_from_path |
|
|
import warnings |
|
|
import pickle |
|
|
from scipy import signal as sg |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
def extract_hrv_features(rr_intervals): |
|
|
""" |
|
|
Extract heart rate variability features from RR intervals. |
|
|
|
|
|
Args: |
|
|
rr_intervals (numpy.ndarray): RR intervals in seconds |
|
|
|
|
|
Returns: |
|
|
list: Four HRV features [sdnn, rmssd, pnn50, tri_index] |
|
|
""" |
|
|
if len(rr_intervals) < 2: |
|
|
return [0, 0, 0, 0] |
|
|
|
|
|
sdnn = np.std(rr_intervals) |
|
|
diff_rr = np.diff(rr_intervals) |
|
|
rmssd = np.sqrt(np.mean(diff_rr**2)) if len(diff_rr) > 0 else 0 |
|
|
pnn50 = 100 * np.sum(np.abs(diff_rr) > 0.05) / len(diff_rr) if len(diff_rr) > 0 else 0 |
|
|
|
|
|
if len(rr_intervals) > 2: |
|
|
bin_width = 1/128 |
|
|
bins = np.arange(min(rr_intervals), max(rr_intervals) + bin_width, bin_width) |
|
|
n, _ = np.histogram(rr_intervals, bins=bins) |
|
|
tri_index = len(rr_intervals) / np.max(n) if np.max(n) > 0 else 0 |
|
|
else: |
|
|
tri_index = 0 |
|
|
|
|
|
return [sdnn, rmssd, pnn50, tri_index] |
|
|
|
|
|
|
|
|
def extract_qrs_features(signal, r_peaks, fs): |
|
|
""" |
|
|
Extract QRS complex features from ECG signal and detected R peaks. |
|
|
|
|
|
Args: |
|
|
signal (numpy.ndarray): ECG signal |
|
|
r_peaks (numpy.ndarray): Array of R peak indices |
|
|
fs (int): Sampling frequency in Hz |
|
|
|
|
|
Returns: |
|
|
list: Three QRS features [qrs_width_mean, qrs_width_std, qrs_amplitude_mean] |
|
|
""" |
|
|
if len(r_peaks) < 2: |
|
|
return [0, 0, 0] |
|
|
|
|
|
qrs_width = [] |
|
|
for i in range(len(r_peaks)): |
|
|
r_pos = r_peaks[i] |
|
|
window_before = max(0, r_pos - int(0.1 * fs)) |
|
|
window_after = min(len(signal) - 1, r_pos + int(0.1 * fs)) |
|
|
|
|
|
if r_pos > window_before: |
|
|
q_pos = window_before + np.argmin(signal[window_before:r_pos]) |
|
|
else: |
|
|
q_pos = window_before |
|
|
|
|
|
if r_pos < window_after: |
|
|
s_pos = r_pos + np.argmin(signal[r_pos:window_after]) |
|
|
else: |
|
|
s_pos = r_pos |
|
|
|
|
|
if s_pos > q_pos: |
|
|
qrs_width.append((s_pos - q_pos) / fs) |
|
|
|
|
|
qrs_width_mean = np.mean(qrs_width) if qrs_width else 0 |
|
|
qrs_width_std = np.std(qrs_width) if qrs_width else 0 |
|
|
qrs_amplitude_mean = np.mean([signal[r] for r in r_peaks]) if r_peaks.size > 0 else 0 |
|
|
|
|
|
return [qrs_width_mean, qrs_width_std, qrs_amplitude_mean] |
|
|
|
|
|
|
|
|
def digitize_ecg_from_pdf(pdf_path, output_file=None): |
|
|
""" |
|
|
Process an ECG PDF file and convert it to a .dat signal file. |
|
|
|
|
|
Args: |
|
|
pdf_path (str): Path to the ECG PDF file |
|
|
output_file (str, optional): Path to save the output .dat file |
|
|
|
|
|
Returns: |
|
|
tuple: (path to the created .dat file, list of paths to segment files) |
|
|
""" |
|
|
if output_file is None: |
|
|
output_file = 'calibrated_ecg.dat' |
|
|
|
|
|
images = convert_from_path(pdf_path) |
|
|
temp_image_path = 'temp_ecg_image.jpg' |
|
|
images[0].save(temp_image_path, 'JPEG') |
|
|
|
|
|
img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE) |
|
|
height, width = img.shape |
|
|
|
|
|
calibration = { |
|
|
'seconds_per_pixel': 2.0 / 197.0, |
|
|
'mv_per_pixel': 1.0 / 78.8, |
|
|
} |
|
|
|
|
|
layer1_start = int(height * 35.35 / 100) |
|
|
layer1_end = int(height * 51.76 / 100) |
|
|
layer2_start = int(height * 51.82 / 100) |
|
|
layer2_end = int(height * 69.41 / 100) |
|
|
layer3_start = int(height * 69.47 / 100) |
|
|
layer3_end = int(height * 87.06 / 100) |
|
|
|
|
|
layers = [ |
|
|
img[layer1_start:layer1_end, :], |
|
|
img[layer2_start:layer2_end, :], |
|
|
img[layer3_start:layer3_end, :] |
|
|
] |
|
|
|
|
|
signals = [] |
|
|
time_points = [] |
|
|
layer_duration = 10.0 |
|
|
|
|
|
for i, layer in enumerate(layers): |
|
|
_, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV) |
|
|
|
|
|
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
waveform_contour = max(contours, key=cv2.contourArea) |
|
|
|
|
|
sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0]) |
|
|
x_coords = np.array([point[0][0] for point in sorted_contour]) |
|
|
y_coords = np.array([point[0][1] for point in sorted_contour]) |
|
|
|
|
|
isoelectric_line_y = layer.shape[0] * 0.6 |
|
|
|
|
|
x_min, x_max = np.min(x_coords), np.max(x_coords) |
|
|
time = (x_coords - x_min) / (x_max - x_min) * layer_duration |
|
|
|
|
|
signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel'] |
|
|
signal_mv = signal_mv - np.mean(signal_mv) |
|
|
|
|
|
time_points.append(time) |
|
|
signals.append(signal_mv) |
|
|
|
|
|
total_duration = layer_duration * len(layers) |
|
|
sampling_frequency = 500 |
|
|
num_samples = int(total_duration * sampling_frequency) |
|
|
combined_time = np.linspace(0, total_duration, num_samples) |
|
|
combined_signal = np.zeros(num_samples) |
|
|
|
|
|
for i, (time, signal) in enumerate(zip(time_points, signals)): |
|
|
start_time = i * layer_duration |
|
|
mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration) |
|
|
relevant_times = combined_time[mask] |
|
|
interpolated_signal = np.interp(relevant_times, start_time + time, signal) |
|
|
combined_signal[mask] = interpolated_signal |
|
|
|
|
|
combined_signal = combined_signal - np.mean(combined_signal) |
|
|
signal_peak = np.max(np.abs(combined_signal)) |
|
|
target_amplitude = 2.0 |
|
|
|
|
|
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0): |
|
|
scaling_factor = target_amplitude / signal_peak |
|
|
combined_signal = combined_signal * scaling_factor |
|
|
|
|
|
adc_gain = 1000.0 |
|
|
int_signal = (combined_signal * adc_gain).astype(np.int16) |
|
|
int_signal.tofile(output_file) |
|
|
|
|
|
if os.path.exists(temp_image_path): |
|
|
os.remove(temp_image_path) |
|
|
|
|
|
segment_files = [] |
|
|
samples_per_segment = int(layer_duration * sampling_frequency) |
|
|
|
|
|
base_name = os.path.splitext(output_file)[0] |
|
|
for i in range(3): |
|
|
start_idx = i * samples_per_segment |
|
|
end_idx = (i + 1) * samples_per_segment |
|
|
segment = combined_signal[start_idx:end_idx] |
|
|
|
|
|
segment_file = f"{base_name}_segment{i+1}.dat" |
|
|
(segment * adc_gain).astype(np.int16).tofile(segment_file) |
|
|
segment_files.append(segment_file) |
|
|
|
|
|
return output_file, segment_files |
|
|
|
|
|
|
|
|
def split_dat_into_segments(file_path, segment_duration=10.0): |
|
|
""" |
|
|
Split a DAT file into equal segments. |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the DAT file (without extension) |
|
|
segment_duration (float): Duration of each segment in seconds |
|
|
|
|
|
Returns: |
|
|
list: Paths to the segment files |
|
|
""" |
|
|
signal_all_leads, fs = load_dat_signal(file_path) |
|
|
|
|
|
if signal_all_leads.shape[1] == 1: |
|
|
lead_index = 0 |
|
|
else: |
|
|
lead_priority = [1, 0] |
|
|
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0) |
|
|
|
|
|
signal = signal_all_leads[:, lead_index] |
|
|
|
|
|
samples_per_segment = int(segment_duration * fs) |
|
|
total_samples = len(signal) |
|
|
num_segments = total_samples // samples_per_segment |
|
|
|
|
|
segment_files = [] |
|
|
|
|
|
base_name = os.path.splitext(file_path)[0] |
|
|
|
|
|
for i in range(num_segments): |
|
|
start_idx = i * samples_per_segment |
|
|
end_idx = (i + 1) * samples_per_segment |
|
|
segment = signal[start_idx:end_idx] |
|
|
|
|
|
segment_file = f"{base_name}_segment{i+1}.dat" |
|
|
segment.reshape(-1, 1).tofile(segment_file) |
|
|
segment_files.append(segment_file) |
|
|
|
|
|
return segment_files |
|
|
|
|
|
|
|
|
def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16): |
|
|
""" |
|
|
Load a DAT file containing ECG signal data. |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the DAT file (without extension) |
|
|
n_leads (int): Number of leads in the signal |
|
|
n_samples (int): Number of samples per lead |
|
|
dtype: Data type of the signal |
|
|
|
|
|
Returns: |
|
|
tuple: (numpy array of signal data, sampling frequency) |
|
|
""" |
|
|
if file_path.endswith('.dat'): |
|
|
dat_path = file_path |
|
|
else: |
|
|
dat_path = file_path + '.dat' |
|
|
|
|
|
raw = np.fromfile(dat_path, dtype=dtype) |
|
|
|
|
|
if raw.size != n_leads * n_samples: |
|
|
if raw.size == n_samples: |
|
|
signal = raw.reshape(n_samples, 1) |
|
|
return signal, 500 |
|
|
|
|
|
possible_leads = [1, 2, 3, 6, 12] |
|
|
for possible_lead_count in possible_leads: |
|
|
if raw.size % possible_lead_count == 0: |
|
|
actual_samples = raw.size // possible_lead_count |
|
|
signal = raw.reshape(actual_samples, possible_lead_count) |
|
|
return signal, 500 |
|
|
|
|
|
signal = raw.reshape(-1, 1) |
|
|
return signal, 500 |
|
|
|
|
|
signal = raw.reshape(n_samples, n_leads) |
|
|
return signal, 500 |
|
|
|
|
|
|
|
|
def extract_features_from_signal(signal): |
|
|
""" |
|
|
Extract features from an ECG signal. |
|
|
|
|
|
Args: |
|
|
signal (numpy.ndarray): ECG signal |
|
|
|
|
|
Returns: |
|
|
list: Basic features extracted from the signal (32 features) |
|
|
""" |
|
|
features = [] |
|
|
features.append(np.mean(signal)) |
|
|
features.append(np.std(signal)) |
|
|
features.append(np.median(signal)) |
|
|
features.append(np.min(signal)) |
|
|
features.append(np.max(signal)) |
|
|
features.append(np.percentile(signal, 25)) |
|
|
features.append(np.percentile(signal, 75)) |
|
|
features.append(np.mean(np.diff(signal))) |
|
|
|
|
|
coeffs = pywt.wavedec(signal, 'db4', level=5) |
|
|
for coeff in coeffs: |
|
|
features.append(np.mean(coeff)) |
|
|
features.append(np.std(coeff)) |
|
|
features.append(np.min(coeff)) |
|
|
features.append(np.max(coeff)) |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
def classify_new_ecg(file_path, model): |
|
|
""" |
|
|
Classify a new ECG file. |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the ECG file (without extension) |
|
|
model: The trained model for classification |
|
|
|
|
|
Returns: |
|
|
str: Classification result ("Normal", "Abnormal", or error message) |
|
|
""" |
|
|
signal_all_leads, fs = load_dat_signal(file_path) |
|
|
|
|
|
if signal_all_leads.shape[1] == 1: |
|
|
lead_index = 0 |
|
|
else: |
|
|
lead_priority = [1, 0] |
|
|
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0) |
|
|
|
|
|
signal = signal_all_leads[:, lead_index] |
|
|
signal = (signal - np.mean(signal)) / np.std(signal) |
|
|
|
|
|
try: |
|
|
r_peaks = processing.gqrs_detect(sig=signal, fs=fs) |
|
|
except: |
|
|
r_peaks = np.array([]) |
|
|
|
|
|
if len(r_peaks) < 2: |
|
|
basic_features = extract_features_from_signal(signal) |
|
|
record_features = basic_features + [0] * (45 - len(basic_features)) |
|
|
else: |
|
|
rr_intervals = np.diff(r_peaks) / fs |
|
|
qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))]) |
|
|
|
|
|
record_features = [] |
|
|
basic_features = extract_features_from_signal(signal) |
|
|
record_features.extend(basic_features) |
|
|
|
|
|
record_features.extend([ |
|
|
len(r_peaks), |
|
|
np.mean(rr_intervals) if len(rr_intervals) > 0 else 0, |
|
|
np.std(rr_intervals) if len(rr_intervals) > 0 else 0, |
|
|
np.median(rr_intervals) if len(rr_intervals) > 0 else 0, |
|
|
np.mean(qrs_durations) / fs if len(qrs_durations) > 0 else 0, |
|
|
np.std(qrs_durations) / fs if len(qrs_durations) > 0 else 0 |
|
|
]) |
|
|
|
|
|
hrv_features = extract_hrv_features(rr_intervals) |
|
|
record_features.extend(hrv_features) |
|
|
|
|
|
qrs_features = extract_qrs_features(signal, r_peaks, fs) |
|
|
record_features.extend(qrs_features) |
|
|
|
|
|
if len(rr_intervals) >= 4: |
|
|
try: |
|
|
rr_times = np.cumsum(rr_intervals) |
|
|
rr_times = np.insert(rr_times, 0, 0) |
|
|
|
|
|
fs_interp = 4.0 |
|
|
t_interp = np.arange(0, rr_times[-1], 1/fs_interp) |
|
|
rr_interp = np.interp(t_interp, rr_times[:-1], rr_intervals) |
|
|
|
|
|
freq, psd = sg.welch(rr_interp, fs=fs_interp, nperseg=min(256, len(rr_interp))) |
|
|
|
|
|
vlf_mask = (freq >= 0.0033) & (freq < 0.04) |
|
|
lf_mask = (freq >= 0.04) & (freq < 0.15) |
|
|
hf_mask = (freq >= 0.15) & (freq < 0.4) |
|
|
|
|
|
lf_power = np.trapz(psd[lf_mask], freq[lf_mask]) if np.any(lf_mask) else 0 |
|
|
hf_power = np.trapz(psd[hf_mask], freq[hf_mask]) if np.any(hf_mask) else 0 |
|
|
|
|
|
lf_hf_ratio = lf_power / hf_power if hf_power > 0 else 0 |
|
|
normalized_lf = lf_power / (lf_power + hf_power) if (lf_power + hf_power) > 0 else 0 |
|
|
except: |
|
|
lf_power = hf_power = lf_hf_ratio = normalized_lf = 0 |
|
|
else: |
|
|
lf_power = hf_power = lf_hf_ratio = normalized_lf = 0 |
|
|
|
|
|
record_features.extend([lf_power, hf_power, lf_hf_ratio, normalized_lf]) |
|
|
|
|
|
if len(record_features) < 45: |
|
|
record_features.extend([0] * (45 - len(record_features))) |
|
|
elif len(record_features) > 45: |
|
|
record_features = record_features[:45] |
|
|
|
|
|
prediction = model.predict([record_features])[0] |
|
|
result = "Abnormal" if prediction == 1 else "Normal" |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def classify_ecg(file_path, model, is_pdf=False): |
|
|
""" |
|
|
Wrapper function that handles both PDF and DAT ECG files with segment voting. |
|
|
|
|
|
Args: |
|
|
file_path (str): Path to the ECG file (.pdf or without extension for .dat) |
|
|
model: The trained model for classification |
|
|
is_pdf (bool): Whether the input file is a PDF (True) or DAT (False) |
|
|
|
|
|
Returns: |
|
|
str: Classification result ("Normal", "Abnormal", or error message) |
|
|
""" |
|
|
try: |
|
|
if model is None: |
|
|
return "Error: Model not loaded. Please check model compatibility." |
|
|
|
|
|
if is_pdf: |
|
|
base_name = os.path.splitext(os.path.basename(file_path))[0] |
|
|
output_dat = f"{base_name}_digitized.dat" |
|
|
|
|
|
dat_path, segment_files = digitize_ecg_from_pdf( |
|
|
pdf_path=file_path, |
|
|
output_file=output_dat |
|
|
) |
|
|
else: |
|
|
segment_files = split_dat_into_segments(file_path) |
|
|
|
|
|
if not segment_files: |
|
|
return classify_new_ecg(file_path, model) |
|
|
|
|
|
segment_results = [] |
|
|
|
|
|
for segment_file in segment_files: |
|
|
segment_path = os.path.splitext(segment_file)[0] |
|
|
result = classify_new_ecg(segment_path, model) |
|
|
segment_results.append(result) |
|
|
|
|
|
try: |
|
|
os.remove(segment_file) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if segment_results: |
|
|
normal_count = segment_results.count("Normal") |
|
|
abnormal_count = segment_results.count("Abnormal") |
|
|
|
|
|
if abnormal_count > normal_count: |
|
|
final_result = "Abnormal" |
|
|
elif normal_count > abnormal_count: |
|
|
final_result = "Normal" |
|
|
else: |
|
|
final_result = "Inconclusive" |
|
|
|
|
|
return final_result |
|
|
else: |
|
|
return "Error: No valid segments to classify" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Classification error: {str(e)}" |
|
|
return error_msg |
|
|
|
|
|
|
|
|
|
|
|
|