"""Visualization functions for similarity measures""" from shared_styling import PLOT_COLORS, set_plot_style, get_plot_style import matplotlib.pyplot as plt import numpy as np import seaborn as sns from mpl_toolkits.axes_grid1 import make_axes_locatable from typing import Optional, Union import matplotlib as mpl from matplotlib.colors import LinearSegmentedColormap def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None, plot_type: str = 'input', cbar: bool = True, task_x_name: str = '', task_y_name: str = '', xlabel: Optional[str] = None, ylabel: Optional[str] = None, cmap: Optional[Union[str, mpl.colors.Colormap]] = None, title: Optional[str] = None, fontsize: int = 16, y_label_pad: int = 20, cbar_labels: bool = True, high_level_plot: bool = False, cutoff_treshold: float | None = None): """Plot similarity measure with consistent styling. Args: ax: matplotlib axis object measure_data: 2D array of similarity measures plot_type: str, one of 'input', 'output', or 'conflict' cbar: bool, whether to show colorbar task_x_name: str, name of x-axis task (used if xlabel not provided) task_y_name: str, name of y-axis task (used if ylabel not provided) xlabel: str, optional custom x-axis label ylabel: str, optional custom y-axis label cmap: str or matplotlib colormap, optional custom colormap title: str, optional custom title fontsize: int, font size for all text elements cbar_labels: bool, whether to show colorbar labels cutoff_treshold: float, optional cutoff threshold that will count as a conflict. If it is not None, the amount of values above this threshold will be added as a percent text next to the colorbar. If it is None, no text will be added. high_level_plot: bool, whether to plot in a high-level format. This just means that we subtract the diagonal from the thresholded values, since this will trivially be 1 due to the fact that we compare the same task to itself. """ if ax is None: fig, ax = plt.subplots() # Set color based on plot type if cmap is None: if plot_type == 'input': cmap = PLOT_COLORS['input_similarity'] elif plot_type == 'output': cmap = PLOT_COLORS['output_difference'] elif plot_type == 'output_biomechanical': cmap = PLOT_COLORS['output_biomechanical'] else: # conflict cmap = PLOT_COLORS['conflict'] # Create heatmap with flipped y-axis hm = sns.heatmap(np.flipud(measure_data), vmin=0, vmax=1, ax=ax, cmap=cmap, cbar=False, rasterized=False) # Configure axes ax.set_xticks([0, 149]) ax.set_xticklabels(['0%', '100%'], fontsize=fontsize) ax.set_yticks([0, 149]) ax.set_yticklabels(['100%', '0%'], fontsize=fontsize) # Flipped y-axis labels # Set labels with consistent padding if xlabel is not None: ax.set_xlabel(f'Gait Cycle (%)\n{xlabel}', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize) elif task_x_name: ax.set_xlabel(f'Gait Cycle (%)\n{task_x_name}', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize) else: ax.set_xlabel('Gait Cycle (%)', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize) if ylabel is not None: ax.set_ylabel(f'{ylabel}\nGait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize) elif task_y_name: ax.set_ylabel(f'{task_y_name}\nGait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize) else: ax.set_ylabel('Gait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize) # Configure ticks ax.tick_params(axis='both', which='both', length=get_plot_style()['tick_length'], width=get_plot_style()['tick_width'], pad=get_plot_style()['tick_pad'], labelsize=fontsize) # Ensure x-axis labels are horizontal ax.tick_params(axis='x', rotation=0) # Set the title if title is not None: ax.set_title(title, pad=2, fontsize=fontsize) if cbar: divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) # Increased from 2% to 5% cbar_obj = plt.colorbar(hm.collections[0], cax=cax) cbar_obj.outline.set_visible(False) cbar_obj.ax.yaxis.set_ticks_position('right') cbar_obj.ax.tick_params(length=0, labelsize=fontsize) if cbar_labels: cbar_obj.set_ticks([0, 1]) cbar_obj.set_ticklabels(['0.00', '1.00']) else: cbar_obj.set_ticks([]) if plot_type != 'output': percent = np.mean(measure_data) * 100 # Format as e.g. "12.3%" annotation = r"$\tilde C_{total} = $" + f"{percent:.1f}%" # Place annotation to the right of the colorbar cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left', fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes) # Set aspect ratio to equal ax.set_aspect('equal') plt.tight_layout(pad=1.2) return hm # Create custom colormaps with transparency for low values def create_transparent_cmap(colors): # Convert hex colors to RGB tuples with alpha rgb_colors = [(1,1,1,0)] # Start with transparent white for hex_color in colors: r = int(hex_color[1:3], 16)/255 g = int(hex_color[3:5], 16)/255 b = int(hex_color[5:7], 16)/255 rgb_colors.append((r, g, b, 1.0)) return LinearSegmentedColormap.from_list('custom_cmap', rgb_colors) # Define color sequences blue_colors = ['#F7FBFF', '#DEEBF7', '#C6DBEF', '#9ECAE1', '#6BAED6', '#4292C6', '#2171B5'] orange_colors = ['#FFF5EB', '#FEE6CE', '#FDD0A2', '#FDAE6B', '#FD8D3C', '#F16913', '#D94801'] green_colors = ['#F7FCF5', '#E5F5E0', '#C7E9C0', '#A1D99B', '#74C476', '#41AB5D', '#238B45'] sim_cmap = create_transparent_cmap(blue_colors) # Blue sequential diff_cmap = create_transparent_cmap(orange_colors) # Peach sequential conflict_cmap = create_transparent_cmap(green_colors) # Purple sequential