Lower-Limb-Similarity-Analysis / plot_similarity.py
jmontp's picture
Major update
c3a09bd
"""Visualization functions for similarity measures"""
from shared_styling import PLOT_COLORS, set_plot_style, get_plot_style
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import Optional, Union
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None,
plot_type: str = 'input', cbar: bool = True,
task_x_name: str = '', task_y_name: str = '',
xlabel: Optional[str] = None, ylabel: Optional[str] = None,
cmap: Optional[Union[str, mpl.colors.Colormap]] = None,
title: Optional[str] = None,
fontsize: int = 16,
y_label_pad: int = 20,
cbar_labels: bool = True,
high_level_plot: bool = False,
cutoff_treshold: float | None = None):
"""Plot similarity measure with consistent styling.
Args:
ax: matplotlib axis object
measure_data: 2D array of similarity measures
plot_type: str, one of 'input', 'output', or 'conflict'
cbar: bool, whether to show colorbar
task_x_name: str, name of x-axis task (used if xlabel not provided)
task_y_name: str, name of y-axis task (used if ylabel not provided)
xlabel: str, optional custom x-axis label
ylabel: str, optional custom y-axis label
cmap: str or matplotlib colormap, optional custom colormap
title: str, optional custom title
fontsize: int, font size for all text elements
cbar_labels: bool, whether to show colorbar labels
cutoff_treshold: float, optional cutoff threshold that will count
as a conflict. If it is not None, the amount of values above this
threshold will be added as a percent text next to the colorbar.
If it is None, no text will be added.
high_level_plot: bool, whether to plot in a high-level format. This just
means that we subtract the diagonal from the thresholded values, since
this will trivially be 1 due to the fact that we compare the same task
to itself.
"""
if ax is None:
fig, ax = plt.subplots()
# Set color based on plot type
if cmap is None:
if plot_type == 'input':
cmap = PLOT_COLORS['input_similarity']
elif plot_type == 'output':
cmap = PLOT_COLORS['output_difference']
elif plot_type == 'output_biomechanical':
cmap = PLOT_COLORS['output_biomechanical']
else: # conflict
cmap = PLOT_COLORS['conflict']
# Create heatmap with flipped y-axis
hm = sns.heatmap(np.flipud(measure_data), vmin=0, vmax=1, ax=ax, cmap=cmap,
cbar=False, rasterized=False)
# Configure axes
ax.set_xticks([0, 149])
ax.set_xticklabels(['0%', '100%'], fontsize=fontsize)
ax.set_yticks([0, 149])
ax.set_yticklabels(['100%', '0%'], fontsize=fontsize) # Flipped y-axis labels
# Set labels with consistent padding
if xlabel is not None:
ax.set_xlabel(f'Gait Cycle (%)\n{xlabel}', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize)
elif task_x_name:
ax.set_xlabel(f'Gait Cycle (%)\n{task_x_name}', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize)
else:
ax.set_xlabel('Gait Cycle (%)', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize)
if ylabel is not None:
ax.set_ylabel(f'{ylabel}\nGait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize)
elif task_y_name:
ax.set_ylabel(f'{task_y_name}\nGait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize)
else:
ax.set_ylabel('Gait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize)
# Configure ticks
ax.tick_params(axis='both', which='both',
length=get_plot_style()['tick_length'],
width=get_plot_style()['tick_width'],
pad=get_plot_style()['tick_pad'],
labelsize=fontsize)
# Ensure x-axis labels are horizontal
ax.tick_params(axis='x', rotation=0)
# Set the title
if title is not None:
ax.set_title(title, pad=2, fontsize=fontsize)
if cbar:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05) # Increased from 2% to 5%
cbar_obj = plt.colorbar(hm.collections[0], cax=cax)
cbar_obj.outline.set_visible(False)
cbar_obj.ax.yaxis.set_ticks_position('right')
cbar_obj.ax.tick_params(length=0, labelsize=fontsize)
if cbar_labels:
cbar_obj.set_ticks([0, 1])
cbar_obj.set_ticklabels(['0.00', '1.00'])
else:
cbar_obj.set_ticks([])
if plot_type != 'output':
percent = np.mean(measure_data) * 100
# Format as e.g. "12.3%"
annotation = r"$\tilde C_{total} = $" + f"{percent:.1f}%"
# Place annotation to the right of the colorbar
cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left',
fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes)
# Set aspect ratio to equal
ax.set_aspect('equal')
plt.tight_layout(pad=1.2)
return hm
# Create custom colormaps with transparency for low values
def create_transparent_cmap(colors):
# Convert hex colors to RGB tuples with alpha
rgb_colors = [(1,1,1,0)] # Start with transparent white
for hex_color in colors:
r = int(hex_color[1:3], 16)/255
g = int(hex_color[3:5], 16)/255
b = int(hex_color[5:7], 16)/255
rgb_colors.append((r, g, b, 1.0))
return LinearSegmentedColormap.from_list('custom_cmap', rgb_colors)
# Define color sequences
blue_colors = ['#F7FBFF', '#DEEBF7', '#C6DBEF', '#9ECAE1', '#6BAED6', '#4292C6', '#2171B5']
orange_colors = ['#FFF5EB', '#FEE6CE', '#FDD0A2', '#FDAE6B', '#FD8D3C', '#F16913', '#D94801']
green_colors = ['#F7FCF5', '#E5F5E0', '#C7E9C0', '#A1D99B', '#74C476', '#41AB5D', '#238B45']
sim_cmap = create_transparent_cmap(blue_colors) # Blue sequential
diff_cmap = create_transparent_cmap(orange_colors) # Peach sequential
conflict_cmap = create_transparent_cmap(green_colors) # Purple sequential