Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| import wfdb | |
| import tempfile | |
| import os | |
| from scipy.signal import resample | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| # Custom activation functions | |
| def sin_activation(x): | |
| return tf.math.sin(x) | |
| def cos_activation(x): | |
| return tf.math.cos(x) | |
| # Load model with custom objects | |
| def load_model(): | |
| return tf.keras.models.load_model( | |
| "model.keras", | |
| custom_objects={ | |
| 'sin': sin_activation, | |
| 'cos': cos_activation, | |
| 'gelu': tf.keras.activations.gelu | |
| } | |
| ) | |
| model = load_model() | |
| # AAMI class map | |
| class_map = { | |
| 0: "Normal", | |
| 1: "Supraventricular Ectopic (SVEB)", | |
| 2: "Ventricular Ectopic (VEB)", | |
| 3: "Fusion Beat", | |
| 4: "Unknown" | |
| } | |
| # Function to extract beats from record | |
| def extract_beats(record, annotation, window_size=257): | |
| beats = [] | |
| r_locs = annotation.sample | |
| signal = record.p_signal[:, 0] # Using first channel | |
| for r in r_locs: | |
| start = max(0, r - window_size//2) | |
| end = min(len(signal), r + window_size//2 + 1) | |
| if end - start == window_size: | |
| beat = signal[start:end] | |
| beats.append(beat) | |
| return np.array(beats) | |
| # Function to detect the last Conv1D layer in the model | |
| def get_last_conv_layer_name(model): | |
| last_conv_layer = None | |
| for layer in reversed(model.layers): | |
| if isinstance(layer, tf.keras.layers.Conv1D): | |
| last_conv_layer = layer.name | |
| break | |
| if last_conv_layer is None: | |
| st.error("No Conv1D layer found in the model. Grad-CAM requires a convolution layer.") | |
| return last_conv_layer | |
| # Function to generate Grad-CAM heatmap for a given beat and class index | |
| def generate_grad_cam(model, sample, layer_name): | |
| grad_model = tf.keras.models.Model( | |
| inputs=model.inputs, | |
| outputs=[model.get_layer(layer_name).output, model.output] | |
| ) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, predictions = grad_model(sample) | |
| class_idx = tf.argmax(predictions[0]) | |
| loss = predictions[:, class_idx] | |
| grads = tape.gradient(loss, conv_outputs) | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1)) | |
| conv_outputs = tf.squeeze(conv_outputs, axis=0) | |
| cam = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1) | |
| raw = cam.numpy() | |
| print("raw min/max:", raw.min(), raw.max()) | |
| cam = tf.abs(cam) | |
| cam = cam / (tf.reduce_max(cam) + 1e-8) | |
| return cam.numpy() | |
| # Initialize session state variables if not already set | |
| if 'record_loaded' not in st.session_state: | |
| st.session_state.record_loaded = False | |
| if 'record' not in st.session_state: | |
| st.session_state.record = None | |
| if 'annotation' not in st.session_state: | |
| st.session_state.annotation = None | |
| # Streamlit App Layout | |
| st.title("ECG Arrhythmia Classification with Grad-CAM Visualization") | |
| st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108") | |
| # Load Record 108 Button | |
| if st.button("Load Record 108"): | |
| try: | |
| base_name = "108" | |
| st.session_state.record = wfdb.rdrecord(base_name) | |
| st.session_state.annotation = wfdb.rdann(base_name, 'atr') | |
| st.session_state.record_loaded = True | |
| except Exception as e: | |
| st.error(f"Error loading Record 108: {str(e)}") | |
| # File uploader | |
| uploaded_files = st.file_uploader( | |
| "Or upload your own files", | |
| type=["dat", "hea", "atr"], | |
| accept_multiple_files=True | |
| ) | |
| if uploaded_files and not st.session_state.record_loaded: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| for f in uploaded_files: | |
| file_path = os.path.join(tmpdir, f.name) | |
| with open(file_path, "wb") as f_out: | |
| f_out.write(f.getbuffer()) | |
| base_names = {os.path.splitext(f.name)[0] for f in uploaded_files} | |
| common_base = list(base_names)[0] | |
| try: | |
| st.session_state.record = wfdb.rdrecord(os.path.join(tmpdir, common_base)) | |
| st.session_state.annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr') | |
| st.session_state.record_loaded = True | |
| except Exception as e: | |
| st.error(f"Error reading uploaded files: {str(e)}") | |
| # Process the record if loaded | |
| if st.session_state.record_loaded and st.session_state.record is not None and st.session_state.annotation is not None: | |
| beats = extract_beats(st.session_state.record, st.session_state.annotation) | |
| if len(beats) == 0: | |
| st.error("No valid beats found in the record") | |
| st.stop() | |
| beats = beats.reshape((-1, 257, 1)).astype(np.float32) | |
| predictions = model.predict(beats) | |
| predicted_classes = np.argmax(predictions, axis=1) | |
| st.subheader("Classification Results") | |
| results = pd.DataFrame({ | |
| "Beat Index": range(len(beats)), | |
| "Predicted Class": [class_map[c] for c in predicted_classes], | |
| "Confidence": np.max(predictions, axis=1) | |
| }) | |
| st.dataframe(results) | |
| st.subheader("Class Distribution") | |
| class_indices = list(class_map.keys()) | |
| class_names = [class_map[i] for i in class_indices] | |
| counts = [np.sum(predicted_classes == i) for i in class_indices] | |
| distribution_df = pd.DataFrame({ | |
| "Class": class_names, | |
| "Count": counts | |
| }) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.dataframe(distribution_df.style.format({'Count': '{:,}'})) | |
| with col2: | |
| st.bar_chart(distribution_df.set_index('Class')) | |
| st.subheader("Sample ECG Beat") | |
| fig, ax = plt.subplots() | |
| ax.plot(beats[0].flatten(), label="ECG Beat") | |
| ax.legend() | |
| st.pyplot(fig) | |
| st.subheader("Class Comparison with Grad-CAM") | |
| st.write("Compare model explanations between classes present in this record") | |
| conv_layer_name = get_last_conv_layer_name(model) | |
| if conv_layer_name is not None: | |
| st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.") | |
| present_classes = distribution_df[distribution_df['Count'] > 0]['Class'].tolist() | |
| if not present_classes: | |
| st.warning("No classes with detected beats to compare") | |
| st.stop() | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| with col1: | |
| left_class = st.selectbox("Left Class:", options=present_classes, index=0) | |
| with col2: | |
| right_index = 1 if len(present_classes) > 1 else 0 | |
| right_class = st.selectbox("Right Class:", options=present_classes, index=right_index) | |
| with col3: | |
| num_beats = st.number_input("Beats per class:", min_value=1, max_value=10, value=3) | |
| class_name_to_idx = {v: k for k, v in class_map.items()} | |
| left_class_idx = class_name_to_idx[left_class] | |
| right_class_idx = class_name_to_idx[right_class] | |
| left_indices = np.where(predicted_classes == left_class_idx)[0] | |
| right_indices = np.where(predicted_classes == right_class_idx)[0] | |
| left_col, right_col = st.columns(2) | |
| def display_class_beats(col, class_name, beat_indices, num_beats): | |
| with col: | |
| st.subheader(class_name) | |
| if len(beat_indices) == 0: | |
| st.warning(f"No {class_name} beats found") | |
| return | |
| for beat_idx in beat_indices[:num_beats]: | |
| beat = beats[beat_idx].flatten() | |
| sample = beat.reshape(1, -1, 1).astype(np.float32) | |
| heatmap = generate_grad_cam(model, sample, conv_layer_name) | |
| fig, ax = plt.subplots(figsize=(8, 2)) | |
| y_min, y_max = beat.min(), beat.max() | |
| ax.imshow( | |
| np.expand_dims(heatmap, axis=0), | |
| aspect='auto', | |
| cmap='jet', | |
| alpha=0.5, | |
| extent=[0, len(beat), y_min, y_max] | |
| ) | |
| ax.plot(beat, linewidth=2, color='blue') | |
| ax.axis('off') | |
| ax.set_title(f"Beat {beat_idx}") | |
| ax.set_xlim(0, len(beat)) | |
| ax.set_ylim(y_min, y_max) | |
| st.pyplot(fig) | |
| display_class_beats(left_col, left_class, left_indices, num_beats) | |
| display_class_beats(right_col, right_class, right_indices, num_beats) | |
| if left_class == right_class: | |
| st.info("Comparing different instances of the same class. Note: This shows intra-class variation.") | |