# from fastapi import FastAPI, UploadFile, File # from huggingface_hub import hf_hub_download # import numpy as np # from keras.models import load_model # from graph import zeropad, zeropad_output_shape # from config import get_config # app = FastAPI( # title="ECG Classification Backend", # description="REST API for ECG heartbeat classification", # version="1.1.0" # ) # config = get_config() # MODEL_PATH = "MLII-latest.keras" # print("🔹 Loading model:", MODEL_PATH) # model = load_model( # MODEL_PATH, # custom_objects={ # "zeropad": zeropad, # "zeropad_output_shape": zeropad_output_shape # }, # compile=False # ) # CLASSES = ["N", "V", "/", "A", "F", "~"] # CLASS_NAMES = { # "N": "Normal sinus beat", # "V": "Premature Ventricular Contraction (PVC)", # "/": "Paced beat (Pacemaker)", # "A": "Atrial Premature Beat", # "F": "Fusion of Ventricular & Normal Beat", # "~": "Unclassifiable / Noise" # } # @app.get("/") # async def root(): # return {"message": "✅ ECG Inference API is running successfully!"} # @app.post("/predict-ecg/") # async def predict_ecg(file: UploadFile = File(...)): # """ # Accepts a CSV or TXT file containing ECG signal samples. # Each value should be a single float per line. # """ # content = await file.read() # text = content.decode("utf-8").strip().splitlines() # try: # data = np.array([float(x.strip()) for x in text if x.strip() != ""]) # except Exception: # return {"error": "Invalid file format. Please upload numeric ECG values only."} # max_len = 256 # if len(data) > max_len: # data = data[:max_len] # elif len(data) < max_len: # data = np.pad(data, (0, max_len - len(data))) # data = data.reshape(1, max_len, 1) # preds = model.predict(data, verbose=0) # label_idx = int(np.argmax(preds)) # confidence = float(np.max(preds)) # label = CLASSES[label_idx] # description = CLASS_NAMES[label] # return { # "label": label, # "description": description, # "confidence": round(confidence, 4), # "samples_used": len(data[0]) # } from fastapi import FastAPI, UploadFile, File, Query from typing import Optional, List, Dict, Any import numpy as np from keras.models import load_model from graph import zeropad, zeropad_output_shape from config import get_config # Optional import for resampling if user needs it try: from scipy.signal import resample SCIPY_AVAILABLE = True except Exception: SCIPY_AVAILABLE = False # ------------------------- # App & model setup # ------------------------- app = FastAPI( title="ECG Classification Backend", description="REST API for ECG heartbeat classification (MIT-BIH MLII Compatible)", version="1.3.0" ) config = get_config() MODEL_PATH = "MLII-latest.keras" print(f"🔹 Loading model: {MODEL_PATH}") model = load_model( MODEL_PATH, custom_objects={ "zeropad": zeropad, "zeropad_output_shape": zeropad_output_shape }, compile=False ) # Class definitions CLASSES = ["N", "V", "/", "A", "F", "~"] CLASS_NAMES = { "N": "Normal sinus beat", "V": "Premature Ventricular Contraction (PVC)", "/": "Paced beat (Pacemaker)", "A": "Atrial Premature Beat", "F": "Fusion of Ventricular & Normal Beat", "~": "Unclassifiable / Noise" } # ------------------------- # Helper functions # ------------------------- def _to_float_array(lines: List[str]) -> np.ndarray: """Convert lines of text to a numpy float array, skipping empty lines.""" vals = [] for ln in lines: s = ln.strip() if s == "": continue try: vals.append(float(s)) except Exception: # ignore unparsable lines silently continue return np.array(vals, dtype=np.float32) def _resample_if_requested(signal: np.ndarray, orig_fs: Optional[float], target_fs: Optional[float]) -> np.ndarray: """Resample signal to target_fs if both orig_fs and target_fs provided and scipy available.""" if orig_fs is None or target_fs is None: return signal if not SCIPY_AVAILABLE: # can't resample without scipy; return original signal and let user know via metadata return signal if orig_fs == target_fs: return signal target_len = int(round(len(signal) * (target_fs / orig_fs))) if target_len < 1: return signal return resample(signal, target_len) def _pad_or_truncate(signal: np.ndarray, length: int) -> np.ndarray: if len(signal) > length: return signal[:length].copy() if len(signal) < length: return np.pad(signal, (0, length - len(signal)), mode="constant") return signal def _zscore(signal: np.ndarray) -> np.ndarray: mean = np.mean(signal) std = np.std(signal) if std < 1e-6: std = 1.0 return (signal - mean) / std def _normalize_hardware_adc(signal: np.ndarray, adc_max: float = 4095.0, vref: float = 3.3) -> np.ndarray: """ Convert ADC counts to volts, center, and z-score to match MLII amplitude behavior. This assumes the ADC output is unipolar (0..adc_max). AD8232 output is around midrail; your firmware already applies HPF/LPF which centers the waveform; if you send raw ADC, this conversion will center the waveform. """ # Convert ADC counts -> Volts volts = (signal / float(adc_max)) * float(vref) # Center on zero volts = volts - np.mean(volts) # Z-score scale return _zscore(volts) def _segment_stream(signal: np.ndarray, window: int = 256, step: int = 128) -> List[np.ndarray]: """Make a list of overlapping segments from a longer stream.""" if len(signal) <= window: return [ _pad_or_truncate(signal, window) ] segments = [] for start in range(0, len(signal) - window + 1, step): segments.append(signal[start:start+window].copy()) # if final tail remains and wasn't included, optionally include last window (pad) if (len(signal) - window) % step != 0 and (len(signal) % step) != 0: segments.append(_pad_or_truncate(signal[-window:], window)) return segments def _predict_segments(segments: List[np.ndarray]): """Run model.predict on list of 1D segments -> returns preds (num_segments, num_classes)""" X = np.stack(segments, axis=0) # (N, window) X = X[..., np.newaxis] # (N, window, 1) preds = model.predict(X, verbose=0) # (N, C) return preds # ------------------------- # Routes # ------------------------- @app.get("/") async def root(): return {"message": "ECG Inference API is running successfully!"} @app.post("/predict-ecg/") async def predict_ecg( file: UploadFile = File(...), original_fs: Optional[float] = Query(None, description="(optional) sampling rate of the uploaded data in Hz"), resample_to: Optional[float] = Query(None, description="(optional) resample uploaded data to this fs (Hz) if provided)") ) -> Dict[str, Any]: """ Accepts a CSV or TXT file containing ECG samples (one float per line). Optional query params: - original_fs : sampling rate of the provided file (Hz) - resample_to : if set and scipy available, resample to this rate Processing: - Convert lines -> floats - Optionally resample - Auto-detect hardware ADC vs dataset and normalize accordingly - Segment into 256-sample windows (50% overlap) and run inference per-segment - Aggregate predictions (majority vote) and return per-segment results + summary """ # Read & parse file content = await file.read() text_lines = content.decode("utf-8").strip().splitlines() data = _to_float_array(text_lines) if data.size == 0: return {"error": "No numeric samples found in uploaded file."} # Optional resample (if requested) if original_fs is not None and resample_to is not None and SCIPY_AVAILABLE: data = _resample_if_requested(data, orig_fs=original_fs, target_fs=resample_to) # If user requested resampling but scipy not available -> notify in response metadata resample_unavailable = (original_fs is not None and resample_to is not None and not SCIPY_AVAILABLE) # Detect whether this is ADC-like hardware input or dataset (MIT-BIH style) # Heuristic: ADC counts are usually large integers (>> 100). MIT-BIH signals are small floats (~ -2 .. +2) max_val = float(np.max(np.abs(data))) mean_before = float(np.mean(data)) std_before = float(np.std(data)) is_adc_like = max_val > 100.0 # heuristic threshold; adjust if needed # If ADC-like -> convert to volts and z-score if is_adc_like: norm_data = _normalize_hardware_adc(data) source_type = "hardware_adc" else: # Assume dataset-like (MLII). We still apply z-score to match model training preprocessing. norm_data = _zscore(data) source_type = "dataset_like" # Segment into windows (256 samples, 50% overlap) window = 256 step = window // 2 segments = _segment_stream(norm_data, window=window, step=step) preds = _predict_segments(segments) # (num_segments, num_classes) pred_labels_idx = np.argmax(preds, axis=1) pred_confidences = np.max(preds, axis=1) # Aggregate: majority vote for labels; also compute mean confidence for that class # Map indices -> class codes seg_results = [] for i, (lab_idx, conf) in enumerate(zip(pred_labels_idx, pred_confidences)): code = CLASSES[int(lab_idx)] seg_results.append({ "segment": i, "label": code, "label_name": CLASS_NAMES[code], "confidence": float(round(float(conf), 4)), "mean_before": float(round(float(np.mean(segments[i])), 6)), "std_before": float(round(float(np.std(segments[i])), 6)) }) # Majority vote from collections import Counter votes = [CLASSES[int(i)] for i in pred_labels_idx] vote_counts = Counter(votes) majority_label, majority_count = vote_counts.most_common(1)[0] majority_confidence = float(np.mean([c for c,lab in zip(pred_confidences, votes) if lab == majority_label])) response = { "source_type": source_type, "original_stats": {"max": round(max_val, 6), "mean": round(mean_before, 6), "std": round(std_before, 6)}, "num_samples_uploaded": int(len(data)), "num_segments": len(segments), "per_segment": seg_results, "aggregate": { "label": majority_label, "label_name": CLASS_NAMES[majority_label], "votes_for_label": int(majority_count), "majority_confidence_mean": round(majority_confidence, 4) }, "resample_unavailable": resample_unavailable } return response