| import matplotlib.pyplot as plt |
| from mpl_toolkits.mplot3d import Axes3D |
| from matplotlib.lines import Line2D |
| import numpy as np |
| from DeepDeformationMapRegistration.utils.visualization import add_axes_arrows_3d, remove_tick_labels, set_axes_size |
| import os |
|
|
|
|
| def _plot_graph(graph, ax, nodes_colour='C3', edges_colour='C1', plot_nodes=True, plot_edges=True, add_axes=True): |
| if plot_edges: |
| for (start_node, end_node) in graph.edges(): |
| edge_pts = graph[start_node][end_node]['pts'] |
| edge_pts = np.vstack([graph.nodes[start_node]['o'], edge_pts]) |
| edge_pts = np.vstack([edge_pts, graph.nodes[end_node]['o']]) |
| ax.plot(edge_pts[:, 0], edge_pts[:, 1], edge_pts[:, 2], edges_colour) |
|
|
| if plot_nodes: |
| nodes = graph.nodes() |
| ps = np.array([nodes[i]['o'] for i in nodes]) |
| if len(ps.shape) > 1: |
| ax.scatter(ps[:, 0], ps[:, 1], ps[:, 2], nodes_colour) |
| else: |
| ax.scatter(ps[0], ps[1], ps[2], nodes_colour) |
| ax.set_xlim(0, 63) |
| ax.set_ylim(0, 63) |
| ax.set_zlim(0, 63) |
| remove_tick_labels(ax, True) |
| if add_axes: |
| add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b') |
| ax.view_init(None, 45) |
|
|
| return ax |
|
|
|
|
| def plot_skeleton(img, skeleton, graph, filename='skeleton', extension=['.png']): |
| if not isinstance(extension, list): |
| extension = [extension] |
| |
| f = plt.figure(figsize=(5, 5)) |
| ax = f.add_subplot(111, projection='3d') |
|
|
| coords = np.argwhere(skeleton) |
| i = coords[:, 0] |
| j = coords[:, 1] |
| k = coords[:, 2] |
|
|
| seg = ax.voxels(img, facecolors=(0., 0., 1., 0.3), label='image') |
| ske = ax.scatter(i, j, k, color='C1', label='skeleton', s=1) |
| ax.set_xlim(0, 63) |
| ax.set_ylim(0, 63) |
| ax.set_zlim(0, 63) |
| remove_tick_labels(ax, True) |
| add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b') |
| ax.view_init(None, 45) |
| for ex in extension: |
| f.savefig(filename + '_segmentation_skeleton' + ex) |
|
|
| |
| ax = _plot_graph(graph, ax, 'r', 'r') |
|
|
| for ex in extension: |
| f.savefig(filename + '_combined' + ex) |
| plt.close() |
|
|
| |
| f = plt.figure(figsize=(5, 5)) |
| ax = f.add_subplot(111, projection='3d') |
|
|
| ax = _plot_graph(graph, ax) |
|
|
| for ex in extension: |
| f.savefig(filename + '_graph' + ex) |
| plt.close() |
|
|
|
|
|
|
|
|
| def compare_graphs(graph_0, graph_1, graph_names=None, filename='compare_graphs'): |
| f = plt.figure(figsize=(12, 5)) |
| if graph_names is None: |
| graph_names =['graph_0', 'graph_1'] |
| else: |
| assert len(graph_names) == 2, 'A different name is expected for each graph' |
| ax = f.add_subplot(131, projection='3d') |
| ax = _plot_graph(graph_0, ax) |
| ax.set_title(graph_names[0], y=-0.01) |
|
|
| ax = f.add_subplot(132, projection='3d') |
| ax = _plot_graph(graph_1, ax) |
| ax.set_title(graph_names[1]) |
|
|
| ax = f.add_subplot(133, projection='3d') |
| ax = _plot_graph(graph_0, ax, 'C2', 'C2', plot_nodes=False) |
| ax = _plot_graph(graph_1, ax, 'C4', 'C4', plot_nodes=False) |
| legend_elements = [Line2D([0], [0], color='C2', lw=2, label=graph_names[0]), |
| Line2D([0], [0], color='C4', lw=2, label=graph_names[1])] |
| ax.legend(handles=legend_elements) |
|
|
| f.savefig(filename + '_compare_graphs.png') |
| plt.close() |
|
|
|
|
| def plot_cpd_registration_step(iteration, error, X, Y, out_folder, add_axes=True, pdf=True): |
| fig = plt.figure(figsize=(8, 8)) |
| ax = fig.add_axes([0, 0, .9, .9], projection='3d') |
| ax.scatter(X[:, 0], X[:, 1], X[:, 2], color='C1', label='Fixed') |
| ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], color='C2', label='Moving') |
|
|
| ax.text2D(0.95, 0.98, 'Iteration: {:d}'.format( |
| iteration), horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize='x-large') |
| |
| |
| ax.legend(loc='upper left', fontsize='x-large') |
|
|
| if add_axes: |
| x_range = [np.min(np.hstack([X[:, 0], Y[:, 0]])), np.max(np.hstack([X[:, 0], Y[:, 0]]))] |
| y_range = [np.min(np.hstack([X[:, 1], Y[:, 1]])), np.max(np.hstack([X[:, 1], Y[:, 1]]))] |
| z_range = [np.min(np.hstack([X[:, 2], Y[:, 2]])), np.max(np.hstack([X[:, 2], Y[:, 2]]))] |
| ax.set_xlim(x_range[0], x_range[1]) |
| ax.set_ylim(y_range[0], y_range[1]) |
| ax.set_zlim(z_range[0], z_range[1]) |
|
|
| remove_tick_labels(ax, True) |
| add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3) |
| ax.view_init(None, 45) |
|
|
| os.makedirs(out_folder, exist_ok=True) |
| fig.savefig(os.path.join(out_folder, '{:04d}.png'.format(iteration))) |
| if pdf: |
| fig.savefig(os.path.join(out_folder, '{:04d}.pdf'.format(iteration))) |
| plt.close() |
|
|
|
|
| def plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, file_name): |
| fig = plt.figure(figsize=(8, 8)) |
| ax = fig.add_axes([0, 0, .9, .9], projection='3d') |
| ax.scatter(fix_pts[:, 0], fix_pts[:, 1], fix_pts[:, 2], color='C1', label='Fixed') |
| ax.scatter(mov_pts[:, 0], mov_pts[:, 1], mov_pts[:, 2], color='C2', label='Moving') |
| ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='none', s=100, edgecolor='b', label='Centroid') |
| ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='none', s=100, edgecolor='b') |
| ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='C1') |
| ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='C2') |
|
|
| x_range = [np.min(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]])), |
| np.max(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]]))] |
| y_range = [np.min(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]])), |
| np.max(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]]))] |
| z_range = [np.min(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]])), |
| np.max(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]]))] |
| ax.set_xlim(x_range[0], x_range[1]) |
| ax.set_ylim(y_range[0], y_range[1]) |
| ax.set_zlim(z_range[0], z_range[1]) |
|
|
| remove_tick_labels(ax, True) |
| add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3) |
| ax.view_init(None, 45) |
| ax.legend(fontsize='xx-large') |
| fig.savefig(file_name + '.png') |
| fig.savefig(file_name + '.pdf') |
| plt.close() |
|
|
|
|