mohdfaizanali commited on
Commit
c716961
·
verified ·
1 Parent(s): ea7f82a

ecg_analysis_hf

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -35
  2. .gitignore +2 -0
  3. Dockerfile +25 -0
  4. README.md +10 -10
  5. app.py +72 -0
  6. ecg_model.py +338 -0
  7. requirements.txt +9 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ venv
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile
2
+ FROM python:3.10-slim
3
+
4
+ # create non-root user (Spaces prefers UID 1000)
5
+ RUN useradd -m -u 1000 appuser
6
+
7
+ # copy files
8
+ WORKDIR /app
9
+ COPY . /app
10
+
11
+ # install system deps if needed
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ build-essential git-lfs && \
14
+ rm -rf /var/lib/apt/lists/*
15
+
16
+ # install python deps
17
+ RUN pip install --upgrade pip
18
+ RUN pip install --no-cache-dir -r /app/requirements.txt
19
+
20
+ # expose the port Spaces expects
21
+ ENV PORT=7860
22
+ EXPOSE 7860
23
+
24
+ # run uvicorn (app:app must be your FastAPI object)
25
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,10 @@
1
- ---
2
- title: Ecg Analysis Hf
3
- emoji: 🏃
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Ecg Analysis Hf
3
+ emoji: 🏃
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import tempfile
5
+ from fastapi import FastAPI, UploadFile, File, HTTPException
6
+ from fastapi.responses import JSONResponse, FileResponse
7
+ from ecg_model import predictor # your predictor instance
8
+ import scipy.io
9
+
10
+ app = FastAPI(title="ECG Analysis API")
11
+
12
+
13
+
14
+
15
+ @app.post("/extract_signals/")
16
+ async def extract_signals(file: UploadFile = File(...)):
17
+ """
18
+ Upload an ECG IMAGE (png/jpg). Returns extracted 12-lead signals (list of lists).
19
+ """
20
+ try:
21
+ content = await file.read()
22
+ result = predictor.analyze_image(content, visualize=False)
23
+ if result is None:
24
+ raise HTTPException(status_code=400, detail="Failed to extract signals or analyze image")
25
+ # return signals and basic metadata
26
+ return JSONResponse({
27
+ "filename": file.filename,
28
+ "signals": result.get("signals"),
29
+ "confidence": result.get("confidence"),
30
+ "predicted_conditions": result.get("predicted_conditions"),
31
+ "probabilities": result.get("probabilities"),
32
+ "risk_score": result.get("risk_score")
33
+ })
34
+ except Exception as e:
35
+ raise HTTPException(status_code=500, detail=str(e))
36
+
37
+ @app.post("/create_mat/")
38
+ async def create_mat(file: UploadFile = File(...)):
39
+ """
40
+ Upload an ECG IMAGE and receive a .mat file containing:
41
+ - val : ndarray (12 x 1000) signals
42
+ - meta: dict with filename and sampling info
43
+ Returns the .mat file as a download.
44
+ """
45
+ try:
46
+ content = await file.read()
47
+ result = predictor.analyze_image(content, visualize=False)
48
+ # if result is None or "signals" not in result:
49
+ # raise HTTPException(status_code=400, detail="Failed to extract signals")
50
+
51
+ signals = result["signals"]
52
+ # # ensure numpy array
53
+ # arr = None
54
+ try:
55
+ # import numpy as np
56
+ # arr = np.array(signals, dtype=np.float32)
57
+ return {"val": signals}
58
+
59
+ except Exception as e:
60
+ raise HTTPException(status_code=500, detail=f"Signals conversion error: {e}")
61
+
62
+ # create temp .mat
63
+ # # tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mat")
64
+ # mat_dict = {"val": arr}
65
+ # scipy.io.savemat(tmp.name, mat_dict)
66
+ # tmp.close()
67
+
68
+ # return FileResponse(tmp.name, filename=f"{os.path.splitext(file.filename)[0]}.mat", media_type="application/octet-stream")
69
+ except HTTPException:
70
+ raise
71
+ except Exception as e:
72
+ raise HTTPException(status_code=500, detail=str(e))
ecg_model.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ecg_model.py
2
+ import os
3
+ import io
4
+ import pickle
5
+ import tempfile
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from huggingface_hub import hf_hub_download
10
+ from transformers import AutoModel
11
+ import cv2
12
+ from scipy.interpolate import interp1d
13
+ from scipy.signal import savgol_filter, butter, lfilter
14
+ import matplotlib.pyplot as plt
15
+ from scipy.io import savemat # for saving .mat if needed
16
+
17
+ # ========== HF Repo & files ==========
18
+ REPO_ID = "milanchndr/hubert-ecg-finetuned" # change if needed
19
+ REQUIRED_FILES = [
20
+ "hubert_ecg_superclass_best.pt",
21
+ "class_info.pkl",
22
+ "threshold_optimizer.pkl"
23
+ ]
24
+
25
+ _local_files = {}
26
+ for fname in REQUIRED_FILES:
27
+ try:
28
+ path = hf_hub_download(repo_id=REPO_ID, filename=fname)
29
+ _local_files[fname] = path
30
+ print(f"Downloaded {fname} -> {path}")
31
+ except Exception as e:
32
+ print(f"Could not download {fname}: {e}")
33
+
34
+ # ========== Model class ==========
35
+ class SuperclassHuBERTECG(nn.Module):
36
+ def __init__(self, num_labels=5, dropout=0.2):
37
+ super().__init__()
38
+ # Use the base HuBERT ECG model repo; adjust if another name is used
39
+ self.hubert_ecg = AutoModel.from_pretrained("Edoardo-BS/hubert-ecg-base",
40
+ trust_remote_code=True, torch_dtype="auto")
41
+ # freeze feature extractor
42
+ if hasattr(self.hubert_ecg, "feature_extractor"):
43
+ for param in self.hubert_ecg.feature_extractor.parameters():
44
+ param.requires_grad = False
45
+ hidden_size = getattr(self.hubert_ecg.config, "hidden_size", 768)
46
+ self.layer_norm = nn.LayerNorm(hidden_size)
47
+ self.dropout = nn.Dropout(dropout)
48
+ self.classifier = nn.Linear(hidden_size, num_labels)
49
+
50
+ def forward(self, x):
51
+ outputs = self.hubert_ecg(x)
52
+ hidden_states = self.layer_norm(outputs.last_hidden_state)
53
+ pooled = torch.mean(hidden_states, dim=1)
54
+ return self.classifier(self.dropout(pooled))
55
+
56
+ # ========== ThresholdOptimizer fallback ==========
57
+ class ThresholdOptimizer:
58
+ def __init__(self):
59
+ self.optimal_thresholds = np.array([0.5, 0.5, 0.5, 0.5, 0.5])
60
+ def predict(self, probs):
61
+ return (probs >= self.optimal_thresholds).astype(int)
62
+
63
+ # ========== ECG Image Processor ==========
64
+ class ECGImageProcessor:
65
+ def __init__(self):
66
+ self.leads = ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6']
67
+
68
+ def process_image(self, image_bytes):
69
+ """Input: raw bytes of an image. Output: signals (12,1000) float32, original BGR image."""
70
+ try:
71
+ nparr = np.frombuffer(image_bytes, np.uint8)
72
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
73
+ if img is None:
74
+ raise ValueError("Image decode returned None")
75
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
76
+ clean = self._preprocess_image(gray)
77
+ signals = self._extract_signals(clean)
78
+ return signals.astype(np.float32), img
79
+ except Exception as e:
80
+ print(f"process_image error: {e}")
81
+ return None, None
82
+
83
+ def _preprocess_image(self, gray):
84
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
85
+ enhanced = clahe.apply(gray)
86
+ h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40,1))
87
+ v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,40))
88
+ h_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, h_kernel)
89
+ v_lines = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, v_kernel)
90
+ grid_mask = cv2.addWeighted(h_lines, 0.5, v_lines, 0.5, 0)
91
+ clean = cv2.subtract(enhanced, grid_mask)
92
+ clean = cv2.bilateralFilter(clean, 9, 75, 75)
93
+ return clean
94
+
95
+ def _extract_signals(self, clean_image):
96
+ h, w = clean_image.shape
97
+ signals = np.zeros((12, 1000))
98
+ # positions are heuristics — adjust for your ECG sheet layout
99
+ positions = [
100
+ (0,0),(1,0),(2,0),
101
+ (0,1),(1,1),(2,1),
102
+ (0,2),(1,2),(2,2),
103
+ (0,3),(1,3),(2,3)
104
+ ]
105
+ for i, (row, col) in enumerate(positions):
106
+ margin_y = int(h * 0.05)
107
+ margin_x = int(w * 0.02)
108
+ y1 = int(row * h / 3) + margin_y
109
+ y2 = int((row + 1) * h / 3) - margin_y
110
+ x1 = int(col * w / 4) + margin_x
111
+ x2 = int((col + 1) * w / 4) - margin_x
112
+ if y2 > y1 and x2 > x1:
113
+ region = clean_image[y1:y2, x1:x2]
114
+ signal = self._extract_signal_from_region(region)
115
+ if self._is_valid_signal(signal):
116
+ signals[i,:] = signal
117
+ else:
118
+ signals[i,:] = self._generate_realistic_signal(i)
119
+ else:
120
+ signals[i,:] = self._generate_realistic_signal(i)
121
+ return signals
122
+
123
+ def _extract_signal_from_region(self, region):
124
+ if region.size == 0:
125
+ return np.zeros(1000)
126
+ reg_h, reg_w = region.shape
127
+ signal_points = []
128
+ step = max(1, reg_w // 200)
129
+ for x in range(0, reg_w, step):
130
+ col = region[:, min(x, reg_w-1)]
131
+ dark_threshold = np.percentile(col, 10)
132
+ dark_pixels = np.where(col <= dark_threshold)[0]
133
+ if len(dark_pixels) > 0:
134
+ ecg_y = np.median(dark_pixels)
135
+ val = (reg_h - ecg_y) / reg_h - 0.5
136
+ signal_points.append(val)
137
+ else:
138
+ signal_points.append(signal_points[-1] if signal_points else 0.0)
139
+ return self._clean_and_resample(signal_points)
140
+
141
+ def _clean_and_resample(self, signal_points):
142
+ signal = np.array(signal_points, dtype=float)
143
+ if len(signal) > 5:
144
+ q75, q25 = np.percentile(signal, [75,25])
145
+ iqr = q75 - q25
146
+ if iqr > 0:
147
+ lb = q25 - 1.5 * iqr
148
+ ub = q75 + 1.5 * iqr
149
+ signal = np.clip(signal, lb, ub)
150
+ if len(signal) != 1000:
151
+ x_old = np.linspace(0, 1, len(signal))
152
+ x_new = np.linspace(0, 1, 1000)
153
+ f = interp1d(x_old, signal, kind='linear', bounds_error=False, fill_value='extrapolate')
154
+ signal = f(x_new)
155
+ signal = signal - np.mean(signal)
156
+ if len(signal) >= 5:
157
+ signal = savgol_filter(signal, window_length=5, polyorder=2)
158
+ return signal
159
+
160
+ def _is_valid_signal(self, signal):
161
+ if len(signal) == 0:
162
+ return False
163
+ std_dev = np.std(signal)
164
+ signal_range = np.max(signal) - np.min(signal)
165
+ return std_dev > 0.01 and signal_range > 0.05
166
+
167
+ def _generate_realistic_signal(self, lead_idx):
168
+ t = np.linspace(0, 10, 1000)
169
+ amplitudes = [0.8,1.2,0.4,-0.5,0.6,0.7,0.3,0.5,0.9,1.1,1.0,0.8]
170
+ amp = amplitudes[lead_idx] if lead_idx < len(amplitudes) else 0.8
171
+ signal = np.zeros_like(t)
172
+ heart_rate = np.random.normal(75, 5)
173
+ beat_interval = 60 / max(heart_rate, 50)
174
+ for i, time in enumerate(t):
175
+ cycle = (time % beat_interval) / beat_interval
176
+ if 0.08 < cycle < 0.16:
177
+ p_phase = (cycle - 0.08) / 0.08
178
+ signal[i] += amp * 0.2 * np.sin(p_phase * np.pi)
179
+ elif 0.28 < cycle < 0.36:
180
+ qrs_phase = (cycle - 0.28) / 0.08
181
+ signal[i] += amp * np.sin(qrs_phase * np.pi)
182
+ elif 0.48 < cycle < 0.68:
183
+ t_phase = (cycle - 0.48) / 0.2
184
+ signal[i] += amp * 0.3 * np.sin(t_phase * np.pi)
185
+ signal += np.random.normal(0, 0.008, len(signal))
186
+ return signal
187
+
188
+ def visualize(self, original_img, signals):
189
+ fig, axes = plt.subplots(3,5,figsize=(20,12))
190
+ axes[0,0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
191
+ axes[0,0].set_title("Original ECG Image")
192
+ axes[0,0].axis('off')
193
+ for i in range(12):
194
+ row, col = (i+1)//5, (i+1)%5
195
+ if row < 3 and col < 5:
196
+ axes[row,col].plot(signals[i], linewidth=1.5)
197
+ axes[row,col].set_title(self.leads[i] if i < len(self.leads) else f"Lead{i}")
198
+ axes[row,col].grid(True, alpha=0.3)
199
+ axes[row,col].set_xlim(0,1000)
200
+ plt.tight_layout()
201
+ return fig
202
+
203
+ # ========== Predictor (loads model artifacts) ==========
204
+ import torch
205
+ import pickle
206
+ import numpy as np
207
+ from scipy.signal import butter, lfilter
208
+
209
+ class ECGPredictor:
210
+ def __init__(self, model_path=None, class_info_path=None, threshold_path=None):
211
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
212
+
213
+ # --- Load class info ---
214
+ try:
215
+ if class_info_path is None:
216
+ raise FileNotFoundError("class_info_path is None")
217
+ with open(class_info_path, 'rb') as f:
218
+ class_info = pickle.load(f)
219
+ self.classes = class_info.get('classes', ['CD','HYP','MI','NORM','STTC'])
220
+ except Exception as e:
221
+ print(f"class_info load failed: {e}")
222
+ self.classes = ['CD','HYP','MI','NORM','STTC']
223
+
224
+ # --- Load thresholds ---
225
+ try:
226
+ if threshold_path is None:
227
+ raise FileNotFoundError("threshold_path is None")
228
+ with open(threshold_path, 'rb') as f:
229
+ self.threshold_optimizer = pickle.load(f)
230
+ except Exception as e:
231
+ print(f"threshold load failed: {e}")
232
+ self.threshold_optimizer = ThresholdOptimizer()
233
+
234
+ # --- Load model ---
235
+ try:
236
+ self.model = SuperclassHuBERTECG(num_labels=len(self.classes))
237
+ if model_path is None:
238
+ raise FileNotFoundError("model_path is None")
239
+ model_dict = torch.load(model_path, map_location=self.device)
240
+ self.model.load_state_dict(model_dict)
241
+ self.model.to(self.device)
242
+ self.model.eval()
243
+ print("Model loaded.")
244
+ except Exception as e:
245
+ print(f"Model load failed: {e}")
246
+ self.model = None
247
+
248
+ self.processor = ECGImageProcessor()
249
+
250
+ # Bandpass settings
251
+ self.LOWCUT = 0.5
252
+ self.HIGHCUT = 47.0
253
+ self.TARGET_FS = 100
254
+
255
+ # --- Preprocessing functions ---
256
+ def butter_bandpass(self, lowcut, highcut, fs, order=5):
257
+ nyq = 0.5 * fs
258
+ low = lowcut / nyq
259
+ high = highcut / nyq
260
+ b, a = butter(order, [low, high], btype='band')
261
+ return b, a
262
+
263
+ def bandpass_filter(self, data, fs, order=5):
264
+ b, a = self.butter_bandpass(self.LOWCUT, self.HIGHCUT, fs, order=order)
265
+ return lfilter(b, a, data)
266
+
267
+ def preprocess_signals(self, signals):
268
+ """Preprocesses ECG signals: bandpass + normalization"""
269
+ if signals.ndim != 3 or signals.shape[0] == 0:
270
+ raise ValueError(f"Invalid input signals shape: {signals.shape}")
271
+
272
+ filtered_signals = np.zeros_like(signals)
273
+ for i in range(signals.shape[0]): # batch
274
+ for j in range(signals.shape[1]): # leads
275
+ filtered_signals[i, j, :] = self.bandpass_filter(signals[i, j, :], fs=self.TARGET_FS)
276
+
277
+ max_val = np.abs(filtered_signals).max(axis=(1, 2), keepdims=True)
278
+ max_val[max_val == 0] = 1
279
+ return filtered_signals / max_val
280
+
281
+ # --- Main analysis ---
282
+ def analyze_image(self, image_bytes, visualize=False):
283
+ signals, img = self.processor.process_image(image_bytes)
284
+ if signals is None:
285
+ return None
286
+
287
+ # (12, 1000) → (1, 12, 1000) for batch format
288
+ signals = signals[np.newaxis, :, :]
289
+ signals = self.preprocess_signals(signals)
290
+
291
+ if self.model is None:
292
+ probs = np.array([0.05,0.03,0.02,0.88,0.02])
293
+ preds = self.threshold_optimizer.predict(probs.reshape(1,-1))[0]
294
+ return {
295
+ 'signals': signals.tolist(),
296
+ 'probabilities': {n: float(p) for n,p in zip(self.classes, probs)},
297
+ 'predictions': {n: bool(v) for n,v in zip(self.classes, preds)},
298
+ 'predicted_conditions': [n for n,v in zip(self.classes,preds) if v],
299
+ 'confidence': float(np.max(probs)),
300
+ 'risk_score': float(self._calculate_risk(probs))
301
+ }
302
+
303
+ # Segment & run through model
304
+ seg1 = signals[:, :, :500].reshape(1, -1)
305
+ seg2 = signals[:, :, 500:].reshape(1, -1)
306
+
307
+ with torch.no_grad():
308
+ t1 = torch.tensor(seg1, dtype=torch.float32).to(self.device)
309
+ t2 = torch.tensor(seg2, dtype=torch.float32).to(self.device)
310
+ raw1 = self.model(t1).cpu().numpy()[0]
311
+ raw2 = self.model(t2).cpu().numpy()[0]
312
+ p1 = torch.sigmoid(torch.tensor(raw1)).numpy()
313
+ p2 = torch.sigmoid(torch.tensor(raw2)).numpy()
314
+
315
+ avg_probs = (p1 + p2) / 2
316
+ preds = self.threshold_optimizer.predict(avg_probs.reshape(1,-1))[0]
317
+
318
+ return {
319
+ 'signals': signals.tolist(),
320
+ 'probabilities': {n: float(p) for n,p in zip(self.classes, avg_probs)},
321
+ 'predictions': {n: bool(v) for n,v in zip(self.classes, preds)},
322
+ 'predicted_conditions': [n for n,v in zip(self.classes,preds) if v],
323
+ 'confidence': float(np.max(avg_probs)),
324
+ 'risk_score': float(self._calculate_risk(avg_probs))
325
+ }
326
+
327
+ def _calculate_risk(self, probs):
328
+ risk_weights = {'MI':0.5,'STTC':0.3,'CD':0.15,'HYP':0.05,'NORM':0.0}
329
+ return min(sum(probs[i] * risk_weights.get(n, 0.0) for i,n in enumerate(self.classes)), 1.0)
330
+
331
+ # ✅ Use the actual local file path for the .pt checkpoint
332
+ MODEL_PATH = _local_files.get("hubert_ecg_superclass_best.pt")
333
+ CLASS_INFO_PATH = _local_files.get("class_info.pkl")
334
+ THRESHOLD_PATH = _local_files.get("threshold_optimizer.pkl")
335
+
336
+ predictor = ECGPredictor(model_path=MODEL_PATH,
337
+ class_info_path=CLASS_INFO_PATH,
338
+ threshold_path=THRESHOLD_PATH)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers
5
+ huggingface-hub
6
+ opencv-python-headless
7
+ scipy
8
+ matplotlib
9
+ numpy