Spaces:
Sleeping
Sleeping
File size: 10,871 Bytes
49edb6e 2d09f0a 08caee5 2d09f0a 49edb6e 2d09f0a 08caee5 49edb6e 2d09f0a 08caee5 af5990b 08caee5 49edb6e 08caee5 2d09f0a 08caee5 2d09f0a 08caee5 49edb6e 08caee5 2d09f0a 08caee5 49edb6e 2d09f0a 08caee5 49edb6e 2d09f0a 08caee5 2d09f0a 08caee5 49edb6e 2d09f0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
# from fastapi import FastAPI, UploadFile, File
# from huggingface_hub import hf_hub_download
# import numpy as np
# from keras.models import load_model
# from graph import zeropad, zeropad_output_shape
# from config import get_config
# app = FastAPI(
# title="ECG Classification Backend",
# description="REST API for ECG heartbeat classification",
# version="1.1.0"
# )
# config = get_config()
# MODEL_PATH = "MLII-latest.keras"
# print("๐น Loading model:", MODEL_PATH)
# model = load_model(
# MODEL_PATH,
# custom_objects={
# "zeropad": zeropad,
# "zeropad_output_shape": zeropad_output_shape
# },
# compile=False
# )
# CLASSES = ["N", "V", "/", "A", "F", "~"]
# CLASS_NAMES = {
# "N": "Normal sinus beat",
# "V": "Premature Ventricular Contraction (PVC)",
# "/": "Paced beat (Pacemaker)",
# "A": "Atrial Premature Beat",
# "F": "Fusion of Ventricular & Normal Beat",
# "~": "Unclassifiable / Noise"
# }
# @app.get("/")
# async def root():
# return {"message": "โ
ECG Inference API is running successfully!"}
# @app.post("/predict-ecg/")
# async def predict_ecg(file: UploadFile = File(...)):
# """
# Accepts a CSV or TXT file containing ECG signal samples.
# Each value should be a single float per line.
# """
# content = await file.read()
# text = content.decode("utf-8").strip().splitlines()
# try:
# data = np.array([float(x.strip()) for x in text if x.strip() != ""])
# except Exception:
# return {"error": "Invalid file format. Please upload numeric ECG values only."}
# max_len = 256
# if len(data) > max_len:
# data = data[:max_len]
# elif len(data) < max_len:
# data = np.pad(data, (0, max_len - len(data)))
# data = data.reshape(1, max_len, 1)
# preds = model.predict(data, verbose=0)
# label_idx = int(np.argmax(preds))
# confidence = float(np.max(preds))
# label = CLASSES[label_idx]
# description = CLASS_NAMES[label]
# return {
# "label": label,
# "description": description,
# "confidence": round(confidence, 4),
# "samples_used": len(data[0])
# }
from fastapi import FastAPI, UploadFile, File, Query
from typing import Optional, List, Dict, Any
import numpy as np
from keras.models import load_model
from graph import zeropad, zeropad_output_shape
from config import get_config
# Optional import for resampling if user needs it
try:
from scipy.signal import resample
SCIPY_AVAILABLE = True
except Exception:
SCIPY_AVAILABLE = False
# -------------------------
# App & model setup
# -------------------------
app = FastAPI(
title="ECG Classification Backend",
description="REST API for ECG heartbeat classification (MIT-BIH MLII Compatible)",
version="1.3.0"
)
config = get_config()
MODEL_PATH = "MLII-latest.keras"
print(f"๐น Loading model: {MODEL_PATH}")
model = load_model(
MODEL_PATH,
custom_objects={
"zeropad": zeropad,
"zeropad_output_shape": zeropad_output_shape
},
compile=False
)
# Class definitions
CLASSES = ["N", "V", "/", "A", "F", "~"]
CLASS_NAMES = {
"N": "Normal sinus beat",
"V": "Premature Ventricular Contraction (PVC)",
"/": "Paced beat (Pacemaker)",
"A": "Atrial Premature Beat",
"F": "Fusion of Ventricular & Normal Beat",
"~": "Unclassifiable / Noise"
}
# -------------------------
# Helper functions
# -------------------------
def _to_float_array(lines: List[str]) -> np.ndarray:
"""Convert lines of text to a numpy float array, skipping empty lines."""
vals = []
for ln in lines:
s = ln.strip()
if s == "":
continue
try:
vals.append(float(s))
except Exception:
# ignore unparsable lines silently
continue
return np.array(vals, dtype=np.float32)
def _resample_if_requested(signal: np.ndarray, orig_fs: Optional[float], target_fs: Optional[float]) -> np.ndarray:
"""Resample signal to target_fs if both orig_fs and target_fs provided and scipy available."""
if orig_fs is None or target_fs is None:
return signal
if not SCIPY_AVAILABLE:
# can't resample without scipy; return original signal and let user know via metadata
return signal
if orig_fs == target_fs:
return signal
target_len = int(round(len(signal) * (target_fs / orig_fs)))
if target_len < 1:
return signal
return resample(signal, target_len)
def _pad_or_truncate(signal: np.ndarray, length: int) -> np.ndarray:
if len(signal) > length:
return signal[:length].copy()
if len(signal) < length:
return np.pad(signal, (0, length - len(signal)), mode="constant")
return signal
def _zscore(signal: np.ndarray) -> np.ndarray:
mean = np.mean(signal)
std = np.std(signal)
if std < 1e-6:
std = 1.0
return (signal - mean) / std
def _normalize_hardware_adc(signal: np.ndarray, adc_max: float = 4095.0, vref: float = 3.3) -> np.ndarray:
"""
Convert ADC counts to volts, center, and z-score to match MLII amplitude behavior.
This assumes the ADC output is unipolar (0..adc_max). AD8232 output is around midrail;
your firmware already applies HPF/LPF which centers the waveform; if you send raw ADC,
this conversion will center the waveform.
"""
# Convert ADC counts -> Volts
volts = (signal / float(adc_max)) * float(vref)
# Center on zero
volts = volts - np.mean(volts)
# Z-score scale
return _zscore(volts)
def _segment_stream(signal: np.ndarray, window: int = 256, step: int = 128) -> List[np.ndarray]:
"""Make a list of overlapping segments from a longer stream."""
if len(signal) <= window:
return [ _pad_or_truncate(signal, window) ]
segments = []
for start in range(0, len(signal) - window + 1, step):
segments.append(signal[start:start+window].copy())
# if final tail remains and wasn't included, optionally include last window (pad)
if (len(signal) - window) % step != 0 and (len(signal) % step) != 0:
segments.append(_pad_or_truncate(signal[-window:], window))
return segments
def _predict_segments(segments: List[np.ndarray]):
"""Run model.predict on list of 1D segments -> returns preds (num_segments, num_classes)"""
X = np.stack(segments, axis=0) # (N, window)
X = X[..., np.newaxis] # (N, window, 1)
preds = model.predict(X, verbose=0) # (N, C)
return preds
# -------------------------
# Routes
# -------------------------
@app.get("/")
async def root():
return {"message": "ECG Inference API is running successfully!"}
@app.post("/predict-ecg/")
async def predict_ecg(
file: UploadFile = File(...),
original_fs: Optional[float] = Query(None, description="(optional) sampling rate of the uploaded data in Hz"),
resample_to: Optional[float] = Query(None, description="(optional) resample uploaded data to this fs (Hz) if provided)")
) -> Dict[str, Any]:
"""
Accepts a CSV or TXT file containing ECG samples (one float per line).
Optional query params:
- original_fs : sampling rate of the provided file (Hz)
- resample_to : if set and scipy available, resample to this rate
Processing:
- Convert lines -> floats
- Optionally resample
- Auto-detect hardware ADC vs dataset and normalize accordingly
- Segment into 256-sample windows (50% overlap) and run inference per-segment
- Aggregate predictions (majority vote) and return per-segment results + summary
"""
# Read & parse file
content = await file.read()
text_lines = content.decode("utf-8").strip().splitlines()
data = _to_float_array(text_lines)
if data.size == 0:
return {"error": "No numeric samples found in uploaded file."}
# Optional resample (if requested)
if original_fs is not None and resample_to is not None and SCIPY_AVAILABLE:
data = _resample_if_requested(data, orig_fs=original_fs, target_fs=resample_to)
# If user requested resampling but scipy not available -> notify in response metadata
resample_unavailable = (original_fs is not None and resample_to is not None and not SCIPY_AVAILABLE)
# Detect whether this is ADC-like hardware input or dataset (MIT-BIH style)
# Heuristic: ADC counts are usually large integers (>> 100). MIT-BIH signals are small floats (~ -2 .. +2)
max_val = float(np.max(np.abs(data)))
mean_before = float(np.mean(data))
std_before = float(np.std(data))
is_adc_like = max_val > 100.0 # heuristic threshold; adjust if needed
# If ADC-like -> convert to volts and z-score
if is_adc_like:
norm_data = _normalize_hardware_adc(data)
source_type = "hardware_adc"
else:
# Assume dataset-like (MLII). We still apply z-score to match model training preprocessing.
norm_data = _zscore(data)
source_type = "dataset_like"
# Segment into windows (256 samples, 50% overlap)
window = 256
step = window // 2
segments = _segment_stream(norm_data, window=window, step=step)
preds = _predict_segments(segments) # (num_segments, num_classes)
pred_labels_idx = np.argmax(preds, axis=1)
pred_confidences = np.max(preds, axis=1)
# Aggregate: majority vote for labels; also compute mean confidence for that class
# Map indices -> class codes
seg_results = []
for i, (lab_idx, conf) in enumerate(zip(pred_labels_idx, pred_confidences)):
code = CLASSES[int(lab_idx)]
seg_results.append({
"segment": i,
"label": code,
"label_name": CLASS_NAMES[code],
"confidence": float(round(float(conf), 4)),
"mean_before": float(round(float(np.mean(segments[i])), 6)),
"std_before": float(round(float(np.std(segments[i])), 6))
})
# Majority vote
from collections import Counter
votes = [CLASSES[int(i)] for i in pred_labels_idx]
vote_counts = Counter(votes)
majority_label, majority_count = vote_counts.most_common(1)[0]
majority_confidence = float(np.mean([c for c,lab in zip(pred_confidences, votes) if lab == majority_label]))
response = {
"source_type": source_type,
"original_stats": {"max": round(max_val, 6), "mean": round(mean_before, 6), "std": round(std_before, 6)},
"num_samples_uploaded": int(len(data)),
"num_segments": len(segments),
"per_segment": seg_results,
"aggregate": {
"label": majority_label,
"label_name": CLASS_NAMES[majority_label],
"votes_for_label": int(majority_count),
"majority_confidence_mean": round(majority_confidence, 4)
},
"resample_unavailable": resample_unavailable
}
return response
|