Spaces:
Sleeping
Sleeping
| """ | |
| HeartWatch AI Visualization Module | |
| This module provides visualization functions for ECG analysis including: | |
| - 12-lead ECG waveform plotting with clinical layout | |
| - Diagnosis probability bar charts | |
| - Risk assessment gauges | |
| - ECG thumbnail generation for galleries | |
| """ | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib.patches import Wedge | |
| from PIL import Image | |
| import io | |
| # Standard 12-lead ECG names in clinical order | |
| LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
| # Clinical layout: 4 columns x 3 rows | |
| # Col 1: I, II, III | Col 2: aVR, aVL, aVF | Col 3: V1, V2, V3 | Col 4: V4, V5, V6 | |
| LEAD_LAYOUT = [ | |
| ['I', 'aVR', 'V1', 'V4'], | |
| ['II', 'aVL', 'V2', 'V5'], | |
| ['III', 'aVF', 'V3', 'V6'] | |
| ] | |
| def plot_ecg_waveform(ecg_signal: np.ndarray, sample_rate: int = 250, | |
| title: str = "12-Lead ECG") -> plt.Figure: | |
| """ | |
| Plot a 12-lead ECG waveform in clinical layout format. | |
| Parameters | |
| ---------- | |
| ecg_signal : np.ndarray | |
| ECG signal array of shape (12, n_samples) or (n_samples, 12) | |
| Each row/column represents one of the 12 standard leads | |
| sample_rate : int, optional | |
| Sampling rate in Hz, default 250 | |
| title : str, optional | |
| Figure title, default "12-Lead ECG" | |
| Returns | |
| ------- | |
| plt.Figure | |
| Matplotlib figure with 4x3 ECG layout | |
| """ | |
| # Ensure correct shape (12, n_samples) | |
| if ecg_signal.shape[0] != 12: | |
| if ecg_signal.shape[1] == 12: | |
| ecg_signal = ecg_signal.T | |
| else: | |
| raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}") | |
| n_samples = ecg_signal.shape[1] | |
| # 2.5 seconds per column | |
| samples_per_col = int(2.5 * sample_rate) | |
| # Create figure with clinical dimensions | |
| fig, axes = plt.subplots(3, 4, figsize=(14, 8)) | |
| fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98) | |
| # Create lead index mapping | |
| lead_to_idx = {name: i for i, name in enumerate(LEAD_NAMES)} | |
| for row in range(3): | |
| for col in range(4): | |
| ax = axes[row, col] | |
| lead_name = LEAD_LAYOUT[row][col] | |
| lead_idx = lead_to_idx[lead_name] | |
| # Get signal segment for this column (2.5 sec) | |
| start_sample = 0 | |
| end_sample = min(samples_per_col, n_samples) | |
| signal_segment = ecg_signal[lead_idx, start_sample:end_sample] | |
| time_segment = np.arange(len(signal_segment)) / sample_rate | |
| # Set up ECG paper grid background (pink/red) | |
| ax.set_facecolor('#fff5f5') | |
| # Major grid (0.5 sec, 0.5 mV equivalent) | |
| ax.set_axisbelow(True) | |
| ax.grid(True, which='major', color='#ffcccc', linewidth=0.8, linestyle='-') | |
| ax.grid(True, which='minor', color='#ffe6e6', linewidth=0.4, linestyle='-') | |
| # Set tick spacing for major/minor grids | |
| ax.set_xticks(np.arange(0, 2.6, 0.5)) | |
| ax.set_xticks(np.arange(0, 2.6, 0.1), minor=True) | |
| # Calculate y-limits based on signal range | |
| signal_min, signal_max = signal_segment.min(), signal_segment.max() | |
| signal_range = signal_max - signal_min | |
| if signal_range < 0.1: | |
| signal_range = 2.0 # Default range if signal is flat | |
| padding = signal_range * 0.1 | |
| y_min = signal_min - padding | |
| y_max = signal_max + padding | |
| # Set y-ticks for grid | |
| y_tick_spacing = signal_range / 4 | |
| ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing, y_tick_spacing)) | |
| ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing/5, y_tick_spacing/5), minor=True) | |
| # Plot ECG waveform | |
| ax.plot(time_segment, signal_segment, color='black', linewidth=0.8) | |
| # Add lead label | |
| ax.text(0.02, 0.98, lead_name, transform=ax.transAxes, | |
| fontsize=10, fontweight='bold', verticalalignment='top', | |
| bbox=dict(boxstyle='round,pad=0.2', facecolor='white', | |
| edgecolor='none', alpha=0.7)) | |
| # Set axis limits | |
| ax.set_xlim(0, 2.5) | |
| ax.set_ylim(y_min, y_max) | |
| # Remove tick labels for cleaner look (except bottom row and left column) | |
| if row < 2: | |
| ax.set_xticklabels([]) | |
| else: | |
| ax.set_xlabel('Time (s)', fontsize=8) | |
| if col > 0: | |
| ax.set_yticklabels([]) | |
| else: | |
| ax.set_ylabel('Amplitude (mV)', fontsize=8) | |
| ax.tick_params(axis='both', which='both', labelsize=6) | |
| plt.tight_layout(rect=[0, 0, 1, 0.96]) | |
| return fig | |
| def plot_diagnosis_bars(diagnosis_77: dict, top_n: int = 10, | |
| ground_truth: list = None) -> plt.Figure: | |
| """ | |
| Plot horizontal bar chart of diagnosis probabilities. | |
| Parameters | |
| ---------- | |
| diagnosis_77 : dict | |
| Dictionary mapping diagnosis names to probabilities (0-1) | |
| top_n : int, optional | |
| Number of top diagnoses to display, default 10 | |
| ground_truth : list, optional | |
| List of ground truth diagnosis names to mark with star | |
| Returns | |
| ------- | |
| plt.Figure | |
| Matplotlib figure with horizontal bar chart | |
| """ | |
| if ground_truth is None: | |
| ground_truth = [] | |
| # Sort diagnoses by probability (descending) | |
| sorted_diagnoses = sorted(diagnosis_77.items(), key=lambda x: x[1], reverse=True) | |
| top_diagnoses = sorted_diagnoses[:top_n] | |
| # Extract names and probabilities | |
| names = [d[0] for d in top_diagnoses] | |
| probs = [d[1] for d in top_diagnoses] | |
| # Determine colors based on probability thresholds | |
| colors = [] | |
| for p in probs: | |
| if p >= 0.7: | |
| colors.append('#2ecc71') # Green for high confidence | |
| elif p >= 0.3: | |
| colors.append('#f1c40f') # Yellow for moderate | |
| else: | |
| colors.append('#95a5a6') # Gray for low confidence | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| # Create horizontal bar chart | |
| y_pos = np.arange(len(names)) | |
| bars = ax.barh(y_pos, probs, color=colors, edgecolor='black', linewidth=0.5) | |
| # Add probability labels on bars | |
| for i, (bar, prob) in enumerate(zip(bars, probs)): | |
| width = bar.get_width() | |
| label_x = width + 0.02 if width < 0.85 else width - 0.08 | |
| label_color = 'black' if width < 0.85 else 'white' | |
| ax.text(label_x, bar.get_y() + bar.get_height()/2, | |
| f'{prob:.1%}', va='center', fontsize=9, color=label_color) | |
| # Mark ground truth with star | |
| display_names = [] | |
| for name in names: | |
| if name in ground_truth: | |
| display_names.append(f'{name} \u2605') # Unicode star | |
| else: | |
| display_names.append(name) | |
| # Set y-axis labels | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(display_names, fontsize=9) | |
| # Set axis limits and labels | |
| ax.set_xlim(0, 1.0) | |
| ax.set_xlabel('Probability', fontsize=11) | |
| ax.set_title('Diagnosis Probabilities (Top {})'.format(top_n), | |
| fontsize=12, fontweight='bold', pad=10) | |
| # Add legend | |
| legend_elements = [ | |
| mpatches.Patch(facecolor='#2ecc71', edgecolor='black', label='High (\u2265 70%)'), | |
| mpatches.Patch(facecolor='#f1c40f', edgecolor='black', label='Moderate (30-70%)'), | |
| mpatches.Patch(facecolor='#95a5a6', edgecolor='black', label='Low (< 30%)') | |
| ] | |
| if ground_truth: | |
| legend_elements.append(mpatches.Patch(facecolor='white', edgecolor='white', | |
| label='\u2605 = Ground Truth')) | |
| ax.legend(handles=legend_elements, loc='lower right', fontsize=8) | |
| # Add grid for readability | |
| ax.xaxis.grid(True, linestyle='--', alpha=0.7) | |
| ax.set_axisbelow(True) | |
| # Invert y-axis so highest probability is at top | |
| ax.invert_yaxis() | |
| plt.tight_layout() | |
| return fig | |
| def _draw_gauge(ax, value: float, title: str): | |
| """ | |
| Draw a semicircular gauge on the given axes. | |
| Parameters | |
| ---------- | |
| ax : matplotlib.axes.Axes | |
| Axes to draw on | |
| value : float | |
| Value between 0 and 1 to display | |
| title : str | |
| Gauge title | |
| """ | |
| # Clear axes | |
| ax.clear() | |
| ax.set_xlim(-1.5, 1.5) | |
| ax.set_ylim(-0.3, 1.3) | |
| ax.set_aspect('equal') | |
| ax.axis('off') | |
| # Create gradient background arc (Green -> Yellow -> Red) | |
| n_segments = 100 | |
| for i in range(n_segments): | |
| theta1 = 180 - i * (180 / n_segments) | |
| theta2 = 180 - (i + 1) * (180 / n_segments) | |
| # Calculate color based on position | |
| pos = i / n_segments | |
| if pos < 0.3: | |
| # Green zone | |
| color = '#2ecc71' | |
| elif pos < 0.6: | |
| # Yellow zone (transition from green to yellow) | |
| t = (pos - 0.3) / 0.3 | |
| r = int(46 + t * (241 - 46)) | |
| g = int(204 + t * (196 - 204)) | |
| b = int(113 + t * (15 - 113)) | |
| color = f'#{r:02x}{g:02x}{b:02x}' | |
| else: | |
| # Red zone (transition from yellow to red) | |
| t = (pos - 0.6) / 0.4 | |
| r = int(241 + t * (231 - 241)) | |
| g = int(196 - t * 196) | |
| b = int(15 - t * 15) | |
| color = f'#{r:02x}{g:02x}{b:02x}' | |
| wedge = Wedge((0, 0), 1.0, theta2, theta1, width=0.3, facecolor=color, | |
| edgecolor='white', linewidth=0.5) | |
| ax.add_patch(wedge) | |
| # Draw needle | |
| needle_angle = 180 - value * 180 | |
| needle_rad = np.radians(needle_angle) | |
| needle_length = 0.85 | |
| needle_x = needle_length * np.cos(needle_rad) | |
| needle_y = needle_length * np.sin(needle_rad) | |
| ax.annotate('', xy=(needle_x, needle_y), xytext=(0, 0), | |
| arrowprops=dict(arrowstyle='->', color='#2c3e50', lw=2)) | |
| # Draw center circle | |
| center_circle = plt.Circle((0, 0), 0.1, color='#2c3e50', zorder=5) | |
| ax.add_patch(center_circle) | |
| # Add value text | |
| ax.text(0, -0.15, f'{value*100:.0f}%', ha='center', va='top', | |
| fontsize=14, fontweight='bold', color='#2c3e50') | |
| # Add title | |
| ax.text(0, 1.2, title, ha='center', va='bottom', | |
| fontsize=11, fontweight='bold', color='#2c3e50') | |
| # Add risk labels | |
| ax.text(-1.1, -0.05, 'Low', ha='center', va='top', fontsize=8, color='#27ae60') | |
| ax.text(0, 1.05, 'Moderate', ha='center', va='bottom', fontsize=8, color='#f39c12') | |
| ax.text(1.1, -0.05, 'High', ha='center', va='top', fontsize=8, color='#c0392b') | |
| # Add threshold markers | |
| for pct, label in [(0.3, '30%'), (0.6, '60%')]: | |
| angle = 180 - pct * 180 | |
| rad = np.radians(angle) | |
| x_outer = 1.05 * np.cos(rad) | |
| y_outer = 1.05 * np.sin(rad) | |
| ax.text(x_outer, y_outer, label, ha='center', va='center', fontsize=7, color='#7f8c8d') | |
| def plot_risk_gauges(lvef_40: float, lvef_50: float, afib_5y: float) -> plt.Figure: | |
| """ | |
| Plot risk assessment gauges for LVEF and AFib predictions. | |
| Parameters | |
| ---------- | |
| lvef_40 : float | |
| Probability (0-1) of LVEF < 40% | |
| lvef_50 : float | |
| Probability (0-1) of LVEF < 50% | |
| afib_5y : float | |
| Probability (0-1) of AFib within 5 years | |
| Returns | |
| ------- | |
| plt.Figure | |
| Matplotlib figure with 3 semicircular gauges | |
| """ | |
| # Clamp values to [0, 1] | |
| lvef_40 = np.clip(lvef_40, 0, 1) | |
| lvef_50 = np.clip(lvef_50, 0, 1) | |
| afib_5y = np.clip(afib_5y, 0, 1) | |
| # Create figure with 3 subplots | |
| fig, axes = plt.subplots(1, 3, figsize=(14, 4)) | |
| fig.suptitle('Risk Assessment', fontsize=14, fontweight='bold', y=0.98) | |
| # Draw each gauge | |
| _draw_gauge(axes[0], lvef_40, 'LVEF < 40%') | |
| _draw_gauge(axes[1], lvef_50, 'LVEF < 50%') | |
| _draw_gauge(axes[2], afib_5y, 'AFib (5-year)') | |
| plt.tight_layout(rect=[0, 0, 1, 0.95]) | |
| return fig | |
| def generate_thumbnail(ecg_signal: np.ndarray, label: str, | |
| sample_rate: int = 250) -> Image.Image: | |
| """ | |
| Generate a thumbnail preview image of Lead II for gallery display. | |
| Parameters | |
| ---------- | |
| ecg_signal : np.ndarray | |
| ECG signal array of shape (12, n_samples) or (n_samples, 12) | |
| label : str | |
| Label text to display on thumbnail | |
| sample_rate : int, optional | |
| Sampling rate in Hz, default 250 | |
| Returns | |
| ------- | |
| PIL.Image.Image | |
| Thumbnail image approximately 300x150 pixels | |
| """ | |
| # Ensure correct shape (12, n_samples) | |
| if ecg_signal.shape[0] != 12: | |
| if ecg_signal.shape[1] == 12: | |
| ecg_signal = ecg_signal.T | |
| else: | |
| raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}") | |
| # Extract Lead II (index 1) | |
| lead_ii = ecg_signal[1, :] | |
| n_samples = len(lead_ii) | |
| time = np.arange(n_samples) / sample_rate | |
| # Create figure with appropriate DPI for ~300x150 pixel output | |
| fig, ax = plt.subplots(figsize=(3, 1.5), dpi=100) | |
| # Clean, minimal design | |
| ax.plot(time, lead_ii, color='#e74c3c', linewidth=1.0) | |
| # Set background | |
| ax.set_facecolor('#fafafa') | |
| fig.patch.set_facecolor('#fafafa') | |
| # Remove axes for clean look | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| # Add label | |
| ax.text(0.02, 0.98, label, transform=ax.transAxes, | |
| fontsize=8, fontweight='bold', verticalalignment='top', | |
| color='#2c3e50') | |
| # Add "Lead II" indicator | |
| ax.text(0.98, 0.02, 'Lead II', transform=ax.transAxes, | |
| fontsize=6, verticalalignment='bottom', horizontalalignment='right', | |
| color='#7f8c8d') | |
| plt.tight_layout(pad=0.2) | |
| # Convert to PIL Image | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', facecolor=fig.get_facecolor(), | |
| edgecolor='none', bbox_inches='tight', pad_inches=0.05) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| # Resize to ensure ~300x150 pixels | |
| img = img.resize((300, 150), Image.Resampling.LANCZOS) | |
| return img | |
| if __name__ == '__main__': | |
| # Quick test | |
| print("Visualization module loaded successfully.") | |
| print(f"Available functions: plot_ecg_waveform, plot_diagnosis_bars, plot_risk_gauges, generate_thumbnail") | |