HeartWatchAI / visualization.py
Ashkan Taghipour (The University of Western Australia)
Initial HeartWatch AI demo release
cd846d7
"""
HeartWatch AI Visualization Module
This module provides visualization functions for ECG analysis including:
- 12-lead ECG waveform plotting with clinical layout
- Diagnosis probability bar charts
- Risk assessment gauges
- ECG thumbnail generation for galleries
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Wedge
from PIL import Image
import io
# Standard 12-lead ECG names in clinical order
LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
# Clinical layout: 4 columns x 3 rows
# Col 1: I, II, III | Col 2: aVR, aVL, aVF | Col 3: V1, V2, V3 | Col 4: V4, V5, V6
LEAD_LAYOUT = [
['I', 'aVR', 'V1', 'V4'],
['II', 'aVL', 'V2', 'V5'],
['III', 'aVF', 'V3', 'V6']
]
def plot_ecg_waveform(ecg_signal: np.ndarray, sample_rate: int = 250,
title: str = "12-Lead ECG") -> plt.Figure:
"""
Plot a 12-lead ECG waveform in clinical layout format.
Parameters
----------
ecg_signal : np.ndarray
ECG signal array of shape (12, n_samples) or (n_samples, 12)
Each row/column represents one of the 12 standard leads
sample_rate : int, optional
Sampling rate in Hz, default 250
title : str, optional
Figure title, default "12-Lead ECG"
Returns
-------
plt.Figure
Matplotlib figure with 4x3 ECG layout
"""
# Ensure correct shape (12, n_samples)
if ecg_signal.shape[0] != 12:
if ecg_signal.shape[1] == 12:
ecg_signal = ecg_signal.T
else:
raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}")
n_samples = ecg_signal.shape[1]
# 2.5 seconds per column
samples_per_col = int(2.5 * sample_rate)
# Create figure with clinical dimensions
fig, axes = plt.subplots(3, 4, figsize=(14, 8))
fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98)
# Create lead index mapping
lead_to_idx = {name: i for i, name in enumerate(LEAD_NAMES)}
for row in range(3):
for col in range(4):
ax = axes[row, col]
lead_name = LEAD_LAYOUT[row][col]
lead_idx = lead_to_idx[lead_name]
# Get signal segment for this column (2.5 sec)
start_sample = 0
end_sample = min(samples_per_col, n_samples)
signal_segment = ecg_signal[lead_idx, start_sample:end_sample]
time_segment = np.arange(len(signal_segment)) / sample_rate
# Set up ECG paper grid background (pink/red)
ax.set_facecolor('#fff5f5')
# Major grid (0.5 sec, 0.5 mV equivalent)
ax.set_axisbelow(True)
ax.grid(True, which='major', color='#ffcccc', linewidth=0.8, linestyle='-')
ax.grid(True, which='minor', color='#ffe6e6', linewidth=0.4, linestyle='-')
# Set tick spacing for major/minor grids
ax.set_xticks(np.arange(0, 2.6, 0.5))
ax.set_xticks(np.arange(0, 2.6, 0.1), minor=True)
# Calculate y-limits based on signal range
signal_min, signal_max = signal_segment.min(), signal_segment.max()
signal_range = signal_max - signal_min
if signal_range < 0.1:
signal_range = 2.0 # Default range if signal is flat
padding = signal_range * 0.1
y_min = signal_min - padding
y_max = signal_max + padding
# Set y-ticks for grid
y_tick_spacing = signal_range / 4
ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing, y_tick_spacing))
ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing/5, y_tick_spacing/5), minor=True)
# Plot ECG waveform
ax.plot(time_segment, signal_segment, color='black', linewidth=0.8)
# Add lead label
ax.text(0.02, 0.98, lead_name, transform=ax.transAxes,
fontsize=10, fontweight='bold', verticalalignment='top',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
edgecolor='none', alpha=0.7))
# Set axis limits
ax.set_xlim(0, 2.5)
ax.set_ylim(y_min, y_max)
# Remove tick labels for cleaner look (except bottom row and left column)
if row < 2:
ax.set_xticklabels([])
else:
ax.set_xlabel('Time (s)', fontsize=8)
if col > 0:
ax.set_yticklabels([])
else:
ax.set_ylabel('Amplitude (mV)', fontsize=8)
ax.tick_params(axis='both', which='both', labelsize=6)
plt.tight_layout(rect=[0, 0, 1, 0.96])
return fig
def plot_diagnosis_bars(diagnosis_77: dict, top_n: int = 10,
ground_truth: list = None) -> plt.Figure:
"""
Plot horizontal bar chart of diagnosis probabilities.
Parameters
----------
diagnosis_77 : dict
Dictionary mapping diagnosis names to probabilities (0-1)
top_n : int, optional
Number of top diagnoses to display, default 10
ground_truth : list, optional
List of ground truth diagnosis names to mark with star
Returns
-------
plt.Figure
Matplotlib figure with horizontal bar chart
"""
if ground_truth is None:
ground_truth = []
# Sort diagnoses by probability (descending)
sorted_diagnoses = sorted(diagnosis_77.items(), key=lambda x: x[1], reverse=True)
top_diagnoses = sorted_diagnoses[:top_n]
# Extract names and probabilities
names = [d[0] for d in top_diagnoses]
probs = [d[1] for d in top_diagnoses]
# Determine colors based on probability thresholds
colors = []
for p in probs:
if p >= 0.7:
colors.append('#2ecc71') # Green for high confidence
elif p >= 0.3:
colors.append('#f1c40f') # Yellow for moderate
else:
colors.append('#95a5a6') # Gray for low confidence
# Create figure
fig, ax = plt.subplots(figsize=(8, 6))
# Create horizontal bar chart
y_pos = np.arange(len(names))
bars = ax.barh(y_pos, probs, color=colors, edgecolor='black', linewidth=0.5)
# Add probability labels on bars
for i, (bar, prob) in enumerate(zip(bars, probs)):
width = bar.get_width()
label_x = width + 0.02 if width < 0.85 else width - 0.08
label_color = 'black' if width < 0.85 else 'white'
ax.text(label_x, bar.get_y() + bar.get_height()/2,
f'{prob:.1%}', va='center', fontsize=9, color=label_color)
# Mark ground truth with star
display_names = []
for name in names:
if name in ground_truth:
display_names.append(f'{name} \u2605') # Unicode star
else:
display_names.append(name)
# Set y-axis labels
ax.set_yticks(y_pos)
ax.set_yticklabels(display_names, fontsize=9)
# Set axis limits and labels
ax.set_xlim(0, 1.0)
ax.set_xlabel('Probability', fontsize=11)
ax.set_title('Diagnosis Probabilities (Top {})'.format(top_n),
fontsize=12, fontweight='bold', pad=10)
# Add legend
legend_elements = [
mpatches.Patch(facecolor='#2ecc71', edgecolor='black', label='High (\u2265 70%)'),
mpatches.Patch(facecolor='#f1c40f', edgecolor='black', label='Moderate (30-70%)'),
mpatches.Patch(facecolor='#95a5a6', edgecolor='black', label='Low (< 30%)')
]
if ground_truth:
legend_elements.append(mpatches.Patch(facecolor='white', edgecolor='white',
label='\u2605 = Ground Truth'))
ax.legend(handles=legend_elements, loc='lower right', fontsize=8)
# Add grid for readability
ax.xaxis.grid(True, linestyle='--', alpha=0.7)
ax.set_axisbelow(True)
# Invert y-axis so highest probability is at top
ax.invert_yaxis()
plt.tight_layout()
return fig
def _draw_gauge(ax, value: float, title: str):
"""
Draw a semicircular gauge on the given axes.
Parameters
----------
ax : matplotlib.axes.Axes
Axes to draw on
value : float
Value between 0 and 1 to display
title : str
Gauge title
"""
# Clear axes
ax.clear()
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-0.3, 1.3)
ax.set_aspect('equal')
ax.axis('off')
# Create gradient background arc (Green -> Yellow -> Red)
n_segments = 100
for i in range(n_segments):
theta1 = 180 - i * (180 / n_segments)
theta2 = 180 - (i + 1) * (180 / n_segments)
# Calculate color based on position
pos = i / n_segments
if pos < 0.3:
# Green zone
color = '#2ecc71'
elif pos < 0.6:
# Yellow zone (transition from green to yellow)
t = (pos - 0.3) / 0.3
r = int(46 + t * (241 - 46))
g = int(204 + t * (196 - 204))
b = int(113 + t * (15 - 113))
color = f'#{r:02x}{g:02x}{b:02x}'
else:
# Red zone (transition from yellow to red)
t = (pos - 0.6) / 0.4
r = int(241 + t * (231 - 241))
g = int(196 - t * 196)
b = int(15 - t * 15)
color = f'#{r:02x}{g:02x}{b:02x}'
wedge = Wedge((0, 0), 1.0, theta2, theta1, width=0.3, facecolor=color,
edgecolor='white', linewidth=0.5)
ax.add_patch(wedge)
# Draw needle
needle_angle = 180 - value * 180
needle_rad = np.radians(needle_angle)
needle_length = 0.85
needle_x = needle_length * np.cos(needle_rad)
needle_y = needle_length * np.sin(needle_rad)
ax.annotate('', xy=(needle_x, needle_y), xytext=(0, 0),
arrowprops=dict(arrowstyle='->', color='#2c3e50', lw=2))
# Draw center circle
center_circle = plt.Circle((0, 0), 0.1, color='#2c3e50', zorder=5)
ax.add_patch(center_circle)
# Add value text
ax.text(0, -0.15, f'{value*100:.0f}%', ha='center', va='top',
fontsize=14, fontweight='bold', color='#2c3e50')
# Add title
ax.text(0, 1.2, title, ha='center', va='bottom',
fontsize=11, fontweight='bold', color='#2c3e50')
# Add risk labels
ax.text(-1.1, -0.05, 'Low', ha='center', va='top', fontsize=8, color='#27ae60')
ax.text(0, 1.05, 'Moderate', ha='center', va='bottom', fontsize=8, color='#f39c12')
ax.text(1.1, -0.05, 'High', ha='center', va='top', fontsize=8, color='#c0392b')
# Add threshold markers
for pct, label in [(0.3, '30%'), (0.6, '60%')]:
angle = 180 - pct * 180
rad = np.radians(angle)
x_outer = 1.05 * np.cos(rad)
y_outer = 1.05 * np.sin(rad)
ax.text(x_outer, y_outer, label, ha='center', va='center', fontsize=7, color='#7f8c8d')
def plot_risk_gauges(lvef_40: float, lvef_50: float, afib_5y: float) -> plt.Figure:
"""
Plot risk assessment gauges for LVEF and AFib predictions.
Parameters
----------
lvef_40 : float
Probability (0-1) of LVEF < 40%
lvef_50 : float
Probability (0-1) of LVEF < 50%
afib_5y : float
Probability (0-1) of AFib within 5 years
Returns
-------
plt.Figure
Matplotlib figure with 3 semicircular gauges
"""
# Clamp values to [0, 1]
lvef_40 = np.clip(lvef_40, 0, 1)
lvef_50 = np.clip(lvef_50, 0, 1)
afib_5y = np.clip(afib_5y, 0, 1)
# Create figure with 3 subplots
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
fig.suptitle('Risk Assessment', fontsize=14, fontweight='bold', y=0.98)
# Draw each gauge
_draw_gauge(axes[0], lvef_40, 'LVEF < 40%')
_draw_gauge(axes[1], lvef_50, 'LVEF < 50%')
_draw_gauge(axes[2], afib_5y, 'AFib (5-year)')
plt.tight_layout(rect=[0, 0, 1, 0.95])
return fig
def generate_thumbnail(ecg_signal: np.ndarray, label: str,
sample_rate: int = 250) -> Image.Image:
"""
Generate a thumbnail preview image of Lead II for gallery display.
Parameters
----------
ecg_signal : np.ndarray
ECG signal array of shape (12, n_samples) or (n_samples, 12)
label : str
Label text to display on thumbnail
sample_rate : int, optional
Sampling rate in Hz, default 250
Returns
-------
PIL.Image.Image
Thumbnail image approximately 300x150 pixels
"""
# Ensure correct shape (12, n_samples)
if ecg_signal.shape[0] != 12:
if ecg_signal.shape[1] == 12:
ecg_signal = ecg_signal.T
else:
raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}")
# Extract Lead II (index 1)
lead_ii = ecg_signal[1, :]
n_samples = len(lead_ii)
time = np.arange(n_samples) / sample_rate
# Create figure with appropriate DPI for ~300x150 pixel output
fig, ax = plt.subplots(figsize=(3, 1.5), dpi=100)
# Clean, minimal design
ax.plot(time, lead_ii, color='#e74c3c', linewidth=1.0)
# Set background
ax.set_facecolor('#fafafa')
fig.patch.set_facecolor('#fafafa')
# Remove axes for clean look
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
# Add label
ax.text(0.02, 0.98, label, transform=ax.transAxes,
fontsize=8, fontweight='bold', verticalalignment='top',
color='#2c3e50')
# Add "Lead II" indicator
ax.text(0.98, 0.02, 'Lead II', transform=ax.transAxes,
fontsize=6, verticalalignment='bottom', horizontalalignment='right',
color='#7f8c8d')
plt.tight_layout(pad=0.2)
# Convert to PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png', facecolor=fig.get_facecolor(),
edgecolor='none', bbox_inches='tight', pad_inches=0.05)
plt.close(fig)
buf.seek(0)
img = Image.open(buf)
# Resize to ensure ~300x150 pixels
img = img.resize((300, 150), Image.Resampling.LANCZOS)
return img
if __name__ == '__main__':
# Quick test
print("Visualization module loaded successfully.")
print(f"Available functions: plot_ecg_waveform, plot_diagnosis_bars, plot_risk_gauges, generate_thumbnail")