import streamlit as st import numpy as np import pandas as pd import tensorflow as tf import matplotlib.pyplot as plt import neurokit2 as nk # Custom activation functions required by the model def sin_activation(x): return tf.math.sin(x) def cos_activation(x): return tf.math.cos(x) # Load model with custom objects @st.cache_resource def load_model(): return tf.keras.models.load_model( "model.keras", custom_objects={ 'sin': sin_activation, 'cos': cos_activation, 'gelu': tf.keras.activations.gelu } ) model = load_model() # AAMI class mapping matching training code CLASS_MAP = { 0: "Normal", 1: "Supraventricular Ectopic (SVEB)", 2: "Ventricular Ectopic (VEB)", 3: "Fusion Beat", 4: "Unknown" } def process_mitbih_file(dat_file): """Process MIT-BIH .dat file using NeuroKit2""" try: # Read raw signal data (assuming format 16, 360Hz, gain=200) signal = np.frombuffer(dat_file.getbuffer(), dtype=np.int16).astype(np.float32) signal /= 200.0 # Convert to mV using standard gain # NeuroKit2 processing with assumed 360Hz sampling rate ecg_signals, info = nk.ecg_process(signal, sampling_rate=360) r_peaks = info["ECG_R_Peaks"] # Extract beats with same parameters as training window_size = 257 beats = [] for r in r_peaks: start = max(0, r - window_size//2) end = start + window_size if end <= len(signal): beat = signal[start:end] beats.append(beat) return np.array(beats) except Exception as e: st.error(f"File processing error: {str(e)}") return None # Streamlit UI st.title("ECG Arrhythmia Detection") st.write("Upload MIT-BIH .dat file") uploaded_file = st.file_uploader( "Select .dat file", type=["dat"], accept_multiple_files=False ) if uploaded_file is not None: if st.button("Analyze"): with st.spinner("Processing ECG signal..."): beats = process_mitbih_file(uploaded_file) if beats is not None and len(beats) > 0: # Prepare data for model beats = beats.reshape((-1, 257, 1)).astype(np.float32) # Make predictions predictions = model.predict(beats) pred_classes = np.argmax(predictions, axis=1) # Show results st.subheader("Analysis Results") results = pd.DataFrame({ "Beat Index": range(len(beats)), "Predicted Class": [CLASS_MAP[c] for c in pred_classes], "Confidence": [f"{np.max(p):.1%}" for p in predictions] }) st.dataframe(results) # Visualizations st.subheader("ECG Signal") fig, ax = plt.subplots(1, 2, figsize=(15, 4)) ax[0].plot(beats[0].flatten()) ax[0].set_title("Sample ECG Beat") class_dist = results["Predicted Class"].value_counts() ax[1].bar(class_dist.index, class_dist.values) ax[1].set_title("Class Distribution") ax[1].tick_params(axis='x', rotation=45) st.pyplot(fig) else: st.error("Failed to extract valid beats from the signal")