Deploy_El7a2ny_Application / ECG /ECG_MultiClass.py
Hussein El-Hadidy
Latest ECG
676f928
"""
ECG Analysis Pipeline: From PDF to Arrhythmia Classification
-----------------------------------------------------------
This module provides functions to:
1. Digitize ECG from PDF files
2. Process the digitized ECG signal
3. Classify arrhythmias using a trained CNN model
"""
import cv2
import numpy as np
import os
import tensorflow as tf
import pickle
from scipy.interpolate import interp1d
from pdf2image import convert_from_path
ARRHYTHMIA_CLASSES = ["Conduction Abnormalities", "Atrial Arrhythmias", "Tachyarrhythmias", "Normal"]
SAMPLING_RATE = 500
SEGMENT_DURATION = 10.0
TARGET_SEGMENT_LENGTH = 5000
DEFAULT_OUTPUT_FILE = 'calibrated_ecg.dat'
DAT_SCALE_FACTOR = 0.001
def digitize_ecg_from_pdf(pdf_path, output_file=None):
"""
Process an ECG PDF file and convert it to a .dat signal file.
Args:
pdf_path (str): Path to the ECG PDF file
output_file (str, optional): Path to save the output .dat file
Returns:
tuple: (path to the created .dat file, list of paths to segment files)
"""
if output_file is None:
output_file = DEFAULT_OUTPUT_FILE
images = convert_from_path(pdf_path)
temp_image_path = 'temp_ecg_image.jpg'
images[0].save(temp_image_path, 'JPEG')
img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
height, width = img.shape
calibration = {
'seconds_per_pixel': 2.0 / 197.0,
'mv_per_pixel': 1.0 / 78.8,
}
layer1_start = int(height * 35.35 / 100)
layer1_end = int(height * 51.76 / 100)
layer2_start = int(height * 51.82 / 100)
layer2_end = int(height * 69.41 / 100)
layer3_start = int(height * 69.47 / 100)
layer3_end = int(height * 87.06 / 100)
layers = [
img[layer1_start:layer1_end, :],
img[layer2_start:layer2_end, :],
img[layer3_start:layer3_end, :]
]
signals = []
time_points = []
layer_duration = 10.0
for i, layer in enumerate(layers):
_, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
waveform_contour = max(contours, key=cv2.contourArea)
sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
x_coords = np.array([point[0][0] for point in sorted_contour])
y_coords = np.array([point[0][1] for point in sorted_contour])
isoelectric_line_y = layer.shape[0] * 0.6
x_min, x_max = np.min(x_coords), np.max(x_coords)
time = (x_coords - x_min) / (x_max - x_min) * layer_duration
signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
signal_mv = signal_mv - np.mean(signal_mv)
time_points.append(time)
signals.append(signal_mv)
total_duration = layer_duration * len(layers)
sampling_frequency = 500
num_samples = int(total_duration * sampling_frequency)
combined_time = np.linspace(0, total_duration, num_samples)
combined_signal = np.zeros(num_samples)
for i, (time, signal) in enumerate(zip(time_points, signals)):
start_time = i * layer_duration
mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
relevant_times = combined_time[mask]
interpolated_signal = np.interp(relevant_times, start_time + time, signal)
combined_signal[mask] = interpolated_signal
combined_signal = combined_signal - np.mean(combined_signal)
signal_peak = np.max(np.abs(combined_signal))
target_amplitude = 2.0
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
scaling_factor = target_amplitude / signal_peak
combined_signal = combined_signal * scaling_factor
adc_gain = 1000.0
int_signal = (combined_signal * adc_gain).astype(np.int16)
int_signal.tofile(output_file)
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
segment_files = []
samples_per_segment = int(layer_duration * sampling_frequency)
base_name = os.path.splitext(output_file)[0]
for i in range(3):
start_idx = i * samples_per_segment
end_idx = (i + 1) * samples_per_segment
segment = combined_signal[start_idx:end_idx]
segment_file = f"{base_name}_segment{i+1}.dat"
(segment * adc_gain).astype(np.int16).tofile(segment_file)
segment_files.append(segment_file)
return output_file, segment_files
def read_ecg_dat_file(dat_file_path):
"""
Read a DAT file directly and properly scale it
Parameters:
-----------
dat_file_path : str
Path to the .dat file (with or without .dat extension)
Returns:
--------
numpy.ndarray
ECG signal data with shape (total_samples,)
"""
if not dat_file_path.endswith('.dat'):
dat_file_path += '.dat'
try:
data = np.fromfile(dat_file_path, dtype=np.int16)
signal = data * DAT_SCALE_FACTOR
return signal
except Exception as e:
raise
def segment_signal(signal):
"""
Segment a signal into equal-length segments
Parameters:
-----------
signal : numpy.ndarray
The full signal to segment
Returns:
--------
list
List of signal segments
"""
segment_samples = int(SAMPLING_RATE * SEGMENT_DURATION)
segments = []
num_segments = len(signal) // segment_samples
for i in range(num_segments):
start_idx = i * segment_samples
end_idx = (i + 1) * segment_samples
segment = signal[start_idx:end_idx]
segments.append(segment)
return segments
def process_segment(segment):
"""
Process a segment of ECG data to ensure it's properly formatted for the model
Parameters:
-----------
segment : numpy.ndarray
Raw ECG segment
Returns:
--------
numpy.ndarray
Processed segment ready for model input
"""
if len(segment) != TARGET_SEGMENT_LENGTH:
x = np.linspace(0, 1, len(segment))
x_new = np.linspace(0, 1, TARGET_SEGMENT_LENGTH)
f = interp1d(x, segment, kind='linear', bounds_error=False, fill_value="extrapolate")
segment = f(x_new)
segment = (segment - np.mean(segment)) / (np.std(segment) + 1e-8)
return segment
def predict_with_cnn_model(signal_data, model):
"""
Process signal data and make predictions using the CNN model.
Parameters:
-----------
signal_data : numpy.ndarray
Raw signal data
model : tensorflow.keras.Model
Loaded CNN model
Returns:
--------
dict
Dictionary containing predictions for each segment and final averaged prediction
"""
segments = segment_signal(signal_data)
all_predictions = []
for i, segment in enumerate(segments):
processed_segment = process_segment(segment)
X = processed_segment.reshape(1, TARGET_SEGMENT_LENGTH, 1)
prediction = model.predict(X, verbose=0)
all_predictions.append(prediction[0])
if all_predictions:
avg_prediction = np.mean(all_predictions, axis=0)
top_class_idx = np.argmax(avg_prediction)
results = {
"segment_predictions": all_predictions,
"averaged_prediction": avg_prediction,
"top_class_index": top_class_idx,
"top_class": ARRHYTHMIA_CLASSES[top_class_idx],
"probability": float(avg_prediction[top_class_idx])
}
return results
else:
return {"error": "No valid segments for prediction"}
def analyze_ecg_pdf(pdf_path, model_path, cleanup=True):
"""
Complete ECG analysis pipeline: digitizes a PDF ECG, analyzes it with the model,
and returns the arrhythmia classification with highest probability.
Args:
pdf_path (str): Path to the ECG PDF file
model_path (str): Path to the model (.h5) file
cleanup (bool, optional): Whether to remove temporary files after processing
Returns:
dict: {
"arrhythmia_class": str, # Top arrhythmia class
"probability": float, # Probability of top class
"all_probabilities": dict, # All classes with probabilities
"digitized_file": str # Path to digitized file (if cleanup=False)
}
"""
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try:
dat_file_path, segment_files = digitize_ecg_from_pdf(pdf_path)
ecg_model = tf.keras.models.load_model(model_path)
ecg_signal = read_ecg_dat_file(dat_file_path)
classification_results = predict_with_cnn_model(ecg_signal, ecg_model)
arrhythmia_result = {
"arrhythmia_class": classification_results.get("top_class"),
"probability": classification_results.get("probability", 0.0),
"all_probabilities": {}
}
if "averaged_prediction" in classification_results:
for idx, class_name in enumerate(ARRHYTHMIA_CLASSES):
arrhythmia_result["all_probabilities"][class_name] = float(classification_results["averaged_prediction"][idx])
if not cleanup:
arrhythmia_result["digitized_file"] = dat_file_path
if cleanup:
if os.path.exists(dat_file_path):
os.remove(dat_file_path)
for segment_file in segment_files:
if os.path.exists(segment_file):
os.remove(segment_file)
return arrhythmia_result
except Exception as e:
error_msg = f"Error in ECG analysis: {str(e)}"
return {"error": error_msg}