ecg-backend / app.py
niol08's picture
Update app.py
2d09f0a verified
# from fastapi import FastAPI, UploadFile, File
# from huggingface_hub import hf_hub_download
# import numpy as np
# from keras.models import load_model
# from graph import zeropad, zeropad_output_shape
# from config import get_config
# app = FastAPI(
# title="ECG Classification Backend",
# description="REST API for ECG heartbeat classification",
# version="1.1.0"
# )
# config = get_config()
# MODEL_PATH = "MLII-latest.keras"
# print("๐Ÿ”น Loading model:", MODEL_PATH)
# model = load_model(
# MODEL_PATH,
# custom_objects={
# "zeropad": zeropad,
# "zeropad_output_shape": zeropad_output_shape
# },
# compile=False
# )
# CLASSES = ["N", "V", "/", "A", "F", "~"]
# CLASS_NAMES = {
# "N": "Normal sinus beat",
# "V": "Premature Ventricular Contraction (PVC)",
# "/": "Paced beat (Pacemaker)",
# "A": "Atrial Premature Beat",
# "F": "Fusion of Ventricular & Normal Beat",
# "~": "Unclassifiable / Noise"
# }
# @app.get("/")
# async def root():
# return {"message": "โœ… ECG Inference API is running successfully!"}
# @app.post("/predict-ecg/")
# async def predict_ecg(file: UploadFile = File(...)):
# """
# Accepts a CSV or TXT file containing ECG signal samples.
# Each value should be a single float per line.
# """
# content = await file.read()
# text = content.decode("utf-8").strip().splitlines()
# try:
# data = np.array([float(x.strip()) for x in text if x.strip() != ""])
# except Exception:
# return {"error": "Invalid file format. Please upload numeric ECG values only."}
# max_len = 256
# if len(data) > max_len:
# data = data[:max_len]
# elif len(data) < max_len:
# data = np.pad(data, (0, max_len - len(data)))
# data = data.reshape(1, max_len, 1)
# preds = model.predict(data, verbose=0)
# label_idx = int(np.argmax(preds))
# confidence = float(np.max(preds))
# label = CLASSES[label_idx]
# description = CLASS_NAMES[label]
# return {
# "label": label,
# "description": description,
# "confidence": round(confidence, 4),
# "samples_used": len(data[0])
# }
from fastapi import FastAPI, UploadFile, File, Query
from typing import Optional, List, Dict, Any
import numpy as np
from keras.models import load_model
from graph import zeropad, zeropad_output_shape
from config import get_config
# Optional import for resampling if user needs it
try:
from scipy.signal import resample
SCIPY_AVAILABLE = True
except Exception:
SCIPY_AVAILABLE = False
# -------------------------
# App & model setup
# -------------------------
app = FastAPI(
title="ECG Classification Backend",
description="REST API for ECG heartbeat classification (MIT-BIH MLII Compatible)",
version="1.3.0"
)
config = get_config()
MODEL_PATH = "MLII-latest.keras"
print(f"๐Ÿ”น Loading model: {MODEL_PATH}")
model = load_model(
MODEL_PATH,
custom_objects={
"zeropad": zeropad,
"zeropad_output_shape": zeropad_output_shape
},
compile=False
)
# Class definitions
CLASSES = ["N", "V", "/", "A", "F", "~"]
CLASS_NAMES = {
"N": "Normal sinus beat",
"V": "Premature Ventricular Contraction (PVC)",
"/": "Paced beat (Pacemaker)",
"A": "Atrial Premature Beat",
"F": "Fusion of Ventricular & Normal Beat",
"~": "Unclassifiable / Noise"
}
# -------------------------
# Helper functions
# -------------------------
def _to_float_array(lines: List[str]) -> np.ndarray:
"""Convert lines of text to a numpy float array, skipping empty lines."""
vals = []
for ln in lines:
s = ln.strip()
if s == "":
continue
try:
vals.append(float(s))
except Exception:
# ignore unparsable lines silently
continue
return np.array(vals, dtype=np.float32)
def _resample_if_requested(signal: np.ndarray, orig_fs: Optional[float], target_fs: Optional[float]) -> np.ndarray:
"""Resample signal to target_fs if both orig_fs and target_fs provided and scipy available."""
if orig_fs is None or target_fs is None:
return signal
if not SCIPY_AVAILABLE:
# can't resample without scipy; return original signal and let user know via metadata
return signal
if orig_fs == target_fs:
return signal
target_len = int(round(len(signal) * (target_fs / orig_fs)))
if target_len < 1:
return signal
return resample(signal, target_len)
def _pad_or_truncate(signal: np.ndarray, length: int) -> np.ndarray:
if len(signal) > length:
return signal[:length].copy()
if len(signal) < length:
return np.pad(signal, (0, length - len(signal)), mode="constant")
return signal
def _zscore(signal: np.ndarray) -> np.ndarray:
mean = np.mean(signal)
std = np.std(signal)
if std < 1e-6:
std = 1.0
return (signal - mean) / std
def _normalize_hardware_adc(signal: np.ndarray, adc_max: float = 4095.0, vref: float = 3.3) -> np.ndarray:
"""
Convert ADC counts to volts, center, and z-score to match MLII amplitude behavior.
This assumes the ADC output is unipolar (0..adc_max). AD8232 output is around midrail;
your firmware already applies HPF/LPF which centers the waveform; if you send raw ADC,
this conversion will center the waveform.
"""
# Convert ADC counts -> Volts
volts = (signal / float(adc_max)) * float(vref)
# Center on zero
volts = volts - np.mean(volts)
# Z-score scale
return _zscore(volts)
def _segment_stream(signal: np.ndarray, window: int = 256, step: int = 128) -> List[np.ndarray]:
"""Make a list of overlapping segments from a longer stream."""
if len(signal) <= window:
return [ _pad_or_truncate(signal, window) ]
segments = []
for start in range(0, len(signal) - window + 1, step):
segments.append(signal[start:start+window].copy())
# if final tail remains and wasn't included, optionally include last window (pad)
if (len(signal) - window) % step != 0 and (len(signal) % step) != 0:
segments.append(_pad_or_truncate(signal[-window:], window))
return segments
def _predict_segments(segments: List[np.ndarray]):
"""Run model.predict on list of 1D segments -> returns preds (num_segments, num_classes)"""
X = np.stack(segments, axis=0) # (N, window)
X = X[..., np.newaxis] # (N, window, 1)
preds = model.predict(X, verbose=0) # (N, C)
return preds
# -------------------------
# Routes
# -------------------------
@app.get("/")
async def root():
return {"message": "ECG Inference API is running successfully!"}
@app.post("/predict-ecg/")
async def predict_ecg(
file: UploadFile = File(...),
original_fs: Optional[float] = Query(None, description="(optional) sampling rate of the uploaded data in Hz"),
resample_to: Optional[float] = Query(None, description="(optional) resample uploaded data to this fs (Hz) if provided)")
) -> Dict[str, Any]:
"""
Accepts a CSV or TXT file containing ECG samples (one float per line).
Optional query params:
- original_fs : sampling rate of the provided file (Hz)
- resample_to : if set and scipy available, resample to this rate
Processing:
- Convert lines -> floats
- Optionally resample
- Auto-detect hardware ADC vs dataset and normalize accordingly
- Segment into 256-sample windows (50% overlap) and run inference per-segment
- Aggregate predictions (majority vote) and return per-segment results + summary
"""
# Read & parse file
content = await file.read()
text_lines = content.decode("utf-8").strip().splitlines()
data = _to_float_array(text_lines)
if data.size == 0:
return {"error": "No numeric samples found in uploaded file."}
# Optional resample (if requested)
if original_fs is not None and resample_to is not None and SCIPY_AVAILABLE:
data = _resample_if_requested(data, orig_fs=original_fs, target_fs=resample_to)
# If user requested resampling but scipy not available -> notify in response metadata
resample_unavailable = (original_fs is not None and resample_to is not None and not SCIPY_AVAILABLE)
# Detect whether this is ADC-like hardware input or dataset (MIT-BIH style)
# Heuristic: ADC counts are usually large integers (>> 100). MIT-BIH signals are small floats (~ -2 .. +2)
max_val = float(np.max(np.abs(data)))
mean_before = float(np.mean(data))
std_before = float(np.std(data))
is_adc_like = max_val > 100.0 # heuristic threshold; adjust if needed
# If ADC-like -> convert to volts and z-score
if is_adc_like:
norm_data = _normalize_hardware_adc(data)
source_type = "hardware_adc"
else:
# Assume dataset-like (MLII). We still apply z-score to match model training preprocessing.
norm_data = _zscore(data)
source_type = "dataset_like"
# Segment into windows (256 samples, 50% overlap)
window = 256
step = window // 2
segments = _segment_stream(norm_data, window=window, step=step)
preds = _predict_segments(segments) # (num_segments, num_classes)
pred_labels_idx = np.argmax(preds, axis=1)
pred_confidences = np.max(preds, axis=1)
# Aggregate: majority vote for labels; also compute mean confidence for that class
# Map indices -> class codes
seg_results = []
for i, (lab_idx, conf) in enumerate(zip(pred_labels_idx, pred_confidences)):
code = CLASSES[int(lab_idx)]
seg_results.append({
"segment": i,
"label": code,
"label_name": CLASS_NAMES[code],
"confidence": float(round(float(conf), 4)),
"mean_before": float(round(float(np.mean(segments[i])), 6)),
"std_before": float(round(float(np.std(segments[i])), 6))
})
# Majority vote
from collections import Counter
votes = [CLASSES[int(i)] for i in pred_labels_idx]
vote_counts = Counter(votes)
majority_label, majority_count = vote_counts.most_common(1)[0]
majority_confidence = float(np.mean([c for c,lab in zip(pred_confidences, votes) if lab == majority_label]))
response = {
"source_type": source_type,
"original_stats": {"max": round(max_val, 6), "mean": round(mean_before, 6), "std": round(std_before, 6)},
"num_samples_uploaded": int(len(data)),
"num_segments": len(segments),
"per_segment": seg_results,
"aggregate": {
"label": majority_label,
"label_name": CLASS_NAMES[majority_label],
"votes_for_label": int(majority_count),
"majority_confidence_mean": round(majority_confidence, 4)
},
"resample_unavailable": resample_unavailable
}
return response