Spaces:
Sleeping
Sleeping
File size: 3,510 Bytes
28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 28dbb9e 48fb238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import streamlit as st
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import neurokit2 as nk
# Custom activation functions required by the model
def sin_activation(x):
return tf.math.sin(x)
def cos_activation(x):
return tf.math.cos(x)
# Load model with custom objects
@st.cache_resource
def load_model():
return tf.keras.models.load_model(
"model.keras",
custom_objects={
'sin': sin_activation,
'cos': cos_activation,
'gelu': tf.keras.activations.gelu
}
)
model = load_model()
# AAMI class mapping matching training code
CLASS_MAP = {
0: "Normal",
1: "Supraventricular Ectopic (SVEB)",
2: "Ventricular Ectopic (VEB)",
3: "Fusion Beat",
4: "Unknown"
}
def process_mitbih_file(dat_file):
"""Process MIT-BIH .dat file using NeuroKit2"""
try:
# Read raw signal data (assuming format 16, 360Hz, gain=200)
signal = np.frombuffer(dat_file.getbuffer(), dtype=np.int16).astype(np.float32)
signal /= 200.0 # Convert to mV using standard gain
# NeuroKit2 processing with assumed 360Hz sampling rate
ecg_signals, info = nk.ecg_process(signal, sampling_rate=360)
r_peaks = info["ECG_R_Peaks"]
# Extract beats with same parameters as training
window_size = 257
beats = []
for r in r_peaks:
start = max(0, r - window_size//2)
end = start + window_size
if end <= len(signal):
beat = signal[start:end]
beats.append(beat)
return np.array(beats)
except Exception as e:
st.error(f"File processing error: {str(e)}")
return None
# Streamlit UI
st.title("ECG Arrhythmia Detection")
st.write("Upload MIT-BIH .dat file")
uploaded_file = st.file_uploader(
"Select .dat file",
type=["dat"],
accept_multiple_files=False
)
if uploaded_file is not None:
if st.button("Analyze"):
with st.spinner("Processing ECG signal..."):
beats = process_mitbih_file(uploaded_file)
if beats is not None and len(beats) > 0:
# Prepare data for model
beats = beats.reshape((-1, 257, 1)).astype(np.float32)
# Make predictions
predictions = model.predict(beats)
pred_classes = np.argmax(predictions, axis=1)
# Show results
st.subheader("Analysis Results")
results = pd.DataFrame({
"Beat Index": range(len(beats)),
"Predicted Class": [CLASS_MAP[c] for c in pred_classes],
"Confidence": [f"{np.max(p):.1%}" for p in predictions]
})
st.dataframe(results)
# Visualizations
st.subheader("ECG Signal")
fig, ax = plt.subplots(1, 2, figsize=(15, 4))
ax[0].plot(beats[0].flatten())
ax[0].set_title("Sample ECG Beat")
class_dist = results["Predicted Class"].value_counts()
ax[1].bar(class_dist.index, class_dist.values)
ax[1].set_title("Class Distribution")
ax[1].tick_params(axis='x', rotation=45)
st.pyplot(fig)
else:
st.error("Failed to extract valid beats from the signal") |