Hussein El-Hadidy commited on
Commit
af48e90
·
1 Parent(s): e036440

Added new endpoints for ECG

Browse files
Dockerfile CHANGED
@@ -1,13 +1,14 @@
1
  # Use the correct Python version (3.10)
2
  FROM python:3.10
3
 
4
- # Install system dependencies
5
  RUN apt-get update && \
6
  apt-get install -y \
7
  build-essential \
8
  libssl-dev \
9
  ca-certificates \
10
  libgl1 \
 
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
  # Create a user for non-root operation
 
1
  # Use the correct Python version (3.10)
2
  FROM python:3.10
3
 
4
+ # Install system dependencies including poppler-utils for pdf2image
5
  RUN apt-get update && \
6
  apt-get install -y \
7
  build-essential \
8
  libssl-dev \
9
  ca-certificates \
10
  libgl1 \
11
+ poppler-utils \ # Add poppler-utils here
12
  && rm -rf /var/lib/apt/lists/*
13
 
14
  # Create a user for non-root operation
ECG.py DELETED
@@ -1,73 +0,0 @@
1
- import wfdb # To read the ECG files
2
- from wfdb import processing # For QRS detection
3
- import numpy as np # Numerical operations
4
- import joblib # To load the saved model
5
- import pywt # For wavelet feature extraction
6
-
7
- def extract_features_from_signal(signal):
8
- features = []
9
- features.append(np.mean(signal))
10
- features.append(np.std(signal))
11
- features.append(np.median(signal))
12
- features.append(np.min(signal))
13
- features.append(np.max(signal))
14
- features.append(np.percentile(signal, 25))
15
- features.append(np.percentile(signal, 75))
16
- features.append(np.mean(np.diff(signal)))
17
-
18
- coeffs = pywt.wavedec(signal, 'db4', level=5)
19
- for coeff in coeffs:
20
- features.append(np.mean(coeff))
21
- features.append(np.std(coeff))
22
- features.append(np.min(coeff))
23
- features.append(np.max(coeff))
24
-
25
- return features
26
-
27
- def classify_new_ecg(file_path, model):
28
- try:
29
- record = wfdb.rdrecord(file_path)
30
-
31
- available_leads = record.sig_name
32
- lead_index = next((available_leads.index(lead) for lead in ["II", "MLII", "I"] if lead in available_leads), None)
33
- if lead_index is None:
34
- return "Unsupported lead"
35
-
36
- signal = record.p_signal[:, lead_index]
37
- signal = (signal - np.mean(signal)) / np.std(signal)
38
-
39
- try:
40
- xqrs = processing.XQRS(sig=signal, fs=record.fs)
41
- xqrs.detect()
42
- r_peaks = xqrs.qrs_inds
43
- except:
44
- r_peaks = processing.gqrs_detect(sig=signal, fs=record.fs)
45
-
46
- if len(r_peaks) < 5:
47
- return "Insufficient beats"
48
-
49
- rr_intervals = np.diff(r_peaks) / record.fs
50
- qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
51
-
52
- features = extract_features_from_signal(signal)
53
- features.extend([
54
- len(r_peaks),
55
- np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
56
- np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
57
- np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
58
- np.mean(qrs_durations) if len(qrs_durations) > 0 else 0,
59
- np.std(qrs_durations) if len(qrs_durations) > 0 else 0
60
- ])
61
-
62
- prediction = model.predict([features])[0]
63
- return "Abnormal" if prediction == 1 else "Normal"
64
-
65
- except Exception as e:
66
- return f"Error: {str(e)}"
67
-
68
- # Load the saved model
69
- #voting_loaded = joblib.load('voting_classifier.pkl')
70
-
71
- #file_path = "00001_hr"
72
- #result = classify_new_ecg(file_path, voting_loaded)
73
- #print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ECG/00001_hr.hea DELETED
@@ -1,13 +0,0 @@
1
- 00001_hr 12 500 5000
2
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -115 13047 0 I
3
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -50 11561 0 II
4
- 00001_hr.dat 16 1000.0(0)/mV 16 0 65 64050 0 III
5
- 00001_hr.dat 16 1000.0(0)/mV 16 0 82 53190 0 AVR
6
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -90 7539 0 AVL
7
- 00001_hr.dat 16 1000.0(0)/mV 16 0 7 5145 0 AVF
8
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -65 59817 0 V1
9
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -40 44027 0 V2
10
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -5 64232 0 V3
11
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -35 50309 0 V4
12
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -35 4821 0 V5
13
- 00001_hr.dat 16 1000.0(0)/mV 16 0 -75 12159 0 V6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ECG/00008_hr.hea DELETED
@@ -1,13 +0,0 @@
1
- 00008_hr 12 500 5000
2
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -40 12319 0 I
3
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -75 22545 0 II
4
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -35 10283 0 III
5
- 00008_hr.dat 16 1000.0(0)/mV 16 0 58 47892 0 AVR
6
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -2 891 0 AVL
7
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -55 16258 0 AVF
8
- 00008_hr.dat 16 1000.0(0)/mV 16 0 45 511 0 V1
9
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -5 64894 0 V2
10
- 00008_hr.dat 16 1000.0(0)/mV 16 0 0 57055 0 V3
11
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -55 33262 0 V4
12
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -70 18240 0 V5
13
- 00008_hr.dat 16 1000.0(0)/mV 16 0 -40 6332 0 V6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ECG/ECG_Classify.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wfdb # To read the ECG files
2
+ from wfdb import processing # For QRS detection
3
+ import numpy as np # Numerical operations
4
+ import joblib # To load the saved model
5
+ import pywt # For wavelet feature extraction
6
+ import os # For file operations
7
+ import cv2 # For image processing
8
+ from pdf2image import convert_from_path # For PDF to image conversion
9
+ import warnings
10
+ import pickle
11
+ import sklearn
12
+
13
+ # Let's modify the digitize_ecg_from_pdf function to return segment information
14
+ def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=False, save_segments=True):
15
+ """
16
+ Process an ECG PDF file and convert it to a .dat signal file.
17
+
18
+ Args:
19
+ pdf_path (str): Path to the ECG PDF file
20
+ output_file (str): Path to save the output .dat file (default: 'calibrated_ecg.dat')
21
+ debug (bool): Whether to print debug information
22
+ save_segments (bool): Whether to save individual segments
23
+
24
+ Returns:
25
+ tuple: (path to the created .dat file, list of paths to segment files)
26
+ """
27
+ if debug:
28
+ print(f"Starting ECG digitization from PDF: {pdf_path}")
29
+
30
+ # Convert PDF to image
31
+ images = convert_from_path(pdf_path)
32
+ temp_image_path = 'temp_ecg_image.jpg'
33
+ images[0].save(temp_image_path, 'JPEG')
34
+
35
+ if debug:
36
+ print(f"Converted PDF to image: {temp_image_path}")
37
+
38
+ # Load the image
39
+ img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
40
+ height, width = img.shape
41
+
42
+ if debug:
43
+ print(f"Image dimensions: {width}x{height}")
44
+
45
+ # Fixed calibration parameters
46
+ calibration = {
47
+ 'seconds_per_pixel': 2.0 / 197.0, # 197 pixels = 2 seconds
48
+ 'mv_per_pixel': 1.0 / 78.8, # 78.8 pixels = 1 mV
49
+ }
50
+
51
+ if debug:
52
+ print(f"Calibration parameters: {calibration}")
53
+
54
+ # Calculate layer boundaries using percentages
55
+ layer1_start = int(height * 35.35 / 100)
56
+ layer1_end = int(height * 51.76 / 100)
57
+ layer2_start = int(height * 51.82 / 100)
58
+ layer2_end = int(height * 69.41 / 100)
59
+ layer3_start = int(height * 69.47 / 100)
60
+ layer3_end = int(height * 87.06 / 100)
61
+
62
+ if debug:
63
+ print(f"Layer 1 boundaries: {layer1_start}-{layer1_end}")
64
+ print(f"Layer 2 boundaries: {layer2_start}-{layer2_end}")
65
+ print(f"Layer 3 boundaries: {layer3_start}-{layer3_end}")
66
+
67
+ # Crop each layer
68
+ layers = [
69
+ img[layer1_start:layer1_end, :], # Layer 1
70
+ img[layer2_start:layer2_end, :], # Layer 2
71
+ img[layer3_start:layer3_end, :] # Layer 3
72
+ ]
73
+
74
+ # Process each layer to extract waveform contours
75
+ signals = []
76
+ time_points = []
77
+ layer_duration = 10.0 # Each layer is 10 seconds long
78
+
79
+ for i, layer in enumerate(layers):
80
+ if debug:
81
+ print(f"Processing layer {i+1}...")
82
+
83
+ # Binary thresholding
84
+ _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
85
+
86
+ # Detect contours
87
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
88
+ waveform_contour = max(contours, key=cv2.contourArea) # Largest contour is the ECG
89
+
90
+ if debug:
91
+ print(f" - Found {len(contours)} contours")
92
+ print(f" - Selected contour with {len(waveform_contour)} points")
93
+
94
+ # Sort contour points and extract coordinates
95
+ sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
96
+ x_coords = np.array([point[0][0] for point in sorted_contour])
97
+ y_coords = np.array([point[0][1] for point in sorted_contour])
98
+
99
+ # Calculate isoelectric line (one-third from the bottom)
100
+ isoelectric_line_y = layer.shape[0] * 0.6
101
+
102
+ # Convert to time using fixed layer duration
103
+ x_min, x_max = np.min(x_coords), np.max(x_coords)
104
+ time = (x_coords - x_min) / (x_max - x_min) * layer_duration
105
+
106
+ # Calculate signal in millivolts and apply baseline correction
107
+ signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
108
+ signal_mv = signal_mv - np.mean(signal_mv)
109
+
110
+ if debug:
111
+ print(f" - Layer {i+1} signal range: {np.min(signal_mv):.2f} mV to {np.max(signal_mv):.2f} mV")
112
+
113
+ # Store the time points and calibrated signal
114
+ time_points.append(time)
115
+ signals.append(signal_mv)
116
+
117
+ # Save individual segments if requested
118
+ segment_files = []
119
+ sampling_frequency = 500 # Standard ECG frequency
120
+ samples_per_segment = int(layer_duration * sampling_frequency) # 5000 samples per 10-second segment
121
+
122
+ if save_segments:
123
+ base_name = os.path.splitext(output_file)[0]
124
+
125
+ for i, signal in enumerate(signals):
126
+ # Interpolate to get evenly sampled signal
127
+ segment_time = np.linspace(0, layer_duration, samples_per_segment)
128
+ interpolated_signal = np.interp(segment_time, time_points[i], signals[i])
129
+
130
+ # Normalize and scale
131
+ interpolated_signal = interpolated_signal - np.mean(interpolated_signal)
132
+ signal_peak = np.max(np.abs(interpolated_signal))
133
+
134
+ if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
135
+ scaling_factor = 2.0 / signal_peak # Target peak amplitude of 2.0 mV
136
+ interpolated_signal = interpolated_signal * scaling_factor
137
+
138
+ # Convert to 16-bit integers
139
+ adc_gain = 1000.0
140
+ int_signal = (interpolated_signal * adc_gain).astype(np.int16)
141
+
142
+ # Save segment
143
+ segment_file = f"{base_name}_segment{i+1}.dat"
144
+ int_signal.reshape(-1, 1).tofile(segment_file)
145
+ segment_files.append(segment_file)
146
+
147
+ if debug:
148
+ print(f"Saved segment {i+1} to {segment_file}")
149
+
150
+ # Combine signals with proper time alignment for the full record
151
+ total_duration = layer_duration * len(layers)
152
+ num_samples = int(total_duration * sampling_frequency)
153
+ combined_time = np.linspace(0, total_duration, num_samples)
154
+ combined_signal = np.zeros(num_samples)
155
+
156
+ if debug:
157
+ print(f"Combining signals with {sampling_frequency} Hz sampling rate, total duration: {total_duration}s")
158
+
159
+ # Place each lead at the correct time position
160
+ for i, (time, signal) in enumerate(zip(time_points, signals)):
161
+ start_time = i * layer_duration
162
+ mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
163
+ relevant_times = combined_time[mask]
164
+ interpolated_signal = np.interp(relevant_times, start_time + time, signal)
165
+ combined_signal[mask] = interpolated_signal
166
+
167
+ if debug:
168
+ print(f" - Added layer {i+1} signal from {start_time}s to {start_time + layer_duration}s")
169
+
170
+ # Baseline correction and amplitude scaling
171
+ combined_signal = combined_signal - np.mean(combined_signal)
172
+ signal_peak = np.max(np.abs(combined_signal))
173
+ target_amplitude = 2.0 # Target peak amplitude in mV
174
+
175
+ if debug:
176
+ print(f"Signal peak before scaling: {signal_peak:.2f} mV")
177
+
178
+ if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
179
+ scaling_factor = target_amplitude / signal_peak
180
+ combined_signal = combined_signal * scaling_factor
181
+ if debug:
182
+ print(f"Applied scaling factor: {scaling_factor:.2f}")
183
+ print(f"Signal peak after scaling: {np.max(np.abs(combined_signal)):.2f} mV")
184
+
185
+ # Convert to 16-bit integers and save as .dat file
186
+ adc_gain = 1000.0 # Standard gain: 1000 units per mV
187
+ int_signal = (combined_signal * adc_gain).astype(np.int16)
188
+ int_signal.tofile(output_file)
189
+
190
+ if debug:
191
+ print(f"Saved signal to {output_file} with {len(int_signal)} samples")
192
+ print(f"Integer signal range: {np.min(int_signal)} to {np.max(int_signal)}")
193
+
194
+ # Clean up temporary files
195
+ if os.path.exists(temp_image_path):
196
+ os.remove(temp_image_path)
197
+ if debug:
198
+ print(f"Removed temporary image: {temp_image_path}")
199
+
200
+ return output_file, segment_files
201
+
202
+ # Add a function to split a DAT file into segments
203
+ def split_dat_into_segments(file_path, segment_duration=10.0, debug=False):
204
+ """
205
+ Split a DAT file into equal segments.
206
+
207
+ Args:
208
+ file_path (str): Path to the DAT file (without extension)
209
+ segment_duration (float): Duration of each segment in seconds
210
+ debug (bool): Whether to print debug information
211
+
212
+ Returns:
213
+ list: Paths to the segment files
214
+ """
215
+ try:
216
+ # Load the signal
217
+ signal_all_leads, fs = load_dat_signal(file_path, debug=debug)
218
+
219
+ if debug:
220
+ print(f"Loaded signal with shape {signal_all_leads.shape}")
221
+
222
+ # Choose a lead
223
+ if signal_all_leads.shape[1] == 1:
224
+ lead_index = 0
225
+ else:
226
+ lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
227
+ lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
228
+
229
+ signal = signal_all_leads[:, lead_index]
230
+
231
+ # Calculate samples per segment
232
+ samples_per_segment = int(segment_duration * fs)
233
+ total_samples = len(signal)
234
+ num_segments = total_samples // samples_per_segment
235
+
236
+ if debug:
237
+ print(f"Splitting signal into {num_segments} segments of {segment_duration} seconds each")
238
+
239
+ segment_files = []
240
+
241
+ # Split and save each segment
242
+ base_name = os.path.splitext(file_path)[0]
243
+
244
+ for i in range(num_segments):
245
+ start_idx = i * samples_per_segment
246
+ end_idx = (i + 1) * samples_per_segment
247
+ segment = signal[start_idx:end_idx]
248
+
249
+ # Save segment
250
+ segment_file = f"{base_name}_segment{i+1}.dat"
251
+ segment.reshape(-1, 1).tofile(segment_file)
252
+ segment_files.append(segment_file)
253
+
254
+ if debug:
255
+ print(f"Saved segment {i+1} to {segment_file}")
256
+
257
+ return segment_files
258
+
259
+ except Exception as e:
260
+ if debug:
261
+ print(f"Error splitting DAT file: {str(e)}")
262
+ return []
263
+
264
+ # Add function to load DAT signals
265
+ def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16, debug=False):
266
+ """
267
+ Load a DAT file containing ECG signal data.
268
+
269
+ Args:
270
+ file_path (str): Path to the DAT file (without extension)
271
+ n_leads (int): Number of leads in the signal
272
+ n_samples (int): Number of samples per lead
273
+ dtype: Data type of the signal
274
+ debug (bool): Whether to print debug information
275
+
276
+ Returns:
277
+ tuple: (numpy array of signal data, sampling frequency)
278
+ """
279
+ try:
280
+ # Handle both cases: with and without .dat extension
281
+ if file_path.endswith('.dat'):
282
+ dat_path = file_path
283
+ else:
284
+ dat_path = file_path + '.dat'
285
+
286
+ if debug:
287
+ print(f"Loading signal from: {dat_path}")
288
+
289
+ raw = np.fromfile(dat_path, dtype=dtype)
290
+
291
+ if debug:
292
+ print(f"Raw data size: {raw.size}")
293
+
294
+ # Try to infer number of leads if read size doesn't match expected
295
+ if raw.size != n_leads * n_samples:
296
+ if debug:
297
+ print(f"Unexpected size: {raw.size}, expected {n_leads * n_samples}")
298
+ print("Attempting to infer number of leads...")
299
+
300
+ # Check if single lead
301
+ if raw.size == n_samples:
302
+ if debug:
303
+ print("Detected single lead signal")
304
+ signal = raw.reshape(n_samples, 1)
305
+ return signal, 500
306
+
307
+ # Try common lead counts
308
+ possible_leads = [1, 2, 3, 6, 12]
309
+ for possible_lead_count in possible_leads:
310
+ if raw.size % possible_lead_count == 0:
311
+ actual_samples = raw.size // possible_lead_count
312
+ if debug:
313
+ print(f"Inferred {possible_lead_count} leads with {actual_samples} samples each")
314
+ signal = raw.reshape(actual_samples, possible_lead_count)
315
+ return signal, 500
316
+
317
+ # If we can't determine it reliably, reshape as single lead
318
+ if debug:
319
+ print("Could not infer lead count, reshaping as single lead")
320
+ signal = raw.reshape(-1, 1)
321
+ return signal, 500
322
+
323
+ # Normal case when size matches expectation
324
+ signal = raw.reshape(n_samples, n_leads)
325
+ return signal, 500 # Signal + sampling frequency
326
+ except Exception as e:
327
+ if debug:
328
+ print(f"Error loading DAT file: {str(e)}")
329
+ # Return empty signal with single channel
330
+ return np.zeros((n_samples, 1)), 500
331
+
332
+ # Add the feature extraction function
333
+ def extract_features_from_signal(signal, debug=False):
334
+ """
335
+ Extract features from an ECG signal.
336
+
337
+ Args:
338
+ signal (numpy.ndarray): ECG signal
339
+ debug (bool): Whether to print debug information
340
+
341
+ Returns:
342
+ list: Features extracted from the signal
343
+ """
344
+ if debug:
345
+ print("Extracting features from signal...")
346
+
347
+ features = []
348
+ features.append(np.mean(signal))
349
+ features.append(np.std(signal))
350
+ features.append(np.median(signal))
351
+ features.append(np.min(signal))
352
+ features.append(np.max(signal))
353
+ features.append(np.percentile(signal, 25))
354
+ features.append(np.percentile(signal, 75))
355
+ features.append(np.mean(np.diff(signal)))
356
+
357
+ if debug:
358
+ print("Computing wavelet decomposition...")
359
+
360
+ coeffs = pywt.wavedec(signal, 'db4', level=5)
361
+ for i, coeff in enumerate(coeffs):
362
+ features.append(np.mean(coeff))
363
+ features.append(np.std(coeff))
364
+ features.append(np.min(coeff))
365
+ features.append(np.max(coeff))
366
+
367
+ if debug and i == 0:
368
+ print(f"Wavelet features for level {i}: mean={np.mean(coeff):.4f}, std={np.std(coeff):.4f}")
369
+
370
+ if debug:
371
+ print(f"Extracted {len(features)} features")
372
+
373
+ return features
374
+
375
+ # Add the classify_new_ecg function
376
+ def classify_new_ecg(file_path, model, debug=False):
377
+ """
378
+ Classify a new ECG file.
379
+
380
+ Args:
381
+ file_path (str): Path to the ECG file (without extension)
382
+ model: The trained model for classification
383
+ debug (bool): Whether to print debug information
384
+
385
+ Returns:
386
+ str: Classification result ("Normal", "Abnormal", or error message)
387
+ """
388
+ try:
389
+ if debug:
390
+ print(f"Classifying ECG from: {file_path}")
391
+
392
+ signal_all_leads, fs = load_dat_signal(file_path, debug=debug)
393
+
394
+ if debug:
395
+ print(f"Loaded signal with shape {signal_all_leads.shape}, sampling rate {fs} Hz")
396
+
397
+ # Choose lead for analysis - priority order
398
+ if signal_all_leads.shape[1] == 1:
399
+ lead_index = 0
400
+ if debug:
401
+ print("Using single lead")
402
+ else:
403
+ lead_priority = [1, 0] # Try Lead II (index 1), then I (index 0)
404
+ lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
405
+ if debug:
406
+ print(f"Using lead index {lead_index}")
407
+
408
+ # Extract the signal
409
+ signal = signal_all_leads[:, lead_index]
410
+
411
+ # Normalize signal
412
+ signal = (signal - np.mean(signal)) / np.std(signal)
413
+
414
+ if debug:
415
+ print("Signal normalized")
416
+ print(f"Detecting QRS complexes...")
417
+
418
+ # Detect QRS complexes
419
+ try:
420
+ xqrs = processing.XQRS(sig=signal, fs=fs)
421
+ xqrs.detect()
422
+ r_peaks = xqrs.qrs_inds
423
+ if debug:
424
+ print(f"Detected {len(r_peaks)} QRS complexes with XQRS method")
425
+ except Exception as e:
426
+ if debug:
427
+ print(f"XQRS detection failed: {str(e)}")
428
+ print("Falling back to GQRS detector")
429
+ r_peaks = processing.gqrs_detect(sig=signal, fs=fs)
430
+ if debug:
431
+ print(f"Detected {len(r_peaks)} QRS complexes with GQRS method")
432
+
433
+ # Check if we found enough QRS complexes
434
+ if len(r_peaks) < 5:
435
+ if debug:
436
+ print(f"Insufficient beats detected: {len(r_peaks)}")
437
+ return "Insufficient beats"
438
+
439
+ # Calculate RR intervals and QRS durations
440
+ rr_intervals = np.diff(r_peaks) / fs
441
+ qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
442
+
443
+ if debug:
444
+ print(f"Mean RR interval: {np.mean(rr_intervals):.4f} s")
445
+ print(f"Mean QRS duration: {np.mean(qrs_durations) / fs:.4f} s")
446
+
447
+ # Extract features
448
+ features = extract_features_from_signal(signal, debug=debug)
449
+
450
+ # Add rhythm features
451
+ features.extend([
452
+ len(r_peaks),
453
+ np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
454
+ np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
455
+ np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
456
+ np.mean(qrs_durations) if len(qrs_durations) > 0 else 0,
457
+ np.std(qrs_durations) if len(qrs_durations) > 0 else 0
458
+ ])
459
+
460
+ if debug:
461
+ print(f"Final feature vector length: {len(features)}")
462
+
463
+ # Make prediction
464
+ prediction = model.predict([features])[0]
465
+ result = "Abnormal" if prediction == 1 else "Normal"
466
+
467
+ if debug:
468
+ print(f"Classification result: {result} (prediction value: {prediction})")
469
+
470
+ return result
471
+
472
+ except Exception as e:
473
+ error_msg = f"Error: {str(e)}"
474
+ if debug:
475
+ print(error_msg)
476
+ return error_msg
477
+
478
+ # Modify the classify_ecg wrapper function to use the voting approach
479
+ def classify_ecg(file_path, model, is_pdf=False, debug=False):
480
+ """
481
+ Wrapper function that handles both PDF and DAT ECG files with segment voting.
482
+
483
+ Args:
484
+ file_path (str): Path to the ECG file (.pdf or without extension for .dat)
485
+ model: The trained model for classification
486
+ is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
487
+ debug (bool): Enable debug output
488
+
489
+ Returns:
490
+ str: Classification result ("Normal", "Abnormal", or error message)
491
+ """
492
+ try:
493
+ # Check if model is valid
494
+ if model is None:
495
+ return "Error: Model not loaded. Please check model compatibility."
496
+
497
+ if is_pdf:
498
+ if debug:
499
+ print(f"Processing PDF file: {file_path}")
500
+
501
+ # Extract file name without extension for output
502
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
503
+ output_dat = f"{base_name}_digitized.dat"
504
+
505
+ # Digitize the PDF to a DAT file and get segment files
506
+ dat_path, segment_files = digitize_ecg_from_pdf(
507
+ pdf_path=file_path,
508
+ output_file=output_dat,
509
+ debug=debug
510
+ )
511
+
512
+ if debug:
513
+ print(f"Digitized ECG saved to: {dat_path}")
514
+ print(f"Created {len(segment_files)} segment files")
515
+ else:
516
+ if debug:
517
+ print(f"Processing DAT file: {file_path}")
518
+
519
+ # For DAT files, we need to split into segments
520
+ segment_files = split_dat_into_segments(file_path, debug=debug)
521
+
522
+ if not segment_files:
523
+ # If splitting failed, try classifying the whole file
524
+ return classify_new_ecg(file_path, model, debug=debug)
525
+
526
+ # Process each segment and collect votes
527
+ segment_results = []
528
+
529
+ for i, segment_file in enumerate(segment_files):
530
+ if debug:
531
+ print(f"\n--- Processing Segment {i+1} ---")
532
+
533
+ # Get file path without extension
534
+ segment_path = os.path.splitext(segment_file)[0]
535
+
536
+ # Classify this segment
537
+ result = classify_new_ecg(segment_path, model, debug=debug)
538
+
539
+ if debug:
540
+ print(f"Segment {i+1} classification: {result}")
541
+
542
+ segment_results.append(result)
543
+
544
+ # Remove temporary segment files
545
+ try:
546
+ os.remove(segment_file)
547
+ if debug:
548
+ print(f"Removed temporary segment file: {segment_file}")
549
+ except:
550
+ pass
551
+
552
+ # Count results and use majority voting
553
+ if segment_results:
554
+ normal_count = segment_results.count("Normal")
555
+ abnormal_count = segment_results.count("Abnormal")
556
+ error_count = len(segment_results) - normal_count - abnormal_count
557
+
558
+ if debug:
559
+ print(f"\n--- Voting Results ---")
560
+ print(f"Normal votes: {normal_count}")
561
+ print(f"Abnormal votes: {abnormal_count}")
562
+ print(f"Errors/Inconclusive: {error_count}")
563
+
564
+ # Decision rules:
565
+ # 1. If any segment is abnormal, classify as abnormal
566
+ # 2. Only classify as normal if majority of segments are normal
567
+ if abnormal_count > normal_count:
568
+ final_result = "Abnormal"
569
+ elif normal_count > abnormal_count:
570
+ final_result = "Normal"
571
+ else:
572
+ final_result = "Inconclusive"
573
+
574
+ if debug:
575
+ print(f"Final decision: {final_result}")
576
+
577
+ return final_result
578
+ else:
579
+ return "Error: No valid segments to classify"
580
+
581
+ except Exception as e:
582
+ error_msg = f"Classification error: {str(e)}"
583
+ if debug:
584
+ print(error_msg)
585
+ return error_msg
586
+ # Load the saved model
587
+ try:
588
+ model_path = 'voting_classifier.pkl'
589
+ if os.path.exists(model_path):
590
+ voting_loaded = joblib.load(model_path)
591
+ else:
592
+ # Try to find the model in the current or parent directories
593
+ for root, dirs, files in os.walk('.'):
594
+ for file in files:
595
+ if file.endswith('.pkl') and 'voting' in file.lower():
596
+ model_path = os.path.join(root, file)
597
+ voting_loaded = joblib.load(model_path)
598
+ break
599
+ if 'voting_loaded' in locals():
600
+ break
601
+
602
+ if 'voting_loaded' not in locals():
603
+ voting_loaded = None
604
+ except Exception as e:
605
+ voting_loaded = None
606
+
607
+ # Simple test for the classify_ecg function
608
+ test_pdf_path = "sample.pdf"
609
+ if os.path.exists(test_pdf_path) and voting_loaded is not None:
610
+ result_pdf = classify_ecg(test_pdf_path, voting_loaded, is_pdf=True)
611
+ print(f"Classification result: {result_pdf}")
ECG/ECG_MultiClass.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ECG Analysis Pipeline: From PDF to Diagnosis
3
+ -------------------------------------------
4
+ This module provides functions to:
5
+ 1. Digitize ECG from PDF files
6
+ 2. Process the digitized ECG signal
7
+ 3. Make diagnoses using a trained model
8
+ """
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import os
13
+ import tensorflow as tf
14
+ import pickle
15
+ from scipy.interpolate import interp1d
16
+ from collections import Counter
17
+ from pdf2image import convert_from_path
18
+ import matplotlib.pyplot as plt # Added for visualization
19
+
20
+ def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=False):
21
+ """
22
+ Process an ECG PDF file and convert it to a .dat signal file.
23
+
24
+ Args:
25
+ pdf_path (str): Path to the ECG PDF file
26
+ output_file (str): Path to save the output .dat file (default: 'calibrated_ecg.dat')
27
+ debug (bool): Whether to print debug information
28
+
29
+ Returns:
30
+ str: Path to the created .dat file
31
+ """
32
+ if debug:
33
+ print(f"Starting ECG digitization from PDF: {pdf_path}")
34
+
35
+ # Convert PDF to image
36
+ images = convert_from_path(pdf_path)
37
+ temp_image_path = 'temp_ecg_image.jpg'
38
+ images[0].save(temp_image_path, 'JPEG')
39
+
40
+ if debug:
41
+ print(f"Converted PDF to image: {temp_image_path}")
42
+
43
+ # Load the image
44
+ img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
45
+ height, width = img.shape
46
+
47
+ if debug:
48
+ print(f"Image dimensions: {width}x{height}")
49
+
50
+ # Fixed calibration parameters
51
+ calibration = {
52
+ 'seconds_per_pixel': 2.0 / 197.0, # 197 pixels = 2 seconds
53
+ 'mv_per_pixel': 1.0 / 78.8, # 78.8 pixels = 1 mV
54
+ }
55
+
56
+ if debug:
57
+ print(f"Calibration parameters: {calibration}")
58
+
59
+ # Calculate layer boundaries using percentages
60
+ layer1_start = int(height * 35.35 / 100)
61
+ layer1_end = int(height * 51.76 / 100)
62
+ layer2_start = int(height * 51.82 / 100)
63
+ layer2_end = int(height * 69.41 / 100)
64
+ layer3_start = int(height * 69.47 / 100)
65
+ layer3_end = int(height * 87.06 / 100)
66
+
67
+ if debug:
68
+ print(f"Layer 1 boundaries: {layer1_start}-{layer1_end}")
69
+ print(f"Layer 2 boundaries: {layer2_start}-{layer2_end}")
70
+ print(f"Layer 3 boundaries: {layer3_start}-{layer3_end}")
71
+
72
+ # Crop each layer
73
+ layers = [
74
+ img[layer1_start:layer1_end, :], # Layer 1
75
+ img[layer2_start:layer2_end, :], # Layer 2
76
+ img[layer3_start:layer3_end, :] # Layer 3
77
+ ]
78
+
79
+ # Process each layer to extract waveform contours
80
+ signals = []
81
+ time_points = []
82
+ layer_duration = 10.0 # Each layer is 10 seconds long
83
+
84
+ for i, layer in enumerate(layers):
85
+ if debug:
86
+ print(f"Processing layer {i+1}...")
87
+
88
+ # Binary thresholding
89
+ _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
90
+
91
+ # Detect contours
92
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
93
+ waveform_contour = max(contours, key=cv2.contourArea) # Largest contour is the ECG
94
+
95
+ if debug:
96
+ print(f" - Found {len(contours)} contours")
97
+ print(f" - Selected contour with {len(waveform_contour)} points")
98
+
99
+ # Sort contour points and extract coordinates
100
+ sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
101
+ x_coords = np.array([point[0][0] for point in sorted_contour])
102
+ y_coords = np.array([point[0][1] for point in sorted_contour])
103
+
104
+ # Calculate isoelectric line (one-third from the bottom)
105
+ isoelectric_line_y = layer.shape[0] * 0.6
106
+
107
+ # Convert to time using fixed layer duration
108
+ x_min, x_max = np.min(x_coords), np.max(x_coords)
109
+ time = (x_coords - x_min) / (x_max - x_min) * layer_duration
110
+
111
+ # Calculate signal in millivolts and apply baseline correction
112
+ signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
113
+ signal_mv = signal_mv - np.mean(signal_mv)
114
+
115
+ if debug:
116
+ print(f" - Layer {i+1} signal range: {np.min(signal_mv):.2f} mV to {np.max(signal_mv):.2f} mV")
117
+
118
+ # Store the time points and calibrated signal
119
+ time_points.append(time)
120
+ signals.append(signal_mv)
121
+
122
+ # Combine signals with proper time alignment
123
+ total_duration = layer_duration * len(layers)
124
+ sampling_frequency = 500 # Standard ECG frequency
125
+ num_samples = int(total_duration * sampling_frequency)
126
+ combined_time = np.linspace(0, total_duration, num_samples)
127
+ combined_signal = np.zeros(num_samples)
128
+
129
+ if debug:
130
+ print(f"Combining signals with {sampling_frequency} Hz sampling rate, total duration: {total_duration}s")
131
+
132
+ # Place each lead at the correct time position
133
+ for i, (time, signal) in enumerate(zip(time_points, signals)):
134
+ start_time = i * layer_duration
135
+ mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
136
+ relevant_times = combined_time[mask]
137
+ interpolated_signal = np.interp(relevant_times, start_time + time, signal)
138
+ combined_signal[mask] = interpolated_signal
139
+
140
+ if debug:
141
+ print(f" - Added layer {i+1} signal from {start_time}s to {start_time + layer_duration}s")
142
+
143
+ # Baseline correction and amplitude scaling
144
+ combined_signal = combined_signal - np.mean(combined_signal)
145
+ signal_peak = np.max(np.abs(combined_signal))
146
+ target_amplitude = 2.0 # Target peak amplitude in mV
147
+
148
+ if debug:
149
+ print(f"Signal peak before scaling: {signal_peak:.2f} mV")
150
+
151
+ if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
152
+ scaling_factor = target_amplitude / signal_peak
153
+ combined_signal = combined_signal * scaling_factor
154
+ if debug:
155
+ print(f"Applied scaling factor: {scaling_factor:.2f}")
156
+ print(f"Signal peak after scaling: {np.max(np.abs(combined_signal)):.2f} mV")
157
+
158
+ # Convert to 16-bit integers and save as .dat file
159
+ adc_gain = 1000.0 # Standard gain: 1000 units per mV
160
+ int_signal = (combined_signal * adc_gain).astype(np.int16)
161
+ int_signal.tofile(output_file)
162
+
163
+ if debug:
164
+ print(f"Saved signal to {output_file} with {len(int_signal)} samples")
165
+ print(f"Integer signal range: {np.min(int_signal)} to {np.max(int_signal)}")
166
+
167
+ # Clean up temporary files
168
+ if os.path.exists(temp_image_path):
169
+ os.remove(temp_image_path)
170
+ if debug:
171
+ print(f"Removed temporary image: {temp_image_path}")
172
+
173
+ return output_file
174
+
175
+ def visualize_ecg_signal(signal, sampling_rate=500, title="Digitized ECG Signal"):
176
+ """
177
+ Visualize an ECG signal with proper time axis.
178
+
179
+ Parameters:
180
+ -----------
181
+ signal : numpy.ndarray
182
+ ECG signal data
183
+ sampling_rate : int
184
+ Sampling rate in Hz
185
+ title : str
186
+ Plot title
187
+ """
188
+ # Calculate time axis
189
+ time = np.arange(len(signal)) / sampling_rate
190
+
191
+ # Create figure with appropriate size
192
+ plt.figure(figsize=(15, 5))
193
+ plt.plot(time, signal)
194
+ plt.title(title)
195
+ plt.xlabel('Time (seconds)')
196
+ plt.ylabel('Amplitude (mV)')
197
+ plt.grid(True)
198
+
199
+ # Add 1mV scale bar
200
+ plt.plot([1, 1], [-0.5, 0.5], 'r-', linewidth=2)
201
+ plt.text(1.1, 0, '1mV', va='center')
202
+
203
+ # Add time scale bar (1 second)
204
+ y_min = np.min(signal)
205
+ plt.plot([1, 2], [y_min, y_min], 'r-', linewidth=2)
206
+ plt.text(1.5, y_min - 0.1, '1s', ha='center')
207
+
208
+ plt.tight_layout()
209
+ plt.show()
210
+
211
+ def read_lead_i_long_dat_file(dat_file_path, sampling_rate=500, data_format='16', scale_factor=0.001):
212
+ """
213
+ Read a 30-second pure Lead I .dat file directly and properly scale it
214
+
215
+ Parameters:
216
+ -----------
217
+ dat_file_path : str
218
+ Path to the .dat file (with or without .dat extension)
219
+ sampling_rate : int
220
+ Sampling rate in Hz (default 500Hz)
221
+ data_format : str
222
+ Data format of the binary file: '16' for 16-bit integers, '32' for 32-bit floats
223
+ scale_factor : float
224
+ Scale factor to convert units (0.001 for converting µV to mV)
225
+
226
+ Returns:
227
+ --------
228
+ numpy.ndarray
229
+ ECG signal data for Lead I with shape (total_samples,)
230
+ """
231
+ # Ensure the path ends with .dat
232
+ if not dat_file_path.endswith('.dat'):
233
+ dat_file_path += '.dat'
234
+
235
+ # Expected samples for full 30 seconds
236
+ expected_samples = sampling_rate * 30
237
+
238
+ # Read the binary data
239
+ try:
240
+ if data_format == '16':
241
+ # 16-bit signed integers (common format for ECG)
242
+ data = np.fromfile(dat_file_path, dtype=np.int16)
243
+ elif data_format == '32':
244
+ # 32-bit floating point (less common)
245
+ data = np.fromfile(dat_file_path, dtype=np.float32)
246
+ else:
247
+ raise ValueError(f"Unsupported data format: {data_format}")
248
+
249
+ # Apply scaling to convert µV to mV
250
+ signal = data * scale_factor
251
+
252
+ # Handle if signal is not exactly 30 seconds
253
+ if len(signal) < expected_samples:
254
+ # Pad with zeros if too short
255
+ padded_signal = np.zeros(expected_samples)
256
+ padded_signal[:len(signal)] = signal
257
+ signal = padded_signal
258
+ elif len(signal) > expected_samples:
259
+ # Truncate if too long
260
+ signal = signal[:expected_samples]
261
+
262
+ return signal
263
+
264
+ except Exception as e:
265
+ raise
266
+
267
+ def segment_signal(signal, sampling_rate=500):
268
+ """
269
+ Segment a 30-second signal into three 10-second segments
270
+
271
+ Parameters:
272
+ -----------
273
+ signal : numpy.ndarray
274
+ The full signal to segment
275
+ sampling_rate : int
276
+ Sampling rate in Hz
277
+
278
+ Returns:
279
+ --------
280
+ list
281
+ List of three 10-second signal segments
282
+ """
283
+ # Calculate samples per segment (10 seconds)
284
+ segment_samples = sampling_rate * 10
285
+
286
+ # Expected samples for full 30 seconds
287
+ expected_samples = sampling_rate * 30
288
+
289
+ # Ensure the signal is 30 seconds long
290
+ if len(signal) != expected_samples:
291
+ # Resample to 30 seconds
292
+ x = np.linspace(0, 1, len(signal))
293
+ x_new = np.linspace(0, 1, expected_samples)
294
+ f = interp1d(x, signal, kind='linear', bounds_error=False, fill_value="extrapolate")
295
+ signal = f(x_new)
296
+
297
+ # Split the signal into three 10-second segments
298
+ segments = []
299
+ for i in range(3):
300
+ start_idx = i * segment_samples
301
+ end_idx = (i + 1) * segment_samples
302
+ segment = signal[start_idx:end_idx]
303
+ segments.append(segment)
304
+
305
+ return segments
306
+
307
+ def process_segment(segment, sampling_rate=500):
308
+ """
309
+ Process a segment of ECG data to ensure it's properly formatted for the model
310
+
311
+ Parameters:
312
+ -----------
313
+ segment : numpy.ndarray
314
+ Raw ECG segment
315
+ sampling_rate : int
316
+ Sampling rate of the ECG
317
+
318
+ Returns:
319
+ --------
320
+ numpy.ndarray
321
+ Processed segment ready for model input
322
+ """
323
+ # Ensure correct length (5000 samples for 10 seconds)
324
+ if len(segment) != 5000:
325
+ x = np.linspace(0, 1, len(segment))
326
+ x_new = np.linspace(0, 1, 5000)
327
+ f = interp1d(x, segment, kind='linear', bounds_error=False, fill_value="extrapolate")
328
+ segment = f(x_new)
329
+
330
+ return segment
331
+
332
+ def predict_with_voting(dat_file_path, model_path, mlb_path=None, sampling_rate=500, scale_factor=0.001, debug=False):
333
+ """
334
+ Process a 30-second .dat file, properly scale it, split it into three 10-second segments,
335
+ make predictions on each segment, and return the class with highest average probability.
336
+
337
+ Parameters:
338
+ -----------
339
+ dat_file_path : str
340
+ Path to the .dat file
341
+ model_path : str
342
+ Path to the saved model (.h5 file)
343
+ mlb_path : str, optional
344
+ Path to the saved MultiLabelBinarizer pickle file for label decoding
345
+ sampling_rate : int
346
+ Sampling rate in Hz (default 500Hz)
347
+ scale_factor : float
348
+ Scale factor to convert units (0.001 for converting µV to mV)
349
+ debug : bool
350
+ Whether to print debug information
351
+
352
+ Returns:
353
+ --------
354
+ dict
355
+ Dictionary containing segment predictions and final class probabilities
356
+ """
357
+ try:
358
+ # Step 1: Read the 30-second ECG data (pure Lead I) and apply scaling
359
+ if debug:
360
+ print(f"Reading signal from {dat_file_path}")
361
+
362
+ full_signal = read_lead_i_long_dat_file(
363
+ dat_file_path,
364
+ sampling_rate=sampling_rate,
365
+ scale_factor=scale_factor
366
+ )
367
+
368
+ if debug:
369
+ print(f"Signal loaded: {len(full_signal)} samples, range: {np.min(full_signal):.2f} to {np.max(full_signal):.2f} mV")
370
+
371
+ # Step 2: Split into three 10-second segments
372
+ segments = segment_signal(full_signal, sampling_rate)
373
+
374
+ if debug:
375
+ print(f"Split into {len(segments)} segments of {len(segments[0])} samples each")
376
+
377
+ # Step 3: Load the model (load once to improve performance)
378
+ if debug:
379
+ print(f"Loading model from {model_path}")
380
+
381
+ model = tf.keras.models.load_model(model_path)
382
+
383
+ # Load MLB if provided
384
+ mlb = None
385
+ if mlb_path and os.path.exists(mlb_path):
386
+ if debug:
387
+ print(f"Loading label binarizer from {mlb_path}")
388
+ with open(mlb_path, 'rb') as f:
389
+ mlb = pickle.load(f)
390
+
391
+ # Step 4: Process each segment and collect predictions
392
+ segment_results = []
393
+ all_predictions = []
394
+
395
+ for i, segment in enumerate(segments):
396
+ if debug:
397
+ print(f"Processing segment {i+1}...")
398
+
399
+ # Process the segment to ensure it's properly formatted
400
+ processed_segment = process_segment(segment)
401
+
402
+ # Reshape for model input (batch, time, channels)
403
+ X = processed_segment.reshape(1, 5000, 1)
404
+
405
+ # Make predictions
406
+ predictions = model.predict(X, verbose=0)
407
+ all_predictions.append(predictions[0])
408
+
409
+ # Process segment results
410
+ segment_result = {"raw_predictions": predictions[0].tolist()}
411
+
412
+ # Decode labels if MLB is provided
413
+ if mlb is not None:
414
+ # Add class probabilities
415
+ class_probs = {}
416
+ for j, class_name in enumerate(mlb.classes_):
417
+ class_probs[class_name] = float(predictions[0][j])
418
+
419
+ segment_result["class_probabilities"] = class_probs
420
+
421
+ segment_results.append(segment_result)
422
+
423
+ # Step 5: Calculate average probabilities across all segments
424
+ final_result = {"segment_results": segment_results}
425
+
426
+ # Average the raw predictions
427
+ avg_predictions = np.mean(all_predictions, axis=0)
428
+ final_result["averaged_raw_predictions"] = avg_predictions.tolist()
429
+
430
+ # Calculate final class probabilities (average across segments)
431
+ if mlb is not None:
432
+ # Calculate average probability for each class
433
+ final_class_probs = {}
434
+ for cls_idx, cls_name in enumerate(mlb.classes_):
435
+ final_class_probs[cls_name] = float(np.mean([pred[cls_idx] for pred in all_predictions]))
436
+
437
+ # Find the class with highest average probability
438
+ top_class = max(final_class_probs.items(), key=lambda x: x[1])
439
+ top_class_name = top_class[0]
440
+
441
+ final_result["final_class_probabilities"] = final_class_probs
442
+ final_result["top_class"] = top_class_name
443
+
444
+ if debug:
445
+ print(f"Top class by average probability: {top_class_name} ({top_class[1]:.2f})")
446
+
447
+ return final_result
448
+
449
+ except Exception as e:
450
+ if debug:
451
+ print(f"Error in predict_with_voting: {str(e)}")
452
+ return {"error": str(e)}
453
+
454
+ def analyze_ecg_pdf(pdf_path, model_path, mlb_path=None, temp_dat_file='calibrated_ecg.dat', cleanup=True, debug=False, visualize=False):
455
+ """
456
+ Complete ECG analysis pipeline: digitizes a PDF ECG, analyzes it with the model,
457
+ and returns the diagnosis with highest probability.
458
+
459
+ Args:
460
+ pdf_path (str): Path to the ECG PDF file
461
+ model_path (str): Path to the model (.h5) file
462
+ mlb_path (str, optional): Path to the MultiLabelBinarizer file
463
+ temp_dat_file (str, optional): Path to save the temporary digitized file
464
+ cleanup (bool, optional): Whether to remove temporary files after processing
465
+ debug (bool, optional): Whether to print debug information
466
+ visualize (bool, optional): Whether to visualize the digitized signal
467
+
468
+ Returns:
469
+ dict: {
470
+ "diagnosis": str, # Top diagnosis (highest average probability)
471
+ "probability": float, # Probability of top diagnosis
472
+ "all_probabilities": dict, # All diagnoses with probabilities
473
+ "digitized_file": str # Path to digitized file (if cleanup=False)
474
+ }
475
+ """
476
+ # Silence TensorFlow warnings
477
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
478
+
479
+ if debug:
480
+ print(f"Starting ECG analysis pipeline for {pdf_path}")
481
+
482
+ # 1. Digitize ECG from PDF to DAT file
483
+ dat_file_path = digitize_ecg_from_pdf(pdf_path, output_file=temp_dat_file, debug=debug)
484
+
485
+ # Visualize the digitized signal if requested
486
+ if visualize:
487
+ signal = read_lead_i_long_dat_file(dat_file_path, scale_factor=0.001)
488
+ visualize_ecg_signal(signal, title=f"Digitized ECG from {os.path.basename(pdf_path)}")
489
+
490
+ # 2. Process DAT file with model
491
+ if debug:
492
+ print("Processing digitized signal with model...")
493
+
494
+ results = predict_with_voting(
495
+ dat_file_path,
496
+ model_path,
497
+ mlb_path,
498
+ scale_factor=0.001, # Convert microvolts to millivolts
499
+ debug=debug
500
+ )
501
+
502
+ # 3. Extract top diagnosis (highest probability)
503
+ top_diagnosis = {
504
+ "diagnosis": None,
505
+ "probability": 0.0,
506
+ "all_probabilities": {},
507
+ "digitized_file": dat_file_path
508
+ }
509
+
510
+ # If we have class probabilities, find the highest one
511
+ if "final_class_probabilities" in results:
512
+ probs = results["final_class_probabilities"]
513
+ top_diagnosis["all_probabilities"] = probs
514
+
515
+ # Use the top class directly from the results
516
+ if "top_class" in results:
517
+ top_diagnosis["diagnosis"] = results["top_class"]
518
+ top_diagnosis["probability"] = probs[results["top_class"]]
519
+
520
+ # Clean up temporary files if requested
521
+ if cleanup and os.path.exists(dat_file_path):
522
+ if debug:
523
+ print(f"Cleaning up temporary file: {dat_file_path}")
524
+ os.remove(dat_file_path)
525
+ top_diagnosis.pop("digitized_file")
526
+
527
+ if debug:
528
+ print(f"Analysis complete. Diagnosis: {top_diagnosis['diagnosis']} (Probability: {top_diagnosis['probability']:.2f})")
529
+
530
+ return top_diagnosis
531
+
532
+ # Example usage
533
+ if __name__ == "__main__":
534
+ # Path configuration
535
+ sample_pdf = 'samplebayez.pdf'
536
+ model_path = 'deep-multiclass.h5' # Update with actual path
537
+ mlb_path = 'deep-multiclass.pkl' # Update with actual path
538
+
539
+ # Analyze ECG with debug output and visualization
540
+ result = analyze_ecg_pdf(
541
+ sample_pdf,
542
+ model_path,
543
+ mlb_path,
544
+ cleanup=False, # Keep the digitized file
545
+ debug=False, # Print debug information
546
+ visualize=False # Visualize the digitized signal
547
+ )
548
+
549
+ # Display result
550
+ if result["diagnosis"]:
551
+ print(f"Diagnosis: {result['diagnosis']} ")
552
+
553
+ else:
554
+ print("No clear diagnosis found")
app.py CHANGED
@@ -14,7 +14,8 @@ from SkinBurns_Segmentation import segment_burn
14
  import requests
15
  import joblib
16
  import numpy as np
17
- from ECG import classify_new_ecg
 
18
  from ultralytics import YOLO
19
  import tensorflow as tf
20
  from fastapi import HTTPException
@@ -208,35 +209,62 @@ def transform_image():
208
  return {"error": str(e)}
209
 
210
  @app.post("/classify-ecg")
211
- async def classify_ecg(files: list[UploadFile] = File(...)):
212
  model = joblib.load('voting_classifier.pkl')
213
-
214
- temp_dir = f"temp_ecg_{uuid.uuid4()}"
215
- os.makedirs(temp_dir, exist_ok=True)
216
 
217
  try:
218
- for file in files:
219
- file_path = os.path.join(temp_dir, file.filename)
220
- with open(file_path, "wb") as f:
221
- f.write(file.file.read()) # Replacing shutil.copyfileobj
222
 
223
- # Assume both .hea and .dat have same base name
224
- base_names = set(os.path.splitext(file.filename)[0] for file in files)
225
- if len(base_names) != 1:
226
- return JSONResponse(content={"error": "Files must have the same base name"}, status_code=400)
227
 
228
- base_name = list(base_names)[0]
229
- file_path = os.path.join(temp_dir, base_name)
230
 
231
- result = classify_new_ecg(file_path, model)
232
  return {"result": result}
233
 
234
- finally:
235
- # Replace shutil.rmtree with os removal operations
236
- for file_name in os.listdir(temp_dir):
237
- file_path = os.path.join(temp_dir, file_name)
238
- os.remove(file_path)
239
- os.rmdir(temp_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
 
242
  @app.post("/process_video")
 
14
  import requests
15
  import joblib
16
  import numpy as np
17
+ from ECG.ECG_Classify import classify_ecg
18
+ from ECG.ECG_MultiClass import analyze_ecg_pdf
19
  from ultralytics import YOLO
20
  import tensorflow as tf
21
  from fastapi import HTTPException
 
209
  return {"error": str(e)}
210
 
211
  @app.post("/classify-ecg")
212
+ async def classify_ecg_endpoint(file: UploadFile = File(...)):
213
  model = joblib.load('voting_classifier.pkl')
214
+ # Load the model
 
 
215
 
216
  try:
217
+ # Save the uploaded file temporarily
218
+ temp_file_path = f"temp_{file.filename}"
219
+ with open(temp_file_path, "wb") as temp_file:
220
+ temp_file.write(await file.read())
221
 
222
+ # Call the ECG classification function
223
+ result = classify_ecg(temp_file_path, model, debug=True, is_pdf=True)
 
 
224
 
225
+ # Remove the temporary file
226
+ os.remove(temp_file_path)
227
 
 
228
  return {"result": result}
229
 
230
+ except Exception as e:
231
+ return JSONResponse(content={"error": str(e)}, status_code=500)
232
+
233
+
234
+ @app.post("/diagnose-ecg")
235
+ async def diagnose_ecg(file: UploadFile = File(...)):
236
+ try:
237
+ # Save the uploaded file temporarily
238
+ temp_file_path = f"temp_{file.filename}"
239
+ with open(temp_file_path, "wb") as temp_file:
240
+ temp_file.write(await file.read())
241
+
242
+ model_path = 'deep-multiclass.h5' # Update with actual path
243
+ mlb_path = 'deep-multiclass.pkl' # Update with actual path
244
+
245
+
246
+ # Call the ECG classification function
247
+ result = analyze_ecg_pdf(
248
+ temp_file_path,
249
+ model_path,
250
+ mlb_path,
251
+ cleanup=False, # Keep the digitized file
252
+ debug=False, # Print debug information
253
+ visualize=False # Visualize the digitized signal
254
+ )
255
+
256
+
257
+ # Remove the temporary file
258
+ os.remove(temp_file_path)
259
+
260
+ if result and result["diagnosis"]:
261
+ return {"result": result["diagnosis"]}
262
+ else:
263
+ return {"result": "No diagnosis"}
264
+
265
+ except Exception as e:
266
+ return JSONResponse(content={"error": str(e)}, status_code=500)
267
+
268
 
269
 
270
  @app.post("/process_video")
ECG/00008_hr.dat → deep-multiclass.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:119d2eaf8e8aaa4a091f9f76523bf559691cfd8dd4d936cf724af1f630357233
3
- size 120000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2fdc1bcbf7820ae426a9a74f0210c738884ee2bb872bccff55a4036e2c642e1
3
+ size 8941912
ECG/00001_hr.dat → deep-multiclass.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0797d0a8c43d5cc05bb0a73026a0f6d9f358b459849674ebd38abd9120241bcf
3
- size 120000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6bd3507a56fc3e77f54e6fe0772888f61b22a2273021a04c288170237e1bc4b
3
+ size 346
requirements.txt CHANGED
@@ -102,4 +102,5 @@ wfdb==4.3.0
102
  wrapt==1.17.2
103
  yarl==1.20.0
104
  websockets
105
- xgboost
 
 
102
  wrapt==1.17.2
103
  yarl==1.20.0
104
  websockets
105
+ xgboost
106
+ pdf2image