File size: 10,871 Bytes
49edb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d09f0a
 
08caee5
 
 
 
 
2d09f0a
 
 
 
 
 
49edb6e
2d09f0a
 
 
08caee5
 
49edb6e
2d09f0a
08caee5
 
af5990b
 
08caee5
49edb6e
08caee5
 
 
 
 
 
 
 
 
2d09f0a
08caee5
 
 
 
 
 
 
 
 
 
2d09f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08caee5
 
49edb6e
08caee5
 
 
2d09f0a
 
 
 
 
08caee5
49edb6e
2d09f0a
 
 
 
 
 
 
 
 
08caee5
49edb6e
2d09f0a
08caee5
2d09f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08caee5
49edb6e
2d09f0a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# from fastapi import FastAPI, UploadFile, File
# from huggingface_hub import hf_hub_download
# import numpy as np
# from keras.models import load_model
# from graph import zeropad, zeropad_output_shape
# from config import get_config

# app = FastAPI(
#     title="ECG Classification Backend",
#     description="REST API for ECG heartbeat classification",
#     version="1.1.0"
# )


# config = get_config()
# MODEL_PATH = "MLII-latest.keras"

# print("๐Ÿ”น Loading model:", MODEL_PATH)
# model = load_model(
#     MODEL_PATH,
#     custom_objects={
#         "zeropad": zeropad,
#         "zeropad_output_shape": zeropad_output_shape
#     },
#     compile=False
# )


# CLASSES = ["N", "V", "/", "A", "F", "~"]
# CLASS_NAMES = {
#     "N": "Normal sinus beat",
#     "V": "Premature Ventricular Contraction (PVC)",
#     "/": "Paced beat (Pacemaker)",
#     "A": "Atrial Premature Beat",
#     "F": "Fusion of Ventricular & Normal Beat",
#     "~": "Unclassifiable / Noise"
# }

# @app.get("/")
# async def root():
#     return {"message": "โœ… ECG Inference API is running successfully!"}


# @app.post("/predict-ecg/")
# async def predict_ecg(file: UploadFile = File(...)):
#     """
#     Accepts a CSV or TXT file containing ECG signal samples.
#     Each value should be a single float per line.
#     """
#     content = await file.read()
#     text = content.decode("utf-8").strip().splitlines()

 
#     try:
#         data = np.array([float(x.strip()) for x in text if x.strip() != ""])
#     except Exception:
#         return {"error": "Invalid file format. Please upload numeric ECG values only."}


#     max_len = 256
#     if len(data) > max_len:
#         data = data[:max_len]
#     elif len(data) < max_len:
#         data = np.pad(data, (0, max_len - len(data)))

#     data = data.reshape(1, max_len, 1)

  
#     preds = model.predict(data, verbose=0)
#     label_idx = int(np.argmax(preds))
#     confidence = float(np.max(preds))
#     label = CLASSES[label_idx]
#     description = CLASS_NAMES[label]

#     return {
#         "label": label,
#         "description": description,
#         "confidence": round(confidence, 4),
#         "samples_used": len(data[0])
#     }


from fastapi import FastAPI, UploadFile, File, Query
from typing import Optional, List, Dict, Any
import numpy as np
from keras.models import load_model
from graph import zeropad, zeropad_output_shape
from config import get_config

# Optional import for resampling if user needs it
try:
    from scipy.signal import resample
    SCIPY_AVAILABLE = True
except Exception:
    SCIPY_AVAILABLE = False

# -------------------------
# App & model setup
# -------------------------
app = FastAPI(
    title="ECG Classification Backend",
    description="REST API for ECG heartbeat classification (MIT-BIH MLII Compatible)",
    version="1.3.0"
)

config = get_config()
MODEL_PATH = "MLII-latest.keras"

print(f"๐Ÿ”น Loading model: {MODEL_PATH}")
model = load_model(
    MODEL_PATH,
    custom_objects={
        "zeropad": zeropad,
        "zeropad_output_shape": zeropad_output_shape
    },
    compile=False
)

# Class definitions
CLASSES = ["N", "V", "/", "A", "F", "~"]
CLASS_NAMES = {
    "N": "Normal sinus beat",
    "V": "Premature Ventricular Contraction (PVC)",
    "/": "Paced beat (Pacemaker)",
    "A": "Atrial Premature Beat",
    "F": "Fusion of Ventricular & Normal Beat",
    "~": "Unclassifiable / Noise"
}

# -------------------------
# Helper functions
# -------------------------
def _to_float_array(lines: List[str]) -> np.ndarray:
    """Convert lines of text to a numpy float array, skipping empty lines."""
    vals = []
    for ln in lines:
        s = ln.strip()
        if s == "":
            continue
        try:
            vals.append(float(s))
        except Exception:
            # ignore unparsable lines silently
            continue
    return np.array(vals, dtype=np.float32)

def _resample_if_requested(signal: np.ndarray, orig_fs: Optional[float], target_fs: Optional[float]) -> np.ndarray:
    """Resample signal to target_fs if both orig_fs and target_fs provided and scipy available."""
    if orig_fs is None or target_fs is None:
        return signal
    if not SCIPY_AVAILABLE:
        # can't resample without scipy; return original signal and let user know via metadata
        return signal
    if orig_fs == target_fs:
        return signal
    target_len = int(round(len(signal) * (target_fs / orig_fs)))
    if target_len < 1:
        return signal
    return resample(signal, target_len)

def _pad_or_truncate(signal: np.ndarray, length: int) -> np.ndarray:
    if len(signal) > length:
        return signal[:length].copy()
    if len(signal) < length:
        return np.pad(signal, (0, length - len(signal)), mode="constant")
    return signal

def _zscore(signal: np.ndarray) -> np.ndarray:
    mean = np.mean(signal)
    std = np.std(signal)
    if std < 1e-6:
        std = 1.0
    return (signal - mean) / std

def _normalize_hardware_adc(signal: np.ndarray, adc_max: float = 4095.0, vref: float = 3.3) -> np.ndarray:
    """
    Convert ADC counts to volts, center, and z-score to match MLII amplitude behavior.
    This assumes the ADC output is unipolar (0..adc_max). AD8232 output is around midrail;
    your firmware already applies HPF/LPF which centers the waveform; if you send raw ADC,
    this conversion will center the waveform.
    """
    # Convert ADC counts -> Volts
    volts = (signal / float(adc_max)) * float(vref)
    # Center on zero
    volts = volts - np.mean(volts)
    # Z-score scale
    return _zscore(volts)

def _segment_stream(signal: np.ndarray, window: int = 256, step: int = 128) -> List[np.ndarray]:
    """Make a list of overlapping segments from a longer stream."""
    if len(signal) <= window:
        return [ _pad_or_truncate(signal, window) ]
    segments = []
    for start in range(0, len(signal) - window + 1, step):
        segments.append(signal[start:start+window].copy())
    # if final tail remains and wasn't included, optionally include last window (pad)
    if (len(signal) - window) % step != 0 and (len(signal) % step) != 0:
        segments.append(_pad_or_truncate(signal[-window:], window))
    return segments

def _predict_segments(segments: List[np.ndarray]):
    """Run model.predict on list of 1D segments -> returns preds (num_segments, num_classes)"""
    X = np.stack(segments, axis=0)  # (N, window)
    X = X[..., np.newaxis]  # (N, window, 1)
    preds = model.predict(X, verbose=0)  # (N, C)
    return preds

# -------------------------
# Routes
# -------------------------
@app.get("/")
async def root():
    return {"message": "ECG Inference API is running successfully!"}


@app.post("/predict-ecg/")
async def predict_ecg(
    file: UploadFile = File(...),
    original_fs: Optional[float] = Query(None, description="(optional) sampling rate of the uploaded data in Hz"),
    resample_to: Optional[float] = Query(None, description="(optional) resample uploaded data to this fs (Hz) if provided)")
) -> Dict[str, Any]:
    """
    Accepts a CSV or TXT file containing ECG samples (one float per line).
    Optional query params:
      - original_fs : sampling rate of the provided file (Hz)
      - resample_to  : if set and scipy available, resample to this rate
    Processing:
      - Convert lines -> floats
      - Optionally resample
      - Auto-detect hardware ADC vs dataset and normalize accordingly
      - Segment into 256-sample windows (50% overlap) and run inference per-segment
      - Aggregate predictions (majority vote) and return per-segment results + summary
    """

    # Read & parse file
    content = await file.read()
    text_lines = content.decode("utf-8").strip().splitlines()
    data = _to_float_array(text_lines)
    if data.size == 0:
        return {"error": "No numeric samples found in uploaded file."}

    # Optional resample (if requested)
    if original_fs is not None and resample_to is not None and SCIPY_AVAILABLE:
        data = _resample_if_requested(data, orig_fs=original_fs, target_fs=resample_to)

    # If user requested resampling but scipy not available -> notify in response metadata
    resample_unavailable = (original_fs is not None and resample_to is not None and not SCIPY_AVAILABLE)

    # Detect whether this is ADC-like hardware input or dataset (MIT-BIH style)
    # Heuristic: ADC counts are usually large integers (>> 100). MIT-BIH signals are small floats (~ -2 .. +2)
    max_val = float(np.max(np.abs(data)))
    mean_before = float(np.mean(data))
    std_before = float(np.std(data))

    is_adc_like = max_val > 100.0  # heuristic threshold; adjust if needed

    # If ADC-like -> convert to volts and z-score
    if is_adc_like:
        norm_data = _normalize_hardware_adc(data)
        source_type = "hardware_adc"
    else:
        # Assume dataset-like (MLII). We still apply z-score to match model training preprocessing.
        norm_data = _zscore(data)
        source_type = "dataset_like"

    # Segment into windows (256 samples, 50% overlap)
    window = 256
    step = window // 2
    segments = _segment_stream(norm_data, window=window, step=step)

    preds = _predict_segments(segments)  # (num_segments, num_classes)
    pred_labels_idx = np.argmax(preds, axis=1)
    pred_confidences = np.max(preds, axis=1)

    # Aggregate: majority vote for labels; also compute mean confidence for that class
    # Map indices -> class codes
    seg_results = []
    for i, (lab_idx, conf) in enumerate(zip(pred_labels_idx, pred_confidences)):
        code = CLASSES[int(lab_idx)]
        seg_results.append({
            "segment": i,
            "label": code,
            "label_name": CLASS_NAMES[code],
            "confidence": float(round(float(conf), 4)),
            "mean_before": float(round(float(np.mean(segments[i])), 6)),
            "std_before": float(round(float(np.std(segments[i])), 6))
        })

    # Majority vote
    from collections import Counter
    votes = [CLASSES[int(i)] for i in pred_labels_idx]
    vote_counts = Counter(votes)
    majority_label, majority_count = vote_counts.most_common(1)[0]
    majority_confidence = float(np.mean([c for c,lab in zip(pred_confidences, votes) if lab == majority_label]))

    response = {
        "source_type": source_type,
        "original_stats": {"max": round(max_val, 6), "mean": round(mean_before, 6), "std": round(std_before, 6)},
        "num_samples_uploaded": int(len(data)),
        "num_segments": len(segments),
        "per_segment": seg_results,
        "aggregate": {
            "label": majority_label,
            "label_name": CLASS_NAMES[majority_label],
            "votes_for_label": int(majority_count),
            "majority_confidence_mean": round(majority_confidence, 4)
        },
        "resample_unavailable": resample_unavailable
    }

    return response