Hussein El-Hadidy commited on
Commit Β·
676f928
1
Parent(s): ef22c1c
Latest ECG
Browse files- deep-multiclass.h5 β Arrhythmia_Model_with_SMOTE.h5 +2 -2
- ECG/ECG_Classify.py +238 -392
- ECG/ECG_MultiClass.py +127 -368
- SkinBurns_Classification.py β SkinBurns/SkinBurns_Classification.py +0 -0
- SkinBurns_Segmentation.py β SkinBurns/SkinBurns_Segmentation.py +0 -0
- app.py +12 -27
- voting_classifier.pkl +0 -3
- deep-multiclass.pkl β voting_classifier_arrhythmia.pkl +2 -2
deep-multiclass.h5 β Arrhythmia_Model_with_SMOTE.h5
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7145bc906536b54953df3e7034d0da786c24261ec4080fbfd51031d19895e713
|
| 3 |
+
size 8923320
|
ECG/ECG_Classify.py
CHANGED
|
@@ -1,57 +1,113 @@
|
|
| 1 |
-
import wfdb
|
| 2 |
-
from wfdb import processing
|
| 3 |
-
import numpy as np
|
| 4 |
-
import joblib
|
| 5 |
-
import pywt
|
| 6 |
-
import os
|
| 7 |
-
import cv2
|
| 8 |
-
from pdf2image import convert_from_path
|
| 9 |
import warnings
|
| 10 |
import pickle
|
| 11 |
-
import
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
Process an ECG PDF file and convert it to a .dat signal file.
|
| 17 |
|
| 18 |
Args:
|
| 19 |
pdf_path (str): Path to the ECG PDF file
|
| 20 |
-
output_file (str): Path to save the output .dat file
|
| 21 |
-
debug (bool): Whether to print debug information
|
| 22 |
-
save_segments (bool): Whether to save individual segments
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
tuple: (path to the created .dat file, list of paths to segment files)
|
| 26 |
"""
|
| 27 |
-
if
|
| 28 |
-
|
| 29 |
|
| 30 |
-
# Convert PDF to image
|
| 31 |
images = convert_from_path(pdf_path)
|
| 32 |
temp_image_path = 'temp_ecg_image.jpg'
|
| 33 |
images[0].save(temp_image_path, 'JPEG')
|
| 34 |
|
| 35 |
-
if debug:
|
| 36 |
-
print(f"Converted PDF to image: {temp_image_path}")
|
| 37 |
-
|
| 38 |
-
# Load the image
|
| 39 |
img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
|
| 40 |
height, width = img.shape
|
| 41 |
|
| 42 |
-
if debug:
|
| 43 |
-
print(f"Image dimensions: {width}x{height}")
|
| 44 |
-
|
| 45 |
-
# Fixed calibration parameters
|
| 46 |
calibration = {
|
| 47 |
-
'seconds_per_pixel': 2.0 / 197.0,
|
| 48 |
-
'mv_per_pixel': 1.0 / 78.8,
|
| 49 |
}
|
| 50 |
|
| 51 |
-
if debug:
|
| 52 |
-
print(f"Calibration parameters: {calibration}")
|
| 53 |
-
|
| 54 |
-
# Calculate layer boundaries using percentages
|
| 55 |
layer1_start = int(height * 35.35 / 100)
|
| 56 |
layer1_end = int(height * 51.76 / 100)
|
| 57 |
layer2_start = int(height * 51.82 / 100)
|
|
@@ -59,210 +115,123 @@ def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=Fals
|
|
| 59 |
layer3_start = int(height * 69.47 / 100)
|
| 60 |
layer3_end = int(height * 87.06 / 100)
|
| 61 |
|
| 62 |
-
if debug:
|
| 63 |
-
print(f"Layer 1 boundaries: {layer1_start}-{layer1_end}")
|
| 64 |
-
print(f"Layer 2 boundaries: {layer2_start}-{layer2_end}")
|
| 65 |
-
print(f"Layer 3 boundaries: {layer3_start}-{layer3_end}")
|
| 66 |
-
|
| 67 |
-
# Crop each layer
|
| 68 |
layers = [
|
| 69 |
-
img[layer1_start:layer1_end, :],
|
| 70 |
-
img[layer2_start:layer2_end, :],
|
| 71 |
-
img[layer3_start:layer3_end, :]
|
| 72 |
]
|
| 73 |
|
| 74 |
-
# Process each layer to extract waveform contours
|
| 75 |
signals = []
|
| 76 |
time_points = []
|
| 77 |
-
layer_duration = 10.0
|
| 78 |
|
| 79 |
for i, layer in enumerate(layers):
|
| 80 |
-
if debug:
|
| 81 |
-
print(f"Processing layer {i+1}...")
|
| 82 |
-
|
| 83 |
-
# Binary thresholding
|
| 84 |
_, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
|
| 85 |
|
| 86 |
-
# Detect contours
|
| 87 |
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 88 |
-
waveform_contour = max(contours, key=cv2.contourArea)
|
| 89 |
-
|
| 90 |
-
if debug:
|
| 91 |
-
print(f" - Found {len(contours)} contours")
|
| 92 |
-
print(f" - Selected contour with {len(waveform_contour)} points")
|
| 93 |
|
| 94 |
-
# Sort contour points and extract coordinates
|
| 95 |
sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
|
| 96 |
x_coords = np.array([point[0][0] for point in sorted_contour])
|
| 97 |
y_coords = np.array([point[0][1] for point in sorted_contour])
|
| 98 |
|
| 99 |
-
# Calculate isoelectric line (one-third from the bottom)
|
| 100 |
isoelectric_line_y = layer.shape[0] * 0.6
|
| 101 |
|
| 102 |
-
# Convert to time using fixed layer duration
|
| 103 |
x_min, x_max = np.min(x_coords), np.max(x_coords)
|
| 104 |
time = (x_coords - x_min) / (x_max - x_min) * layer_duration
|
| 105 |
|
| 106 |
-
# Calculate signal in millivolts and apply baseline correction
|
| 107 |
signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
|
| 108 |
signal_mv = signal_mv - np.mean(signal_mv)
|
| 109 |
|
| 110 |
-
if debug:
|
| 111 |
-
print(f" - Layer {i+1} signal range: {np.min(signal_mv):.2f} mV to {np.max(signal_mv):.2f} mV")
|
| 112 |
-
|
| 113 |
-
# Store the time points and calibrated signal
|
| 114 |
time_points.append(time)
|
| 115 |
signals.append(signal_mv)
|
| 116 |
|
| 117 |
-
# Save individual segments if requested
|
| 118 |
-
segment_files = []
|
| 119 |
-
sampling_frequency = 500 # Standard ECG frequency
|
| 120 |
-
samples_per_segment = int(layer_duration * sampling_frequency) # 5000 samples per 10-second segment
|
| 121 |
-
|
| 122 |
-
if save_segments:
|
| 123 |
-
base_name = os.path.splitext(output_file)[0]
|
| 124 |
-
|
| 125 |
-
for i, signal in enumerate(signals):
|
| 126 |
-
# Interpolate to get evenly sampled signal
|
| 127 |
-
segment_time = np.linspace(0, layer_duration, samples_per_segment)
|
| 128 |
-
interpolated_signal = np.interp(segment_time, time_points[i], signals[i])
|
| 129 |
-
|
| 130 |
-
# Normalize and scale
|
| 131 |
-
interpolated_signal = interpolated_signal - np.mean(interpolated_signal)
|
| 132 |
-
signal_peak = np.max(np.abs(interpolated_signal))
|
| 133 |
-
|
| 134 |
-
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
|
| 135 |
-
scaling_factor = 2.0 / signal_peak # Target peak amplitude of 2.0 mV
|
| 136 |
-
interpolated_signal = interpolated_signal * scaling_factor
|
| 137 |
-
|
| 138 |
-
# Convert to 16-bit integers
|
| 139 |
-
adc_gain = 1000.0
|
| 140 |
-
int_signal = (interpolated_signal * adc_gain).astype(np.int16)
|
| 141 |
-
|
| 142 |
-
# Save segment
|
| 143 |
-
segment_file = f"{base_name}_segment{i+1}.dat"
|
| 144 |
-
int_signal.reshape(-1, 1).tofile(segment_file)
|
| 145 |
-
segment_files.append(segment_file)
|
| 146 |
-
|
| 147 |
-
if debug:
|
| 148 |
-
print(f"Saved segment {i+1} to {segment_file}")
|
| 149 |
-
|
| 150 |
-
# Combine signals with proper time alignment for the full record
|
| 151 |
total_duration = layer_duration * len(layers)
|
|
|
|
| 152 |
num_samples = int(total_duration * sampling_frequency)
|
| 153 |
combined_time = np.linspace(0, total_duration, num_samples)
|
| 154 |
combined_signal = np.zeros(num_samples)
|
| 155 |
|
| 156 |
-
if debug:
|
| 157 |
-
print(f"Combining signals with {sampling_frequency} Hz sampling rate, total duration: {total_duration}s")
|
| 158 |
-
|
| 159 |
-
# Place each lead at the correct time position
|
| 160 |
for i, (time, signal) in enumerate(zip(time_points, signals)):
|
| 161 |
start_time = i * layer_duration
|
| 162 |
mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
|
| 163 |
relevant_times = combined_time[mask]
|
| 164 |
interpolated_signal = np.interp(relevant_times, start_time + time, signal)
|
| 165 |
combined_signal[mask] = interpolated_signal
|
| 166 |
-
|
| 167 |
-
if debug:
|
| 168 |
-
print(f" - Added layer {i+1} signal from {start_time}s to {start_time + layer_duration}s")
|
| 169 |
|
| 170 |
-
# Baseline correction and amplitude scaling
|
| 171 |
combined_signal = combined_signal - np.mean(combined_signal)
|
| 172 |
signal_peak = np.max(np.abs(combined_signal))
|
| 173 |
-
target_amplitude = 2.0
|
| 174 |
-
|
| 175 |
-
if debug:
|
| 176 |
-
print(f"Signal peak before scaling: {signal_peak:.2f} mV")
|
| 177 |
|
| 178 |
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
|
| 179 |
scaling_factor = target_amplitude / signal_peak
|
| 180 |
combined_signal = combined_signal * scaling_factor
|
| 181 |
-
if debug:
|
| 182 |
-
print(f"Applied scaling factor: {scaling_factor:.2f}")
|
| 183 |
-
print(f"Signal peak after scaling: {np.max(np.abs(combined_signal)):.2f} mV")
|
| 184 |
|
| 185 |
-
|
| 186 |
-
adc_gain = 1000.0 # Standard gain: 1000 units per mV
|
| 187 |
int_signal = (combined_signal * adc_gain).astype(np.int16)
|
| 188 |
int_signal.tofile(output_file)
|
| 189 |
|
| 190 |
-
if debug:
|
| 191 |
-
print(f"Saved signal to {output_file} with {len(int_signal)} samples")
|
| 192 |
-
print(f"Integer signal range: {np.min(int_signal)} to {np.max(int_signal)}")
|
| 193 |
-
|
| 194 |
-
# Clean up temporary files
|
| 195 |
if os.path.exists(temp_image_path):
|
| 196 |
os.remove(temp_image_path)
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
return output_file, segment_files
|
| 201 |
|
| 202 |
-
|
| 203 |
-
def split_dat_into_segments(file_path, segment_duration=10.0
|
| 204 |
"""
|
| 205 |
Split a DAT file into equal segments.
|
| 206 |
|
| 207 |
Args:
|
| 208 |
file_path (str): Path to the DAT file (without extension)
|
| 209 |
segment_duration (float): Duration of each segment in seconds
|
| 210 |
-
debug (bool): Whether to print debug information
|
| 211 |
|
| 212 |
Returns:
|
| 213 |
list: Paths to the segment files
|
| 214 |
"""
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
# Choose a lead
|
| 223 |
-
if signal_all_leads.shape[1] == 1:
|
| 224 |
-
lead_index = 0
|
| 225 |
-
else:
|
| 226 |
-
lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
|
| 227 |
-
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
|
| 228 |
-
|
| 229 |
-
signal = signal_all_leads[:, lead_index]
|
| 230 |
-
|
| 231 |
-
# Calculate samples per segment
|
| 232 |
-
samples_per_segment = int(segment_duration * fs)
|
| 233 |
-
total_samples = len(signal)
|
| 234 |
-
num_segments = total_samples // samples_per_segment
|
| 235 |
-
|
| 236 |
-
if debug:
|
| 237 |
-
print(f"Splitting signal into {num_segments} segments of {segment_duration} seconds each")
|
| 238 |
-
|
| 239 |
-
segment_files = []
|
| 240 |
-
|
| 241 |
-
# Split and save each segment
|
| 242 |
-
base_name = os.path.splitext(file_path)[0]
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
return segment_files
|
| 258 |
-
|
| 259 |
-
except Exception as e:
|
| 260 |
-
if debug:
|
| 261 |
-
print(f"Error splitting DAT file: {str(e)}")
|
| 262 |
-
return []
|
| 263 |
|
| 264 |
-
|
| 265 |
-
def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16, debug=False):
|
| 266 |
"""
|
| 267 |
Load a DAT file containing ECG signal data.
|
| 268 |
|
|
@@ -271,79 +240,46 @@ def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16, debug
|
|
| 271 |
n_leads (int): Number of leads in the signal
|
| 272 |
n_samples (int): Number of samples per lead
|
| 273 |
dtype: Data type of the signal
|
| 274 |
-
debug (bool): Whether to print debug information
|
| 275 |
|
| 276 |
Returns:
|
| 277 |
tuple: (numpy array of signal data, sampling frequency)
|
| 278 |
"""
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
else:
|
| 284 |
-
dat_path = file_path + '.dat'
|
| 285 |
-
|
| 286 |
-
if debug:
|
| 287 |
-
print(f"Loading signal from: {dat_path}")
|
| 288 |
-
|
| 289 |
-
raw = np.fromfile(dat_path, dtype=dtype)
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
if debug:
|
| 297 |
-
print(f"Unexpected size: {raw.size}, expected {n_leads * n_samples}")
|
| 298 |
-
print("Attempting to infer number of leads...")
|
| 299 |
-
|
| 300 |
-
# Check if single lead
|
| 301 |
-
if raw.size == n_samples:
|
| 302 |
-
if debug:
|
| 303 |
-
print("Detected single lead signal")
|
| 304 |
-
signal = raw.reshape(n_samples, 1)
|
| 305 |
-
return signal, 500
|
| 306 |
-
|
| 307 |
-
# Try common lead counts
|
| 308 |
-
possible_leads = [1, 2, 3, 6, 12]
|
| 309 |
-
for possible_lead_count in possible_leads:
|
| 310 |
-
if raw.size % possible_lead_count == 0:
|
| 311 |
-
actual_samples = raw.size // possible_lead_count
|
| 312 |
-
if debug:
|
| 313 |
-
print(f"Inferred {possible_lead_count} leads with {actual_samples} samples each")
|
| 314 |
-
signal = raw.reshape(actual_samples, possible_lead_count)
|
| 315 |
-
return signal, 500
|
| 316 |
-
|
| 317 |
-
# If we can't determine it reliably, reshape as single lead
|
| 318 |
-
if debug:
|
| 319 |
-
print("Could not infer lead count, reshaping as single lead")
|
| 320 |
-
signal = raw.reshape(-1, 1)
|
| 321 |
return signal, 500
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
-
|
| 333 |
-
def extract_features_from_signal(signal, debug=False):
|
| 334 |
"""
|
| 335 |
Extract features from an ECG signal.
|
| 336 |
|
| 337 |
Args:
|
| 338 |
signal (numpy.ndarray): ECG signal
|
| 339 |
-
debug (bool): Whether to print debug information
|
| 340 |
|
| 341 |
Returns:
|
| 342 |
-
list:
|
| 343 |
"""
|
| 344 |
-
if debug:
|
| 345 |
-
print("Extracting features from signal...")
|
| 346 |
-
|
| 347 |
features = []
|
| 348 |
features.append(np.mean(signal))
|
| 349 |
features.append(np.std(signal))
|
|
@@ -354,129 +290,108 @@ def extract_features_from_signal(signal, debug=False):
|
|
| 354 |
features.append(np.percentile(signal, 75))
|
| 355 |
features.append(np.mean(np.diff(signal)))
|
| 356 |
|
| 357 |
-
if debug:
|
| 358 |
-
print("Computing wavelet decomposition...")
|
| 359 |
-
|
| 360 |
coeffs = pywt.wavedec(signal, 'db4', level=5)
|
| 361 |
-
for
|
| 362 |
features.append(np.mean(coeff))
|
| 363 |
features.append(np.std(coeff))
|
| 364 |
features.append(np.min(coeff))
|
| 365 |
features.append(np.max(coeff))
|
| 366 |
-
|
| 367 |
-
if debug and i == 0:
|
| 368 |
-
print(f"Wavelet features for level {i}: mean={np.mean(coeff):.4f}, std={np.std(coeff):.4f}")
|
| 369 |
|
| 370 |
-
if debug:
|
| 371 |
-
print(f"Extracted {len(features)} features")
|
| 372 |
-
|
| 373 |
return features
|
| 374 |
|
| 375 |
-
|
| 376 |
-
def classify_new_ecg(file_path, model
|
| 377 |
"""
|
| 378 |
Classify a new ECG file.
|
| 379 |
|
| 380 |
Args:
|
| 381 |
file_path (str): Path to the ECG file (without extension)
|
| 382 |
model: The trained model for classification
|
| 383 |
-
debug (bool): Whether to print debug information
|
| 384 |
|
| 385 |
Returns:
|
| 386 |
str: Classification result ("Normal", "Abnormal", or error message)
|
| 387 |
"""
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
if
|
| 395 |
-
print(f"Loaded signal with shape {signal_all_leads.shape}, sampling rate {fs} Hz")
|
| 396 |
-
|
| 397 |
-
# Choose lead for analysis - priority order
|
| 398 |
-
if signal_all_leads.shape[1] == 1:
|
| 399 |
-
lead_index = 0
|
| 400 |
-
if debug:
|
| 401 |
-
print("Using single lead")
|
| 402 |
-
else:
|
| 403 |
-
lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
|
| 404 |
-
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
|
| 405 |
-
if debug:
|
| 406 |
-
print(f"Using lead index {lead_index}")
|
| 407 |
-
|
| 408 |
-
# Extract the signal
|
| 409 |
-
signal = signal_all_leads[:, lead_index]
|
| 410 |
-
|
| 411 |
-
# Normalize signal
|
| 412 |
-
signal = (signal - np.mean(signal)) / np.std(signal)
|
| 413 |
-
|
| 414 |
-
if debug:
|
| 415 |
-
print("Signal normalized")
|
| 416 |
-
print(f"Detecting QRS complexes...")
|
| 417 |
-
|
| 418 |
-
# Detect QRS complexes
|
| 419 |
-
try:
|
| 420 |
-
xqrs = processing.XQRS(sig=signal, fs=fs)
|
| 421 |
-
xqrs.detect()
|
| 422 |
-
r_peaks = xqrs.qrs_inds
|
| 423 |
-
if debug:
|
| 424 |
-
print(f"Detected {len(r_peaks)} QRS complexes with XQRS method")
|
| 425 |
-
except Exception as e:
|
| 426 |
-
if debug:
|
| 427 |
-
print(f"XQRS detection failed: {str(e)}")
|
| 428 |
-
print("Falling back to GQRS detector")
|
| 429 |
-
r_peaks = processing.gqrs_detect(sig=signal, fs=fs)
|
| 430 |
-
if debug:
|
| 431 |
-
print(f"Detected {len(r_peaks)} QRS complexes with GQRS method")
|
| 432 |
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
| 438 |
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
| 440 |
rr_intervals = np.diff(r_peaks) / fs
|
| 441 |
qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
|
| 442 |
-
|
| 443 |
-
if debug:
|
| 444 |
-
print(f"Mean RR interval: {np.mean(rr_intervals):.4f} s")
|
| 445 |
-
print(f"Mean QRS duration: {np.mean(qrs_durations) / fs:.4f} s")
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
|
|
|
| 449 |
|
| 450 |
-
|
| 451 |
-
features.extend([
|
| 452 |
len(r_peaks),
|
| 453 |
np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
|
| 454 |
np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
|
| 455 |
np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
|
| 456 |
-
np.mean(qrs_durations) if len(qrs_durations) > 0 else 0,
|
| 457 |
-
np.std(qrs_durations) if len(qrs_durations) > 0 else 0
|
| 458 |
])
|
| 459 |
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
# Make prediction
|
| 464 |
-
prediction = model.predict([features])[0]
|
| 465 |
-
result = "Abnormal" if prediction == 1 else "Normal"
|
| 466 |
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
-
except Exception as e:
|
| 473 |
-
error_msg = f"Error: {str(e)}"
|
| 474 |
-
if debug:
|
| 475 |
-
print(error_msg)
|
| 476 |
-
return error_msg
|
| 477 |
|
| 478 |
-
|
| 479 |
-
def classify_ecg(file_path, model, is_pdf=False, debug=False):
|
| 480 |
"""
|
| 481 |
Wrapper function that handles both PDF and DAT ECG files with segment voting.
|
| 482 |
|
|
@@ -484,86 +399,44 @@ def classify_ecg(file_path, model, is_pdf=False, debug=False):
|
|
| 484 |
file_path (str): Path to the ECG file (.pdf or without extension for .dat)
|
| 485 |
model: The trained model for classification
|
| 486 |
is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
|
| 487 |
-
debug (bool): Enable debug output
|
| 488 |
|
| 489 |
Returns:
|
| 490 |
str: Classification result ("Normal", "Abnormal", or error message)
|
| 491 |
"""
|
| 492 |
try:
|
| 493 |
-
# Check if model is valid
|
| 494 |
if model is None:
|
| 495 |
return "Error: Model not loaded. Please check model compatibility."
|
| 496 |
|
| 497 |
if is_pdf:
|
| 498 |
-
if debug:
|
| 499 |
-
print(f"Processing PDF file: {file_path}")
|
| 500 |
-
|
| 501 |
-
# Extract file name without extension for output
|
| 502 |
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 503 |
output_dat = f"{base_name}_digitized.dat"
|
| 504 |
|
| 505 |
-
# Digitize the PDF to a DAT file and get segment files
|
| 506 |
dat_path, segment_files = digitize_ecg_from_pdf(
|
| 507 |
pdf_path=file_path,
|
| 508 |
-
output_file=output_dat
|
| 509 |
-
debug=debug
|
| 510 |
)
|
| 511 |
-
|
| 512 |
-
if debug:
|
| 513 |
-
print(f"Digitized ECG saved to: {dat_path}")
|
| 514 |
-
print(f"Created {len(segment_files)} segment files")
|
| 515 |
else:
|
| 516 |
-
|
| 517 |
-
print(f"Processing DAT file: {file_path}")
|
| 518 |
-
|
| 519 |
-
# For DAT files, we need to split into segments
|
| 520 |
-
segment_files = split_dat_into_segments(file_path, debug=debug)
|
| 521 |
|
| 522 |
if not segment_files:
|
| 523 |
-
|
| 524 |
-
return classify_new_ecg(file_path, model, debug=debug)
|
| 525 |
|
| 526 |
-
# Process each segment and collect votes
|
| 527 |
segment_results = []
|
| 528 |
|
| 529 |
-
for
|
| 530 |
-
if debug:
|
| 531 |
-
print(f"\n--- Processing Segment {i+1} ---")
|
| 532 |
-
|
| 533 |
-
# Get file path without extension
|
| 534 |
segment_path = os.path.splitext(segment_file)[0]
|
| 535 |
-
|
| 536 |
-
# Classify this segment
|
| 537 |
-
result = classify_new_ecg(segment_path, model, debug=debug)
|
| 538 |
-
|
| 539 |
-
if debug:
|
| 540 |
-
print(f"Segment {i+1} classification: {result}")
|
| 541 |
-
|
| 542 |
segment_results.append(result)
|
| 543 |
|
| 544 |
-
# Remove temporary segment files
|
| 545 |
try:
|
| 546 |
os.remove(segment_file)
|
| 547 |
-
if debug:
|
| 548 |
-
print(f"Removed temporary segment file: {segment_file}")
|
| 549 |
except:
|
| 550 |
pass
|
| 551 |
|
| 552 |
-
# Count results and use majority voting
|
| 553 |
if segment_results:
|
| 554 |
normal_count = segment_results.count("Normal")
|
| 555 |
abnormal_count = segment_results.count("Abnormal")
|
| 556 |
-
error_count = len(segment_results) - normal_count - abnormal_count
|
| 557 |
|
| 558 |
-
if debug:
|
| 559 |
-
print(f"\n--- Voting Results ---")
|
| 560 |
-
print(f"Normal votes: {normal_count}")
|
| 561 |
-
print(f"Abnormal votes: {abnormal_count}")
|
| 562 |
-
print(f"Errors/Inconclusive: {error_count}")
|
| 563 |
-
|
| 564 |
-
# Decision rules:
|
| 565 |
-
# 1. If any segment is abnormal, classify as abnormal
|
| 566 |
-
# 2. Only classify as normal if majority of segments are normal
|
| 567 |
if abnormal_count > normal_count:
|
| 568 |
final_result = "Abnormal"
|
| 569 |
elif normal_count > abnormal_count:
|
|
@@ -571,41 +444,14 @@ def classify_ecg(file_path, model, is_pdf=False, debug=False):
|
|
| 571 |
else:
|
| 572 |
final_result = "Inconclusive"
|
| 573 |
|
| 574 |
-
if debug:
|
| 575 |
-
print(f"Final decision: {final_result}")
|
| 576 |
-
|
| 577 |
return final_result
|
| 578 |
else:
|
| 579 |
return "Error: No valid segments to classify"
|
| 580 |
|
| 581 |
except Exception as e:
|
| 582 |
error_msg = f"Classification error: {str(e)}"
|
| 583 |
-
if debug:
|
| 584 |
-
print(error_msg)
|
| 585 |
return error_msg
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
model_path = 'voting_classifier.pkl'
|
| 589 |
-
if os.path.exists(model_path):
|
| 590 |
-
voting_loaded = joblib.load(model_path)
|
| 591 |
-
else:
|
| 592 |
-
# Try to find the model in the current or parent directories
|
| 593 |
-
for root, dirs, files in os.walk('.'):
|
| 594 |
-
for file in files:
|
| 595 |
-
if file.endswith('.pkl') and 'voting' in file.lower():
|
| 596 |
-
model_path = os.path.join(root, file)
|
| 597 |
-
voting_loaded = joblib.load(model_path)
|
| 598 |
-
break
|
| 599 |
-
if 'voting_loaded' in locals():
|
| 600 |
-
break
|
| 601 |
-
|
| 602 |
-
if 'voting_loaded' not in locals():
|
| 603 |
-
voting_loaded = None
|
| 604 |
-
except Exception as e:
|
| 605 |
-
voting_loaded = None
|
| 606 |
|
| 607 |
-
|
| 608 |
-
test_pdf_path = "sample.pdf"
|
| 609 |
-
if os.path.exists(test_pdf_path) and voting_loaded is not None:
|
| 610 |
-
result_pdf = classify_ecg(test_pdf_path, voting_loaded, is_pdf=True)
|
| 611 |
-
print(f"Classification result: {result_pdf}")
|
|
|
|
| 1 |
+
import wfdb
|
| 2 |
+
from wfdb import processing
|
| 3 |
+
import numpy as np
|
| 4 |
+
import joblib
|
| 5 |
+
import pywt
|
| 6 |
+
import os
|
| 7 |
+
import cv2
|
| 8 |
+
from pdf2image import convert_from_path
|
| 9 |
import warnings
|
| 10 |
import pickle
|
| 11 |
+
from scipy import signal as sg
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
|
| 14 |
+
|
| 15 |
+
def extract_hrv_features(rr_intervals):
|
| 16 |
+
"""
|
| 17 |
+
Extract heart rate variability features from RR intervals.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
rr_intervals (numpy.ndarray): RR intervals in seconds
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
list: Four HRV features [sdnn, rmssd, pnn50, tri_index]
|
| 24 |
+
"""
|
| 25 |
+
if len(rr_intervals) < 2:
|
| 26 |
+
return [0, 0, 0, 0]
|
| 27 |
+
|
| 28 |
+
sdnn = np.std(rr_intervals)
|
| 29 |
+
diff_rr = np.diff(rr_intervals)
|
| 30 |
+
rmssd = np.sqrt(np.mean(diff_rr**2)) if len(diff_rr) > 0 else 0
|
| 31 |
+
pnn50 = 100 * np.sum(np.abs(diff_rr) > 0.05) / len(diff_rr) if len(diff_rr) > 0 else 0
|
| 32 |
+
|
| 33 |
+
if len(rr_intervals) > 2:
|
| 34 |
+
bin_width = 1/128
|
| 35 |
+
bins = np.arange(min(rr_intervals), max(rr_intervals) + bin_width, bin_width)
|
| 36 |
+
n, _ = np.histogram(rr_intervals, bins=bins)
|
| 37 |
+
tri_index = len(rr_intervals) / np.max(n) if np.max(n) > 0 else 0
|
| 38 |
+
else:
|
| 39 |
+
tri_index = 0
|
| 40 |
+
|
| 41 |
+
return [sdnn, rmssd, pnn50, tri_index]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def extract_qrs_features(signal, r_peaks, fs):
|
| 45 |
+
"""
|
| 46 |
+
Extract QRS complex features from ECG signal and detected R peaks.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
signal (numpy.ndarray): ECG signal
|
| 50 |
+
r_peaks (numpy.ndarray): Array of R peak indices
|
| 51 |
+
fs (int): Sampling frequency in Hz
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
list: Three QRS features [qrs_width_mean, qrs_width_std, qrs_amplitude_mean]
|
| 55 |
+
"""
|
| 56 |
+
if len(r_peaks) < 2:
|
| 57 |
+
return [0, 0, 0]
|
| 58 |
+
|
| 59 |
+
qrs_width = []
|
| 60 |
+
for i in range(len(r_peaks)):
|
| 61 |
+
r_pos = r_peaks[i]
|
| 62 |
+
window_before = max(0, r_pos - int(0.1 * fs))
|
| 63 |
+
window_after = min(len(signal) - 1, r_pos + int(0.1 * fs))
|
| 64 |
+
|
| 65 |
+
if r_pos > window_before:
|
| 66 |
+
q_pos = window_before + np.argmin(signal[window_before:r_pos])
|
| 67 |
+
else:
|
| 68 |
+
q_pos = window_before
|
| 69 |
+
|
| 70 |
+
if r_pos < window_after:
|
| 71 |
+
s_pos = r_pos + np.argmin(signal[r_pos:window_after])
|
| 72 |
+
else:
|
| 73 |
+
s_pos = r_pos
|
| 74 |
+
|
| 75 |
+
if s_pos > q_pos:
|
| 76 |
+
qrs_width.append((s_pos - q_pos) / fs)
|
| 77 |
+
|
| 78 |
+
qrs_width_mean = np.mean(qrs_width) if qrs_width else 0
|
| 79 |
+
qrs_width_std = np.std(qrs_width) if qrs_width else 0
|
| 80 |
+
qrs_amplitude_mean = np.mean([signal[r] for r in r_peaks]) if r_peaks.size > 0 else 0
|
| 81 |
+
|
| 82 |
+
return [qrs_width_mean, qrs_width_std, qrs_amplitude_mean]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def digitize_ecg_from_pdf(pdf_path, output_file=None):
|
| 86 |
"""
|
| 87 |
Process an ECG PDF file and convert it to a .dat signal file.
|
| 88 |
|
| 89 |
Args:
|
| 90 |
pdf_path (str): Path to the ECG PDF file
|
| 91 |
+
output_file (str, optional): Path to save the output .dat file
|
|
|
|
|
|
|
| 92 |
|
| 93 |
Returns:
|
| 94 |
tuple: (path to the created .dat file, list of paths to segment files)
|
| 95 |
"""
|
| 96 |
+
if output_file is None:
|
| 97 |
+
output_file = 'calibrated_ecg.dat'
|
| 98 |
|
|
|
|
| 99 |
images = convert_from_path(pdf_path)
|
| 100 |
temp_image_path = 'temp_ecg_image.jpg'
|
| 101 |
images[0].save(temp_image_path, 'JPEG')
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
|
| 104 |
height, width = img.shape
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
calibration = {
|
| 107 |
+
'seconds_per_pixel': 2.0 / 197.0,
|
| 108 |
+
'mv_per_pixel': 1.0 / 78.8,
|
| 109 |
}
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
layer1_start = int(height * 35.35 / 100)
|
| 112 |
layer1_end = int(height * 51.76 / 100)
|
| 113 |
layer2_start = int(height * 51.82 / 100)
|
|
|
|
| 115 |
layer3_start = int(height * 69.47 / 100)
|
| 116 |
layer3_end = int(height * 87.06 / 100)
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
layers = [
|
| 119 |
+
img[layer1_start:layer1_end, :],
|
| 120 |
+
img[layer2_start:layer2_end, :],
|
| 121 |
+
img[layer3_start:layer3_end, :]
|
| 122 |
]
|
| 123 |
|
|
|
|
| 124 |
signals = []
|
| 125 |
time_points = []
|
| 126 |
+
layer_duration = 10.0
|
| 127 |
|
| 128 |
for i, layer in enumerate(layers):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
_, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
|
| 130 |
|
|
|
|
| 131 |
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 132 |
+
waveform_contour = max(contours, key=cv2.contourArea)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
|
|
|
| 134 |
sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
|
| 135 |
x_coords = np.array([point[0][0] for point in sorted_contour])
|
| 136 |
y_coords = np.array([point[0][1] for point in sorted_contour])
|
| 137 |
|
|
|
|
| 138 |
isoelectric_line_y = layer.shape[0] * 0.6
|
| 139 |
|
|
|
|
| 140 |
x_min, x_max = np.min(x_coords), np.max(x_coords)
|
| 141 |
time = (x_coords - x_min) / (x_max - x_min) * layer_duration
|
| 142 |
|
|
|
|
| 143 |
signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
|
| 144 |
signal_mv = signal_mv - np.mean(signal_mv)
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
time_points.append(time)
|
| 147 |
signals.append(signal_mv)
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
total_duration = layer_duration * len(layers)
|
| 150 |
+
sampling_frequency = 500
|
| 151 |
num_samples = int(total_duration * sampling_frequency)
|
| 152 |
combined_time = np.linspace(0, total_duration, num_samples)
|
| 153 |
combined_signal = np.zeros(num_samples)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
for i, (time, signal) in enumerate(zip(time_points, signals)):
|
| 156 |
start_time = i * layer_duration
|
| 157 |
mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
|
| 158 |
relevant_times = combined_time[mask]
|
| 159 |
interpolated_signal = np.interp(relevant_times, start_time + time, signal)
|
| 160 |
combined_signal[mask] = interpolated_signal
|
|
|
|
|
|
|
|
|
|
| 161 |
|
|
|
|
| 162 |
combined_signal = combined_signal - np.mean(combined_signal)
|
| 163 |
signal_peak = np.max(np.abs(combined_signal))
|
| 164 |
+
target_amplitude = 2.0
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
|
| 167 |
scaling_factor = target_amplitude / signal_peak
|
| 168 |
combined_signal = combined_signal * scaling_factor
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
adc_gain = 1000.0
|
|
|
|
| 171 |
int_signal = (combined_signal * adc_gain).astype(np.int16)
|
| 172 |
int_signal.tofile(output_file)
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if os.path.exists(temp_image_path):
|
| 175 |
os.remove(temp_image_path)
|
| 176 |
+
|
| 177 |
+
segment_files = []
|
| 178 |
+
samples_per_segment = int(layer_duration * sampling_frequency)
|
| 179 |
+
|
| 180 |
+
base_name = os.path.splitext(output_file)[0]
|
| 181 |
+
for i in range(3):
|
| 182 |
+
start_idx = i * samples_per_segment
|
| 183 |
+
end_idx = (i + 1) * samples_per_segment
|
| 184 |
+
segment = combined_signal[start_idx:end_idx]
|
| 185 |
+
|
| 186 |
+
segment_file = f"{base_name}_segment{i+1}.dat"
|
| 187 |
+
(segment * adc_gain).astype(np.int16).tofile(segment_file)
|
| 188 |
+
segment_files.append(segment_file)
|
| 189 |
|
| 190 |
return output_file, segment_files
|
| 191 |
|
| 192 |
+
|
| 193 |
+
def split_dat_into_segments(file_path, segment_duration=10.0):
|
| 194 |
"""
|
| 195 |
Split a DAT file into equal segments.
|
| 196 |
|
| 197 |
Args:
|
| 198 |
file_path (str): Path to the DAT file (without extension)
|
| 199 |
segment_duration (float): Duration of each segment in seconds
|
|
|
|
| 200 |
|
| 201 |
Returns:
|
| 202 |
list: Paths to the segment files
|
| 203 |
"""
|
| 204 |
+
signal_all_leads, fs = load_dat_signal(file_path)
|
| 205 |
+
|
| 206 |
+
if signal_all_leads.shape[1] == 1:
|
| 207 |
+
lead_index = 0
|
| 208 |
+
else:
|
| 209 |
+
lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
|
| 210 |
+
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
signal = signal_all_leads[:, lead_index]
|
| 213 |
+
|
| 214 |
+
samples_per_segment = int(segment_duration * fs)
|
| 215 |
+
total_samples = len(signal)
|
| 216 |
+
num_segments = total_samples // samples_per_segment
|
| 217 |
+
|
| 218 |
+
segment_files = []
|
| 219 |
+
|
| 220 |
+
base_name = os.path.splitext(file_path)[0]
|
| 221 |
+
|
| 222 |
+
for i in range(num_segments):
|
| 223 |
+
start_idx = i * samples_per_segment
|
| 224 |
+
end_idx = (i + 1) * samples_per_segment
|
| 225 |
+
segment = signal[start_idx:end_idx]
|
| 226 |
+
|
| 227 |
+
segment_file = f"{base_name}_segment{i+1}.dat"
|
| 228 |
+
segment.reshape(-1, 1).tofile(segment_file)
|
| 229 |
+
segment_files.append(segment_file)
|
| 230 |
|
| 231 |
+
return segment_files
|
| 232 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16):
|
|
|
|
| 235 |
"""
|
| 236 |
Load a DAT file containing ECG signal data.
|
| 237 |
|
|
|
|
| 240 |
n_leads (int): Number of leads in the signal
|
| 241 |
n_samples (int): Number of samples per lead
|
| 242 |
dtype: Data type of the signal
|
|
|
|
| 243 |
|
| 244 |
Returns:
|
| 245 |
tuple: (numpy array of signal data, sampling frequency)
|
| 246 |
"""
|
| 247 |
+
if file_path.endswith('.dat'):
|
| 248 |
+
dat_path = file_path
|
| 249 |
+
else:
|
| 250 |
+
dat_path = file_path + '.dat'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
raw = np.fromfile(dat_path, dtype=dtype)
|
| 253 |
+
|
| 254 |
+
if raw.size != n_leads * n_samples:
|
| 255 |
+
if raw.size == n_samples:
|
| 256 |
+
signal = raw.reshape(n_samples, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
return signal, 500
|
| 258 |
|
| 259 |
+
possible_leads = [1, 2, 3, 6, 12]
|
| 260 |
+
for possible_lead_count in possible_leads:
|
| 261 |
+
if raw.size % possible_lead_count == 0:
|
| 262 |
+
actual_samples = raw.size // possible_lead_count
|
| 263 |
+
signal = raw.reshape(actual_samples, possible_lead_count)
|
| 264 |
+
return signal, 500
|
| 265 |
+
|
| 266 |
+
signal = raw.reshape(-1, 1)
|
| 267 |
+
return signal, 500
|
| 268 |
+
|
| 269 |
+
signal = raw.reshape(n_samples, n_leads)
|
| 270 |
+
return signal, 500
|
| 271 |
+
|
| 272 |
|
| 273 |
+
def extract_features_from_signal(signal):
|
|
|
|
| 274 |
"""
|
| 275 |
Extract features from an ECG signal.
|
| 276 |
|
| 277 |
Args:
|
| 278 |
signal (numpy.ndarray): ECG signal
|
|
|
|
| 279 |
|
| 280 |
Returns:
|
| 281 |
+
list: Basic features extracted from the signal (32 features)
|
| 282 |
"""
|
|
|
|
|
|
|
|
|
|
| 283 |
features = []
|
| 284 |
features.append(np.mean(signal))
|
| 285 |
features.append(np.std(signal))
|
|
|
|
| 290 |
features.append(np.percentile(signal, 75))
|
| 291 |
features.append(np.mean(np.diff(signal)))
|
| 292 |
|
|
|
|
|
|
|
|
|
|
| 293 |
coeffs = pywt.wavedec(signal, 'db4', level=5)
|
| 294 |
+
for coeff in coeffs:
|
| 295 |
features.append(np.mean(coeff))
|
| 296 |
features.append(np.std(coeff))
|
| 297 |
features.append(np.min(coeff))
|
| 298 |
features.append(np.max(coeff))
|
|
|
|
|
|
|
|
|
|
| 299 |
|
|
|
|
|
|
|
|
|
|
| 300 |
return features
|
| 301 |
|
| 302 |
+
|
| 303 |
+
def classify_new_ecg(file_path, model):
|
| 304 |
"""
|
| 305 |
Classify a new ECG file.
|
| 306 |
|
| 307 |
Args:
|
| 308 |
file_path (str): Path to the ECG file (without extension)
|
| 309 |
model: The trained model for classification
|
|
|
|
| 310 |
|
| 311 |
Returns:
|
| 312 |
str: Classification result ("Normal", "Abnormal", or error message)
|
| 313 |
"""
|
| 314 |
+
signal_all_leads, fs = load_dat_signal(file_path)
|
| 315 |
+
|
| 316 |
+
if signal_all_leads.shape[1] == 1:
|
| 317 |
+
lead_index = 0
|
| 318 |
+
else:
|
| 319 |
+
lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
|
| 320 |
+
lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
+
signal = signal_all_leads[:, lead_index]
|
| 323 |
+
signal = (signal - np.mean(signal)) / np.std(signal)
|
| 324 |
+
|
| 325 |
+
try:
|
| 326 |
+
r_peaks = processing.gqrs_detect(sig=signal, fs=fs)
|
| 327 |
+
except:
|
| 328 |
+
r_peaks = np.array([])
|
| 329 |
|
| 330 |
+
if len(r_peaks) < 2:
|
| 331 |
+
basic_features = extract_features_from_signal(signal)
|
| 332 |
+
record_features = basic_features + [0] * (45 - len(basic_features))
|
| 333 |
+
else:
|
| 334 |
rr_intervals = np.diff(r_peaks) / fs
|
| 335 |
qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
+
record_features = []
|
| 338 |
+
basic_features = extract_features_from_signal(signal)
|
| 339 |
+
record_features.extend(basic_features)
|
| 340 |
|
| 341 |
+
record_features.extend([
|
|
|
|
| 342 |
len(r_peaks),
|
| 343 |
np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
|
| 344 |
np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
|
| 345 |
np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
|
| 346 |
+
np.mean(qrs_durations) / fs if len(qrs_durations) > 0 else 0,
|
| 347 |
+
np.std(qrs_durations) / fs if len(qrs_durations) > 0 else 0
|
| 348 |
])
|
| 349 |
|
| 350 |
+
hrv_features = extract_hrv_features(rr_intervals)
|
| 351 |
+
record_features.extend(hrv_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
+
qrs_features = extract_qrs_features(signal, r_peaks, fs)
|
| 354 |
+
record_features.extend(qrs_features)
|
| 355 |
+
|
| 356 |
+
if len(rr_intervals) >= 4:
|
| 357 |
+
try:
|
| 358 |
+
rr_times = np.cumsum(rr_intervals)
|
| 359 |
+
rr_times = np.insert(rr_times, 0, 0)
|
| 360 |
+
|
| 361 |
+
fs_interp = 4.0
|
| 362 |
+
t_interp = np.arange(0, rr_times[-1], 1/fs_interp)
|
| 363 |
+
rr_interp = np.interp(t_interp, rr_times[:-1], rr_intervals)
|
| 364 |
+
|
| 365 |
+
freq, psd = sg.welch(rr_interp, fs=fs_interp, nperseg=min(256, len(rr_interp)))
|
| 366 |
+
|
| 367 |
+
vlf_mask = (freq >= 0.0033) & (freq < 0.04)
|
| 368 |
+
lf_mask = (freq >= 0.04) & (freq < 0.15)
|
| 369 |
+
hf_mask = (freq >= 0.15) & (freq < 0.4)
|
| 370 |
+
|
| 371 |
+
lf_power = np.trapz(psd[lf_mask], freq[lf_mask]) if np.any(lf_mask) else 0
|
| 372 |
+
hf_power = np.trapz(psd[hf_mask], freq[hf_mask]) if np.any(hf_mask) else 0
|
| 373 |
+
|
| 374 |
+
lf_hf_ratio = lf_power / hf_power if hf_power > 0 else 0
|
| 375 |
+
normalized_lf = lf_power / (lf_power + hf_power) if (lf_power + hf_power) > 0 else 0
|
| 376 |
+
except:
|
| 377 |
+
lf_power = hf_power = lf_hf_ratio = normalized_lf = 0
|
| 378 |
+
else:
|
| 379 |
+
lf_power = hf_power = lf_hf_ratio = normalized_lf = 0
|
| 380 |
|
| 381 |
+
record_features.extend([lf_power, hf_power, lf_hf_ratio, normalized_lf])
|
| 382 |
+
|
| 383 |
+
if len(record_features) < 45:
|
| 384 |
+
record_features.extend([0] * (45 - len(record_features)))
|
| 385 |
+
elif len(record_features) > 45:
|
| 386 |
+
record_features = record_features[:45]
|
| 387 |
+
|
| 388 |
+
prediction = model.predict([record_features])[0]
|
| 389 |
+
result = "Abnormal" if prediction == 1 else "Normal"
|
| 390 |
+
|
| 391 |
+
return result
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
+
def classify_ecg(file_path, model, is_pdf=False):
|
|
|
|
| 395 |
"""
|
| 396 |
Wrapper function that handles both PDF and DAT ECG files with segment voting.
|
| 397 |
|
|
|
|
| 399 |
file_path (str): Path to the ECG file (.pdf or without extension for .dat)
|
| 400 |
model: The trained model for classification
|
| 401 |
is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
|
|
|
|
| 402 |
|
| 403 |
Returns:
|
| 404 |
str: Classification result ("Normal", "Abnormal", or error message)
|
| 405 |
"""
|
| 406 |
try:
|
|
|
|
| 407 |
if model is None:
|
| 408 |
return "Error: Model not loaded. Please check model compatibility."
|
| 409 |
|
| 410 |
if is_pdf:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 412 |
output_dat = f"{base_name}_digitized.dat"
|
| 413 |
|
|
|
|
| 414 |
dat_path, segment_files = digitize_ecg_from_pdf(
|
| 415 |
pdf_path=file_path,
|
| 416 |
+
output_file=output_dat
|
|
|
|
| 417 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
else:
|
| 419 |
+
segment_files = split_dat_into_segments(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
if not segment_files:
|
| 422 |
+
return classify_new_ecg(file_path, model)
|
|
|
|
| 423 |
|
|
|
|
| 424 |
segment_results = []
|
| 425 |
|
| 426 |
+
for segment_file in segment_files:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
segment_path = os.path.splitext(segment_file)[0]
|
| 428 |
+
result = classify_new_ecg(segment_path, model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
segment_results.append(result)
|
| 430 |
|
|
|
|
| 431 |
try:
|
| 432 |
os.remove(segment_file)
|
|
|
|
|
|
|
| 433 |
except:
|
| 434 |
pass
|
| 435 |
|
|
|
|
| 436 |
if segment_results:
|
| 437 |
normal_count = segment_results.count("Normal")
|
| 438 |
abnormal_count = segment_results.count("Abnormal")
|
|
|
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
if abnormal_count > normal_count:
|
| 441 |
final_result = "Abnormal"
|
| 442 |
elif normal_count > abnormal_count:
|
|
|
|
| 444 |
else:
|
| 445 |
final_result = "Inconclusive"
|
| 446 |
|
|
|
|
|
|
|
|
|
|
| 447 |
return final_result
|
| 448 |
else:
|
| 449 |
return "Error: No valid segments to classify"
|
| 450 |
|
| 451 |
except Exception as e:
|
| 452 |
error_msg = f"Classification error: {str(e)}"
|
|
|
|
|
|
|
| 453 |
return error_msg
|
| 454 |
+
|
| 455 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
+
|
|
|
|
|
|
|
|
|
|
|
|
ECG/ECG_MultiClass.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
"""
|
| 2 |
-
ECG Analysis Pipeline: From PDF to
|
| 3 |
-
-------------------------------------------
|
| 4 |
This module provides functions to:
|
| 5 |
1. Digitize ECG from PDF files
|
| 6 |
2. Process the digitized ECG signal
|
| 7 |
-
3.
|
| 8 |
"""
|
| 9 |
|
| 10 |
import cv2
|
|
@@ -13,50 +13,42 @@ import os
|
|
| 13 |
import tensorflow as tf
|
| 14 |
import pickle
|
| 15 |
from scipy.interpolate import interp1d
|
| 16 |
-
from collections import Counter
|
| 17 |
from pdf2image import convert_from_path
|
| 18 |
-
import matplotlib.pyplot as plt # Added for visualization
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
Process an ECG PDF file and convert it to a .dat signal file.
|
| 23 |
|
| 24 |
Args:
|
| 25 |
pdf_path (str): Path to the ECG PDF file
|
| 26 |
-
output_file (str): Path to save the output .dat file
|
| 27 |
-
debug (bool): Whether to print debug information
|
| 28 |
|
| 29 |
Returns:
|
| 30 |
-
|
| 31 |
"""
|
| 32 |
-
if
|
| 33 |
-
|
| 34 |
|
| 35 |
-
# Convert PDF to image
|
| 36 |
images = convert_from_path(pdf_path)
|
| 37 |
temp_image_path = 'temp_ecg_image.jpg'
|
| 38 |
images[0].save(temp_image_path, 'JPEG')
|
| 39 |
|
| 40 |
-
if debug:
|
| 41 |
-
print(f"Converted PDF to image: {temp_image_path}")
|
| 42 |
-
|
| 43 |
-
# Load the image
|
| 44 |
img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
|
| 45 |
height, width = img.shape
|
| 46 |
|
| 47 |
-
if debug:
|
| 48 |
-
print(f"Image dimensions: {width}x{height}")
|
| 49 |
-
|
| 50 |
-
# Fixed calibration parameters
|
| 51 |
calibration = {
|
| 52 |
-
'seconds_per_pixel': 2.0 / 197.0,
|
| 53 |
-
'mv_per_pixel': 1.0 / 78.8,
|
| 54 |
}
|
| 55 |
|
| 56 |
-
if debug:
|
| 57 |
-
print(f"Calibration parameters: {calibration}")
|
| 58 |
-
|
| 59 |
-
# Calculate layer boundaries using percentages
|
| 60 |
layer1_start = int(height * 35.35 / 100)
|
| 61 |
layer1_end = int(height * 51.76 / 100)
|
| 62 |
layer2_start = int(height * 51.82 / 100)
|
|
@@ -64,239 +56,126 @@ def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=Fals
|
|
| 64 |
layer3_start = int(height * 69.47 / 100)
|
| 65 |
layer3_end = int(height * 87.06 / 100)
|
| 66 |
|
| 67 |
-
if debug:
|
| 68 |
-
print(f"Layer 1 boundaries: {layer1_start}-{layer1_end}")
|
| 69 |
-
print(f"Layer 2 boundaries: {layer2_start}-{layer2_end}")
|
| 70 |
-
print(f"Layer 3 boundaries: {layer3_start}-{layer3_end}")
|
| 71 |
-
|
| 72 |
-
# Crop each layer
|
| 73 |
layers = [
|
| 74 |
-
img[layer1_start:layer1_end, :],
|
| 75 |
-
img[layer2_start:layer2_end, :],
|
| 76 |
-
img[layer3_start:layer3_end, :]
|
| 77 |
]
|
| 78 |
|
| 79 |
-
# Process each layer to extract waveform contours
|
| 80 |
signals = []
|
| 81 |
time_points = []
|
| 82 |
-
layer_duration = 10.0
|
| 83 |
|
| 84 |
for i, layer in enumerate(layers):
|
| 85 |
-
if debug:
|
| 86 |
-
print(f"Processing layer {i+1}...")
|
| 87 |
-
|
| 88 |
-
# Binary thresholding
|
| 89 |
_, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
|
| 90 |
|
| 91 |
-
# Detect contours
|
| 92 |
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 93 |
-
waveform_contour = max(contours, key=cv2.contourArea)
|
| 94 |
-
|
| 95 |
-
if debug:
|
| 96 |
-
print(f" - Found {len(contours)} contours")
|
| 97 |
-
print(f" - Selected contour with {len(waveform_contour)} points")
|
| 98 |
|
| 99 |
-
# Sort contour points and extract coordinates
|
| 100 |
sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
|
| 101 |
x_coords = np.array([point[0][0] for point in sorted_contour])
|
| 102 |
y_coords = np.array([point[0][1] for point in sorted_contour])
|
| 103 |
|
| 104 |
-
# Calculate isoelectric line (one-third from the bottom)
|
| 105 |
isoelectric_line_y = layer.shape[0] * 0.6
|
| 106 |
|
| 107 |
-
# Convert to time using fixed layer duration
|
| 108 |
x_min, x_max = np.min(x_coords), np.max(x_coords)
|
| 109 |
time = (x_coords - x_min) / (x_max - x_min) * layer_duration
|
| 110 |
|
| 111 |
-
# Calculate signal in millivolts and apply baseline correction
|
| 112 |
signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
|
| 113 |
signal_mv = signal_mv - np.mean(signal_mv)
|
| 114 |
|
| 115 |
-
if debug:
|
| 116 |
-
print(f" - Layer {i+1} signal range: {np.min(signal_mv):.2f} mV to {np.max(signal_mv):.2f} mV")
|
| 117 |
-
|
| 118 |
-
# Store the time points and calibrated signal
|
| 119 |
time_points.append(time)
|
| 120 |
signals.append(signal_mv)
|
| 121 |
|
| 122 |
-
# Combine signals with proper time alignment
|
| 123 |
total_duration = layer_duration * len(layers)
|
| 124 |
-
sampling_frequency = 500
|
| 125 |
num_samples = int(total_duration * sampling_frequency)
|
| 126 |
combined_time = np.linspace(0, total_duration, num_samples)
|
| 127 |
combined_signal = np.zeros(num_samples)
|
| 128 |
|
| 129 |
-
if debug:
|
| 130 |
-
print(f"Combining signals with {sampling_frequency} Hz sampling rate, total duration: {total_duration}s")
|
| 131 |
-
|
| 132 |
-
# Place each lead at the correct time position
|
| 133 |
for i, (time, signal) in enumerate(zip(time_points, signals)):
|
| 134 |
start_time = i * layer_duration
|
| 135 |
mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
|
| 136 |
relevant_times = combined_time[mask]
|
| 137 |
interpolated_signal = np.interp(relevant_times, start_time + time, signal)
|
| 138 |
combined_signal[mask] = interpolated_signal
|
| 139 |
-
|
| 140 |
-
if debug:
|
| 141 |
-
print(f" - Added layer {i+1} signal from {start_time}s to {start_time + layer_duration}s")
|
| 142 |
|
| 143 |
-
# Baseline correction and amplitude scaling
|
| 144 |
combined_signal = combined_signal - np.mean(combined_signal)
|
| 145 |
signal_peak = np.max(np.abs(combined_signal))
|
| 146 |
-
target_amplitude = 2.0
|
| 147 |
-
|
| 148 |
-
if debug:
|
| 149 |
-
print(f"Signal peak before scaling: {signal_peak:.2f} mV")
|
| 150 |
|
| 151 |
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
|
| 152 |
scaling_factor = target_amplitude / signal_peak
|
| 153 |
combined_signal = combined_signal * scaling_factor
|
| 154 |
-
if debug:
|
| 155 |
-
print(f"Applied scaling factor: {scaling_factor:.2f}")
|
| 156 |
-
print(f"Signal peak after scaling: {np.max(np.abs(combined_signal)):.2f} mV")
|
| 157 |
|
| 158 |
-
|
| 159 |
-
adc_gain = 1000.0 # Standard gain: 1000 units per mV
|
| 160 |
int_signal = (combined_signal * adc_gain).astype(np.int16)
|
| 161 |
int_signal.tofile(output_file)
|
| 162 |
|
| 163 |
-
if debug:
|
| 164 |
-
print(f"Saved signal to {output_file} with {len(int_signal)} samples")
|
| 165 |
-
print(f"Integer signal range: {np.min(int_signal)} to {np.max(int_signal)}")
|
| 166 |
-
|
| 167 |
-
# Clean up temporary files
|
| 168 |
if os.path.exists(temp_image_path):
|
| 169 |
os.remove(temp_image_path)
|
| 170 |
-
if debug:
|
| 171 |
-
print(f"Removed temporary image: {temp_image_path}")
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
def visualize_ecg_signal(signal, sampling_rate=500, title="Digitized ECG Signal"):
|
| 176 |
-
"""
|
| 177 |
-
Visualize an ECG signal with proper time axis.
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
# Create figure with appropriate size
|
| 192 |
-
plt.figure(figsize=(15, 5))
|
| 193 |
-
plt.plot(time, signal)
|
| 194 |
-
plt.title(title)
|
| 195 |
-
plt.xlabel('Time (seconds)')
|
| 196 |
-
plt.ylabel('Amplitude (mV)')
|
| 197 |
-
plt.grid(True)
|
| 198 |
-
|
| 199 |
-
# Add 1mV scale bar
|
| 200 |
-
plt.plot([1, 1], [-0.5, 0.5], 'r-', linewidth=2)
|
| 201 |
-
plt.text(1.1, 0, '1mV', va='center')
|
| 202 |
-
|
| 203 |
-
# Add time scale bar (1 second)
|
| 204 |
-
y_min = np.min(signal)
|
| 205 |
-
plt.plot([1, 2], [y_min, y_min], 'r-', linewidth=2)
|
| 206 |
-
plt.text(1.5, y_min - 0.1, '1s', ha='center')
|
| 207 |
-
|
| 208 |
-
plt.tight_layout()
|
| 209 |
-
plt.show()
|
| 210 |
|
| 211 |
-
|
|
|
|
| 212 |
"""
|
| 213 |
-
Read a
|
| 214 |
|
| 215 |
Parameters:
|
| 216 |
-----------
|
| 217 |
dat_file_path : str
|
| 218 |
Path to the .dat file (with or without .dat extension)
|
| 219 |
-
sampling_rate : int
|
| 220 |
-
Sampling rate in Hz (default 500Hz)
|
| 221 |
-
data_format : str
|
| 222 |
-
Data format of the binary file: '16' for 16-bit integers, '32' for 32-bit floats
|
| 223 |
-
scale_factor : float
|
| 224 |
-
Scale factor to convert units (0.001 for converting Β΅V to mV)
|
| 225 |
|
| 226 |
Returns:
|
| 227 |
--------
|
| 228 |
numpy.ndarray
|
| 229 |
-
ECG signal data
|
| 230 |
"""
|
| 231 |
-
# Ensure the path ends with .dat
|
| 232 |
if not dat_file_path.endswith('.dat'):
|
| 233 |
dat_file_path += '.dat'
|
| 234 |
|
| 235 |
-
# Expected samples for full 30 seconds
|
| 236 |
-
expected_samples = sampling_rate * 30
|
| 237 |
-
|
| 238 |
-
# Read the binary data
|
| 239 |
try:
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
data = np.fromfile(dat_file_path, dtype=np.int16)
|
| 243 |
-
elif data_format == '32':
|
| 244 |
-
# 32-bit floating point (less common)
|
| 245 |
-
data = np.fromfile(dat_file_path, dtype=np.float32)
|
| 246 |
-
else:
|
| 247 |
-
raise ValueError(f"Unsupported data format: {data_format}")
|
| 248 |
-
|
| 249 |
-
# Apply scaling to convert Β΅V to mV
|
| 250 |
-
signal = data * scale_factor
|
| 251 |
-
|
| 252 |
-
# Handle if signal is not exactly 30 seconds
|
| 253 |
-
if len(signal) < expected_samples:
|
| 254 |
-
# Pad with zeros if too short
|
| 255 |
-
padded_signal = np.zeros(expected_samples)
|
| 256 |
-
padded_signal[:len(signal)] = signal
|
| 257 |
-
signal = padded_signal
|
| 258 |
-
elif len(signal) > expected_samples:
|
| 259 |
-
# Truncate if too long
|
| 260 |
-
signal = signal[:expected_samples]
|
| 261 |
-
|
| 262 |
return signal
|
| 263 |
|
| 264 |
except Exception as e:
|
| 265 |
raise
|
| 266 |
|
| 267 |
-
def segment_signal(signal
|
| 268 |
"""
|
| 269 |
-
Segment a
|
| 270 |
|
| 271 |
Parameters:
|
| 272 |
-----------
|
| 273 |
signal : numpy.ndarray
|
| 274 |
The full signal to segment
|
| 275 |
-
sampling_rate : int
|
| 276 |
-
Sampling rate in Hz
|
| 277 |
|
| 278 |
Returns:
|
| 279 |
--------
|
| 280 |
list
|
| 281 |
-
List of
|
| 282 |
"""
|
| 283 |
-
|
| 284 |
-
segment_samples = sampling_rate * 10
|
| 285 |
-
|
| 286 |
-
# Expected samples for full 30 seconds
|
| 287 |
-
expected_samples = sampling_rate * 30
|
| 288 |
-
|
| 289 |
-
# Ensure the signal is 30 seconds long
|
| 290 |
-
if len(signal) != expected_samples:
|
| 291 |
-
# Resample to 30 seconds
|
| 292 |
-
x = np.linspace(0, 1, len(signal))
|
| 293 |
-
x_new = np.linspace(0, 1, expected_samples)
|
| 294 |
-
f = interp1d(x, signal, kind='linear', bounds_error=False, fill_value="extrapolate")
|
| 295 |
-
signal = f(x_new)
|
| 296 |
|
| 297 |
-
# Split the signal into three 10-second segments
|
| 298 |
segments = []
|
| 299 |
-
|
|
|
|
|
|
|
| 300 |
start_idx = i * segment_samples
|
| 301 |
end_idx = (i + 1) * segment_samples
|
| 302 |
segment = signal[start_idx:end_idx]
|
|
@@ -304,7 +183,7 @@ def segment_signal(signal, sampling_rate=500):
|
|
| 304 |
|
| 305 |
return segments
|
| 306 |
|
| 307 |
-
def process_segment(segment
|
| 308 |
"""
|
| 309 |
Process a segment of ECG data to ensure it's properly formatted for the model
|
| 310 |
|
|
@@ -312,243 +191,123 @@ def process_segment(segment, sampling_rate=500):
|
|
| 312 |
-----------
|
| 313 |
segment : numpy.ndarray
|
| 314 |
Raw ECG segment
|
| 315 |
-
sampling_rate : int
|
| 316 |
-
Sampling rate of the ECG
|
| 317 |
|
| 318 |
Returns:
|
| 319 |
--------
|
| 320 |
numpy.ndarray
|
| 321 |
Processed segment ready for model input
|
| 322 |
"""
|
| 323 |
-
|
| 324 |
-
if len(segment) != 5000:
|
| 325 |
x = np.linspace(0, 1, len(segment))
|
| 326 |
-
x_new = np.linspace(0, 1,
|
| 327 |
f = interp1d(x, segment, kind='linear', bounds_error=False, fill_value="extrapolate")
|
| 328 |
segment = f(x_new)
|
| 329 |
|
|
|
|
|
|
|
| 330 |
return segment
|
| 331 |
|
| 332 |
-
|
|
|
|
| 333 |
"""
|
| 334 |
-
Process
|
| 335 |
-
make predictions on each segment, and return the class with highest average probability.
|
| 336 |
|
| 337 |
Parameters:
|
| 338 |
-----------
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
mlb_path : str, optional
|
| 344 |
-
Path to the saved MultiLabelBinarizer pickle file for label decoding
|
| 345 |
-
sampling_rate : int
|
| 346 |
-
Sampling rate in Hz (default 500Hz)
|
| 347 |
-
scale_factor : float
|
| 348 |
-
Scale factor to convert units (0.001 for converting Β΅V to mV)
|
| 349 |
-
debug : bool
|
| 350 |
-
Whether to print debug information
|
| 351 |
|
| 352 |
Returns:
|
| 353 |
--------
|
| 354 |
dict
|
| 355 |
-
Dictionary containing
|
| 356 |
"""
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
dat_file_path,
|
| 364 |
-
sampling_rate=sampling_rate,
|
| 365 |
-
scale_factor=scale_factor
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
if debug:
|
| 369 |
-
print(f"Signal loaded: {len(full_signal)} samples, range: {np.min(full_signal):.2f} to {np.max(full_signal):.2f} mV")
|
| 370 |
-
|
| 371 |
-
# Step 2: Split into three 10-second segments
|
| 372 |
-
segments = segment_signal(full_signal, sampling_rate)
|
| 373 |
-
|
| 374 |
-
if debug:
|
| 375 |
-
print(f"Split into {len(segments)} segments of {len(segments[0])} samples each")
|
| 376 |
-
|
| 377 |
-
# Step 3: Load the model (load once to improve performance)
|
| 378 |
-
if debug:
|
| 379 |
-
print(f"Loading model from {model_path}")
|
| 380 |
-
|
| 381 |
-
model = tf.keras.models.load_model(model_path)
|
| 382 |
-
|
| 383 |
-
# Load MLB if provided
|
| 384 |
-
mlb = None
|
| 385 |
-
if mlb_path and os.path.exists(mlb_path):
|
| 386 |
-
if debug:
|
| 387 |
-
print(f"Loading label binarizer from {mlb_path}")
|
| 388 |
-
with open(mlb_path, 'rb') as f:
|
| 389 |
-
mlb = pickle.load(f)
|
| 390 |
-
|
| 391 |
-
# Step 4: Process each segment and collect predictions
|
| 392 |
-
segment_results = []
|
| 393 |
-
all_predictions = []
|
| 394 |
-
|
| 395 |
-
for i, segment in enumerate(segments):
|
| 396 |
-
if debug:
|
| 397 |
-
print(f"Processing segment {i+1}...")
|
| 398 |
-
|
| 399 |
-
# Process the segment to ensure it's properly formatted
|
| 400 |
-
processed_segment = process_segment(segment)
|
| 401 |
-
|
| 402 |
-
# Reshape for model input (batch, time, channels)
|
| 403 |
-
X = processed_segment.reshape(1, 5000, 1)
|
| 404 |
-
|
| 405 |
-
# Make predictions
|
| 406 |
-
predictions = model.predict(X, verbose=0)
|
| 407 |
-
all_predictions.append(predictions[0])
|
| 408 |
-
|
| 409 |
-
# Process segment results
|
| 410 |
-
segment_result = {"raw_predictions": predictions[0].tolist()}
|
| 411 |
-
|
| 412 |
-
# Decode labels if MLB is provided
|
| 413 |
-
if mlb is not None:
|
| 414 |
-
# Add class probabilities
|
| 415 |
-
class_probs = {}
|
| 416 |
-
for j, class_name in enumerate(mlb.classes_):
|
| 417 |
-
class_probs[class_name] = float(predictions[0][j])
|
| 418 |
-
|
| 419 |
-
segment_result["class_probabilities"] = class_probs
|
| 420 |
-
|
| 421 |
-
segment_results.append(segment_result)
|
| 422 |
|
| 423 |
-
|
| 424 |
-
final_result = {"segment_results": segment_results}
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
# Find the class with highest average probability
|
| 438 |
-
top_class = max(final_class_probs.items(), key=lambda x: x[1])
|
| 439 |
-
top_class_name = top_class[0]
|
| 440 |
-
|
| 441 |
-
final_result["final_class_probabilities"] = final_class_probs
|
| 442 |
-
final_result["top_class"] = top_class_name
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
return final_result
|
| 448 |
-
|
| 449 |
-
except Exception as e:
|
| 450 |
-
if debug:
|
| 451 |
-
print(f"Error in predict_with_voting: {str(e)}")
|
| 452 |
-
return {"error": str(e)}
|
| 453 |
|
| 454 |
-
def analyze_ecg_pdf(pdf_path, model_path,
|
| 455 |
"""
|
| 456 |
Complete ECG analysis pipeline: digitizes a PDF ECG, analyzes it with the model,
|
| 457 |
-
and returns the
|
| 458 |
|
| 459 |
Args:
|
| 460 |
pdf_path (str): Path to the ECG PDF file
|
| 461 |
model_path (str): Path to the model (.h5) file
|
| 462 |
-
mlb_path (str, optional): Path to the MultiLabelBinarizer file
|
| 463 |
-
temp_dat_file (str, optional): Path to save the temporary digitized file
|
| 464 |
cleanup (bool, optional): Whether to remove temporary files after processing
|
| 465 |
-
debug (bool, optional): Whether to print debug information
|
| 466 |
-
visualize (bool, optional): Whether to visualize the digitized signal
|
| 467 |
|
| 468 |
Returns:
|
| 469 |
dict: {
|
| 470 |
-
"
|
| 471 |
-
"probability": float, # Probability of top
|
| 472 |
-
"all_probabilities": dict, # All
|
| 473 |
"digitized_file": str # Path to digitized file (if cleanup=False)
|
| 474 |
}
|
| 475 |
"""
|
| 476 |
-
# Silence TensorFlow warnings
|
| 477 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 478 |
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
signal = read_lead_i_long_dat_file(dat_file_path, scale_factor=0.001)
|
| 488 |
-
visualize_ecg_signal(signal, title=f"Digitized ECG from {os.path.basename(pdf_path)}")
|
| 489 |
-
|
| 490 |
-
# 2. Process DAT file with model
|
| 491 |
-
if debug:
|
| 492 |
-
print("Processing digitized signal with model...")
|
| 493 |
-
|
| 494 |
-
results = predict_with_voting(
|
| 495 |
-
dat_file_path,
|
| 496 |
-
model_path,
|
| 497 |
-
mlb_path,
|
| 498 |
-
scale_factor=0.001, # Convert microvolts to millivolts
|
| 499 |
-
debug=debug
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
# 3. Extract top diagnosis (highest probability)
|
| 503 |
-
top_diagnosis = {
|
| 504 |
-
"diagnosis": None,
|
| 505 |
-
"probability": 0.0,
|
| 506 |
-
"all_probabilities": {},
|
| 507 |
-
"digitized_file": dat_file_path
|
| 508 |
-
}
|
| 509 |
-
|
| 510 |
-
# If we have class probabilities, find the highest one
|
| 511 |
-
if "final_class_probabilities" in results:
|
| 512 |
-
probs = results["final_class_probabilities"]
|
| 513 |
-
top_diagnosis["all_probabilities"] = probs
|
| 514 |
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
|
| 530 |
-
return top_diagnosis
|
| 531 |
|
| 532 |
-
# Example usage
|
| 533 |
-
if __name__ == "__main__":
|
| 534 |
-
# Path configuration
|
| 535 |
-
sample_pdf = 'samplebayez.pdf'
|
| 536 |
-
model_path = 'deep-multiclass.h5' # Update with actual path
|
| 537 |
-
mlb_path = 'deep-multiclass.pkl' # Update with actual path
|
| 538 |
-
|
| 539 |
-
# Analyze ECG with debug output and visualization
|
| 540 |
-
result = analyze_ecg_pdf(
|
| 541 |
-
sample_pdf,
|
| 542 |
-
model_path,
|
| 543 |
-
mlb_path,
|
| 544 |
-
cleanup=False, # Keep the digitized file
|
| 545 |
-
debug=False, # Print debug information
|
| 546 |
-
visualize=False # Visualize the digitized signal
|
| 547 |
-
)
|
| 548 |
-
|
| 549 |
-
# Display result
|
| 550 |
-
if result["diagnosis"]:
|
| 551 |
-
print(f"Diagnosis: {result['diagnosis']} ")
|
| 552 |
|
| 553 |
-
|
| 554 |
-
print("No clear diagnosis found")
|
|
|
|
| 1 |
"""
|
| 2 |
+
ECG Analysis Pipeline: From PDF to Arrhythmia Classification
|
| 3 |
+
-----------------------------------------------------------
|
| 4 |
This module provides functions to:
|
| 5 |
1. Digitize ECG from PDF files
|
| 6 |
2. Process the digitized ECG signal
|
| 7 |
+
3. Classify arrhythmias using a trained CNN model
|
| 8 |
"""
|
| 9 |
|
| 10 |
import cv2
|
|
|
|
| 13 |
import tensorflow as tf
|
| 14 |
import pickle
|
| 15 |
from scipy.interpolate import interp1d
|
|
|
|
| 16 |
from pdf2image import convert_from_path
|
|
|
|
| 17 |
|
| 18 |
+
ARRHYTHMIA_CLASSES = ["Conduction Abnormalities", "Atrial Arrhythmias", "Tachyarrhythmias", "Normal"]
|
| 19 |
+
SAMPLING_RATE = 500
|
| 20 |
+
SEGMENT_DURATION = 10.0
|
| 21 |
+
TARGET_SEGMENT_LENGTH = 5000
|
| 22 |
+
DEFAULT_OUTPUT_FILE = 'calibrated_ecg.dat'
|
| 23 |
+
DAT_SCALE_FACTOR = 0.001
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def digitize_ecg_from_pdf(pdf_path, output_file=None):
|
| 27 |
"""
|
| 28 |
Process an ECG PDF file and convert it to a .dat signal file.
|
| 29 |
|
| 30 |
Args:
|
| 31 |
pdf_path (str): Path to the ECG PDF file
|
| 32 |
+
output_file (str, optional): Path to save the output .dat file
|
|
|
|
| 33 |
|
| 34 |
Returns:
|
| 35 |
+
tuple: (path to the created .dat file, list of paths to segment files)
|
| 36 |
"""
|
| 37 |
+
if output_file is None:
|
| 38 |
+
output_file = DEFAULT_OUTPUT_FILE
|
| 39 |
|
|
|
|
| 40 |
images = convert_from_path(pdf_path)
|
| 41 |
temp_image_path = 'temp_ecg_image.jpg'
|
| 42 |
images[0].save(temp_image_path, 'JPEG')
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
|
| 45 |
height, width = img.shape
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
calibration = {
|
| 48 |
+
'seconds_per_pixel': 2.0 / 197.0,
|
| 49 |
+
'mv_per_pixel': 1.0 / 78.8,
|
| 50 |
}
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
layer1_start = int(height * 35.35 / 100)
|
| 53 |
layer1_end = int(height * 51.76 / 100)
|
| 54 |
layer2_start = int(height * 51.82 / 100)
|
|
|
|
| 56 |
layer3_start = int(height * 69.47 / 100)
|
| 57 |
layer3_end = int(height * 87.06 / 100)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
layers = [
|
| 60 |
+
img[layer1_start:layer1_end, :],
|
| 61 |
+
img[layer2_start:layer2_end, :],
|
| 62 |
+
img[layer3_start:layer3_end, :]
|
| 63 |
]
|
| 64 |
|
|
|
|
| 65 |
signals = []
|
| 66 |
time_points = []
|
| 67 |
+
layer_duration = 10.0
|
| 68 |
|
| 69 |
for i, layer in enumerate(layers):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
_, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
|
| 71 |
|
|
|
|
| 72 |
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 73 |
+
waveform_contour = max(contours, key=cv2.contourArea)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
|
|
|
| 75 |
sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
|
| 76 |
x_coords = np.array([point[0][0] for point in sorted_contour])
|
| 77 |
y_coords = np.array([point[0][1] for point in sorted_contour])
|
| 78 |
|
|
|
|
| 79 |
isoelectric_line_y = layer.shape[0] * 0.6
|
| 80 |
|
|
|
|
| 81 |
x_min, x_max = np.min(x_coords), np.max(x_coords)
|
| 82 |
time = (x_coords - x_min) / (x_max - x_min) * layer_duration
|
| 83 |
|
|
|
|
| 84 |
signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
|
| 85 |
signal_mv = signal_mv - np.mean(signal_mv)
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
time_points.append(time)
|
| 88 |
signals.append(signal_mv)
|
| 89 |
|
|
|
|
| 90 |
total_duration = layer_duration * len(layers)
|
| 91 |
+
sampling_frequency = 500
|
| 92 |
num_samples = int(total_duration * sampling_frequency)
|
| 93 |
combined_time = np.linspace(0, total_duration, num_samples)
|
| 94 |
combined_signal = np.zeros(num_samples)
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
for i, (time, signal) in enumerate(zip(time_points, signals)):
|
| 97 |
start_time = i * layer_duration
|
| 98 |
mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
|
| 99 |
relevant_times = combined_time[mask]
|
| 100 |
interpolated_signal = np.interp(relevant_times, start_time + time, signal)
|
| 101 |
combined_signal[mask] = interpolated_signal
|
|
|
|
|
|
|
|
|
|
| 102 |
|
|
|
|
| 103 |
combined_signal = combined_signal - np.mean(combined_signal)
|
| 104 |
signal_peak = np.max(np.abs(combined_signal))
|
| 105 |
+
target_amplitude = 2.0
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
|
| 108 |
scaling_factor = target_amplitude / signal_peak
|
| 109 |
combined_signal = combined_signal * scaling_factor
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
adc_gain = 1000.0
|
|
|
|
| 112 |
int_signal = (combined_signal * adc_gain).astype(np.int16)
|
| 113 |
int_signal.tofile(output_file)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if os.path.exists(temp_image_path):
|
| 116 |
os.remove(temp_image_path)
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
segment_files = []
|
| 119 |
+
samples_per_segment = int(layer_duration * sampling_frequency)
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
base_name = os.path.splitext(output_file)[0]
|
| 122 |
+
for i in range(3):
|
| 123 |
+
start_idx = i * samples_per_segment
|
| 124 |
+
end_idx = (i + 1) * samples_per_segment
|
| 125 |
+
segment = combined_signal[start_idx:end_idx]
|
| 126 |
+
|
| 127 |
+
segment_file = f"{base_name}_segment{i+1}.dat"
|
| 128 |
+
(segment * adc_gain).astype(np.int16).tofile(segment_file)
|
| 129 |
+
segment_files.append(segment_file)
|
| 130 |
+
|
| 131 |
+
return output_file, segment_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
|
| 134 |
+
def read_ecg_dat_file(dat_file_path):
|
| 135 |
"""
|
| 136 |
+
Read a DAT file directly and properly scale it
|
| 137 |
|
| 138 |
Parameters:
|
| 139 |
-----------
|
| 140 |
dat_file_path : str
|
| 141 |
Path to the .dat file (with or without .dat extension)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
Returns:
|
| 144 |
--------
|
| 145 |
numpy.ndarray
|
| 146 |
+
ECG signal data with shape (total_samples,)
|
| 147 |
"""
|
|
|
|
| 148 |
if not dat_file_path.endswith('.dat'):
|
| 149 |
dat_file_path += '.dat'
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
try:
|
| 152 |
+
data = np.fromfile(dat_file_path, dtype=np.int16)
|
| 153 |
+
signal = data * DAT_SCALE_FACTOR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
return signal
|
| 155 |
|
| 156 |
except Exception as e:
|
| 157 |
raise
|
| 158 |
|
| 159 |
+
def segment_signal(signal):
|
| 160 |
"""
|
| 161 |
+
Segment a signal into equal-length segments
|
| 162 |
|
| 163 |
Parameters:
|
| 164 |
-----------
|
| 165 |
signal : numpy.ndarray
|
| 166 |
The full signal to segment
|
|
|
|
|
|
|
| 167 |
|
| 168 |
Returns:
|
| 169 |
--------
|
| 170 |
list
|
| 171 |
+
List of signal segments
|
| 172 |
"""
|
| 173 |
+
segment_samples = int(SAMPLING_RATE * SEGMENT_DURATION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
|
|
|
| 175 |
segments = []
|
| 176 |
+
num_segments = len(signal) // segment_samples
|
| 177 |
+
|
| 178 |
+
for i in range(num_segments):
|
| 179 |
start_idx = i * segment_samples
|
| 180 |
end_idx = (i + 1) * segment_samples
|
| 181 |
segment = signal[start_idx:end_idx]
|
|
|
|
| 183 |
|
| 184 |
return segments
|
| 185 |
|
| 186 |
+
def process_segment(segment):
|
| 187 |
"""
|
| 188 |
Process a segment of ECG data to ensure it's properly formatted for the model
|
| 189 |
|
|
|
|
| 191 |
-----------
|
| 192 |
segment : numpy.ndarray
|
| 193 |
Raw ECG segment
|
|
|
|
|
|
|
| 194 |
|
| 195 |
Returns:
|
| 196 |
--------
|
| 197 |
numpy.ndarray
|
| 198 |
Processed segment ready for model input
|
| 199 |
"""
|
| 200 |
+
if len(segment) != TARGET_SEGMENT_LENGTH:
|
|
|
|
| 201 |
x = np.linspace(0, 1, len(segment))
|
| 202 |
+
x_new = np.linspace(0, 1, TARGET_SEGMENT_LENGTH)
|
| 203 |
f = interp1d(x, segment, kind='linear', bounds_error=False, fill_value="extrapolate")
|
| 204 |
segment = f(x_new)
|
| 205 |
|
| 206 |
+
segment = (segment - np.mean(segment)) / (np.std(segment) + 1e-8)
|
| 207 |
+
|
| 208 |
return segment
|
| 209 |
|
| 210 |
+
|
| 211 |
+
def predict_with_cnn_model(signal_data, model):
|
| 212 |
"""
|
| 213 |
+
Process signal data and make predictions using the CNN model.
|
|
|
|
| 214 |
|
| 215 |
Parameters:
|
| 216 |
-----------
|
| 217 |
+
signal_data : numpy.ndarray
|
| 218 |
+
Raw signal data
|
| 219 |
+
model : tensorflow.keras.Model
|
| 220 |
+
Loaded CNN model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
Returns:
|
| 223 |
--------
|
| 224 |
dict
|
| 225 |
+
Dictionary containing predictions for each segment and final averaged prediction
|
| 226 |
"""
|
| 227 |
+
segments = segment_signal(signal_data)
|
| 228 |
+
|
| 229 |
+
all_predictions = []
|
| 230 |
+
|
| 231 |
+
for i, segment in enumerate(segments):
|
| 232 |
+
processed_segment = process_segment(segment)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
X = processed_segment.reshape(1, TARGET_SEGMENT_LENGTH, 1)
|
|
|
|
| 235 |
|
| 236 |
+
prediction = model.predict(X, verbose=0)
|
| 237 |
+
all_predictions.append(prediction[0])
|
| 238 |
+
|
| 239 |
+
if all_predictions:
|
| 240 |
+
avg_prediction = np.mean(all_predictions, axis=0)
|
| 241 |
+
top_class_idx = np.argmax(avg_prediction)
|
| 242 |
|
| 243 |
+
results = {
|
| 244 |
+
"segment_predictions": all_predictions,
|
| 245 |
+
"averaged_prediction": avg_prediction,
|
| 246 |
+
"top_class_index": top_class_idx,
|
| 247 |
+
"top_class": ARRHYTHMIA_CLASSES[top_class_idx],
|
| 248 |
+
"probability": float(avg_prediction[top_class_idx])
|
| 249 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
return results
|
| 252 |
+
else:
|
| 253 |
+
return {"error": "No valid segments for prediction"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
def analyze_ecg_pdf(pdf_path, model_path, cleanup=True):
|
| 256 |
"""
|
| 257 |
Complete ECG analysis pipeline: digitizes a PDF ECG, analyzes it with the model,
|
| 258 |
+
and returns the arrhythmia classification with highest probability.
|
| 259 |
|
| 260 |
Args:
|
| 261 |
pdf_path (str): Path to the ECG PDF file
|
| 262 |
model_path (str): Path to the model (.h5) file
|
|
|
|
|
|
|
| 263 |
cleanup (bool, optional): Whether to remove temporary files after processing
|
|
|
|
|
|
|
| 264 |
|
| 265 |
Returns:
|
| 266 |
dict: {
|
| 267 |
+
"arrhythmia_class": str, # Top arrhythmia class
|
| 268 |
+
"probability": float, # Probability of top class
|
| 269 |
+
"all_probabilities": dict, # All classes with probabilities
|
| 270 |
"digitized_file": str # Path to digitized file (if cleanup=False)
|
| 271 |
}
|
| 272 |
"""
|
|
|
|
| 273 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 274 |
|
| 275 |
+
try:
|
| 276 |
+
dat_file_path, segment_files = digitize_ecg_from_pdf(pdf_path)
|
| 277 |
+
|
| 278 |
+
ecg_model = tf.keras.models.load_model(model_path)
|
| 279 |
+
|
| 280 |
+
ecg_signal = read_ecg_dat_file(dat_file_path)
|
| 281 |
+
|
| 282 |
+
classification_results = predict_with_cnn_model(ecg_signal, ecg_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
arrhythmia_result = {
|
| 285 |
+
"arrhythmia_class": classification_results.get("top_class"),
|
| 286 |
+
"probability": classification_results.get("probability", 0.0),
|
| 287 |
+
"all_probabilities": {}
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
if "averaged_prediction" in classification_results:
|
| 291 |
+
for idx, class_name in enumerate(ARRHYTHMIA_CLASSES):
|
| 292 |
+
arrhythmia_result["all_probabilities"][class_name] = float(classification_results["averaged_prediction"][idx])
|
| 293 |
+
|
| 294 |
+
if not cleanup:
|
| 295 |
+
arrhythmia_result["digitized_file"] = dat_file_path
|
| 296 |
+
|
| 297 |
+
if cleanup:
|
| 298 |
+
if os.path.exists(dat_file_path):
|
| 299 |
+
os.remove(dat_file_path)
|
| 300 |
|
| 301 |
+
for segment_file in segment_files:
|
| 302 |
+
if os.path.exists(segment_file):
|
| 303 |
+
os.remove(segment_file)
|
| 304 |
+
|
| 305 |
+
return arrhythmia_result
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
error_msg = f"Error in ECG analysis: {str(e)}"
|
| 309 |
+
return {"error": error_msg}
|
| 310 |
|
|
|
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
+
|
|
|
SkinBurns_Classification.py β SkinBurns/SkinBurns_Classification.py
RENAMED
|
File without changes
|
SkinBurns_Segmentation.py β SkinBurns/SkinBurns_Segmentation.py
RENAMED
|
File without changes
|
app.py
CHANGED
|
@@ -9,8 +9,8 @@ from pymongo.server_api import ServerApi
|
|
| 9 |
import cloudinary
|
| 10 |
import cloudinary.uploader
|
| 11 |
from cloudinary.utils import cloudinary_url
|
| 12 |
-
from SkinBurns_Classification import FullFeautures
|
| 13 |
-
from SkinBurns_Segmentation import segment_burn
|
| 14 |
import requests
|
| 15 |
import joblib
|
| 16 |
import numpy as np
|
|
@@ -63,7 +63,7 @@ except Exception as e:
|
|
| 63 |
cloudinary.config(
|
| 64 |
cloud_name = "darumyfpl",
|
| 65 |
api_key = "493972437417214",
|
| 66 |
-
api_secret = "jjOScVGochJYA7IxDam7L4HU2Ig",
|
| 67 |
secure=True
|
| 68 |
)
|
| 69 |
|
|
@@ -148,14 +148,14 @@ async def segment_burn_endpoint(reference: UploadFile = File(...), patient: Uplo
|
|
| 148 |
|
| 149 |
@app.post("/classify-ecg")
|
| 150 |
async def classify_ecg_endpoint(file: UploadFile = File(...)):
|
| 151 |
-
model = joblib.load('
|
| 152 |
|
| 153 |
try:
|
| 154 |
temp_file_path = f"temp_{file.filename}"
|
| 155 |
with open(temp_file_path, "wb") as temp_file:
|
| 156 |
temp_file.write(await file.read())
|
| 157 |
|
| 158 |
-
result = classify_ecg(temp_file_path, model,
|
| 159 |
|
| 160 |
os.remove(temp_file_path)
|
| 161 |
|
|
@@ -167,31 +167,22 @@ async def classify_ecg_endpoint(file: UploadFile = File(...)):
|
|
| 167 |
@app.post("/diagnose-ecg")
|
| 168 |
async def diagnose_ecg(file: UploadFile = File(...)):
|
| 169 |
try:
|
| 170 |
-
# Save the uploaded file temporarily
|
| 171 |
temp_file_path = f"temp_{file.filename}"
|
| 172 |
with open(temp_file_path, "wb") as temp_file:
|
| 173 |
temp_file.write(await file.read())
|
| 174 |
|
| 175 |
-
model_path = '
|
| 176 |
-
mlb_path = 'deep-multiclass.pkl' # Update with actual path
|
| 177 |
-
|
| 178 |
|
| 179 |
-
# Call the ECG classification function
|
| 180 |
result = analyze_ecg_pdf(
|
| 181 |
temp_file_path,
|
| 182 |
model_path,
|
| 183 |
-
|
| 184 |
-
cleanup=False, # Keep the digitized file
|
| 185 |
-
debug=False, # Print debug information
|
| 186 |
-
visualize=False # Visualize the digitized signal
|
| 187 |
)
|
| 188 |
-
|
| 189 |
|
| 190 |
-
# Remove the temporary file
|
| 191 |
os.remove(temp_file_path)
|
| 192 |
|
| 193 |
-
if result and result["
|
| 194 |
-
return {"result": result["
|
| 195 |
else:
|
| 196 |
return {"result": "No diagnosis"}
|
| 197 |
|
|
@@ -216,7 +207,6 @@ async def process_video(file: UploadFile = File(...)):
|
|
| 216 |
print("File content type:", file.content_type)
|
| 217 |
print("File filename:", file.filename)
|
| 218 |
|
| 219 |
-
# Prepare directories
|
| 220 |
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 221 |
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
|
| 222 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
@@ -259,7 +249,6 @@ async def process_video(file: UploadFile = File(...)):
|
|
| 259 |
overwrite=True
|
| 260 |
)
|
| 261 |
|
| 262 |
-
# Add new warning with image_url and description
|
| 263 |
warnings.append({
|
| 264 |
"image_url": upload_result['secure_url'],
|
| 265 |
"description": description
|
|
@@ -279,7 +268,6 @@ async def process_video(file: UploadFile = File(...)):
|
|
| 279 |
else:
|
| 280 |
wholevideoURL = None
|
| 281 |
|
| 282 |
-
# Upload graph output
|
| 283 |
graphURL = None
|
| 284 |
if os.path.isfile(plot_output_path):
|
| 285 |
upload_graph_result = cloudinary.uploader.upload(
|
|
@@ -310,7 +298,7 @@ clients = set()
|
|
| 310 |
analyzer_thread = None
|
| 311 |
analysis_started = False
|
| 312 |
analyzer_lock = threading.Lock()
|
| 313 |
-
socket_server: AnalysisSocketServer = None
|
| 314 |
|
| 315 |
|
| 316 |
async def forward_results_from_queue(websocket: WebSocket, warning_queue):
|
|
@@ -367,11 +355,9 @@ async def websocket_analysis(websocket: WebSocket):
|
|
| 367 |
logger.info("[WebSocket] Flutter connected")
|
| 368 |
|
| 369 |
try:
|
| 370 |
-
# Wait for the client to send the stream URL as first message
|
| 371 |
source = await websocket.receive_text()
|
| 372 |
logger.info(f"[WebSocket] Received stream URL: {source}")
|
| 373 |
|
| 374 |
-
# Ensure analyzer starts only once using a thread-safe lock
|
| 375 |
with analyzer_lock:
|
| 376 |
if not analysis_started:
|
| 377 |
requested_fps = 30
|
|
@@ -386,7 +372,6 @@ async def websocket_analysis(websocket: WebSocket):
|
|
| 386 |
analysis_started = True
|
| 387 |
logger.info("[WebSocket] Analysis thread started")
|
| 388 |
|
| 389 |
-
# Rest of your existing code remains exactly the same...
|
| 390 |
while socket_server is None or socket_server.warning_queue is None:
|
| 391 |
await asyncio.sleep(0.1)
|
| 392 |
|
|
@@ -395,7 +380,7 @@ async def websocket_analysis(websocket: WebSocket):
|
|
| 395 |
)
|
| 396 |
|
| 397 |
while True:
|
| 398 |
-
await asyncio.sleep(1)
|
| 399 |
|
| 400 |
except WebSocketDisconnect:
|
| 401 |
logger.warning("[WebSocket] Client disconnected")
|
|
@@ -403,7 +388,7 @@ async def websocket_analysis(websocket: WebSocket):
|
|
| 403 |
forward_task.cancel()
|
| 404 |
except Exception as e:
|
| 405 |
logger.error(f"[WebSocket] Error receiving stream URL: {str(e)}")
|
| 406 |
-
await websocket.close(code=1011)
|
| 407 |
finally:
|
| 408 |
clients.discard(websocket)
|
| 409 |
logger.info(f"[WebSocket] Active clients: {len(clients)}")
|
|
|
|
| 9 |
import cloudinary
|
| 10 |
import cloudinary.uploader
|
| 11 |
from cloudinary.utils import cloudinary_url
|
| 12 |
+
from SkinBurns.SkinBurns_Classification import FullFeautures
|
| 13 |
+
from SkinBurns.SkinBurns_Segmentation import segment_burn
|
| 14 |
import requests
|
| 15 |
import joblib
|
| 16 |
import numpy as np
|
|
|
|
| 63 |
cloudinary.config(
|
| 64 |
cloud_name = "darumyfpl",
|
| 65 |
api_key = "493972437417214",
|
| 66 |
+
api_secret = "jjOScVGochJYA7IxDam7L4HU2Ig",
|
| 67 |
secure=True
|
| 68 |
)
|
| 69 |
|
|
|
|
| 148 |
|
| 149 |
@app.post("/classify-ecg")
|
| 150 |
async def classify_ecg_endpoint(file: UploadFile = File(...)):
|
| 151 |
+
model = joblib.load('voting_classifier_arrhythmia.pkl')
|
| 152 |
|
| 153 |
try:
|
| 154 |
temp_file_path = f"temp_{file.filename}"
|
| 155 |
with open(temp_file_path, "wb") as temp_file:
|
| 156 |
temp_file.write(await file.read())
|
| 157 |
|
| 158 |
+
result = classify_ecg(temp_file_path, model, is_pdf=True)
|
| 159 |
|
| 160 |
os.remove(temp_file_path)
|
| 161 |
|
|
|
|
| 167 |
@app.post("/diagnose-ecg")
|
| 168 |
async def diagnose_ecg(file: UploadFile = File(...)):
|
| 169 |
try:
|
|
|
|
| 170 |
temp_file_path = f"temp_{file.filename}"
|
| 171 |
with open(temp_file_path, "wb") as temp_file:
|
| 172 |
temp_file.write(await file.read())
|
| 173 |
|
| 174 |
+
model_path = 'Arrhythmia_Model_with_SMOTE.h5'
|
|
|
|
|
|
|
| 175 |
|
|
|
|
| 176 |
result = analyze_ecg_pdf(
|
| 177 |
temp_file_path,
|
| 178 |
model_path,
|
| 179 |
+
cleanup=False
|
|
|
|
|
|
|
|
|
|
| 180 |
)
|
|
|
|
| 181 |
|
|
|
|
| 182 |
os.remove(temp_file_path)
|
| 183 |
|
| 184 |
+
if result and result["arrhythmia_class"]:
|
| 185 |
+
return {"result": result["arrhythmia_class"]}
|
| 186 |
else:
|
| 187 |
return {"result": "No diagnosis"}
|
| 188 |
|
|
|
|
| 207 |
print("File content type:", file.content_type)
|
| 208 |
print("File filename:", file.filename)
|
| 209 |
|
|
|
|
| 210 |
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 211 |
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
|
| 212 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
| 249 |
overwrite=True
|
| 250 |
)
|
| 251 |
|
|
|
|
| 252 |
warnings.append({
|
| 253 |
"image_url": upload_result['secure_url'],
|
| 254 |
"description": description
|
|
|
|
| 268 |
else:
|
| 269 |
wholevideoURL = None
|
| 270 |
|
|
|
|
| 271 |
graphURL = None
|
| 272 |
if os.path.isfile(plot_output_path):
|
| 273 |
upload_graph_result = cloudinary.uploader.upload(
|
|
|
|
| 298 |
analyzer_thread = None
|
| 299 |
analysis_started = False
|
| 300 |
analyzer_lock = threading.Lock()
|
| 301 |
+
socket_server: AnalysisSocketServer = None
|
| 302 |
|
| 303 |
|
| 304 |
async def forward_results_from_queue(websocket: WebSocket, warning_queue):
|
|
|
|
| 355 |
logger.info("[WebSocket] Flutter connected")
|
| 356 |
|
| 357 |
try:
|
|
|
|
| 358 |
source = await websocket.receive_text()
|
| 359 |
logger.info(f"[WebSocket] Received stream URL: {source}")
|
| 360 |
|
|
|
|
| 361 |
with analyzer_lock:
|
| 362 |
if not analysis_started:
|
| 363 |
requested_fps = 30
|
|
|
|
| 372 |
analysis_started = True
|
| 373 |
logger.info("[WebSocket] Analysis thread started")
|
| 374 |
|
|
|
|
| 375 |
while socket_server is None or socket_server.warning_queue is None:
|
| 376 |
await asyncio.sleep(0.1)
|
| 377 |
|
|
|
|
| 380 |
)
|
| 381 |
|
| 382 |
while True:
|
| 383 |
+
await asyncio.sleep(1)
|
| 384 |
|
| 385 |
except WebSocketDisconnect:
|
| 386 |
logger.warning("[WebSocket] Client disconnected")
|
|
|
|
| 388 |
forward_task.cancel()
|
| 389 |
except Exception as e:
|
| 390 |
logger.error(f"[WebSocket] Error receiving stream URL: {str(e)}")
|
| 391 |
+
await websocket.close(code=1011)
|
| 392 |
finally:
|
| 393 |
clients.discard(websocket)
|
| 394 |
logger.info(f"[WebSocket] Active clients: {len(clients)}")
|
voting_classifier.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:52e5d9789c5a5f6b42f595fceb67948bc15e9c9035de9c02f72cf29ff42c9d93
|
| 3 |
-
size 4084247
|
|
|
|
|
|
|
|
|
|
|
|
deep-multiclass.pkl β voting_classifier_arrhythmia.pkl
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f1a2aff5dffb25a19be3bcaa4db79373dcad23355ba9b166ca1d2a8978e3600
|
| 3 |
+
size 50223409
|