|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
from matplotlib.colors import ListedColormap
|
|
|
from skimage.measure import label, regionprops
|
|
|
import base64
|
|
|
from typing import Callable
|
|
|
|
|
|
def imshow_compare(data_dict, ax_size=4, draw_bbox=False, max_images=None):
|
|
|
"""
|
|
|
Display the images in a grid format for comparison.
|
|
|
Each key is an annotator, each value is another dict, where the key is the image type and the value the list of corresponding images.
|
|
|
"""
|
|
|
|
|
|
cmap = ListedColormap(['black', 'red', 'green'])
|
|
|
|
|
|
|
|
|
data = dict()
|
|
|
for annotator, images in data_dict.items():
|
|
|
if annotator not in data:
|
|
|
data[annotator] = []
|
|
|
for image_type, masks in images.items():
|
|
|
for mask in masks:
|
|
|
data[annotator].append(mask)
|
|
|
|
|
|
annotators = list(data.keys())
|
|
|
num_images = len(data[annotators[0]])
|
|
|
if max_images is not None and num_images > max_images:
|
|
|
num_images = max_images
|
|
|
num_annotators = len(annotators)
|
|
|
|
|
|
fig_size = (ax_size * num_annotators, ax_size * num_images)
|
|
|
fig, axes = plt.subplots(num_images, num_annotators, figsize=fig_size, squeeze=False)
|
|
|
|
|
|
for i, annotator in enumerate(annotators):
|
|
|
for j in range(num_images):
|
|
|
if max_images is not None and j > max_images:
|
|
|
break
|
|
|
ax = axes[j, i]
|
|
|
mask = data[annotator][j]
|
|
|
ax.imshow(mask, cmap=cmap, interpolation='nearest')
|
|
|
ax.axis('off')
|
|
|
ax.set_xticks([])
|
|
|
ax.set_yticks([])
|
|
|
if draw_bbox:
|
|
|
mask = mask > 0
|
|
|
labeled_mask = label(mask, connectivity=2)
|
|
|
regions = regionprops(labeled_mask)
|
|
|
for region in regions:
|
|
|
minr, minc, maxr, maxc = region.bbox
|
|
|
rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr,
|
|
|
fill=False, edgecolor='yellow', linewidth=0.5)
|
|
|
ax.add_patch(rect)
|
|
|
|
|
|
|
|
|
|
|
|
if j == 0:
|
|
|
ax.set_title(annotator)
|
|
|
|
|
|
|
|
|
fig.tight_layout()
|
|
|
return fig, axes
|
|
|
|
|
|
|
|
|
def add_p_value_annotation(fig, array_columns, stats_test, subplot=None, _format=dict(interline=0.07, text_height=1.07, color='black')):
|
|
|
''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
|
|
|
|
|
|
Parameters:
|
|
|
----------
|
|
|
fig: figure
|
|
|
plotly boxplot figure
|
|
|
array_columns: np.array
|
|
|
array of which columns to compare
|
|
|
e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
|
|
|
subplot: None or int
|
|
|
specifies if the figures has subplots and what subplot to add the notation to
|
|
|
_format: dict
|
|
|
format characteristics for the lines
|
|
|
|
|
|
Returns:
|
|
|
-------
|
|
|
fig: figure
|
|
|
figure with the added notation
|
|
|
'''
|
|
|
|
|
|
y_range = np.zeros([len(array_columns), 2])
|
|
|
for i in range(len(array_columns)):
|
|
|
y_range[i] = [1.01+i*_format['interline'], 1.02+i*_format['interline']]
|
|
|
|
|
|
|
|
|
fig_dict = fig.to_dict()
|
|
|
|
|
|
if subplot:
|
|
|
if subplot == 1:
|
|
|
subplot_str = ''
|
|
|
else:
|
|
|
subplot_str =str(subplot)
|
|
|
indices = []
|
|
|
for index, data in enumerate(fig_dict['data']):
|
|
|
|
|
|
if data['xaxis'] == 'x' + subplot_str:
|
|
|
indices = np.append(indices, index)
|
|
|
indices = [int(i) for i in indices]
|
|
|
print((indices))
|
|
|
else:
|
|
|
subplot_str = ''
|
|
|
|
|
|
|
|
|
for index, column_pair in enumerate(array_columns):
|
|
|
if subplot:
|
|
|
data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
|
|
|
else:
|
|
|
data_pair = column_pair
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(stats_test, Callable):
|
|
|
|
|
|
d1 = fig_dict['data'][data_pair[0]]['y']
|
|
|
d2 = fig_dict['data'][data_pair[1]]['y']
|
|
|
d1 = base64.b64decode(d1['bdata'])
|
|
|
d2 = base64.b64decode(d2['bdata'])
|
|
|
d1 = np.frombuffer(d1, dtype=np.float64)
|
|
|
d2 = np.frombuffer(d2, dtype=np.float64)
|
|
|
pvalue = stats_test(
|
|
|
d1,
|
|
|
d2,
|
|
|
)[1]
|
|
|
else:
|
|
|
pvalue = stats_test[index]
|
|
|
if pvalue >= 0.05:
|
|
|
symbol = 'ns'
|
|
|
elif pvalue >= 0.01:
|
|
|
symbol = '*'
|
|
|
elif pvalue >= 0.001:
|
|
|
symbol = '**'
|
|
|
else:
|
|
|
symbol = '***'
|
|
|
|
|
|
fig.add_shape(type="line",
|
|
|
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
|
|
x0=column_pair[0], y0=y_range[index][0],
|
|
|
x1=column_pair[0], y1=y_range[index][1],
|
|
|
line=dict(color=_format['color'], width=2,)
|
|
|
)
|
|
|
|
|
|
fig.add_shape(type="line",
|
|
|
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
|
|
x0=column_pair[0], y0=y_range[index][1],
|
|
|
x1=column_pair[1], y1=y_range[index][1],
|
|
|
line=dict(color=_format['color'], width=2,)
|
|
|
)
|
|
|
|
|
|
fig.add_shape(type="line",
|
|
|
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
|
|
x0=column_pair[1], y0=y_range[index][0],
|
|
|
x1=column_pair[1], y1=y_range[index][1],
|
|
|
line=dict(color=_format['color'], width=2,)
|
|
|
)
|
|
|
|
|
|
|
|
|
fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
|
|
|
x=(column_pair[0] + column_pair[1])/2,
|
|
|
y=y_range[index][1]*_format['text_height'],
|
|
|
showarrow=False,
|
|
|
text=symbol,
|
|
|
textangle=0,
|
|
|
xref="x"+subplot_str,
|
|
|
yref="y"+subplot_str+" domain"
|
|
|
))
|
|
|
return fig |