ecg / app.py
lukiod's picture
Update app.py
48fb238 verified
raw
history blame
3.51 kB
import streamlit as st
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import neurokit2 as nk
# Custom activation functions required by the model
def sin_activation(x):
return tf.math.sin(x)
def cos_activation(x):
return tf.math.cos(x)
# Load model with custom objects
@st.cache_resource
def load_model():
return tf.keras.models.load_model(
"model.keras",
custom_objects={
'sin': sin_activation,
'cos': cos_activation,
'gelu': tf.keras.activations.gelu
}
)
model = load_model()
# AAMI class mapping matching training code
CLASS_MAP = {
0: "Normal",
1: "Supraventricular Ectopic (SVEB)",
2: "Ventricular Ectopic (VEB)",
3: "Fusion Beat",
4: "Unknown"
}
def process_mitbih_file(dat_file):
"""Process MIT-BIH .dat file using NeuroKit2"""
try:
# Read raw signal data (assuming format 16, 360Hz, gain=200)
signal = np.frombuffer(dat_file.getbuffer(), dtype=np.int16).astype(np.float32)
signal /= 200.0 # Convert to mV using standard gain
# NeuroKit2 processing with assumed 360Hz sampling rate
ecg_signals, info = nk.ecg_process(signal, sampling_rate=360)
r_peaks = info["ECG_R_Peaks"]
# Extract beats with same parameters as training
window_size = 257
beats = []
for r in r_peaks:
start = max(0, r - window_size//2)
end = start + window_size
if end <= len(signal):
beat = signal[start:end]
beats.append(beat)
return np.array(beats)
except Exception as e:
st.error(f"File processing error: {str(e)}")
return None
# Streamlit UI
st.title("ECG Arrhythmia Detection")
st.write("Upload MIT-BIH .dat file")
uploaded_file = st.file_uploader(
"Select .dat file",
type=["dat"],
accept_multiple_files=False
)
if uploaded_file is not None:
if st.button("Analyze"):
with st.spinner("Processing ECG signal..."):
beats = process_mitbih_file(uploaded_file)
if beats is not None and len(beats) > 0:
# Prepare data for model
beats = beats.reshape((-1, 257, 1)).astype(np.float32)
# Make predictions
predictions = model.predict(beats)
pred_classes = np.argmax(predictions, axis=1)
# Show results
st.subheader("Analysis Results")
results = pd.DataFrame({
"Beat Index": range(len(beats)),
"Predicted Class": [CLASS_MAP[c] for c in pred_classes],
"Confidence": [f"{np.max(p):.1%}" for p in predictions]
})
st.dataframe(results)
# Visualizations
st.subheader("ECG Signal")
fig, ax = plt.subplots(1, 2, figsize=(15, 4))
ax[0].plot(beats[0].flatten())
ax[0].set_title("Sample ECG Beat")
class_dist = results["Predicted Class"].value_counts()
ax[1].bar(class_dist.index, class_dist.values)
ax[1].set_title("Class Distribution")
ax[1].tick_params(axis='x', rotation=45)
st.pyplot(fig)
else:
st.error("Failed to extract valid beats from the signal")