import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import ListedColormap from skimage.measure import label, regionprops import base64 from typing import Callable def imshow_compare(data_dict, ax_size=4, draw_bbox=False, max_images=None): """ Display the images in a grid format for comparison. Each key is an annotator, each value is another dict, where the key is the image type and the value the list of corresponding images. """ # 0 is black, 1 is red, 2 is green cmap = ListedColormap(['black', 'red', 'green']) # Convert the data dictionary to a dict of annotators: list of images data = dict() for annotator, images in data_dict.items(): if annotator not in data: data[annotator] = [] for image_type, masks in images.items(): for mask in masks: data[annotator].append(mask) annotators = list(data.keys()) num_images = len(data[annotators[0]]) if max_images is not None and num_images > max_images: num_images = max_images num_annotators = len(annotators) fig_size = (ax_size * num_annotators, ax_size * num_images) fig, axes = plt.subplots(num_images, num_annotators, figsize=fig_size, squeeze=False) for i, annotator in enumerate(annotators): for j in range(num_images): if max_images is not None and j > max_images: break ax = axes[j, i] mask = data[annotator][j] ax.imshow(mask, cmap=cmap, interpolation='nearest') ax.axis('off') ax.set_xticks([]) ax.set_yticks([]) if draw_bbox: mask = mask > 0 labeled_mask = label(mask, connectivity=2) regions = regionprops(labeled_mask) for region in regions: minr, minc, maxr, maxc = region.bbox rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor='yellow', linewidth=0.5) ax.add_patch(rect) if j == 0: ax.set_title(annotator) fig.tight_layout() return fig, axes def add_p_value_annotation(fig, array_columns, stats_test, subplot=None, _format=dict(interline=0.07, text_height=1.07, color='black')): ''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison) Parameters: ---------- fig: figure plotly boxplot figure array_columns: np.array array of which columns to compare e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2 subplot: None or int specifies if the figures has subplots and what subplot to add the notation to _format: dict format characteristics for the lines Returns: ------- fig: figure figure with the added notation ''' # Specify in what y_range to plot for each pair of columns y_range = np.zeros([len(array_columns), 2]) for i in range(len(array_columns)): y_range[i] = [1.01+i*_format['interline'], 1.02+i*_format['interline']] # Get values from figure fig_dict = fig.to_dict() # Get indices if working with subplots if subplot: if subplot == 1: subplot_str = '' else: subplot_str =str(subplot) indices = [] #Change the box index to the indices of the data for that subplot for index, data in enumerate(fig_dict['data']): #print(index, data['xaxis'], 'x' + subplot_str) if data['xaxis'] == 'x' + subplot_str: indices = np.append(indices, index) indices = [int(i) for i in indices] print((indices)) else: subplot_str = '' # Print the p-values for index, column_pair in enumerate(array_columns): if subplot: data_pair = [indices[column_pair[0]], indices[column_pair[1]]] else: data_pair = column_pair # Mare sure it is selecting the data and subplot you want #print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis']) #print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis']) if isinstance(stats_test, Callable): # Get the p-value d1 = fig_dict['data'][data_pair[0]]['y'] d2 = fig_dict['data'][data_pair[1]]['y'] d1 = base64.b64decode(d1['bdata']) d2 = base64.b64decode(d2['bdata']) d1 = np.frombuffer(d1, dtype=np.float64) d2 = np.frombuffer(d2, dtype=np.float64) pvalue = stats_test( d1, d2, )[1] else: pvalue = stats_test[index] if pvalue >= 0.05: symbol = 'ns' elif pvalue >= 0.01: symbol = '*' elif pvalue >= 0.001: symbol = '**' else: symbol = '***' # Vertical line fig.add_shape(type="line", xref="x"+subplot_str, yref="y"+subplot_str+" domain", x0=column_pair[0], y0=y_range[index][0], x1=column_pair[0], y1=y_range[index][1], line=dict(color=_format['color'], width=2,) ) # Horizontal line fig.add_shape(type="line", xref="x"+subplot_str, yref="y"+subplot_str+" domain", x0=column_pair[0], y0=y_range[index][1], x1=column_pair[1], y1=y_range[index][1], line=dict(color=_format['color'], width=2,) ) # Vertical line fig.add_shape(type="line", xref="x"+subplot_str, yref="y"+subplot_str+" domain", x0=column_pair[1], y0=y_range[index][0], x1=column_pair[1], y1=y_range[index][1], line=dict(color=_format['color'], width=2,) ) ## add text at the correct x, y coordinates ## for bars, there is a direct mapping from the bar number to 0, 1, 2... fig.add_annotation(dict(font=dict(color=_format['color'],size=14), x=(column_pair[0] + column_pair[1])/2, y=y_range[index][1]*_format['text_height'], showarrow=False, text=symbol, textangle=0, xref="x"+subplot_str, yref="y"+subplot_str+" domain" )) return fig