Spaces:
Runtime error
Runtime error
| """Visualization functions for similarity measures""" | |
| from shared_styling import PLOT_COLORS, set_plot_style, get_plot_style | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| from mpl_toolkits.axes_grid1 import make_axes_locatable | |
| from typing import Optional, Union | |
| import matplotlib as mpl | |
| from matplotlib.colors import LinearSegmentedColormap | |
| def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None, | |
| plot_type: str = 'input', cbar: bool = True, | |
| task_x_name: str = '', task_y_name: str = '', | |
| xlabel: Optional[str] = None, ylabel: Optional[str] = None, | |
| cmap: Optional[Union[str, mpl.colors.Colormap]] = None, | |
| title: Optional[str] = None, | |
| fontsize: int = 16, | |
| y_label_pad: int = 20, | |
| cbar_labels: bool = True, | |
| high_level_plot: bool = False, | |
| cutoff_treshold: float | None = None): | |
| """Plot similarity measure with consistent styling. | |
| Args: | |
| ax: matplotlib axis object | |
| measure_data: 2D array of similarity measures | |
| plot_type: str, one of 'input', 'output', or 'conflict' | |
| cbar: bool, whether to show colorbar | |
| task_x_name: str, name of x-axis task (used if xlabel not provided) | |
| task_y_name: str, name of y-axis task (used if ylabel not provided) | |
| xlabel: str, optional custom x-axis label | |
| ylabel: str, optional custom y-axis label | |
| cmap: str or matplotlib colormap, optional custom colormap | |
| title: str, optional custom title | |
| fontsize: int, font size for all text elements | |
| cbar_labels: bool, whether to show colorbar labels | |
| cutoff_treshold: float, optional cutoff threshold that will count | |
| as a conflict. If it is not None, the amount of values above this | |
| threshold will be added as a percent text next to the colorbar. | |
| If it is None, no text will be added. | |
| high_level_plot: bool, whether to plot in a high-level format. This just | |
| means that we subtract the diagonal from the thresholded values, since | |
| this will trivially be 1 due to the fact that we compare the same task | |
| to itself. | |
| """ | |
| if ax is None: | |
| fig, ax = plt.subplots() | |
| # Set color based on plot type | |
| if cmap is None: | |
| if plot_type == 'input': | |
| cmap = PLOT_COLORS['input_similarity'] | |
| elif plot_type == 'output': | |
| cmap = PLOT_COLORS['output_difference'] | |
| elif plot_type == 'output_biomechanical': | |
| cmap = PLOT_COLORS['output_biomechanical'] | |
| else: # conflict | |
| cmap = PLOT_COLORS['conflict'] | |
| # Create heatmap with flipped y-axis | |
| hm = sns.heatmap(np.flipud(measure_data), vmin=0, vmax=1, ax=ax, cmap=cmap, | |
| cbar=False, rasterized=False) | |
| # Configure axes | |
| ax.set_xticks([0, 149]) | |
| ax.set_xticklabels(['0%', '100%'], fontsize=fontsize) | |
| ax.set_yticks([0, 149]) | |
| ax.set_yticklabels(['100%', '0%'], fontsize=fontsize) # Flipped y-axis labels | |
| # Set labels with consistent padding | |
| if xlabel is not None: | |
| ax.set_xlabel(f'Gait Cycle (%)\n{xlabel}', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize) | |
| elif task_x_name: | |
| ax.set_xlabel(f'Gait Cycle (%)\n{task_x_name}', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize) | |
| else: | |
| ax.set_xlabel('Gait Cycle (%)', labelpad=get_plot_style()['label_pad_x']-3, fontsize=fontsize) | |
| if ylabel is not None: | |
| ax.set_ylabel(f'{ylabel}\nGait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize) | |
| elif task_y_name: | |
| ax.set_ylabel(f'{task_y_name}\nGait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize) | |
| else: | |
| ax.set_ylabel('Gait Cycle (%)', labelpad=get_plot_style()['label_pad_y']+y_label_pad, fontsize=fontsize) | |
| # Configure ticks | |
| ax.tick_params(axis='both', which='both', | |
| length=get_plot_style()['tick_length'], | |
| width=get_plot_style()['tick_width'], | |
| pad=get_plot_style()['tick_pad'], | |
| labelsize=fontsize) | |
| # Ensure x-axis labels are horizontal | |
| ax.tick_params(axis='x', rotation=0) | |
| # Set the title | |
| if title is not None: | |
| ax.set_title(title, pad=2, fontsize=fontsize) | |
| if cbar: | |
| divider = make_axes_locatable(ax) | |
| cax = divider.append_axes("right", size="5%", pad=0.05) # Increased from 2% to 5% | |
| cbar_obj = plt.colorbar(hm.collections[0], cax=cax) | |
| cbar_obj.outline.set_visible(False) | |
| cbar_obj.ax.yaxis.set_ticks_position('right') | |
| cbar_obj.ax.tick_params(length=0, labelsize=fontsize) | |
| if cbar_labels: | |
| cbar_obj.set_ticks([0, 1]) | |
| cbar_obj.set_ticklabels(['0.00', '1.00']) | |
| else: | |
| cbar_obj.set_ticks([]) | |
| if plot_type != 'output': | |
| percent = np.mean(measure_data) * 100 | |
| # Format as e.g. "12.3%" | |
| annotation = r"$\tilde C_{total} = $" + f"{percent:.1f}%" | |
| # Place annotation to the right of the colorbar | |
| cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left', | |
| fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes) | |
| # Set aspect ratio to equal | |
| ax.set_aspect('equal') | |
| plt.tight_layout(pad=1.2) | |
| return hm | |
| # Create custom colormaps with transparency for low values | |
| def create_transparent_cmap(colors): | |
| # Convert hex colors to RGB tuples with alpha | |
| rgb_colors = [(1,1,1,0)] # Start with transparent white | |
| for hex_color in colors: | |
| r = int(hex_color[1:3], 16)/255 | |
| g = int(hex_color[3:5], 16)/255 | |
| b = int(hex_color[5:7], 16)/255 | |
| rgb_colors.append((r, g, b, 1.0)) | |
| return LinearSegmentedColormap.from_list('custom_cmap', rgb_colors) | |
| # Define color sequences | |
| blue_colors = ['#F7FBFF', '#DEEBF7', '#C6DBEF', '#9ECAE1', '#6BAED6', '#4292C6', '#2171B5'] | |
| orange_colors = ['#FFF5EB', '#FEE6CE', '#FDD0A2', '#FDAE6B', '#FD8D3C', '#F16913', '#D94801'] | |
| green_colors = ['#F7FCF5', '#E5F5E0', '#C7E9C0', '#A1D99B', '#74C476', '#41AB5D', '#238B45'] | |
| sim_cmap = create_transparent_cmap(blue_colors) # Blue sequential | |
| diff_cmap = create_transparent_cmap(orange_colors) # Peach sequential | |
| conflict_cmap = create_transparent_cmap(green_colors) # Purple sequential | |