| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from multiprocessing.pool import Pool |
|
|
| import numpy as np |
| import SimpleITK as sitk |
| from nnunet.utilities.task_name_id_conversion import convert_task_name_to_id, convert_id_to_task_name |
| from batchgenerators.utilities.file_and_folder_operations import * |
| from nnunet.paths import * |
|
|
| color_cycle = ( |
| "000000", |
| "4363d8", |
| "f58231", |
| "3cb44b", |
| "e6194B", |
| "911eb4", |
| "ffe119", |
| "bfef45", |
| "42d4f4", |
| "f032e6", |
| "000075", |
| "9A6324", |
| "808000", |
| "800000", |
| "469990", |
| ) |
|
|
|
|
| def hex_to_rgb(hex: str): |
| assert len(hex) == 6 |
| return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
| def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None, color_cycle=color_cycle, |
| overlay_intensity=0.6): |
| """ |
| image must be a color image, so last dimension must be 3. if image is grayscale, tile it first! |
| Segmentation must be label map of same shape as image (w/o color channels) |
| mapping can be label_id -> idx_in_cycle or None |
| |
| returned image is scaled to [0, 255]!!! |
| """ |
| |
| |
|
|
| |
| image = np.copy(input_image) |
|
|
| if len(image.shape) == 2: |
| image = np.tile(image[:, :, None], (1, 1, 3)) |
| elif len(image.shape) == 3: |
| assert image.shape[2] == 3, 'if 3d image is given the last dimension must be the color channels ' \ |
| '(3 channels). Only 2D images are supported' |
|
|
| else: |
| raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in " |
| "last dimension) are supported") |
|
|
| |
| image = image - image.min() |
| image = image / image.max() * 255 |
|
|
| |
|
|
| if mapping is None: |
| uniques = np.unique(segmentation) |
| mapping = {i: c for c, i in enumerate(uniques)} |
|
|
| for l in mapping.keys(): |
| image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]])) |
|
|
| |
| image = image / image.max() * 255 |
| return image.astype(np.uint8) |
|
|
|
|
| def plot_overlay(image_file: str, segmentation_file: str, output_file: str, overlay_intensity: float = 0.6): |
| import matplotlib.pyplot as plt |
|
|
| image = sitk.GetArrayFromImage(sitk.ReadImage(image_file)) |
| seg = sitk.GetArrayFromImage(sitk.ReadImage(segmentation_file)) |
| assert all([i == j for i, j in zip(image.shape, seg.shape)]), "image and seg do not have the same shape: %s, %s" % ( |
| image_file, segmentation_file) |
|
|
| assert len(image.shape) == 3, 'only 3D images/segs are supported' |
|
|
| fg_mask = seg != 0 |
| fg_per_slice = fg_mask.sum((1, 2)) |
| selected_slice = np.argmax(fg_per_slice) |
|
|
| overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) |
|
|
| plt.imsave(output_file, overlay) |
|
|
|
|
| def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, modality_index=0): |
| import matplotlib.pyplot as plt |
| data = np.load(case_file)['data'] |
|
|
| assert modality_index < (data.shape[0] - 1), 'This dataset only supports modality index up to %d' % (data.shape[0] - 2) |
|
|
| image = data[modality_index] |
| seg = data[-1] |
| seg[seg < 0] = 0 |
|
|
| fg_mask = seg > 0 |
| fg_per_slice = fg_mask.sum((1, 2)) |
| selected_slice = np.argmax(fg_per_slice) |
|
|
| overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) |
|
|
| plt.imsave(output_file, overlay) |
|
|
|
|
| def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, list_of_output_files, overlay_intensity, |
| num_processes=8): |
| p = Pool(num_processes) |
| r = p.starmap_async(plot_overlay, zip( |
| list_of_image_files, list_of_seg_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files) |
| )) |
| r.get() |
| p.close() |
| p.join() |
|
|
|
|
| def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity, |
| num_processes=8, modality_index=0): |
| p = Pool(num_processes) |
| r = p.starmap_async(plot_overlay_preprocessed, zip( |
| list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files), |
| [modality_index] * len(list_of_output_files) |
| )) |
| r.get() |
| p.close() |
| p.join() |
|
|
|
|
| def generate_overlays_for_task(task_name_or_id, output_folder, num_processes=8, modality_idx=0, use_preprocessed=True, |
| data_identifier=default_data_identifier): |
| if isinstance(task_name_or_id, str): |
| if not task_name_or_id.startswith("Task"): |
| task_name_or_id = int(task_name_or_id) |
| task_name = convert_id_to_task_name(task_name_or_id) |
| else: |
| task_name = task_name_or_id |
| else: |
| task_name = convert_id_to_task_name(int(task_name_or_id)) |
|
|
| if not use_preprocessed: |
| folder = join(nnUNet_raw_data, task_name) |
|
|
| identifiers = [i[:-7] for i in subfiles(join(folder, 'labelsTr'), suffix='.nii.gz', join=False)] |
|
|
| image_files = [join(folder, 'imagesTr', i + "_%04.0d.nii.gz" % modality_idx) for i in identifiers] |
| seg_files = [join(folder, 'labelsTr', i + ".nii.gz") for i in identifiers] |
|
|
| assert all([isfile(i) for i in image_files]) |
| assert all([isfile(i) for i in seg_files]) |
|
|
| maybe_mkdir_p(output_folder) |
| output_files = [join(output_folder, i + '.png') for i in identifiers] |
| multiprocessing_plot_overlay(image_files, seg_files, output_files, 0.6, num_processes) |
| else: |
| folder = join(preprocessing_output_dir, task_name) |
| if not isdir(folder): raise RuntimeError("run preprocessing for that task first") |
| matching_folders = subdirs(folder, prefix=data_identifier + "_stage") |
| if len(matching_folders) == 0: "run preprocessing for that task first (use default experiment planner!)" |
| matching_folders.sort() |
| folder = matching_folders[-1] |
| identifiers = [i[:-4] for i in subfiles(folder, suffix='.npz', join=False)] |
| maybe_mkdir_p(output_folder) |
| output_files = [join(output_folder, i + '.png') for i in identifiers] |
| image_files = [join(folder, i + ".npz") for i in identifiers] |
| maybe_mkdir_p(output_folder) |
| multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=0.6, |
| num_processes=num_processes, modality_index=modality_idx) |
|
|
|
|
| def entry_point_generate_overlay(): |
| import argparse |
| parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this " |
| "disregards spacing information!") |
| parser.add_argument('-t', type=str, help="task name or task ID", required=True) |
| parser.add_argument('-o', type=str, help="output folder", required=True) |
| parser.add_argument('-num_processes', type=int, default=8, required=False, help="number of processes used. Default: 8") |
| parser.add_argument('-modality_idx', type=int, default=0, required=False, |
| help="modality index used (0 = _0000.nii.gz). Default: 0") |
| parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else " |
| "we use preprocessed") |
| args = parser.parse_args() |
|
|
| generate_overlays_for_task(args.t, args.o, args.num_processes, args.modality_idx, use_preprocessed=not args.use_raw) |