Hussein El-Hadidy commited on
Commit
676f928
Β·
1 Parent(s): ef22c1c

Latest ECG

Browse files
deep-multiclass.h5 β†’ Arrhythmia_Model_with_SMOTE.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a2fdc1bcbf7820ae426a9a74f0210c738884ee2bb872bccff55a4036e2c642e1
3
- size 8941912
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7145bc906536b54953df3e7034d0da786c24261ec4080fbfd51031d19895e713
3
+ size 8923320
ECG/ECG_Classify.py CHANGED
@@ -1,57 +1,113 @@
1
- import wfdb # To read the ECG files
2
- from wfdb import processing # For QRS detection
3
- import numpy as np # Numerical operations
4
- import joblib # To load the saved model
5
- import pywt # For wavelet feature extraction
6
- import os # For file operations
7
- import cv2 # For image processing
8
- from pdf2image import convert_from_path # For PDF to image conversion
9
  import warnings
10
  import pickle
11
- import sklearn
 
12
 
13
- # Let's modify the digitize_ecg_from_pdf function to return segment information
14
- def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=False, save_segments=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
  Process an ECG PDF file and convert it to a .dat signal file.
17
 
18
  Args:
19
  pdf_path (str): Path to the ECG PDF file
20
- output_file (str): Path to save the output .dat file (default: 'calibrated_ecg.dat')
21
- debug (bool): Whether to print debug information
22
- save_segments (bool): Whether to save individual segments
23
 
24
  Returns:
25
  tuple: (path to the created .dat file, list of paths to segment files)
26
  """
27
- if debug:
28
- print(f"Starting ECG digitization from PDF: {pdf_path}")
29
 
30
- # Convert PDF to image
31
  images = convert_from_path(pdf_path)
32
  temp_image_path = 'temp_ecg_image.jpg'
33
  images[0].save(temp_image_path, 'JPEG')
34
 
35
- if debug:
36
- print(f"Converted PDF to image: {temp_image_path}")
37
-
38
- # Load the image
39
  img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
40
  height, width = img.shape
41
 
42
- if debug:
43
- print(f"Image dimensions: {width}x{height}")
44
-
45
- # Fixed calibration parameters
46
  calibration = {
47
- 'seconds_per_pixel': 2.0 / 197.0, # 197 pixels = 2 seconds
48
- 'mv_per_pixel': 1.0 / 78.8, # 78.8 pixels = 1 mV
49
  }
50
 
51
- if debug:
52
- print(f"Calibration parameters: {calibration}")
53
-
54
- # Calculate layer boundaries using percentages
55
  layer1_start = int(height * 35.35 / 100)
56
  layer1_end = int(height * 51.76 / 100)
57
  layer2_start = int(height * 51.82 / 100)
@@ -59,210 +115,123 @@ def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=Fals
59
  layer3_start = int(height * 69.47 / 100)
60
  layer3_end = int(height * 87.06 / 100)
61
 
62
- if debug:
63
- print(f"Layer 1 boundaries: {layer1_start}-{layer1_end}")
64
- print(f"Layer 2 boundaries: {layer2_start}-{layer2_end}")
65
- print(f"Layer 3 boundaries: {layer3_start}-{layer3_end}")
66
-
67
- # Crop each layer
68
  layers = [
69
- img[layer1_start:layer1_end, :], # Layer 1
70
- img[layer2_start:layer2_end, :], # Layer 2
71
- img[layer3_start:layer3_end, :] # Layer 3
72
  ]
73
 
74
- # Process each layer to extract waveform contours
75
  signals = []
76
  time_points = []
77
- layer_duration = 10.0 # Each layer is 10 seconds long
78
 
79
  for i, layer in enumerate(layers):
80
- if debug:
81
- print(f"Processing layer {i+1}...")
82
-
83
- # Binary thresholding
84
  _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
85
 
86
- # Detect contours
87
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
88
- waveform_contour = max(contours, key=cv2.contourArea) # Largest contour is the ECG
89
-
90
- if debug:
91
- print(f" - Found {len(contours)} contours")
92
- print(f" - Selected contour with {len(waveform_contour)} points")
93
 
94
- # Sort contour points and extract coordinates
95
  sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
96
  x_coords = np.array([point[0][0] for point in sorted_contour])
97
  y_coords = np.array([point[0][1] for point in sorted_contour])
98
 
99
- # Calculate isoelectric line (one-third from the bottom)
100
  isoelectric_line_y = layer.shape[0] * 0.6
101
 
102
- # Convert to time using fixed layer duration
103
  x_min, x_max = np.min(x_coords), np.max(x_coords)
104
  time = (x_coords - x_min) / (x_max - x_min) * layer_duration
105
 
106
- # Calculate signal in millivolts and apply baseline correction
107
  signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
108
  signal_mv = signal_mv - np.mean(signal_mv)
109
 
110
- if debug:
111
- print(f" - Layer {i+1} signal range: {np.min(signal_mv):.2f} mV to {np.max(signal_mv):.2f} mV")
112
-
113
- # Store the time points and calibrated signal
114
  time_points.append(time)
115
  signals.append(signal_mv)
116
 
117
- # Save individual segments if requested
118
- segment_files = []
119
- sampling_frequency = 500 # Standard ECG frequency
120
- samples_per_segment = int(layer_duration * sampling_frequency) # 5000 samples per 10-second segment
121
-
122
- if save_segments:
123
- base_name = os.path.splitext(output_file)[0]
124
-
125
- for i, signal in enumerate(signals):
126
- # Interpolate to get evenly sampled signal
127
- segment_time = np.linspace(0, layer_duration, samples_per_segment)
128
- interpolated_signal = np.interp(segment_time, time_points[i], signals[i])
129
-
130
- # Normalize and scale
131
- interpolated_signal = interpolated_signal - np.mean(interpolated_signal)
132
- signal_peak = np.max(np.abs(interpolated_signal))
133
-
134
- if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
135
- scaling_factor = 2.0 / signal_peak # Target peak amplitude of 2.0 mV
136
- interpolated_signal = interpolated_signal * scaling_factor
137
-
138
- # Convert to 16-bit integers
139
- adc_gain = 1000.0
140
- int_signal = (interpolated_signal * adc_gain).astype(np.int16)
141
-
142
- # Save segment
143
- segment_file = f"{base_name}_segment{i+1}.dat"
144
- int_signal.reshape(-1, 1).tofile(segment_file)
145
- segment_files.append(segment_file)
146
-
147
- if debug:
148
- print(f"Saved segment {i+1} to {segment_file}")
149
-
150
- # Combine signals with proper time alignment for the full record
151
  total_duration = layer_duration * len(layers)
 
152
  num_samples = int(total_duration * sampling_frequency)
153
  combined_time = np.linspace(0, total_duration, num_samples)
154
  combined_signal = np.zeros(num_samples)
155
 
156
- if debug:
157
- print(f"Combining signals with {sampling_frequency} Hz sampling rate, total duration: {total_duration}s")
158
-
159
- # Place each lead at the correct time position
160
  for i, (time, signal) in enumerate(zip(time_points, signals)):
161
  start_time = i * layer_duration
162
  mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
163
  relevant_times = combined_time[mask]
164
  interpolated_signal = np.interp(relevant_times, start_time + time, signal)
165
  combined_signal[mask] = interpolated_signal
166
-
167
- if debug:
168
- print(f" - Added layer {i+1} signal from {start_time}s to {start_time + layer_duration}s")
169
 
170
- # Baseline correction and amplitude scaling
171
  combined_signal = combined_signal - np.mean(combined_signal)
172
  signal_peak = np.max(np.abs(combined_signal))
173
- target_amplitude = 2.0 # Target peak amplitude in mV
174
-
175
- if debug:
176
- print(f"Signal peak before scaling: {signal_peak:.2f} mV")
177
 
178
  if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
179
  scaling_factor = target_amplitude / signal_peak
180
  combined_signal = combined_signal * scaling_factor
181
- if debug:
182
- print(f"Applied scaling factor: {scaling_factor:.2f}")
183
- print(f"Signal peak after scaling: {np.max(np.abs(combined_signal)):.2f} mV")
184
 
185
- # Convert to 16-bit integers and save as .dat file
186
- adc_gain = 1000.0 # Standard gain: 1000 units per mV
187
  int_signal = (combined_signal * adc_gain).astype(np.int16)
188
  int_signal.tofile(output_file)
189
 
190
- if debug:
191
- print(f"Saved signal to {output_file} with {len(int_signal)} samples")
192
- print(f"Integer signal range: {np.min(int_signal)} to {np.max(int_signal)}")
193
-
194
- # Clean up temporary files
195
  if os.path.exists(temp_image_path):
196
  os.remove(temp_image_path)
197
- if debug:
198
- print(f"Removed temporary image: {temp_image_path}")
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  return output_file, segment_files
201
 
202
- # Add a function to split a DAT file into segments
203
- def split_dat_into_segments(file_path, segment_duration=10.0, debug=False):
204
  """
205
  Split a DAT file into equal segments.
206
 
207
  Args:
208
  file_path (str): Path to the DAT file (without extension)
209
  segment_duration (float): Duration of each segment in seconds
210
- debug (bool): Whether to print debug information
211
 
212
  Returns:
213
  list: Paths to the segment files
214
  """
215
- try:
216
- # Load the signal
217
- signal_all_leads, fs = load_dat_signal(file_path, debug=debug)
218
-
219
- if debug:
220
- print(f"Loaded signal with shape {signal_all_leads.shape}")
221
-
222
- # Choose a lead
223
- if signal_all_leads.shape[1] == 1:
224
- lead_index = 0
225
- else:
226
- lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
227
- lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
228
-
229
- signal = signal_all_leads[:, lead_index]
230
-
231
- # Calculate samples per segment
232
- samples_per_segment = int(segment_duration * fs)
233
- total_samples = len(signal)
234
- num_segments = total_samples // samples_per_segment
235
-
236
- if debug:
237
- print(f"Splitting signal into {num_segments} segments of {segment_duration} seconds each")
238
-
239
- segment_files = []
240
-
241
- # Split and save each segment
242
- base_name = os.path.splitext(file_path)[0]
243
 
244
- for i in range(num_segments):
245
- start_idx = i * samples_per_segment
246
- end_idx = (i + 1) * samples_per_segment
247
- segment = signal[start_idx:end_idx]
248
-
249
- # Save segment
250
- segment_file = f"{base_name}_segment{i+1}.dat"
251
- segment.reshape(-1, 1).tofile(segment_file)
252
- segment_files.append(segment_file)
 
 
 
 
 
 
 
 
 
253
 
254
- if debug:
255
- print(f"Saved segment {i+1} to {segment_file}")
256
-
257
- return segment_files
258
-
259
- except Exception as e:
260
- if debug:
261
- print(f"Error splitting DAT file: {str(e)}")
262
- return []
263
 
264
- # Add function to load DAT signals
265
- def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16, debug=False):
266
  """
267
  Load a DAT file containing ECG signal data.
268
 
@@ -271,79 +240,46 @@ def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16, debug
271
  n_leads (int): Number of leads in the signal
272
  n_samples (int): Number of samples per lead
273
  dtype: Data type of the signal
274
- debug (bool): Whether to print debug information
275
 
276
  Returns:
277
  tuple: (numpy array of signal data, sampling frequency)
278
  """
279
- try:
280
- # Handle both cases: with and without .dat extension
281
- if file_path.endswith('.dat'):
282
- dat_path = file_path
283
- else:
284
- dat_path = file_path + '.dat'
285
-
286
- if debug:
287
- print(f"Loading signal from: {dat_path}")
288
-
289
- raw = np.fromfile(dat_path, dtype=dtype)
290
 
291
- if debug:
292
- print(f"Raw data size: {raw.size}")
293
-
294
- # Try to infer number of leads if read size doesn't match expected
295
- if raw.size != n_leads * n_samples:
296
- if debug:
297
- print(f"Unexpected size: {raw.size}, expected {n_leads * n_samples}")
298
- print("Attempting to infer number of leads...")
299
-
300
- # Check if single lead
301
- if raw.size == n_samples:
302
- if debug:
303
- print("Detected single lead signal")
304
- signal = raw.reshape(n_samples, 1)
305
- return signal, 500
306
-
307
- # Try common lead counts
308
- possible_leads = [1, 2, 3, 6, 12]
309
- for possible_lead_count in possible_leads:
310
- if raw.size % possible_lead_count == 0:
311
- actual_samples = raw.size // possible_lead_count
312
- if debug:
313
- print(f"Inferred {possible_lead_count} leads with {actual_samples} samples each")
314
- signal = raw.reshape(actual_samples, possible_lead_count)
315
- return signal, 500
316
-
317
- # If we can't determine it reliably, reshape as single lead
318
- if debug:
319
- print("Could not infer lead count, reshaping as single lead")
320
- signal = raw.reshape(-1, 1)
321
  return signal, 500
322
 
323
- # Normal case when size matches expectation
324
- signal = raw.reshape(n_samples, n_leads)
325
- return signal, 500 # Signal + sampling frequency
326
- except Exception as e:
327
- if debug:
328
- print(f"Error loading DAT file: {str(e)}")
329
- # Return empty signal with single channel
330
- return np.zeros((n_samples, 1)), 500
 
 
 
 
 
331
 
332
- # Add the feature extraction function
333
- def extract_features_from_signal(signal, debug=False):
334
  """
335
  Extract features from an ECG signal.
336
 
337
  Args:
338
  signal (numpy.ndarray): ECG signal
339
- debug (bool): Whether to print debug information
340
 
341
  Returns:
342
- list: Features extracted from the signal
343
  """
344
- if debug:
345
- print("Extracting features from signal...")
346
-
347
  features = []
348
  features.append(np.mean(signal))
349
  features.append(np.std(signal))
@@ -354,129 +290,108 @@ def extract_features_from_signal(signal, debug=False):
354
  features.append(np.percentile(signal, 75))
355
  features.append(np.mean(np.diff(signal)))
356
 
357
- if debug:
358
- print("Computing wavelet decomposition...")
359
-
360
  coeffs = pywt.wavedec(signal, 'db4', level=5)
361
- for i, coeff in enumerate(coeffs):
362
  features.append(np.mean(coeff))
363
  features.append(np.std(coeff))
364
  features.append(np.min(coeff))
365
  features.append(np.max(coeff))
366
-
367
- if debug and i == 0:
368
- print(f"Wavelet features for level {i}: mean={np.mean(coeff):.4f}, std={np.std(coeff):.4f}")
369
 
370
- if debug:
371
- print(f"Extracted {len(features)} features")
372
-
373
  return features
374
 
375
- # Add the classify_new_ecg function
376
- def classify_new_ecg(file_path, model, debug=False):
377
  """
378
  Classify a new ECG file.
379
 
380
  Args:
381
  file_path (str): Path to the ECG file (without extension)
382
  model: The trained model for classification
383
- debug (bool): Whether to print debug information
384
 
385
  Returns:
386
  str: Classification result ("Normal", "Abnormal", or error message)
387
  """
388
- try:
389
- if debug:
390
- print(f"Classifying ECG from: {file_path}")
391
-
392
- signal_all_leads, fs = load_dat_signal(file_path, debug=debug)
393
-
394
- if debug:
395
- print(f"Loaded signal with shape {signal_all_leads.shape}, sampling rate {fs} Hz")
396
-
397
- # Choose lead for analysis - priority order
398
- if signal_all_leads.shape[1] == 1:
399
- lead_index = 0
400
- if debug:
401
- print("Using single lead")
402
- else:
403
- lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
404
- lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
405
- if debug:
406
- print(f"Using lead index {lead_index}")
407
-
408
- # Extract the signal
409
- signal = signal_all_leads[:, lead_index]
410
-
411
- # Normalize signal
412
- signal = (signal - np.mean(signal)) / np.std(signal)
413
-
414
- if debug:
415
- print("Signal normalized")
416
- print(f"Detecting QRS complexes...")
417
-
418
- # Detect QRS complexes
419
- try:
420
- xqrs = processing.XQRS(sig=signal, fs=fs)
421
- xqrs.detect()
422
- r_peaks = xqrs.qrs_inds
423
- if debug:
424
- print(f"Detected {len(r_peaks)} QRS complexes with XQRS method")
425
- except Exception as e:
426
- if debug:
427
- print(f"XQRS detection failed: {str(e)}")
428
- print("Falling back to GQRS detector")
429
- r_peaks = processing.gqrs_detect(sig=signal, fs=fs)
430
- if debug:
431
- print(f"Detected {len(r_peaks)} QRS complexes with GQRS method")
432
 
433
- # Check if we found enough QRS complexes
434
- if len(r_peaks) < 5:
435
- if debug:
436
- print(f"Insufficient beats detected: {len(r_peaks)}")
437
- return "Insufficient beats"
 
 
438
 
439
- # Calculate RR intervals and QRS durations
 
 
 
440
  rr_intervals = np.diff(r_peaks) / fs
441
  qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
442
-
443
- if debug:
444
- print(f"Mean RR interval: {np.mean(rr_intervals):.4f} s")
445
- print(f"Mean QRS duration: {np.mean(qrs_durations) / fs:.4f} s")
446
 
447
- # Extract features
448
- features = extract_features_from_signal(signal, debug=debug)
 
449
 
450
- # Add rhythm features
451
- features.extend([
452
  len(r_peaks),
453
  np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
454
  np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
455
  np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
456
- np.mean(qrs_durations) if len(qrs_durations) > 0 else 0,
457
- np.std(qrs_durations) if len(qrs_durations) > 0 else 0
458
  ])
459
 
460
- if debug:
461
- print(f"Final feature vector length: {len(features)}")
462
-
463
- # Make prediction
464
- prediction = model.predict([features])[0]
465
- result = "Abnormal" if prediction == 1 else "Normal"
466
 
467
- if debug:
468
- print(f"Classification result: {result} (prediction value: {prediction})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
- return result
 
 
 
 
 
 
 
 
 
 
471
 
472
- except Exception as e:
473
- error_msg = f"Error: {str(e)}"
474
- if debug:
475
- print(error_msg)
476
- return error_msg
477
 
478
- # Modify the classify_ecg wrapper function to use the voting approach
479
- def classify_ecg(file_path, model, is_pdf=False, debug=False):
480
  """
481
  Wrapper function that handles both PDF and DAT ECG files with segment voting.
482
 
@@ -484,86 +399,44 @@ def classify_ecg(file_path, model, is_pdf=False, debug=False):
484
  file_path (str): Path to the ECG file (.pdf or without extension for .dat)
485
  model: The trained model for classification
486
  is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
487
- debug (bool): Enable debug output
488
 
489
  Returns:
490
  str: Classification result ("Normal", "Abnormal", or error message)
491
  """
492
  try:
493
- # Check if model is valid
494
  if model is None:
495
  return "Error: Model not loaded. Please check model compatibility."
496
 
497
  if is_pdf:
498
- if debug:
499
- print(f"Processing PDF file: {file_path}")
500
-
501
- # Extract file name without extension for output
502
  base_name = os.path.splitext(os.path.basename(file_path))[0]
503
  output_dat = f"{base_name}_digitized.dat"
504
 
505
- # Digitize the PDF to a DAT file and get segment files
506
  dat_path, segment_files = digitize_ecg_from_pdf(
507
  pdf_path=file_path,
508
- output_file=output_dat,
509
- debug=debug
510
  )
511
-
512
- if debug:
513
- print(f"Digitized ECG saved to: {dat_path}")
514
- print(f"Created {len(segment_files)} segment files")
515
  else:
516
- if debug:
517
- print(f"Processing DAT file: {file_path}")
518
-
519
- # For DAT files, we need to split into segments
520
- segment_files = split_dat_into_segments(file_path, debug=debug)
521
 
522
  if not segment_files:
523
- # If splitting failed, try classifying the whole file
524
- return classify_new_ecg(file_path, model, debug=debug)
525
 
526
- # Process each segment and collect votes
527
  segment_results = []
528
 
529
- for i, segment_file in enumerate(segment_files):
530
- if debug:
531
- print(f"\n--- Processing Segment {i+1} ---")
532
-
533
- # Get file path without extension
534
  segment_path = os.path.splitext(segment_file)[0]
535
-
536
- # Classify this segment
537
- result = classify_new_ecg(segment_path, model, debug=debug)
538
-
539
- if debug:
540
- print(f"Segment {i+1} classification: {result}")
541
-
542
  segment_results.append(result)
543
 
544
- # Remove temporary segment files
545
  try:
546
  os.remove(segment_file)
547
- if debug:
548
- print(f"Removed temporary segment file: {segment_file}")
549
  except:
550
  pass
551
 
552
- # Count results and use majority voting
553
  if segment_results:
554
  normal_count = segment_results.count("Normal")
555
  abnormal_count = segment_results.count("Abnormal")
556
- error_count = len(segment_results) - normal_count - abnormal_count
557
 
558
- if debug:
559
- print(f"\n--- Voting Results ---")
560
- print(f"Normal votes: {normal_count}")
561
- print(f"Abnormal votes: {abnormal_count}")
562
- print(f"Errors/Inconclusive: {error_count}")
563
-
564
- # Decision rules:
565
- # 1. If any segment is abnormal, classify as abnormal
566
- # 2. Only classify as normal if majority of segments are normal
567
  if abnormal_count > normal_count:
568
  final_result = "Abnormal"
569
  elif normal_count > abnormal_count:
@@ -571,41 +444,14 @@ def classify_ecg(file_path, model, is_pdf=False, debug=False):
571
  else:
572
  final_result = "Inconclusive"
573
 
574
- if debug:
575
- print(f"Final decision: {final_result}")
576
-
577
  return final_result
578
  else:
579
  return "Error: No valid segments to classify"
580
 
581
  except Exception as e:
582
  error_msg = f"Classification error: {str(e)}"
583
- if debug:
584
- print(error_msg)
585
  return error_msg
586
- # Load the saved model
587
- try:
588
- model_path = 'voting_classifier.pkl'
589
- if os.path.exists(model_path):
590
- voting_loaded = joblib.load(model_path)
591
- else:
592
- # Try to find the model in the current or parent directories
593
- for root, dirs, files in os.walk('.'):
594
- for file in files:
595
- if file.endswith('.pkl') and 'voting' in file.lower():
596
- model_path = os.path.join(root, file)
597
- voting_loaded = joblib.load(model_path)
598
- break
599
- if 'voting_loaded' in locals():
600
- break
601
-
602
- if 'voting_loaded' not in locals():
603
- voting_loaded = None
604
- except Exception as e:
605
- voting_loaded = None
606
 
607
- # Simple test for the classify_ecg function
608
- test_pdf_path = "sample.pdf"
609
- if os.path.exists(test_pdf_path) and voting_loaded is not None:
610
- result_pdf = classify_ecg(test_pdf_path, voting_loaded, is_pdf=True)
611
- print(f"Classification result: {result_pdf}")
 
1
+ import wfdb
2
+ from wfdb import processing
3
+ import numpy as np
4
+ import joblib
5
+ import pywt
6
+ import os
7
+ import cv2
8
+ from pdf2image import convert_from_path
9
  import warnings
10
  import pickle
11
+ from scipy import signal as sg
12
+ warnings.filterwarnings('ignore')
13
 
14
+
15
+ def extract_hrv_features(rr_intervals):
16
+ """
17
+ Extract heart rate variability features from RR intervals.
18
+
19
+ Args:
20
+ rr_intervals (numpy.ndarray): RR intervals in seconds
21
+
22
+ Returns:
23
+ list: Four HRV features [sdnn, rmssd, pnn50, tri_index]
24
+ """
25
+ if len(rr_intervals) < 2:
26
+ return [0, 0, 0, 0]
27
+
28
+ sdnn = np.std(rr_intervals)
29
+ diff_rr = np.diff(rr_intervals)
30
+ rmssd = np.sqrt(np.mean(diff_rr**2)) if len(diff_rr) > 0 else 0
31
+ pnn50 = 100 * np.sum(np.abs(diff_rr) > 0.05) / len(diff_rr) if len(diff_rr) > 0 else 0
32
+
33
+ if len(rr_intervals) > 2:
34
+ bin_width = 1/128
35
+ bins = np.arange(min(rr_intervals), max(rr_intervals) + bin_width, bin_width)
36
+ n, _ = np.histogram(rr_intervals, bins=bins)
37
+ tri_index = len(rr_intervals) / np.max(n) if np.max(n) > 0 else 0
38
+ else:
39
+ tri_index = 0
40
+
41
+ return [sdnn, rmssd, pnn50, tri_index]
42
+
43
+
44
+ def extract_qrs_features(signal, r_peaks, fs):
45
+ """
46
+ Extract QRS complex features from ECG signal and detected R peaks.
47
+
48
+ Args:
49
+ signal (numpy.ndarray): ECG signal
50
+ r_peaks (numpy.ndarray): Array of R peak indices
51
+ fs (int): Sampling frequency in Hz
52
+
53
+ Returns:
54
+ list: Three QRS features [qrs_width_mean, qrs_width_std, qrs_amplitude_mean]
55
+ """
56
+ if len(r_peaks) < 2:
57
+ return [0, 0, 0]
58
+
59
+ qrs_width = []
60
+ for i in range(len(r_peaks)):
61
+ r_pos = r_peaks[i]
62
+ window_before = max(0, r_pos - int(0.1 * fs))
63
+ window_after = min(len(signal) - 1, r_pos + int(0.1 * fs))
64
+
65
+ if r_pos > window_before:
66
+ q_pos = window_before + np.argmin(signal[window_before:r_pos])
67
+ else:
68
+ q_pos = window_before
69
+
70
+ if r_pos < window_after:
71
+ s_pos = r_pos + np.argmin(signal[r_pos:window_after])
72
+ else:
73
+ s_pos = r_pos
74
+
75
+ if s_pos > q_pos:
76
+ qrs_width.append((s_pos - q_pos) / fs)
77
+
78
+ qrs_width_mean = np.mean(qrs_width) if qrs_width else 0
79
+ qrs_width_std = np.std(qrs_width) if qrs_width else 0
80
+ qrs_amplitude_mean = np.mean([signal[r] for r in r_peaks]) if r_peaks.size > 0 else 0
81
+
82
+ return [qrs_width_mean, qrs_width_std, qrs_amplitude_mean]
83
+
84
+
85
+ def digitize_ecg_from_pdf(pdf_path, output_file=None):
86
  """
87
  Process an ECG PDF file and convert it to a .dat signal file.
88
 
89
  Args:
90
  pdf_path (str): Path to the ECG PDF file
91
+ output_file (str, optional): Path to save the output .dat file
 
 
92
 
93
  Returns:
94
  tuple: (path to the created .dat file, list of paths to segment files)
95
  """
96
+ if output_file is None:
97
+ output_file = 'calibrated_ecg.dat'
98
 
 
99
  images = convert_from_path(pdf_path)
100
  temp_image_path = 'temp_ecg_image.jpg'
101
  images[0].save(temp_image_path, 'JPEG')
102
 
 
 
 
 
103
  img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
104
  height, width = img.shape
105
 
 
 
 
 
106
  calibration = {
107
+ 'seconds_per_pixel': 2.0 / 197.0,
108
+ 'mv_per_pixel': 1.0 / 78.8,
109
  }
110
 
 
 
 
 
111
  layer1_start = int(height * 35.35 / 100)
112
  layer1_end = int(height * 51.76 / 100)
113
  layer2_start = int(height * 51.82 / 100)
 
115
  layer3_start = int(height * 69.47 / 100)
116
  layer3_end = int(height * 87.06 / 100)
117
 
 
 
 
 
 
 
118
  layers = [
119
+ img[layer1_start:layer1_end, :],
120
+ img[layer2_start:layer2_end, :],
121
+ img[layer3_start:layer3_end, :]
122
  ]
123
 
 
124
  signals = []
125
  time_points = []
126
+ layer_duration = 10.0
127
 
128
  for i, layer in enumerate(layers):
 
 
 
 
129
  _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
130
 
 
131
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
132
+ waveform_contour = max(contours, key=cv2.contourArea)
 
 
 
 
133
 
 
134
  sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
135
  x_coords = np.array([point[0][0] for point in sorted_contour])
136
  y_coords = np.array([point[0][1] for point in sorted_contour])
137
 
 
138
  isoelectric_line_y = layer.shape[0] * 0.6
139
 
 
140
  x_min, x_max = np.min(x_coords), np.max(x_coords)
141
  time = (x_coords - x_min) / (x_max - x_min) * layer_duration
142
 
 
143
  signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
144
  signal_mv = signal_mv - np.mean(signal_mv)
145
 
 
 
 
 
146
  time_points.append(time)
147
  signals.append(signal_mv)
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  total_duration = layer_duration * len(layers)
150
+ sampling_frequency = 500
151
  num_samples = int(total_duration * sampling_frequency)
152
  combined_time = np.linspace(0, total_duration, num_samples)
153
  combined_signal = np.zeros(num_samples)
154
 
 
 
 
 
155
  for i, (time, signal) in enumerate(zip(time_points, signals)):
156
  start_time = i * layer_duration
157
  mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
158
  relevant_times = combined_time[mask]
159
  interpolated_signal = np.interp(relevant_times, start_time + time, signal)
160
  combined_signal[mask] = interpolated_signal
 
 
 
161
 
 
162
  combined_signal = combined_signal - np.mean(combined_signal)
163
  signal_peak = np.max(np.abs(combined_signal))
164
+ target_amplitude = 2.0
 
 
 
165
 
166
  if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
167
  scaling_factor = target_amplitude / signal_peak
168
  combined_signal = combined_signal * scaling_factor
 
 
 
169
 
170
+ adc_gain = 1000.0
 
171
  int_signal = (combined_signal * adc_gain).astype(np.int16)
172
  int_signal.tofile(output_file)
173
 
 
 
 
 
 
174
  if os.path.exists(temp_image_path):
175
  os.remove(temp_image_path)
176
+
177
+ segment_files = []
178
+ samples_per_segment = int(layer_duration * sampling_frequency)
179
+
180
+ base_name = os.path.splitext(output_file)[0]
181
+ for i in range(3):
182
+ start_idx = i * samples_per_segment
183
+ end_idx = (i + 1) * samples_per_segment
184
+ segment = combined_signal[start_idx:end_idx]
185
+
186
+ segment_file = f"{base_name}_segment{i+1}.dat"
187
+ (segment * adc_gain).astype(np.int16).tofile(segment_file)
188
+ segment_files.append(segment_file)
189
 
190
  return output_file, segment_files
191
 
192
+
193
+ def split_dat_into_segments(file_path, segment_duration=10.0):
194
  """
195
  Split a DAT file into equal segments.
196
 
197
  Args:
198
  file_path (str): Path to the DAT file (without extension)
199
  segment_duration (float): Duration of each segment in seconds
 
200
 
201
  Returns:
202
  list: Paths to the segment files
203
  """
204
+ signal_all_leads, fs = load_dat_signal(file_path)
205
+
206
+ if signal_all_leads.shape[1] == 1:
207
+ lead_index = 0
208
+ else:
209
+ lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
210
+ lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ signal = signal_all_leads[:, lead_index]
213
+
214
+ samples_per_segment = int(segment_duration * fs)
215
+ total_samples = len(signal)
216
+ num_segments = total_samples // samples_per_segment
217
+
218
+ segment_files = []
219
+
220
+ base_name = os.path.splitext(file_path)[0]
221
+
222
+ for i in range(num_segments):
223
+ start_idx = i * samples_per_segment
224
+ end_idx = (i + 1) * samples_per_segment
225
+ segment = signal[start_idx:end_idx]
226
+
227
+ segment_file = f"{base_name}_segment{i+1}.dat"
228
+ segment.reshape(-1, 1).tofile(segment_file)
229
+ segment_files.append(segment_file)
230
 
231
+ return segment_files
232
+
 
 
 
 
 
 
 
233
 
234
+ def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16):
 
235
  """
236
  Load a DAT file containing ECG signal data.
237
 
 
240
  n_leads (int): Number of leads in the signal
241
  n_samples (int): Number of samples per lead
242
  dtype: Data type of the signal
 
243
 
244
  Returns:
245
  tuple: (numpy array of signal data, sampling frequency)
246
  """
247
+ if file_path.endswith('.dat'):
248
+ dat_path = file_path
249
+ else:
250
+ dat_path = file_path + '.dat'
 
 
 
 
 
 
 
251
 
252
+ raw = np.fromfile(dat_path, dtype=dtype)
253
+
254
+ if raw.size != n_leads * n_samples:
255
+ if raw.size == n_samples:
256
+ signal = raw.reshape(n_samples, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  return signal, 500
258
 
259
+ possible_leads = [1, 2, 3, 6, 12]
260
+ for possible_lead_count in possible_leads:
261
+ if raw.size % possible_lead_count == 0:
262
+ actual_samples = raw.size // possible_lead_count
263
+ signal = raw.reshape(actual_samples, possible_lead_count)
264
+ return signal, 500
265
+
266
+ signal = raw.reshape(-1, 1)
267
+ return signal, 500
268
+
269
+ signal = raw.reshape(n_samples, n_leads)
270
+ return signal, 500
271
+
272
 
273
+ def extract_features_from_signal(signal):
 
274
  """
275
  Extract features from an ECG signal.
276
 
277
  Args:
278
  signal (numpy.ndarray): ECG signal
 
279
 
280
  Returns:
281
+ list: Basic features extracted from the signal (32 features)
282
  """
 
 
 
283
  features = []
284
  features.append(np.mean(signal))
285
  features.append(np.std(signal))
 
290
  features.append(np.percentile(signal, 75))
291
  features.append(np.mean(np.diff(signal)))
292
 
 
 
 
293
  coeffs = pywt.wavedec(signal, 'db4', level=5)
294
+ for coeff in coeffs:
295
  features.append(np.mean(coeff))
296
  features.append(np.std(coeff))
297
  features.append(np.min(coeff))
298
  features.append(np.max(coeff))
 
 
 
299
 
 
 
 
300
  return features
301
 
302
+
303
+ def classify_new_ecg(file_path, model):
304
  """
305
  Classify a new ECG file.
306
 
307
  Args:
308
  file_path (str): Path to the ECG file (without extension)
309
  model: The trained model for classification
 
310
 
311
  Returns:
312
  str: Classification result ("Normal", "Abnormal", or error message)
313
  """
314
+ signal_all_leads, fs = load_dat_signal(file_path)
315
+
316
+ if signal_all_leads.shape[1] == 1:
317
+ lead_index = 0
318
+ else:
319
+ lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
320
+ lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
+ signal = signal_all_leads[:, lead_index]
323
+ signal = (signal - np.mean(signal)) / np.std(signal)
324
+
325
+ try:
326
+ r_peaks = processing.gqrs_detect(sig=signal, fs=fs)
327
+ except:
328
+ r_peaks = np.array([])
329
 
330
+ if len(r_peaks) < 2:
331
+ basic_features = extract_features_from_signal(signal)
332
+ record_features = basic_features + [0] * (45 - len(basic_features))
333
+ else:
334
  rr_intervals = np.diff(r_peaks) / fs
335
  qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
 
 
 
 
336
 
337
+ record_features = []
338
+ basic_features = extract_features_from_signal(signal)
339
+ record_features.extend(basic_features)
340
 
341
+ record_features.extend([
 
342
  len(r_peaks),
343
  np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
344
  np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
345
  np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
346
+ np.mean(qrs_durations) / fs if len(qrs_durations) > 0 else 0,
347
+ np.std(qrs_durations) / fs if len(qrs_durations) > 0 else 0
348
  ])
349
 
350
+ hrv_features = extract_hrv_features(rr_intervals)
351
+ record_features.extend(hrv_features)
 
 
 
 
352
 
353
+ qrs_features = extract_qrs_features(signal, r_peaks, fs)
354
+ record_features.extend(qrs_features)
355
+
356
+ if len(rr_intervals) >= 4:
357
+ try:
358
+ rr_times = np.cumsum(rr_intervals)
359
+ rr_times = np.insert(rr_times, 0, 0)
360
+
361
+ fs_interp = 4.0
362
+ t_interp = np.arange(0, rr_times[-1], 1/fs_interp)
363
+ rr_interp = np.interp(t_interp, rr_times[:-1], rr_intervals)
364
+
365
+ freq, psd = sg.welch(rr_interp, fs=fs_interp, nperseg=min(256, len(rr_interp)))
366
+
367
+ vlf_mask = (freq >= 0.0033) & (freq < 0.04)
368
+ lf_mask = (freq >= 0.04) & (freq < 0.15)
369
+ hf_mask = (freq >= 0.15) & (freq < 0.4)
370
+
371
+ lf_power = np.trapz(psd[lf_mask], freq[lf_mask]) if np.any(lf_mask) else 0
372
+ hf_power = np.trapz(psd[hf_mask], freq[hf_mask]) if np.any(hf_mask) else 0
373
+
374
+ lf_hf_ratio = lf_power / hf_power if hf_power > 0 else 0
375
+ normalized_lf = lf_power / (lf_power + hf_power) if (lf_power + hf_power) > 0 else 0
376
+ except:
377
+ lf_power = hf_power = lf_hf_ratio = normalized_lf = 0
378
+ else:
379
+ lf_power = hf_power = lf_hf_ratio = normalized_lf = 0
380
 
381
+ record_features.extend([lf_power, hf_power, lf_hf_ratio, normalized_lf])
382
+
383
+ if len(record_features) < 45:
384
+ record_features.extend([0] * (45 - len(record_features)))
385
+ elif len(record_features) > 45:
386
+ record_features = record_features[:45]
387
+
388
+ prediction = model.predict([record_features])[0]
389
+ result = "Abnormal" if prediction == 1 else "Normal"
390
+
391
+ return result
392
 
 
 
 
 
 
393
 
394
+ def classify_ecg(file_path, model, is_pdf=False):
 
395
  """
396
  Wrapper function that handles both PDF and DAT ECG files with segment voting.
397
 
 
399
  file_path (str): Path to the ECG file (.pdf or without extension for .dat)
400
  model: The trained model for classification
401
  is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
 
402
 
403
  Returns:
404
  str: Classification result ("Normal", "Abnormal", or error message)
405
  """
406
  try:
 
407
  if model is None:
408
  return "Error: Model not loaded. Please check model compatibility."
409
 
410
  if is_pdf:
 
 
 
 
411
  base_name = os.path.splitext(os.path.basename(file_path))[0]
412
  output_dat = f"{base_name}_digitized.dat"
413
 
 
414
  dat_path, segment_files = digitize_ecg_from_pdf(
415
  pdf_path=file_path,
416
+ output_file=output_dat
 
417
  )
 
 
 
 
418
  else:
419
+ segment_files = split_dat_into_segments(file_path)
 
 
 
 
420
 
421
  if not segment_files:
422
+ return classify_new_ecg(file_path, model)
 
423
 
 
424
  segment_results = []
425
 
426
+ for segment_file in segment_files:
 
 
 
 
427
  segment_path = os.path.splitext(segment_file)[0]
428
+ result = classify_new_ecg(segment_path, model)
 
 
 
 
 
 
429
  segment_results.append(result)
430
 
 
431
  try:
432
  os.remove(segment_file)
 
 
433
  except:
434
  pass
435
 
 
436
  if segment_results:
437
  normal_count = segment_results.count("Normal")
438
  abnormal_count = segment_results.count("Abnormal")
 
439
 
 
 
 
 
 
 
 
 
 
440
  if abnormal_count > normal_count:
441
  final_result = "Abnormal"
442
  elif normal_count > abnormal_count:
 
444
  else:
445
  final_result = "Inconclusive"
446
 
 
 
 
447
  return final_result
448
  else:
449
  return "Error: No valid segments to classify"
450
 
451
  except Exception as e:
452
  error_msg = f"Classification error: {str(e)}"
 
 
453
  return error_msg
454
+
455
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
+
 
 
 
 
ECG/ECG_MultiClass.py CHANGED
@@ -1,10 +1,10 @@
1
  """
2
- ECG Analysis Pipeline: From PDF to Diagnosis
3
- -------------------------------------------
4
  This module provides functions to:
5
  1. Digitize ECG from PDF files
6
  2. Process the digitized ECG signal
7
- 3. Make diagnoses using a trained model
8
  """
9
 
10
  import cv2
@@ -13,50 +13,42 @@ import os
13
  import tensorflow as tf
14
  import pickle
15
  from scipy.interpolate import interp1d
16
- from collections import Counter
17
  from pdf2image import convert_from_path
18
- import matplotlib.pyplot as plt # Added for visualization
19
 
20
- def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=False):
 
 
 
 
 
 
 
 
21
  """
22
  Process an ECG PDF file and convert it to a .dat signal file.
23
 
24
  Args:
25
  pdf_path (str): Path to the ECG PDF file
26
- output_file (str): Path to save the output .dat file (default: 'calibrated_ecg.dat')
27
- debug (bool): Whether to print debug information
28
 
29
  Returns:
30
- str: Path to the created .dat file
31
  """
32
- if debug:
33
- print(f"Starting ECG digitization from PDF: {pdf_path}")
34
 
35
- # Convert PDF to image
36
  images = convert_from_path(pdf_path)
37
  temp_image_path = 'temp_ecg_image.jpg'
38
  images[0].save(temp_image_path, 'JPEG')
39
 
40
- if debug:
41
- print(f"Converted PDF to image: {temp_image_path}")
42
-
43
- # Load the image
44
  img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
45
  height, width = img.shape
46
 
47
- if debug:
48
- print(f"Image dimensions: {width}x{height}")
49
-
50
- # Fixed calibration parameters
51
  calibration = {
52
- 'seconds_per_pixel': 2.0 / 197.0, # 197 pixels = 2 seconds
53
- 'mv_per_pixel': 1.0 / 78.8, # 78.8 pixels = 1 mV
54
  }
55
 
56
- if debug:
57
- print(f"Calibration parameters: {calibration}")
58
-
59
- # Calculate layer boundaries using percentages
60
  layer1_start = int(height * 35.35 / 100)
61
  layer1_end = int(height * 51.76 / 100)
62
  layer2_start = int(height * 51.82 / 100)
@@ -64,239 +56,126 @@ def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=Fals
64
  layer3_start = int(height * 69.47 / 100)
65
  layer3_end = int(height * 87.06 / 100)
66
 
67
- if debug:
68
- print(f"Layer 1 boundaries: {layer1_start}-{layer1_end}")
69
- print(f"Layer 2 boundaries: {layer2_start}-{layer2_end}")
70
- print(f"Layer 3 boundaries: {layer3_start}-{layer3_end}")
71
-
72
- # Crop each layer
73
  layers = [
74
- img[layer1_start:layer1_end, :], # Layer 1
75
- img[layer2_start:layer2_end, :], # Layer 2
76
- img[layer3_start:layer3_end, :] # Layer 3
77
  ]
78
 
79
- # Process each layer to extract waveform contours
80
  signals = []
81
  time_points = []
82
- layer_duration = 10.0 # Each layer is 10 seconds long
83
 
84
  for i, layer in enumerate(layers):
85
- if debug:
86
- print(f"Processing layer {i+1}...")
87
-
88
- # Binary thresholding
89
  _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
90
 
91
- # Detect contours
92
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
93
- waveform_contour = max(contours, key=cv2.contourArea) # Largest contour is the ECG
94
-
95
- if debug:
96
- print(f" - Found {len(contours)} contours")
97
- print(f" - Selected contour with {len(waveform_contour)} points")
98
 
99
- # Sort contour points and extract coordinates
100
  sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
101
  x_coords = np.array([point[0][0] for point in sorted_contour])
102
  y_coords = np.array([point[0][1] for point in sorted_contour])
103
 
104
- # Calculate isoelectric line (one-third from the bottom)
105
  isoelectric_line_y = layer.shape[0] * 0.6
106
 
107
- # Convert to time using fixed layer duration
108
  x_min, x_max = np.min(x_coords), np.max(x_coords)
109
  time = (x_coords - x_min) / (x_max - x_min) * layer_duration
110
 
111
- # Calculate signal in millivolts and apply baseline correction
112
  signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
113
  signal_mv = signal_mv - np.mean(signal_mv)
114
 
115
- if debug:
116
- print(f" - Layer {i+1} signal range: {np.min(signal_mv):.2f} mV to {np.max(signal_mv):.2f} mV")
117
-
118
- # Store the time points and calibrated signal
119
  time_points.append(time)
120
  signals.append(signal_mv)
121
 
122
- # Combine signals with proper time alignment
123
  total_duration = layer_duration * len(layers)
124
- sampling_frequency = 500 # Standard ECG frequency
125
  num_samples = int(total_duration * sampling_frequency)
126
  combined_time = np.linspace(0, total_duration, num_samples)
127
  combined_signal = np.zeros(num_samples)
128
 
129
- if debug:
130
- print(f"Combining signals with {sampling_frequency} Hz sampling rate, total duration: {total_duration}s")
131
-
132
- # Place each lead at the correct time position
133
  for i, (time, signal) in enumerate(zip(time_points, signals)):
134
  start_time = i * layer_duration
135
  mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
136
  relevant_times = combined_time[mask]
137
  interpolated_signal = np.interp(relevant_times, start_time + time, signal)
138
  combined_signal[mask] = interpolated_signal
139
-
140
- if debug:
141
- print(f" - Added layer {i+1} signal from {start_time}s to {start_time + layer_duration}s")
142
 
143
- # Baseline correction and amplitude scaling
144
  combined_signal = combined_signal - np.mean(combined_signal)
145
  signal_peak = np.max(np.abs(combined_signal))
146
- target_amplitude = 2.0 # Target peak amplitude in mV
147
-
148
- if debug:
149
- print(f"Signal peak before scaling: {signal_peak:.2f} mV")
150
 
151
  if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
152
  scaling_factor = target_amplitude / signal_peak
153
  combined_signal = combined_signal * scaling_factor
154
- if debug:
155
- print(f"Applied scaling factor: {scaling_factor:.2f}")
156
- print(f"Signal peak after scaling: {np.max(np.abs(combined_signal)):.2f} mV")
157
 
158
- # Convert to 16-bit integers and save as .dat file
159
- adc_gain = 1000.0 # Standard gain: 1000 units per mV
160
  int_signal = (combined_signal * adc_gain).astype(np.int16)
161
  int_signal.tofile(output_file)
162
 
163
- if debug:
164
- print(f"Saved signal to {output_file} with {len(int_signal)} samples")
165
- print(f"Integer signal range: {np.min(int_signal)} to {np.max(int_signal)}")
166
-
167
- # Clean up temporary files
168
  if os.path.exists(temp_image_path):
169
  os.remove(temp_image_path)
170
- if debug:
171
- print(f"Removed temporary image: {temp_image_path}")
172
 
173
- return output_file
174
-
175
- def visualize_ecg_signal(signal, sampling_rate=500, title="Digitized ECG Signal"):
176
- """
177
- Visualize an ECG signal with proper time axis.
178
 
179
- Parameters:
180
- -----------
181
- signal : numpy.ndarray
182
- ECG signal data
183
- sampling_rate : int
184
- Sampling rate in Hz
185
- title : str
186
- Plot title
187
- """
188
- # Calculate time axis
189
- time = np.arange(len(signal)) / sampling_rate
190
-
191
- # Create figure with appropriate size
192
- plt.figure(figsize=(15, 5))
193
- plt.plot(time, signal)
194
- plt.title(title)
195
- plt.xlabel('Time (seconds)')
196
- plt.ylabel('Amplitude (mV)')
197
- plt.grid(True)
198
-
199
- # Add 1mV scale bar
200
- plt.plot([1, 1], [-0.5, 0.5], 'r-', linewidth=2)
201
- plt.text(1.1, 0, '1mV', va='center')
202
-
203
- # Add time scale bar (1 second)
204
- y_min = np.min(signal)
205
- plt.plot([1, 2], [y_min, y_min], 'r-', linewidth=2)
206
- plt.text(1.5, y_min - 0.1, '1s', ha='center')
207
-
208
- plt.tight_layout()
209
- plt.show()
210
 
211
- def read_lead_i_long_dat_file(dat_file_path, sampling_rate=500, data_format='16', scale_factor=0.001):
 
212
  """
213
- Read a 30-second pure Lead I .dat file directly and properly scale it
214
 
215
  Parameters:
216
  -----------
217
  dat_file_path : str
218
  Path to the .dat file (with or without .dat extension)
219
- sampling_rate : int
220
- Sampling rate in Hz (default 500Hz)
221
- data_format : str
222
- Data format of the binary file: '16' for 16-bit integers, '32' for 32-bit floats
223
- scale_factor : float
224
- Scale factor to convert units (0.001 for converting Β΅V to mV)
225
 
226
  Returns:
227
  --------
228
  numpy.ndarray
229
- ECG signal data for Lead I with shape (total_samples,)
230
  """
231
- # Ensure the path ends with .dat
232
  if not dat_file_path.endswith('.dat'):
233
  dat_file_path += '.dat'
234
 
235
- # Expected samples for full 30 seconds
236
- expected_samples = sampling_rate * 30
237
-
238
- # Read the binary data
239
  try:
240
- if data_format == '16':
241
- # 16-bit signed integers (common format for ECG)
242
- data = np.fromfile(dat_file_path, dtype=np.int16)
243
- elif data_format == '32':
244
- # 32-bit floating point (less common)
245
- data = np.fromfile(dat_file_path, dtype=np.float32)
246
- else:
247
- raise ValueError(f"Unsupported data format: {data_format}")
248
-
249
- # Apply scaling to convert Β΅V to mV
250
- signal = data * scale_factor
251
-
252
- # Handle if signal is not exactly 30 seconds
253
- if len(signal) < expected_samples:
254
- # Pad with zeros if too short
255
- padded_signal = np.zeros(expected_samples)
256
- padded_signal[:len(signal)] = signal
257
- signal = padded_signal
258
- elif len(signal) > expected_samples:
259
- # Truncate if too long
260
- signal = signal[:expected_samples]
261
-
262
  return signal
263
 
264
  except Exception as e:
265
  raise
266
 
267
- def segment_signal(signal, sampling_rate=500):
268
  """
269
- Segment a 30-second signal into three 10-second segments
270
 
271
  Parameters:
272
  -----------
273
  signal : numpy.ndarray
274
  The full signal to segment
275
- sampling_rate : int
276
- Sampling rate in Hz
277
 
278
  Returns:
279
  --------
280
  list
281
- List of three 10-second signal segments
282
  """
283
- # Calculate samples per segment (10 seconds)
284
- segment_samples = sampling_rate * 10
285
-
286
- # Expected samples for full 30 seconds
287
- expected_samples = sampling_rate * 30
288
-
289
- # Ensure the signal is 30 seconds long
290
- if len(signal) != expected_samples:
291
- # Resample to 30 seconds
292
- x = np.linspace(0, 1, len(signal))
293
- x_new = np.linspace(0, 1, expected_samples)
294
- f = interp1d(x, signal, kind='linear', bounds_error=False, fill_value="extrapolate")
295
- signal = f(x_new)
296
 
297
- # Split the signal into three 10-second segments
298
  segments = []
299
- for i in range(3):
 
 
300
  start_idx = i * segment_samples
301
  end_idx = (i + 1) * segment_samples
302
  segment = signal[start_idx:end_idx]
@@ -304,7 +183,7 @@ def segment_signal(signal, sampling_rate=500):
304
 
305
  return segments
306
 
307
- def process_segment(segment, sampling_rate=500):
308
  """
309
  Process a segment of ECG data to ensure it's properly formatted for the model
310
 
@@ -312,243 +191,123 @@ def process_segment(segment, sampling_rate=500):
312
  -----------
313
  segment : numpy.ndarray
314
  Raw ECG segment
315
- sampling_rate : int
316
- Sampling rate of the ECG
317
 
318
  Returns:
319
  --------
320
  numpy.ndarray
321
  Processed segment ready for model input
322
  """
323
- # Ensure correct length (5000 samples for 10 seconds)
324
- if len(segment) != 5000:
325
  x = np.linspace(0, 1, len(segment))
326
- x_new = np.linspace(0, 1, 5000)
327
  f = interp1d(x, segment, kind='linear', bounds_error=False, fill_value="extrapolate")
328
  segment = f(x_new)
329
 
 
 
330
  return segment
331
 
332
- def predict_with_voting(dat_file_path, model_path, mlb_path=None, sampling_rate=500, scale_factor=0.001, debug=False):
 
333
  """
334
- Process a 30-second .dat file, properly scale it, split it into three 10-second segments,
335
- make predictions on each segment, and return the class with highest average probability.
336
 
337
  Parameters:
338
  -----------
339
- dat_file_path : str
340
- Path to the .dat file
341
- model_path : str
342
- Path to the saved model (.h5 file)
343
- mlb_path : str, optional
344
- Path to the saved MultiLabelBinarizer pickle file for label decoding
345
- sampling_rate : int
346
- Sampling rate in Hz (default 500Hz)
347
- scale_factor : float
348
- Scale factor to convert units (0.001 for converting Β΅V to mV)
349
- debug : bool
350
- Whether to print debug information
351
 
352
  Returns:
353
  --------
354
  dict
355
- Dictionary containing segment predictions and final class probabilities
356
  """
357
- try:
358
- # Step 1: Read the 30-second ECG data (pure Lead I) and apply scaling
359
- if debug:
360
- print(f"Reading signal from {dat_file_path}")
361
-
362
- full_signal = read_lead_i_long_dat_file(
363
- dat_file_path,
364
- sampling_rate=sampling_rate,
365
- scale_factor=scale_factor
366
- )
367
-
368
- if debug:
369
- print(f"Signal loaded: {len(full_signal)} samples, range: {np.min(full_signal):.2f} to {np.max(full_signal):.2f} mV")
370
-
371
- # Step 2: Split into three 10-second segments
372
- segments = segment_signal(full_signal, sampling_rate)
373
-
374
- if debug:
375
- print(f"Split into {len(segments)} segments of {len(segments[0])} samples each")
376
-
377
- # Step 3: Load the model (load once to improve performance)
378
- if debug:
379
- print(f"Loading model from {model_path}")
380
-
381
- model = tf.keras.models.load_model(model_path)
382
-
383
- # Load MLB if provided
384
- mlb = None
385
- if mlb_path and os.path.exists(mlb_path):
386
- if debug:
387
- print(f"Loading label binarizer from {mlb_path}")
388
- with open(mlb_path, 'rb') as f:
389
- mlb = pickle.load(f)
390
-
391
- # Step 4: Process each segment and collect predictions
392
- segment_results = []
393
- all_predictions = []
394
-
395
- for i, segment in enumerate(segments):
396
- if debug:
397
- print(f"Processing segment {i+1}...")
398
-
399
- # Process the segment to ensure it's properly formatted
400
- processed_segment = process_segment(segment)
401
-
402
- # Reshape for model input (batch, time, channels)
403
- X = processed_segment.reshape(1, 5000, 1)
404
-
405
- # Make predictions
406
- predictions = model.predict(X, verbose=0)
407
- all_predictions.append(predictions[0])
408
-
409
- # Process segment results
410
- segment_result = {"raw_predictions": predictions[0].tolist()}
411
-
412
- # Decode labels if MLB is provided
413
- if mlb is not None:
414
- # Add class probabilities
415
- class_probs = {}
416
- for j, class_name in enumerate(mlb.classes_):
417
- class_probs[class_name] = float(predictions[0][j])
418
-
419
- segment_result["class_probabilities"] = class_probs
420
-
421
- segment_results.append(segment_result)
422
 
423
- # Step 5: Calculate average probabilities across all segments
424
- final_result = {"segment_results": segment_results}
425
 
426
- # Average the raw predictions
427
- avg_predictions = np.mean(all_predictions, axis=0)
428
- final_result["averaged_raw_predictions"] = avg_predictions.tolist()
 
 
 
429
 
430
- # Calculate final class probabilities (average across segments)
431
- if mlb is not None:
432
- # Calculate average probability for each class
433
- final_class_probs = {}
434
- for cls_idx, cls_name in enumerate(mlb.classes_):
435
- final_class_probs[cls_name] = float(np.mean([pred[cls_idx] for pred in all_predictions]))
436
-
437
- # Find the class with highest average probability
438
- top_class = max(final_class_probs.items(), key=lambda x: x[1])
439
- top_class_name = top_class[0]
440
-
441
- final_result["final_class_probabilities"] = final_class_probs
442
- final_result["top_class"] = top_class_name
443
 
444
- if debug:
445
- print(f"Top class by average probability: {top_class_name} ({top_class[1]:.2f})")
446
-
447
- return final_result
448
-
449
- except Exception as e:
450
- if debug:
451
- print(f"Error in predict_with_voting: {str(e)}")
452
- return {"error": str(e)}
453
 
454
- def analyze_ecg_pdf(pdf_path, model_path, mlb_path=None, temp_dat_file='calibrated_ecg.dat', cleanup=True, debug=False, visualize=False):
455
  """
456
  Complete ECG analysis pipeline: digitizes a PDF ECG, analyzes it with the model,
457
- and returns the diagnosis with highest probability.
458
 
459
  Args:
460
  pdf_path (str): Path to the ECG PDF file
461
  model_path (str): Path to the model (.h5) file
462
- mlb_path (str, optional): Path to the MultiLabelBinarizer file
463
- temp_dat_file (str, optional): Path to save the temporary digitized file
464
  cleanup (bool, optional): Whether to remove temporary files after processing
465
- debug (bool, optional): Whether to print debug information
466
- visualize (bool, optional): Whether to visualize the digitized signal
467
 
468
  Returns:
469
  dict: {
470
- "diagnosis": str, # Top diagnosis (highest average probability)
471
- "probability": float, # Probability of top diagnosis
472
- "all_probabilities": dict, # All diagnoses with probabilities
473
  "digitized_file": str # Path to digitized file (if cleanup=False)
474
  }
475
  """
476
- # Silence TensorFlow warnings
477
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
478
 
479
- if debug:
480
- print(f"Starting ECG analysis pipeline for {pdf_path}")
481
-
482
- # 1. Digitize ECG from PDF to DAT file
483
- dat_file_path = digitize_ecg_from_pdf(pdf_path, output_file=temp_dat_file, debug=debug)
484
-
485
- # Visualize the digitized signal if requested
486
- if visualize:
487
- signal = read_lead_i_long_dat_file(dat_file_path, scale_factor=0.001)
488
- visualize_ecg_signal(signal, title=f"Digitized ECG from {os.path.basename(pdf_path)}")
489
-
490
- # 2. Process DAT file with model
491
- if debug:
492
- print("Processing digitized signal with model...")
493
-
494
- results = predict_with_voting(
495
- dat_file_path,
496
- model_path,
497
- mlb_path,
498
- scale_factor=0.001, # Convert microvolts to millivolts
499
- debug=debug
500
- )
501
-
502
- # 3. Extract top diagnosis (highest probability)
503
- top_diagnosis = {
504
- "diagnosis": None,
505
- "probability": 0.0,
506
- "all_probabilities": {},
507
- "digitized_file": dat_file_path
508
- }
509
-
510
- # If we have class probabilities, find the highest one
511
- if "final_class_probabilities" in results:
512
- probs = results["final_class_probabilities"]
513
- top_diagnosis["all_probabilities"] = probs
514
 
515
- # Use the top class directly from the results
516
- if "top_class" in results:
517
- top_diagnosis["diagnosis"] = results["top_class"]
518
- top_diagnosis["probability"] = probs[results["top_class"]]
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
- # Clean up temporary files if requested
521
- if cleanup and os.path.exists(dat_file_path):
522
- if debug:
523
- print(f"Cleaning up temporary file: {dat_file_path}")
524
- os.remove(dat_file_path)
525
- top_diagnosis.pop("digitized_file")
526
-
527
- if debug:
528
- print(f"Analysis complete. Diagnosis: {top_diagnosis['diagnosis']} (Probability: {top_diagnosis['probability']:.2f})")
529
 
530
- return top_diagnosis
531
 
532
- # Example usage
533
- if __name__ == "__main__":
534
- # Path configuration
535
- sample_pdf = 'samplebayez.pdf'
536
- model_path = 'deep-multiclass.h5' # Update with actual path
537
- mlb_path = 'deep-multiclass.pkl' # Update with actual path
538
-
539
- # Analyze ECG with debug output and visualization
540
- result = analyze_ecg_pdf(
541
- sample_pdf,
542
- model_path,
543
- mlb_path,
544
- cleanup=False, # Keep the digitized file
545
- debug=False, # Print debug information
546
- visualize=False # Visualize the digitized signal
547
- )
548
-
549
- # Display result
550
- if result["diagnosis"]:
551
- print(f"Diagnosis: {result['diagnosis']} ")
552
 
553
- else:
554
- print("No clear diagnosis found")
 
1
  """
2
+ ECG Analysis Pipeline: From PDF to Arrhythmia Classification
3
+ -----------------------------------------------------------
4
  This module provides functions to:
5
  1. Digitize ECG from PDF files
6
  2. Process the digitized ECG signal
7
+ 3. Classify arrhythmias using a trained CNN model
8
  """
9
 
10
  import cv2
 
13
  import tensorflow as tf
14
  import pickle
15
  from scipy.interpolate import interp1d
 
16
  from pdf2image import convert_from_path
 
17
 
18
+ ARRHYTHMIA_CLASSES = ["Conduction Abnormalities", "Atrial Arrhythmias", "Tachyarrhythmias", "Normal"]
19
+ SAMPLING_RATE = 500
20
+ SEGMENT_DURATION = 10.0
21
+ TARGET_SEGMENT_LENGTH = 5000
22
+ DEFAULT_OUTPUT_FILE = 'calibrated_ecg.dat'
23
+ DAT_SCALE_FACTOR = 0.001
24
+
25
+
26
+ def digitize_ecg_from_pdf(pdf_path, output_file=None):
27
  """
28
  Process an ECG PDF file and convert it to a .dat signal file.
29
 
30
  Args:
31
  pdf_path (str): Path to the ECG PDF file
32
+ output_file (str, optional): Path to save the output .dat file
 
33
 
34
  Returns:
35
+ tuple: (path to the created .dat file, list of paths to segment files)
36
  """
37
+ if output_file is None:
38
+ output_file = DEFAULT_OUTPUT_FILE
39
 
 
40
  images = convert_from_path(pdf_path)
41
  temp_image_path = 'temp_ecg_image.jpg'
42
  images[0].save(temp_image_path, 'JPEG')
43
 
 
 
 
 
44
  img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
45
  height, width = img.shape
46
 
 
 
 
 
47
  calibration = {
48
+ 'seconds_per_pixel': 2.0 / 197.0,
49
+ 'mv_per_pixel': 1.0 / 78.8,
50
  }
51
 
 
 
 
 
52
  layer1_start = int(height * 35.35 / 100)
53
  layer1_end = int(height * 51.76 / 100)
54
  layer2_start = int(height * 51.82 / 100)
 
56
  layer3_start = int(height * 69.47 / 100)
57
  layer3_end = int(height * 87.06 / 100)
58
 
 
 
 
 
 
 
59
  layers = [
60
+ img[layer1_start:layer1_end, :],
61
+ img[layer2_start:layer2_end, :],
62
+ img[layer3_start:layer3_end, :]
63
  ]
64
 
 
65
  signals = []
66
  time_points = []
67
+ layer_duration = 10.0
68
 
69
  for i, layer in enumerate(layers):
 
 
 
 
70
  _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
71
 
 
72
  contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
73
+ waveform_contour = max(contours, key=cv2.contourArea)
 
 
 
 
74
 
 
75
  sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
76
  x_coords = np.array([point[0][0] for point in sorted_contour])
77
  y_coords = np.array([point[0][1] for point in sorted_contour])
78
 
 
79
  isoelectric_line_y = layer.shape[0] * 0.6
80
 
 
81
  x_min, x_max = np.min(x_coords), np.max(x_coords)
82
  time = (x_coords - x_min) / (x_max - x_min) * layer_duration
83
 
 
84
  signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
85
  signal_mv = signal_mv - np.mean(signal_mv)
86
 
 
 
 
 
87
  time_points.append(time)
88
  signals.append(signal_mv)
89
 
 
90
  total_duration = layer_duration * len(layers)
91
+ sampling_frequency = 500
92
  num_samples = int(total_duration * sampling_frequency)
93
  combined_time = np.linspace(0, total_duration, num_samples)
94
  combined_signal = np.zeros(num_samples)
95
 
 
 
 
 
96
  for i, (time, signal) in enumerate(zip(time_points, signals)):
97
  start_time = i * layer_duration
98
  mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
99
  relevant_times = combined_time[mask]
100
  interpolated_signal = np.interp(relevant_times, start_time + time, signal)
101
  combined_signal[mask] = interpolated_signal
 
 
 
102
 
 
103
  combined_signal = combined_signal - np.mean(combined_signal)
104
  signal_peak = np.max(np.abs(combined_signal))
105
+ target_amplitude = 2.0
 
 
 
106
 
107
  if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
108
  scaling_factor = target_amplitude / signal_peak
109
  combined_signal = combined_signal * scaling_factor
 
 
 
110
 
111
+ adc_gain = 1000.0
 
112
  int_signal = (combined_signal * adc_gain).astype(np.int16)
113
  int_signal.tofile(output_file)
114
 
 
 
 
 
 
115
  if os.path.exists(temp_image_path):
116
  os.remove(temp_image_path)
 
 
117
 
118
+ segment_files = []
119
+ samples_per_segment = int(layer_duration * sampling_frequency)
 
 
 
120
 
121
+ base_name = os.path.splitext(output_file)[0]
122
+ for i in range(3):
123
+ start_idx = i * samples_per_segment
124
+ end_idx = (i + 1) * samples_per_segment
125
+ segment = combined_signal[start_idx:end_idx]
126
+
127
+ segment_file = f"{base_name}_segment{i+1}.dat"
128
+ (segment * adc_gain).astype(np.int16).tofile(segment_file)
129
+ segment_files.append(segment_file)
130
+
131
+ return output_file, segment_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+
134
+ def read_ecg_dat_file(dat_file_path):
135
  """
136
+ Read a DAT file directly and properly scale it
137
 
138
  Parameters:
139
  -----------
140
  dat_file_path : str
141
  Path to the .dat file (with or without .dat extension)
 
 
 
 
 
 
142
 
143
  Returns:
144
  --------
145
  numpy.ndarray
146
+ ECG signal data with shape (total_samples,)
147
  """
 
148
  if not dat_file_path.endswith('.dat'):
149
  dat_file_path += '.dat'
150
 
 
 
 
 
151
  try:
152
+ data = np.fromfile(dat_file_path, dtype=np.int16)
153
+ signal = data * DAT_SCALE_FACTOR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  return signal
155
 
156
  except Exception as e:
157
  raise
158
 
159
+ def segment_signal(signal):
160
  """
161
+ Segment a signal into equal-length segments
162
 
163
  Parameters:
164
  -----------
165
  signal : numpy.ndarray
166
  The full signal to segment
 
 
167
 
168
  Returns:
169
  --------
170
  list
171
+ List of signal segments
172
  """
173
+ segment_samples = int(SAMPLING_RATE * SEGMENT_DURATION)
 
 
 
 
 
 
 
 
 
 
 
 
174
 
 
175
  segments = []
176
+ num_segments = len(signal) // segment_samples
177
+
178
+ for i in range(num_segments):
179
  start_idx = i * segment_samples
180
  end_idx = (i + 1) * segment_samples
181
  segment = signal[start_idx:end_idx]
 
183
 
184
  return segments
185
 
186
+ def process_segment(segment):
187
  """
188
  Process a segment of ECG data to ensure it's properly formatted for the model
189
 
 
191
  -----------
192
  segment : numpy.ndarray
193
  Raw ECG segment
 
 
194
 
195
  Returns:
196
  --------
197
  numpy.ndarray
198
  Processed segment ready for model input
199
  """
200
+ if len(segment) != TARGET_SEGMENT_LENGTH:
 
201
  x = np.linspace(0, 1, len(segment))
202
+ x_new = np.linspace(0, 1, TARGET_SEGMENT_LENGTH)
203
  f = interp1d(x, segment, kind='linear', bounds_error=False, fill_value="extrapolate")
204
  segment = f(x_new)
205
 
206
+ segment = (segment - np.mean(segment)) / (np.std(segment) + 1e-8)
207
+
208
  return segment
209
 
210
+
211
+ def predict_with_cnn_model(signal_data, model):
212
  """
213
+ Process signal data and make predictions using the CNN model.
 
214
 
215
  Parameters:
216
  -----------
217
+ signal_data : numpy.ndarray
218
+ Raw signal data
219
+ model : tensorflow.keras.Model
220
+ Loaded CNN model
 
 
 
 
 
 
 
 
221
 
222
  Returns:
223
  --------
224
  dict
225
+ Dictionary containing predictions for each segment and final averaged prediction
226
  """
227
+ segments = segment_signal(signal_data)
228
+
229
+ all_predictions = []
230
+
231
+ for i, segment in enumerate(segments):
232
+ processed_segment = process_segment(segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ X = processed_segment.reshape(1, TARGET_SEGMENT_LENGTH, 1)
 
235
 
236
+ prediction = model.predict(X, verbose=0)
237
+ all_predictions.append(prediction[0])
238
+
239
+ if all_predictions:
240
+ avg_prediction = np.mean(all_predictions, axis=0)
241
+ top_class_idx = np.argmax(avg_prediction)
242
 
243
+ results = {
244
+ "segment_predictions": all_predictions,
245
+ "averaged_prediction": avg_prediction,
246
+ "top_class_index": top_class_idx,
247
+ "top_class": ARRHYTHMIA_CLASSES[top_class_idx],
248
+ "probability": float(avg_prediction[top_class_idx])
249
+ }
 
 
 
 
 
 
250
 
251
+ return results
252
+ else:
253
+ return {"error": "No valid segments for prediction"}
 
 
 
 
 
 
254
 
255
+ def analyze_ecg_pdf(pdf_path, model_path, cleanup=True):
256
  """
257
  Complete ECG analysis pipeline: digitizes a PDF ECG, analyzes it with the model,
258
+ and returns the arrhythmia classification with highest probability.
259
 
260
  Args:
261
  pdf_path (str): Path to the ECG PDF file
262
  model_path (str): Path to the model (.h5) file
 
 
263
  cleanup (bool, optional): Whether to remove temporary files after processing
 
 
264
 
265
  Returns:
266
  dict: {
267
+ "arrhythmia_class": str, # Top arrhythmia class
268
+ "probability": float, # Probability of top class
269
+ "all_probabilities": dict, # All classes with probabilities
270
  "digitized_file": str # Path to digitized file (if cleanup=False)
271
  }
272
  """
 
273
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
274
 
275
+ try:
276
+ dat_file_path, segment_files = digitize_ecg_from_pdf(pdf_path)
277
+
278
+ ecg_model = tf.keras.models.load_model(model_path)
279
+
280
+ ecg_signal = read_ecg_dat_file(dat_file_path)
281
+
282
+ classification_results = predict_with_cnn_model(ecg_signal, ecg_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ arrhythmia_result = {
285
+ "arrhythmia_class": classification_results.get("top_class"),
286
+ "probability": classification_results.get("probability", 0.0),
287
+ "all_probabilities": {}
288
+ }
289
+
290
+ if "averaged_prediction" in classification_results:
291
+ for idx, class_name in enumerate(ARRHYTHMIA_CLASSES):
292
+ arrhythmia_result["all_probabilities"][class_name] = float(classification_results["averaged_prediction"][idx])
293
+
294
+ if not cleanup:
295
+ arrhythmia_result["digitized_file"] = dat_file_path
296
+
297
+ if cleanup:
298
+ if os.path.exists(dat_file_path):
299
+ os.remove(dat_file_path)
300
 
301
+ for segment_file in segment_files:
302
+ if os.path.exists(segment_file):
303
+ os.remove(segment_file)
304
+
305
+ return arrhythmia_result
306
+
307
+ except Exception as e:
308
+ error_msg = f"Error in ECG analysis: {str(e)}"
309
+ return {"error": error_msg}
310
 
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+
 
SkinBurns_Classification.py β†’ SkinBurns/SkinBurns_Classification.py RENAMED
File without changes
SkinBurns_Segmentation.py β†’ SkinBurns/SkinBurns_Segmentation.py RENAMED
File without changes
app.py CHANGED
@@ -9,8 +9,8 @@ from pymongo.server_api import ServerApi
9
  import cloudinary
10
  import cloudinary.uploader
11
  from cloudinary.utils import cloudinary_url
12
- from SkinBurns_Classification import FullFeautures
13
- from SkinBurns_Segmentation import segment_burn
14
  import requests
15
  import joblib
16
  import numpy as np
@@ -63,7 +63,7 @@ except Exception as e:
63
  cloudinary.config(
64
  cloud_name = "darumyfpl",
65
  api_key = "493972437417214",
66
- api_secret = "jjOScVGochJYA7IxDam7L4HU2Ig", # Replace in production
67
  secure=True
68
  )
69
 
@@ -148,14 +148,14 @@ async def segment_burn_endpoint(reference: UploadFile = File(...), patient: Uplo
148
 
149
  @app.post("/classify-ecg")
150
  async def classify_ecg_endpoint(file: UploadFile = File(...)):
151
- model = joblib.load('voting_classifier.pkl')
152
 
153
  try:
154
  temp_file_path = f"temp_{file.filename}"
155
  with open(temp_file_path, "wb") as temp_file:
156
  temp_file.write(await file.read())
157
 
158
- result = classify_ecg(temp_file_path, model, debug=True, is_pdf=True)
159
 
160
  os.remove(temp_file_path)
161
 
@@ -167,31 +167,22 @@ async def classify_ecg_endpoint(file: UploadFile = File(...)):
167
  @app.post("/diagnose-ecg")
168
  async def diagnose_ecg(file: UploadFile = File(...)):
169
  try:
170
- # Save the uploaded file temporarily
171
  temp_file_path = f"temp_{file.filename}"
172
  with open(temp_file_path, "wb") as temp_file:
173
  temp_file.write(await file.read())
174
 
175
- model_path = 'deep-multiclass.h5' # Update with actual path
176
- mlb_path = 'deep-multiclass.pkl' # Update with actual path
177
-
178
 
179
- # Call the ECG classification function
180
  result = analyze_ecg_pdf(
181
  temp_file_path,
182
  model_path,
183
- mlb_path,
184
- cleanup=False, # Keep the digitized file
185
- debug=False, # Print debug information
186
- visualize=False # Visualize the digitized signal
187
  )
188
-
189
 
190
- # Remove the temporary file
191
  os.remove(temp_file_path)
192
 
193
- if result and result["diagnosis"]:
194
- return {"result": result["diagnosis"]}
195
  else:
196
  return {"result": "No diagnosis"}
197
 
@@ -216,7 +207,6 @@ async def process_video(file: UploadFile = File(...)):
216
  print("File content type:", file.content_type)
217
  print("File filename:", file.filename)
218
 
219
- # Prepare directories
220
  os.makedirs(UPLOAD_DIR, exist_ok=True)
221
  os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
222
  os.makedirs(OUTPUT_DIR, exist_ok=True)
@@ -259,7 +249,6 @@ async def process_video(file: UploadFile = File(...)):
259
  overwrite=True
260
  )
261
 
262
- # Add new warning with image_url and description
263
  warnings.append({
264
  "image_url": upload_result['secure_url'],
265
  "description": description
@@ -279,7 +268,6 @@ async def process_video(file: UploadFile = File(...)):
279
  else:
280
  wholevideoURL = None
281
 
282
- # Upload graph output
283
  graphURL = None
284
  if os.path.isfile(plot_output_path):
285
  upload_graph_result = cloudinary.uploader.upload(
@@ -310,7 +298,7 @@ clients = set()
310
  analyzer_thread = None
311
  analysis_started = False
312
  analyzer_lock = threading.Lock()
313
- socket_server: AnalysisSocketServer = None # Global reference
314
 
315
 
316
  async def forward_results_from_queue(websocket: WebSocket, warning_queue):
@@ -367,11 +355,9 @@ async def websocket_analysis(websocket: WebSocket):
367
  logger.info("[WebSocket] Flutter connected")
368
 
369
  try:
370
- # Wait for the client to send the stream URL as first message
371
  source = await websocket.receive_text()
372
  logger.info(f"[WebSocket] Received stream URL: {source}")
373
 
374
- # Ensure analyzer starts only once using a thread-safe lock
375
  with analyzer_lock:
376
  if not analysis_started:
377
  requested_fps = 30
@@ -386,7 +372,6 @@ async def websocket_analysis(websocket: WebSocket):
386
  analysis_started = True
387
  logger.info("[WebSocket] Analysis thread started")
388
 
389
- # Rest of your existing code remains exactly the same...
390
  while socket_server is None or socket_server.warning_queue is None:
391
  await asyncio.sleep(0.1)
392
 
@@ -395,7 +380,7 @@ async def websocket_analysis(websocket: WebSocket):
395
  )
396
 
397
  while True:
398
- await asyncio.sleep(1) # Keep alive
399
 
400
  except WebSocketDisconnect:
401
  logger.warning("[WebSocket] Client disconnected")
@@ -403,7 +388,7 @@ async def websocket_analysis(websocket: WebSocket):
403
  forward_task.cancel()
404
  except Exception as e:
405
  logger.error(f"[WebSocket] Error receiving stream URL: {str(e)}")
406
- await websocket.close(code=1011) # 1011 = Internal Error
407
  finally:
408
  clients.discard(websocket)
409
  logger.info(f"[WebSocket] Active clients: {len(clients)}")
 
9
  import cloudinary
10
  import cloudinary.uploader
11
  from cloudinary.utils import cloudinary_url
12
+ from SkinBurns.SkinBurns_Classification import FullFeautures
13
+ from SkinBurns.SkinBurns_Segmentation import segment_burn
14
  import requests
15
  import joblib
16
  import numpy as np
 
63
  cloudinary.config(
64
  cloud_name = "darumyfpl",
65
  api_key = "493972437417214",
66
+ api_secret = "jjOScVGochJYA7IxDam7L4HU2Ig",
67
  secure=True
68
  )
69
 
 
148
 
149
  @app.post("/classify-ecg")
150
  async def classify_ecg_endpoint(file: UploadFile = File(...)):
151
+ model = joblib.load('voting_classifier_arrhythmia.pkl')
152
 
153
  try:
154
  temp_file_path = f"temp_{file.filename}"
155
  with open(temp_file_path, "wb") as temp_file:
156
  temp_file.write(await file.read())
157
 
158
+ result = classify_ecg(temp_file_path, model, is_pdf=True)
159
 
160
  os.remove(temp_file_path)
161
 
 
167
  @app.post("/diagnose-ecg")
168
  async def diagnose_ecg(file: UploadFile = File(...)):
169
  try:
 
170
  temp_file_path = f"temp_{file.filename}"
171
  with open(temp_file_path, "wb") as temp_file:
172
  temp_file.write(await file.read())
173
 
174
+ model_path = 'Arrhythmia_Model_with_SMOTE.h5'
 
 
175
 
 
176
  result = analyze_ecg_pdf(
177
  temp_file_path,
178
  model_path,
179
+ cleanup=False
 
 
 
180
  )
 
181
 
 
182
  os.remove(temp_file_path)
183
 
184
+ if result and result["arrhythmia_class"]:
185
+ return {"result": result["arrhythmia_class"]}
186
  else:
187
  return {"result": "No diagnosis"}
188
 
 
207
  print("File content type:", file.content_type)
208
  print("File filename:", file.filename)
209
 
 
210
  os.makedirs(UPLOAD_DIR, exist_ok=True)
211
  os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
212
  os.makedirs(OUTPUT_DIR, exist_ok=True)
 
249
  overwrite=True
250
  )
251
 
 
252
  warnings.append({
253
  "image_url": upload_result['secure_url'],
254
  "description": description
 
268
  else:
269
  wholevideoURL = None
270
 
 
271
  graphURL = None
272
  if os.path.isfile(plot_output_path):
273
  upload_graph_result = cloudinary.uploader.upload(
 
298
  analyzer_thread = None
299
  analysis_started = False
300
  analyzer_lock = threading.Lock()
301
+ socket_server: AnalysisSocketServer = None
302
 
303
 
304
  async def forward_results_from_queue(websocket: WebSocket, warning_queue):
 
355
  logger.info("[WebSocket] Flutter connected")
356
 
357
  try:
 
358
  source = await websocket.receive_text()
359
  logger.info(f"[WebSocket] Received stream URL: {source}")
360
 
 
361
  with analyzer_lock:
362
  if not analysis_started:
363
  requested_fps = 30
 
372
  analysis_started = True
373
  logger.info("[WebSocket] Analysis thread started")
374
 
 
375
  while socket_server is None or socket_server.warning_queue is None:
376
  await asyncio.sleep(0.1)
377
 
 
380
  )
381
 
382
  while True:
383
+ await asyncio.sleep(1)
384
 
385
  except WebSocketDisconnect:
386
  logger.warning("[WebSocket] Client disconnected")
 
388
  forward_task.cancel()
389
  except Exception as e:
390
  logger.error(f"[WebSocket] Error receiving stream URL: {str(e)}")
391
+ await websocket.close(code=1011)
392
  finally:
393
  clients.discard(websocket)
394
  logger.info(f"[WebSocket] Active clients: {len(clients)}")
voting_classifier.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:52e5d9789c5a5f6b42f595fceb67948bc15e9c9035de9c02f72cf29ff42c9d93
3
- size 4084247
 
 
 
 
deep-multiclass.pkl β†’ voting_classifier_arrhythmia.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a6bd3507a56fc3e77f54e6fe0772888f61b22a2273021a04c288170237e1bc4b
3
- size 346
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f1a2aff5dffb25a19be3bcaa4db79373dcad23355ba9b166ca1d2a8978e3600
3
+ size 50223409