ecg / app.py
lukiod's picture
Update app.py
fd3c0db verified
raw
history blame
3.97 kB
import streamlit as st
import numpy as np
import pandas as pd
import tensorflow as tf
import wfdb
import tempfile
import os
from scipy.signal import resample
import matplotlib.pyplot as plt
# Custom activation functions
def sin_activation(x):
return tf.math.sin(x)
def cos_activation(x):
return tf.math.cos(x)
# Load model with custom objects
@st.cache_resource
def load_model():
return tf.keras.models.load_model(
"model.keras",
custom_objects={
'sin': sin_activation,
'cos': cos_activation,
'gelu': tf.keras.activations.gelu
}
)
model = load_model()
# AAMI class map
class_map = {
0: "Normal",
1: "Supraventricular Ectopic (SVEB)",
2: "Ventricular Ectopic (VEB)",
3: "Fusion Beat",
4: "Unknown"
}
def extract_beats(record, annotation, window_size=257):
beats = []
r_locs = annotation.sample
signal = record.p_signal[:, 0] # Using first channel
for r in r_locs:
start = max(0, r - window_size//2)
end = min(len(signal), r + window_size//2 + 1)
if end - start == window_size:
beat = signal[start:end]
beats.append(beat)
return np.array(beats)
st.title("ECG Arrhythmia Classification")
st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108")
record_loaded = False
record = None
annotation = None
# Load Record 108 Button
if st.button("Load Record 108"):
try:
base_name = "108"
record = wfdb.rdrecord(base_name)
annotation = wfdb.rdann(base_name, 'atr')
record_loaded = True
except Exception as e:
st.error(f"Error loading Record 108: {str(e)}")
# File uploader
uploaded_files = st.file_uploader(
"Or upload your own files",
type=["dat", "hea", "atr"],
accept_multiple_files=True
)
if uploaded_files and not record_loaded:
with tempfile.TemporaryDirectory() as tmpdir:
for f in uploaded_files:
file_path = os.path.join(tmpdir, f.name)
with open(file_path, "wb") as f_out:
f_out.write(f.getbuffer())
base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
common_base = list(base_names)[0]
try:
record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
record_loaded = True
except Exception as e:
st.error(f"Error reading uploaded files: {str(e)}")
# Run processing if record is loaded
if record_loaded and record is not None and annotation is not None:
beats = extract_beats(record, annotation)
if len(beats) == 0:
st.error("No valid beats found in the record")
st.stop()
beats = beats.reshape((-1, 257, 1)).astype(np.float32)
predictions = model.predict(beats)
predicted_classes = np.argmax(predictions, axis=1)
st.subheader("Classification Results")
results = pd.DataFrame({
"Beat Index": range(len(beats)),
"Predicted Class": [class_map[c] for c in predicted_classes],
"Confidence": np.max(predictions, axis=1)
})
st.dataframe(results)
# Class Distribution Section
st.subheader("Class Distribution")
# Get counts for all classes
class_indices = list(class_map.keys())
class_names = [class_map[i] for i in class_indices]
counts = [np.sum(predicted_classes == i) for i in class_indices]
# Create distribution dataframe
distribution_df = pd.DataFrame({
"Class": class_names,
"Count": counts
})
# Display in two columns
col1, col2 = st.columns([1, 2])
with col1:
st.dataframe(distribution_df.style.format({'Count': '{:,}'}))
with col2:
st.bar_chart(distribution_df.set_index('Class'))
st.subheader("Sample ECG Beat")
fig, ax = plt.subplots()
ax.plot(beats[0].flatten())
st.pyplot(fig)