| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import monai |
| | import torch |
| | import os |
| | import json |
| | import matplotlib |
| | import shutil |
| | from torchview import draw_graph |
| |
|
| | def plot_architecture(network, img_shape, batch_size, name, save_dir): |
| | if name == 'SegNet': |
| | num_channels = 1 |
| | else: |
| | num_channels = 2 |
| | |
| | H, D, W = img_shape |
| | model_graph = draw_graph(network, |
| | input_size=(batch_size, num_channels, H, D, W), |
| | device='meta', |
| | roll=True, |
| | expand_nested=True, |
| | save_graph=True, |
| | filename=f"{name}_Graph", |
| | directory=save_dir) |
| | |
| | def make_if_dont_exist(folder_path, overwrite=False): |
| | if os.path.exists(folder_path): |
| | if not overwrite: |
| | print(f'{folder_path} exists.') |
| | else: |
| | print(f"{folder_path} overwritten") |
| | shutil.rmtree(folder_path, ignore_errors = True) |
| | os.makedirs(folder_path) |
| | else: |
| | os.makedirs(folder_path) |
| | print(f"{folder_path} created!") |
| |
|
| | def preview_image(image_array, normalize_by="volume", cmap=None, figsize=(12, 12), threshold=None): |
| | """ |
| | Display three orthogonal slices of the given 3D image. |
| | |
| | image_array is assumed to be of shape (H,W,D) |
| | |
| | If a number is provided for threshold, then pixels for which the value |
| | is below the threshold will be shown in red |
| | """ |
| | plt.figure() |
| | if normalize_by == "slice": |
| | vmin = None |
| | vmax = None |
| | elif normalize_by == "volume": |
| | vmin = 0 |
| | vmax = image_array.max().item() |
| | else: |
| | raise(ValueError( |
| | f"Invalid value '{normalize_by}' given for normalize_by")) |
| |
|
| | |
| | x, y, z = np.array(image_array.shape)//2 |
| | imgs = (image_array[x, :, :], image_array[:, y, :], image_array[:, :, z]) |
| |
|
| | fig, axs = plt.subplots(1, 3, figsize=figsize) |
| | for ax, im in zip(axs, imgs): |
| | ax.axis('off') |
| | ax.imshow(im, origin='lower', vmin=vmin, vmax=vmax, cmap=cmap) |
| |
|
| | |
| | |
| | if threshold is not None: |
| | red = np.zeros(im.shape+(4,)) |
| | red[im <= threshold] = [1, 0, 0, 1] |
| | ax.imshow(red, origin='lower') |
| |
|
| | plt.savefig('test.png') |
| |
|
| |
|
| | def plot_2D_vector_field(vector_field, downsampling): |
| | """Plot a 2D vector field given as a tensor of shape (2,H,W). |
| | |
| | The plot origin will be in the lower left. |
| | Using "x" and "y" for the rightward and upward directions respectively, |
| | the vector at location (x,y) in the plot image will have |
| | vector_field[1,y,x] as its x-component and |
| | vector_field[0,y,x] as its y-component. |
| | """ |
| | downsample2D = monai.networks.layers.factories.Pool['AVG', 2]( |
| | kernel_size=downsampling) |
| | vf_downsampled = downsample2D(vector_field.unsqueeze(0))[0] |
| | plt.quiver( |
| | vf_downsampled[1, :, :], vf_downsampled[0, :, :], |
| | angles='xy', scale_units='xy', scale=downsampling, |
| | headwidth=4. |
| | ) |
| |
|
| |
|
| | def preview_3D_vector_field(vector_field, downsampling=None, ep=None, path=None): |
| | """ |
| | Display three orthogonal slices of the given 3D vector field. |
| | |
| | vector_field should be a tensor of shape (3,H,W,D) |
| | |
| | Vectors are projected into the viewing plane, so you are only seeing |
| | their components in the viewing plane. |
| | """ |
| |
|
| | if downsampling is None: |
| | |
| | downsampling = max(1, int(max(vector_field.shape[1:])) >> 5) |
| |
|
| | x, y, z = np.array(vector_field.shape[1:])//2 |
| | plt.figure(figsize=(18, 6)) |
| | plt.subplot(1, 3, 1) |
| | plt.axis('off') |
| | plot_2D_vector_field(vector_field[[1, 2], x, :, :], downsampling) |
| | plt.subplot(1, 3, 2) |
| | plt.axis('off') |
| | plot_2D_vector_field(vector_field[[0, 2], :, y, :], downsampling) |
| | plt.subplot(1, 3, 3) |
| | plt.axis('off') |
| | plot_2D_vector_field(vector_field[[0, 1], :, :, z], downsampling) |
| | plt.savefig(os.path.join(path, f'df_{ep}.png')) |
| |
|
| |
|
| | def plot_2D_deformation(vector_field, grid_spacing, **kwargs): |
| | """ |
| | Interpret vector_field as a displacement vector field defining a deformation, |
| | and plot an x-y grid warped by this deformation. |
| | |
| | vector_field should be a tensor of shape (2,H,W) |
| | """ |
| | _, H, W = vector_field.shape |
| | grid_img = np.zeros((H, W)) |
| | grid_img[np.arange(0, H, grid_spacing), :] = 1 |
| | grid_img[:, np.arange(0, W, grid_spacing)] = 1 |
| | grid_img = torch.tensor(grid_img, dtype=vector_field.dtype).unsqueeze( |
| | 0) |
| | warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="zeros") |
| | grid_img_warped = warp(grid_img.unsqueeze(0), vector_field.unsqueeze(0))[0] |
| | plt.imshow(grid_img_warped[0], origin='lower', cmap='gist_gray') |
| |
|
| |
|
| | def preview_3D_deformation(vector_field, grid_spacing, **kwargs): |
| | """ |
| | Interpret vector_field as a displacement vector field defining a deformation, |
| | and plot warped grids along three orthogonal slices. |
| | |
| | vector_field should be a tensor of shape (3,H,W,D) |
| | kwargs are passed to matplotlib plotting |
| | |
| | Deformations are projected into the viewing plane, so you are only seeing |
| | their components in the viewing plane. |
| | """ |
| | x, y, z = np.array(vector_field.shape[1:])//2 |
| | plt.figure(figsize=(18, 6)) |
| | plt.subplot(1, 3, 1) |
| | plt.axis('off') |
| | plot_2D_deformation(vector_field[[1, 2], x, :, :], grid_spacing, **kwargs) |
| | plt.subplot(1, 3, 2) |
| | plt.axis('off') |
| | plot_2D_deformation(vector_field[[0, 2], :, y, :], grid_spacing, **kwargs) |
| | plt.subplot(1, 3, 3) |
| | plt.axis('off') |
| | plot_2D_deformation(vector_field[[0, 1], :, :, z], grid_spacing, **kwargs) |
| | plt.show() |
| |
|
| |
|
| | def jacobian_determinant(vf): |
| | """ |
| | Given a displacement vector field vf, compute the jacobian determinant scalar field. |
| | |
| | vf is assumed to be a vector field of shape (3,H,W,D), |
| | and it is interpreted as the displacement field. |
| | So it is defining a discretely sampled map from a subset of 3-space into 3-space, |
| | namely the map that sends point (x,y,z) to the point (x,y,z)+vf[:,x,y,z]. |
| | This function computes a jacobian determinant by taking discrete differences in each spatial direction. |
| | |
| | Returns a numpy array of shape (H-1,W-1,D-1). |
| | """ |
| |
|
| | _, H, W, D = vf.shape |
| |
|
| | |
| | def diff_and_trim(array, axis): return np.diff( |
| | array, axis=axis)[:, :(H-1), :(W-1), :(D-1)] |
| | dx = diff_and_trim(vf, 1) |
| | dy = diff_and_trim(vf, 2) |
| | dz = diff_and_trim(vf, 3) |
| |
|
| | |
| | dx[0] += 1 |
| | dy[1] += 1 |
| | dz[2] += 1 |
| |
|
| | |
| | det = dx[0]*(dy[1]*dz[2]-dz[1]*dy[2]) - dy[0]*(dx[1]*dz[2] - |
| | dz[1]*dx[2]) + dz[0]*(dx[1]*dy[2]-dy[1]*dx[2]) |
| |
|
| | return det |
| |
|
| | def load_json(json_path): |
| | assert type(json_path) == str |
| | fjson = open(json_path, 'r') |
| | json_file = json.load(fjson) |
| | return json_file |
| |
|
| | def plot_progress(logger, save_dir, train_loss, val_loss, name): |
| | """ |
| | Should probably by improved |
| | :return: |
| | """ |
| | assert len(train_loss) != 0 |
| | train_loss = np.array(train_loss) |
| | try: |
| | font = {'weight': 'normal', |
| | 'size': 18} |
| |
|
| | matplotlib.rc('font', **font) |
| |
|
| | fig = plt.figure(figsize=(30, 24)) |
| | ax = fig.add_subplot(111) |
| | ax.plot(train_loss[:,0], train_loss[:,1], color='b', ls='-', label="loss_tr") |
| | if len(val_loss) != 0: |
| | val_loss = np.array(val_loss) |
| | ax.plot(val_loss[:, 0], val_loss[:, 1], color='r', ls='-', label="loss_val") |
| |
|
| | ax.set_xlabel("epoch") |
| | ax.set_ylabel("loss") |
| | ax.legend() |
| | ax.set_title(name) |
| | fig.savefig(os.path.join(save_dir, name + ".png")) |
| | plt.cla() |
| | plt.close(fig) |
| | except: |
| | logger.info(f"failed to plot {name} training progress") |
| |
|
| | def save_reg_checkpoint(network, optimizer, epoch, best_loss, sim_loss=None, regular_loss=None, ana_loss=None, total_loss=None, save_dir=None, name=None): |
| | all_loss = { |
| | 'best_loss': best_loss, |
| | 'total_loss': total_loss, |
| | } |
| | if sim_loss is not None: |
| | all_loss['sim_loss'] = sim_loss |
| | if regular_loss is not None: |
| | all_loss['regular_loss'] = regular_loss |
| | if ana_loss is not None: |
| | all_loss['ana_loss'] = ana_loss |
| | |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'network_state_dict': network.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'all_loss': all_loss, |
| | }, os.path.join(save_dir, name+'_checkpoint.pth')) |
| |
|
| |
|
| | def save_seg_checkpoint(network, optimizer, epoch, best_loss, super_loss=None, ana_loss=None, total_loss=None, save_dir=None, name=None): |
| | all_loss = { |
| | 'best_loss': best_loss, |
| | 'total_loss': total_loss, |
| | } |
| | if super_loss is not None: |
| | all_loss['super_loss'] = super_loss |
| | if ana_loss is not None: |
| | all_loss['ana_loss'] = ana_loss |
| | |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'network_state_dict': network.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'all_loss': all_loss, |
| | }, os.path.join(save_dir, name+'_checkpoint.pth')) |
| |
|
| |
|
| | def load_latest_checkpoint(path, network, optimizer, device): |
| | checkpoint_path = os.path.join(path, 'latest_checkpoint.pth') |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | network.load_state_dict(checkpoint['network_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | all_loss = checkpoint['all_loss'] |
| | return network, optimizer, all_loss |
| |
|
| | def load_valid_checkpoint(path, device): |
| | checkpoint_path = os.path.join(path, 'valid_checkpoint.pth') |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | all_loss = checkpoint['all_loss'] |
| | return all_loss |
| |
|
| | def load_best_checkpoint(path, device): |
| | checkpoint_path = os.path.join(path, 'best_checkpoint.pth') |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | best_loss = checkpoint['all_loss']['best_loss'] |
| | return best_loss |
| |
|
| |
|