Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| import neurokit2 as nk | |
| # Custom activation functions required by the model | |
| def sin_activation(x): | |
| return tf.math.sin(x) | |
| def cos_activation(x): | |
| return tf.math.cos(x) | |
| # Load model with custom objects | |
| def load_model(): | |
| return tf.keras.models.load_model( | |
| "model.keras", | |
| custom_objects={ | |
| 'sin': sin_activation, | |
| 'cos': cos_activation, | |
| 'gelu': tf.keras.activations.gelu | |
| } | |
| ) | |
| model = load_model() | |
| # AAMI class mapping matching training code | |
| CLASS_MAP = { | |
| 0: "Normal", | |
| 1: "Supraventricular Ectopic (SVEB)", | |
| 2: "Ventricular Ectopic (VEB)", | |
| 3: "Fusion Beat", | |
| 4: "Unknown" | |
| } | |
| def process_mitbih_file(dat_file): | |
| """Process MIT-BIH .dat file using NeuroKit2""" | |
| try: | |
| # Read raw signal data (assuming format 16, 360Hz, gain=200) | |
| signal = np.frombuffer(dat_file.getbuffer(), dtype=np.int16).astype(np.float32) | |
| signal /= 200.0 # Convert to mV using standard gain | |
| # NeuroKit2 processing with assumed 360Hz sampling rate | |
| ecg_signals, info = nk.ecg_process(signal, sampling_rate=360) | |
| r_peaks = info["ECG_R_Peaks"] | |
| # Extract beats with same parameters as training | |
| window_size = 257 | |
| beats = [] | |
| for r in r_peaks: | |
| start = max(0, r - window_size//2) | |
| end = start + window_size | |
| if end <= len(signal): | |
| beat = signal[start:end] | |
| beats.append(beat) | |
| return np.array(beats) | |
| except Exception as e: | |
| st.error(f"File processing error: {str(e)}") | |
| return None | |
| # Streamlit UI | |
| st.title("ECG Arrhythmia Detection") | |
| st.write("Upload MIT-BIH .dat file") | |
| uploaded_file = st.file_uploader( | |
| "Select .dat file", | |
| type=["dat"], | |
| accept_multiple_files=False | |
| ) | |
| if uploaded_file is not None: | |
| if st.button("Analyze"): | |
| with st.spinner("Processing ECG signal..."): | |
| beats = process_mitbih_file(uploaded_file) | |
| if beats is not None and len(beats) > 0: | |
| # Prepare data for model | |
| beats = beats.reshape((-1, 257, 1)).astype(np.float32) | |
| # Make predictions | |
| predictions = model.predict(beats) | |
| pred_classes = np.argmax(predictions, axis=1) | |
| # Show results | |
| st.subheader("Analysis Results") | |
| results = pd.DataFrame({ | |
| "Beat Index": range(len(beats)), | |
| "Predicted Class": [CLASS_MAP[c] for c in pred_classes], | |
| "Confidence": [f"{np.max(p):.1%}" for p in predictions] | |
| }) | |
| st.dataframe(results) | |
| # Visualizations | |
| st.subheader("ECG Signal") | |
| fig, ax = plt.subplots(1, 2, figsize=(15, 4)) | |
| ax[0].plot(beats[0].flatten()) | |
| ax[0].set_title("Sample ECG Beat") | |
| class_dist = results["Predicted Class"].value_counts() | |
| ax[1].bar(class_dist.index, class_dist.values) | |
| ax[1].set_title("Class Distribution") | |
| ax[1].tick_params(axis='x', rotation=45) | |
| st.pyplot(fig) | |
| else: | |
| st.error("Failed to extract valid beats from the signal") |