Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| import wfdb | |
| import tempfile | |
| import os | |
| from scipy.signal import resample | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| # Custom activation functions | |
| def sin_activation(x): | |
| return tf.math.sin(x) | |
| def cos_activation(x): | |
| return tf.math.cos(x) | |
| # Load model with custom objects | |
| def load_model(): | |
| return tf.keras.models.load_model( | |
| "model.keras", | |
| custom_objects={ | |
| 'sin': sin_activation, | |
| 'cos': cos_activation, | |
| 'gelu': tf.keras.activations.gelu | |
| } | |
| ) | |
| model = load_model() | |
| # AAMI class map | |
| class_map = { | |
| 0: "Normal", | |
| 1: "Supraventricular Ectopic (SVEB)", | |
| 2: "Ventricular Ectopic (VEB)", | |
| 3: "Fusion Beat", | |
| 4: "Unknown" | |
| } | |
| # Function to extract beats from record | |
| def extract_beats(record, annotation, window_size=257): | |
| beats = [] | |
| r_locs = annotation.sample | |
| signal = record.p_signal[:, 0] # Using first channel | |
| for r in r_locs: | |
| start = max(0, r - window_size//2) | |
| end = min(len(signal), r + window_size//2 + 1) | |
| if end - start == window_size: | |
| beat = signal[start:end] | |
| beats.append(beat) | |
| return np.array(beats) | |
| # Function to detect the last Conv1D layer in the model | |
| def get_last_conv_layer_name(model): | |
| last_conv_layer = None | |
| # Loop in reverse order over layers to find a Conv1D layer | |
| for layer in reversed(model.layers): | |
| if isinstance(layer, tf.keras.layers.Conv1D): | |
| last_conv_layer = layer.name | |
| break | |
| if last_conv_layer is None: | |
| st.error("No Conv1D layer found in the model. Grad-CAM requires a convolution layer.") | |
| return last_conv_layer | |
| # Function to generate Grad-CAM heatmap for a given beat and class index | |
| def make_gradcam_heatmap(beat, model, conv_layer_name, class_index): | |
| # Create a model that maps the input beat to the activations of the conv layer and the output predictions | |
| grad_model = tf.keras.models.Model( | |
| [model.inputs], | |
| [model.get_layer(conv_layer_name).output, model.output] | |
| ) | |
| # Record operations for automatic differentiation | |
| with tf.GradientTape() as tape: | |
| # Expand dims to add batch axis: shape (1, 257, 1) | |
| beat_tensor = tf.expand_dims(beat, axis=0) | |
| conv_outputs, predictions = grad_model(beat_tensor) | |
| loss = predictions[:, class_index] | |
| # Compute gradients of the target class wrt feature map | |
| grads = tape.gradient(loss, conv_outputs) | |
| # Global average pooling over the time dimension to get weights | |
| weights = tf.reduce_mean(grads, axis=1) | |
| # Compute the weighted sum of feature maps along the channel dimension | |
| cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1) | |
| cam = tf.squeeze(cam) # Remove batch dimension | |
| # Apply ReLU to the heatmap to keep only positive influences | |
| heatmap = tf.maximum(cam, 0) | |
| # Normalize heatmap to the [0, 1] range | |
| heatmap_max = tf.reduce_max(heatmap) | |
| if heatmap_max == 0: | |
| heatmap = tf.zeros_like(heatmap) | |
| else: | |
| heatmap /= heatmap_max | |
| heatmap = heatmap.numpy() | |
| # Resize heatmap to match the input beat size (if needed) | |
| # For 1D, we use cv2.resize with the new shape (length, 1) then flatten | |
| heatmap = cv2.resize(heatmap, (beat.shape[0], 1)).flatten() | |
| return heatmap | |
| # Streamlit App Layout | |
| st.title("ECG Arrhythmia Classification with Grad-CAM Visualization") | |
| st.write("Upload MIT-BIH record files (.dat, .hea, .atr) or load record 108") | |
| record_loaded = False | |
| record = None | |
| annotation = None | |
| # Load Record 108 Button | |
| if st.button("Load Record 108"): | |
| try: | |
| base_name = "108" | |
| record = wfdb.rdrecord(base_name) | |
| annotation = wfdb.rdann(base_name, 'atr') | |
| record_loaded = True | |
| except Exception as e: | |
| st.error(f"Error loading Record 108: {str(e)}") | |
| # File uploader | |
| uploaded_files = st.file_uploader( | |
| "Or upload your own files", | |
| type=["dat", "hea", "atr"], | |
| accept_multiple_files=True | |
| ) | |
| if uploaded_files and not record_loaded: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| for f in uploaded_files: | |
| file_path = os.path.join(tmpdir, f.name) | |
| with open(file_path, "wb") as f_out: | |
| f_out.write(f.getbuffer()) | |
| base_names = {os.path.splitext(f.name)[0] for f in uploaded_files} | |
| common_base = list(base_names)[0] | |
| try: | |
| record = wfdb.rdrecord(os.path.join(tmpdir, common_base)) | |
| annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr') | |
| record_loaded = True | |
| except Exception as e: | |
| st.error(f"Error reading uploaded files: {str(e)}") | |
| # Process the record if loaded | |
| if record_loaded and record is not None and annotation is not None: | |
| beats = extract_beats(record, annotation) | |
| if len(beats) == 0: | |
| st.error("No valid beats found in the record") | |
| st.stop() | |
| beats = beats.reshape((-1, 257, 1)).astype(np.float32) | |
| predictions = model.predict(beats) | |
| predicted_classes = np.argmax(predictions, axis=1) | |
| st.subheader("Classification Results") | |
| results = pd.DataFrame({ | |
| "Beat Index": range(len(beats)), | |
| "Predicted Class": [class_map[c] for c in predicted_classes], | |
| "Confidence": np.max(predictions, axis=1) | |
| }) | |
| st.dataframe(results) | |
| # Class Distribution Section | |
| st.subheader("Class Distribution") | |
| class_indices = list(class_map.keys()) | |
| class_names = [class_map[i] for i in class_indices] | |
| counts = [np.sum(predicted_classes == i) for i in class_indices] | |
| distribution_df = pd.DataFrame({ | |
| "Class": class_names, | |
| "Count": counts | |
| }) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.dataframe(distribution_df.style.format({'Count': '{:,}'})) | |
| with col2: | |
| st.bar_chart(distribution_df.set_index('Class')) | |
| # Display a Sample ECG Beat | |
| st.subheader("Sample ECG Beat") | |
| fig, ax = plt.subplots() | |
| ax.plot(beats[0].flatten(), label="ECG Beat") | |
| ax.legend() | |
| st.pyplot(fig) | |
| # ---------------- Grad-CAM Visualization Section ---------------- | |
| st.subheader("Grad-CAM Heatmap Visualization for Each Beat") | |
| st.write("Below are Grad-CAM heatmaps overlaying each beat. The heatmaps show the regions contributing most to the predicted class.") | |
| # Automatically detect the last convolutional layer name | |
| conv_layer_name = get_last_conv_layer_name(model) | |
| if conv_layer_name is not None: | |
| st.write(f"Using Conv1D layer: **{conv_layer_name}** for Grad-CAM.") | |
| # Optionally, you can limit the number of beats displayed to avoid long processing times. | |
| # For demonstration, here we process all beats, but you might want to show only the first N beats. | |
| show_all = st.checkbox("Show Grad-CAM for all beats", value=False) | |
| if not show_all: | |
| num_beats_to_show = st.number_input("Number of beats to show:", min_value=1, max_value=len(beats), value=5) | |
| else: | |
| num_beats_to_show = len(beats) | |
| # Loop over each beat and its prediction to generate Grad-CAM heatmap | |
| for idx in range(num_beats_to_show): | |
| beat = beats[idx] | |
| pred_class = predicted_classes[idx] | |
| predicted_label = class_map[pred_class] | |
| # Compute Grad-CAM heatmap for the beat | |
| heatmap = make_gradcam_heatmap(beat, model, conv_layer_name, pred_class) | |
| # Generate visualization figure | |
| fig, ax = plt.subplots(figsize=(10, 3)) | |
| # Plot the raw ECG beat | |
| ax.plot(beat.flatten(), color="black", label="ECG Beat") | |
| # Overlay Grad-CAM heatmap by scatter plotting points with a colormap according to heatmap value | |
| sc = ax.scatter(np.arange(len(beat)), beat.flatten(), c=heatmap, cmap="jet", s=25) | |
| ax.set_title(f"Beat {idx} - Predicted: {predicted_label}") | |
| ax.set_xlabel("Time Index") | |
| ax.set_ylabel("Amplitude") | |
| # Add a colorbar to indicate heatmap intensity | |
| fig.colorbar(sc, ax=ax, label="Grad-CAM Intensity") | |
| st.pyplot(fig) | |