import streamlit as st import numpy as np import pandas as pd import tensorflow as tf import wfdb import tempfile import os from scipy.signal import resample import matplotlib.pyplot as plt # Custom activation functions def sin_activation(x): return tf.math.sin(x) def cos_activation(x): return tf.math.cos(x) # Load model with custom objects @st.cache_resource def load_model(): return tf.keras.models.load_model( "model.keras", custom_objects={ 'sin': sin_activation, 'cos': cos_activation, 'gelu': tf.keras.activations.gelu } ) model = load_model() # AAMI class map class_map = { 0: "Normal", 1: "Supraventricular Ectopic (SVEB)", 2: "Ventricular Ectopic (VEB)", 3: "Fusion Beat", 4: "Unknown" } def extract_beats(record, annotation, window_size=257): beats = [] r_locs = annotation.sample signal = record.p_signal[:, 0] # Using first channel for r in r_locs: start = max(0, r - window_size//2) end = min(len(signal), r + window_size//2 + 1) if end - start == window_size: beat = signal[start:end] beats.append(beat) return np.array(beats) st.title("ECG Arrhythmia Classification") st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108") record_loaded = False record = None annotation = None # Load Record 108 Button if st.button("Load Record 108"): try: base_name = "108" record = wfdb.rdrecord(base_name) annotation = wfdb.rdann(base_name, 'atr') record_loaded = True except Exception as e: st.error(f"Error loading Record 108: {str(e)}") # File uploader uploaded_files = st.file_uploader( "Or upload your own files", type=["dat", "hea", "atr"], accept_multiple_files=True ) if uploaded_files and not record_loaded: with tempfile.TemporaryDirectory() as tmpdir: for f in uploaded_files: file_path = os.path.join(tmpdir, f.name) with open(file_path, "wb") as f_out: f_out.write(f.getbuffer()) base_names = {os.path.splitext(f.name)[0] for f in uploaded_files} common_base = list(base_names)[0] try: record = wfdb.rdrecord(os.path.join(tmpdir, common_base)) annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr') record_loaded = True except Exception as e: st.error(f"Error reading uploaded files: {str(e)}") # Run processing if record is loaded if record_loaded and record is not None and annotation is not None: beats = extract_beats(record, annotation) if len(beats) == 0: st.error("No valid beats found in the record") st.stop() beats = beats.reshape((-1, 257, 1)).astype(np.float32) predictions = model.predict(beats) predicted_classes = np.argmax(predictions, axis=1) st.subheader("Classification Results") results = pd.DataFrame({ "Beat Index": range(len(beats)), "Predicted Class": [class_map[c] for c in predicted_classes], "Confidence": np.max(predictions, axis=1) }) st.dataframe(results) # Class Distribution Section st.subheader("Class Distribution") # Get counts for all classes class_indices = list(class_map.keys()) class_names = [class_map[i] for i in class_indices] counts = [np.sum(predicted_classes == i) for i in class_indices] # Create distribution dataframe distribution_df = pd.DataFrame({ "Class": class_names, "Count": counts }) # Display in two columns col1, col2 = st.columns([1, 2]) with col1: st.dataframe(distribution_df.style.format({'Count': '{:,}'})) with col2: st.bar_chart(distribution_df.set_index('Class')) st.subheader("Sample ECG Beat") fig, ax = plt.subplots() ax.plot(beats[0].flatten()) st.pyplot(fig)