ClementP's picture
Upload 55 files
69591a9 verified
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from skimage.measure import label, regionprops
import base64
from typing import Callable
def imshow_compare(data_dict, ax_size=4, draw_bbox=False, max_images=None):
"""
Display the images in a grid format for comparison.
Each key is an annotator, each value is another dict, where the key is the image type and the value the list of corresponding images.
"""
# 0 is black, 1 is red, 2 is green
cmap = ListedColormap(['black', 'red', 'green'])
# Convert the data dictionary to a dict of annotators: list of images
data = dict()
for annotator, images in data_dict.items():
if annotator not in data:
data[annotator] = []
for image_type, masks in images.items():
for mask in masks:
data[annotator].append(mask)
annotators = list(data.keys())
num_images = len(data[annotators[0]])
if max_images is not None and num_images > max_images:
num_images = max_images
num_annotators = len(annotators)
fig_size = (ax_size * num_annotators, ax_size * num_images)
fig, axes = plt.subplots(num_images, num_annotators, figsize=fig_size, squeeze=False)
for i, annotator in enumerate(annotators):
for j in range(num_images):
if max_images is not None and j > max_images:
break
ax = axes[j, i]
mask = data[annotator][j]
ax.imshow(mask, cmap=cmap, interpolation='nearest')
ax.axis('off')
ax.set_xticks([])
ax.set_yticks([])
if draw_bbox:
mask = mask > 0
labeled_mask = label(mask, connectivity=2)
regions = regionprops(labeled_mask)
for region in regions:
minr, minc, maxr, maxc = region.bbox
rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr,
fill=False, edgecolor='yellow', linewidth=0.5)
ax.add_patch(rect)
if j == 0:
ax.set_title(annotator)
fig.tight_layout()
return fig, axes
def add_p_value_annotation(fig, array_columns, stats_test, subplot=None, _format=dict(interline=0.07, text_height=1.07, color='black')):
''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
Parameters:
----------
fig: figure
plotly boxplot figure
array_columns: np.array
array of which columns to compare
e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
subplot: None or int
specifies if the figures has subplots and what subplot to add the notation to
_format: dict
format characteristics for the lines
Returns:
-------
fig: figure
figure with the added notation
'''
# Specify in what y_range to plot for each pair of columns
y_range = np.zeros([len(array_columns), 2])
for i in range(len(array_columns)):
y_range[i] = [1.01+i*_format['interline'], 1.02+i*_format['interline']]
# Get values from figure
fig_dict = fig.to_dict()
# Get indices if working with subplots
if subplot:
if subplot == 1:
subplot_str = ''
else:
subplot_str =str(subplot)
indices = [] #Change the box index to the indices of the data for that subplot
for index, data in enumerate(fig_dict['data']):
#print(index, data['xaxis'], 'x' + subplot_str)
if data['xaxis'] == 'x' + subplot_str:
indices = np.append(indices, index)
indices = [int(i) for i in indices]
print((indices))
else:
subplot_str = ''
# Print the p-values
for index, column_pair in enumerate(array_columns):
if subplot:
data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
else:
data_pair = column_pair
# Mare sure it is selecting the data and subplot you want
#print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
#print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])
if isinstance(stats_test, Callable):
# Get the p-value
d1 = fig_dict['data'][data_pair[0]]['y']
d2 = fig_dict['data'][data_pair[1]]['y']
d1 = base64.b64decode(d1['bdata'])
d2 = base64.b64decode(d2['bdata'])
d1 = np.frombuffer(d1, dtype=np.float64)
d2 = np.frombuffer(d2, dtype=np.float64)
pvalue = stats_test(
d1,
d2,
)[1]
else:
pvalue = stats_test[index]
if pvalue >= 0.05:
symbol = 'ns'
elif pvalue >= 0.01:
symbol = '*'
elif pvalue >= 0.001:
symbol = '**'
else:
symbol = '***'
# Vertical line
fig.add_shape(type="line",
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
x0=column_pair[0], y0=y_range[index][0],
x1=column_pair[0], y1=y_range[index][1],
line=dict(color=_format['color'], width=2,)
)
# Horizontal line
fig.add_shape(type="line",
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
x0=column_pair[0], y0=y_range[index][1],
x1=column_pair[1], y1=y_range[index][1],
line=dict(color=_format['color'], width=2,)
)
# Vertical line
fig.add_shape(type="line",
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
x0=column_pair[1], y0=y_range[index][0],
x1=column_pair[1], y1=y_range[index][1],
line=dict(color=_format['color'], width=2,)
)
## add text at the correct x, y coordinates
## for bars, there is a direct mapping from the bar number to 0, 1, 2...
fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
x=(column_pair[0] + column_pair[1])/2,
y=y_range[index][1]*_format['text_height'],
showarrow=False,
text=symbol,
textangle=0,
xref="x"+subplot_str,
yref="y"+subplot_str+" domain"
))
return fig