""" HeartWatch AI Visualization Module This module provides visualization functions for ECG analysis including: - 12-lead ECG waveform plotting with clinical layout - Diagnosis probability bar charts - Risk assessment gauges - ECG thumbnail generation for galleries """ import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.patches import Wedge from PIL import Image import io # Standard 12-lead ECG names in clinical order LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] # Clinical layout: 4 columns x 3 rows # Col 1: I, II, III | Col 2: aVR, aVL, aVF | Col 3: V1, V2, V3 | Col 4: V4, V5, V6 LEAD_LAYOUT = [ ['I', 'aVR', 'V1', 'V4'], ['II', 'aVL', 'V2', 'V5'], ['III', 'aVF', 'V3', 'V6'] ] def plot_ecg_waveform(ecg_signal: np.ndarray, sample_rate: int = 250, title: str = "12-Lead ECG") -> plt.Figure: """ Plot a 12-lead ECG waveform in clinical layout format. Parameters ---------- ecg_signal : np.ndarray ECG signal array of shape (12, n_samples) or (n_samples, 12) Each row/column represents one of the 12 standard leads sample_rate : int, optional Sampling rate in Hz, default 250 title : str, optional Figure title, default "12-Lead ECG" Returns ------- plt.Figure Matplotlib figure with 4x3 ECG layout """ # Ensure correct shape (12, n_samples) if ecg_signal.shape[0] != 12: if ecg_signal.shape[1] == 12: ecg_signal = ecg_signal.T else: raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}") n_samples = ecg_signal.shape[1] # 2.5 seconds per column samples_per_col = int(2.5 * sample_rate) # Create figure with clinical dimensions fig, axes = plt.subplots(3, 4, figsize=(14, 8)) fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98) # Create lead index mapping lead_to_idx = {name: i for i, name in enumerate(LEAD_NAMES)} for row in range(3): for col in range(4): ax = axes[row, col] lead_name = LEAD_LAYOUT[row][col] lead_idx = lead_to_idx[lead_name] # Get signal segment for this column (2.5 sec) start_sample = 0 end_sample = min(samples_per_col, n_samples) signal_segment = ecg_signal[lead_idx, start_sample:end_sample] time_segment = np.arange(len(signal_segment)) / sample_rate # Set up ECG paper grid background (pink/red) ax.set_facecolor('#fff5f5') # Major grid (0.5 sec, 0.5 mV equivalent) ax.set_axisbelow(True) ax.grid(True, which='major', color='#ffcccc', linewidth=0.8, linestyle='-') ax.grid(True, which='minor', color='#ffe6e6', linewidth=0.4, linestyle='-') # Set tick spacing for major/minor grids ax.set_xticks(np.arange(0, 2.6, 0.5)) ax.set_xticks(np.arange(0, 2.6, 0.1), minor=True) # Calculate y-limits based on signal range signal_min, signal_max = signal_segment.min(), signal_segment.max() signal_range = signal_max - signal_min if signal_range < 0.1: signal_range = 2.0 # Default range if signal is flat padding = signal_range * 0.1 y_min = signal_min - padding y_max = signal_max + padding # Set y-ticks for grid y_tick_spacing = signal_range / 4 ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing, y_tick_spacing)) ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing/5, y_tick_spacing/5), minor=True) # Plot ECG waveform ax.plot(time_segment, signal_segment, color='black', linewidth=0.8) # Add lead label ax.text(0.02, 0.98, lead_name, transform=ax.transAxes, fontsize=10, fontweight='bold', verticalalignment='top', bbox=dict(boxstyle='round,pad=0.2', facecolor='white', edgecolor='none', alpha=0.7)) # Set axis limits ax.set_xlim(0, 2.5) ax.set_ylim(y_min, y_max) # Remove tick labels for cleaner look (except bottom row and left column) if row < 2: ax.set_xticklabels([]) else: ax.set_xlabel('Time (s)', fontsize=8) if col > 0: ax.set_yticklabels([]) else: ax.set_ylabel('Amplitude (mV)', fontsize=8) ax.tick_params(axis='both', which='both', labelsize=6) plt.tight_layout(rect=[0, 0, 1, 0.96]) return fig def plot_diagnosis_bars(diagnosis_77: dict, top_n: int = 10, ground_truth: list = None) -> plt.Figure: """ Plot horizontal bar chart of diagnosis probabilities. Parameters ---------- diagnosis_77 : dict Dictionary mapping diagnosis names to probabilities (0-1) top_n : int, optional Number of top diagnoses to display, default 10 ground_truth : list, optional List of ground truth diagnosis names to mark with star Returns ------- plt.Figure Matplotlib figure with horizontal bar chart """ if ground_truth is None: ground_truth = [] # Sort diagnoses by probability (descending) sorted_diagnoses = sorted(diagnosis_77.items(), key=lambda x: x[1], reverse=True) top_diagnoses = sorted_diagnoses[:top_n] # Extract names and probabilities names = [d[0] for d in top_diagnoses] probs = [d[1] for d in top_diagnoses] # Determine colors based on probability thresholds colors = [] for p in probs: if p >= 0.7: colors.append('#2ecc71') # Green for high confidence elif p >= 0.3: colors.append('#f1c40f') # Yellow for moderate else: colors.append('#95a5a6') # Gray for low confidence # Create figure fig, ax = plt.subplots(figsize=(8, 6)) # Create horizontal bar chart y_pos = np.arange(len(names)) bars = ax.barh(y_pos, probs, color=colors, edgecolor='black', linewidth=0.5) # Add probability labels on bars for i, (bar, prob) in enumerate(zip(bars, probs)): width = bar.get_width() label_x = width + 0.02 if width < 0.85 else width - 0.08 label_color = 'black' if width < 0.85 else 'white' ax.text(label_x, bar.get_y() + bar.get_height()/2, f'{prob:.1%}', va='center', fontsize=9, color=label_color) # Mark ground truth with star display_names = [] for name in names: if name in ground_truth: display_names.append(f'{name} \u2605') # Unicode star else: display_names.append(name) # Set y-axis labels ax.set_yticks(y_pos) ax.set_yticklabels(display_names, fontsize=9) # Set axis limits and labels ax.set_xlim(0, 1.0) ax.set_xlabel('Probability', fontsize=11) ax.set_title('Diagnosis Probabilities (Top {})'.format(top_n), fontsize=12, fontweight='bold', pad=10) # Add legend legend_elements = [ mpatches.Patch(facecolor='#2ecc71', edgecolor='black', label='High (\u2265 70%)'), mpatches.Patch(facecolor='#f1c40f', edgecolor='black', label='Moderate (30-70%)'), mpatches.Patch(facecolor='#95a5a6', edgecolor='black', label='Low (< 30%)') ] if ground_truth: legend_elements.append(mpatches.Patch(facecolor='white', edgecolor='white', label='\u2605 = Ground Truth')) ax.legend(handles=legend_elements, loc='lower right', fontsize=8) # Add grid for readability ax.xaxis.grid(True, linestyle='--', alpha=0.7) ax.set_axisbelow(True) # Invert y-axis so highest probability is at top ax.invert_yaxis() plt.tight_layout() return fig def _draw_gauge(ax, value: float, title: str): """ Draw a semicircular gauge on the given axes. Parameters ---------- ax : matplotlib.axes.Axes Axes to draw on value : float Value between 0 and 1 to display title : str Gauge title """ # Clear axes ax.clear() ax.set_xlim(-1.5, 1.5) ax.set_ylim(-0.3, 1.3) ax.set_aspect('equal') ax.axis('off') # Create gradient background arc (Green -> Yellow -> Red) n_segments = 100 for i in range(n_segments): theta1 = 180 - i * (180 / n_segments) theta2 = 180 - (i + 1) * (180 / n_segments) # Calculate color based on position pos = i / n_segments if pos < 0.3: # Green zone color = '#2ecc71' elif pos < 0.6: # Yellow zone (transition from green to yellow) t = (pos - 0.3) / 0.3 r = int(46 + t * (241 - 46)) g = int(204 + t * (196 - 204)) b = int(113 + t * (15 - 113)) color = f'#{r:02x}{g:02x}{b:02x}' else: # Red zone (transition from yellow to red) t = (pos - 0.6) / 0.4 r = int(241 + t * (231 - 241)) g = int(196 - t * 196) b = int(15 - t * 15) color = f'#{r:02x}{g:02x}{b:02x}' wedge = Wedge((0, 0), 1.0, theta2, theta1, width=0.3, facecolor=color, edgecolor='white', linewidth=0.5) ax.add_patch(wedge) # Draw needle needle_angle = 180 - value * 180 needle_rad = np.radians(needle_angle) needle_length = 0.85 needle_x = needle_length * np.cos(needle_rad) needle_y = needle_length * np.sin(needle_rad) ax.annotate('', xy=(needle_x, needle_y), xytext=(0, 0), arrowprops=dict(arrowstyle='->', color='#2c3e50', lw=2)) # Draw center circle center_circle = plt.Circle((0, 0), 0.1, color='#2c3e50', zorder=5) ax.add_patch(center_circle) # Add value text ax.text(0, -0.15, f'{value*100:.0f}%', ha='center', va='top', fontsize=14, fontweight='bold', color='#2c3e50') # Add title ax.text(0, 1.2, title, ha='center', va='bottom', fontsize=11, fontweight='bold', color='#2c3e50') # Add risk labels ax.text(-1.1, -0.05, 'Low', ha='center', va='top', fontsize=8, color='#27ae60') ax.text(0, 1.05, 'Moderate', ha='center', va='bottom', fontsize=8, color='#f39c12') ax.text(1.1, -0.05, 'High', ha='center', va='top', fontsize=8, color='#c0392b') # Add threshold markers for pct, label in [(0.3, '30%'), (0.6, '60%')]: angle = 180 - pct * 180 rad = np.radians(angle) x_outer = 1.05 * np.cos(rad) y_outer = 1.05 * np.sin(rad) ax.text(x_outer, y_outer, label, ha='center', va='center', fontsize=7, color='#7f8c8d') def plot_risk_gauges(lvef_40: float, lvef_50: float, afib_5y: float) -> plt.Figure: """ Plot risk assessment gauges for LVEF and AFib predictions. Parameters ---------- lvef_40 : float Probability (0-1) of LVEF < 40% lvef_50 : float Probability (0-1) of LVEF < 50% afib_5y : float Probability (0-1) of AFib within 5 years Returns ------- plt.Figure Matplotlib figure with 3 semicircular gauges """ # Clamp values to [0, 1] lvef_40 = np.clip(lvef_40, 0, 1) lvef_50 = np.clip(lvef_50, 0, 1) afib_5y = np.clip(afib_5y, 0, 1) # Create figure with 3 subplots fig, axes = plt.subplots(1, 3, figsize=(14, 4)) fig.suptitle('Risk Assessment', fontsize=14, fontweight='bold', y=0.98) # Draw each gauge _draw_gauge(axes[0], lvef_40, 'LVEF < 40%') _draw_gauge(axes[1], lvef_50, 'LVEF < 50%') _draw_gauge(axes[2], afib_5y, 'AFib (5-year)') plt.tight_layout(rect=[0, 0, 1, 0.95]) return fig def generate_thumbnail(ecg_signal: np.ndarray, label: str, sample_rate: int = 250) -> Image.Image: """ Generate a thumbnail preview image of Lead II for gallery display. Parameters ---------- ecg_signal : np.ndarray ECG signal array of shape (12, n_samples) or (n_samples, 12) label : str Label text to display on thumbnail sample_rate : int, optional Sampling rate in Hz, default 250 Returns ------- PIL.Image.Image Thumbnail image approximately 300x150 pixels """ # Ensure correct shape (12, n_samples) if ecg_signal.shape[0] != 12: if ecg_signal.shape[1] == 12: ecg_signal = ecg_signal.T else: raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}") # Extract Lead II (index 1) lead_ii = ecg_signal[1, :] n_samples = len(lead_ii) time = np.arange(n_samples) / sample_rate # Create figure with appropriate DPI for ~300x150 pixel output fig, ax = plt.subplots(figsize=(3, 1.5), dpi=100) # Clean, minimal design ax.plot(time, lead_ii, color='#e74c3c', linewidth=1.0) # Set background ax.set_facecolor('#fafafa') fig.patch.set_facecolor('#fafafa') # Remove axes for clean look ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) # Add label ax.text(0.02, 0.98, label, transform=ax.transAxes, fontsize=8, fontweight='bold', verticalalignment='top', color='#2c3e50') # Add "Lead II" indicator ax.text(0.98, 0.02, 'Lead II', transform=ax.transAxes, fontsize=6, verticalalignment='bottom', horizontalalignment='right', color='#7f8c8d') plt.tight_layout(pad=0.2) # Convert to PIL Image buf = io.BytesIO() fig.savefig(buf, format='png', facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight', pad_inches=0.05) plt.close(fig) buf.seek(0) img = Image.open(buf) # Resize to ensure ~300x150 pixels img = img.resize((300, 150), Image.Resampling.LANCZOS) return img if __name__ == '__main__': # Quick test print("Visualization module loaded successfully.") print(f"Available functions: plot_ecg_waveform, plot_diagnosis_bars, plot_risk_gauges, generate_thumbnail")