Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| import wfdb | |
| import tempfile | |
| import os | |
| from scipy.signal import resample | |
| import matplotlib.pyplot as plt | |
| # Custom activation functions | |
| def sin_activation(x): | |
| return tf.math.sin(x) | |
| def cos_activation(x): | |
| return tf.math.cos(x) | |
| # Load model with custom objects | |
| def load_model(): | |
| return tf.keras.models.load_model( | |
| "model.keras", | |
| custom_objects={ | |
| 'sin': sin_activation, | |
| 'cos': cos_activation, | |
| 'gelu': tf.keras.activations.gelu | |
| } | |
| ) | |
| model = load_model() | |
| # AAMI class map | |
| class_map = { | |
| 0: "Normal", | |
| 1: "Supraventricular Ectopic (SVEB)", | |
| 2: "Ventricular Ectopic (VEB)", | |
| 3: "Fusion Beat", | |
| 4: "Unknown" | |
| } | |
| def extract_beats(record, annotation, window_size=257): | |
| beats = [] | |
| r_locs = annotation.sample | |
| signal = record.p_signal[:, 0] # Using first channel | |
| for r in r_locs: | |
| start = max(0, r - window_size//2) | |
| end = min(len(signal), r + window_size//2 + 1) | |
| if end - start == window_size: | |
| beat = signal[start:end] | |
| beats.append(beat) | |
| return np.array(beats) | |
| st.title("ECG Arrhythmia Classification") | |
| st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108") | |
| record_loaded = False | |
| record = None | |
| annotation = None | |
| # Load Record 108 Button | |
| if st.button("Load Record 108"): | |
| try: | |
| base_name = "108" | |
| record = wfdb.rdrecord(base_name) | |
| annotation = wfdb.rdann(base_name, 'atr') | |
| record_loaded = True | |
| except Exception as e: | |
| st.error(f"Error loading Record 108: {str(e)}") | |
| # File uploader | |
| uploaded_files = st.file_uploader( | |
| "Or upload your own files", | |
| type=["dat", "hea", "atr"], | |
| accept_multiple_files=True | |
| ) | |
| if uploaded_files and not record_loaded: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| for f in uploaded_files: | |
| file_path = os.path.join(tmpdir, f.name) | |
| with open(file_path, "wb") as f_out: | |
| f_out.write(f.getbuffer()) | |
| base_names = {os.path.splitext(f.name)[0] for f in uploaded_files} | |
| common_base = list(base_names)[0] | |
| try: | |
| record = wfdb.rdrecord(os.path.join(tmpdir, common_base)) | |
| annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr') | |
| record_loaded = True | |
| except Exception as e: | |
| st.error(f"Error reading uploaded files: {str(e)}") | |
| # Run processing if record is loaded | |
| if record_loaded and record is not None and annotation is not None: | |
| beats = extract_beats(record, annotation) | |
| if len(beats) == 0: | |
| st.error("No valid beats found in the record") | |
| st.stop() | |
| beats = beats.reshape((-1, 257, 1)).astype(np.float32) | |
| predictions = model.predict(beats) | |
| predicted_classes = np.argmax(predictions, axis=1) | |
| st.subheader("Classification Results") | |
| results = pd.DataFrame({ | |
| "Beat Index": range(len(beats)), | |
| "Predicted Class": [class_map[c] for c in predicted_classes], | |
| "Confidence": np.max(predictions, axis=1) | |
| }) | |
| st.dataframe(results) | |
| # Class Distribution Section | |
| st.subheader("Class Distribution") | |
| # Get counts for all classes | |
| class_indices = list(class_map.keys()) | |
| class_names = [class_map[i] for i in class_indices] | |
| counts = [np.sum(predicted_classes == i) for i in class_indices] | |
| # Create distribution dataframe | |
| distribution_df = pd.DataFrame({ | |
| "Class": class_names, | |
| "Count": counts | |
| }) | |
| # Display in two columns | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.dataframe(distribution_df.style.format({'Count': '{:,}'})) | |
| with col2: | |
| st.bar_chart(distribution_df.set_index('Class')) | |
| st.subheader("Sample ECG Beat") | |
| fig, ax = plt.subplots() | |
| ax.plot(beats[0].flatten()) | |
| st.pyplot(fig) |