Spaces:
Sleeping
Sleeping
| # from fastapi import FastAPI, UploadFile, File | |
| # from huggingface_hub import hf_hub_download | |
| # import numpy as np | |
| # from keras.models import load_model | |
| # from graph import zeropad, zeropad_output_shape | |
| # from config import get_config | |
| # app = FastAPI( | |
| # title="ECG Classification Backend", | |
| # description="REST API for ECG heartbeat classification", | |
| # version="1.1.0" | |
| # ) | |
| # config = get_config() | |
| # MODEL_PATH = "MLII-latest.keras" | |
| # print("๐น Loading model:", MODEL_PATH) | |
| # model = load_model( | |
| # MODEL_PATH, | |
| # custom_objects={ | |
| # "zeropad": zeropad, | |
| # "zeropad_output_shape": zeropad_output_shape | |
| # }, | |
| # compile=False | |
| # ) | |
| # CLASSES = ["N", "V", "/", "A", "F", "~"] | |
| # CLASS_NAMES = { | |
| # "N": "Normal sinus beat", | |
| # "V": "Premature Ventricular Contraction (PVC)", | |
| # "/": "Paced beat (Pacemaker)", | |
| # "A": "Atrial Premature Beat", | |
| # "F": "Fusion of Ventricular & Normal Beat", | |
| # "~": "Unclassifiable / Noise" | |
| # } | |
| # @app.get("/") | |
| # async def root(): | |
| # return {"message": "โ ECG Inference API is running successfully!"} | |
| # @app.post("/predict-ecg/") | |
| # async def predict_ecg(file: UploadFile = File(...)): | |
| # """ | |
| # Accepts a CSV or TXT file containing ECG signal samples. | |
| # Each value should be a single float per line. | |
| # """ | |
| # content = await file.read() | |
| # text = content.decode("utf-8").strip().splitlines() | |
| # try: | |
| # data = np.array([float(x.strip()) for x in text if x.strip() != ""]) | |
| # except Exception: | |
| # return {"error": "Invalid file format. Please upload numeric ECG values only."} | |
| # max_len = 256 | |
| # if len(data) > max_len: | |
| # data = data[:max_len] | |
| # elif len(data) < max_len: | |
| # data = np.pad(data, (0, max_len - len(data))) | |
| # data = data.reshape(1, max_len, 1) | |
| # preds = model.predict(data, verbose=0) | |
| # label_idx = int(np.argmax(preds)) | |
| # confidence = float(np.max(preds)) | |
| # label = CLASSES[label_idx] | |
| # description = CLASS_NAMES[label] | |
| # return { | |
| # "label": label, | |
| # "description": description, | |
| # "confidence": round(confidence, 4), | |
| # "samples_used": len(data[0]) | |
| # } | |
| from fastapi import FastAPI, UploadFile, File, Query | |
| from typing import Optional, List, Dict, Any | |
| import numpy as np | |
| from keras.models import load_model | |
| from graph import zeropad, zeropad_output_shape | |
| from config import get_config | |
| # Optional import for resampling if user needs it | |
| try: | |
| from scipy.signal import resample | |
| SCIPY_AVAILABLE = True | |
| except Exception: | |
| SCIPY_AVAILABLE = False | |
| # ------------------------- | |
| # App & model setup | |
| # ------------------------- | |
| app = FastAPI( | |
| title="ECG Classification Backend", | |
| description="REST API for ECG heartbeat classification (MIT-BIH MLII Compatible)", | |
| version="1.3.0" | |
| ) | |
| config = get_config() | |
| MODEL_PATH = "MLII-latest.keras" | |
| print(f"๐น Loading model: {MODEL_PATH}") | |
| model = load_model( | |
| MODEL_PATH, | |
| custom_objects={ | |
| "zeropad": zeropad, | |
| "zeropad_output_shape": zeropad_output_shape | |
| }, | |
| compile=False | |
| ) | |
| # Class definitions | |
| CLASSES = ["N", "V", "/", "A", "F", "~"] | |
| CLASS_NAMES = { | |
| "N": "Normal sinus beat", | |
| "V": "Premature Ventricular Contraction (PVC)", | |
| "/": "Paced beat (Pacemaker)", | |
| "A": "Atrial Premature Beat", | |
| "F": "Fusion of Ventricular & Normal Beat", | |
| "~": "Unclassifiable / Noise" | |
| } | |
| # ------------------------- | |
| # Helper functions | |
| # ------------------------- | |
| def _to_float_array(lines: List[str]) -> np.ndarray: | |
| """Convert lines of text to a numpy float array, skipping empty lines.""" | |
| vals = [] | |
| for ln in lines: | |
| s = ln.strip() | |
| if s == "": | |
| continue | |
| try: | |
| vals.append(float(s)) | |
| except Exception: | |
| # ignore unparsable lines silently | |
| continue | |
| return np.array(vals, dtype=np.float32) | |
| def _resample_if_requested(signal: np.ndarray, orig_fs: Optional[float], target_fs: Optional[float]) -> np.ndarray: | |
| """Resample signal to target_fs if both orig_fs and target_fs provided and scipy available.""" | |
| if orig_fs is None or target_fs is None: | |
| return signal | |
| if not SCIPY_AVAILABLE: | |
| # can't resample without scipy; return original signal and let user know via metadata | |
| return signal | |
| if orig_fs == target_fs: | |
| return signal | |
| target_len = int(round(len(signal) * (target_fs / orig_fs))) | |
| if target_len < 1: | |
| return signal | |
| return resample(signal, target_len) | |
| def _pad_or_truncate(signal: np.ndarray, length: int) -> np.ndarray: | |
| if len(signal) > length: | |
| return signal[:length].copy() | |
| if len(signal) < length: | |
| return np.pad(signal, (0, length - len(signal)), mode="constant") | |
| return signal | |
| def _zscore(signal: np.ndarray) -> np.ndarray: | |
| mean = np.mean(signal) | |
| std = np.std(signal) | |
| if std < 1e-6: | |
| std = 1.0 | |
| return (signal - mean) / std | |
| def _normalize_hardware_adc(signal: np.ndarray, adc_max: float = 4095.0, vref: float = 3.3) -> np.ndarray: | |
| """ | |
| Convert ADC counts to volts, center, and z-score to match MLII amplitude behavior. | |
| This assumes the ADC output is unipolar (0..adc_max). AD8232 output is around midrail; | |
| your firmware already applies HPF/LPF which centers the waveform; if you send raw ADC, | |
| this conversion will center the waveform. | |
| """ | |
| # Convert ADC counts -> Volts | |
| volts = (signal / float(adc_max)) * float(vref) | |
| # Center on zero | |
| volts = volts - np.mean(volts) | |
| # Z-score scale | |
| return _zscore(volts) | |
| def _segment_stream(signal: np.ndarray, window: int = 256, step: int = 128) -> List[np.ndarray]: | |
| """Make a list of overlapping segments from a longer stream.""" | |
| if len(signal) <= window: | |
| return [ _pad_or_truncate(signal, window) ] | |
| segments = [] | |
| for start in range(0, len(signal) - window + 1, step): | |
| segments.append(signal[start:start+window].copy()) | |
| # if final tail remains and wasn't included, optionally include last window (pad) | |
| if (len(signal) - window) % step != 0 and (len(signal) % step) != 0: | |
| segments.append(_pad_or_truncate(signal[-window:], window)) | |
| return segments | |
| def _predict_segments(segments: List[np.ndarray]): | |
| """Run model.predict on list of 1D segments -> returns preds (num_segments, num_classes)""" | |
| X = np.stack(segments, axis=0) # (N, window) | |
| X = X[..., np.newaxis] # (N, window, 1) | |
| preds = model.predict(X, verbose=0) # (N, C) | |
| return preds | |
| # ------------------------- | |
| # Routes | |
| # ------------------------- | |
| async def root(): | |
| return {"message": "ECG Inference API is running successfully!"} | |
| async def predict_ecg( | |
| file: UploadFile = File(...), | |
| original_fs: Optional[float] = Query(None, description="(optional) sampling rate of the uploaded data in Hz"), | |
| resample_to: Optional[float] = Query(None, description="(optional) resample uploaded data to this fs (Hz) if provided)") | |
| ) -> Dict[str, Any]: | |
| """ | |
| Accepts a CSV or TXT file containing ECG samples (one float per line). | |
| Optional query params: | |
| - original_fs : sampling rate of the provided file (Hz) | |
| - resample_to : if set and scipy available, resample to this rate | |
| Processing: | |
| - Convert lines -> floats | |
| - Optionally resample | |
| - Auto-detect hardware ADC vs dataset and normalize accordingly | |
| - Segment into 256-sample windows (50% overlap) and run inference per-segment | |
| - Aggregate predictions (majority vote) and return per-segment results + summary | |
| """ | |
| # Read & parse file | |
| content = await file.read() | |
| text_lines = content.decode("utf-8").strip().splitlines() | |
| data = _to_float_array(text_lines) | |
| if data.size == 0: | |
| return {"error": "No numeric samples found in uploaded file."} | |
| # Optional resample (if requested) | |
| if original_fs is not None and resample_to is not None and SCIPY_AVAILABLE: | |
| data = _resample_if_requested(data, orig_fs=original_fs, target_fs=resample_to) | |
| # If user requested resampling but scipy not available -> notify in response metadata | |
| resample_unavailable = (original_fs is not None and resample_to is not None and not SCIPY_AVAILABLE) | |
| # Detect whether this is ADC-like hardware input or dataset (MIT-BIH style) | |
| # Heuristic: ADC counts are usually large integers (>> 100). MIT-BIH signals are small floats (~ -2 .. +2) | |
| max_val = float(np.max(np.abs(data))) | |
| mean_before = float(np.mean(data)) | |
| std_before = float(np.std(data)) | |
| is_adc_like = max_val > 100.0 # heuristic threshold; adjust if needed | |
| # If ADC-like -> convert to volts and z-score | |
| if is_adc_like: | |
| norm_data = _normalize_hardware_adc(data) | |
| source_type = "hardware_adc" | |
| else: | |
| # Assume dataset-like (MLII). We still apply z-score to match model training preprocessing. | |
| norm_data = _zscore(data) | |
| source_type = "dataset_like" | |
| # Segment into windows (256 samples, 50% overlap) | |
| window = 256 | |
| step = window // 2 | |
| segments = _segment_stream(norm_data, window=window, step=step) | |
| preds = _predict_segments(segments) # (num_segments, num_classes) | |
| pred_labels_idx = np.argmax(preds, axis=1) | |
| pred_confidences = np.max(preds, axis=1) | |
| # Aggregate: majority vote for labels; also compute mean confidence for that class | |
| # Map indices -> class codes | |
| seg_results = [] | |
| for i, (lab_idx, conf) in enumerate(zip(pred_labels_idx, pred_confidences)): | |
| code = CLASSES[int(lab_idx)] | |
| seg_results.append({ | |
| "segment": i, | |
| "label": code, | |
| "label_name": CLASS_NAMES[code], | |
| "confidence": float(round(float(conf), 4)), | |
| "mean_before": float(round(float(np.mean(segments[i])), 6)), | |
| "std_before": float(round(float(np.std(segments[i])), 6)) | |
| }) | |
| # Majority vote | |
| from collections import Counter | |
| votes = [CLASSES[int(i)] for i in pred_labels_idx] | |
| vote_counts = Counter(votes) | |
| majority_label, majority_count = vote_counts.most_common(1)[0] | |
| majority_confidence = float(np.mean([c for c,lab in zip(pred_confidences, votes) if lab == majority_label])) | |
| response = { | |
| "source_type": source_type, | |
| "original_stats": {"max": round(max_val, 6), "mean": round(mean_before, 6), "std": round(std_before, 6)}, | |
| "num_samples_uploaded": int(len(data)), | |
| "num_segments": len(segments), | |
| "per_segment": seg_results, | |
| "aggregate": { | |
| "label": majority_label, | |
| "label_name": CLASS_NAMES[majority_label], | |
| "votes_for_label": int(majority_count), | |
| "majority_confidence_mean": round(majority_confidence, 4) | |
| }, | |
| "resample_unavailable": resample_unavailable | |
| } | |
| return response | |